diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohttp-3.11.13.dist-info/INSTALLER b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohttp-3.11.13.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohttp-3.11.13.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohttp-3.11.13.dist-info/LICENSE.txt b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohttp-3.11.13.dist-info/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..e497a322f2091d022983b9c5c043082ab61d1a8c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohttp-3.11.13.dist-info/LICENSE.txt @@ -0,0 +1,13 @@ + Copyright aio-libs contributors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohttp-3.11.13.dist-info/METADATA b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohttp-3.11.13.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..fa6a4ff8494d6c1dd3a7f20d4735003760e5e194 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohttp-3.11.13.dist-info/METADATA @@ -0,0 +1,250 @@ +Metadata-Version: 2.2 +Name: aiohttp +Version: 3.11.13 +Summary: Async http client/server framework (asyncio) +Home-page: https://github.com/aio-libs/aiohttp +Maintainer: aiohttp team +Maintainer-email: team@aiohttp.org +License: Apache-2.0 +Project-URL: Chat: Matrix, https://matrix.to/#/#aio-libs:matrix.org +Project-URL: Chat: Matrix Space, https://matrix.to/#/#aio-libs-space:matrix.org +Project-URL: CI: GitHub Actions, https://github.com/aio-libs/aiohttp/actions?query=workflow%3ACI +Project-URL: Coverage: codecov, https://codecov.io/github/aio-libs/aiohttp +Project-URL: Docs: Changelog, https://docs.aiohttp.org/en/stable/changes.html +Project-URL: Docs: RTD, https://docs.aiohttp.org +Project-URL: GitHub: issues, https://github.com/aio-libs/aiohttp/issues +Project-URL: GitHub: repo, https://github.com/aio-libs/aiohttp +Classifier: Development Status :: 5 - Production/Stable +Classifier: Framework :: AsyncIO +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Operating System :: POSIX +Classifier: Operating System :: MacOS :: MacOS X +Classifier: Operating System :: Microsoft :: Windows +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Programming Language :: Python :: 3.13 +Classifier: Topic :: Internet :: WWW/HTTP +Requires-Python: >=3.9 +Description-Content-Type: text/x-rst +License-File: LICENSE.txt +Requires-Dist: aiohappyeyeballs>=2.3.0 +Requires-Dist: aiosignal>=1.1.2 +Requires-Dist: async-timeout<6.0,>=4.0; python_version < "3.11" +Requires-Dist: attrs>=17.3.0 +Requires-Dist: frozenlist>=1.1.1 +Requires-Dist: multidict<7.0,>=4.5 +Requires-Dist: propcache>=0.2.0 +Requires-Dist: yarl<2.0,>=1.17.0 +Provides-Extra: speedups +Requires-Dist: aiodns>=3.2.0; (sys_platform == "linux" or sys_platform == "darwin") and extra == "speedups" +Requires-Dist: Brotli; platform_python_implementation == "CPython" and extra == "speedups" +Requires-Dist: brotlicffi; platform_python_implementation != "CPython" and extra == "speedups" + +================================== +Async http client/server framework +================================== + +.. image:: https://raw.githubusercontent.com/aio-libs/aiohttp/master/docs/aiohttp-plain.svg + :height: 64px + :width: 64px + :alt: aiohttp logo + +| + +.. image:: https://github.com/aio-libs/aiohttp/workflows/CI/badge.svg + :target: https://github.com/aio-libs/aiohttp/actions?query=workflow%3ACI + :alt: GitHub Actions status for master branch + +.. image:: https://codecov.io/gh/aio-libs/aiohttp/branch/master/graph/badge.svg + :target: https://codecov.io/gh/aio-libs/aiohttp + :alt: codecov.io status for master branch + +.. image:: https://img.shields.io/endpoint?url=https://codspeed.io/badge.json + :target: https://codspeed.io/aio-libs/aiohttp + :alt: Codspeed.io status for aiohttp + +.. image:: https://badge.fury.io/py/aiohttp.svg + :target: https://pypi.org/project/aiohttp + :alt: Latest PyPI package version + +.. image:: https://readthedocs.org/projects/aiohttp/badge/?version=latest + :target: https://docs.aiohttp.org/ + :alt: Latest Read The Docs + +.. image:: https://img.shields.io/matrix/aio-libs:matrix.org?label=Discuss%20on%20Matrix%20at%20%23aio-libs%3Amatrix.org&logo=matrix&server_fqdn=matrix.org&style=flat + :target: https://matrix.to/#/%23aio-libs:matrix.org + :alt: Matrix Room — #aio-libs:matrix.org + +.. image:: https://img.shields.io/matrix/aio-libs-space:matrix.org?label=Discuss%20on%20Matrix%20at%20%23aio-libs-space%3Amatrix.org&logo=matrix&server_fqdn=matrix.org&style=flat + :target: https://matrix.to/#/%23aio-libs-space:matrix.org + :alt: Matrix Space — #aio-libs-space:matrix.org + + +Key Features +============ + +- Supports both client and server side of HTTP protocol. +- Supports both client and server Web-Sockets out-of-the-box and avoids + Callback Hell. +- Provides Web-server with middleware and pluggable routing. + + +Getting started +=============== + +Client +------ + +To get something from the web: + +.. code-block:: python + + import aiohttp + import asyncio + + async def main(): + + async with aiohttp.ClientSession() as session: + async with session.get('http://python.org') as response: + + print("Status:", response.status) + print("Content-type:", response.headers['content-type']) + + html = await response.text() + print("Body:", html[:15], "...") + + asyncio.run(main()) + +This prints: + +.. code-block:: + + Status: 200 + Content-type: text/html; charset=utf-8 + Body: ... + +Coming from `requests `_ ? Read `why we need so many lines `_. + +Server +------ + +An example using a simple server: + +.. code-block:: python + + # examples/server_simple.py + from aiohttp import web + + async def handle(request): + name = request.match_info.get('name', "Anonymous") + text = "Hello, " + name + return web.Response(text=text) + + async def wshandle(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for msg in ws: + if msg.type == web.WSMsgType.text: + await ws.send_str("Hello, {}".format(msg.data)) + elif msg.type == web.WSMsgType.binary: + await ws.send_bytes(msg.data) + elif msg.type == web.WSMsgType.close: + break + + return ws + + + app = web.Application() + app.add_routes([web.get('/', handle), + web.get('/echo', wshandle), + web.get('/{name}', handle)]) + + if __name__ == '__main__': + web.run_app(app) + + +Documentation +============= + +https://aiohttp.readthedocs.io/ + + +Demos +===== + +https://github.com/aio-libs/aiohttp-demos + + +External links +============== + +* `Third party libraries + `_ +* `Built with aiohttp + `_ +* `Powered by aiohttp + `_ + +Feel free to make a Pull Request for adding your link to these pages! + + +Communication channels +====================== + +*aio-libs Discussions*: https://github.com/aio-libs/aiohttp/discussions + +*Matrix*: `#aio-libs:matrix.org `_ + +We support `Stack Overflow +`_. +Please add *aiohttp* tag to your question there. + +Requirements +============ + +- attrs_ +- multidict_ +- yarl_ +- frozenlist_ + +Optionally you may install the aiodns_ library (highly recommended for sake of speed). + +.. _aiodns: https://pypi.python.org/pypi/aiodns +.. _attrs: https://github.com/python-attrs/attrs +.. _multidict: https://pypi.python.org/pypi/multidict +.. _frozenlist: https://pypi.org/project/frozenlist/ +.. _yarl: https://pypi.python.org/pypi/yarl +.. _async-timeout: https://pypi.python.org/pypi/async_timeout + +License +======= + +``aiohttp`` is offered under the Apache 2 license. + + +Keepsafe +======== + +The aiohttp community would like to thank Keepsafe +(https://www.getkeepsafe.com) for its support in the early days of +the project. + + +Source code +=========== + +The latest developer version is available in a GitHub repository: +https://github.com/aio-libs/aiohttp + +Benchmarks +========== + +If you are interested in efficiency, the AsyncIO community maintains a +list of benchmarks on the official wiki: +https://github.com/python/asyncio/wiki/Benchmarks diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohttp-3.11.13.dist-info/RECORD b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohttp-3.11.13.dist-info/RECORD new file mode 100644 index 0000000000000000000000000000000000000000..9d43e4eb2ebd686dce04c3f7fad6013447835214 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohttp-3.11.13.dist-info/RECORD @@ -0,0 +1,131 @@ +aiohttp-3.11.13.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +aiohttp-3.11.13.dist-info/LICENSE.txt,sha256=n4DQ2311WpQdtFchcsJw7L2PCCuiFd3QlZhZQu2Uqes,588 +aiohttp-3.11.13.dist-info/METADATA,sha256=V_vS5w25_e4iRV4NQ4mjjJeaSpzg0IaWqru7XzJqJLc,7712 +aiohttp-3.11.13.dist-info/RECORD,, +aiohttp-3.11.13.dist-info/WHEEL,sha256=siqMuoWpRueIZ87ijidBxnOwHeSOOcxNwYCs-pC4Yv0,151 +aiohttp-3.11.13.dist-info/top_level.txt,sha256=iv-JIaacmTl-hSho3QmphcKnbRRYx1st47yjz_178Ro,8 +aiohttp/.hash/_cparser.pxd.hash,sha256=hYa9Vje-oMs2eh_7MfCPOh2QW_1x1yCjcZuc7AmwLd0,121 +aiohttp/.hash/_find_header.pxd.hash,sha256=_mbpD6vM-CVCKq3ulUvsOAz5Wdo88wrDzfpOsMQaMNA,125 +aiohttp/.hash/_http_parser.pyx.hash,sha256=GBgZjCNbtZApPhf9-gHpS5Z2WMIzM-vgp5VSZIEvZfk,125 +aiohttp/.hash/_http_writer.pyx.hash,sha256=-UgSF82qclpxjP0og_gcFEsstXRKF9e3Ou4wziAyDvI,125 +aiohttp/.hash/hdrs.py.hash,sha256=v6IaKbsxjsdQxBzhb5AjP0x_9G3rUe84D7avf7AI4cs,116 +aiohttp/__init__.py,sha256=612FENJ9aLTJsbeZNR9LF_PAJ4nLUY4NgyqlkDnO70c,7840 +aiohttp/__pycache__/__init__.cpython-312.pyc,, +aiohttp/__pycache__/abc.cpython-312.pyc,, +aiohttp/__pycache__/base_protocol.cpython-312.pyc,, +aiohttp/__pycache__/client.cpython-312.pyc,, +aiohttp/__pycache__/client_exceptions.cpython-312.pyc,, +aiohttp/__pycache__/client_proto.cpython-312.pyc,, +aiohttp/__pycache__/client_reqrep.cpython-312.pyc,, +aiohttp/__pycache__/client_ws.cpython-312.pyc,, +aiohttp/__pycache__/compression_utils.cpython-312.pyc,, +aiohttp/__pycache__/connector.cpython-312.pyc,, +aiohttp/__pycache__/cookiejar.cpython-312.pyc,, +aiohttp/__pycache__/formdata.cpython-312.pyc,, +aiohttp/__pycache__/hdrs.cpython-312.pyc,, +aiohttp/__pycache__/helpers.cpython-312.pyc,, +aiohttp/__pycache__/http.cpython-312.pyc,, +aiohttp/__pycache__/http_exceptions.cpython-312.pyc,, +aiohttp/__pycache__/http_parser.cpython-312.pyc,, +aiohttp/__pycache__/http_websocket.cpython-312.pyc,, +aiohttp/__pycache__/http_writer.cpython-312.pyc,, +aiohttp/__pycache__/log.cpython-312.pyc,, +aiohttp/__pycache__/multipart.cpython-312.pyc,, +aiohttp/__pycache__/payload.cpython-312.pyc,, +aiohttp/__pycache__/payload_streamer.cpython-312.pyc,, +aiohttp/__pycache__/pytest_plugin.cpython-312.pyc,, +aiohttp/__pycache__/resolver.cpython-312.pyc,, +aiohttp/__pycache__/streams.cpython-312.pyc,, +aiohttp/__pycache__/tcp_helpers.cpython-312.pyc,, +aiohttp/__pycache__/test_utils.cpython-312.pyc,, +aiohttp/__pycache__/tracing.cpython-312.pyc,, +aiohttp/__pycache__/typedefs.cpython-312.pyc,, +aiohttp/__pycache__/web.cpython-312.pyc,, +aiohttp/__pycache__/web_app.cpython-312.pyc,, +aiohttp/__pycache__/web_exceptions.cpython-312.pyc,, +aiohttp/__pycache__/web_fileresponse.cpython-312.pyc,, +aiohttp/__pycache__/web_log.cpython-312.pyc,, +aiohttp/__pycache__/web_middlewares.cpython-312.pyc,, +aiohttp/__pycache__/web_protocol.cpython-312.pyc,, +aiohttp/__pycache__/web_request.cpython-312.pyc,, +aiohttp/__pycache__/web_response.cpython-312.pyc,, +aiohttp/__pycache__/web_routedef.cpython-312.pyc,, +aiohttp/__pycache__/web_runner.cpython-312.pyc,, +aiohttp/__pycache__/web_server.cpython-312.pyc,, +aiohttp/__pycache__/web_urldispatcher.cpython-312.pyc,, +aiohttp/__pycache__/web_ws.cpython-312.pyc,, +aiohttp/__pycache__/worker.cpython-312.pyc,, +aiohttp/_cparser.pxd,sha256=8jGIg-VJ9p3llwCakUYDsPGxA4HiZe9dmK9Jmtlz-5g,4318 +aiohttp/_find_header.pxd,sha256=0GfwFCPN2zxEKTO1_MA5sYq2UfzsG8kcV3aTqvwlz3g,68 +aiohttp/_headers.pxi,sha256=n701k28dVPjwRnx5j6LpJhLTfj7dqu2vJt7f0O60Oyg,2007 +aiohttp/_http_parser.cpython-312-x86_64-linux-gnu.so,sha256=kZJwKEDTHDTxMWYND9KqRALWr8BaLSyyTBmQ6aRYuRA,2813904 +aiohttp/_http_parser.pyx,sha256=wQdADj5LizwC_7nFGr8nIlk6GpoaQeQ0359H0HMKGuM,28241 +aiohttp/_http_writer.cpython-312-x86_64-linux-gnu.so,sha256=6h2_x6dsZeFsvy_Iurvq9c-RfNLg-heJJAablu4uSSk,492232 +aiohttp/_http_writer.pyx,sha256=fiCck_EVgRiTX7VtAoV2AldjuesJMFPev4TWd9Fx8jo,4597 +aiohttp/_websocket/.hash/mask.pxd.hash,sha256=Y0zBddk_ck3pi9-BFzMcpkcvCKvwvZ4GTtZFb9u1nxQ,128 +aiohttp/_websocket/.hash/mask.pyx.hash,sha256=90owpXYM8_kIma4KUcOxhWSk-Uv4NVMBoCYeFM1B3d0,128 +aiohttp/_websocket/.hash/reader_c.pxd.hash,sha256=EoZjkF_tAFEbGvV0oRY2GZOSuAfWFWFjMhXgq6mQExo,132 +aiohttp/_websocket/__init__.py,sha256=Mar3R9_vBN_Ea4lsW7iTAVXD7OKswKPGqF5xgSyt77k,44 +aiohttp/_websocket/__pycache__/__init__.cpython-312.pyc,, +aiohttp/_websocket/__pycache__/helpers.cpython-312.pyc,, +aiohttp/_websocket/__pycache__/models.cpython-312.pyc,, +aiohttp/_websocket/__pycache__/reader.cpython-312.pyc,, +aiohttp/_websocket/__pycache__/reader_c.cpython-312.pyc,, +aiohttp/_websocket/__pycache__/reader_py.cpython-312.pyc,, +aiohttp/_websocket/__pycache__/writer.cpython-312.pyc,, +aiohttp/_websocket/helpers.py,sha256=P-XLv8IUaihKzDenVUqfKU5DJbWE5HvG8uhvUZK8Ic4,5038 +aiohttp/_websocket/mask.cpython-312-x86_64-linux-gnu.so,sha256=rdCiOTakoDIcEYBPFRf7ncQUgXmhosFnNMxUYqJn8uA,265432 +aiohttp/_websocket/mask.pxd,sha256=sBmZ1Amym9kW4Ge8lj1fLZ7mPPya4LzLdpkQExQXv5M,112 +aiohttp/_websocket/mask.pyx,sha256=BHjOtV0O0w7xp9p0LNADRJvGmgfPn9sGeJvSs0fL__4,1397 +aiohttp/_websocket/models.py,sha256=XAzjs_8JYszWXIgZ6R3ZRrF-tX9Q_6LiD49WRYojopM,2121 +aiohttp/_websocket/reader.py,sha256=eC4qS0c5sOeQ2ebAHLaBpIaTVFaSKX79pY2xvh3Pqyw,1030 +aiohttp/_websocket/reader_c.cpython-312-x86_64-linux-gnu.so,sha256=wt49Wc5GzT9rRbxIN-8pgo3mQdTthCSVaoQR4NI6OTQ,1871856 +aiohttp/_websocket/reader_c.pxd,sha256=9rMWCpAC1jng7_gtqLjRlqQv9q7UkOn63tIQfq2k8Gc,2444 +aiohttp/_websocket/reader_c.py,sha256=anZsBKZWlL8SO8gArsZMDstH37qBuZOvJA7jtj0Z95M,17975 +aiohttp/_websocket/reader_py.py,sha256=anZsBKZWlL8SO8gArsZMDstH37qBuZOvJA7jtj0Z95M,17975 +aiohttp/_websocket/writer.py,sha256=T3P36iMrzVPPC2XeScserHMD5vd9an6yizWzqDUkRZ0,7077 +aiohttp/abc.py,sha256=JLMOxrKLGTDaPRLfraY1pl-xka53YiHhAH9yaF9QRXQ,6512 +aiohttp/base_protocol.py,sha256=Tp8cxUPQvv9kUPk3w6lAzk6d2MAzV3scwI_3Go3C47c,3025 +aiohttp/client.py,sha256=isdfGlM4O5ILr4F4gBABlybxo4MQ1tNaMm7zjMcrfrM,54309 +aiohttp/client_exceptions.py,sha256=uyKbxI2peZhKl7lELBMx3UeusNkfpemPWpGFq0r6JeM,11367 +aiohttp/client_proto.py,sha256=dV7u9floGWG-_xtD2fLUYqiANG6VsJtq0HMlTjf1g-g,10015 +aiohttp/client_reqrep.py,sha256=VAgh0NxP2HvYWx6nX1Pr8FINc1m-W8-5q2zKeZV68n8,43925 +aiohttp/client_ws.py,sha256=1CIjIXwyzOMIYw6AjUES4-qUwbyVHW1seJKQfg_Rta8,15109 +aiohttp/compression_utils.py,sha256=0J3EAOR-0HehlYIudJXRu_Kr6hrYCY0IfuJ1px9MhQs,5681 +aiohttp/connector.py,sha256=ZAXixLOyIl6zQEnetPLYKkbjP2BoHFEzGeEyvRBoGqI,60734 +aiohttp/cookiejar.py,sha256=PYR1K1mkLa24Hm6c9UEJnAitccNzz97CbsJyQ2ULAlU,17615 +aiohttp/formdata.py,sha256=CUJnCWDNHFcXSYZ_TupaT6rHkY-Q7ghssvWzaYBPIo0,6552 +aiohttp/hdrs.py,sha256=2rj5MyA-6yRdYPhW5UKkW4iNWhEAlGIOSBH5D4FmKNE,5111 +aiohttp/helpers.py,sha256=KqPQECeiJ_EhA93k7-5ZoVdZH0sk_4n0tCoM_E-iMnE,29091 +aiohttp/http.py,sha256=8o8j8xH70OWjnfTWA9V44NR785QPxEPrUtzMXiAVpwc,1842 +aiohttp/http_exceptions.py,sha256=RYmBycJvvPerKkgXXm8v145I1N-fbsgSpcsbNIC-gdE,2961 +aiohttp/http_parser.py,sha256=UqerYPJzA1MqLmeG1jURhTNO1YhwUASK3QVcNEz0me8,36851 +aiohttp/http_websocket.py,sha256=8VXFKw6KQUEmPg48GtRMB37v0gTK7A0inoxXuDxMZEc,842 +aiohttp/http_writer.py,sha256=pRIyfOmL3cZmdWDWBBJ2cZEwEJzLWzlPPAJInaPLThI,7595 +aiohttp/log.py,sha256=BbNKx9e3VMIm0xYjZI0IcBBoS7wjdeIeSaiJE7-qK2g,325 +aiohttp/multipart.py,sha256=SABIvo3vhXzG4bLDZ0C4V3yG_86vAb-3Zb9Li7BVmI8,36944 +aiohttp/payload.py,sha256=rCA9JJI_RMCik_7qNIaC1Bh21aXhABGYK2tsYeaHRQ4,15793 +aiohttp/payload_streamer.py,sha256=ZzEYyfzcjGWkVkK3XR2pBthSCSIykYvY3Wr5cGQ2eTc,2211 +aiohttp/py.typed,sha256=sow9soTwP9T_gEAQSVh7Gb8855h04Nwmhs2We-JRgZM,7 +aiohttp/pytest_plugin.py,sha256=AfJ6VIWzsp5KgpXRREsX3yqGUZrJyfb7zzcMqzWxz7I,12768 +aiohttp/resolver.py,sha256=sJ8-LYCtl_g9f6gn_5X2NFQ9FQ72Q2Mr4_rLxo9NVeI,6375 +aiohttp/streams.py,sha256=U-qTkuAqIfpJChuKEy-vYn8nQ_Z1MVcW0WO2DHiJz_o,22329 +aiohttp/tcp_helpers.py,sha256=BSadqVWaBpMFDRWnhaaR941N9MiDZ7bdTrxgCb0CW-M,961 +aiohttp/test_utils.py,sha256=r7kBasmZtC3tQY5OmyMaIl1B9P8Bnnq1oM3npVcAPKs,22811 +aiohttp/tracing.py,sha256=66XQwtdR5DHv8p953eeNL0l8o6iHDaNwH9bBaybHXD4,15137 +aiohttp/typedefs.py,sha256=wUlqwe9Mw9W8jT3HsYJcYk00qP3EMPz3nTkYXmeNN48,1657 +aiohttp/web.py,sha256=As5nqGQy4QXWMXSaOsh0JudSVVJVIt_nr3n0b8CaMb0,18422 +aiohttp/web_app.py,sha256=Zre0QHM9JAp4d7jrj5NRxlPnfTrKLNuA42Rdsh8Q2TI,19554 +aiohttp/web_exceptions.py,sha256=7nIuiwhZ39vJJ9KrWqArA5QcWbUdqkz2CLwEpJapeN8,10360 +aiohttp/web_fileresponse.py,sha256=FRsS0p9r1KU5y8ceG0QXBYnrL6xggjbxcXSmI6qIR4k,16504 +aiohttp/web_log.py,sha256=rX5D7xLOX2B6BMdiZ-chme_KfJfW5IXEoFwLfkfkajs,7865 +aiohttp/web_middlewares.py,sha256=sFI0AgeNjdyAjuz92QtMIpngmJSOxrqe2Jfbs4BNUu0,4165 +aiohttp/web_protocol.py,sha256=0MYjcaQishUyJxJ4lsH4IfHef4nIvHDf-DSZwI1Con4,25539 +aiohttp/web_request.py,sha256=j_SSX9s-d3ZeNyqUTpFIaPUaNdSqHwb7yfc0ufL8xFA,29750 +aiohttp/web_response.py,sha256=65aliDETi7rZ8P76ksuHQI0ZTu1cKpclCSailNu105M,28696 +aiohttp/web_routedef.py,sha256=VT1GAx6BrawoDh5RwBwBu5wSABSqgWwAe74AUCyZAEo,6110 +aiohttp/web_runner.py,sha256=v1G1nKiOOQgFnTSR4IMc6I9ReEFDMaHtMLvO_roDM-A,11786 +aiohttp/web_server.py,sha256=-9WDKUAiR9ll-rSdwXSqG6YjaoW79d1R4y0BGSqgUMA,2888 +aiohttp/web_urldispatcher.py,sha256=TIMxFmhLjERseG0xcZv2Ef9Xuo_GTBRqBqeMkCgL0K8,43825 +aiohttp/web_ws.py,sha256=Gr-UWgau41P-OoJUb3WJvaNYiDESXzrHmIw1Cqonupc,22612 +aiohttp/worker.py,sha256=0lvxRNMjGM47ddlQWtci53ri9YN42Su1Vdd_Z7zMMH0,8040 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohttp-3.11.13.dist-info/WHEEL b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohttp-3.11.13.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..074b7f880946583239d6b4f8284e932b9e12c9e7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohttp-3.11.13.dist-info/WHEEL @@ -0,0 +1,6 @@ +Wheel-Version: 1.0 +Generator: setuptools (75.8.0) +Root-Is-Purelib: false +Tag: cp312-cp312-manylinux_2_17_x86_64 +Tag: cp312-cp312-manylinux2014_x86_64 + diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohttp-3.11.13.dist-info/top_level.txt b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohttp-3.11.13.dist-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..ee4ba4f3d739e094878215c84eb41ba85c80e4a8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohttp-3.11.13.dist-info/top_level.txt @@ -0,0 +1 @@ +aiohttp diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/anyio/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/anyio/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..098312599f660abb288992a09df460125b899c31 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/anyio/__init__.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from ._core._eventloop import current_time as current_time +from ._core._eventloop import get_all_backends as get_all_backends +from ._core._eventloop import get_cancelled_exc_class as get_cancelled_exc_class +from ._core._eventloop import run as run +from ._core._eventloop import sleep as sleep +from ._core._eventloop import sleep_forever as sleep_forever +from ._core._eventloop import sleep_until as sleep_until +from ._core._exceptions import BrokenResourceError as BrokenResourceError +from ._core._exceptions import BrokenWorkerIntepreter as BrokenWorkerIntepreter +from ._core._exceptions import BrokenWorkerProcess as BrokenWorkerProcess +from ._core._exceptions import BusyResourceError as BusyResourceError +from ._core._exceptions import ClosedResourceError as ClosedResourceError +from ._core._exceptions import DelimiterNotFound as DelimiterNotFound +from ._core._exceptions import EndOfStream as EndOfStream +from ._core._exceptions import IncompleteRead as IncompleteRead +from ._core._exceptions import TypedAttributeLookupError as TypedAttributeLookupError +from ._core._exceptions import WouldBlock as WouldBlock +from ._core._fileio import AsyncFile as AsyncFile +from ._core._fileio import Path as Path +from ._core._fileio import open_file as open_file +from ._core._fileio import wrap_file as wrap_file +from ._core._resources import aclose_forcefully as aclose_forcefully +from ._core._signals import open_signal_receiver as open_signal_receiver +from ._core._sockets import connect_tcp as connect_tcp +from ._core._sockets import connect_unix as connect_unix +from ._core._sockets import create_connected_udp_socket as create_connected_udp_socket +from ._core._sockets import ( + create_connected_unix_datagram_socket as create_connected_unix_datagram_socket, +) +from ._core._sockets import create_tcp_listener as create_tcp_listener +from ._core._sockets import create_udp_socket as create_udp_socket +from ._core._sockets import create_unix_datagram_socket as create_unix_datagram_socket +from ._core._sockets import create_unix_listener as create_unix_listener +from ._core._sockets import getaddrinfo as getaddrinfo +from ._core._sockets import getnameinfo as getnameinfo +from ._core._sockets import wait_readable as wait_readable +from ._core._sockets import wait_socket_readable as wait_socket_readable +from ._core._sockets import wait_socket_writable as wait_socket_writable +from ._core._sockets import wait_writable as wait_writable +from ._core._streams import create_memory_object_stream as create_memory_object_stream +from ._core._subprocesses import open_process as open_process +from ._core._subprocesses import run_process as run_process +from ._core._synchronization import CapacityLimiter as CapacityLimiter +from ._core._synchronization import ( + CapacityLimiterStatistics as CapacityLimiterStatistics, +) +from ._core._synchronization import Condition as Condition +from ._core._synchronization import ConditionStatistics as ConditionStatistics +from ._core._synchronization import Event as Event +from ._core._synchronization import EventStatistics as EventStatistics +from ._core._synchronization import Lock as Lock +from ._core._synchronization import LockStatistics as LockStatistics +from ._core._synchronization import ResourceGuard as ResourceGuard +from ._core._synchronization import Semaphore as Semaphore +from ._core._synchronization import SemaphoreStatistics as SemaphoreStatistics +from ._core._tasks import TASK_STATUS_IGNORED as TASK_STATUS_IGNORED +from ._core._tasks import CancelScope as CancelScope +from ._core._tasks import create_task_group as create_task_group +from ._core._tasks import current_effective_deadline as current_effective_deadline +from ._core._tasks import fail_after as fail_after +from ._core._tasks import move_on_after as move_on_after +from ._core._testing import TaskInfo as TaskInfo +from ._core._testing import get_current_task as get_current_task +from ._core._testing import get_running_tasks as get_running_tasks +from ._core._testing import wait_all_tasks_blocked as wait_all_tasks_blocked +from ._core._typedattr import TypedAttributeProvider as TypedAttributeProvider +from ._core._typedattr import TypedAttributeSet as TypedAttributeSet +from ._core._typedattr import typed_attribute as typed_attribute + +# Re-export imports so they look like they live directly in this package +for __value in list(locals().values()): + if getattr(__value, "__module__", "").startswith("anyio."): + __value.__module__ = __name__ + +del __value diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/anyio/from_thread.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/anyio/from_thread.py new file mode 100644 index 0000000000000000000000000000000000000000..93a4cfe8e49fc7dcd6d28fc0a794cd3fab0542d8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/anyio/from_thread.py @@ -0,0 +1,527 @@ +from __future__ import annotations + +import sys +from collections.abc import Awaitable, Callable, Generator +from concurrent.futures import Future +from contextlib import ( + AbstractAsyncContextManager, + AbstractContextManager, + contextmanager, +) +from dataclasses import dataclass, field +from inspect import isawaitable +from threading import Lock, Thread, get_ident +from types import TracebackType +from typing import ( + Any, + Generic, + TypeVar, + cast, + overload, +) + +from ._core import _eventloop +from ._core._eventloop import get_async_backend, get_cancelled_exc_class, threadlocals +from ._core._synchronization import Event +from ._core._tasks import CancelScope, create_task_group +from .abc import AsyncBackend +from .abc._tasks import TaskStatus + +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack +else: + from typing_extensions import TypeVarTuple, Unpack + +T_Retval = TypeVar("T_Retval") +T_co = TypeVar("T_co", covariant=True) +PosArgsT = TypeVarTuple("PosArgsT") + + +def run( + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], *args: Unpack[PosArgsT] +) -> T_Retval: + """ + Call a coroutine function from a worker thread. + + :param func: a coroutine function + :param args: positional arguments for the callable + :return: the return value of the coroutine function + + """ + try: + async_backend = threadlocals.current_async_backend + token = threadlocals.current_token + except AttributeError: + raise RuntimeError( + "This function can only be run from an AnyIO worker thread" + ) from None + + return async_backend.run_async_from_thread(func, args, token=token) + + +def run_sync( + func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT] +) -> T_Retval: + """ + Call a function in the event loop thread from a worker thread. + + :param func: a callable + :param args: positional arguments for the callable + :return: the return value of the callable + + """ + try: + async_backend = threadlocals.current_async_backend + token = threadlocals.current_token + except AttributeError: + raise RuntimeError( + "This function can only be run from an AnyIO worker thread" + ) from None + + return async_backend.run_sync_from_thread(func, args, token=token) + + +class _BlockingAsyncContextManager(Generic[T_co], AbstractContextManager): + _enter_future: Future[T_co] + _exit_future: Future[bool | None] + _exit_event: Event + _exit_exc_info: tuple[ + type[BaseException] | None, BaseException | None, TracebackType | None + ] = (None, None, None) + + def __init__( + self, async_cm: AbstractAsyncContextManager[T_co], portal: BlockingPortal + ): + self._async_cm = async_cm + self._portal = portal + + async def run_async_cm(self) -> bool | None: + try: + self._exit_event = Event() + value = await self._async_cm.__aenter__() + except BaseException as exc: + self._enter_future.set_exception(exc) + raise + else: + self._enter_future.set_result(value) + + try: + # Wait for the sync context manager to exit. + # This next statement can raise `get_cancelled_exc_class()` if + # something went wrong in a task group in this async context + # manager. + await self._exit_event.wait() + finally: + # In case of cancellation, it could be that we end up here before + # `_BlockingAsyncContextManager.__exit__` is called, and an + # `_exit_exc_info` has been set. + result = await self._async_cm.__aexit__(*self._exit_exc_info) + return result + + def __enter__(self) -> T_co: + self._enter_future = Future() + self._exit_future = self._portal.start_task_soon(self.run_async_cm) + return self._enter_future.result() + + def __exit__( + self, + __exc_type: type[BaseException] | None, + __exc_value: BaseException | None, + __traceback: TracebackType | None, + ) -> bool | None: + self._exit_exc_info = __exc_type, __exc_value, __traceback + self._portal.call(self._exit_event.set) + return self._exit_future.result() + + +class _BlockingPortalTaskStatus(TaskStatus): + def __init__(self, future: Future): + self._future = future + + def started(self, value: object = None) -> None: + self._future.set_result(value) + + +class BlockingPortal: + """An object that lets external threads run code in an asynchronous event loop.""" + + def __new__(cls) -> BlockingPortal: + return get_async_backend().create_blocking_portal() + + def __init__(self) -> None: + self._event_loop_thread_id: int | None = get_ident() + self._stop_event = Event() + self._task_group = create_task_group() + self._cancelled_exc_class = get_cancelled_exc_class() + + async def __aenter__(self) -> BlockingPortal: + await self._task_group.__aenter__() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + await self.stop() + return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + + def _check_running(self) -> None: + if self._event_loop_thread_id is None: + raise RuntimeError("This portal is not running") + if self._event_loop_thread_id == get_ident(): + raise RuntimeError( + "This method cannot be called from the event loop thread" + ) + + async def sleep_until_stopped(self) -> None: + """Sleep until :meth:`stop` is called.""" + await self._stop_event.wait() + + async def stop(self, cancel_remaining: bool = False) -> None: + """ + Signal the portal to shut down. + + This marks the portal as no longer accepting new calls and exits from + :meth:`sleep_until_stopped`. + + :param cancel_remaining: ``True`` to cancel all the remaining tasks, ``False`` + to let them finish before returning + + """ + self._event_loop_thread_id = None + self._stop_event.set() + if cancel_remaining: + self._task_group.cancel_scope.cancel() + + async def _call_func( + self, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], + args: tuple[Unpack[PosArgsT]], + kwargs: dict[str, Any], + future: Future[T_Retval], + ) -> None: + def callback(f: Future[T_Retval]) -> None: + if f.cancelled() and self._event_loop_thread_id not in ( + None, + get_ident(), + ): + self.call(scope.cancel) + + try: + retval_or_awaitable = func(*args, **kwargs) + if isawaitable(retval_or_awaitable): + with CancelScope() as scope: + if future.cancelled(): + scope.cancel() + else: + future.add_done_callback(callback) + + retval = await retval_or_awaitable + else: + retval = retval_or_awaitable + except self._cancelled_exc_class: + future.cancel() + future.set_running_or_notify_cancel() + except BaseException as exc: + if not future.cancelled(): + future.set_exception(exc) + + # Let base exceptions fall through + if not isinstance(exc, Exception): + raise + else: + if not future.cancelled(): + future.set_result(retval) + finally: + scope = None # type: ignore[assignment] + + def _spawn_task_from_thread( + self, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], + args: tuple[Unpack[PosArgsT]], + kwargs: dict[str, Any], + name: object, + future: Future[T_Retval], + ) -> None: + """ + Spawn a new task using the given callable. + + Implementors must ensure that the future is resolved when the task finishes. + + :param func: a callable + :param args: positional arguments to be passed to the callable + :param kwargs: keyword arguments to be passed to the callable + :param name: name of the task (will be coerced to a string if not ``None``) + :param future: a future that will resolve to the return value of the callable, + or the exception raised during its execution + + """ + raise NotImplementedError + + @overload + def call( + self, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + *args: Unpack[PosArgsT], + ) -> T_Retval: ... + + @overload + def call( + self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT] + ) -> T_Retval: ... + + def call( + self, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], + *args: Unpack[PosArgsT], + ) -> T_Retval: + """ + Call the given function in the event loop thread. + + If the callable returns a coroutine object, it is awaited on. + + :param func: any callable + :raises RuntimeError: if the portal is not running or if this method is called + from within the event loop thread + + """ + return cast(T_Retval, self.start_task_soon(func, *args).result()) + + @overload + def start_task_soon( + self, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], + *args: Unpack[PosArgsT], + name: object = None, + ) -> Future[T_Retval]: ... + + @overload + def start_task_soon( + self, + func: Callable[[Unpack[PosArgsT]], T_Retval], + *args: Unpack[PosArgsT], + name: object = None, + ) -> Future[T_Retval]: ... + + def start_task_soon( + self, + func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], + *args: Unpack[PosArgsT], + name: object = None, + ) -> Future[T_Retval]: + """ + Start a task in the portal's task group. + + The task will be run inside a cancel scope which can be cancelled by cancelling + the returned future. + + :param func: the target function + :param args: positional arguments passed to ``func`` + :param name: name of the task (will be coerced to a string if not ``None``) + :return: a future that resolves with the return value of the callable if the + task completes successfully, or with the exception raised in the task + :raises RuntimeError: if the portal is not running or if this method is called + from within the event loop thread + :rtype: concurrent.futures.Future[T_Retval] + + .. versionadded:: 3.0 + + """ + self._check_running() + f: Future[T_Retval] = Future() + self._spawn_task_from_thread(func, args, {}, name, f) + return f + + def start_task( + self, + func: Callable[..., Awaitable[T_Retval]], + *args: object, + name: object = None, + ) -> tuple[Future[T_Retval], Any]: + """ + Start a task in the portal's task group and wait until it signals for readiness. + + This method works the same way as :meth:`.abc.TaskGroup.start`. + + :param func: the target function + :param args: positional arguments passed to ``func`` + :param name: name of the task (will be coerced to a string if not ``None``) + :return: a tuple of (future, task_status_value) where the ``task_status_value`` + is the value passed to ``task_status.started()`` from within the target + function + :rtype: tuple[concurrent.futures.Future[T_Retval], Any] + + .. versionadded:: 3.0 + + """ + + def task_done(future: Future[T_Retval]) -> None: + if not task_status_future.done(): + if future.cancelled(): + task_status_future.cancel() + elif future.exception(): + task_status_future.set_exception(future.exception()) + else: + exc = RuntimeError( + "Task exited without calling task_status.started()" + ) + task_status_future.set_exception(exc) + + self._check_running() + task_status_future: Future = Future() + task_status = _BlockingPortalTaskStatus(task_status_future) + f: Future = Future() + f.add_done_callback(task_done) + self._spawn_task_from_thread(func, args, {"task_status": task_status}, name, f) + return f, task_status_future.result() + + def wrap_async_context_manager( + self, cm: AbstractAsyncContextManager[T_co] + ) -> AbstractContextManager[T_co]: + """ + Wrap an async context manager as a synchronous context manager via this portal. + + Spawns a task that will call both ``__aenter__()`` and ``__aexit__()``, stopping + in the middle until the synchronous context manager exits. + + :param cm: an asynchronous context manager + :return: a synchronous context manager + + .. versionadded:: 2.1 + + """ + return _BlockingAsyncContextManager(cm, self) + + +@dataclass +class BlockingPortalProvider: + """ + A manager for a blocking portal. Used as a context manager. The first thread to + enter this context manager causes a blocking portal to be started with the specific + parameters, and the last thread to exit causes the portal to be shut down. Thus, + there will be exactly one blocking portal running in this context as long as at + least one thread has entered this context manager. + + The parameters are the same as for :func:`~anyio.run`. + + :param backend: name of the backend + :param backend_options: backend options + + .. versionadded:: 4.4 + """ + + backend: str = "asyncio" + backend_options: dict[str, Any] | None = None + _lock: Lock = field(init=False, default_factory=Lock) + _leases: int = field(init=False, default=0) + _portal: BlockingPortal = field(init=False) + _portal_cm: AbstractContextManager[BlockingPortal] | None = field( + init=False, default=None + ) + + def __enter__(self) -> BlockingPortal: + with self._lock: + if self._portal_cm is None: + self._portal_cm = start_blocking_portal( + self.backend, self.backend_options + ) + self._portal = self._portal_cm.__enter__() + + self._leases += 1 + return self._portal + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + portal_cm: AbstractContextManager[BlockingPortal] | None = None + with self._lock: + assert self._portal_cm + assert self._leases > 0 + self._leases -= 1 + if not self._leases: + portal_cm = self._portal_cm + self._portal_cm = None + del self._portal + + if portal_cm: + portal_cm.__exit__(None, None, None) + + +@contextmanager +def start_blocking_portal( + backend: str = "asyncio", backend_options: dict[str, Any] | None = None +) -> Generator[BlockingPortal, Any, None]: + """ + Start a new event loop in a new thread and run a blocking portal in its main task. + + The parameters are the same as for :func:`~anyio.run`. + + :param backend: name of the backend + :param backend_options: backend options + :return: a context manager that yields a blocking portal + + .. versionchanged:: 3.0 + Usage as a context manager is now required. + + """ + + async def run_portal() -> None: + async with BlockingPortal() as portal_: + future.set_result(portal_) + await portal_.sleep_until_stopped() + + def run_blocking_portal() -> None: + if future.set_running_or_notify_cancel(): + try: + _eventloop.run( + run_portal, backend=backend, backend_options=backend_options + ) + except BaseException as exc: + if not future.done(): + future.set_exception(exc) + + future: Future[BlockingPortal] = Future() + thread = Thread(target=run_blocking_portal, daemon=True) + thread.start() + try: + cancel_remaining_tasks = False + portal = future.result() + try: + yield portal + except BaseException: + cancel_remaining_tasks = True + raise + finally: + try: + portal.call(portal.stop, cancel_remaining_tasks) + except RuntimeError: + pass + finally: + thread.join() + + +def check_cancelled() -> None: + """ + Check if the cancel scope of the host task's running the current worker thread has + been cancelled. + + If the host task's current cancel scope has indeed been cancelled, the + backend-specific cancellation exception will be raised. + + :raises RuntimeError: if the current thread was not spawned by + :func:`.to_thread.run_sync` + + """ + try: + async_backend: AsyncBackend = threadlocals.current_async_backend + except AttributeError: + raise RuntimeError( + "This function can only be run from an AnyIO worker thread" + ) from None + + async_backend.check_cancelled() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/anyio/lowlevel.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/anyio/lowlevel.py new file mode 100644 index 0000000000000000000000000000000000000000..14c7668cb3fc50d71a79b1af1d29b7dc18742660 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/anyio/lowlevel.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +import enum +from dataclasses import dataclass +from typing import Any, Generic, Literal, TypeVar, overload +from weakref import WeakKeyDictionary + +from ._core._eventloop import get_async_backend + +T = TypeVar("T") +D = TypeVar("D") + + +async def checkpoint() -> None: + """ + Check for cancellation and allow the scheduler to switch to another task. + + Equivalent to (but more efficient than):: + + await checkpoint_if_cancelled() + await cancel_shielded_checkpoint() + + + .. versionadded:: 3.0 + + """ + await get_async_backend().checkpoint() + + +async def checkpoint_if_cancelled() -> None: + """ + Enter a checkpoint if the enclosing cancel scope has been cancelled. + + This does not allow the scheduler to switch to a different task. + + .. versionadded:: 3.0 + + """ + await get_async_backend().checkpoint_if_cancelled() + + +async def cancel_shielded_checkpoint() -> None: + """ + Allow the scheduler to switch to another task but without checking for cancellation. + + Equivalent to (but potentially more efficient than):: + + with CancelScope(shield=True): + await checkpoint() + + + .. versionadded:: 3.0 + + """ + await get_async_backend().cancel_shielded_checkpoint() + + +def current_token() -> object: + """ + Return a backend specific token object that can be used to get back to the event + loop. + + """ + return get_async_backend().current_token() + + +_run_vars: WeakKeyDictionary[Any, dict[str, Any]] = WeakKeyDictionary() +_token_wrappers: dict[Any, _TokenWrapper] = {} + + +@dataclass(frozen=True) +class _TokenWrapper: + __slots__ = "_token", "__weakref__" + _token: object + + +class _NoValueSet(enum.Enum): + NO_VALUE_SET = enum.auto() + + +class RunvarToken(Generic[T]): + __slots__ = "_var", "_value", "_redeemed" + + def __init__(self, var: RunVar[T], value: T | Literal[_NoValueSet.NO_VALUE_SET]): + self._var = var + self._value: T | Literal[_NoValueSet.NO_VALUE_SET] = value + self._redeemed = False + + +class RunVar(Generic[T]): + """ + Like a :class:`~contextvars.ContextVar`, except scoped to the running event loop. + """ + + __slots__ = "_name", "_default" + + NO_VALUE_SET: Literal[_NoValueSet.NO_VALUE_SET] = _NoValueSet.NO_VALUE_SET + + _token_wrappers: set[_TokenWrapper] = set() + + def __init__( + self, name: str, default: T | Literal[_NoValueSet.NO_VALUE_SET] = NO_VALUE_SET + ): + self._name = name + self._default = default + + @property + def _current_vars(self) -> dict[str, T]: + token = current_token() + try: + return _run_vars[token] + except KeyError: + run_vars = _run_vars[token] = {} + return run_vars + + @overload + def get(self, default: D) -> T | D: ... + + @overload + def get(self) -> T: ... + + def get( + self, default: D | Literal[_NoValueSet.NO_VALUE_SET] = NO_VALUE_SET + ) -> T | D: + try: + return self._current_vars[self._name] + except KeyError: + if default is not RunVar.NO_VALUE_SET: + return default + elif self._default is not RunVar.NO_VALUE_SET: + return self._default + + raise LookupError( + f'Run variable "{self._name}" has no value and no default set' + ) + + def set(self, value: T) -> RunvarToken[T]: + current_vars = self._current_vars + token = RunvarToken(self, current_vars.get(self._name, RunVar.NO_VALUE_SET)) + current_vars[self._name] = value + return token + + def reset(self, token: RunvarToken[T]) -> None: + if token._var is not self: + raise ValueError("This token does not belong to this RunVar") + + if token._redeemed: + raise ValueError("This token has already been used") + + if token._value is _NoValueSet.NO_VALUE_SET: + try: + del self._current_vars[self._name] + except KeyError: + pass + else: + self._current_vars[self._name] = token._value + + token._redeemed = True + + def __repr__(self) -> str: + return f"" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/anyio/py.typed b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/anyio/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/anyio/pytest_plugin.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/anyio/pytest_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..4a0d59dd06736eee90e069aeba42c1a69d398a78 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/anyio/pytest_plugin.py @@ -0,0 +1,191 @@ +from __future__ import annotations + +import sys +from collections.abc import Generator, Iterator +from contextlib import ExitStack, contextmanager +from inspect import isasyncgenfunction, iscoroutinefunction, ismethod +from typing import Any, cast + +import pytest +import sniffio +from _pytest.fixtures import SubRequest +from _pytest.outcomes import Exit + +from ._core._eventloop import get_all_backends, get_async_backend +from ._core._exceptions import iterate_exceptions +from .abc import TestRunner + +if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup + +_current_runner: TestRunner | None = None +_runner_stack: ExitStack | None = None +_runner_leases = 0 + + +def extract_backend_and_options(backend: object) -> tuple[str, dict[str, Any]]: + if isinstance(backend, str): + return backend, {} + elif isinstance(backend, tuple) and len(backend) == 2: + if isinstance(backend[0], str) and isinstance(backend[1], dict): + return cast(tuple[str, dict[str, Any]], backend) + + raise TypeError("anyio_backend must be either a string or tuple of (string, dict)") + + +@contextmanager +def get_runner( + backend_name: str, backend_options: dict[str, Any] +) -> Iterator[TestRunner]: + global _current_runner, _runner_leases, _runner_stack + if _current_runner is None: + asynclib = get_async_backend(backend_name) + _runner_stack = ExitStack() + if sniffio.current_async_library_cvar.get(None) is None: + # Since we're in control of the event loop, we can cache the name of the + # async library + token = sniffio.current_async_library_cvar.set(backend_name) + _runner_stack.callback(sniffio.current_async_library_cvar.reset, token) + + backend_options = backend_options or {} + _current_runner = _runner_stack.enter_context( + asynclib.create_test_runner(backend_options) + ) + + _runner_leases += 1 + try: + yield _current_runner + finally: + _runner_leases -= 1 + if not _runner_leases: + assert _runner_stack is not None + _runner_stack.close() + _runner_stack = _current_runner = None + + +def pytest_configure(config: Any) -> None: + config.addinivalue_line( + "markers", + "anyio: mark the (coroutine function) test to be run " + "asynchronously via anyio.", + ) + + +@pytest.hookimpl(hookwrapper=True) +def pytest_fixture_setup(fixturedef: Any, request: Any) -> Generator[Any]: + def wrapper( + *args: Any, anyio_backend: Any, request: SubRequest, **kwargs: Any + ) -> Any: + # Rebind any fixture methods to the request instance + if ( + request.instance + and ismethod(func) + and type(func.__self__) is type(request.instance) + ): + local_func = func.__func__.__get__(request.instance) + else: + local_func = func + + backend_name, backend_options = extract_backend_and_options(anyio_backend) + if has_backend_arg: + kwargs["anyio_backend"] = anyio_backend + + if has_request_arg: + kwargs["request"] = request + + with get_runner(backend_name, backend_options) as runner: + if isasyncgenfunction(local_func): + yield from runner.run_asyncgen_fixture(local_func, kwargs) + else: + yield runner.run_fixture(local_func, kwargs) + + # Only apply this to coroutine functions and async generator functions in requests + # that involve the anyio_backend fixture + func = fixturedef.func + if isasyncgenfunction(func) or iscoroutinefunction(func): + if "anyio_backend" in request.fixturenames: + fixturedef.func = wrapper + original_argname = fixturedef.argnames + + if not (has_backend_arg := "anyio_backend" in fixturedef.argnames): + fixturedef.argnames += ("anyio_backend",) + + if not (has_request_arg := "request" in fixturedef.argnames): + fixturedef.argnames += ("request",) + + try: + return (yield) + finally: + fixturedef.func = func + fixturedef.argnames = original_argname + + return (yield) + + +@pytest.hookimpl(tryfirst=True) +def pytest_pycollect_makeitem(collector: Any, name: Any, obj: Any) -> None: + if collector.istestfunction(obj, name): + inner_func = obj.hypothesis.inner_test if hasattr(obj, "hypothesis") else obj + if iscoroutinefunction(inner_func): + marker = collector.get_closest_marker("anyio") + own_markers = getattr(obj, "pytestmark", ()) + if marker or any(marker.name == "anyio" for marker in own_markers): + pytest.mark.usefixtures("anyio_backend")(obj) + + +@pytest.hookimpl(tryfirst=True) +def pytest_pyfunc_call(pyfuncitem: Any) -> bool | None: + def run_with_hypothesis(**kwargs: Any) -> None: + with get_runner(backend_name, backend_options) as runner: + runner.run_test(original_func, kwargs) + + backend = pyfuncitem.funcargs.get("anyio_backend") + if backend: + backend_name, backend_options = extract_backend_and_options(backend) + + if hasattr(pyfuncitem.obj, "hypothesis"): + # Wrap the inner test function unless it's already wrapped + original_func = pyfuncitem.obj.hypothesis.inner_test + if original_func.__qualname__ != run_with_hypothesis.__qualname__: + if iscoroutinefunction(original_func): + pyfuncitem.obj.hypothesis.inner_test = run_with_hypothesis + + return None + + if iscoroutinefunction(pyfuncitem.obj): + funcargs = pyfuncitem.funcargs + testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames} + with get_runner(backend_name, backend_options) as runner: + try: + runner.run_test(pyfuncitem.obj, testargs) + except ExceptionGroup as excgrp: + for exc in iterate_exceptions(excgrp): + if isinstance(exc, (Exit, KeyboardInterrupt, SystemExit)): + raise exc from excgrp + + raise + + return True + + return None + + +@pytest.fixture(scope="module", params=get_all_backends()) +def anyio_backend(request: Any) -> Any: + return request.param + + +@pytest.fixture +def anyio_backend_name(anyio_backend: Any) -> str: + if isinstance(anyio_backend, str): + return anyio_backend + else: + return anyio_backend[0] + + +@pytest.fixture +def anyio_backend_options(anyio_backend: Any) -> dict[str, Any]: + if isinstance(anyio_backend, str): + return {} + else: + return anyio_backend[1] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/anyio/to_interpreter.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/anyio/to_interpreter.py new file mode 100644 index 0000000000000000000000000000000000000000..bcde24d3d1d7af07beb169c50ac860c4c270a9d4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/anyio/to_interpreter.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +import atexit +import os +import pickle +import sys +from collections import deque +from collections.abc import Callable +from textwrap import dedent +from typing import Any, Final, TypeVar + +from . import current_time, to_thread +from ._core._exceptions import BrokenWorkerIntepreter +from ._core._synchronization import CapacityLimiter +from .lowlevel import RunVar + +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack +else: + from typing_extensions import TypeVarTuple, Unpack + +UNBOUND: Final = 2 # I have no clue how this works, but it was used in the stdlib +FMT_UNPICKLED: Final = 0 +FMT_PICKLED: Final = 1 +DEFAULT_CPU_COUNT: Final = 8 # this is just an arbitrarily selected value +MAX_WORKER_IDLE_TIME = ( + 30 # seconds a subinterpreter can be idle before becoming eligible for pruning +) + +T_Retval = TypeVar("T_Retval") +PosArgsT = TypeVarTuple("PosArgsT") + +_idle_workers = RunVar[deque["Worker"]]("_available_workers") +_default_interpreter_limiter = RunVar[CapacityLimiter]("_default_interpreter_limiter") + + +class Worker: + _run_func = compile( + dedent(""" + import _interpqueues as queues + import _interpreters as interpreters + from pickle import loads, dumps, HIGHEST_PROTOCOL + + item = queues.get(queue_id)[0] + try: + func, args = loads(item) + retval = func(*args) + except BaseException as exc: + is_exception = True + retval = exc + else: + is_exception = False + + try: + queues.put(queue_id, (retval, is_exception), FMT_UNPICKLED, UNBOUND) + except interpreters.NotShareableError: + retval = dumps(retval, HIGHEST_PROTOCOL) + queues.put(queue_id, (retval, is_exception), FMT_PICKLED, UNBOUND) + """), + "", + "exec", + ) + + last_used: float = 0 + + _initialized: bool = False + _interpreter_id: int + _queue_id: int + + def initialize(self) -> None: + import _interpqueues as queues + import _interpreters as interpreters + + self._interpreter_id = interpreters.create() + self._queue_id = queues.create(2, FMT_UNPICKLED, UNBOUND) # type: ignore[call-arg] + self._initialized = True + interpreters.set___main___attrs( + self._interpreter_id, + { + "queue_id": self._queue_id, + "FMT_PICKLED": FMT_PICKLED, + "FMT_UNPICKLED": FMT_UNPICKLED, + "UNBOUND": UNBOUND, + }, + ) + + def destroy(self) -> None: + import _interpqueues as queues + import _interpreters as interpreters + + if self._initialized: + interpreters.destroy(self._interpreter_id) + queues.destroy(self._queue_id) + + def _call( + self, + func: Callable[..., T_Retval], + args: tuple[Any], + ) -> tuple[Any, bool]: + import _interpqueues as queues + import _interpreters as interpreters + + if not self._initialized: + self.initialize() + + payload = pickle.dumps((func, args), pickle.HIGHEST_PROTOCOL) + queues.put(self._queue_id, payload, FMT_PICKLED, UNBOUND) # type: ignore[call-arg] + + res: Any + is_exception: bool + if exc_info := interpreters.exec(self._interpreter_id, self._run_func): # type: ignore[func-returns-value,arg-type] + raise BrokenWorkerIntepreter(exc_info) + + (res, is_exception), fmt = queues.get(self._queue_id)[:2] + if fmt == FMT_PICKLED: + res = pickle.loads(res) + + return res, is_exception + + async def call( + self, + func: Callable[..., T_Retval], + args: tuple[Any], + limiter: CapacityLimiter, + ) -> T_Retval: + result, is_exception = await to_thread.run_sync( + self._call, + func, + args, + limiter=limiter, + ) + if is_exception: + raise result + + return result + + +def _stop_workers(workers: deque[Worker]) -> None: + for worker in workers: + worker.destroy() + + workers.clear() + + +async def run_sync( + func: Callable[[Unpack[PosArgsT]], T_Retval], + *args: Unpack[PosArgsT], + limiter: CapacityLimiter | None = None, +) -> T_Retval: + """ + Call the given function with the given arguments in a subinterpreter. + + If the ``cancellable`` option is enabled and the task waiting for its completion is + cancelled, the call will still run its course but its return value (or any raised + exception) will be ignored. + + .. warning:: This feature is **experimental**. The upstream interpreter API has not + yet been finalized or thoroughly tested, so don't rely on this for anything + mission critical. + + :param func: a callable + :param args: positional arguments for the callable + :param limiter: capacity limiter to use to limit the total amount of subinterpreters + running (if omitted, the default limiter is used) + :return: the result of the call + :raises BrokenWorkerIntepreter: if there's an internal error in a subinterpreter + + """ + if sys.version_info <= (3, 13): + raise RuntimeError("subinterpreters require at least Python 3.13") + + if limiter is None: + limiter = current_default_interpreter_limiter() + + try: + idle_workers = _idle_workers.get() + except LookupError: + idle_workers = deque() + _idle_workers.set(idle_workers) + atexit.register(_stop_workers, idle_workers) + + async with limiter: + try: + worker = idle_workers.pop() + except IndexError: + worker = Worker() + + try: + return await worker.call(func, args, limiter) + finally: + # Prune workers that have been idle for too long + now = current_time() + while idle_workers: + if now - idle_workers[0].last_used <= MAX_WORKER_IDLE_TIME: + break + + await to_thread.run_sync(idle_workers.popleft().destroy, limiter=limiter) + + worker.last_used = current_time() + idle_workers.append(worker) + + +def current_default_interpreter_limiter() -> CapacityLimiter: + """ + Return the capacity limiter that is used by default to limit the number of + concurrently running subinterpreters. + + Defaults to the number of CPU cores. + + :return: a capacity limiter object + + """ + try: + return _default_interpreter_limiter.get() + except LookupError: + limiter = CapacityLimiter(os.cpu_count() or DEFAULT_CPU_COUNT) + _default_interpreter_limiter.set(limiter) + return limiter diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/anyio/to_process.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/anyio/to_process.py new file mode 100644 index 0000000000000000000000000000000000000000..495de2ae7111ef3b0382f5efe11a0e8e7cbd186b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/anyio/to_process.py @@ -0,0 +1,258 @@ +from __future__ import annotations + +import os +import pickle +import subprocess +import sys +from collections import deque +from collections.abc import Callable +from importlib.util import module_from_spec, spec_from_file_location +from typing import TypeVar, cast + +from ._core._eventloop import current_time, get_async_backend, get_cancelled_exc_class +from ._core._exceptions import BrokenWorkerProcess +from ._core._subprocesses import open_process +from ._core._synchronization import CapacityLimiter +from ._core._tasks import CancelScope, fail_after +from .abc import ByteReceiveStream, ByteSendStream, Process +from .lowlevel import RunVar, checkpoint_if_cancelled +from .streams.buffered import BufferedByteReceiveStream + +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack +else: + from typing_extensions import TypeVarTuple, Unpack + +WORKER_MAX_IDLE_TIME = 300 # 5 minutes + +T_Retval = TypeVar("T_Retval") +PosArgsT = TypeVarTuple("PosArgsT") + +_process_pool_workers: RunVar[set[Process]] = RunVar("_process_pool_workers") +_process_pool_idle_workers: RunVar[deque[tuple[Process, float]]] = RunVar( + "_process_pool_idle_workers" +) +_default_process_limiter: RunVar[CapacityLimiter] = RunVar("_default_process_limiter") + + +async def run_sync( # type: ignore[return] + func: Callable[[Unpack[PosArgsT]], T_Retval], + *args: Unpack[PosArgsT], + cancellable: bool = False, + limiter: CapacityLimiter | None = None, +) -> T_Retval: + """ + Call the given function with the given arguments in a worker process. + + If the ``cancellable`` option is enabled and the task waiting for its completion is + cancelled, the worker process running it will be abruptly terminated using SIGKILL + (or ``terminateProcess()`` on Windows). + + :param func: a callable + :param args: positional arguments for the callable + :param cancellable: ``True`` to allow cancellation of the operation while it's + running + :param limiter: capacity limiter to use to limit the total amount of processes + running (if omitted, the default limiter is used) + :return: an awaitable that yields the return value of the function. + + """ + + async def send_raw_command(pickled_cmd: bytes) -> object: + try: + await stdin.send(pickled_cmd) + response = await buffered.receive_until(b"\n", 50) + status, length = response.split(b" ") + if status not in (b"RETURN", b"EXCEPTION"): + raise RuntimeError( + f"Worker process returned unexpected response: {response!r}" + ) + + pickled_response = await buffered.receive_exactly(int(length)) + except BaseException as exc: + workers.discard(process) + try: + process.kill() + with CancelScope(shield=True): + await process.aclose() + except ProcessLookupError: + pass + + if isinstance(exc, get_cancelled_exc_class()): + raise + else: + raise BrokenWorkerProcess from exc + + retval = pickle.loads(pickled_response) + if status == b"EXCEPTION": + assert isinstance(retval, BaseException) + raise retval + else: + return retval + + # First pickle the request before trying to reserve a worker process + await checkpoint_if_cancelled() + request = pickle.dumps(("run", func, args), protocol=pickle.HIGHEST_PROTOCOL) + + # If this is the first run in this event loop thread, set up the necessary variables + try: + workers = _process_pool_workers.get() + idle_workers = _process_pool_idle_workers.get() + except LookupError: + workers = set() + idle_workers = deque() + _process_pool_workers.set(workers) + _process_pool_idle_workers.set(idle_workers) + get_async_backend().setup_process_pool_exit_at_shutdown(workers) + + async with limiter or current_default_process_limiter(): + # Pop processes from the pool (starting from the most recently used) until we + # find one that hasn't exited yet + process: Process + while idle_workers: + process, idle_since = idle_workers.pop() + if process.returncode is None: + stdin = cast(ByteSendStream, process.stdin) + buffered = BufferedByteReceiveStream( + cast(ByteReceiveStream, process.stdout) + ) + + # Prune any other workers that have been idle for WORKER_MAX_IDLE_TIME + # seconds or longer + now = current_time() + killed_processes: list[Process] = [] + while idle_workers: + if now - idle_workers[0][1] < WORKER_MAX_IDLE_TIME: + break + + process_to_kill, idle_since = idle_workers.popleft() + process_to_kill.kill() + workers.remove(process_to_kill) + killed_processes.append(process_to_kill) + + with CancelScope(shield=True): + for killed_process in killed_processes: + await killed_process.aclose() + + break + + workers.remove(process) + else: + command = [sys.executable, "-u", "-m", __name__] + process = await open_process( + command, stdin=subprocess.PIPE, stdout=subprocess.PIPE + ) + try: + stdin = cast(ByteSendStream, process.stdin) + buffered = BufferedByteReceiveStream( + cast(ByteReceiveStream, process.stdout) + ) + with fail_after(20): + message = await buffered.receive(6) + + if message != b"READY\n": + raise BrokenWorkerProcess( + f"Worker process returned unexpected response: {message!r}" + ) + + main_module_path = getattr(sys.modules["__main__"], "__file__", None) + pickled = pickle.dumps( + ("init", sys.path, main_module_path), + protocol=pickle.HIGHEST_PROTOCOL, + ) + await send_raw_command(pickled) + except (BrokenWorkerProcess, get_cancelled_exc_class()): + raise + except BaseException as exc: + process.kill() + raise BrokenWorkerProcess( + "Error during worker process initialization" + ) from exc + + workers.add(process) + + with CancelScope(shield=not cancellable): + try: + return cast(T_Retval, await send_raw_command(request)) + finally: + if process in workers: + idle_workers.append((process, current_time())) + + +def current_default_process_limiter() -> CapacityLimiter: + """ + Return the capacity limiter that is used by default to limit the number of worker + processes. + + :return: a capacity limiter object + + """ + try: + return _default_process_limiter.get() + except LookupError: + limiter = CapacityLimiter(os.cpu_count() or 2) + _default_process_limiter.set(limiter) + return limiter + + +def process_worker() -> None: + # Redirect standard streams to os.devnull so that user code won't interfere with the + # parent-worker communication + stdin = sys.stdin + stdout = sys.stdout + sys.stdin = open(os.devnull) + sys.stdout = open(os.devnull, "w") + + stdout.buffer.write(b"READY\n") + while True: + retval = exception = None + try: + command, *args = pickle.load(stdin.buffer) + except EOFError: + return + except BaseException as exc: + exception = exc + else: + if command == "run": + func, args = args + try: + retval = func(*args) + except BaseException as exc: + exception = exc + elif command == "init": + main_module_path: str | None + sys.path, main_module_path = args + del sys.modules["__main__"] + if main_module_path and os.path.isfile(main_module_path): + # Load the parent's main module but as __mp_main__ instead of + # __main__ (like multiprocessing does) to avoid infinite recursion + try: + spec = spec_from_file_location("__mp_main__", main_module_path) + if spec and spec.loader: + main = module_from_spec(spec) + spec.loader.exec_module(main) + sys.modules["__main__"] = main + except BaseException as exc: + exception = exc + try: + if exception is not None: + status = b"EXCEPTION" + pickled = pickle.dumps(exception, pickle.HIGHEST_PROTOCOL) + else: + status = b"RETURN" + pickled = pickle.dumps(retval, pickle.HIGHEST_PROTOCOL) + except BaseException as exc: + exception = exc + status = b"EXCEPTION" + pickled = pickle.dumps(exc, pickle.HIGHEST_PROTOCOL) + + stdout.buffer.write(b"%s %d\n" % (status, len(pickled))) + stdout.buffer.write(pickled) + + # Respect SIGTERM + if isinstance(exception, SystemExit): + raise exception + + +if __name__ == "__main__": + process_worker() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/anyio/to_thread.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/anyio/to_thread.py new file mode 100644 index 0000000000000000000000000000000000000000..5070516eb56679f863bd446c97cf76376d80d83b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/anyio/to_thread.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import sys +from collections.abc import Callable +from typing import TypeVar +from warnings import warn + +from ._core._eventloop import get_async_backend +from .abc import CapacityLimiter + +if sys.version_info >= (3, 11): + from typing import TypeVarTuple, Unpack +else: + from typing_extensions import TypeVarTuple, Unpack + +T_Retval = TypeVar("T_Retval") +PosArgsT = TypeVarTuple("PosArgsT") + + +async def run_sync( + func: Callable[[Unpack[PosArgsT]], T_Retval], + *args: Unpack[PosArgsT], + abandon_on_cancel: bool = False, + cancellable: bool | None = None, + limiter: CapacityLimiter | None = None, +) -> T_Retval: + """ + Call the given function with the given arguments in a worker thread. + + If the ``cancellable`` option is enabled and the task waiting for its completion is + cancelled, the thread will still run its course but its return value (or any raised + exception) will be ignored. + + :param func: a callable + :param args: positional arguments for the callable + :param abandon_on_cancel: ``True`` to abandon the thread (leaving it to run + unchecked on own) if the host task is cancelled, ``False`` to ignore + cancellations in the host task until the operation has completed in the worker + thread + :param cancellable: deprecated alias of ``abandon_on_cancel``; will override + ``abandon_on_cancel`` if both parameters are passed + :param limiter: capacity limiter to use to limit the total amount of threads running + (if omitted, the default limiter is used) + :return: an awaitable that yields the return value of the function. + + """ + if cancellable is not None: + abandon_on_cancel = cancellable + warn( + "The `cancellable=` keyword argument to `anyio.to_thread.run_sync` is " + "deprecated since AnyIO 4.1.0; use `abandon_on_cancel=` instead", + DeprecationWarning, + stacklevel=2, + ) + + return await get_async_backend().run_sync_in_worker_thread( + func, args, abandon_on_cancel=abandon_on_cancel, limiter=limiter + ) + + +def current_default_thread_limiter() -> CapacityLimiter: + """ + Return the capacity limiter that is used by default to limit the number of + concurrent threads. + + :return: a capacity limiter object + + """ + return get_async_backend().current_default_thread_limiter() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/markdown_it/_compat.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/markdown_it/_compat.py new file mode 100644 index 0000000000000000000000000000000000000000..974d431bd9828ef226e5c965dee56edd47d4f0ed --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/markdown_it/_compat.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from collections.abc import Mapping +import sys +from typing import Any + +DATACLASS_KWARGS: Mapping[str, Any] +if sys.version_info >= (3, 10): + DATACLASS_KWARGS = {"slots": True} +else: + DATACLASS_KWARGS = {} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/markdown_it/parser_inline.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/markdown_it/parser_inline.py new file mode 100644 index 0000000000000000000000000000000000000000..0026c3839d14fe8b78ff033abfac0c299b8a0c54 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/markdown_it/parser_inline.py @@ -0,0 +1,147 @@ +"""Tokenizes paragraph content. +""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable + +from . import rules_inline +from .ruler import Ruler +from .rules_inline.state_inline import StateInline +from .token import Token +from .utils import EnvType + +if TYPE_CHECKING: + from markdown_it import MarkdownIt + + +# Parser rules +RuleFuncInlineType = Callable[[StateInline, bool], bool] +"""(state: StateInline, silent: bool) -> matched: bool) + +`silent` disables token generation, useful for lookahead. +""" +_rules: list[tuple[str, RuleFuncInlineType]] = [ + ("text", rules_inline.text), + ("linkify", rules_inline.linkify), + ("newline", rules_inline.newline), + ("escape", rules_inline.escape), + ("backticks", rules_inline.backtick), + ("strikethrough", rules_inline.strikethrough.tokenize), + ("emphasis", rules_inline.emphasis.tokenize), + ("link", rules_inline.link), + ("image", rules_inline.image), + ("autolink", rules_inline.autolink), + ("html_inline", rules_inline.html_inline), + ("entity", rules_inline.entity), +] + +# Note `rule2` ruleset was created specifically for emphasis/strikethrough +# post-processing and may be changed in the future. +# +# Don't use this for anything except pairs (plugins working with `balance_pairs`). +# +RuleFuncInline2Type = Callable[[StateInline], None] +_rules2: list[tuple[str, RuleFuncInline2Type]] = [ + ("balance_pairs", rules_inline.link_pairs), + ("strikethrough", rules_inline.strikethrough.postProcess), + ("emphasis", rules_inline.emphasis.postProcess), + # rules for pairs separate '**' into its own text tokens, which may be left unused, + # rule below merges unused segments back with the rest of the text + ("fragments_join", rules_inline.fragments_join), +] + + +class ParserInline: + def __init__(self) -> None: + self.ruler = Ruler[RuleFuncInlineType]() + for name, rule in _rules: + self.ruler.push(name, rule) + # Second ruler used for post-processing (e.g. in emphasis-like rules) + self.ruler2 = Ruler[RuleFuncInline2Type]() + for name, rule2 in _rules2: + self.ruler2.push(name, rule2) + + def skipToken(self, state: StateInline) -> None: + """Skip single token by running all rules in validation mode; + returns `True` if any rule reported success + """ + ok = False + pos = state.pos + rules = self.ruler.getRules("") + maxNesting = state.md.options["maxNesting"] + cache = state.cache + + if pos in cache: + state.pos = cache[pos] + return + + if state.level < maxNesting: + for rule in rules: + # Increment state.level and decrement it later to limit recursion. + # It's harmless to do here, because no tokens are created. + # But ideally, we'd need a separate private state variable for this purpose. + state.level += 1 + ok = rule(state, True) + state.level -= 1 + if ok: + break + else: + # Too much nesting, just skip until the end of the paragraph. + # + # NOTE: this will cause links to behave incorrectly in the following case, + # when an amount of `[` is exactly equal to `maxNesting + 1`: + # + # [[[[[[[[[[[[[[[[[[[[[foo]() + # + # TODO: remove this workaround when CM standard will allow nested links + # (we can replace it by preventing links from being parsed in + # validation mode) + # + state.pos = state.posMax + + if not ok: + state.pos += 1 + cache[pos] = state.pos + + def tokenize(self, state: StateInline) -> None: + """Generate tokens for input range.""" + ok = False + rules = self.ruler.getRules("") + end = state.posMax + maxNesting = state.md.options["maxNesting"] + + while state.pos < end: + # Try all possible rules. + # On success, rule should: + # + # - update `state.pos` + # - update `state.tokens` + # - return true + + if state.level < maxNesting: + for rule in rules: + ok = rule(state, False) + if ok: + break + + if ok: + if state.pos >= end: + break + continue + + state.pending += state.src[state.pos] + state.pos += 1 + + if state.pending: + state.pushPending() + + def parse( + self, src: str, md: MarkdownIt, env: EnvType, tokens: list[Token] + ) -> list[Token]: + """Process input string and push inline tokens into `tokens`""" + state = StateInline(src, md, env, tokens) + self.tokenize(state) + rules2 = self.ruler2.getRules("") + for rule in rules2: + rule(state) + return state.tokens diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/markdown_it/port.yaml b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/markdown_it/port.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3e289e9e27d53efea476622d1d6dc60799920835 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/markdown_it/port.yaml @@ -0,0 +1,48 @@ +- package: markdown-it/markdown-it + version: 13.0.1 + commit: e843acc9edad115cbf8cf85e676443f01658be08 + date: May 3, 2022 + notes: + - Rename variables that use python built-in names, e.g. + - `max` -> `maximum` + - `len` -> `length` + - `str` -> `string` + - | + Convert JS `for` loops to `while` loops + this is generally the main difference between the codes, + because in python you can't do e.g. `for {i=1;i Any: + """Convert Token.attrs set as ``None`` or ``[[key, value], ...]`` to a dict. + + This improves compatibility with upstream markdown-it. + """ + if not value: + return {} + if isinstance(value, list): + return dict(value) + return value + + +@dc.dataclass(**DATACLASS_KWARGS) +class Token: + type: str + """Type of the token (string, e.g. "paragraph_open")""" + + tag: str + """HTML tag name, e.g. 'p'""" + + nesting: Literal[-1, 0, 1] + """Level change (number in {-1, 0, 1} set), where: + - `1` means the tag is opening + - `0` means the tag is self-closing + - `-1` means the tag is closing + """ + + attrs: dict[str, str | int | float] = dc.field(default_factory=dict) + """HTML attributes. + Note this differs from the upstream "list of lists" format, + although than an instance can still be initialised with this format. + """ + + map: list[int] | None = None + """Source map info. Format: `[ line_begin, line_end ]`""" + + level: int = 0 + """Nesting level, the same as `state.level`""" + + children: list[Token] | None = None + """Array of child nodes (inline and img tokens).""" + + content: str = "" + """Inner content, in the case of a self-closing tag (code, html, fence, etc.),""" + + markup: str = "" + """'*' or '_' for emphasis, fence string for fence, etc.""" + + info: str = "" + """Additional information: + - Info string for "fence" tokens + - The value "auto" for autolink "link_open" and "link_close" tokens + - The string value of the item marker for ordered-list "list_item_open" tokens + """ + + meta: dict[Any, Any] = dc.field(default_factory=dict) + """A place for plugins to store any arbitrary data""" + + block: bool = False + """True for block-level tokens, false for inline tokens. + Used in renderer to calculate line breaks + """ + + hidden: bool = False + """If true, ignore this element when rendering. + Used for tight lists to hide paragraphs. + """ + + def __post_init__(self) -> None: + self.attrs = convert_attrs(self.attrs) + + def attrIndex(self, name: str) -> int: + warnings.warn( # noqa: B028 + "Token.attrIndex should not be used, since Token.attrs is a dictionary", + UserWarning, + ) + if name not in self.attrs: + return -1 + return list(self.attrs.keys()).index(name) + + def attrItems(self) -> list[tuple[str, str | int | float]]: + """Get (key, value) list of attrs.""" + return list(self.attrs.items()) + + def attrPush(self, attrData: tuple[str, str | int | float]) -> None: + """Add `[ name, value ]` attribute to list. Init attrs if necessary.""" + name, value = attrData + self.attrSet(name, value) + + def attrSet(self, name: str, value: str | int | float) -> None: + """Set `name` attribute to `value`. Override old value if exists.""" + self.attrs[name] = value + + def attrGet(self, name: str) -> None | str | int | float: + """Get the value of attribute `name`, or null if it does not exist.""" + return self.attrs.get(name, None) + + def attrJoin(self, name: str, value: str) -> None: + """Join value to existing attribute via space. + Or create new attribute if not exists. + Useful to operate with token classes. + """ + if name in self.attrs: + current = self.attrs[name] + if not isinstance(current, str): + raise TypeError( + f"existing attr 'name' is not a str: {self.attrs[name]}" + ) + self.attrs[name] = f"{current} {value}" + else: + self.attrs[name] = value + + def copy(self, **changes: Any) -> Token: + """Return a shallow copy of the instance.""" + return dc.replace(self, **changes) + + def as_dict( + self, + *, + children: bool = True, + as_upstream: bool = True, + meta_serializer: Callable[[dict[Any, Any]], Any] | None = None, + filter: Callable[[str, Any], bool] | None = None, + dict_factory: Callable[..., MutableMapping[str, Any]] = dict, + ) -> MutableMapping[str, Any]: + """Return the token as a dictionary. + + :param children: Also convert children to dicts + :param as_upstream: Ensure the output dictionary is equal to that created by markdown-it + For example, attrs are converted to null or lists + :param meta_serializer: hook for serializing ``Token.meta`` + :param filter: A callable whose return code determines whether an + attribute or element is included (``True``) or dropped (``False``). + Is called with the (key, value) pair. + :param dict_factory: A callable to produce dictionaries from. + For example, to produce ordered dictionaries instead of normal Python + dictionaries, pass in ``collections.OrderedDict``. + + """ + mapping = dict_factory((f.name, getattr(self, f.name)) for f in dc.fields(self)) + if filter: + mapping = dict_factory((k, v) for k, v in mapping.items() if filter(k, v)) + if as_upstream and "attrs" in mapping: + mapping["attrs"] = ( + None + if not mapping["attrs"] + else [[k, v] for k, v in mapping["attrs"].items()] + ) + if meta_serializer and "meta" in mapping: + mapping["meta"] = meta_serializer(mapping["meta"]) + if children and mapping.get("children", None): + mapping["children"] = [ + child.as_dict( + children=children, + filter=filter, + dict_factory=dict_factory, + as_upstream=as_upstream, + meta_serializer=meta_serializer, + ) + for child in mapping["children"] + ] + return mapping + + @classmethod + def from_dict(cls, dct: MutableMapping[str, Any]) -> Token: + """Convert a dict to a Token.""" + token = cls(**dct) + if token.children: + token.children = [cls.from_dict(c) for c in token.children] # type: ignore[arg-type] + return token diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/propcache-0.3.0.dist-info/INSTALLER b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/propcache-0.3.0.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/propcache-0.3.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/propcache-0.3.0.dist-info/LICENSE b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/propcache-0.3.0.dist-info/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..d645695673349e3947e8e5ae42332d0ac3164cd7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/propcache-0.3.0.dist-info/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/propcache-0.3.0.dist-info/METADATA b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/propcache-0.3.0.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..6209a43e0423510deb5a93e59dd8c2564d1df8e8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/propcache-0.3.0.dist-info/METADATA @@ -0,0 +1,317 @@ +Metadata-Version: 2.2 +Name: propcache +Version: 0.3.0 +Summary: Accelerated property cache +Home-page: https://github.com/aio-libs/propcache +Author: Andrew Svetlov +Author-email: andrew.svetlov@gmail.com +Maintainer: aiohttp team +Maintainer-email: team@aiohttp.org +License: Apache-2.0 +Project-URL: Chat: Matrix, https://matrix.to/#/#aio-libs:matrix.org +Project-URL: Chat: Matrix Space, https://matrix.to/#/#aio-libs-space:matrix.org +Project-URL: CI: GitHub Workflows, https://github.com/aio-libs/propcache/actions?query=branch:master +Project-URL: Code of Conduct, https://github.com/aio-libs/.github/blob/master/CODE_OF_CONDUCT.md +Project-URL: Coverage: codecov, https://codecov.io/github/aio-libs/propcache +Project-URL: Docs: Changelog, https://propcache.readthedocs.io/en/latest/changes/ +Project-URL: Docs: RTD, https://propcache.readthedocs.io +Project-URL: GitHub: issues, https://github.com/aio-libs/propcache/issues +Project-URL: GitHub: repo, https://github.com/aio-libs/propcache +Keywords: cython,cext,propcache +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Programming Language :: Cython +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Programming Language :: Python :: 3.13 +Classifier: Topic :: Internet :: WWW/HTTP +Classifier: Topic :: Software Development :: Libraries :: Python Modules +Requires-Python: >=3.9 +Description-Content-Type: text/x-rst +License-File: LICENSE +License-File: NOTICE + +propcache +========= + +The module provides a fast implementation of cached properties for Python 3.9+. + +.. image:: https://github.com/aio-libs/propcache/actions/workflows/ci-cd.yml/badge.svg + :target: https://github.com/aio-libs/propcache/actions?query=workflow%3ACI + :align: right + +.. image:: https://codecov.io/gh/aio-libs/propcache/branch/master/graph/badge.svg + :target: https://codecov.io/gh/aio-libs/propcache + +.. image:: https://badge.fury.io/py/propcache.svg + :target: https://badge.fury.io/py/propcache + + +.. image:: https://readthedocs.org/projects/propcache/badge/?version=latest + :target: https://propcache.readthedocs.io + + +.. image:: https://img.shields.io/pypi/pyversions/propcache.svg + :target: https://pypi.python.org/pypi/propcache + +.. image:: https://img.shields.io/matrix/aio-libs:matrix.org?label=Discuss%20on%20Matrix%20at%20%23aio-libs%3Amatrix.org&logo=matrix&server_fqdn=matrix.org&style=flat + :target: https://matrix.to/#/%23aio-libs:matrix.org + :alt: Matrix Room — #aio-libs:matrix.org + +.. image:: https://img.shields.io/matrix/aio-libs-space:matrix.org?label=Discuss%20on%20Matrix%20at%20%23aio-libs-space%3Amatrix.org&logo=matrix&server_fqdn=matrix.org&style=flat + :target: https://matrix.to/#/%23aio-libs-space:matrix.org + :alt: Matrix Space — #aio-libs-space:matrix.org + +Introduction +------------ + +The API is designed to be nearly identical to the built-in ``functools.cached_property`` class, +except for the additional ``under_cached_property`` class which uses ``self._cache`` +instead of ``self.__dict__`` to store the cached values and prevents ``__set__`` from being called. + +For full documentation please read https://propcache.readthedocs.io. + +Installation +------------ + +:: + + $ pip install propcache + +The library is Python 3 only! + +PyPI contains binary wheels for Linux, Windows and MacOS. If you want to install +``propcache`` on another operating system where wheels are not provided, +the the tarball will be used to compile the library from +the source code. It requires a C compiler and and Python headers installed. + +To skip the compilation you must explicitly opt-in by using a PEP 517 +configuration setting ``pure-python``, or setting the ``PROPCACHE_NO_EXTENSIONS`` +environment variable to a non-empty value, e.g.: + +.. code-block:: console + + $ pip install propcache --config-settings=pure-python=false + +Please note that the pure-Python (uncompiled) version is much slower. However, +PyPy always uses a pure-Python implementation, and, as such, it is unaffected +by this variable. + + +API documentation +------------------ + +The documentation is located at https://propcache.readthedocs.io. + +Source code +----------- + +The project is hosted on GitHub_ + +Please file an issue on the `bug tracker +`_ if you have found a bug +or have some suggestion in order to improve the library. + +Discussion list +--------------- + +*aio-libs* google group: https://groups.google.com/forum/#!forum/aio-libs + +Feel free to post your questions and ideas here. + + +Authors and License +------------------- + +The ``propcache`` package is derived from ``yarl`` which is written by Andrew Svetlov. + +It's *Apache 2* licensed and freely available. + + +.. _GitHub: https://github.com/aio-libs/propcache + +========= +Changelog +========= + +.. + You should *NOT* be adding new change log entries to this file, this + file is managed by towncrier. You *may* edit previous change logs to + fix problems like typo corrections or such. + To add a new change log entry, please see + https://pip.pypa.io/en/latest/development/#adding-a-news-entry + we named the news folder "changes". + + WARNING: Don't drop the next directive! + +.. towncrier release notes start + +0.3.0 +===== + +*(2025-02-20)* + + +Features +-------- + +- Implemented support for the free-threaded build of CPython 3.13 -- by `@lysnikolaou `__. + + *Related issues and pull requests on GitHub:* + `#84 `__. + + +Packaging updates and notes for downstreams +------------------------------------------- + +- Started building wheels for the free-threaded build of CPython 3.13 -- by `@lysnikolaou `__. + + *Related issues and pull requests on GitHub:* + `#84 `__. + + +Contributor-facing changes +-------------------------- + +- GitHub Actions CI/CD is now configured to manage caching pip-ecosystem + dependencies using `re-actors/cache-python-deps`_ -- an action by + `@webknjaz `__ that takes into account ABI stability and the exact + version of Python runtime. + + .. _`re-actors/cache-python-deps`: + https://github.com/marketplace/actions/cache-python-deps + + *Related issues and pull requests on GitHub:* + `#93 `__. + + +---- + + +0.2.1 +===== + +*(2024-12-01)* + + +Bug fixes +--------- + +- Stopped implicitly allowing the use of Cython pre-release versions when + building the distribution package -- by `@ajsanchezsanz `__ and + `@markgreene74 `__. + + *Related commits on GitHub:* + `64df0a6 `__. + +- Fixed ``wrapped`` and ``func`` not being accessible in the Cython versions of ``propcache.api.cached_property`` and ``propcache.api.under_cached_property`` decorators -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#72 `__. + + +Removals and backward incompatible breaking changes +--------------------------------------------------- + +- Removed support for Python 3.8 as it has reached end of life -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#57 `__. + + +Packaging updates and notes for downstreams +------------------------------------------- + +- Stopped implicitly allowing the use of Cython pre-release versions when + building the distribution package -- by `@ajsanchezsanz `__ and + `@markgreene74 `__. + + *Related commits on GitHub:* + `64df0a6 `__. + + +---- + + +0.2.0 +===== + +*(2024-10-07)* + + +Bug fixes +--------- + +- Fixed loading the C-extensions on Python 3.8 -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#26 `__. + + +Features +-------- + +- Improved typing for the ``propcache.api.under_cached_property`` decorator -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#38 `__. + + +Improved documentation +---------------------- + +- Added API documentation for the ``propcache.api.cached_property`` and ``propcache.api.under_cached_property`` decorators -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#16 `__. + + +Packaging updates and notes for downstreams +------------------------------------------- + +- Moved ``propcache.api.under_cached_property`` and ``propcache.api.cached_property`` to `propcache.api` -- by `@bdraco `__. + + Both decorators remain importable from the top-level package, however importing from `propcache.api` is now the recommended way to use them. + + *Related issues and pull requests on GitHub:* + `#19 `__, `#24 `__, `#32 `__. + +- Converted project to use a src layout -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#22 `__, `#29 `__, `#37 `__. + + +---- + + +0.1.0 +===== + +*(2024-10-03)* + + +Features +-------- + +- Added ``armv7l`` wheels -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#5 `__. + + +---- + + +0.0.0 +===== + +*(2024-10-02)* + + +- Initial release. diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/propcache-0.3.0.dist-info/NOTICE b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/propcache-0.3.0.dist-info/NOTICE new file mode 100644 index 0000000000000000000000000000000000000000..fa53b2b138df881c4c95239d0e4bede831b36ab5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/propcache-0.3.0.dist-info/NOTICE @@ -0,0 +1,13 @@ + Copyright 2016-2021, Andrew Svetlov and aio-libs team + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/propcache-0.3.0.dist-info/RECORD b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/propcache-0.3.0.dist-info/RECORD new file mode 100644 index 0000000000000000000000000000000000000000..6c4b51f178813bca712f91561be90451084a3201 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/propcache-0.3.0.dist-info/RECORD @@ -0,0 +1,18 @@ +propcache-0.3.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +propcache-0.3.0.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358 +propcache-0.3.0.dist-info/METADATA,sha256=EVZE30m1bHCrcoZ2RLtvfNkdXyHIInAIG39SNZEjLyE,10357 +propcache-0.3.0.dist-info/NOTICE,sha256=VtasbIEFwKUTBMIdsGDjYa-ajqCvmnXCOcKLXRNpODg,609 +propcache-0.3.0.dist-info/RECORD,, +propcache-0.3.0.dist-info/WHEEL,sha256=siqMuoWpRueIZ87ijidBxnOwHeSOOcxNwYCs-pC4Yv0,151 +propcache-0.3.0.dist-info/top_level.txt,sha256=pVF_GbqSAITPMiX27kfU3QP9-ufhRvkADmudDxWdF3w,10 +propcache/__init__.py,sha256=82yOKjJMHwsj2IpsIfiuDumvBEOckaz2HB823YDJH4Y,965 +propcache/__pycache__/__init__.cpython-312.pyc,, +propcache/__pycache__/_helpers.cpython-312.pyc,, +propcache/__pycache__/_helpers_py.cpython-312.pyc,, +propcache/__pycache__/api.cpython-312.pyc,, +propcache/_helpers.py,sha256=8CnlWmfTM6RDbMvNDXwL-VMHWiwIUjG8nbeqmvRsbh8,1579 +propcache/_helpers_c.cpython-312-x86_64-linux-gnu.so,sha256=mqnzbvcdqj2KOhx9xGMMjatXLlEuB69xfnsCF5O9G8I,844512 +propcache/_helpers_c.pyx,sha256=9UqfhVrbbkiZDGtEPFEOfT7qghPjAkNtJpgI1JYUPao,2518 +propcache/_helpers_py.py,sha256=jnK6W43iETLcW-A1WMroGUKnElzX8Drw2UQfbEqLlI8,1637 +propcache/api.py,sha256=wvgB-ypkkI5uf72VVYl2NFGc_TnzUQA2CxC7dTlL5ak,179 +propcache/py.typed,sha256=ay5OMO475PlcZ_Fbun9maHW7Y6MBTk0UXL4ztHx3Iug,14 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/propcache-0.3.0.dist-info/WHEEL b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/propcache-0.3.0.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..074b7f880946583239d6b4f8284e932b9e12c9e7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/propcache-0.3.0.dist-info/WHEEL @@ -0,0 +1,6 @@ +Wheel-Version: 1.0 +Generator: setuptools (75.8.0) +Root-Is-Purelib: false +Tag: cp312-cp312-manylinux_2_17_x86_64 +Tag: cp312-cp312-manylinux2014_x86_64 + diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/propcache-0.3.0.dist-info/top_level.txt b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/propcache-0.3.0.dist-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..8c9accf6226df7e4011a41ac5d6014223685cfed --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/propcache-0.3.0.dist-info/top_level.txt @@ -0,0 +1 @@ +propcache diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d00a731324c92d51f0c421dbd64328904dce6dac --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/__init__.py @@ -0,0 +1,437 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# flake8: noqa + +""" +PyArrow is the python implementation of Apache Arrow. + +Apache Arrow is a cross-language development platform for in-memory data. +It specifies a standardized language-independent columnar memory format for +flat and hierarchical data, organized for efficient analytic operations on +modern hardware. It also provides computational libraries and zero-copy +streaming messaging and interprocess communication. + +For more information see the official page at https://arrow.apache.org +""" + +import gc as _gc +import importlib as _importlib +import os as _os +import platform as _platform +import sys as _sys +import warnings as _warnings + +try: + from ._generated_version import version as __version__ +except ImportError: + # Package is not installed, parse git tag at runtime + try: + import setuptools_scm + # Code duplicated from setup.py to avoid a dependency on each other + + def parse_git(root, **kwargs): + """ + Parse function for setuptools_scm that ignores tags for non-C++ + subprojects, e.g. apache-arrow-js-XXX tags. + """ + from setuptools_scm.git import parse + kwargs['describe_command'] = \ + "git describe --dirty --tags --long --match 'apache-arrow-[0-9]*.*'" + return parse(root, **kwargs) + __version__ = setuptools_scm.get_version('../', + parse=parse_git) + except ImportError: + __version__ = None + +# ARROW-8684: Disable GC while initializing Cython extension module, +# to workaround Cython bug in https://github.com/cython/cython/issues/3603 +_gc_enabled = _gc.isenabled() +_gc.disable() +import pyarrow.lib as _lib +if _gc_enabled: + _gc.enable() + +from pyarrow.lib import (BuildInfo, RuntimeInfo, set_timezone_db_path, + MonthDayNano, VersionInfo, cpp_build_info, + cpp_version, cpp_version_info, runtime_info, + cpu_count, set_cpu_count, enable_signal_handlers, + io_thread_count, set_io_thread_count) + + +def show_versions(): + """ + Print various version information, to help with error reporting. + """ + def print_entry(label, value): + print(f"{label: <26}: {value: <8}") + + print("pyarrow version info\n--------------------") + print_entry("Package kind", cpp_build_info.package_kind + if len(cpp_build_info.package_kind) > 0 + else "not indicated") + print_entry("Arrow C++ library version", cpp_build_info.version) + print_entry("Arrow C++ compiler", + f"{cpp_build_info.compiler_id} {cpp_build_info.compiler_version}") + print_entry("Arrow C++ compiler flags", cpp_build_info.compiler_flags) + print_entry("Arrow C++ git revision", cpp_build_info.git_id) + print_entry("Arrow C++ git description", cpp_build_info.git_description) + print_entry("Arrow C++ build type", cpp_build_info.build_type) + + +def _module_is_available(module): + try: + _importlib.import_module(f'pyarrow.{module}') + except ImportError: + return False + else: + return True + + +def _filesystem_is_available(fs): + try: + import pyarrow.fs + except ImportError: + return False + + try: + getattr(pyarrow.fs, fs) + except (ImportError, AttributeError): + return False + else: + return True + + +def show_info(): + """ + Print detailed version and platform information, for error reporting + """ + show_versions() + + def print_entry(label, value): + print(f" {label: <20}: {value: <8}") + + print("\nPlatform:") + print_entry("OS / Arch", f"{_platform.system()} {_platform.machine()}") + print_entry("SIMD Level", runtime_info().simd_level) + print_entry("Detected SIMD Level", runtime_info().detected_simd_level) + + pool = default_memory_pool() + print("\nMemory:") + print_entry("Default backend", pool.backend_name) + print_entry("Bytes allocated", f"{pool.bytes_allocated()} bytes") + print_entry("Max memory", f"{pool.max_memory()} bytes") + print_entry("Supported Backends", ', '.join(supported_memory_backends())) + + print("\nOptional modules:") + modules = ["csv", "cuda", "dataset", "feather", "flight", "fs", "gandiva", "json", + "orc", "parquet"] + for module in modules: + status = "Enabled" if _module_is_available(module) else "-" + print(f" {module: <20}: {status: <8}") + + print("\nFilesystems:") + filesystems = ["AzureFileSystem", "GcsFileSystem", + "HadoopFileSystem", "S3FileSystem"] + for fs in filesystems: + status = "Enabled" if _filesystem_is_available(fs) else "-" + print(f" {fs: <20}: {status: <8}") + + print("\nCompression Codecs:") + codecs = ["brotli", "bz2", "gzip", "lz4_frame", "lz4", "snappy", "zstd"] + for codec in codecs: + status = "Enabled" if Codec.is_available(codec) else "-" + print(f" {codec: <20}: {status: <8}") + + +from pyarrow.lib import (null, bool_, + int8, int16, int32, int64, + uint8, uint16, uint32, uint64, + time32, time64, timestamp, date32, date64, duration, + month_day_nano_interval, + float16, float32, float64, + binary, string, utf8, binary_view, string_view, + large_binary, large_string, large_utf8, + decimal32, decimal64, decimal128, decimal256, + list_, large_list, list_view, large_list_view, + map_, struct, + union, sparse_union, dense_union, + dictionary, + run_end_encoded, + bool8, fixed_shape_tensor, json_, opaque, uuid, + field, + type_for_alias, + DataType, DictionaryType, StructType, + ListType, LargeListType, FixedSizeListType, + ListViewType, LargeListViewType, + MapType, UnionType, SparseUnionType, DenseUnionType, + TimestampType, Time32Type, Time64Type, DurationType, + FixedSizeBinaryType, + Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, + BaseExtensionType, ExtensionType, + RunEndEncodedType, Bool8Type, FixedShapeTensorType, + JsonType, OpaqueType, UuidType, + PyExtensionType, UnknownExtensionType, + register_extension_type, unregister_extension_type, + DictionaryMemo, + KeyValueMetadata, + Field, + Schema, + schema, + unify_schemas, + Array, Tensor, + array, chunked_array, record_batch, nulls, repeat, + SparseCOOTensor, SparseCSRMatrix, SparseCSCMatrix, + SparseCSFTensor, + infer_type, from_numpy_dtype, + NullArray, + NumericArray, IntegerArray, FloatingPointArray, + BooleanArray, + Int8Array, UInt8Array, + Int16Array, UInt16Array, + Int32Array, UInt32Array, + Int64Array, UInt64Array, + HalfFloatArray, FloatArray, DoubleArray, + ListArray, LargeListArray, FixedSizeListArray, + ListViewArray, LargeListViewArray, + MapArray, UnionArray, + BinaryArray, StringArray, + LargeBinaryArray, LargeStringArray, + BinaryViewArray, StringViewArray, + FixedSizeBinaryArray, + DictionaryArray, + Date32Array, Date64Array, TimestampArray, + Time32Array, Time64Array, DurationArray, + MonthDayNanoIntervalArray, + Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array, + StructArray, ExtensionArray, + RunEndEncodedArray, Bool8Array, FixedShapeTensorArray, + JsonArray, OpaqueArray, UuidArray, + scalar, NA, _NULL as NULL, Scalar, + NullScalar, BooleanScalar, + Int8Scalar, Int16Scalar, Int32Scalar, Int64Scalar, + UInt8Scalar, UInt16Scalar, UInt32Scalar, UInt64Scalar, + HalfFloatScalar, FloatScalar, DoubleScalar, + Decimal32Scalar, Decimal64Scalar, Decimal128Scalar, Decimal256Scalar, + ListScalar, LargeListScalar, FixedSizeListScalar, + ListViewScalar, LargeListViewScalar, + Date32Scalar, Date64Scalar, + Time32Scalar, Time64Scalar, + TimestampScalar, DurationScalar, + MonthDayNanoIntervalScalar, + BinaryScalar, LargeBinaryScalar, BinaryViewScalar, + StringScalar, LargeStringScalar, StringViewScalar, + FixedSizeBinaryScalar, DictionaryScalar, + MapScalar, StructScalar, UnionScalar, + RunEndEncodedScalar, Bool8Scalar, ExtensionScalar, + FixedShapeTensorScalar, JsonScalar, OpaqueScalar, UuidScalar) + +# Buffers, allocation +from pyarrow.lib import (DeviceAllocationType, Device, MemoryManager, + default_cpu_memory_manager) + +from pyarrow.lib import (Buffer, ResizableBuffer, foreign_buffer, py_buffer, + Codec, compress, decompress, allocate_buffer) + +from pyarrow.lib import (MemoryPool, LoggingMemoryPool, ProxyMemoryPool, + total_allocated_bytes, set_memory_pool, + default_memory_pool, system_memory_pool, + jemalloc_memory_pool, mimalloc_memory_pool, + logging_memory_pool, proxy_memory_pool, + log_memory_allocations, jemalloc_set_decay_ms, + supported_memory_backends) + +# I/O +from pyarrow.lib import (NativeFile, PythonFile, + BufferedInputStream, BufferedOutputStream, CacheOptions, + CompressedInputStream, CompressedOutputStream, + TransformInputStream, transcoding_input_stream, + FixedSizeBufferWriter, + BufferReader, BufferOutputStream, + OSFile, MemoryMappedFile, memory_map, + create_memory_map, MockOutputStream, + input_stream, output_stream, + have_libhdfs) + +from pyarrow.lib import (ChunkedArray, RecordBatch, Table, table, + concat_arrays, concat_tables, TableGroupBy, + RecordBatchReader, concat_batches) + +# Exceptions +from pyarrow.lib import (ArrowCancelled, + ArrowCapacityError, + ArrowException, + ArrowKeyError, + ArrowIndexError, + ArrowInvalid, + ArrowIOError, + ArrowMemoryError, + ArrowNotImplementedError, + ArrowTypeError, + ArrowSerializationError) + +from pyarrow.ipc import serialize_pandas, deserialize_pandas +import pyarrow.ipc as ipc + +import pyarrow.types as types + + +# ---------------------------------------------------------------------- +# Deprecations + +from pyarrow.util import _deprecate_api, _deprecate_class + + +# TODO: Deprecate these somehow in the pyarrow namespace +from pyarrow.ipc import (Message, MessageReader, MetadataVersion, + RecordBatchFileReader, RecordBatchFileWriter, + RecordBatchStreamReader, RecordBatchStreamWriter) + +# ---------------------------------------------------------------------- +# Returning absolute path to the pyarrow include directory (if bundled, e.g. in +# wheels) + + +def get_include(): + """ + Return absolute path to directory containing Arrow C++ include + headers. Similar to numpy.get_include + """ + return _os.path.join(_os.path.dirname(__file__), 'include') + + +def _get_pkg_config_executable(): + return _os.environ.get('PKG_CONFIG', 'pkg-config') + + +def _has_pkg_config(pkgname): + import subprocess + try: + return subprocess.call([_get_pkg_config_executable(), + '--exists', pkgname]) == 0 + except FileNotFoundError: + return False + + +def _read_pkg_config_variable(pkgname, cli_args): + import subprocess + cmd = [_get_pkg_config_executable(), pkgname] + cli_args + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + out, err = proc.communicate() + if proc.returncode != 0: + raise RuntimeError("pkg-config failed: " + err.decode('utf8')) + return out.rstrip().decode('utf8') + + +def get_libraries(): + """ + Return list of library names to include in the `libraries` argument for C + or Cython extensions using pyarrow + """ + return ['arrow_python', 'arrow'] + + +def create_library_symlinks(): + """ + With Linux and macOS wheels, the bundled shared libraries have an embedded + ABI version like libarrow.so.17 or libarrow.17.dylib and so linking to them + with -larrow won't work unless we create symlinks at locations like + site-packages/pyarrow/libarrow.so. This unfortunate workaround addresses + prior problems we had with shipping two copies of the shared libraries to + permit third party projects like turbodbc to build their C++ extensions + against the pyarrow wheels. + + This function must only be invoked once and only when the shared libraries + are bundled with the Python package, which should only apply to wheel-based + installs. It requires write access to the site-packages/pyarrow directory + and so depending on your system may need to be run with root. + """ + import glob + if _sys.platform == 'win32': + return + package_cwd = _os.path.dirname(__file__) + + if _sys.platform == 'linux': + bundled_libs = glob.glob(_os.path.join(package_cwd, '*.so.*')) + + def get_symlink_path(hard_path): + return hard_path.rsplit('.', 1)[0] + else: + bundled_libs = glob.glob(_os.path.join(package_cwd, '*.*.dylib')) + + def get_symlink_path(hard_path): + return '.'.join((hard_path.rsplit('.', 2)[0], 'dylib')) + + for lib_hard_path in bundled_libs: + symlink_path = get_symlink_path(lib_hard_path) + if _os.path.exists(symlink_path): + continue + try: + _os.symlink(lib_hard_path, symlink_path) + except PermissionError: + print("Tried creating symlink {}. If you need to link to " + "bundled shared libraries, run " + "pyarrow.create_library_symlinks() as root") + + +def get_library_dirs(): + """ + Return lists of directories likely to contain Arrow C++ libraries for + linking C or Cython extensions using pyarrow + """ + package_cwd = _os.path.dirname(__file__) + library_dirs = [package_cwd] + + def append_library_dir(library_dir): + if library_dir not in library_dirs: + library_dirs.append(library_dir) + + # Search library paths via pkg-config. This is necessary if the user + # installed libarrow and the other shared libraries manually and they + # are not shipped inside the pyarrow package (see also ARROW-2976). + pkg_config_executable = _os.environ.get('PKG_CONFIG') or 'pkg-config' + for pkgname in ["arrow", "arrow_python"]: + if _has_pkg_config(pkgname): + library_dir = _read_pkg_config_variable(pkgname, + ["--libs-only-L"]) + # pkg-config output could be empty if Arrow is installed + # as a system package. + if library_dir: + if not library_dir.startswith("-L"): + raise ValueError( + "pkg-config --libs-only-L returned unexpected " + "value {!r}".format(library_dir)) + append_library_dir(library_dir[2:]) + + if _sys.platform == 'win32': + # TODO(wesm): Is this necessary, or does setuptools within a conda + # installation add Library\lib to the linker path for MSVC? + python_base_install = _os.path.dirname(_sys.executable) + library_dir = _os.path.join(python_base_install, 'Library', 'lib') + + if _os.path.exists(_os.path.join(library_dir, 'arrow.lib')): + append_library_dir(library_dir) + + # ARROW-4074: Allow for ARROW_HOME to be set to some other directory + if _os.environ.get('ARROW_HOME'): + append_library_dir(_os.path.join(_os.environ['ARROW_HOME'], 'lib')) + else: + # Python wheels bundle the Arrow libraries in the pyarrow directory. + append_library_dir(_os.path.dirname(_os.path.abspath(__file__))) + + return library_dirs diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_azurefs.pyx b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_azurefs.pyx new file mode 100644 index 0000000000000000000000000000000000000000..5cd6300c18c6a83e7036d84724666ba85396b530 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_azurefs.pyx @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +from cython cimport binding + + +from pyarrow.lib import frombytes, tobytes +from pyarrow.includes.libarrow_fs cimport * +from pyarrow._fs cimport FileSystem + + +cdef class AzureFileSystem(FileSystem): + """ + Azure Blob Storage backed FileSystem implementation + + This implementation supports flat namespace and hierarchical namespace (HNS) a.k.a. + Data Lake Gen2 storage accounts. HNS will be automatically detected and HNS specific + features will be used when they provide a performance advantage. Azurite emulator is + also supported. Note: `/` is the only supported delimiter. + + The storage account is considered the root of the filesystem. When enabled, containers + will be created or deleted during relevant directory operations. Obviously, this also + requires authentication with the additional permissions. + + By default `DefaultAzureCredential `__ + is used for authentication. This means it will try several types of authentication + and go with the first one that works. If any authentication parameters are provided when + initialising the FileSystem, they will be used instead of the default credential. + + Parameters + ---------- + account_name : str + Azure Blob Storage account name. This is the globally unique identifier for the + storage account. + account_key : str, default None + Account key of the storage account. Pass None to use default credential. + blob_storage_authority : str, default None + hostname[:port] of the Blob Service. Defaults to `.blob.core.windows.net`. Useful + for connecting to a local emulator, like Azurite. + dfs_storage_authority : str, default None + hostname[:port] of the Data Lake Gen 2 Service. Defaults to + `.dfs.core.windows.net`. Useful for connecting to a local emulator, like Azurite. + blob_storage_scheme : str, default None + Either `http` or `https`. Defaults to `https`. Useful for connecting to a local + emulator, like Azurite. + dfs_storage_scheme : str, default None + Either `http` or `https`. Defaults to `https`. Useful for connecting to a local + emulator, like Azurite. + + Examples + -------- + >>> from pyarrow import fs + >>> azure_fs = fs.AzureFileSystem(account_name='myaccount') + >>> azurite_fs = fs.AzureFileSystem( + ... account_name='devstoreaccount1', + ... account_key='Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==', + ... blob_storage_authority='127.0.0.1:10000', + ... dfs_storage_authority='127.0.0.1:10000', + ... blob_storage_scheme='http', + ... dfs_storage_scheme='http', + ... ) + + For usage of the methods see examples for :func:`~pyarrow.fs.LocalFileSystem`. + """ + cdef: + CAzureFileSystem* azurefs + c_string account_key + + def __init__(self, account_name, *, account_key=None, blob_storage_authority=None, + dfs_storage_authority=None, blob_storage_scheme=None, + dfs_storage_scheme=None): + cdef: + CAzureOptions options + shared_ptr[CAzureFileSystem] wrapped + + options.account_name = tobytes(account_name) + if blob_storage_authority: + options.blob_storage_authority = tobytes(blob_storage_authority) + if dfs_storage_authority: + options.dfs_storage_authority = tobytes(dfs_storage_authority) + if blob_storage_scheme: + options.blob_storage_scheme = tobytes(blob_storage_scheme) + if dfs_storage_scheme: + options.dfs_storage_scheme = tobytes(dfs_storage_scheme) + + if account_key: + options.ConfigureAccountKeyCredential(tobytes(account_key)) + self.account_key = tobytes(account_key) + else: + options.ConfigureDefaultCredential() + + with nogil: + wrapped = GetResultValue(CAzureFileSystem.Make(options)) + + self.init( wrapped) + + cdef init(self, const shared_ptr[CFileSystem]& wrapped): + FileSystem.init(self, wrapped) + self.azurefs = wrapped.get() + + @staticmethod + @binding(True) # Required for cython < 3 + def _reconstruct(kwargs): + # __reduce__ doesn't allow passing named arguments directly to the + # reconstructor, hence this wrapper. + return AzureFileSystem(**kwargs) + + def __reduce__(self): + cdef CAzureOptions opts = self.azurefs.options() + return ( + AzureFileSystem._reconstruct, (dict( + account_name=frombytes(opts.account_name), + account_key=frombytes(self.account_key), + blob_storage_authority=frombytes(opts.blob_storage_authority), + dfs_storage_authority=frombytes(opts.dfs_storage_authority), + blob_storage_scheme=frombytes(opts.blob_storage_scheme), + dfs_storage_scheme=frombytes(opts.dfs_storage_scheme) + ),)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_compute.pyx b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_compute.pyx new file mode 100644 index 0000000000000000000000000000000000000000..658f6b6cac4b5cc3dc4dce8a60b1bd7be027670b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_compute.pyx @@ -0,0 +1,3274 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +import sys + +from cpython.object cimport Py_LT, Py_EQ, Py_GT, Py_LE, Py_NE, Py_GE +from cython.operator cimport dereference as deref + +from collections import namedtuple + +from pyarrow.lib import frombytes, tobytes, ArrowInvalid +from pyarrow.lib cimport * +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * +import pyarrow.lib as lib +from pyarrow.util import _DEPR_MSG +from libcpp cimport bool as c_bool + +import inspect +try: + import numpy as np +except ImportError: + np = None +import warnings + + +__pas = None +_substrait_msg = ( + "The pyarrow installation is not built with support for Substrait." +) + + +SUPPORTED_INPUT_ARR_TYPES = (list, tuple) +if np is not None: + SUPPORTED_INPUT_ARR_TYPES += (np.ndarray, ) + + +def _pas(): + global __pas + if __pas is None: + try: + import pyarrow.substrait as pas + __pas = pas + except ImportError: + raise ImportError(_substrait_msg) + return __pas + + +def _forbid_instantiation(klass, subclasses_instead=True): + msg = '{} is an abstract class thus cannot be initialized.'.format( + klass.__name__ + ) + if subclasses_instead: + subclasses = [cls.__name__ for cls in klass.__subclasses__] + msg += ' Use one of the subclasses instead: {}'.format( + ', '.join(subclasses) + ) + raise TypeError(msg) + + +cdef wrap_scalar_function(const shared_ptr[CFunction]& sp_func): + """ + Wrap a C++ scalar Function in a ScalarFunction object. + """ + cdef ScalarFunction func = ScalarFunction.__new__(ScalarFunction) + func.init(sp_func) + return func + + +cdef wrap_vector_function(const shared_ptr[CFunction]& sp_func): + """ + Wrap a C++ vector Function in a VectorFunction object. + """ + cdef VectorFunction func = VectorFunction.__new__(VectorFunction) + func.init(sp_func) + return func + + +cdef wrap_scalar_aggregate_function(const shared_ptr[CFunction]& sp_func): + """ + Wrap a C++ aggregate Function in a ScalarAggregateFunction object. + """ + cdef ScalarAggregateFunction func = \ + ScalarAggregateFunction.__new__(ScalarAggregateFunction) + func.init(sp_func) + return func + + +cdef wrap_hash_aggregate_function(const shared_ptr[CFunction]& sp_func): + """ + Wrap a C++ aggregate Function in a HashAggregateFunction object. + """ + cdef HashAggregateFunction func = \ + HashAggregateFunction.__new__(HashAggregateFunction) + func.init(sp_func) + return func + + +cdef wrap_meta_function(const shared_ptr[CFunction]& sp_func): + """ + Wrap a C++ meta Function in a MetaFunction object. + """ + cdef MetaFunction func = MetaFunction.__new__(MetaFunction) + func.init(sp_func) + return func + + +cdef wrap_function(const shared_ptr[CFunction]& sp_func): + """ + Wrap a C++ Function in a Function object. + + This dispatches to specialized wrappers depending on the function kind. + """ + if sp_func.get() == NULL: + raise ValueError("Function was NULL") + + cdef FunctionKind c_kind = sp_func.get().kind() + if c_kind == FunctionKind_SCALAR: + return wrap_scalar_function(sp_func) + elif c_kind == FunctionKind_VECTOR: + return wrap_vector_function(sp_func) + elif c_kind == FunctionKind_SCALAR_AGGREGATE: + return wrap_scalar_aggregate_function(sp_func) + elif c_kind == FunctionKind_HASH_AGGREGATE: + return wrap_hash_aggregate_function(sp_func) + elif c_kind == FunctionKind_META: + return wrap_meta_function(sp_func) + else: + raise NotImplementedError("Unknown Function::Kind") + + +cdef wrap_scalar_kernel(const CScalarKernel* c_kernel): + if c_kernel == NULL: + raise ValueError("Kernel was NULL") + cdef ScalarKernel kernel = ScalarKernel.__new__(ScalarKernel) + kernel.init(c_kernel) + return kernel + + +cdef wrap_vector_kernel(const CVectorKernel* c_kernel): + if c_kernel == NULL: + raise ValueError("Kernel was NULL") + cdef VectorKernel kernel = VectorKernel.__new__(VectorKernel) + kernel.init(c_kernel) + return kernel + + +cdef wrap_scalar_aggregate_kernel(const CScalarAggregateKernel* c_kernel): + if c_kernel == NULL: + raise ValueError("Kernel was NULL") + cdef ScalarAggregateKernel kernel = \ + ScalarAggregateKernel.__new__(ScalarAggregateKernel) + kernel.init(c_kernel) + return kernel + + +cdef wrap_hash_aggregate_kernel(const CHashAggregateKernel* c_kernel): + if c_kernel == NULL: + raise ValueError("Kernel was NULL") + cdef HashAggregateKernel kernel = \ + HashAggregateKernel.__new__(HashAggregateKernel) + kernel.init(c_kernel) + return kernel + + +cdef class Kernel(_Weakrefable): + """ + A kernel object. + + Kernels handle the execution of a Function for a certain signature. + """ + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly" + .format(self.__class__.__name__)) + + +cdef class ScalarKernel(Kernel): + cdef const CScalarKernel* kernel + + cdef void init(self, const CScalarKernel* kernel) except *: + self.kernel = kernel + + def __repr__(self): + return ("ScalarKernel<{}>" + .format(frombytes(self.kernel.signature.get().ToString()))) + + +cdef class VectorKernel(Kernel): + cdef const CVectorKernel* kernel + + cdef void init(self, const CVectorKernel* kernel) except *: + self.kernel = kernel + + def __repr__(self): + return ("VectorKernel<{}>" + .format(frombytes(self.kernel.signature.get().ToString()))) + + +cdef class ScalarAggregateKernel(Kernel): + cdef const CScalarAggregateKernel* kernel + + cdef void init(self, const CScalarAggregateKernel* kernel) except *: + self.kernel = kernel + + def __repr__(self): + return ("ScalarAggregateKernel<{}>" + .format(frombytes(self.kernel.signature.get().ToString()))) + + +cdef class HashAggregateKernel(Kernel): + cdef const CHashAggregateKernel* kernel + + cdef void init(self, const CHashAggregateKernel* kernel) except *: + self.kernel = kernel + + def __repr__(self): + return ("HashAggregateKernel<{}>" + .format(frombytes(self.kernel.signature.get().ToString()))) + + +FunctionDoc = namedtuple( + "FunctionDoc", + ("summary", "description", "arg_names", "options_class", + "options_required")) + + +cdef class Function(_Weakrefable): + """ + A compute function. + + A function implements a certain logical computation over a range of + possible input signatures. Each signature accepts a range of input + types and is implemented by a given Kernel. + + Functions can be of different kinds: + + * "scalar" functions apply an item-wise computation over all items + of their inputs. Each item in the output only depends on the values + of the inputs at the same position. Examples: addition, comparisons, + string predicates... + + * "vector" functions apply a collection-wise computation, such that + each item in the output may depend on the values of several items + in each input. Examples: dictionary encoding, sorting, extracting + unique values... + + * "scalar_aggregate" functions reduce the dimensionality of the inputs by + applying a reduction function. Examples: sum, min_max, mode... + + * "hash_aggregate" functions apply a reduction function to an input + subdivided by grouping criteria. They may not be directly called. + Examples: hash_sum, hash_min_max... + + * "meta" functions dispatch to other functions. + """ + + cdef: + shared_ptr[CFunction] sp_func + CFunction* base_func + + _kind_map = { + FunctionKind_SCALAR: "scalar", + FunctionKind_VECTOR: "vector", + FunctionKind_SCALAR_AGGREGATE: "scalar_aggregate", + FunctionKind_HASH_AGGREGATE: "hash_aggregate", + FunctionKind_META: "meta", + } + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly" + .format(self.__class__.__name__)) + + cdef void init(self, const shared_ptr[CFunction]& sp_func) except *: + self.sp_func = sp_func + self.base_func = sp_func.get() + + def __repr__(self): + return ("arrow.compute.Function" + .format(self.name, self.kind, self.arity, self.num_kernels)) + + def __reduce__(self): + # Reduction uses the global registry + return get_function, (self.name,) + + @property + def name(self): + """ + The function name. + """ + return frombytes(self.base_func.name()) + + @property + def arity(self): + """ + The function arity. + + If Ellipsis (i.e. `...`) is returned, the function takes a variable + number of arguments. + """ + cdef CArity arity = self.base_func.arity() + if arity.is_varargs: + return ... + else: + return arity.num_args + + @property + def kind(self): + """ + The function kind. + """ + cdef FunctionKind c_kind = self.base_func.kind() + try: + return self._kind_map[c_kind] + except KeyError: + raise NotImplementedError("Unknown Function::Kind") + + @property + def _doc(self): + """ + The C++-like function documentation (for internal use). + """ + cdef CFunctionDoc c_doc = self.base_func.doc() + return FunctionDoc(frombytes(c_doc.summary), + frombytes(c_doc.description), + [frombytes(s) for s in c_doc.arg_names], + frombytes(c_doc.options_class), + c_doc.options_required) + + @property + def num_kernels(self): + """ + The number of kernels implementing this function. + """ + return self.base_func.num_kernels() + + def call(self, args, FunctionOptions options=None, + MemoryPool memory_pool=None, length=None): + """ + Call the function on the given arguments. + + Parameters + ---------- + args : iterable + The arguments to pass to the function. Accepted types depend + on the specific function. + options : FunctionOptions, optional + Options instance for executing this function. This should have + the right concrete options type. + memory_pool : pyarrow.MemoryPool, optional + If not passed, will allocate memory from the default memory pool. + length : int, optional + Batch size for execution, for nullary (no argument) functions. If + not passed, will be inferred from passed data. + """ + cdef: + const CFunctionOptions* c_options = NULL + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + CExecContext c_exec_ctx = CExecContext(pool) + CExecBatch c_batch + CDatum result + + _pack_compute_args(args, &c_batch.values) + + if options is not None: + c_options = options.get_options() + + if length is not None: + c_batch.length = length + with nogil: + result = GetResultValue( + self.base_func.Execute(c_batch, c_options, &c_exec_ctx) + ) + else: + with nogil: + result = GetResultValue( + self.base_func.Execute(c_batch.values, c_options, + &c_exec_ctx) + ) + + return wrap_datum(result) + + +cdef class ScalarFunction(Function): + cdef const CScalarFunction* func + + cdef void init(self, const shared_ptr[CFunction]& sp_func) except *: + Function.init(self, sp_func) + self.func = sp_func.get() + + @property + def kernels(self): + """ + The kernels implementing this function. + """ + cdef vector[const CScalarKernel*] kernels = self.func.kernels() + return [wrap_scalar_kernel(k) for k in kernels] + + +cdef class VectorFunction(Function): + cdef const CVectorFunction* func + + cdef void init(self, const shared_ptr[CFunction]& sp_func) except *: + Function.init(self, sp_func) + self.func = sp_func.get() + + @property + def kernels(self): + """ + The kernels implementing this function. + """ + cdef vector[const CVectorKernel*] kernels = self.func.kernels() + return [wrap_vector_kernel(k) for k in kernels] + + +cdef class ScalarAggregateFunction(Function): + cdef const CScalarAggregateFunction* func + + cdef void init(self, const shared_ptr[CFunction]& sp_func) except *: + Function.init(self, sp_func) + self.func = sp_func.get() + + @property + def kernels(self): + """ + The kernels implementing this function. + """ + cdef vector[const CScalarAggregateKernel*] kernels = \ + self.func.kernels() + return [wrap_scalar_aggregate_kernel(k) for k in kernels] + + +cdef class HashAggregateFunction(Function): + cdef const CHashAggregateFunction* func + + cdef void init(self, const shared_ptr[CFunction]& sp_func) except *: + Function.init(self, sp_func) + self.func = sp_func.get() + + @property + def kernels(self): + """ + The kernels implementing this function. + """ + cdef vector[const CHashAggregateKernel*] kernels = self.func.kernels() + return [wrap_hash_aggregate_kernel(k) for k in kernels] + + +cdef class MetaFunction(Function): + cdef const CMetaFunction* func + + cdef void init(self, const shared_ptr[CFunction]& sp_func) except *: + Function.init(self, sp_func) + self.func = sp_func.get() + + # Since num_kernels is exposed, also expose a kernels property + @property + def kernels(self): + """ + The kernels implementing this function. + """ + return [] + + +cdef _pack_compute_args(object values, vector[CDatum]* out): + for val in values: + if isinstance(val, SUPPORTED_INPUT_ARR_TYPES): + val = lib.asarray(val) + + if isinstance(val, Array): + out.push_back(CDatum(( val).sp_array)) + continue + elif isinstance(val, ChunkedArray): + out.push_back(CDatum(( val).sp_chunked_array)) + continue + elif isinstance(val, Scalar): + out.push_back(CDatum(( val).unwrap())) + continue + elif isinstance(val, RecordBatch): + out.push_back(CDatum(( val).sp_batch)) + continue + elif isinstance(val, Table): + out.push_back(CDatum(( val).sp_table)) + continue + else: + # Is it a Python scalar? + try: + scal = lib.scalar(val) + except Exception: + # Raise dedicated error below + pass + else: + out.push_back(CDatum(( scal).unwrap())) + continue + + raise TypeError(f"Got unexpected argument type {type(val)} " + "for compute function") + + +cdef class FunctionRegistry(_Weakrefable): + cdef CFunctionRegistry* registry + + def __init__(self): + self.registry = GetFunctionRegistry() + + def list_functions(self): + """ + Return all function names in the registry. + """ + cdef vector[c_string] names = self.registry.GetFunctionNames() + return [frombytes(name) for name in names] + + def get_function(self, name): + """ + Look up a function by name in the registry. + + Parameters + ---------- + name : str + The name of the function to lookup + """ + cdef: + c_string c_name = tobytes(name) + shared_ptr[CFunction] func + with nogil: + func = GetResultValue(self.registry.GetFunction(c_name)) + return wrap_function(func) + + +cdef FunctionRegistry _global_func_registry = FunctionRegistry() + + +def function_registry(): + return _global_func_registry + + +def get_function(name): + """ + Get a function by name. + + The function is looked up in the global registry + (as returned by `function_registry()`). + + Parameters + ---------- + name : str + The name of the function to lookup + """ + return _global_func_registry.get_function(name) + + +def list_functions(): + """ + Return all function names in the global registry. + """ + return _global_func_registry.list_functions() + + +def call_function(name, args, options=None, memory_pool=None, length=None): + """ + Call a named function. + + The function is looked up in the global registry + (as returned by `function_registry()`). + + Parameters + ---------- + name : str + The name of the function to call. + args : list + The arguments to the function. + options : optional + options provided to the function. + memory_pool : MemoryPool, optional + memory pool to use for allocations during function execution. + length : int, optional + Batch size for execution, for nullary (no argument) functions. If not + passed, inferred from data. + """ + func = _global_func_registry.get_function(name) + return func.call(args, options=options, memory_pool=memory_pool, + length=length) + + +cdef class FunctionOptions(_Weakrefable): + __slots__ = () # avoid mistakingly creating attributes + + cdef const CFunctionOptions* get_options(self) except NULL: + return self.wrapped.get() + + cdef void init(self, const shared_ptr[CFunctionOptions]& sp): + self.wrapped = sp + + cdef inline shared_ptr[CFunctionOptions] unwrap(self): + return self.wrapped + + def serialize(self): + cdef: + CResult[shared_ptr[CBuffer]] res = self.get_options().Serialize() + shared_ptr[CBuffer] c_buf = GetResultValue(res) + return pyarrow_wrap_buffer(c_buf) + + @staticmethod + def deserialize(buf): + """ + Deserialize options for a function. + + Parameters + ---------- + buf : Buffer + The buffer containing the data to deserialize. + """ + cdef: + shared_ptr[CBuffer] c_buf = pyarrow_unwrap_buffer(buf) + CResult[unique_ptr[CFunctionOptions]] maybe_options = \ + DeserializeFunctionOptions(deref(c_buf)) + shared_ptr[CFunctionOptions] c_options + c_options = to_shared(GetResultValue(move(maybe_options))) + type_name = frombytes(c_options.get().options_type().type_name()) + module = globals() + if type_name not in module: + raise ValueError(f'Cannot deserialize "{type_name}"') + klass = module[type_name] + options = klass.__new__(klass) + ( options).init(c_options) + return options + + def __repr__(self): + type_name = self.__class__.__name__ + # Remove {} so we can use our own braces + string_repr = frombytes(self.get_options().ToString())[1:-1] + return f"{type_name}({string_repr})" + + def __eq__(self, FunctionOptions other): + return self.get_options().Equals(deref(other.get_options())) + + +def _raise_invalid_function_option(value, description, *, + exception_class=ValueError): + raise exception_class(f"\"{value}\" is not a valid {description}") + + +# NOTE: +# To properly expose the constructor signature of FunctionOptions +# subclasses, we use a two-level inheritance: +# 1. a C extension class that implements option validation and setting +# (won't expose function signatures because of +# https://github.com/cython/cython/issues/3873) +# 2. a Python derived class that implements the constructor + +cdef class _CastOptions(FunctionOptions): + cdef CCastOptions* options + + cdef void init(self, const shared_ptr[CFunctionOptions]& sp): + FunctionOptions.init(self, sp) + self.options = self.wrapped.get() + + def _set_options(self, DataType target_type, allow_int_overflow, + allow_time_truncate, allow_time_overflow, + allow_decimal_truncate, allow_float_truncate, + allow_invalid_utf8): + cdef: + shared_ptr[CCastOptions] wrapped = make_shared[CCastOptions]() + self.init( wrapped) + self._set_type(target_type) + if allow_int_overflow is not None: + self.allow_int_overflow = allow_int_overflow + if allow_time_truncate is not None: + self.allow_time_truncate = allow_time_truncate + if allow_time_overflow is not None: + self.allow_time_overflow = allow_time_overflow + if allow_decimal_truncate is not None: + self.allow_decimal_truncate = allow_decimal_truncate + if allow_float_truncate is not None: + self.allow_float_truncate = allow_float_truncate + if allow_invalid_utf8 is not None: + self.allow_invalid_utf8 = allow_invalid_utf8 + + def _set_type(self, target_type=None): + if target_type is not None: + deref(self.options).to_type = \ + ( ensure_type(target_type)).sp_type + + def _set_safe(self): + self.init(shared_ptr[CFunctionOptions]( + new CCastOptions(CCastOptions.Safe()))) + + def _set_unsafe(self): + self.init(shared_ptr[CFunctionOptions]( + new CCastOptions(CCastOptions.Unsafe()))) + + def is_safe(self): + return not (deref(self.options).allow_int_overflow or + deref(self.options).allow_time_truncate or + deref(self.options).allow_time_overflow or + deref(self.options).allow_decimal_truncate or + deref(self.options).allow_float_truncate or + deref(self.options).allow_invalid_utf8) + + @property + def allow_int_overflow(self): + return deref(self.options).allow_int_overflow + + @allow_int_overflow.setter + def allow_int_overflow(self, c_bool flag): + deref(self.options).allow_int_overflow = flag + + @property + def allow_time_truncate(self): + return deref(self.options).allow_time_truncate + + @allow_time_truncate.setter + def allow_time_truncate(self, c_bool flag): + deref(self.options).allow_time_truncate = flag + + @property + def allow_time_overflow(self): + return deref(self.options).allow_time_overflow + + @allow_time_overflow.setter + def allow_time_overflow(self, c_bool flag): + deref(self.options).allow_time_overflow = flag + + @property + def allow_decimal_truncate(self): + return deref(self.options).allow_decimal_truncate + + @allow_decimal_truncate.setter + def allow_decimal_truncate(self, c_bool flag): + deref(self.options).allow_decimal_truncate = flag + + @property + def allow_float_truncate(self): + return deref(self.options).allow_float_truncate + + @allow_float_truncate.setter + def allow_float_truncate(self, c_bool flag): + deref(self.options).allow_float_truncate = flag + + @property + def allow_invalid_utf8(self): + return deref(self.options).allow_invalid_utf8 + + @allow_invalid_utf8.setter + def allow_invalid_utf8(self, c_bool flag): + deref(self.options).allow_invalid_utf8 = flag + + +class CastOptions(_CastOptions): + """ + Options for the `cast` function. + + Parameters + ---------- + target_type : DataType, optional + The PyArrow type to cast to. + allow_int_overflow : bool, default False + Whether integer overflow is allowed when casting. + allow_time_truncate : bool, default False + Whether time precision truncation is allowed when casting. + allow_time_overflow : bool, default False + Whether date/time range overflow is allowed when casting. + allow_decimal_truncate : bool, default False + Whether decimal precision truncation is allowed when casting. + allow_float_truncate : bool, default False + Whether floating-point precision truncation is allowed when casting. + allow_invalid_utf8 : bool, default False + Whether producing invalid utf8 data is allowed when casting. + """ + + def __init__(self, target_type=None, *, allow_int_overflow=None, + allow_time_truncate=None, allow_time_overflow=None, + allow_decimal_truncate=None, allow_float_truncate=None, + allow_invalid_utf8=None): + self._set_options(target_type, allow_int_overflow, allow_time_truncate, + allow_time_overflow, allow_decimal_truncate, + allow_float_truncate, allow_invalid_utf8) + + @staticmethod + def safe(target_type=None): + """" + Create a CastOptions for a safe cast. + + Parameters + ---------- + target_type : optional + Target cast type for the safe cast. + """ + self = CastOptions() + self._set_safe() + self._set_type(target_type) + return self + + @staticmethod + def unsafe(target_type=None): + """" + Create a CastOptions for an unsafe cast. + + Parameters + ---------- + target_type : optional + Target cast type for the unsafe cast. + """ + self = CastOptions() + self._set_unsafe() + self._set_type(target_type) + return self + + +def _skip_nulls_doc(): + # (note the weird indent because of how the string is inserted + # by callers) + return """skip_nulls : bool, default True + Whether to skip (ignore) nulls in the input. + If False, any null in the input forces the output to null. +""" + + +def _min_count_doc(*, default): + return f"""min_count : int, default {default} + Minimum number of non-null values in the input. If the number + of non-null values is below `min_count`, the output is null. +""" + + +cdef class _ElementWiseAggregateOptions(FunctionOptions): + def _set_options(self, skip_nulls): + self.wrapped.reset(new CElementWiseAggregateOptions(skip_nulls)) + + +class ElementWiseAggregateOptions(_ElementWiseAggregateOptions): + __doc__ = f""" + Options for element-wise aggregate functions. + + Parameters + ---------- + {_skip_nulls_doc()} + """ + + def __init__(self, *, skip_nulls=True): + self._set_options(skip_nulls) + + +cdef CRoundMode unwrap_round_mode(round_mode) except *: + if round_mode == "down": + return CRoundMode_DOWN + elif round_mode == "up": + return CRoundMode_UP + elif round_mode == "towards_zero": + return CRoundMode_TOWARDS_ZERO + elif round_mode == "towards_infinity": + return CRoundMode_TOWARDS_INFINITY + elif round_mode == "half_down": + return CRoundMode_HALF_DOWN + elif round_mode == "half_up": + return CRoundMode_HALF_UP + elif round_mode == "half_towards_zero": + return CRoundMode_HALF_TOWARDS_ZERO + elif round_mode == "half_towards_infinity": + return CRoundMode_HALF_TOWARDS_INFINITY + elif round_mode == "half_to_even": + return CRoundMode_HALF_TO_EVEN + elif round_mode == "half_to_odd": + return CRoundMode_HALF_TO_ODD + _raise_invalid_function_option(round_mode, "round mode") + + +cdef class _RoundOptions(FunctionOptions): + def _set_options(self, ndigits, round_mode): + self.wrapped.reset( + new CRoundOptions(ndigits, unwrap_round_mode(round_mode)) + ) + + +class RoundOptions(_RoundOptions): + """ + Options for rounding numbers. + + Parameters + ---------- + ndigits : int, default 0 + Number of fractional digits to round to. + round_mode : str, default "half_to_even" + Rounding and tie-breaking mode. + Accepted values are "down", "up", "towards_zero", "towards_infinity", + "half_down", "half_up", "half_towards_zero", "half_towards_infinity", + "half_to_even", "half_to_odd". + """ + + def __init__(self, ndigits=0, round_mode="half_to_even"): + self._set_options(ndigits, round_mode) + + +cdef class _RoundBinaryOptions(FunctionOptions): + def _set_options(self, round_mode): + self.wrapped.reset( + new CRoundBinaryOptions(unwrap_round_mode(round_mode)) + ) + + +class RoundBinaryOptions(_RoundBinaryOptions): + """ + Options for rounding numbers when ndigits is provided by a second array + + Parameters + ---------- + round_mode : str, default "half_to_even" + Rounding and tie-breaking mode. + Accepted values are "down", "up", "towards_zero", "towards_infinity", + "half_down", "half_up", "half_towards_zero", "half_towards_infinity", + "half_to_even", "half_to_odd". + """ + + def __init__(self, round_mode="half_to_even"): + self._set_options(round_mode) + + +cdef CCalendarUnit unwrap_round_temporal_unit(unit) except *: + if unit == "nanosecond": + return CCalendarUnit_NANOSECOND + elif unit == "microsecond": + return CCalendarUnit_MICROSECOND + elif unit == "millisecond": + return CCalendarUnit_MILLISECOND + elif unit == "second": + return CCalendarUnit_SECOND + elif unit == "minute": + return CCalendarUnit_MINUTE + elif unit == "hour": + return CCalendarUnit_HOUR + elif unit == "day": + return CCalendarUnit_DAY + elif unit == "week": + return CCalendarUnit_WEEK + elif unit == "month": + return CCalendarUnit_MONTH + elif unit == "quarter": + return CCalendarUnit_QUARTER + elif unit == "year": + return CCalendarUnit_YEAR + _raise_invalid_function_option(unit, "Calendar unit") + + +cdef class _RoundTemporalOptions(FunctionOptions): + def _set_options(self, multiple, unit, week_starts_monday, + ceil_is_strictly_greater, calendar_based_origin): + self.wrapped.reset( + new CRoundTemporalOptions( + multiple, unwrap_round_temporal_unit(unit), + week_starts_monday, ceil_is_strictly_greater, + calendar_based_origin) + ) + + +class RoundTemporalOptions(_RoundTemporalOptions): + """ + Options for rounding temporal values. + + Parameters + ---------- + multiple : int, default 1 + Number of units to round to. + unit : str, default "day" + The unit in which `multiple` is expressed. + Accepted values are "year", "quarter", "month", "week", "day", + "hour", "minute", "second", "millisecond", "microsecond", + "nanosecond". + week_starts_monday : bool, default True + If True, weeks start on Monday; if False, on Sunday. + ceil_is_strictly_greater : bool, default False + If True, ceil returns a rounded value that is strictly greater than the + input. For example: ceiling 1970-01-01T00:00:00 to 3 hours would + yield 1970-01-01T03:00:00 if set to True and 1970-01-01T00:00:00 + if set to False. + This applies to the ceil_temporal function only. + calendar_based_origin : bool, default False + By default, the origin is 1970-01-01T00:00:00. By setting this to True, + rounding origin will be beginning of one less precise calendar unit. + E.g.: rounding to hours will use beginning of day as origin. + + By default time is rounded to a multiple of units since + 1970-01-01T00:00:00. By setting calendar_based_origin to true, + time will be rounded to number of units since the last greater + calendar unit. + For example: rounding to multiple of days since the beginning of the + month or to hours since the beginning of the day. + Exceptions: week and quarter are not used as greater units, + therefore days will be rounded to the beginning of the month not + week. Greater unit of week is a year. + Note that ceiling and rounding might change sorting order of an array + near greater unit change. For example rounding YYYY-mm-dd 23:00:00 to + 5 hours will ceil and round to YYYY-mm-dd+1 01:00:00 and floor to + YYYY-mm-dd 20:00:00. On the other hand YYYY-mm-dd+1 00:00:00 will + ceil, round and floor to YYYY-mm-dd+1 00:00:00. This can break the + order of an already ordered array. + + """ + + def __init__(self, multiple=1, unit="day", *, week_starts_monday=True, + ceil_is_strictly_greater=False, + calendar_based_origin=False): + self._set_options(multiple, unit, week_starts_monday, + ceil_is_strictly_greater, + calendar_based_origin) + + +cdef class _RoundToMultipleOptions(FunctionOptions): + def _set_options(self, multiple, round_mode): + if not isinstance(multiple, Scalar): + try: + multiple = lib.scalar(multiple) + except Exception: + _raise_invalid_function_option( + multiple, "multiple type for RoundToMultipleOptions", + exception_class=TypeError) + + self.wrapped.reset( + new CRoundToMultipleOptions( + pyarrow_unwrap_scalar(multiple), unwrap_round_mode(round_mode)) + ) + + +class RoundToMultipleOptions(_RoundToMultipleOptions): + """ + Options for rounding numbers to a multiple. + + Parameters + ---------- + multiple : numeric scalar, default 1.0 + Multiple to round to. Should be a scalar of a type compatible + with the argument to be rounded. + round_mode : str, default "half_to_even" + Rounding and tie-breaking mode. + Accepted values are "down", "up", "towards_zero", "towards_infinity", + "half_down", "half_up", "half_towards_zero", "half_towards_infinity", + "half_to_even", "half_to_odd". + """ + + def __init__(self, multiple=1.0, round_mode="half_to_even"): + self._set_options(multiple, round_mode) + + +cdef class _JoinOptions(FunctionOptions): + _null_handling_map = { + "emit_null": CJoinNullHandlingBehavior_EMIT_NULL, + "skip": CJoinNullHandlingBehavior_SKIP, + "replace": CJoinNullHandlingBehavior_REPLACE, + } + + def _set_options(self, null_handling, null_replacement): + try: + self.wrapped.reset( + new CJoinOptions(self._null_handling_map[null_handling], + tobytes(null_replacement)) + ) + except KeyError: + _raise_invalid_function_option(null_handling, "null handling") + + +class JoinOptions(_JoinOptions): + """ + Options for the `binary_join_element_wise` function. + + Parameters + ---------- + null_handling : str, default "emit_null" + How to handle null values in the inputs. + Accepted values are "emit_null", "skip", "replace". + null_replacement : str, default "" + Replacement string to emit for null inputs if `null_handling` + is "replace". + """ + + def __init__(self, null_handling="emit_null", null_replacement=""): + self._set_options(null_handling, null_replacement) + + +cdef class _MatchSubstringOptions(FunctionOptions): + def _set_options(self, pattern, ignore_case): + self.wrapped.reset( + new CMatchSubstringOptions(tobytes(pattern), ignore_case) + ) + + +class MatchSubstringOptions(_MatchSubstringOptions): + """ + Options for looking for a substring. + + Parameters + ---------- + pattern : str + Substring pattern to look for inside input values. + ignore_case : bool, default False + Whether to perform a case-insensitive match. + """ + + def __init__(self, pattern, *, ignore_case=False): + self._set_options(pattern, ignore_case) + + +cdef class _PadOptions(FunctionOptions): + def _set_options(self, width, padding, lean_left_on_odd_padding): + self.wrapped.reset(new CPadOptions(width, tobytes(padding), lean_left_on_odd_padding)) + + +class PadOptions(_PadOptions): + """ + Options for padding strings. + + Parameters + ---------- + width : int + Desired string length. + padding : str, default " " + What to pad the string with. Should be one byte or codepoint. + lean_left_on_odd_padding : bool, default True + What to do if there is an odd number of padding characters (in case + of centered padding). Defaults to aligning on the left (i.e. adding + the extra padding character on the right). + """ + + def __init__(self, width, padding=' ', lean_left_on_odd_padding=True): + self._set_options(width, padding, lean_left_on_odd_padding) + + +cdef class _TrimOptions(FunctionOptions): + def _set_options(self, characters): + self.wrapped.reset(new CTrimOptions(tobytes(characters))) + + +class TrimOptions(_TrimOptions): + """ + Options for trimming characters from strings. + + Parameters + ---------- + characters : str + Individual characters to be trimmed from the string. + """ + + def __init__(self, characters): + self._set_options(tobytes(characters)) + + +cdef class _ReplaceSubstringOptions(FunctionOptions): + def _set_options(self, pattern, replacement, max_replacements): + self.wrapped.reset( + new CReplaceSubstringOptions(tobytes(pattern), + tobytes(replacement), + max_replacements) + ) + + +class ReplaceSubstringOptions(_ReplaceSubstringOptions): + """ + Options for replacing matched substrings. + + Parameters + ---------- + pattern : str + Substring pattern to look for inside input values. + replacement : str + What to replace the pattern with. + max_replacements : int or None, default None + The maximum number of strings to replace in each + input value (unlimited if None). + """ + + def __init__(self, pattern, replacement, *, max_replacements=None): + if max_replacements is None: + max_replacements = -1 + self._set_options(pattern, replacement, max_replacements) + + +cdef class _ExtractRegexOptions(FunctionOptions): + def _set_options(self, pattern): + self.wrapped.reset(new CExtractRegexOptions(tobytes(pattern))) + + +class ExtractRegexOptions(_ExtractRegexOptions): + """ + Options for the `extract_regex` function. + + Parameters + ---------- + pattern : str + Regular expression with named capture fields. + """ + + def __init__(self, pattern): + self._set_options(pattern) + + +cdef class _SliceOptions(FunctionOptions): + def _set_options(self, start, stop, step): + self.wrapped.reset(new CSliceOptions(start, stop, step)) + + +class SliceOptions(_SliceOptions): + """ + Options for slicing. + + Parameters + ---------- + start : int + Index to start slicing at (inclusive). + stop : int or None, default None + If given, index to stop slicing at (exclusive). + If not given, slicing will stop at the end. + step : int, default 1 + Slice step. + """ + + def __init__(self, start, stop=None, step=1): + if stop is None: + stop = sys.maxsize + if step < 0: + stop = -stop + self._set_options(start, stop, step) + + +cdef class _ListSliceOptions(FunctionOptions): + cpdef _set_options(self, start, stop=None, step=1, return_fixed_size_list=None): + cdef: + CListSliceOptions* opts + opts = new CListSliceOptions( + start, + nullopt if stop is None + else (stop), + step, + nullopt if return_fixed_size_list is None + else (return_fixed_size_list) + ) + self.wrapped.reset(opts) + + +class ListSliceOptions(_ListSliceOptions): + """ + Options for list array slicing. + + Parameters + ---------- + start : int + Index to start slicing inner list elements (inclusive). + stop : Optional[int], default None + If given, index to stop slicing at (exclusive). + If not given, slicing will stop at the end. (NotImplemented) + step : int, default 1 + Slice step. + return_fixed_size_list : Optional[bool], default None + Whether to return a FixedSizeListArray. If true _and_ stop is after + a list element's length, nulls will be appended to create the + requested slice size. The default of `None` will return the same + type which was passed in. + """ + + def __init__(self, start, stop=None, step=1, return_fixed_size_list=None): + self._set_options(start, stop, step, return_fixed_size_list) + + +cdef class _ReplaceSliceOptions(FunctionOptions): + def _set_options(self, start, stop, replacement): + self.wrapped.reset( + new CReplaceSliceOptions(start, stop, tobytes(replacement)) + ) + + +class ReplaceSliceOptions(_ReplaceSliceOptions): + """ + Options for replacing slices. + + Parameters + ---------- + start : int + Index to start slicing at (inclusive). + stop : int + Index to stop slicing at (exclusive). + replacement : str + What to replace the slice with. + """ + + def __init__(self, start, stop, replacement): + self._set_options(start, stop, replacement) + + +cdef class _FilterOptions(FunctionOptions): + _null_selection_map = { + "drop": CFilterNullSelectionBehavior_DROP, + "emit_null": CFilterNullSelectionBehavior_EMIT_NULL, + } + + def _set_options(self, null_selection_behavior): + try: + self.wrapped.reset( + new CFilterOptions( + self._null_selection_map[null_selection_behavior] + ) + ) + except KeyError: + _raise_invalid_function_option(null_selection_behavior, + "null selection behavior") + + +class FilterOptions(_FilterOptions): + """ + Options for selecting with a boolean filter. + + Parameters + ---------- + null_selection_behavior : str, default "drop" + How to handle nulls in the selection filter. + Accepted values are "drop", "emit_null". + """ + + def __init__(self, null_selection_behavior="drop"): + self._set_options(null_selection_behavior) + + +cdef class _DictionaryEncodeOptions(FunctionOptions): + _null_encoding_map = { + "encode": CDictionaryEncodeNullEncodingBehavior_ENCODE, + "mask": CDictionaryEncodeNullEncodingBehavior_MASK, + } + + def _set_options(self, null_encoding): + try: + self.wrapped.reset( + new CDictionaryEncodeOptions( + self._null_encoding_map[null_encoding] + ) + ) + except KeyError: + _raise_invalid_function_option(null_encoding, "null encoding") + + +class DictionaryEncodeOptions(_DictionaryEncodeOptions): + """ + Options for dictionary encoding. + + Parameters + ---------- + null_encoding : str, default "mask" + How to encode nulls in the input. + Accepted values are "mask" (null inputs emit a null in the indices + array), "encode" (null inputs emit a non-null index pointing to + a null value in the dictionary array). + """ + + def __init__(self, null_encoding="mask"): + self._set_options(null_encoding) + + +cdef class _RunEndEncodeOptions(FunctionOptions): + def _set_options(self, run_end_type): + run_end_ty = ensure_type(run_end_type) + self.wrapped.reset(new CRunEndEncodeOptions(pyarrow_unwrap_data_type(run_end_ty))) + + +class RunEndEncodeOptions(_RunEndEncodeOptions): + """ + Options for run-end encoding. + + Parameters + ---------- + run_end_type : DataType, default pyarrow.int32() + The data type of the run_ends array. + + Accepted values are pyarrow.{int16(), int32(), int64()}. + """ + + def __init__(self, run_end_type=lib.int32()): + self._set_options(run_end_type) + + +cdef class _TakeOptions(FunctionOptions): + def _set_options(self, boundscheck): + self.wrapped.reset(new CTakeOptions(boundscheck)) + + +class TakeOptions(_TakeOptions): + """ + Options for the `take` and `array_take` functions. + + Parameters + ---------- + boundscheck : boolean, default True + Whether to check indices are within bounds. If False and an + index is out of bounds, behavior is undefined (the process + may crash). + """ + + def __init__(self, *, boundscheck=True): + self._set_options(boundscheck) + + +cdef class _MakeStructOptions(FunctionOptions): + def _set_options(self, field_names, field_nullability, field_metadata): + cdef: + vector[c_string] c_field_names + vector[shared_ptr[const CKeyValueMetadata]] c_field_metadata + for name in field_names: + c_field_names.push_back(tobytes(name)) + for metadata in field_metadata: + c_field_metadata.push_back(pyarrow_unwrap_metadata(metadata)) + self.wrapped.reset( + new CMakeStructOptions(c_field_names, field_nullability, + c_field_metadata) + ) + + +class MakeStructOptions(_MakeStructOptions): + """ + Options for the `make_struct` function. + + Parameters + ---------- + field_names : sequence of str + Names of the struct fields to create. + field_nullability : sequence of bool, optional + Nullability information for each struct field. + If omitted, all fields are nullable. + field_metadata : sequence of KeyValueMetadata, optional + Metadata for each struct field. + """ + + def __init__(self, field_names=(), *, field_nullability=None, + field_metadata=None): + if field_nullability is None: + field_nullability = [True] * len(field_names) + if field_metadata is None: + field_metadata = [None] * len(field_names) + self._set_options(field_names, field_nullability, field_metadata) + + +cdef CFieldRef _ensure_field_ref(value) except *: + cdef: + CFieldRef field_ref + const CFieldRef* field_ref_ptr + + if isinstance(value, (list, tuple)): + value = Expression._nested_field(tuple(value)) + + if isinstance(value, Expression): + field_ref_ptr = (value).unwrap().field_ref() + if field_ref_ptr is NULL: + raise ValueError("Unable to get FieldRef from Expression") + field_ref = deref(field_ref_ptr) + elif isinstance(value, (bytes, str)): + if value.startswith(b'.' if isinstance(value, bytes) else '.'): + field_ref = GetResultValue( + CFieldRef.FromDotPath(tobytes(value))) + else: + field_ref = CFieldRef(tobytes(value)) + elif isinstance(value, int): + field_ref = CFieldRef( value) + else: + raise TypeError("Expected a field reference as a str or int, list of " + f"str or int, or Expression. Got {type(value)} instead.") + return field_ref + + +cdef class _StructFieldOptions(FunctionOptions): + def _set_options(self, indices): + + if isinstance(indices, (list, tuple)) and not len(indices): + # Allow empty indices; effectively return same array + self.wrapped.reset( + new CStructFieldOptions(indices)) + return + + cdef CFieldRef field_ref = _ensure_field_ref(indices) + self.wrapped.reset(new CStructFieldOptions(field_ref)) + + +class StructFieldOptions(_StructFieldOptions): + """ + Options for the `struct_field` function. + + Parameters + ---------- + indices : List[str], List[bytes], List[int], Expression, bytes, str, or int + List of indices for chained field lookup, for example `[4, 1]` + will look up the second nested field in the fifth outer field. + """ + + def __init__(self, indices): + self._set_options(indices) + + +cdef class _ScalarAggregateOptions(FunctionOptions): + def _set_options(self, skip_nulls, min_count): + self.wrapped.reset(new CScalarAggregateOptions(skip_nulls, min_count)) + + +class ScalarAggregateOptions(_ScalarAggregateOptions): + __doc__ = f""" + Options for scalar aggregations. + + Parameters + ---------- + {_skip_nulls_doc()} + {_min_count_doc(default=1)} + """ + + def __init__(self, *, skip_nulls=True, min_count=1): + self._set_options(skip_nulls, min_count) + + +cdef class _CountOptions(FunctionOptions): + _mode_map = { + "only_valid": CCountMode_ONLY_VALID, + "only_null": CCountMode_ONLY_NULL, + "all": CCountMode_ALL, + } + + def _set_options(self, mode): + try: + self.wrapped.reset(new CCountOptions(self._mode_map[mode])) + except KeyError: + _raise_invalid_function_option(mode, "count mode") + + +class CountOptions(_CountOptions): + """ + Options for the `count` function. + + Parameters + ---------- + mode : str, default "only_valid" + Which values to count in the input. + Accepted values are "only_valid", "only_null", "all". + """ + + def __init__(self, mode="only_valid"): + self._set_options(mode) + + +cdef class _IndexOptions(FunctionOptions): + def _set_options(self, scalar): + self.wrapped.reset(new CIndexOptions(pyarrow_unwrap_scalar(scalar))) + + +class IndexOptions(_IndexOptions): + """ + Options for the `index` function. + + Parameters + ---------- + value : Scalar + The value to search for. + """ + + def __init__(self, value): + self._set_options(value) + + +cdef class _MapLookupOptions(FunctionOptions): + _occurrence_map = { + "all": CMapLookupOccurrence_ALL, + "first": CMapLookupOccurrence_FIRST, + "last": CMapLookupOccurrence_LAST, + } + + def _set_options(self, query_key, occurrence): + try: + self.wrapped.reset( + new CMapLookupOptions( + pyarrow_unwrap_scalar(query_key), + self._occurrence_map[occurrence] + ) + ) + except KeyError: + _raise_invalid_function_option(occurrence, + "Should either be first, last, or all") + + +class MapLookupOptions(_MapLookupOptions): + """ + Options for the `map_lookup` function. + + Parameters + ---------- + query_key : Scalar or Object can be converted to Scalar + The key to search for. + occurrence : str + The occurrence(s) to return from the Map + Accepted values are "first", "last", or "all". + """ + + def __init__(self, query_key, occurrence): + if not isinstance(query_key, lib.Scalar): + query_key = lib.scalar(query_key) + + self._set_options(query_key, occurrence) + + +cdef class _ModeOptions(FunctionOptions): + def _set_options(self, n, skip_nulls, min_count): + self.wrapped.reset(new CModeOptions(n, skip_nulls, min_count)) + + +class ModeOptions(_ModeOptions): + __doc__ = f""" + Options for the `mode` function. + + Parameters + ---------- + n : int, default 1 + Number of distinct most-common values to return. + {_skip_nulls_doc()} + {_min_count_doc(default=0)} + """ + + def __init__(self, n=1, *, skip_nulls=True, min_count=0): + self._set_options(n, skip_nulls, min_count) + + +cdef class _SetLookupOptions(FunctionOptions): + def _set_options(self, value_set, c_bool skip_nulls): + cdef unique_ptr[CDatum] valset + if isinstance(value_set, Array): + valset.reset(new CDatum(( value_set).sp_array)) + elif isinstance(value_set, ChunkedArray): + valset.reset( + new CDatum(( value_set).sp_chunked_array) + ) + elif isinstance(value_set, Scalar): + valset.reset(new CDatum(( value_set).unwrap())) + else: + _raise_invalid_function_option(value_set, "value set", + exception_class=TypeError) + + self.wrapped.reset(new CSetLookupOptions(deref(valset), skip_nulls)) + + +class SetLookupOptions(_SetLookupOptions): + """ + Options for the `is_in` and `index_in` functions. + + Parameters + ---------- + value_set : Array + Set of values to look for in the input. + skip_nulls : bool, default False + If False, nulls in the input are matched in the value_set just + like regular values. + If True, nulls in the input always fail matching. + """ + + def __init__(self, value_set, *, skip_nulls=False): + self._set_options(value_set, skip_nulls) + + +cdef class _StrptimeOptions(FunctionOptions): + _unit_map = { + "s": TimeUnit_SECOND, + "ms": TimeUnit_MILLI, + "us": TimeUnit_MICRO, + "ns": TimeUnit_NANO, + } + + def _set_options(self, format, unit, error_is_null): + try: + self.wrapped.reset( + new CStrptimeOptions(tobytes(format), self._unit_map[unit], + error_is_null) + ) + except KeyError: + _raise_invalid_function_option(unit, "time unit") + + +class StrptimeOptions(_StrptimeOptions): + """ + Options for the `strptime` function. + + Parameters + ---------- + format : str + Pattern for parsing input strings as timestamps, such as "%Y/%m/%d". + Note that the semantics of the format follow the C/C++ strptime, not the Python one. + There are differences in behavior, for example how the "%y" placeholder + handles years with less than four digits. + unit : str + Timestamp unit of the output. + Accepted values are "s", "ms", "us", "ns". + error_is_null : boolean, default False + Return null on parsing errors if true or raise if false. + """ + + def __init__(self, format, unit, error_is_null=False): + self._set_options(format, unit, error_is_null) + + +cdef class _StrftimeOptions(FunctionOptions): + def _set_options(self, format, locale): + self.wrapped.reset( + new CStrftimeOptions(tobytes(format), tobytes(locale)) + ) + + +class StrftimeOptions(_StrftimeOptions): + """ + Options for the `strftime` function. + + Parameters + ---------- + format : str, default "%Y-%m-%dT%H:%M:%S" + Pattern for formatting input values. + locale : str, default "C" + Locale to use for locale-specific format specifiers. + """ + + def __init__(self, format="%Y-%m-%dT%H:%M:%S", locale="C"): + self._set_options(format, locale) + + +cdef class _DayOfWeekOptions(FunctionOptions): + def _set_options(self, count_from_zero, week_start): + self.wrapped.reset( + new CDayOfWeekOptions(count_from_zero, week_start) + ) + + +class DayOfWeekOptions(_DayOfWeekOptions): + """ + Options for the `day_of_week` function. + + Parameters + ---------- + count_from_zero : bool, default True + If True, number days from 0, otherwise from 1. + week_start : int, default 1 + Which day does the week start with (Monday=1, Sunday=7). + How this value is numbered is unaffected by `count_from_zero`. + """ + + def __init__(self, *, count_from_zero=True, week_start=1): + self._set_options(count_from_zero, week_start) + + +cdef class _WeekOptions(FunctionOptions): + def _set_options(self, week_starts_monday, count_from_zero, + first_week_is_fully_in_year): + self.wrapped.reset( + new CWeekOptions(week_starts_monday, count_from_zero, + first_week_is_fully_in_year) + ) + + +class WeekOptions(_WeekOptions): + """ + Options for the `week` function. + + Parameters + ---------- + week_starts_monday : bool, default True + If True, weeks start on Monday; if False, on Sunday. + count_from_zero : bool, default False + If True, dates at the start of a year that fall into the last week + of the previous year emit 0. + If False, they emit 52 or 53 (the week number of the last week + of the previous year). + first_week_is_fully_in_year : bool, default False + If True, week number 0 is fully in January. + If False, a week that begins on December 29, 30 or 31 is considered + to be week number 0 of the following year. + """ + + def __init__(self, *, week_starts_monday=True, count_from_zero=False, + first_week_is_fully_in_year=False): + self._set_options(week_starts_monday, + count_from_zero, first_week_is_fully_in_year) + + +cdef class _AssumeTimezoneOptions(FunctionOptions): + _ambiguous_map = { + "raise": CAssumeTimezoneAmbiguous_AMBIGUOUS_RAISE, + "earliest": CAssumeTimezoneAmbiguous_AMBIGUOUS_EARLIEST, + "latest": CAssumeTimezoneAmbiguous_AMBIGUOUS_LATEST, + } + _nonexistent_map = { + "raise": CAssumeTimezoneNonexistent_NONEXISTENT_RAISE, + "earliest": CAssumeTimezoneNonexistent_NONEXISTENT_EARLIEST, + "latest": CAssumeTimezoneNonexistent_NONEXISTENT_LATEST, + } + + def _set_options(self, timezone, ambiguous, nonexistent): + if ambiguous not in self._ambiguous_map: + _raise_invalid_function_option(ambiguous, + "'ambiguous' timestamp handling") + if nonexistent not in self._nonexistent_map: + _raise_invalid_function_option(nonexistent, + "'nonexistent' timestamp handling") + self.wrapped.reset( + new CAssumeTimezoneOptions(tobytes(timezone), + self._ambiguous_map[ambiguous], + self._nonexistent_map[nonexistent]) + ) + + +class AssumeTimezoneOptions(_AssumeTimezoneOptions): + """ + Options for the `assume_timezone` function. + + Parameters + ---------- + timezone : str + Timezone to assume for the input. + ambiguous : str, default "raise" + How to handle timestamps that are ambiguous in the assumed timezone. + Accepted values are "raise", "earliest", "latest". + nonexistent : str, default "raise" + How to handle timestamps that don't exist in the assumed timezone. + Accepted values are "raise", "earliest", "latest". + """ + + def __init__(self, timezone, *, ambiguous="raise", nonexistent="raise"): + self._set_options(timezone, ambiguous, nonexistent) + + +cdef class _NullOptions(FunctionOptions): + def _set_options(self, nan_is_null): + self.wrapped.reset(new CNullOptions(nan_is_null)) + + +class NullOptions(_NullOptions): + """ + Options for the `is_null` function. + + Parameters + ---------- + nan_is_null : bool, default False + Whether floating-point NaN values are considered null. + """ + + def __init__(self, *, nan_is_null=False): + self._set_options(nan_is_null) + + +cdef class _VarianceOptions(FunctionOptions): + def _set_options(self, ddof, skip_nulls, min_count): + self.wrapped.reset(new CVarianceOptions(ddof, skip_nulls, min_count)) + + +class VarianceOptions(_VarianceOptions): + __doc__ = f""" + Options for the `variance` and `stddev` functions. + + Parameters + ---------- + ddof : int, default 0 + Number of degrees of freedom. + {_skip_nulls_doc()} + {_min_count_doc(default=0)} + """ + + def __init__(self, *, ddof=0, skip_nulls=True, min_count=0): + self._set_options(ddof, skip_nulls, min_count) + + +cdef class _SplitOptions(FunctionOptions): + def _set_options(self, max_splits, reverse): + self.wrapped.reset(new CSplitOptions(max_splits, reverse)) + + +class SplitOptions(_SplitOptions): + """ + Options for splitting on whitespace. + + Parameters + ---------- + max_splits : int or None, default None + Maximum number of splits for each input value (unlimited if None). + reverse : bool, default False + Whether to start splitting from the end of each input value. + This only has an effect if `max_splits` is not None. + """ + + def __init__(self, *, max_splits=None, reverse=False): + if max_splits is None: + max_splits = -1 + self._set_options(max_splits, reverse) + + +cdef class _SplitPatternOptions(FunctionOptions): + def _set_options(self, pattern, max_splits, reverse): + self.wrapped.reset( + new CSplitPatternOptions(tobytes(pattern), max_splits, reverse) + ) + + +class SplitPatternOptions(_SplitPatternOptions): + """ + Options for splitting on a string pattern. + + Parameters + ---------- + pattern : str + String pattern to split on. + max_splits : int or None, default None + Maximum number of splits for each input value (unlimited if None). + reverse : bool, default False + Whether to start splitting from the end of each input value. + This only has an effect if `max_splits` is not None. + """ + + def __init__(self, pattern, *, max_splits=None, reverse=False): + if max_splits is None: + max_splits = -1 + self._set_options(pattern, max_splits, reverse) + + +cdef CSortOrder unwrap_sort_order(order) except *: + if order == "ascending": + return CSortOrder_Ascending + elif order == "descending": + return CSortOrder_Descending + _raise_invalid_function_option(order, "sort order") + + +cdef CNullPlacement unwrap_null_placement(null_placement) except *: + if null_placement == "at_start": + return CNullPlacement_AtStart + elif null_placement == "at_end": + return CNullPlacement_AtEnd + _raise_invalid_function_option(null_placement, "null placement") + + +cdef class _PartitionNthOptions(FunctionOptions): + def _set_options(self, pivot, null_placement): + self.wrapped.reset(new CPartitionNthOptions( + pivot, unwrap_null_placement(null_placement))) + + +class PartitionNthOptions(_PartitionNthOptions): + """ + Options for the `partition_nth_indices` function. + + Parameters + ---------- + pivot : int + Index into the equivalent sorted array of the pivot element. + null_placement : str, default "at_end" + Where nulls in the input should be partitioned. + Accepted values are "at_start", "at_end". + """ + + def __init__(self, pivot, *, null_placement="at_end"): + self._set_options(pivot, null_placement) + + +cdef class _CumulativeOptions(FunctionOptions): + def _set_options(self, start, skip_nulls): + if start is None: + self.wrapped.reset(new CCumulativeOptions(skip_nulls)) + elif isinstance(start, Scalar): + self.wrapped.reset(new CCumulativeOptions( + pyarrow_unwrap_scalar(start), skip_nulls)) + else: + try: + start = lib.scalar(start) + self.wrapped.reset(new CCumulativeOptions( + pyarrow_unwrap_scalar(start), skip_nulls)) + except Exception: + _raise_invalid_function_option( + start, "`start` type for CumulativeOptions", TypeError) + + +class CumulativeOptions(_CumulativeOptions): + """ + Options for `cumulative_*` functions. + + - cumulative_sum + - cumulative_sum_checked + - cumulative_prod + - cumulative_prod_checked + - cumulative_max + - cumulative_min + + Parameters + ---------- + start : Scalar, default None + Starting value for the cumulative operation. If none is given, + a default value depending on the operation and input type is used. + skip_nulls : bool, default False + When false, the first encountered null is propagated. + """ + + def __init__(self, start=None, *, skip_nulls=False): + self._set_options(start, skip_nulls) + + +class CumulativeSumOptions(_CumulativeOptions): + """ + Options for `cumulative_sum` function. + + Parameters + ---------- + start : Scalar, default None + Starting value for sum computation + skip_nulls : bool, default False + When false, the first encountered null is propagated. + """ + + def __init__(self, start=None, *, skip_nulls=False): + warnings.warn( + _DEPR_MSG.format("CumulativeSumOptions", "14.0", "CumulativeOptions"), + FutureWarning, + stacklevel=2 + ) + self._set_options(start, skip_nulls) + + +cdef class _PairwiseOptions(FunctionOptions): + def _set_options(self, period): + self.wrapped.reset(new CPairwiseOptions(period)) + + +class PairwiseOptions(_PairwiseOptions): + """ + Options for `pairwise` functions. + + Parameters + ---------- + period : int, default 1 + Period for applying the period function. + """ + + def __init__(self, period=1): + self._set_options(period) + + +cdef class _ListFlattenOptions(FunctionOptions): + def _set_options(self, recursive): + self.wrapped.reset(new CListFlattenOptions(recursive)) + + +class ListFlattenOptions(_ListFlattenOptions): + """ + Options for `list_flatten` function + + Parameters + ---------- + recursive : bool, default False + When True, the list array is flattened recursively until an array + of non-list values is formed. + """ + + def __init__(self, recursive=False): + self._set_options(recursive) + + +cdef class _ArraySortOptions(FunctionOptions): + def _set_options(self, order, null_placement): + self.wrapped.reset(new CArraySortOptions( + unwrap_sort_order(order), unwrap_null_placement(null_placement))) + + +class ArraySortOptions(_ArraySortOptions): + """ + Options for the `array_sort_indices` function. + + Parameters + ---------- + order : str, default "ascending" + Which order to sort values in. + Accepted values are "ascending", "descending". + null_placement : str, default "at_end" + Where nulls in the input should be sorted. + Accepted values are "at_start", "at_end". + """ + + def __init__(self, order="ascending", *, null_placement="at_end"): + self._set_options(order, null_placement) + + +cdef class _SortOptions(FunctionOptions): + def _set_options(self, sort_keys, null_placement): + cdef vector[CSortKey] c_sort_keys + for name, order in sort_keys: + c_sort_keys.push_back( + CSortKey(_ensure_field_ref(name), unwrap_sort_order(order)) + ) + self.wrapped.reset(new CSortOptions( + c_sort_keys, unwrap_null_placement(null_placement))) + + +class SortOptions(_SortOptions): + """ + Options for the `sort_indices` function. + + Parameters + ---------- + sort_keys : sequence of (name, order) tuples + Names of field/column keys to sort the input on, + along with the order each field/column is sorted in. + Accepted values for `order` are "ascending", "descending". + The field name can be a string column name or expression. + null_placement : str, default "at_end" + Where nulls in input should be sorted, only applying to + columns/fields mentioned in `sort_keys`. + Accepted values are "at_start", "at_end". + """ + + def __init__(self, sort_keys=(), *, null_placement="at_end"): + self._set_options(sort_keys, null_placement) + + +cdef class _SelectKOptions(FunctionOptions): + def _set_options(self, k, sort_keys): + cdef vector[CSortKey] c_sort_keys + for name, order in sort_keys: + c_sort_keys.push_back( + CSortKey(_ensure_field_ref(name), unwrap_sort_order(order)) + ) + self.wrapped.reset(new CSelectKOptions(k, c_sort_keys)) + + +class SelectKOptions(_SelectKOptions): + """ + Options for top/bottom k-selection. + + Parameters + ---------- + k : int + Number of leading values to select in sorted order + (i.e. the largest values if sort order is "descending", + the smallest otherwise). + sort_keys : sequence of (name, order) tuples + Names of field/column keys to sort the input on, + along with the order each field/column is sorted in. + Accepted values for `order` are "ascending", "descending". + The field name can be a string column name or expression. + """ + + def __init__(self, k, sort_keys): + self._set_options(k, sort_keys) + + +cdef class _QuantileOptions(FunctionOptions): + _interp_map = { + "linear": CQuantileInterp_LINEAR, + "lower": CQuantileInterp_LOWER, + "higher": CQuantileInterp_HIGHER, + "nearest": CQuantileInterp_NEAREST, + "midpoint": CQuantileInterp_MIDPOINT, + } + + def _set_options(self, quantiles, interp, skip_nulls, min_count): + try: + self.wrapped.reset( + new CQuantileOptions(quantiles, self._interp_map[interp], + skip_nulls, min_count) + ) + except KeyError: + _raise_invalid_function_option(interp, "quantile interpolation") + + +class QuantileOptions(_QuantileOptions): + __doc__ = f""" + Options for the `quantile` function. + + Parameters + ---------- + q : double or sequence of double, default 0.5 + Probability levels of the quantiles to compute. All values must be in + [0, 1]. + interpolation : str, default "linear" + How to break ties between competing data points for a given quantile. + Accepted values are: + + - "linear": compute an interpolation + - "lower": always use the smallest of the two data points + - "higher": always use the largest of the two data points + - "nearest": select the data point that is closest to the quantile + - "midpoint": compute the (unweighted) mean of the two data points + {_skip_nulls_doc()} + {_min_count_doc(default=0)} + """ + + def __init__(self, q=0.5, *, interpolation="linear", skip_nulls=True, + min_count=0): + if not isinstance(q, SUPPORTED_INPUT_ARR_TYPES): + q = [q] + self._set_options(q, interpolation, skip_nulls, min_count) + + +cdef class _TDigestOptions(FunctionOptions): + def _set_options(self, quantiles, delta, buffer_size, skip_nulls, + min_count): + self.wrapped.reset( + new CTDigestOptions(quantiles, delta, buffer_size, skip_nulls, + min_count) + ) + + +class TDigestOptions(_TDigestOptions): + __doc__ = f""" + Options for the `tdigest` function. + + Parameters + ---------- + q : double or sequence of double, default 0.5 + Probability levels of the quantiles to approximate. All values must be + in [0, 1]. + delta : int, default 100 + Compression parameter for the T-digest algorithm. + buffer_size : int, default 500 + Buffer size for the T-digest algorithm. + {_skip_nulls_doc()} + {_min_count_doc(default=0)} + """ + + def __init__(self, q=0.5, *, delta=100, buffer_size=500, skip_nulls=True, + min_count=0): + if not isinstance(q, SUPPORTED_INPUT_ARR_TYPES): + q = [q] + self._set_options(q, delta, buffer_size, skip_nulls, min_count) + + +cdef class _Utf8NormalizeOptions(FunctionOptions): + _form_map = { + "NFC": CUtf8NormalizeForm_NFC, + "NFKC": CUtf8NormalizeForm_NFKC, + "NFD": CUtf8NormalizeForm_NFD, + "NFKD": CUtf8NormalizeForm_NFKD, + } + + def _set_options(self, form): + try: + self.wrapped.reset( + new CUtf8NormalizeOptions(self._form_map[form]) + ) + except KeyError: + _raise_invalid_function_option(form, + "Unicode normalization form") + + +class Utf8NormalizeOptions(_Utf8NormalizeOptions): + """ + Options for the `utf8_normalize` function. + + Parameters + ---------- + form : str + Unicode normalization form. + Accepted values are "NFC", "NFKC", "NFD", NFKD". + """ + + def __init__(self, form): + self._set_options(form) + + +cdef class _RandomOptions(FunctionOptions): + def _set_options(self, initializer): + if initializer == 'system': + self.wrapped.reset(new CRandomOptions( + CRandomOptions.FromSystemRandom())) + return + + if not isinstance(initializer, int): + try: + initializer = hash(initializer) + except TypeError: + raise TypeError( + f"initializer should be 'system', an integer, " + f"or a hashable object; got {initializer!r}") + + if initializer < 0: + initializer += 2**64 + self.wrapped.reset(new CRandomOptions( + CRandomOptions.FromSeed(initializer))) + + +class RandomOptions(_RandomOptions): + """ + Options for random generation. + + Parameters + ---------- + initializer : int or str + How to initialize the underlying random generator. + If an integer is given, it is used as a seed. + If "system" is given, the random generator is initialized with + a system-specific source of (hopefully true) randomness. + Other values are invalid. + """ + + def __init__(self, *, initializer='system'): + self._set_options(initializer) + + +cdef class _RankOptions(FunctionOptions): + + _tiebreaker_map = { + "min": CRankOptionsTiebreaker_Min, + "max": CRankOptionsTiebreaker_Max, + "first": CRankOptionsTiebreaker_First, + "dense": CRankOptionsTiebreaker_Dense, + } + + def _set_options(self, sort_keys, null_placement, tiebreaker): + cdef vector[CSortKey] c_sort_keys + if isinstance(sort_keys, str): + c_sort_keys.push_back( + CSortKey(_ensure_field_ref(""), unwrap_sort_order(sort_keys)) + ) + else: + for name, order in sort_keys: + c_sort_keys.push_back( + CSortKey(_ensure_field_ref(name), unwrap_sort_order(order)) + ) + try: + self.wrapped.reset( + new CRankOptions(c_sort_keys, + unwrap_null_placement(null_placement), + self._tiebreaker_map[tiebreaker]) + ) + except KeyError: + _raise_invalid_function_option(tiebreaker, "tiebreaker") + + +class RankOptions(_RankOptions): + """ + Options for the `rank` function. + + Parameters + ---------- + sort_keys : sequence of (name, order) tuples or str, default "ascending" + Names of field/column keys to sort the input on, + along with the order each field/column is sorted in. + Accepted values for `order` are "ascending", "descending". + The field name can be a string column name or expression. + Alternatively, one can simply pass "ascending" or "descending" as a string + if the input is array-like. + null_placement : str, default "at_end" + Where nulls in input should be sorted. + Accepted values are "at_start", "at_end". + tiebreaker : str, default "first" + Configure how ties between equal values are handled. + Accepted values are: + + - "min": Ties get the smallest possible rank in sorted order. + - "max": Ties get the largest possible rank in sorted order. + - "first": Ranks are assigned in order of when ties appear in the + input. This ensures the ranks are a stable permutation + of the input. + - "dense": The ranks span a dense [1, M] interval where M is the + number of distinct values in the input. + """ + + def __init__(self, sort_keys="ascending", *, null_placement="at_end", tiebreaker="first"): + self._set_options(sort_keys, null_placement, tiebreaker) + + +cdef class Expression(_Weakrefable): + """ + A logical expression to be evaluated against some input. + + To create an expression: + + - Use the factory function ``pyarrow.compute.scalar()`` to create a + scalar (not necessary when combined, see example below). + - Use the factory function ``pyarrow.compute.field()`` to reference + a field (column in table). + - Compare fields and scalars with ``<``, ``<=``, ``==``, ``>=``, ``>``. + - Combine expressions using python operators ``&`` (logical and), + ``|`` (logical or) and ``~`` (logical not). + Note: python keywords ``and``, ``or`` and ``not`` cannot be used + to combine expressions. + - Create expression predicates using Expression methods such as + ``pyarrow.compute.Expression.isin()``. + + Examples + -------- + + >>> import pyarrow.compute as pc + >>> (pc.field("a") < pc.scalar(3)) | (pc.field("b") > 7) + 7))> + >>> pc.field('a') != 3 + + >>> pc.field('a').isin([1, 2, 3]) + + """ + + def __init__(self): + msg = 'Expression is an abstract class thus cannot be initialized.' + raise TypeError(msg) + + cdef void init(self, const CExpression& sp): + self.expr = sp + + @staticmethod + cdef wrap(const CExpression& sp): + cdef Expression self = Expression.__new__(Expression) + self.init(sp) + return self + + cdef inline CExpression unwrap(self): + return self.expr + + def equals(self, Expression other): + """ + Parameters + ---------- + other : pyarrow.dataset.Expression + + Returns + ------- + bool + """ + return self.expr.Equals(other.unwrap()) + + def __str__(self): + return frombytes(self.expr.ToString()) + + def __repr__(self): + return "".format( + self.__class__.__name__, str(self) + ) + + @staticmethod + def from_substrait(object message not None): + """ + Deserialize an expression from Substrait + + The serialized message must be an ExtendedExpression message that has + only a single expression. The name of the expression and the schema + the expression was bound to will be ignored. Use + pyarrow.substrait.deserialize_expressions if this information is needed + or if the message might contain multiple expressions. + + Parameters + ---------- + message : bytes or Buffer or a protobuf Message + The Substrait message to deserialize + + Returns + ------- + Expression + The deserialized expression + """ + expressions = _pas().BoundExpressions.from_substrait(message).expressions + if len(expressions) == 0: + raise ValueError("Substrait message did not contain any expressions") + if len(expressions) > 1: + raise ValueError( + "Substrait message contained multiple expressions. Use pyarrow.substrait.deserialize_expressions instead") + return next(iter(expressions.values())) + + def to_substrait(self, Schema schema not None, c_bool allow_arrow_extensions=False): + """ + Serialize the expression using Substrait + + The expression will be serialized as an ExtendedExpression message that has a + single expression named "expression" + + Parameters + ---------- + schema : Schema + The input schema the expression will be bound to + allow_arrow_extensions : bool, default False + If False then only functions that are part of the core Substrait function + definitions will be allowed. Set this to True to allow pyarrow-specific functions + but the result may not be accepted by other compute libraries. + + Returns + ------- + Buffer + A buffer containing the serialized Protobuf plan. + """ + return _pas().serialize_expressions([self], ["expression"], schema, allow_arrow_extensions=allow_arrow_extensions) + + @staticmethod + def _deserialize(Buffer buffer not None): + return Expression.wrap(GetResultValue(CDeserializeExpression( + pyarrow_unwrap_buffer(buffer)))) + + def __reduce__(self): + buffer = pyarrow_wrap_buffer(GetResultValue( + CSerializeExpression(self.expr))) + return Expression._deserialize, (buffer,) + + @staticmethod + cdef Expression _expr_or_scalar(object expr): + if isinstance(expr, Expression): + return ( expr) + return ( Expression._scalar(expr)) + + @staticmethod + def _call(str function_name, list arguments, FunctionOptions options=None): + cdef: + vector[CExpression] c_arguments + shared_ptr[CFunctionOptions] c_options + + for argument in arguments: + if not isinstance(argument, Expression): + # Attempt to help convert this to an expression + try: + argument = Expression._scalar(argument) + except ArrowInvalid: + raise TypeError( + "only other expressions allowed as arguments") + c_arguments.push_back(( argument).expr) + + if options is not None: + c_options = options.unwrap() + + return Expression.wrap(CMakeCallExpression( + tobytes(function_name), move(c_arguments), c_options)) + + def __richcmp__(self, other, int op): + other = Expression._expr_or_scalar(other) + return Expression._call({ + Py_EQ: "equal", + Py_NE: "not_equal", + Py_GT: "greater", + Py_GE: "greater_equal", + Py_LT: "less", + Py_LE: "less_equal", + }[op], [self, other]) + + def __bool__(self): + raise ValueError( + "An Expression cannot be evaluated to python True or False. " + "If you are using the 'and', 'or' or 'not' operators, use '&', " + "'|' or '~' instead." + ) + + def __invert__(self): + return Expression._call("invert", [self]) + + def __and__(Expression self, other): + other = Expression._expr_or_scalar(other) + return Expression._call("and_kleene", [self, other]) + + def __or__(Expression self, other): + other = Expression._expr_or_scalar(other) + return Expression._call("or_kleene", [self, other]) + + def __add__(Expression self, other): + other = Expression._expr_or_scalar(other) + return Expression._call("add_checked", [self, other]) + + def __mul__(Expression self, other): + other = Expression._expr_or_scalar(other) + return Expression._call("multiply_checked", [self, other]) + + def __sub__(Expression self, other): + other = Expression._expr_or_scalar(other) + return Expression._call("subtract_checked", [self, other]) + + def __truediv__(Expression self, other): + other = Expression._expr_or_scalar(other) + return Expression._call("divide_checked", [self, other]) + + def is_valid(self): + """ + Check whether the expression is not-null (valid). + + This creates a new expression equivalent to calling the + `is_valid` compute function on this expression. + + Returns + ------- + is_valid : Expression + """ + return Expression._call("is_valid", [self]) + + def is_null(self, bint nan_is_null=False): + """ + Check whether the expression is null. + + This creates a new expression equivalent to calling the + `is_null` compute function on this expression. + + Parameters + ---------- + nan_is_null : boolean, default False + Whether floating-point NaNs are considered null. + + Returns + ------- + is_null : Expression + """ + options = NullOptions(nan_is_null=nan_is_null) + return Expression._call("is_null", [self], options) + + def is_nan(self): + """ + Check whether the expression is NaN. + + This creates a new expression equivalent to calling the + `is_nan` compute function on this expression. + + Returns + ------- + is_nan : Expression + """ + return Expression._call("is_nan", [self]) + + def cast(self, type=None, safe=None, options=None): + """ + Explicitly set or change the expression's data type. + + This creates a new expression equivalent to calling the + `cast` compute function on this expression. + + Parameters + ---------- + type : DataType, default None + Type to cast array to. + safe : boolean, default True + Whether to check for conversion errors such as overflow. + options : CastOptions, default None + Additional checks pass by CastOptions + + Returns + ------- + cast : Expression + """ + safe_vars_passed = (safe is not None) or (type is not None) + + if safe_vars_passed and (options is not None): + raise ValueError("Must either pass values for 'type' and 'safe' or pass a " + "value for 'options'") + + if options is None: + type = ensure_type(type, allow_none=False) + if safe is False: + options = CastOptions.unsafe(type) + else: + options = CastOptions.safe(type) + return Expression._call("cast", [self], options) + + def isin(self, values): + """ + Check whether the expression is contained in values. + + This creates a new expression equivalent to calling the + `is_in` compute function on this expression. + + Parameters + ---------- + values : Array or iterable + The values to check for. + + Returns + ------- + isin : Expression + A new expression that, when evaluated, checks whether + this expression's value is contained in `values`. + """ + if not isinstance(values, Array): + values = lib.array(values) + + options = SetLookupOptions(values) + return Expression._call("is_in", [self], options) + + @staticmethod + def _field(name_or_idx not None): + cdef: + CFieldRef c_field + + if isinstance(name_or_idx, int): + return Expression.wrap(CMakeFieldExpressionByIndex(name_or_idx)) + else: + c_field = CFieldRef( tobytes(name_or_idx)) + return Expression.wrap(CMakeFieldExpression(c_field)) + + @staticmethod + def _nested_field(tuple names not None): + cdef: + vector[CFieldRef] nested + + if len(names) == 0: + raise ValueError("nested field reference should be non-empty") + nested.reserve(len(names)) + for name in names: + if isinstance(name, int): + nested.push_back(CFieldRef(name)) + else: + nested.push_back(CFieldRef( tobytes(name))) + return Expression.wrap(CMakeFieldExpression(CFieldRef(move(nested)))) + + @staticmethod + def _scalar(value): + cdef: + Scalar scalar + + if isinstance(value, Scalar): + scalar = value + else: + scalar = lib.scalar(value) + + return Expression.wrap(CMakeScalarExpression(scalar.unwrap())) + + +_deserialize = Expression._deserialize +cdef CExpression _true = CMakeScalarExpression( + make_shared[CBooleanScalar](True) +) + + +cdef CExpression _bind(Expression filter, Schema schema) except *: + assert schema is not None + + if filter is None: + return _true + + return GetResultValue(filter.unwrap().Bind( + deref(pyarrow_unwrap_schema(schema).get()))) + + +cdef class UdfContext: + """ + Per-invocation function context/state. + + This object will always be the first argument to a user-defined + function. It should not be used outside of a call to the function. + """ + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly" + .format(self.__class__.__name__)) + + cdef void init(self, const CUdfContext &c_context): + self.c_context = c_context + + @property + def batch_length(self): + """ + The common length of all input arguments (int). + + In the case that all arguments are scalars, this value + is used to pass the "actual length" of the arguments, + e.g. because the scalar values are encoding a column + with a constant value. + """ + return self.c_context.batch_length + + @property + def memory_pool(self): + """ + A memory pool for allocations (:class:`MemoryPool`). + + This is the memory pool supplied by the user when they invoked + the function and it should be used in any calls to arrow that the + UDF makes if that call accepts a memory_pool. + """ + return box_memory_pool(self.c_context.pool) + + +cdef inline CFunctionDoc _make_function_doc(dict func_doc) except *: + """ + Helper function to generate the FunctionDoc + This function accepts a dictionary and expects the + summary(str), description(str) and arg_names(List[str]) keys. + """ + cdef: + CFunctionDoc f_doc + vector[c_string] c_arg_names + + f_doc.summary = tobytes(func_doc["summary"]) + f_doc.description = tobytes(func_doc["description"]) + for arg_name in func_doc["arg_names"]: + c_arg_names.push_back(tobytes(arg_name)) + f_doc.arg_names = c_arg_names + # UDFOptions integration: + # TODO: https://issues.apache.org/jira/browse/ARROW-16041 + f_doc.options_class = b"" + f_doc.options_required = False + return f_doc + + +cdef object box_udf_context(const CUdfContext& c_context): + cdef UdfContext context = UdfContext.__new__(UdfContext) + context.init(c_context) + return context + + +cdef _udf_callback(user_function, const CUdfContext& c_context, inputs): + """ + Helper callback function used to wrap the UdfContext from Python to C++ + execution. + """ + context = box_udf_context(c_context) + return user_function(context, *inputs) + + +def _get_udf_context(memory_pool, batch_length): + cdef CUdfContext c_context + c_context.pool = maybe_unbox_memory_pool(memory_pool) + c_context.batch_length = batch_length + context = box_udf_context(c_context) + return context + + +ctypedef CStatus (*CRegisterUdf)(PyObject* function, function[CallbackUdf] wrapper, + const CUdfOptions& options, CFunctionRegistry* registry) + +cdef class RegisterUdf(_Weakrefable): + cdef CRegisterUdf register_func + + cdef void init(self, const CRegisterUdf register_func): + self.register_func = register_func + + +cdef get_register_scalar_function(): + cdef RegisterUdf reg = RegisterUdf.__new__(RegisterUdf) + reg.register_func = RegisterScalarFunction + return reg + + +cdef get_register_tabular_function(): + cdef RegisterUdf reg = RegisterUdf.__new__(RegisterUdf) + reg.register_func = RegisterTabularFunction + return reg + + +cdef get_register_aggregate_function(): + cdef RegisterUdf reg = RegisterUdf.__new__(RegisterUdf) + reg.register_func = RegisterAggregateFunction + return reg + +cdef get_register_vector_function(): + cdef RegisterUdf reg = RegisterUdf.__new__(RegisterUdf) + reg.register_func = RegisterVectorFunction + return reg + + +def register_scalar_function(func, function_name, function_doc, in_types, out_type, + func_registry=None): + """ + Register a user-defined scalar function. + + This API is EXPERIMENTAL. + + A scalar function is a function that executes elementwise + operations on arrays or scalars, i.e. a scalar function must + be computed row-by-row with no state where each output row + is computed only from its corresponding input row. + In other words, all argument arrays have the same length, + and the output array is of the same length as the arguments. + Scalar functions are the only functions allowed in query engine + expressions. + + Parameters + ---------- + func : callable + A callable implementing the user-defined function. + The first argument is the context argument of type + UdfContext. + Then, it must take arguments equal to the number of + in_types defined. It must return an Array or Scalar + matching the out_type. It must return a Scalar if + all arguments are scalar, else it must return an Array. + + To define a varargs function, pass a callable that takes + *args. The last in_type will be the type of all varargs + arguments. + function_name : str + Name of the function. There should only be one function + registered with this name in the function registry. + function_doc : dict + A dictionary object with keys "summary" (str), + and "description" (str). + in_types : Dict[str, DataType] + A dictionary mapping function argument names to + their respective DataType. + The argument names will be used to generate + documentation for the function. The number of + arguments specified here determines the function + arity. + out_type : DataType + Output type of the function. + func_registry : FunctionRegistry + Optional function registry to use instead of the default global one. + + Examples + -------- + >>> import pyarrow as pa + >>> import pyarrow.compute as pc + >>> + >>> func_doc = {} + >>> func_doc["summary"] = "simple udf" + >>> func_doc["description"] = "add a constant to a scalar" + >>> + >>> def add_constant(ctx, array): + ... return pc.add(array, 1, memory_pool=ctx.memory_pool) + >>> + >>> func_name = "py_add_func" + >>> in_types = {"array": pa.int64()} + >>> out_type = pa.int64() + >>> pc.register_scalar_function(add_constant, func_name, func_doc, + ... in_types, out_type) + >>> + >>> func = pc.get_function(func_name) + >>> func.name + 'py_add_func' + >>> answer = pc.call_function(func_name, [pa.array([20])]) + >>> answer + + [ + 21 + ] + """ + return _register_user_defined_function(get_register_scalar_function(), + func, function_name, function_doc, in_types, + out_type, func_registry) + + +def register_vector_function(func, function_name, function_doc, in_types, out_type, + func_registry=None): + """ + Register a user-defined vector function. + + This API is EXPERIMENTAL. + + A vector function is a function that executes vector + operations on arrays. Vector function is often used + when compute doesn't fit other more specific types of + functions (e.g., scalar and aggregate). + + Parameters + ---------- + func : callable + A callable implementing the user-defined function. + The first argument is the context argument of type + UdfContext. + Then, it must take arguments equal to the number of + in_types defined. It must return an Array or Scalar + matching the out_type. It must return a Scalar if + all arguments are scalar, else it must return an Array. + + To define a varargs function, pass a callable that takes + *args. The last in_type will be the type of all varargs + arguments. + function_name : str + Name of the function. There should only be one function + registered with this name in the function registry. + function_doc : dict + A dictionary object with keys "summary" (str), + and "description" (str). + in_types : Dict[str, DataType] + A dictionary mapping function argument names to + their respective DataType. + The argument names will be used to generate + documentation for the function. The number of + arguments specified here determines the function + arity. + out_type : DataType + Output type of the function. + func_registry : FunctionRegistry + Optional function registry to use instead of the default global one. + + Examples + -------- + >>> import pyarrow as pa + >>> import pyarrow.compute as pc + >>> + >>> func_doc = {} + >>> func_doc["summary"] = "percent rank" + >>> func_doc["description"] = "compute percent rank" + >>> + >>> def list_flatten_udf(ctx, x): + ... return pc.list_flatten(x) + >>> + >>> func_name = "list_flatten_udf" + >>> in_types = {"array": pa.list_(pa.int64())} + >>> out_type = pa.int64() + >>> pc.register_vector_function(list_flatten_udf, func_name, func_doc, + ... in_types, out_type) + >>> + >>> answer = pc.call_function(func_name, [pa.array([[1, 2], [3, 4]])]) + >>> answer + + [ + 1, + 2, + 3, + 4 + ] + """ + return _register_user_defined_function(get_register_vector_function(), + func, function_name, function_doc, in_types, + out_type, func_registry) + + +def register_aggregate_function(func, function_name, function_doc, in_types, out_type, + func_registry=None): + """ + Register a user-defined non-decomposable aggregate function. + + This API is EXPERIMENTAL. + + A non-decomposable aggregation function is a function that executes + aggregate operations on the whole data that it is aggregating. + In other words, non-decomposable aggregate function cannot be + split into consume/merge/finalize steps. + + This is often used with ordered or segmented aggregation where groups + can be emit before accumulating all of the input data. + + Note that currently the size of any input column cannot exceed 2 GB + for a single segment (all groups combined). + + Parameters + ---------- + func : callable + A callable implementing the user-defined function. + The first argument is the context argument of type + UdfContext. + Then, it must take arguments equal to the number of + in_types defined. It must return a Scalar matching the + out_type. + To define a varargs function, pass a callable that takes + *args. The in_type needs to match in type of inputs when + the function gets called. + function_name : str + Name of the function. This name must be unique, i.e., + there should only be one function registered with + this name in the function registry. + function_doc : dict + A dictionary object with keys "summary" (str), + and "description" (str). + in_types : Dict[str, DataType] + A dictionary mapping function argument names to + their respective DataType. + The argument names will be used to generate + documentation for the function. The number of + arguments specified here determines the function + arity. + out_type : DataType + Output type of the function. + func_registry : FunctionRegistry + Optional function registry to use instead of the default global one. + + Examples + -------- + >>> import numpy as np + >>> import pyarrow as pa + >>> import pyarrow.compute as pc + >>> + >>> func_doc = {} + >>> func_doc["summary"] = "simple median udf" + >>> func_doc["description"] = "compute median" + >>> + >>> def compute_median(ctx, array): + ... return pa.scalar(np.median(array)) + >>> + >>> func_name = "py_compute_median" + >>> in_types = {"array": pa.int64()} + >>> out_type = pa.float64() + >>> pc.register_aggregate_function(compute_median, func_name, func_doc, + ... in_types, out_type) + >>> + >>> func = pc.get_function(func_name) + >>> func.name + 'py_compute_median' + >>> answer = pc.call_function(func_name, [pa.array([20, 40])]) + >>> answer + + >>> table = pa.table([pa.array([1, 1, 2, 2]), pa.array([10, 20, 30, 40])], names=['k', 'v']) + >>> result = table.group_by('k').aggregate([('v', 'py_compute_median')]) + >>> result + pyarrow.Table + k: int64 + v_py_compute_median: double + ---- + k: [[1,2]] + v_py_compute_median: [[15,35]] + """ + return _register_user_defined_function(get_register_aggregate_function(), + func, function_name, function_doc, in_types, + out_type, func_registry) + + +def register_tabular_function(func, function_name, function_doc, in_types, out_type, + func_registry=None): + """ + Register a user-defined tabular function. + + This API is EXPERIMENTAL. + + A tabular function is one accepting a context argument of type + UdfContext and returning a generator of struct arrays. + The in_types argument must be empty and the out_type argument + specifies a schema. Each struct array must have field types + corresponding to the schema. + + Parameters + ---------- + func : callable + A callable implementing the user-defined function. + The only argument is the context argument of type + UdfContext. It must return a callable that + returns on each invocation a StructArray matching + the out_type, where an empty array indicates end. + function_name : str + Name of the function. There should only be one function + registered with this name in the function registry. + function_doc : dict + A dictionary object with keys "summary" (str), + and "description" (str). + in_types : Dict[str, DataType] + Must be an empty dictionary (reserved for future use). + out_type : Union[Schema, DataType] + Schema of the function's output, or a corresponding flat struct type. + func_registry : FunctionRegistry + Optional function registry to use instead of the default global one. + """ + cdef: + shared_ptr[CSchema] c_schema + shared_ptr[CDataType] c_type + + if isinstance(out_type, Schema): + c_schema = pyarrow_unwrap_schema(out_type) + with nogil: + c_type = make_shared[CStructType](deref(c_schema).fields()) + out_type = pyarrow_wrap_data_type(c_type) + return _register_user_defined_function(get_register_tabular_function(), + func, function_name, function_doc, in_types, + out_type, func_registry) + + +def _register_user_defined_function(register_func, func, function_name, function_doc, in_types, + out_type, func_registry=None): + """ + Register a user-defined function. + + This method itself doesn't care about the type of the UDF + (i.e., scalar vs tabular vs aggregate) + + Parameters + ---------- + register_func: object + An object holding a CRegisterUdf in a "register_func" attribute. + func : callable + A callable implementing the user-defined function. + function_name : str + Name of the function. There should only be one function + registered with this name in the function registry. + function_doc : dict + A dictionary object with keys "summary" (str), + and "description" (str). + in_types : Dict[str, DataType] + A dictionary mapping function argument names to + their respective DataType. + out_type : DataType + Output type of the function. + func_registry : FunctionRegistry + Optional function registry to use instead of the default global one. + """ + cdef: + CRegisterUdf c_register_func + c_string c_func_name + CArity c_arity + CFunctionDoc c_func_doc + vector[shared_ptr[CDataType]] c_in_types + PyObject* c_function + shared_ptr[CDataType] c_out_type + CUdfOptions c_options + CFunctionRegistry* c_func_registry + + if callable(func): + c_function = func + else: + raise TypeError("func must be a callable") + + c_func_name = tobytes(function_name) + + func_spec = inspect.getfullargspec(func) + num_args = -1 + if isinstance(in_types, dict): + for in_type in in_types.values(): + c_in_types.push_back( + pyarrow_unwrap_data_type(ensure_type(in_type))) + function_doc["arg_names"] = in_types.keys() + num_args = len(in_types) + else: + raise TypeError( + "in_types must be a dictionary of DataType") + + c_arity = CArity( num_args, func_spec.varargs) + + if "summary" not in function_doc: + raise ValueError("Function doc must contain a summary") + + if "description" not in function_doc: + raise ValueError("Function doc must contain a description") + + if "arg_names" not in function_doc: + raise ValueError("Function doc must contain arg_names") + + c_func_doc = _make_function_doc(function_doc) + + c_out_type = pyarrow_unwrap_data_type(ensure_type(out_type)) + + c_options.func_name = c_func_name + c_options.arity = c_arity + c_options.func_doc = c_func_doc + c_options.input_types = c_in_types + c_options.output_type = c_out_type + + if func_registry is None: + c_func_registry = NULL + else: + c_func_registry = (func_registry).registry + + c_register_func = (register_func).register_func + + check_status(c_register_func(c_function, + &_udf_callback, + c_options, c_func_registry)) + + +def call_tabular_function(function_name, args=None, func_registry=None): + """ + Get a record batch iterator from a tabular function. + + Parameters + ---------- + function_name : str + Name of the function. + args : iterable + The arguments to pass to the function. Accepted types depend + on the specific function. Currently, only an empty args is supported. + func_registry : FunctionRegistry + Optional function registry to use instead of the default global one. + """ + cdef: + c_string c_func_name + vector[CDatum] c_args + CFunctionRegistry* c_func_registry + shared_ptr[CRecordBatchReader] c_reader + RecordBatchReader reader + + c_func_name = tobytes(function_name) + if func_registry is None: + c_func_registry = NULL + else: + c_func_registry = (func_registry).registry + if args is None: + args = [] + _pack_compute_args(args, &c_args) + + with nogil: + c_reader = GetResultValue(CallTabularFunction( + c_func_name, c_args, c_func_registry)) + reader = RecordBatchReader.__new__(RecordBatchReader) + reader.reader = c_reader + return RecordBatchReader.from_batches(pyarrow_wrap_schema(deref(c_reader).schema()), reader) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_compute_docstrings.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_compute_docstrings.py new file mode 100644 index 0000000000000000000000000000000000000000..150dbdb1175803e3c40a1bd2469a4df34ea57e4e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_compute_docstrings.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Custom documentation additions for compute functions. +""" + +function_doc_additions = {} + +function_doc_additions["filter"] = """ + Examples + -------- + >>> import pyarrow as pa + >>> arr = pa.array(["a", "b", "c", None, "e"]) + >>> mask = pa.array([True, False, None, False, True]) + >>> arr.filter(mask) + + [ + "a", + "e" + ] + >>> arr.filter(mask, null_selection_behavior='emit_null') + + [ + "a", + null, + "e" + ] + """ + +function_doc_additions["mode"] = """ + Examples + -------- + >>> import pyarrow as pa + >>> import pyarrow.compute as pc + >>> arr = pa.array([1, 1, 2, 2, 3, 2, 2, 2]) + >>> modes = pc.mode(arr, 2) + >>> modes[0] + + >>> modes[1] + + """ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_cuda.pyx b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_cuda.pyx new file mode 100644 index 0000000000000000000000000000000000000000..5aed9f8a285188d4f3fa173cffa7d1188bc9006a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_cuda.pyx @@ -0,0 +1,1080 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from pyarrow.lib cimport * +from pyarrow.includes.libarrow_cuda cimport * +from pyarrow.lib import allocate_buffer, as_buffer, ArrowTypeError +from pyarrow.util import get_contiguous_span +cimport cpython as cp + + +cdef class Context(_Weakrefable): + """ + CUDA driver context. + """ + + def __init__(self, *args, **kwargs): + """ + Create a CUDA driver context for a particular device. + + If a CUDA context handle is passed, it is wrapped, otherwise + a default CUDA context for the given device is requested. + + Parameters + ---------- + device_number : int (default 0) + Specify the GPU device for which the CUDA driver context is + requested. + handle : int, optional + Specify CUDA handle for a shared context that has been created + by another library. + """ + # This method exposed because autodoc doesn't pick __cinit__ + + def __cinit__(self, int device_number=0, uintptr_t handle=0): + cdef CCudaDeviceManager* manager + manager = GetResultValue(CCudaDeviceManager.Instance()) + cdef int n = manager.num_devices() + if device_number >= n or device_number < 0: + self.context.reset() + raise ValueError('device_number argument must be ' + 'non-negative less than %s' % (n)) + if handle == 0: + self.context = GetResultValue(manager.GetContext(device_number)) + else: + self.context = GetResultValue(manager.GetSharedContext( + device_number, handle)) + self.device_number = device_number + + @staticmethod + def from_numba(context=None): + """ + Create a Context instance from a Numba CUDA context. + + Parameters + ---------- + context : {numba.cuda.cudadrv.driver.Context, None} + A Numba CUDA context instance. + If None, the current Numba context is used. + + Returns + ------- + shared_context : pyarrow.cuda.Context + Context instance. + """ + if context is None: + import numba.cuda + context = numba.cuda.current_context() + return Context(device_number=context.device.id, + handle=context.handle.value) + + def to_numba(self): + """ + Convert Context to a Numba CUDA context. + + Returns + ------- + context : numba.cuda.cudadrv.driver.Context + Numba CUDA context instance. + """ + import ctypes + import numba.cuda + device = numba.cuda.gpus[self.device_number] + handle = ctypes.c_void_p(self.handle) + context = numba.cuda.cudadrv.driver.Context(device, handle) + + class DummyPendingDeallocs(object): + # Context is managed by pyarrow + def add_item(self, *args, **kwargs): + pass + + context.deallocations = DummyPendingDeallocs() + return context + + @staticmethod + def get_num_devices(): + """ Return the number of GPU devices. + """ + cdef CCudaDeviceManager* manager + manager = GetResultValue(CCudaDeviceManager.Instance()) + return manager.num_devices() + + @property + def device_number(self): + """ Return context device number. + """ + return self.device_number + + @property + def handle(self): + """ Return pointer to context handle. + """ + return self.context.get().handle() + + cdef void init(self, const shared_ptr[CCudaContext]& ctx): + self.context = ctx + + def synchronize(self): + """Blocks until the device has completed all preceding requested + tasks. + """ + check_status(self.context.get().Synchronize()) + + @property + def bytes_allocated(self): + """Return the number of allocated bytes. + """ + return self.context.get().bytes_allocated() + + def get_device_address(self, uintptr_t address): + """Return the device address that is reachable from kernels running in + the context + + Parameters + ---------- + address : int + Specify memory address value + + Returns + ------- + device_address : int + Device address accessible from device context + + Notes + ----- + The device address is defined as a memory address accessible + by device. While it is often a device memory address but it + can be also a host memory address, for instance, when the + memory is allocated as host memory (using cudaMallocHost or + cudaHostAlloc) or as managed memory (using cudaMallocManaged) + or the host memory is page-locked (using cudaHostRegister). + """ + return GetResultValue(self.context.get().GetDeviceAddress(address)) + + def new_buffer(self, int64_t nbytes): + """Return new device buffer. + + Parameters + ---------- + nbytes : int + Specify the number of bytes to be allocated. + + Returns + ------- + buf : CudaBuffer + Allocated buffer. + """ + cdef: + shared_ptr[CCudaBuffer] cudabuf + with nogil: + cudabuf = GetResultValue(self.context.get().Allocate(nbytes)) + return pyarrow_wrap_cudabuffer(cudabuf) + + @property + def memory_manager(self): + """ + The default memory manager tied to this context's device. + + Returns + ------- + MemoryManager + """ + return MemoryManager.wrap(self.context.get().memory_manager()) + + @property + def device(self): + """ + The device instance associated with this context. + + Returns + ------- + Device + """ + return Device.wrap(self.context.get().device()) + + def foreign_buffer(self, address, size, base=None): + """ + Create device buffer from address and size as a view. + + The caller is responsible for allocating and freeing the + memory. When `address==size==0` then a new zero-sized buffer + is returned. + + Parameters + ---------- + address : int + Specify the starting address of the buffer. The address can + refer to both device or host memory but it must be + accessible from device after mapping it with + `get_device_address` method. + size : int + Specify the size of device buffer in bytes. + base : {None, object} + Specify object that owns the referenced memory. + + Returns + ------- + cbuf : CudaBuffer + Device buffer as a view of device reachable memory. + + """ + if not address and size == 0: + return self.new_buffer(0) + cdef: + uintptr_t c_addr = self.get_device_address(address) + int64_t c_size = size + shared_ptr[CCudaBuffer] cudabuf + + cudabuf = GetResultValue(self.context.get().View( + c_addr, c_size)) + return pyarrow_wrap_cudabuffer_base(cudabuf, base) + + def open_ipc_buffer(self, ipc_handle): + """ Open existing CUDA IPC memory handle + + Parameters + ---------- + ipc_handle : IpcMemHandle + Specify opaque pointer to CUipcMemHandle (driver API). + + Returns + ------- + buf : CudaBuffer + referencing device buffer + """ + handle = pyarrow_unwrap_cudaipcmemhandle(ipc_handle) + cdef shared_ptr[CCudaBuffer] cudabuf + with nogil: + cudabuf = GetResultValue( + self.context.get().OpenIpcBuffer(handle.get()[0])) + return pyarrow_wrap_cudabuffer(cudabuf) + + def buffer_from_data(self, object data, int64_t offset=0, int64_t size=-1): + """Create device buffer and initialize with data. + + Parameters + ---------- + data : {CudaBuffer, HostBuffer, Buffer, array-like} + Specify data to be copied to device buffer. + offset : int + Specify the offset of input buffer for device data + buffering. Default: 0. + size : int + Specify the size of device buffer in bytes. Default: all + (starting from input offset) + + Returns + ------- + cbuf : CudaBuffer + Device buffer with copied data. + """ + is_host_data = not pyarrow_is_cudabuffer(data) + buf = as_buffer(data) if is_host_data else data + + bsize = buf.size + if offset < 0 or (bsize and offset >= bsize): + raise ValueError('offset argument is out-of-range') + if size < 0: + size = bsize - offset + elif offset + size > bsize: + raise ValueError( + 'requested larger slice than available in device buffer') + + if offset != 0 or size != bsize: + buf = buf.slice(offset, size) + + result = self.new_buffer(size) + if is_host_data: + result.copy_from_host(buf, position=0, nbytes=size) + else: + result.copy_from_device(buf, position=0, nbytes=size) + return result + + def buffer_from_object(self, obj): + """Create device buffer view of arbitrary object that references + device accessible memory. + + When the object contains a non-contiguous view of device + accessible memory then the returned device buffer will contain + contiguous view of the memory, that is, including the + intermediate data that is otherwise invisible to the input + object. + + Parameters + ---------- + obj : {object, Buffer, HostBuffer, CudaBuffer, ...} + Specify an object that holds (device or host) address that + can be accessed from device. This includes objects with + types defined in pyarrow.cuda as well as arbitrary objects + that implement the CUDA array interface as defined by numba. + + Returns + ------- + cbuf : CudaBuffer + Device buffer as a view of device accessible memory. + + """ + if isinstance(obj, HostBuffer): + return self.foreign_buffer(obj.address, obj.size, base=obj) + elif isinstance(obj, Buffer): + return CudaBuffer.from_buffer(obj) + elif isinstance(obj, CudaBuffer): + return obj + elif hasattr(obj, '__cuda_array_interface__'): + desc = obj.__cuda_array_interface__ + addr = desc['data'][0] + if addr is None: + return self.new_buffer(0) + import numpy as np + start, end = get_contiguous_span( + desc['shape'], desc.get('strides'), + np.dtype(desc['typestr']).itemsize) + return self.foreign_buffer(addr + start, end - start, base=obj) + raise ArrowTypeError('cannot create device buffer view from' + ' `%s` object' % (type(obj))) + + +cdef class IpcMemHandle(_Weakrefable): + """A serializable container for a CUDA IPC handle. + """ + cdef void init(self, shared_ptr[CCudaIpcMemHandle]& h): + self.handle = h + + @staticmethod + def from_buffer(Buffer opaque_handle): + """Create IpcMemHandle from opaque buffer (e.g. from another + process) + + Parameters + ---------- + opaque_handle : + a CUipcMemHandle as a const void* + + Returns + ------- + ipc_handle : IpcMemHandle + """ + c_buf = pyarrow_unwrap_buffer(opaque_handle) + cdef: + shared_ptr[CCudaIpcMemHandle] handle + + handle = GetResultValue( + CCudaIpcMemHandle.FromBuffer(c_buf.get().data())) + return pyarrow_wrap_cudaipcmemhandle(handle) + + def serialize(self, pool=None): + """Write IpcMemHandle to a Buffer + + Parameters + ---------- + pool : {MemoryPool, None} + Specify a pool to allocate memory from + + Returns + ------- + buf : Buffer + The serialized buffer. + """ + cdef CMemoryPool* pool_ = maybe_unbox_memory_pool(pool) + cdef shared_ptr[CBuffer] buf + cdef CCudaIpcMemHandle* h = self.handle.get() + with nogil: + buf = GetResultValue(h.Serialize(pool_)) + return pyarrow_wrap_buffer(buf) + + +cdef class CudaBuffer(Buffer): + """An Arrow buffer with data located in a GPU device. + + To create a CudaBuffer instance, use Context.device_buffer(). + + The memory allocated in a CudaBuffer is freed when the buffer object + is deleted. + """ + + def __init__(self): + raise TypeError("Do not call CudaBuffer's constructor directly, use " + "`.device_buffer`" + " method instead.") + + cdef void init_cuda(self, + const shared_ptr[CCudaBuffer]& buffer, + object base): + self.cuda_buffer = buffer + self.init( buffer) + self.base = base + + @staticmethod + def from_buffer(buf): + """ Convert back generic buffer into CudaBuffer + + Parameters + ---------- + buf : Buffer + Specify buffer containing CudaBuffer + + Returns + ------- + dbuf : CudaBuffer + Resulting device buffer. + """ + c_buf = pyarrow_unwrap_buffer(buf) + cuda_buffer = GetResultValue(CCudaBuffer.FromBuffer(c_buf)) + return pyarrow_wrap_cudabuffer(cuda_buffer) + + @staticmethod + def from_numba(mem): + """Create a CudaBuffer view from numba MemoryPointer instance. + + Parameters + ---------- + mem : numba.cuda.cudadrv.driver.MemoryPointer + + Returns + ------- + cbuf : CudaBuffer + Device buffer as a view of numba MemoryPointer. + """ + ctx = Context.from_numba(mem.context) + if mem.device_pointer.value is None and mem.size==0: + return ctx.new_buffer(0) + return ctx.foreign_buffer(mem.device_pointer.value, mem.size, base=mem) + + def to_numba(self): + """Return numba memory pointer of CudaBuffer instance. + """ + import ctypes + from numba.cuda.cudadrv.driver import MemoryPointer + return MemoryPointer(self.context.to_numba(), + pointer=ctypes.c_void_p(self.address), + size=self.size) + + cdef getitem(self, int64_t i): + return self.copy_to_host(position=i, nbytes=1)[0] + + def copy_to_host(self, int64_t position=0, int64_t nbytes=-1, + Buffer buf=None, + MemoryPool memory_pool=None, c_bool resizable=False): + """Copy memory from GPU device to CPU host + + Caller is responsible for ensuring that all tasks affecting + the memory are finished. Use + + `.context.synchronize()` + + when needed. + + Parameters + ---------- + position : int + Specify the starting position of the source data in GPU + device buffer. Default: 0. + nbytes : int + Specify the number of bytes to copy. Default: -1 (all from + the position until host buffer is full). + buf : Buffer + Specify a pre-allocated output buffer in host. Default: None + (allocate new output buffer). + memory_pool : MemoryPool + resizable : bool + Specify extra arguments to allocate_buffer. Used only when + buf is None. + + Returns + ------- + buf : Buffer + Output buffer in host. + + """ + if position < 0 or (self.size and position > self.size) \ + or (self.size == 0 and position != 0): + raise ValueError('position argument is out-of-range') + cdef: + int64_t c_nbytes + if buf is None: + if nbytes < 0: + # copy all starting from position to new host buffer + c_nbytes = self.size - position + else: + if nbytes > self.size - position: + raise ValueError( + 'requested more to copy than available from ' + 'device buffer') + # copy nbytes starting from position to new host buffer + c_nbytes = nbytes + buf = allocate_buffer(c_nbytes, memory_pool=memory_pool, + resizable=resizable) + else: + if nbytes < 0: + # copy all from position until given host buffer is full + c_nbytes = min(self.size - position, buf.size) + else: + if nbytes > buf.size: + raise ValueError( + 'requested copy does not fit into host buffer') + # copy nbytes from position to given host buffer + c_nbytes = nbytes + + cdef: + shared_ptr[CBuffer] c_buf = pyarrow_unwrap_buffer(buf) + int64_t c_position = position + with nogil: + check_status(self.cuda_buffer.get() + .CopyToHost(c_position, c_nbytes, + c_buf.get().mutable_data())) + return buf + + def copy_from_host(self, data, int64_t position=0, int64_t nbytes=-1): + """Copy data from host to device. + + The device buffer must be pre-allocated. + + Parameters + ---------- + data : {Buffer, array-like} + Specify data in host. It can be array-like that is valid + argument to py_buffer + position : int + Specify the starting position of the copy in device buffer. + Default: 0. + nbytes : int + Specify the number of bytes to copy. Default: -1 (all from + source until device buffer, starting from position, is full) + + Returns + ------- + nbytes : int + Number of bytes copied. + """ + if position < 0 or position > self.size: + raise ValueError('position argument is out-of-range') + cdef: + int64_t c_nbytes + buf = as_buffer(data) + + if nbytes < 0: + # copy from host buffer to device buffer starting from + # position until device buffer is full + c_nbytes = min(self.size - position, buf.size) + else: + if nbytes > buf.size: + raise ValueError( + 'requested more to copy than available from host buffer') + if nbytes > self.size - position: + raise ValueError( + 'requested more to copy than available in device buffer') + # copy nbytes from host buffer to device buffer starting + # from position + c_nbytes = nbytes + + cdef: + shared_ptr[CBuffer] c_buf = pyarrow_unwrap_buffer(buf) + int64_t c_position = position + with nogil: + check_status(self.cuda_buffer.get(). + CopyFromHost(c_position, c_buf.get().data(), + c_nbytes)) + return c_nbytes + + def copy_from_device(self, buf, int64_t position=0, int64_t nbytes=-1): + """Copy data from device to device. + + Parameters + ---------- + buf : CudaBuffer + Specify source device buffer. + position : int + Specify the starting position of the copy in device buffer. + Default: 0. + nbytes : int + Specify the number of bytes to copy. Default: -1 (all from + source until device buffer, starting from position, is full) + + Returns + ------- + nbytes : int + Number of bytes copied. + + """ + if position < 0 or position > self.size: + raise ValueError('position argument is out-of-range') + cdef: + int64_t c_nbytes + + if nbytes < 0: + # copy from source device buffer to device buffer starting + # from position until device buffer is full + c_nbytes = min(self.size - position, buf.size) + else: + if nbytes > buf.size: + raise ValueError( + 'requested more to copy than available from device buffer') + if nbytes > self.size - position: + raise ValueError( + 'requested more to copy than available in device buffer') + # copy nbytes from source device buffer to device buffer + # starting from position + c_nbytes = nbytes + + cdef: + shared_ptr[CCudaBuffer] c_buf = pyarrow_unwrap_cudabuffer(buf) + int64_t c_position = position + shared_ptr[CCudaContext] c_src_ctx = pyarrow_unwrap_cudacontext( + buf.context) + void* c_source_data = (c_buf.get().address()) + + if self.context.handle != buf.context.handle: + with nogil: + check_status(self.cuda_buffer.get(). + CopyFromAnotherDevice(c_src_ctx, c_position, + c_source_data, c_nbytes)) + else: + with nogil: + check_status(self.cuda_buffer.get(). + CopyFromDevice(c_position, c_source_data, + c_nbytes)) + return c_nbytes + + def export_for_ipc(self): + """ + Expose this device buffer as IPC memory which can be used in other + processes. + + After calling this function, this device memory will not be + freed when the CudaBuffer is destructed. + + Returns + ------- + ipc_handle : IpcMemHandle + The exported IPC handle + + """ + cdef shared_ptr[CCudaIpcMemHandle] handle + with nogil: + handle = GetResultValue(self.cuda_buffer.get().ExportForIpc()) + return pyarrow_wrap_cudaipcmemhandle(handle) + + @property + def context(self): + """Returns the CUDA driver context of this buffer. + """ + return pyarrow_wrap_cudacontext(self.cuda_buffer.get().context()) + + def slice(self, offset=0, length=None): + """Return slice of device buffer + + Parameters + ---------- + offset : int, default 0 + Specify offset from the start of device buffer to slice + length : int, default None + Specify the length of slice (default is until end of device + buffer starting from offset). If the length is larger than + the data available, the returned slice will have a size of + the available data starting from the offset. + + Returns + ------- + sliced : CudaBuffer + Zero-copy slice of device buffer. + + """ + if offset < 0 or (self.size and offset >= self.size): + raise ValueError('offset argument is out-of-range') + cdef int64_t offset_ = offset + cdef int64_t size + if length is None: + size = self.size - offset_ + elif offset + length <= self.size: + size = length + else: + size = self.size - offset + parent = pyarrow_unwrap_cudabuffer(self) + return pyarrow_wrap_cudabuffer(make_shared[CCudaBuffer](parent, + offset_, size)) + + def to_pybytes(self): + """Return device buffer content as Python bytes. + """ + return self.copy_to_host().to_pybytes() + + def __getbuffer__(self, cp.Py_buffer* buffer, int flags): + # Device buffer contains data pointers on the device. Hence, + # cannot support buffer protocol PEP-3118 for CudaBuffer. + raise BufferError('buffer protocol for device buffer not supported') + + +cdef class HostBuffer(Buffer): + """Device-accessible CPU memory created using cudaHostAlloc. + + To create a HostBuffer instance, use + + cuda.new_host_buffer() + """ + + def __init__(self): + raise TypeError("Do not call HostBuffer's constructor directly," + " use `cuda.new_host_buffer` function instead.") + + cdef void init_host(self, const shared_ptr[CCudaHostBuffer]& buffer): + self.host_buffer = buffer + self.init( buffer) + + @property + def size(self): + return self.host_buffer.get().size() + + +cdef class BufferReader(NativeFile): + """File interface for zero-copy read from CUDA buffers. + + Note: Read methods return pointers to device memory. This means + you must be careful using this interface with any Arrow code which + may expect to be able to do anything other than pointer arithmetic + on the returned buffers. + """ + + def __cinit__(self, CudaBuffer obj): + self.buffer = obj + self.reader = new CCudaBufferReader(self.buffer.buffer) + self.set_random_access_file( + shared_ptr[CRandomAccessFile](self.reader)) + self.is_readable = True + + def read_buffer(self, nbytes=None): + """Return a slice view of the underlying device buffer. + + The slice will start at the current reader position and will + have specified size in bytes. + + Parameters + ---------- + nbytes : int, default None + Specify the number of bytes to read. Default: None (read all + remaining bytes). + + Returns + ------- + cbuf : CudaBuffer + New device buffer. + + """ + cdef: + int64_t c_nbytes + shared_ptr[CCudaBuffer] output + + if nbytes is None: + c_nbytes = self.size() - self.tell() + else: + c_nbytes = nbytes + + with nogil: + output = static_pointer_cast[CCudaBuffer, CBuffer]( + GetResultValue(self.reader.Read(c_nbytes))) + + return pyarrow_wrap_cudabuffer(output) + + +cdef class BufferWriter(NativeFile): + """File interface for writing to CUDA buffers. + + By default writes are unbuffered. Use set_buffer_size to enable + buffering. + """ + + def __cinit__(self, CudaBuffer buffer): + self.buffer = buffer + self.writer = new CCudaBufferWriter(self.buffer.cuda_buffer) + self.set_output_stream(shared_ptr[COutputStream](self.writer)) + self.is_writable = True + + def writeat(self, int64_t position, object data): + """Write data to buffer starting from position. + + Parameters + ---------- + position : int + Specify device buffer position where the data will be + written. + data : array-like + Specify data, the data instance must implement buffer + protocol. + """ + cdef: + Buffer buf = as_buffer(data) + const uint8_t* c_data = buf.buffer.get().data() + int64_t c_size = buf.buffer.get().size() + + with nogil: + check_status(self.writer.WriteAt(position, c_data, c_size)) + + def flush(self): + """ Flush the buffer stream """ + with nogil: + check_status(self.writer.Flush()) + + def seek(self, int64_t position, int whence=0): + # TODO: remove this method after NativeFile.seek supports + # writable files. + cdef int64_t offset + + with nogil: + if whence == 0: + offset = position + elif whence == 1: + offset = GetResultValue(self.writer.Tell()) + offset = offset + position + else: + with gil: + raise ValueError("Invalid value of whence: {0}" + .format(whence)) + check_status(self.writer.Seek(offset)) + return self.tell() + + @property + def buffer_size(self): + """Returns size of host (CPU) buffer, 0 for unbuffered + """ + return self.writer.buffer_size() + + @buffer_size.setter + def buffer_size(self, int64_t buffer_size): + """Set CPU buffer size to limit calls to cudaMemcpy + + Parameters + ---------- + buffer_size : int + Specify the size of CPU buffer to allocate in bytes. + """ + with nogil: + check_status(self.writer.SetBufferSize(buffer_size)) + + @property + def num_bytes_buffered(self): + """Returns number of bytes buffered on host + """ + return self.writer.num_bytes_buffered() + +# Functions + + +def new_host_buffer(const int64_t size, int device=0): + """Return buffer with CUDA-accessible memory on CPU host + + Parameters + ---------- + size : int + Specify the number of bytes to be allocated. + device : int + Specify GPU device number. + + Returns + ------- + dbuf : HostBuffer + Allocated host buffer + """ + cdef shared_ptr[CCudaHostBuffer] buffer + with nogil: + buffer = GetResultValue(AllocateCudaHostBuffer(device, size)) + return pyarrow_wrap_cudahostbuffer(buffer) + + +def serialize_record_batch(object batch, object ctx): + """ Write record batch message to GPU device memory + + Parameters + ---------- + batch : RecordBatch + Record batch to write + ctx : Context + CUDA Context to allocate device memory from + + Returns + ------- + dbuf : CudaBuffer + device buffer which contains the record batch message + """ + cdef shared_ptr[CCudaBuffer] buffer + cdef CRecordBatch* batch_ = pyarrow_unwrap_batch(batch).get() + cdef CCudaContext* ctx_ = pyarrow_unwrap_cudacontext(ctx).get() + with nogil: + buffer = GetResultValue(CudaSerializeRecordBatch(batch_[0], ctx_)) + return pyarrow_wrap_cudabuffer(buffer) + + +def read_message(object source, pool=None): + """ Read Arrow IPC message located on GPU device + + Parameters + ---------- + source : {CudaBuffer, cuda.BufferReader} + Device buffer or reader of device buffer. + pool : MemoryPool (optional) + Pool to allocate CPU memory for the metadata + + Returns + ------- + message : Message + The deserialized message, body still on device + """ + cdef: + Message result = Message.__new__(Message) + cdef CMemoryPool* pool_ = maybe_unbox_memory_pool(pool) + if not isinstance(source, BufferReader): + reader = BufferReader(source) + with nogil: + result.message = move( + GetResultValue(ReadMessage(reader.reader, pool_))) + return result + + +def read_record_batch(object buffer, object schema, *, + DictionaryMemo dictionary_memo=None, pool=None): + """Construct RecordBatch referencing IPC message located on CUDA device. + + While the metadata is copied to host memory for deserialization, + the record batch data remains on the device. + + Parameters + ---------- + buffer : + Device buffer containing the complete IPC message + schema : Schema + The schema for the record batch + dictionary_memo : DictionaryMemo, optional + If message contains dictionaries, must pass a populated + DictionaryMemo + pool : MemoryPool (optional) + Pool to allocate metadata from + + Returns + ------- + batch : RecordBatch + Reconstructed record batch, with device pointers + + """ + cdef: + shared_ptr[CSchema] schema_ = pyarrow_unwrap_schema(schema) + shared_ptr[CCudaBuffer] buffer_ = pyarrow_unwrap_cudabuffer(buffer) + CDictionaryMemo temp_memo + CDictionaryMemo* arg_dict_memo + CMemoryPool* pool_ = maybe_unbox_memory_pool(pool) + shared_ptr[CRecordBatch] batch + + if dictionary_memo is not None: + arg_dict_memo = dictionary_memo.memo + else: + arg_dict_memo = &temp_memo + + with nogil: + batch = GetResultValue(CudaReadRecordBatch( + schema_, arg_dict_memo, buffer_, pool_)) + return pyarrow_wrap_batch(batch) + + +# Public API + + +cdef public api bint pyarrow_is_buffer(object buffer): + return isinstance(buffer, Buffer) + +# cudabuffer + +cdef public api bint pyarrow_is_cudabuffer(object buffer): + return isinstance(buffer, CudaBuffer) + + +cdef public api object \ + pyarrow_wrap_cudabuffer_base(const shared_ptr[CCudaBuffer]& buf, base): + cdef CudaBuffer result = CudaBuffer.__new__(CudaBuffer) + result.init_cuda(buf, base) + return result + + +cdef public api object \ + pyarrow_wrap_cudabuffer(const shared_ptr[CCudaBuffer]& buf): + cdef CudaBuffer result = CudaBuffer.__new__(CudaBuffer) + result.init_cuda(buf, None) + return result + + +cdef public api shared_ptr[CCudaBuffer] pyarrow_unwrap_cudabuffer(object obj): + if pyarrow_is_cudabuffer(obj): + return (obj).cuda_buffer + raise TypeError('expected CudaBuffer instance, got %s' + % (type(obj).__name__)) + +# cudahostbuffer + +cdef public api bint pyarrow_is_cudahostbuffer(object buffer): + return isinstance(buffer, HostBuffer) + + +cdef public api object \ + pyarrow_wrap_cudahostbuffer(const shared_ptr[CCudaHostBuffer]& buf): + cdef HostBuffer result = HostBuffer.__new__(HostBuffer) + result.init_host(buf) + return result + + +cdef public api shared_ptr[CCudaHostBuffer] \ + pyarrow_unwrap_cudahostbuffer(object obj): + if pyarrow_is_cudahostbuffer(obj): + return (obj).host_buffer + raise TypeError('expected HostBuffer instance, got %s' + % (type(obj).__name__)) + +# cudacontext + +cdef public api bint pyarrow_is_cudacontext(object ctx): + return isinstance(ctx, Context) + + +cdef public api object \ + pyarrow_wrap_cudacontext(const shared_ptr[CCudaContext]& ctx): + cdef Context result = Context.__new__(Context) + result.init(ctx) + return result + + +cdef public api shared_ptr[CCudaContext] \ + pyarrow_unwrap_cudacontext(object obj): + if pyarrow_is_cudacontext(obj): + return (obj).context + raise TypeError('expected Context instance, got %s' + % (type(obj).__name__)) + +# cudaipcmemhandle + +cdef public api bint pyarrow_is_cudaipcmemhandle(object handle): + return isinstance(handle, IpcMemHandle) + + +cdef public api object \ + pyarrow_wrap_cudaipcmemhandle(shared_ptr[CCudaIpcMemHandle]& h): + cdef IpcMemHandle result = IpcMemHandle.__new__(IpcMemHandle) + result.init(h) + return result + + +cdef public api shared_ptr[CCudaIpcMemHandle] \ + pyarrow_unwrap_cudaipcmemhandle(object obj): + if pyarrow_is_cudaipcmemhandle(obj): + return (obj).handle + raise TypeError('expected IpcMemHandle instance, got %s' + % (type(obj).__name__)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_dataset_parquet.pyx b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_dataset_parquet.pyx new file mode 100644 index 0000000000000000000000000000000000000000..8fe9f30d33af9bc5cbf7cb25978334292f5ae9dc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_dataset_parquet.pyx @@ -0,0 +1,1053 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +"""Dataset support for Parquet file format.""" + +from cython cimport binding +from cython.operator cimport dereference as deref + +import os +import warnings + +import pyarrow as pa +from pyarrow.lib cimport * +from pyarrow.lib import frombytes, tobytes, is_threading_enabled +from pyarrow.includes.libarrow cimport * +from pyarrow.includes.libarrow_dataset cimport * +from pyarrow.includes.libarrow_dataset_parquet cimport * +from pyarrow._fs cimport FileSystem + +from pyarrow._compute cimport Expression, _bind +from pyarrow._dataset cimport ( + _make_file_source, + DatasetFactory, + FileFormat, + FileFragment, + FileWriteOptions, + Fragment, + FragmentScanOptions, + CacheOptions, + Partitioning, + PartitioningFactory, + WrittenFile +) + +from pyarrow._parquet cimport ( + _create_writer_properties, _create_arrow_writer_properties, + FileMetaData, +) + + +try: + from pyarrow._dataset_parquet_encryption import ( + set_encryption_config, set_decryption_config, set_decryption_properties + ) + parquet_encryption_enabled = True +except ImportError: + parquet_encryption_enabled = False + + +cdef Expression _true = Expression._scalar(True) + +ctypedef CParquetFileWriter* _CParquetFileWriterPtr + + +cdef class ParquetFileFormat(FileFormat): + """ + FileFormat for Parquet + + Parameters + ---------- + read_options : ParquetReadOptions + Read options for the file. + default_fragment_scan_options : ParquetFragmentScanOptions + Scan Options for the file. + **kwargs : dict + Additional options for read option or scan option + """ + + cdef: + CParquetFileFormat* parquet_format + + def __init__(self, read_options=None, + default_fragment_scan_options=None, + **kwargs): + cdef: + shared_ptr[CParquetFileFormat] wrapped + CParquetFileFormatReaderOptions* options + + # Read/scan options + read_options_args = {option: kwargs[option] for option in kwargs + if option in _PARQUET_READ_OPTIONS} + scan_args = {option: kwargs[option] for option in kwargs + if option not in _PARQUET_READ_OPTIONS} + if read_options and read_options_args: + duplicates = ', '.join(sorted(read_options_args)) + raise ValueError(f'If `read_options` is given, ' + f'cannot specify {duplicates}') + if default_fragment_scan_options and scan_args: + duplicates = ', '.join(sorted(scan_args)) + raise ValueError(f'If `default_fragment_scan_options` is given, ' + f'cannot specify {duplicates}') + + if read_options is None: + read_options = ParquetReadOptions(**read_options_args) + elif isinstance(read_options, dict): + # For backwards compatibility + duplicates = [] + for option, value in read_options.items(): + if option in _PARQUET_READ_OPTIONS: + read_options_args[option] = value + else: + duplicates.append(option) + scan_args[option] = value + if duplicates: + duplicates = ", ".join(duplicates) + warnings.warn(f'The scan options {duplicates} should be ' + 'specified directly as keyword arguments') + read_options = ParquetReadOptions(**read_options_args) + elif not isinstance(read_options, ParquetReadOptions): + raise TypeError('`read_options` must be either a dictionary or an ' + 'instance of ParquetReadOptions') + + if default_fragment_scan_options is None: + default_fragment_scan_options = ParquetFragmentScanOptions(**scan_args) + elif isinstance(default_fragment_scan_options, dict): + default_fragment_scan_options = ParquetFragmentScanOptions( + **default_fragment_scan_options) + elif not isinstance(default_fragment_scan_options, + ParquetFragmentScanOptions): + raise TypeError('`default_fragment_scan_options` must be either a ' + 'dictionary or an instance of ' + 'ParquetFragmentScanOptions') + + wrapped = make_shared[CParquetFileFormat]() + + options = &(wrapped.get().reader_options) + if read_options.dictionary_columns is not None: + for column in read_options.dictionary_columns: + options.dict_columns.insert(tobytes(column)) + options.coerce_int96_timestamp_unit = \ + read_options._coerce_int96_timestamp_unit + + self.init( wrapped) + self.default_fragment_scan_options = default_fragment_scan_options + + cdef void init(self, const shared_ptr[CFileFormat]& sp): + FileFormat.init(self, sp) + self.parquet_format = sp.get() + + cdef WrittenFile _finish_write(self, path, base_dir, + CFileWriter* file_writer): + cdef: + FileMetaData parquet_metadata + CParquetFileWriter* parquet_file_writer + + parquet_metadata = None + parquet_file_writer = dynamic_cast[_CParquetFileWriterPtr](file_writer) + with nogil: + metadata = deref( + deref(parquet_file_writer).parquet_writer()).metadata() + if metadata: + parquet_metadata = FileMetaData() + parquet_metadata.init(metadata) + parquet_metadata.set_file_path(os.path.relpath(path, base_dir)) + + size = GetResultValue(file_writer.GetBytesWritten()) + + return WrittenFile(path, parquet_metadata, size) + + @property + def read_options(self): + cdef CParquetFileFormatReaderOptions* options + options = &self.parquet_format.reader_options + parquet_read_options = ParquetReadOptions( + dictionary_columns={frombytes(col) + for col in options.dict_columns}, + ) + # Read options getter/setter works with strings so setting + # the private property which uses the C Type + parquet_read_options._coerce_int96_timestamp_unit = \ + options.coerce_int96_timestamp_unit + return parquet_read_options + + def make_write_options(self, **kwargs): + """ + Parameters + ---------- + **kwargs : dict + + Returns + ------- + pyarrow.dataset.FileWriteOptions + """ + # Safeguard from calling make_write_options as a static class method + if not isinstance(self, ParquetFileFormat): + raise TypeError("make_write_options() should be called on " + "an instance of ParquetFileFormat") + opts = FileFormat.make_write_options(self) + ( opts).update(**kwargs) + return opts + + cdef _set_default_fragment_scan_options(self, FragmentScanOptions options): + if options.type_name == 'parquet': + self.parquet_format.default_fragment_scan_options = options.wrapped + else: + super()._set_default_fragment_scan_options(options) + + def equals(self, ParquetFileFormat other): + """ + Parameters + ---------- + other : pyarrow.dataset.ParquetFileFormat + + Returns + ------- + bool + """ + return ( + self.read_options.equals(other.read_options) and + self.default_fragment_scan_options == + other.default_fragment_scan_options + ) + + @property + def default_extname(self): + return "parquet" + + def __reduce__(self): + return ParquetFileFormat, (self.read_options, + self.default_fragment_scan_options) + + def __repr__(self): + return f"" + + def make_fragment(self, file, filesystem=None, + Expression partition_expression=None, row_groups=None, *, file_size=None): + """ + Make a FileFragment from a given file. + + Parameters + ---------- + file : file-like object, path-like or str + The file or file path to make a fragment from. + filesystem : Filesystem, optional + If `filesystem` is given, `file` must be a string and specifies + the path of the file to read from the filesystem. + partition_expression : Expression, optional + An expression that is guaranteed true for all rows in the fragment. Allows + fragment to be potentially skipped while scanning with a filter. + row_groups : Iterable, optional + The indices of the row groups to include + file_size : int, optional + The size of the file in bytes. Can improve performance with high-latency filesystems + when file size needs to be known before reading. + + Returns + ------- + fragment : Fragment + The file fragment + """ + cdef: + vector[int] c_row_groups + if partition_expression is None: + partition_expression = _true + if row_groups is None: + return super().make_fragment(file, filesystem, + partition_expression, file_size=file_size) + + c_source = _make_file_source(file, filesystem, file_size) + c_row_groups = [ row_group for row_group in set(row_groups)] + + c_fragment = GetResultValue( + self.parquet_format.MakeFragment(move(c_source), + partition_expression.unwrap(), + nullptr, + move(c_row_groups))) + return Fragment.wrap(move(c_fragment)) + + +class RowGroupInfo: + """ + A wrapper class for RowGroup information + + Parameters + ---------- + id : integer + The group ID. + metadata : FileMetaData + The rowgroup metadata. + schema : Schema + Schema of the rows. + """ + + def __init__(self, id, metadata, schema): + self.id = id + self.metadata = metadata + self.schema = schema + + @property + def num_rows(self): + return self.metadata.num_rows + + @property + def total_byte_size(self): + return self.metadata.total_byte_size + + @property + def statistics(self): + def name_stats(i): + col = self.metadata.column(i) + + stats = col.statistics + if stats is None or not stats.has_min_max: + return None, None + + name = col.path_in_schema + field_index = self.schema.get_field_index(name) + if field_index < 0: + return None, None + + typ = self.schema.field(field_index).type + return col.path_in_schema, { + 'min': pa.scalar(stats.min, type=typ).as_py(), + 'max': pa.scalar(stats.max, type=typ).as_py() + } + + return { + name: stats for name, stats + in map(name_stats, range(self.metadata.num_columns)) + if stats is not None + } + + def __repr__(self): + return "RowGroupInfo({})".format(self.id) + + def __eq__(self, other): + if isinstance(other, int): + return self.id == other + if not isinstance(other, RowGroupInfo): + return False + return self.id == other.id + + +cdef class ParquetFileFragment(FileFragment): + """A Fragment representing a parquet file.""" + + cdef: + CParquetFileFragment* parquet_file_fragment + + cdef void init(self, const shared_ptr[CFragment]& sp): + FileFragment.init(self, sp) + self.parquet_file_fragment = sp.get() + + def __reduce__(self): + buffer = self.buffer + # parquet_file_fragment.row_groups() is empty if the metadata + # information of the file is not yet populated + if not bool(self.parquet_file_fragment.row_groups()): + row_groups = None + else: + row_groups = [row_group.id for row_group in self.row_groups] + + return self.format.make_fragment, ( + self.path if buffer is None else buffer, + self.filesystem, + self.partition_expression, + row_groups + ) + + def ensure_complete_metadata(self): + """ + Ensure that all metadata (statistics, physical schema, ...) have + been read and cached in this fragment. + """ + with nogil: + check_status(self.parquet_file_fragment.EnsureCompleteMetadata()) + + @property + def row_groups(self): + metadata = self.metadata + cdef vector[int] row_groups = self.parquet_file_fragment.row_groups() + return [RowGroupInfo(i, metadata.row_group(i), self.physical_schema) + for i in row_groups] + + @property + def metadata(self): + self.ensure_complete_metadata() + cdef FileMetaData metadata = FileMetaData() + metadata.init(self.parquet_file_fragment.metadata()) + return metadata + + @property + def num_row_groups(self): + """ + Return the number of row groups viewed by this fragment (not the + number of row groups in the origin file). + """ + self.ensure_complete_metadata() + return self.parquet_file_fragment.row_groups().size() + + def split_by_row_group(self, Expression filter=None, + Schema schema=None): + """ + Split the fragment into multiple fragments. + + Yield a Fragment wrapping each row group in this ParquetFileFragment. + Row groups will be excluded whose metadata contradicts the optional + filter. + + Parameters + ---------- + filter : Expression, default None + Only include the row groups which satisfy this predicate (using + the Parquet RowGroup statistics). + schema : Schema, default None + Schema to use when filtering row groups. Defaults to the + Fragment's physical schema + + Returns + ------- + A list of Fragments + """ + cdef: + vector[shared_ptr[CFragment]] c_fragments + CExpression c_filter + shared_ptr[CFragment] c_fragment + + schema = schema or self.physical_schema + c_filter = _bind(filter, schema) + with nogil: + c_fragments = move(GetResultValue( + self.parquet_file_fragment.SplitByRowGroup(move(c_filter)))) + + return [Fragment.wrap(c_fragment) for c_fragment in c_fragments] + + def subset(self, Expression filter=None, Schema schema=None, + object row_group_ids=None): + """ + Create a subset of the fragment (viewing a subset of the row groups). + + Subset can be specified by either a filter predicate (with optional + schema) or by a list of row group IDs. Note that when using a filter, + the resulting fragment can be empty (viewing no row groups). + + Parameters + ---------- + filter : Expression, default None + Only include the row groups which satisfy this predicate (using + the Parquet RowGroup statistics). + schema : Schema, default None + Schema to use when filtering row groups. Defaults to the + Fragment's physical schema + row_group_ids : list of ints + The row group IDs to include in the subset. Can only be specified + if `filter` is None. + + Returns + ------- + ParquetFileFragment + """ + cdef: + CExpression c_filter + vector[int] c_row_group_ids + shared_ptr[CFragment] c_fragment + + if filter is not None and row_group_ids is not None: + raise ValueError( + "Cannot specify both 'filter' and 'row_group_ids'." + ) + + if filter is not None: + schema = schema or self.physical_schema + c_filter = _bind(filter, schema) + with nogil: + c_fragment = move(GetResultValue( + self.parquet_file_fragment.SubsetWithFilter( + move(c_filter)))) + elif row_group_ids is not None: + c_row_group_ids = [ + row_group for row_group in sorted(set(row_group_ids)) + ] + with nogil: + c_fragment = move(GetResultValue( + self.parquet_file_fragment.SubsetWithIds( + move(c_row_group_ids)))) + else: + raise ValueError( + "Need to specify one of 'filter' or 'row_group_ids'" + ) + + return Fragment.wrap(c_fragment) + + +cdef class ParquetReadOptions(_Weakrefable): + """ + Parquet format specific options for reading. + + Parameters + ---------- + dictionary_columns : list of string, default None + Names of columns which should be dictionary encoded as + they are read + coerce_int96_timestamp_unit : str, default None + Cast timestamps that are stored in INT96 format to a particular + resolution (e.g. 'ms'). Setting to None is equivalent to 'ns' + and therefore INT96 timestamps will be inferred as timestamps + in nanoseconds + """ + + cdef public: + set dictionary_columns + TimeUnit _coerce_int96_timestamp_unit + + # Also see _PARQUET_READ_OPTIONS + def __init__(self, dictionary_columns=None, + coerce_int96_timestamp_unit=None): + self.dictionary_columns = set(dictionary_columns or set()) + self.coerce_int96_timestamp_unit = coerce_int96_timestamp_unit + + @property + def coerce_int96_timestamp_unit(self): + return timeunit_to_string(self._coerce_int96_timestamp_unit) + + @coerce_int96_timestamp_unit.setter + def coerce_int96_timestamp_unit(self, unit): + if unit is not None: + self._coerce_int96_timestamp_unit = string_to_timeunit(unit) + else: + self._coerce_int96_timestamp_unit = TimeUnit_NANO + + def equals(self, ParquetReadOptions other): + """ + Parameters + ---------- + other : pyarrow.dataset.ParquetReadOptions + + Returns + ------- + bool + """ + return (self.dictionary_columns == other.dictionary_columns and + self.coerce_int96_timestamp_unit == + other.coerce_int96_timestamp_unit) + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return False + + def __repr__(self): + return ( + f"" + ) + + +cdef class ParquetFileWriteOptions(FileWriteOptions): + + def update(self, **kwargs): + """ + Parameters + ---------- + **kwargs : dict + """ + arrow_fields = { + "use_deprecated_int96_timestamps", + "coerce_timestamps", + "allow_truncated_timestamps", + "use_compliant_nested_type", + } + + setters = set() + for name, value in kwargs.items(): + if name not in self._properties: + raise TypeError("unexpected parquet write option: " + name) + self._properties[name] = value + if name in arrow_fields: + setters.add(self._set_arrow_properties) + elif name == "encryption_config" and value is not None: + setters.add(self._set_encryption_config) + else: + setters.add(self._set_properties) + + for setter in setters: + setter() + + def _set_properties(self): + cdef CParquetFileWriteOptions* opts = self.parquet_options + + opts.writer_properties = _create_writer_properties( + use_dictionary=self._properties["use_dictionary"], + compression=self._properties["compression"], + version=self._properties["version"], + write_statistics=self._properties["write_statistics"], + data_page_size=self._properties["data_page_size"], + compression_level=self._properties["compression_level"], + use_byte_stream_split=( + self._properties["use_byte_stream_split"] + ), + column_encoding=self._properties["column_encoding"], + data_page_version=self._properties["data_page_version"], + encryption_properties=self._properties["encryption_properties"], + write_batch_size=self._properties["write_batch_size"], + dictionary_pagesize_limit=self._properties["dictionary_pagesize_limit"], + write_page_index=self._properties["write_page_index"], + write_page_checksum=self._properties["write_page_checksum"], + sorting_columns=self._properties["sorting_columns"], + store_decimal_as_integer=self._properties["store_decimal_as_integer"], + ) + + def _set_arrow_properties(self): + cdef CParquetFileWriteOptions* opts = self.parquet_options + + opts.arrow_writer_properties = _create_arrow_writer_properties( + use_deprecated_int96_timestamps=( + self._properties["use_deprecated_int96_timestamps"] + ), + coerce_timestamps=self._properties["coerce_timestamps"], + allow_truncated_timestamps=( + self._properties["allow_truncated_timestamps"] + ), + writer_engine_version="V2", + use_compliant_nested_type=( + self._properties["use_compliant_nested_type"] + ) + ) + + def _set_encryption_config(self): + if not parquet_encryption_enabled: + raise NotImplementedError( + "Encryption is not enabled in your installation of pyarrow, but an " + "encryption_config was provided." + ) + set_encryption_config(self, self._properties["encryption_config"]) + + cdef void init(self, const shared_ptr[CFileWriteOptions]& sp): + FileWriteOptions.init(self, sp) + self.parquet_options = sp.get() + self._properties = dict( + use_dictionary=True, + compression="snappy", + version="2.6", + write_statistics=None, + data_page_size=None, + compression_level=None, + use_byte_stream_split=False, + column_encoding=None, + data_page_version="1.0", + use_deprecated_int96_timestamps=False, + coerce_timestamps=None, + allow_truncated_timestamps=False, + use_compliant_nested_type=True, + encryption_properties=None, + write_batch_size=None, + dictionary_pagesize_limit=None, + write_page_index=False, + encryption_config=None, + write_page_checksum=False, + sorting_columns=None, + store_decimal_as_integer=False, + ) + + self._set_properties() + self._set_arrow_properties() + + def __repr__(self): + return "".format( + " ".join([f"{key}={value}" for key, value in self._properties.items()]) + ) + + +cdef set _PARQUET_READ_OPTIONS = { + 'dictionary_columns', 'coerce_int96_timestamp_unit' +} + + +cdef class ParquetFragmentScanOptions(FragmentScanOptions): + """ + Scan-specific options for Parquet fragments. + + Parameters + ---------- + use_buffered_stream : bool, default False + Read files through buffered input streams rather than loading entire + row groups at once. This may be enabled to reduce memory overhead. + Disabled by default. + buffer_size : int, default 8192 + Size of buffered stream, if enabled. Default is 8KB. + pre_buffer : bool, default True + If enabled, pre-buffer the raw Parquet data instead of issuing one + read per column chunk. This can improve performance on high-latency + filesystems (e.g. S3, GCS) by coalescing and issuing file reads in + parallel using a background I/O thread pool. + Set to False if you want to prioritize minimal memory usage + over maximum speed. + cache_options : pyarrow.CacheOptions, default None + Cache options used when pre_buffer is enabled. The default values should + be good for most use cases. You may want to adjust these for example if + you have exceptionally high latency to the file system. + thrift_string_size_limit : int, default None + If not None, override the maximum total string size allocated + when decoding Thrift structures. The default limit should be + sufficient for most Parquet files. + thrift_container_size_limit : int, default None + If not None, override the maximum total size of containers allocated + when decoding Thrift structures. The default limit should be + sufficient for most Parquet files. + decryption_config : pyarrow.dataset.ParquetDecryptionConfig, default None + If not None, use the provided ParquetDecryptionConfig to decrypt the + Parquet file. + decryption_properties : pyarrow.parquet.FileDecryptionProperties, default None + If not None, use the provided FileDecryptionProperties to decrypt encrypted + Parquet file. + page_checksum_verification : bool, default False + If True, verify the page checksum for each page read from the file. + """ + + # Avoid mistakingly creating attributes + __slots__ = () + + def __init__(self, *, bint use_buffered_stream=False, + buffer_size=8192, + bint pre_buffer=True, + cache_options=None, + thrift_string_size_limit=None, + thrift_container_size_limit=None, + decryption_config=None, + decryption_properties=None, + bint page_checksum_verification=False): + self.init(shared_ptr[CFragmentScanOptions]( + new CParquetFragmentScanOptions())) + self.use_buffered_stream = use_buffered_stream + self.buffer_size = buffer_size + if pre_buffer and not is_threading_enabled(): + pre_buffer = False + self.pre_buffer = pre_buffer + if cache_options is not None: + self.cache_options = cache_options + if thrift_string_size_limit is not None: + self.thrift_string_size_limit = thrift_string_size_limit + if thrift_container_size_limit is not None: + self.thrift_container_size_limit = thrift_container_size_limit + if decryption_config is not None: + self.parquet_decryption_config = decryption_config + if decryption_properties is not None: + self.decryption_properties = decryption_properties + self.page_checksum_verification = page_checksum_verification + + cdef void init(self, const shared_ptr[CFragmentScanOptions]& sp): + FragmentScanOptions.init(self, sp) + self.parquet_options = sp.get() + + cdef CReaderProperties* reader_properties(self): + return self.parquet_options.reader_properties.get() + + cdef ArrowReaderProperties* arrow_reader_properties(self): + return self.parquet_options.arrow_reader_properties.get() + + @property + def use_buffered_stream(self): + return self.reader_properties().is_buffered_stream_enabled() + + @use_buffered_stream.setter + def use_buffered_stream(self, bint use_buffered_stream): + if use_buffered_stream: + self.reader_properties().enable_buffered_stream() + else: + self.reader_properties().disable_buffered_stream() + + @property + def buffer_size(self): + return self.reader_properties().buffer_size() + + @buffer_size.setter + def buffer_size(self, buffer_size): + if buffer_size <= 0: + raise ValueError("Buffer size must be larger than zero") + self.reader_properties().set_buffer_size(buffer_size) + + @property + def pre_buffer(self): + return self.arrow_reader_properties().pre_buffer() + + @pre_buffer.setter + def pre_buffer(self, bint pre_buffer): + if pre_buffer and not is_threading_enabled(): + return + self.arrow_reader_properties().set_pre_buffer(pre_buffer) + + @property + def cache_options(self): + return CacheOptions.wrap(self.arrow_reader_properties().cache_options()) + + @cache_options.setter + def cache_options(self, CacheOptions options): + self.arrow_reader_properties().set_cache_options(options.unwrap()) + + @property + def thrift_string_size_limit(self): + return self.reader_properties().thrift_string_size_limit() + + @thrift_string_size_limit.setter + def thrift_string_size_limit(self, size): + if size <= 0: + raise ValueError("size must be larger than zero") + self.reader_properties().set_thrift_string_size_limit(size) + + @property + def thrift_container_size_limit(self): + return self.reader_properties().thrift_container_size_limit() + + @thrift_container_size_limit.setter + def thrift_container_size_limit(self, size): + if size <= 0: + raise ValueError("size must be larger than zero") + self.reader_properties().set_thrift_container_size_limit(size) + + @property + def decryption_properties(self): + if not parquet_encryption_enabled: + raise NotImplementedError( + "Unable to access encryption features. " + "Encryption is not enabled in your installation of pyarrow." + ) + return self._decryption_properties + + @decryption_properties.setter + def decryption_properties(self, config): + if not parquet_encryption_enabled: + raise NotImplementedError( + "Encryption is not enabled in your installation of pyarrow, but " + "decryption_properties were provided." + ) + set_decryption_properties(self, config) + self._decryption_properties = config + + @property + def parquet_decryption_config(self): + if not parquet_encryption_enabled: + raise NotImplementedError( + "Unable to access encryption features. " + "Encryption is not enabled in your installation of pyarrow." + ) + return self._parquet_decryption_config + + @parquet_decryption_config.setter + def parquet_decryption_config(self, config): + if not parquet_encryption_enabled: + raise NotImplementedError( + "Encryption is not enabled in your installation of pyarrow, but a " + "decryption_config was provided." + ) + set_decryption_config(self, config) + self._parquet_decryption_config = config + + @property + def page_checksum_verification(self): + return self.reader_properties().page_checksum_verification() + + @page_checksum_verification.setter + def page_checksum_verification(self, bint page_checksum_verification): + self.reader_properties().set_page_checksum_verification(page_checksum_verification) + + def equals(self, ParquetFragmentScanOptions other): + """ + Parameters + ---------- + other : pyarrow.dataset.ParquetFragmentScanOptions + + Returns + ------- + bool + """ + attrs = ( + self.use_buffered_stream, self.buffer_size, self.pre_buffer, self.cache_options, + self.thrift_string_size_limit, self.thrift_container_size_limit, + self.page_checksum_verification) + other_attrs = ( + other.use_buffered_stream, other.buffer_size, other.pre_buffer, other.cache_options, + other.thrift_string_size_limit, + other.thrift_container_size_limit, other.page_checksum_verification) + return attrs == other_attrs + + @staticmethod + @binding(True) # Required for Cython < 3 + def _reconstruct(kwargs): + # __reduce__ doesn't allow passing named arguments directly to the + # reconstructor, hence this wrapper. + return ParquetFragmentScanOptions(**kwargs) + + def __reduce__(self): + kwargs = dict( + use_buffered_stream=self.use_buffered_stream, + buffer_size=self.buffer_size, + pre_buffer=self.pre_buffer, + cache_options=self.cache_options, + thrift_string_size_limit=self.thrift_string_size_limit, + thrift_container_size_limit=self.thrift_container_size_limit, + page_checksum_verification=self.page_checksum_verification + ) + return ParquetFragmentScanOptions._reconstruct, (kwargs,) + + +cdef class ParquetFactoryOptions(_Weakrefable): + """ + Influences the discovery of parquet dataset. + + Parameters + ---------- + partition_base_dir : str, optional + For the purposes of applying the partitioning, paths will be + stripped of the partition_base_dir. Files not matching the + partition_base_dir prefix will be skipped for partitioning discovery. + The ignored files will still be part of the Dataset, but will not + have partition information. + partitioning : Partitioning, PartitioningFactory, optional + The partitioning scheme applied to fragments, see ``Partitioning``. + validate_column_chunk_paths : bool, default False + Assert that all ColumnChunk paths are consistent. The parquet spec + allows for ColumnChunk data to be stored in multiple files, but + ParquetDatasetFactory supports only a single file with all ColumnChunk + data. If this flag is set construction of a ParquetDatasetFactory will + raise an error if ColumnChunk data is not resident in a single file. + """ + + cdef: + CParquetFactoryOptions options + + __slots__ = () # avoid mistakingly creating attributes + + def __init__(self, partition_base_dir=None, partitioning=None, + validate_column_chunk_paths=False): + if isinstance(partitioning, PartitioningFactory): + self.partitioning_factory = partitioning + elif isinstance(partitioning, Partitioning): + self.partitioning = partitioning + + if partition_base_dir is not None: + self.partition_base_dir = partition_base_dir + + self.options.validate_column_chunk_paths = validate_column_chunk_paths + + cdef inline CParquetFactoryOptions unwrap(self): + return self.options + + @property + def partitioning(self): + """Partitioning to apply to discovered files. + + NOTE: setting this property will overwrite partitioning_factory. + """ + c_partitioning = self.options.partitioning.partitioning() + if c_partitioning.get() == nullptr: + return None + return Partitioning.wrap(c_partitioning) + + @partitioning.setter + def partitioning(self, Partitioning value): + self.options.partitioning = ( value).unwrap() + + @property + def partitioning_factory(self): + """PartitioningFactory to apply to discovered files and + discover a Partitioning. + + NOTE: setting this property will overwrite partitioning. + """ + c_factory = self.options.partitioning.factory() + if c_factory.get() == nullptr: + return None + return PartitioningFactory.wrap(c_factory, None, None) + + @partitioning_factory.setter + def partitioning_factory(self, PartitioningFactory value): + self.options.partitioning = ( value).unwrap() + + @property + def partition_base_dir(self): + """ + Base directory to strip paths before applying the partitioning. + """ + return frombytes(self.options.partition_base_dir) + + @partition_base_dir.setter + def partition_base_dir(self, value): + self.options.partition_base_dir = tobytes(value) + + @property + def validate_column_chunk_paths(self): + """ + Base directory to strip paths before applying the partitioning. + """ + return self.options.validate_column_chunk_paths + + @validate_column_chunk_paths.setter + def validate_column_chunk_paths(self, value): + self.options.validate_column_chunk_paths = value + + +cdef class ParquetDatasetFactory(DatasetFactory): + """ + Create a ParquetDatasetFactory from a Parquet `_metadata` file. + + Parameters + ---------- + metadata_path : str + Path to the `_metadata` parquet metadata-only file generated with + `pyarrow.parquet.write_metadata`. + filesystem : pyarrow.fs.FileSystem + Filesystem to read the metadata_path from, and subsequent parquet + files. + format : ParquetFileFormat + Parquet format options. + options : ParquetFactoryOptions, optional + Various flags influencing the discovery of filesystem paths. + """ + + cdef: + CParquetDatasetFactory* parquet_factory + + def __init__(self, metadata_path, FileSystem filesystem not None, + FileFormat format not None, + ParquetFactoryOptions options=None): + cdef: + c_string c_path + shared_ptr[CFileSystem] c_filesystem + shared_ptr[CParquetFileFormat] c_format + CResult[shared_ptr[CDatasetFactory]] result + CParquetFactoryOptions c_options + + c_path = tobytes(metadata_path) + c_filesystem = filesystem.unwrap() + c_format = static_pointer_cast[CParquetFileFormat, CFileFormat]( + format.unwrap()) + options = options or ParquetFactoryOptions() + c_options = options.unwrap() + + with nogil: + result = CParquetDatasetFactory.MakeFromMetaDataPath( + c_path, c_filesystem, c_format, c_options) + self.init(GetResultValue(result)) + + cdef init(self, shared_ptr[CDatasetFactory]& sp): + DatasetFactory.init(self, sp) + self.parquet_factory = sp.get() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_fs.pxd b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_fs.pxd new file mode 100644 index 0000000000000000000000000000000000000000..0df75530bbd6ec3552131e11acc5b0406627fe65 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_fs.pxd @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow_fs cimport * +from pyarrow.lib import _detect_compression, frombytes, tobytes +from pyarrow.lib cimport * + + +cpdef enum FileType: + NotFound = CFileType_NotFound + Unknown = CFileType_Unknown + File = CFileType_File + Directory = CFileType_Directory + + +cdef class FileInfo(_Weakrefable): + cdef: + CFileInfo info + + @staticmethod + cdef wrap(CFileInfo info) + + cdef inline CFileInfo unwrap(self) nogil + + @staticmethod + cdef CFileInfo unwrap_safe(obj) + + +cdef class FileSelector(_Weakrefable): + cdef: + CFileSelector selector + + @staticmethod + cdef FileSelector wrap(CFileSelector selector) + + cdef inline CFileSelector unwrap(self) nogil + + +cdef class FileSystem(_Weakrefable): + cdef: + shared_ptr[CFileSystem] wrapped + CFileSystem* fs + + cdef init(self, const shared_ptr[CFileSystem]& wrapped) + + @staticmethod + cdef wrap(const shared_ptr[CFileSystem]& sp) + + cdef inline shared_ptr[CFileSystem] unwrap(self) nogil + + +cdef class LocalFileSystem(FileSystem): + cdef init(self, const shared_ptr[CFileSystem]& wrapped) + + +cdef class SubTreeFileSystem(FileSystem): + cdef: + CSubTreeFileSystem* subtreefs + + cdef init(self, const shared_ptr[CFileSystem]& wrapped) + + +cdef class _MockFileSystem(FileSystem): + cdef: + CMockFileSystem* mockfs + + cdef init(self, const shared_ptr[CFileSystem]& wrapped) + + +cdef class PyFileSystem(FileSystem): + cdef: + CPyFileSystem* pyfs + + cdef init(self, const shared_ptr[CFileSystem]& wrapped) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_hdfs.pyx b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_hdfs.pyx new file mode 100644 index 0000000000000000000000000000000000000000..c426337a12ec184feb2d699e1e685228c249466e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_hdfs.pyx @@ -0,0 +1,160 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +from cython cimport binding + +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * +from pyarrow.includes.libarrow_fs cimport * +from pyarrow._fs cimport FileSystem + +from pyarrow.lib import frombytes, tobytes +from pyarrow.util import _stringify_path + + +cdef class HadoopFileSystem(FileSystem): + """ + HDFS backed FileSystem implementation + + Parameters + ---------- + host : str + HDFS host to connect to. Set to "default" for fs.defaultFS from + core-site.xml. + port : int, default 8020 + HDFS port to connect to. Set to 0 for default or logical (HA) nodes. + user : str, default None + Username when connecting to HDFS; None implies login user. + replication : int, default 3 + Number of copies each block will have. + buffer_size : int, default 0 + If 0, no buffering will happen otherwise the size of the temporary read + and write buffer. + default_block_size : int, default None + None means the default configuration for HDFS, a typical block size is + 128 MB. + kerb_ticket : string or path, default None + If not None, the path to the Kerberos ticket cache. + extra_conf : dict, default None + Extra key/value pairs for configuration; will override any + hdfs-site.xml properties. + + Examples + -------- + >>> from pyarrow import fs + >>> hdfs = fs.HadoopFileSystem(host, port, user=user, kerb_ticket=ticket_cache_path) # doctest: +SKIP + + For usage of the methods see examples for :func:`~pyarrow.fs.LocalFileSystem`. + """ + + cdef: + CHadoopFileSystem* hdfs + + def __init__(self, str host, int port=8020, *, str user=None, + int replication=3, int buffer_size=0, + default_block_size=None, kerb_ticket=None, + extra_conf=None): + cdef: + CHdfsOptions options + shared_ptr[CHadoopFileSystem] wrapped + + if not host.startswith(('hdfs://', 'viewfs://')) and host != "default": + # TODO(kszucs): do more sanitization + host = 'hdfs://{}'.format(host) + + options.ConfigureEndPoint(tobytes(host), int(port)) + options.ConfigureReplication(replication) + options.ConfigureBufferSize(buffer_size) + + if user is not None: + options.ConfigureUser(tobytes(user)) + if default_block_size is not None: + options.ConfigureBlockSize(default_block_size) + if kerb_ticket is not None: + options.ConfigureKerberosTicketCachePath( + tobytes(_stringify_path(kerb_ticket))) + if extra_conf is not None: + for k, v in extra_conf.items(): + options.ConfigureExtraConf(tobytes(k), tobytes(v)) + + with nogil: + wrapped = GetResultValue(CHadoopFileSystem.Make(options)) + self.init( wrapped) + + cdef init(self, const shared_ptr[CFileSystem]& wrapped): + FileSystem.init(self, wrapped) + self.hdfs = wrapped.get() + + @staticmethod + def from_uri(uri): + """ + Instantiate HadoopFileSystem object from an URI string. + + The following two calls are equivalent + + * ``HadoopFileSystem.from_uri('hdfs://localhost:8020/?user=test\ +&replication=1')`` + * ``HadoopFileSystem('localhost', port=8020, user='test', \ +replication=1)`` + + Parameters + ---------- + uri : str + A string URI describing the connection to HDFS. + In order to change the user, replication, buffer_size or + default_block_size pass the values as query parts. + + Returns + ------- + HadoopFileSystem + """ + cdef: + HadoopFileSystem self = HadoopFileSystem.__new__(HadoopFileSystem) + shared_ptr[CHadoopFileSystem] wrapped + CHdfsOptions options + + options = GetResultValue(CHdfsOptions.FromUriString(tobytes(uri))) + with nogil: + wrapped = GetResultValue(CHadoopFileSystem.Make(options)) + + self.init( wrapped) + return self + + @staticmethod + @binding(True) # Required for cython < 3 + def _reconstruct(kwargs): + # __reduce__ doesn't allow passing named arguments directly to the + # reconstructor, hence this wrapper. + return HadoopFileSystem(**kwargs) + + def __reduce__(self): + cdef CHdfsOptions opts = self.hdfs.options() + return ( + HadoopFileSystem._reconstruct, (dict( + host=frombytes(opts.connection_config.host), + port=opts.connection_config.port, + user=frombytes(opts.connection_config.user), + replication=opts.replication, + buffer_size=opts.buffer_size, + default_block_size=opts.default_block_size, + kerb_ticket=frombytes(opts.connection_config.kerb_ticket), + extra_conf={frombytes(k): frombytes(v) + for k, v in opts.connection_config.extra_conf}, + ),) + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_json.pxd b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_json.pxd new file mode 100644 index 0000000000000000000000000000000000000000..42a0a678a9b6a543c657c905f3eb4fa4490b6edf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_json.pxd @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +from pyarrow.includes.libarrow cimport * +from pyarrow.lib cimport _Weakrefable + + +cdef class ParseOptions(_Weakrefable): + cdef: + CJSONParseOptions options + + @staticmethod + cdef ParseOptions wrap(CJSONParseOptions options) + +cdef class ReadOptions(_Weakrefable): + cdef: + CJSONReadOptions options + + @staticmethod + cdef ReadOptions wrap(CJSONReadOptions options) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_json.pyx b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_json.pyx new file mode 100644 index 0000000000000000000000000000000000000000..d36dad67abbaa575d8963273c884dd9e8f047b13 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_json.pyx @@ -0,0 +1,310 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: profile=False +# distutils: language = c++ +# cython: language_level = 3 + +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * +from pyarrow.lib cimport (_Weakrefable, MemoryPool, + maybe_unbox_memory_pool, + get_input_stream, pyarrow_wrap_table, + pyarrow_wrap_schema, pyarrow_unwrap_schema) + + +cdef class ReadOptions(_Weakrefable): + """ + Options for reading JSON files. + + Parameters + ---------- + use_threads : bool, optional (default True) + Whether to use multiple threads to accelerate reading + block_size : int, optional + How much bytes to process at a time from the input stream. + This will determine multi-threading granularity as well as + the size of individual chunks in the Table. + """ + + # Avoid mistakingly creating attributes + __slots__ = () + + def __init__(self, use_threads=None, block_size=None): + self.options = CJSONReadOptions.Defaults() + if use_threads is not None: + self.use_threads = use_threads + if block_size is not None: + self.block_size = block_size + + @property + def use_threads(self): + """ + Whether to use multiple threads to accelerate reading. + """ + return self.options.use_threads + + @use_threads.setter + def use_threads(self, value): + self.options.use_threads = value + + @property + def block_size(self): + """ + How much bytes to process at a time from the input stream. + + This will determine multi-threading granularity as well as the size of + individual chunks in the Table. + """ + return self.options.block_size + + @block_size.setter + def block_size(self, value): + self.options.block_size = value + + def __reduce__(self): + return ReadOptions, ( + self.use_threads, + self.block_size + ) + + def equals(self, ReadOptions other): + """ + Parameters + ---------- + other : pyarrow.json.ReadOptions + + Returns + ------- + bool + """ + return ( + self.use_threads == other.use_threads and + self.block_size == other.block_size + ) + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return False + + @staticmethod + cdef ReadOptions wrap(CJSONReadOptions options): + out = ReadOptions() + out.options = options # shallow copy + return out + + +cdef class ParseOptions(_Weakrefable): + """ + Options for parsing JSON files. + + Parameters + ---------- + explicit_schema : Schema, optional (default None) + Optional explicit schema (no type inference, ignores other fields). + newlines_in_values : bool, optional (default False) + Whether objects may be printed across multiple lines (for example + pretty printed). If false, input must end with an empty line. + unexpected_field_behavior : str, default "infer" + How JSON fields outside of explicit_schema (if given) are treated. + + Possible behaviors: + + - "ignore": unexpected JSON fields are ignored + - "error": error out on unexpected JSON fields + - "infer": unexpected JSON fields are type-inferred and included in + the output + """ + + __slots__ = () + + def __init__(self, explicit_schema=None, newlines_in_values=None, + unexpected_field_behavior=None): + self.options = CJSONParseOptions.Defaults() + if explicit_schema is not None: + self.explicit_schema = explicit_schema + if newlines_in_values is not None: + self.newlines_in_values = newlines_in_values + if unexpected_field_behavior is not None: + self.unexpected_field_behavior = unexpected_field_behavior + + def __reduce__(self): + return ParseOptions, ( + self.explicit_schema, + self.newlines_in_values, + self.unexpected_field_behavior + ) + + @property + def explicit_schema(self): + """ + Optional explicit schema (no type inference, ignores other fields) + """ + if self.options.explicit_schema.get() == NULL: + return None + else: + return pyarrow_wrap_schema(self.options.explicit_schema) + + @explicit_schema.setter + def explicit_schema(self, value): + self.options.explicit_schema = pyarrow_unwrap_schema(value) + + @property + def newlines_in_values(self): + """ + Whether newline characters are allowed in JSON values. + Setting this to True reduces the performance of multi-threaded + JSON reading. + """ + return self.options.newlines_in_values + + @newlines_in_values.setter + def newlines_in_values(self, value): + self.options.newlines_in_values = value + + @property + def unexpected_field_behavior(self): + """ + How JSON fields outside of explicit_schema (if given) are treated. + + Possible behaviors: + + - "ignore": unexpected JSON fields are ignored + - "error": error out on unexpected JSON fields + - "infer": unexpected JSON fields are type-inferred and included in + the output + + Set to "infer" by default. + """ + v = self.options.unexpected_field_behavior + if v == CUnexpectedFieldBehavior_Ignore: + return "ignore" + elif v == CUnexpectedFieldBehavior_Error: + return "error" + elif v == CUnexpectedFieldBehavior_InferType: + return "infer" + else: + raise ValueError('Unexpected value for unexpected_field_behavior') + + @unexpected_field_behavior.setter + def unexpected_field_behavior(self, value): + cdef CUnexpectedFieldBehavior v + + if value == "ignore": + v = CUnexpectedFieldBehavior_Ignore + elif value == "error": + v = CUnexpectedFieldBehavior_Error + elif value == "infer": + v = CUnexpectedFieldBehavior_InferType + else: + raise ValueError( + "Unexpected value `{}` for `unexpected_field_behavior`, pass " + "either `ignore`, `error` or `infer`.".format(value) + ) + + self.options.unexpected_field_behavior = v + + def equals(self, ParseOptions other): + """ + Parameters + ---------- + other : pyarrow.json.ParseOptions + + Returns + ------- + bool + """ + return ( + self.explicit_schema == other.explicit_schema and + self.newlines_in_values == other.newlines_in_values and + self.unexpected_field_behavior == other.unexpected_field_behavior + ) + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return False + + @staticmethod + cdef ParseOptions wrap(CJSONParseOptions options): + out = ParseOptions() + out.options = options # shallow copy + return out + + +cdef _get_reader(input_file, shared_ptr[CInputStream]* out): + use_memory_map = False + get_input_stream(input_file, use_memory_map, out) + +cdef _get_read_options(ReadOptions read_options, CJSONReadOptions* out): + if read_options is None: + out[0] = CJSONReadOptions.Defaults() + else: + out[0] = read_options.options + +cdef _get_parse_options(ParseOptions parse_options, CJSONParseOptions* out): + if parse_options is None: + out[0] = CJSONParseOptions.Defaults() + else: + out[0] = parse_options.options + + +def read_json(input_file, read_options=None, parse_options=None, + MemoryPool memory_pool=None): + """ + Read a Table from a stream of JSON data. + + Parameters + ---------- + input_file : str, path or file-like object + The location of JSON data. Currently only the line-delimited JSON + format is supported. + read_options : pyarrow.json.ReadOptions, optional + Options for the JSON reader (see ReadOptions constructor for defaults). + parse_options : pyarrow.json.ParseOptions, optional + Options for the JSON parser + (see ParseOptions constructor for defaults). + memory_pool : MemoryPool, optional + Pool to allocate Table memory from. + + Returns + ------- + :class:`pyarrow.Table` + Contents of the JSON file as a in-memory table. + """ + cdef: + shared_ptr[CInputStream] stream + CJSONReadOptions c_read_options + CJSONParseOptions c_parse_options + shared_ptr[CJSONReader] reader + shared_ptr[CTable] table + + _get_reader(input_file, &stream) + _get_read_options(read_options, &c_read_options) + _get_parse_options(parse_options, &c_parse_options) + + reader = GetResultValue( + CJSONReader.Make(maybe_unbox_memory_pool(memory_pool), + stream, c_read_options, c_parse_options)) + + with nogil: + table = GetResultValue(reader.get().Read()) + + return pyarrow_wrap_table(table) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_orc.pyx b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_orc.pyx new file mode 100644 index 0000000000000000000000000000000000000000..1dd6848122c2d4d5d2a40faf70bbb4647329f9d8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_orc.pyx @@ -0,0 +1,445 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: profile=False +# distutils: language = c++ + +from cython.operator cimport dereference as deref +from libcpp.vector cimport vector as std_vector +from libcpp.utility cimport move +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * +from pyarrow.lib cimport (check_status, _Weakrefable, + MemoryPool, maybe_unbox_memory_pool, + pyarrow_wrap_schema, + pyarrow_wrap_batch, + Table, + pyarrow_wrap_table, + pyarrow_wrap_metadata, + pyarrow_unwrap_table, + get_reader, + get_writer) +from pyarrow.lib import frombytes, tobytes +from pyarrow.util import _stringify_path + + +cdef compression_type_from_enum(CCompressionType compression_type): + compression_map = { + CCompressionType_UNCOMPRESSED: 'UNCOMPRESSED', + CCompressionType_GZIP: 'ZLIB', + CCompressionType_SNAPPY: 'SNAPPY', + CCompressionType_LZ4: 'LZ4', + CCompressionType_ZSTD: 'ZSTD', + } + if compression_type in compression_map: + return compression_map[compression_type] + raise ValueError('Unsupported compression') + + +cdef CCompressionType compression_type_from_name(name) except *: + if not isinstance(name, str): + raise TypeError('compression must be a string') + name = name.upper() + if name == 'ZLIB': + return CCompressionType_GZIP + elif name == 'SNAPPY': + return CCompressionType_SNAPPY + elif name == 'LZ4': + return CCompressionType_LZ4 + elif name == 'ZSTD': + return CCompressionType_ZSTD + elif name == 'UNCOMPRESSED': + return CCompressionType_UNCOMPRESSED + raise ValueError(f'Unknown CompressionKind: {name}') + + +cdef compression_strategy_from_enum( + CompressionStrategy compression_strategy +): + compression_strategy_map = { + _CompressionStrategy_SPEED: 'SPEED', + _CompressionStrategy_COMPRESSION: 'COMPRESSION', + } + if compression_strategy in compression_strategy_map: + return compression_strategy_map[compression_strategy] + raise ValueError('Unsupported compression strategy') + + +cdef CompressionStrategy compression_strategy_from_name(name) except *: + if not isinstance(name, str): + raise TypeError('compression strategy must be a string') + name = name.upper() + if name == 'COMPRESSION': + return _CompressionStrategy_COMPRESSION + elif name == 'SPEED': + return _CompressionStrategy_SPEED + raise ValueError(f'Unknown CompressionStrategy: {name}') + + +cdef file_version_from_class(FileVersion file_version): + return frombytes(file_version.ToString()) + + +cdef writer_id_from_enum(WriterId writer_id): + writer_id_map = { + _WriterId_ORC_JAVA_WRITER: 'ORC_JAVA', + _WriterId_ORC_CPP_WRITER: 'ORC_CPP', + _WriterId_PRESTO_WRITER: 'PRESTO', + _WriterId_SCRITCHLEY_GO: 'SCRITCHLEY_GO', + _WriterId_TRINO_WRITER: 'TRINO', + } + if writer_id in writer_id_map: + return writer_id_map[writer_id] + raise ValueError('Unsupported writer ID') + + +cdef writer_version_from_enum(WriterVersion writer_version): + writer_version_map = { + _WriterVersion_ORIGINAL: 'ORIGINAL', + _WriterVersion_HIVE_8732: 'HIVE_8732', + _WriterVersion_HIVE_4243: 'HIVE_4243', + _WriterVersion_HIVE_12055: 'HIVE_12055', + _WriterVersion_HIVE_13083: 'HIVE_13083', + _WriterVersion_ORC_101: 'ORC_101', + _WriterVersion_ORC_135: 'ORC_135', + _WriterVersion_ORC_517: 'ORC_517', + _WriterVersion_ORC_203: 'ORC_203', + _WriterVersion_ORC_14: 'ORC_14', + } + if writer_version in writer_version_map: + return writer_version_map[writer_version] + raise ValueError('Unsupported writer version') + + +cdef shared_ptr[WriteOptions] _create_write_options( + file_version=None, + batch_size=None, + stripe_size=None, + compression=None, + compression_block_size=None, + compression_strategy=None, + row_index_stride=None, + padding_tolerance=None, + dictionary_key_size_threshold=None, + bloom_filter_columns=None, + bloom_filter_fpp=None +) except *: + """General writer options""" + cdef: + shared_ptr[WriteOptions] options + options = make_shared[WriteOptions]() + # batch_size + if batch_size is not None: + if isinstance(batch_size, int) and batch_size > 0: + deref(options).batch_size = batch_size + else: + raise ValueError(f"Invalid ORC writer batch size: {batch_size}") + # file_version + if file_version is not None: + if file_version == "0.12": + deref(options).file_version = FileVersion(0, 12) + elif file_version == "0.11": + deref(options).file_version = FileVersion(0, 11) + else: + raise ValueError(f"Unsupported ORC file version: {file_version}") + # stripe_size + if stripe_size is not None: + if isinstance(stripe_size, int) and stripe_size > 0: + deref(options).stripe_size = stripe_size + else: + raise ValueError(f"Invalid ORC stripe size: {stripe_size}") + # compression + if compression is not None: + if isinstance(compression, str): + deref(options).compression = compression_type_from_name( + compression) + else: + raise TypeError("Unsupported ORC compression type: " + f"{compression}") + # compression_block_size + if compression_block_size is not None: + if (isinstance(compression_block_size, int) and + compression_block_size > 0): + deref(options).compression_block_size = compression_block_size + else: + raise ValueError("Invalid ORC compression block size: " + f"{compression_block_size}") + # compression_strategy + if compression_strategy is not None: + if isinstance(compression, str): + deref(options).compression_strategy = \ + compression_strategy_from_name(compression_strategy) + else: + raise TypeError("Unsupported ORC compression strategy: " + f"{compression_strategy}") + # row_index_stride + if row_index_stride is not None: + if isinstance(row_index_stride, int) and row_index_stride > 0: + deref(options).row_index_stride = row_index_stride + else: + raise ValueError("Invalid ORC row index stride: " + f"{row_index_stride}") + # padding_tolerance + if padding_tolerance is not None: + try: + padding_tolerance = float(padding_tolerance) + deref(options).padding_tolerance = padding_tolerance + except Exception: + raise ValueError("Invalid ORC padding tolerance: " + f"{padding_tolerance}") + # dictionary_key_size_threshold + if dictionary_key_size_threshold is not None: + try: + dictionary_key_size_threshold = float( + dictionary_key_size_threshold) + assert 0 <= dictionary_key_size_threshold <= 1 + deref(options).dictionary_key_size_threshold = \ + dictionary_key_size_threshold + except Exception: + raise ValueError("Invalid ORC dictionary key size threshold: " + f"{dictionary_key_size_threshold}") + # bloom_filter_columns + if bloom_filter_columns is not None: + try: + bloom_filter_columns = list(bloom_filter_columns) + for col in bloom_filter_columns: + assert isinstance(col, int) and col >= 0 + deref(options).bloom_filter_columns = bloom_filter_columns + except Exception: + raise ValueError("Invalid ORC BloomFilter columns: " + f"{bloom_filter_columns}") + # Max false positive rate of the Bloom Filter + if bloom_filter_fpp is not None: + try: + bloom_filter_fpp = float(bloom_filter_fpp) + assert 0 <= bloom_filter_fpp <= 1 + deref(options).bloom_filter_fpp = bloom_filter_fpp + except Exception: + raise ValueError("Invalid ORC BloomFilter false positive rate: " + f"{bloom_filter_fpp}") + return options + + +cdef class ORCReader(_Weakrefable): + cdef: + object source + CMemoryPool* allocator + unique_ptr[ORCFileReader] reader + + def __cinit__(self, MemoryPool memory_pool=None): + self.allocator = maybe_unbox_memory_pool(memory_pool) + + def open(self, object source, c_bool use_memory_map=True): + cdef: + shared_ptr[CRandomAccessFile] rd_handle + + self.source = source + + get_reader(source, use_memory_map, &rd_handle) + with nogil: + self.reader = move(GetResultValue( + ORCFileReader.Open(rd_handle, self.allocator) + )) + + def metadata(self): + """ + The arrow metadata for this file. + + Returns + ------- + metadata : pyarrow.KeyValueMetadata + """ + cdef: + shared_ptr[const CKeyValueMetadata] sp_arrow_metadata + + with nogil: + sp_arrow_metadata = GetResultValue( + deref(self.reader).ReadMetadata() + ) + + return pyarrow_wrap_metadata(sp_arrow_metadata) + + def schema(self): + """ + The arrow schema for this file. + + Returns + ------- + schema : pyarrow.Schema + """ + cdef: + shared_ptr[CSchema] sp_arrow_schema + + with nogil: + sp_arrow_schema = GetResultValue(deref(self.reader).ReadSchema()) + + return pyarrow_wrap_schema(sp_arrow_schema) + + def nrows(self): + return deref(self.reader).NumberOfRows() + + def nstripes(self): + return deref(self.reader).NumberOfStripes() + + def file_version(self): + return file_version_from_class(deref(self.reader).GetFileVersion()) + + def software_version(self): + return frombytes(deref(self.reader).GetSoftwareVersion()) + + def compression(self): + return compression_type_from_enum( + GetResultValue(deref(self.reader).GetCompression())) + + def compression_size(self): + return deref(self.reader).GetCompressionSize() + + def row_index_stride(self): + return deref(self.reader).GetRowIndexStride() + + def writer(self): + writer_name = writer_id_from_enum(deref(self.reader).GetWriterId()) + if writer_name == 'UNKNOWN': + return deref(self.reader).GetWriterIdValue() + else: + return writer_name + + def writer_version(self): + return writer_version_from_enum(deref(self.reader).GetWriterVersion()) + + def nstripe_statistics(self): + return deref(self.reader).GetNumberOfStripeStatistics() + + def content_length(self): + return deref(self.reader).GetContentLength() + + def stripe_statistics_length(self): + return deref(self.reader).GetStripeStatisticsLength() + + def file_footer_length(self): + return deref(self.reader).GetFileFooterLength() + + def file_postscript_length(self): + return deref(self.reader).GetFilePostscriptLength() + + def file_length(self): + return deref(self.reader).GetFileLength() + + def serialized_file_tail(self): + return deref(self.reader).GetSerializedFileTail() + + def read_stripe(self, n, columns=None): + cdef: + shared_ptr[CRecordBatch] sp_record_batch + int64_t stripe + std_vector[c_string] c_names + + stripe = n + + if columns is None: + with nogil: + sp_record_batch = GetResultValue( + deref(self.reader).ReadStripe(stripe) + ) + else: + c_names = [tobytes(name) for name in columns] + with nogil: + sp_record_batch = GetResultValue( + deref(self.reader).ReadStripe(stripe, c_names) + ) + + return pyarrow_wrap_batch(sp_record_batch) + + def read(self, columns=None): + cdef: + shared_ptr[CTable] sp_table + std_vector[c_string] c_names + + if columns is None: + with nogil: + sp_table = GetResultValue(deref(self.reader).Read()) + else: + c_names = [tobytes(name) for name in columns] + with nogil: + sp_table = GetResultValue(deref(self.reader).Read(c_names)) + + return pyarrow_wrap_table(sp_table) + + +cdef class ORCWriter(_Weakrefable): + cdef: + unique_ptr[ORCFileWriter] writer + shared_ptr[COutputStream] sink + c_bool own_sink + + def open(self, object where, *, + file_version=None, + batch_size=None, + stripe_size=None, + compression=None, + compression_block_size=None, + compression_strategy=None, + row_index_stride=None, + padding_tolerance=None, + dictionary_key_size_threshold=None, + bloom_filter_columns=None, + bloom_filter_fpp=None): + cdef: + shared_ptr[WriteOptions] write_options + c_string c_where + try: + where = _stringify_path(where) + except TypeError: + get_writer(where, &self.sink) + self.own_sink = False + else: + c_where = tobytes(where) + with nogil: + self.sink = GetResultValue(FileOutputStream.Open(c_where)) + self.own_sink = True + + write_options = _create_write_options( + file_version=file_version, + batch_size=batch_size, + stripe_size=stripe_size, + compression=compression, + compression_block_size=compression_block_size, + compression_strategy=compression_strategy, + row_index_stride=row_index_stride, + padding_tolerance=padding_tolerance, + dictionary_key_size_threshold=dictionary_key_size_threshold, + bloom_filter_columns=bloom_filter_columns, + bloom_filter_fpp=bloom_filter_fpp + ) + + with nogil: + self.writer = move(GetResultValue( + ORCFileWriter.Open(self.sink.get(), + deref(write_options)))) + + def write(self, Table table): + cdef: + shared_ptr[CTable] sp_table + sp_table = pyarrow_unwrap_table(table) + with nogil: + check_status(deref(self.writer).Write(deref(sp_table))) + + def close(self): + with nogil: + check_status(deref(self.writer).Close()) + if self.own_sink: + check_status(deref(self.sink).Close()) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_parquet.pyx b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_parquet.pyx new file mode 100644 index 0000000000000000000000000000000000000000..a3abf1865b7b5423820fcd5e1898d396c8ac94ec --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_parquet.pyx @@ -0,0 +1,2266 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: profile=False +# distutils: language = c++ + +from collections.abc import Sequence +from textwrap import indent +import warnings + +from cython.operator cimport dereference as deref +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * +from pyarrow.includes.libarrow_python cimport * +from pyarrow.lib cimport (_Weakrefable, Buffer, Schema, + check_status, + MemoryPool, maybe_unbox_memory_pool, + Table, KeyValueMetadata, + pyarrow_wrap_chunked_array, + pyarrow_wrap_schema, + pyarrow_unwrap_metadata, + pyarrow_unwrap_schema, + pyarrow_wrap_table, + pyarrow_wrap_batch, + pyarrow_wrap_scalar, + NativeFile, get_reader, get_writer, + string_to_timeunit) + +from pyarrow.lib import (ArrowException, NativeFile, BufferOutputStream, + _stringify_path, + tobytes, frombytes, is_threading_enabled) + +cimport cpython as cp + +_DEFAULT_ROW_GROUP_SIZE = 1024*1024 +_MAX_ROW_GROUP_SIZE = 64*1024*1024 + +cdef class Statistics(_Weakrefable): + """Statistics for a single column in a single row group.""" + + def __cinit__(self): + pass + + def __repr__(self): + return """{} + has_min_max: {} + min: {} + max: {} + null_count: {} + distinct_count: {} + num_values: {} + physical_type: {} + logical_type: {} + converted_type (legacy): {}""".format(object.__repr__(self), + self.has_min_max, + self.min, + self.max, + self.null_count, + self.distinct_count, + self.num_values, + self.physical_type, + str(self.logical_type), + self.converted_type) + + def to_dict(self): + """ + Get dictionary representation of statistics. + + Returns + ------- + dict + Dictionary with a key for each attribute of this class. + """ + d = dict( + has_min_max=self.has_min_max, + min=self.min, + max=self.max, + null_count=self.null_count, + distinct_count=self.distinct_count, + num_values=self.num_values, + physical_type=self.physical_type + ) + return d + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return NotImplemented + + def equals(self, Statistics other): + """ + Return whether the two column statistics objects are equal. + + Parameters + ---------- + other : Statistics + Statistics to compare against. + + Returns + ------- + are_equal : bool + """ + return self.statistics.get().Equals(deref(other.statistics.get())) + + @property + def has_min_max(self): + """Whether min and max are present (bool).""" + return self.statistics.get().HasMinMax() + + @property + def has_null_count(self): + """Whether null count is present (bool).""" + return self.statistics.get().HasNullCount() + + @property + def has_distinct_count(self): + """Whether distinct count is preset (bool).""" + return self.statistics.get().HasDistinctCount() + + @property + def min_raw(self): + """Min value as physical type (bool, int, float, or bytes).""" + if self.has_min_max: + return _cast_statistic_raw_min(self.statistics.get()) + else: + return None + + @property + def max_raw(self): + """Max value as physical type (bool, int, float, or bytes).""" + if self.has_min_max: + return _cast_statistic_raw_max(self.statistics.get()) + else: + return None + + @property + def min(self): + """ + Min value as logical type. + + Returned as the Python equivalent of logical type, such as datetime.date + for dates and decimal.Decimal for decimals. + """ + if self.has_min_max: + min_scalar, _ = _cast_statistics(self.statistics.get()) + return min_scalar.as_py() + else: + return None + + @property + def max(self): + """ + Max value as logical type. + + Returned as the Python equivalent of logical type, such as datetime.date + for dates and decimal.Decimal for decimals. + """ + if self.has_min_max: + _, max_scalar = _cast_statistics(self.statistics.get()) + return max_scalar.as_py() + else: + return None + + @property + def null_count(self): + """Number of null values in chunk (int).""" + if self.has_null_count: + return self.statistics.get().null_count() + else: + return None + + @property + def distinct_count(self): + """Distinct number of values in chunk (int).""" + if self.has_distinct_count: + return self.statistics.get().distinct_count() + else: + return None + + @property + def num_values(self): + """Number of non-null values (int).""" + return self.statistics.get().num_values() + + @property + def physical_type(self): + """Physical type of column (str).""" + raw_physical_type = self.statistics.get().physical_type() + return physical_type_name_from_enum(raw_physical_type) + + @property + def logical_type(self): + """Logical type of column (:class:`ParquetLogicalType`).""" + return wrap_logical_type(self.statistics.get().descr().logical_type()) + + @property + def converted_type(self): + """Legacy converted type (str or None).""" + raw_converted_type = self.statistics.get().descr().converted_type() + return converted_type_name_from_enum(raw_converted_type) + + +cdef class ParquetLogicalType(_Weakrefable): + """Logical type of parquet type.""" + cdef: + shared_ptr[const CParquetLogicalType] type + + def __cinit__(self): + pass + + cdef init(self, const shared_ptr[const CParquetLogicalType]& type): + self.type = type + + def __repr__(self): + return "{}\n {}".format(object.__repr__(self), str(self)) + + def __str__(self): + return frombytes(self.type.get().ToString(), safe=True) + + def to_json(self): + """ + Get a JSON string containing type and type parameters. + + Returns + ------- + json : str + JSON representation of type, with at least a field called 'Type' + which contains the type name. If the type is parameterized, such + as a decimal with scale and precision, will contain those as fields + as well. + """ + return frombytes(self.type.get().ToJSON()) + + @property + def type(self): + """Name of the logical type (str).""" + return logical_type_name_from_enum(self.type.get().type()) + + +cdef wrap_logical_type(const shared_ptr[const CParquetLogicalType]& type): + cdef ParquetLogicalType out = ParquetLogicalType() + out.init(type) + return out + + +cdef _cast_statistic_raw_min(CStatistics* statistics): + cdef ParquetType physical_type = statistics.physical_type() + cdef uint32_t type_length = statistics.descr().type_length() + if physical_type == ParquetType_BOOLEAN: + return ( statistics).min() + elif physical_type == ParquetType_INT32: + return ( statistics).min() + elif physical_type == ParquetType_INT64: + return ( statistics).min() + elif physical_type == ParquetType_FLOAT: + return ( statistics).min() + elif physical_type == ParquetType_DOUBLE: + return ( statistics).min() + elif physical_type == ParquetType_BYTE_ARRAY: + return _box_byte_array(( statistics).min()) + elif physical_type == ParquetType_FIXED_LEN_BYTE_ARRAY: + return _box_flba(( statistics).min(), type_length) + + +cdef _cast_statistic_raw_max(CStatistics* statistics): + cdef ParquetType physical_type = statistics.physical_type() + cdef uint32_t type_length = statistics.descr().type_length() + if physical_type == ParquetType_BOOLEAN: + return ( statistics).max() + elif physical_type == ParquetType_INT32: + return ( statistics).max() + elif physical_type == ParquetType_INT64: + return ( statistics).max() + elif physical_type == ParquetType_FLOAT: + return ( statistics).max() + elif physical_type == ParquetType_DOUBLE: + return ( statistics).max() + elif physical_type == ParquetType_BYTE_ARRAY: + return _box_byte_array(( statistics).max()) + elif physical_type == ParquetType_FIXED_LEN_BYTE_ARRAY: + return _box_flba(( statistics).max(), type_length) + + +cdef _cast_statistics(CStatistics* statistics): + cdef: + shared_ptr[CScalar] c_min + shared_ptr[CScalar] c_max + check_status(StatisticsAsScalars(statistics[0], &c_min, &c_max)) + return (pyarrow_wrap_scalar(c_min), pyarrow_wrap_scalar(c_max)) + + +cdef _box_byte_array(ParquetByteArray val): + return cp.PyBytes_FromStringAndSize( val.ptr, val.len) + + +cdef _box_flba(ParquetFLBA val, uint32_t len): + return cp.PyBytes_FromStringAndSize( val.ptr, len) + + +cdef class ColumnChunkMetaData(_Weakrefable): + """Column metadata for a single row group.""" + + def __cinit__(self): + pass + + def __repr__(self): + statistics = indent(repr(self.statistics), 4 * ' ') + return """{0} + file_offset: {1} + file_path: {2} + physical_type: {3} + num_values: {4} + path_in_schema: {5} + is_stats_set: {6} + statistics: +{7} + compression: {8} + encodings: {9} + has_dictionary_page: {10} + dictionary_page_offset: {11} + data_page_offset: {12} + total_compressed_size: {13} + total_uncompressed_size: {14}""".format(object.__repr__(self), + self.file_offset, + self.file_path, + self.physical_type, + self.num_values, + self.path_in_schema, + self.is_stats_set, + statistics, + self.compression, + self.encodings, + self.has_dictionary_page, + self.dictionary_page_offset, + self.data_page_offset, + self.total_compressed_size, + self.total_uncompressed_size) + + def to_dict(self): + """ + Get dictionary representation of the column chunk metadata. + + Returns + ------- + dict + Dictionary with a key for each attribute of this class. + """ + statistics = self.statistics.to_dict() if self.is_stats_set else None + d = dict( + file_offset=self.file_offset, + file_path=self.file_path, + physical_type=self.physical_type, + num_values=self.num_values, + path_in_schema=self.path_in_schema, + is_stats_set=self.is_stats_set, + statistics=statistics, + compression=self.compression, + encodings=self.encodings, + has_dictionary_page=self.has_dictionary_page, + dictionary_page_offset=self.dictionary_page_offset, + data_page_offset=self.data_page_offset, + total_compressed_size=self.total_compressed_size, + total_uncompressed_size=self.total_uncompressed_size + ) + return d + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return NotImplemented + + def equals(self, ColumnChunkMetaData other): + """ + Return whether the two column chunk metadata objects are equal. + + Parameters + ---------- + other : ColumnChunkMetaData + Metadata to compare against. + + Returns + ------- + are_equal : bool + """ + return self.metadata.Equals(deref(other.metadata)) + + @property + def file_offset(self): + """Offset into file where column chunk is located (int).""" + return self.metadata.file_offset() + + @property + def file_path(self): + """Optional file path if set (str or None).""" + return frombytes(self.metadata.file_path()) + + @property + def physical_type(self): + """Physical type of column (str).""" + return physical_type_name_from_enum(self.metadata.type()) + + @property + def num_values(self): + """Total number of values (int).""" + return self.metadata.num_values() + + @property + def path_in_schema(self): + """Nested path to field, separated by periods (str).""" + path = self.metadata.path_in_schema().get().ToDotString() + return frombytes(path) + + @property + def is_stats_set(self): + """Whether or not statistics are present in metadata (bool).""" + return self.metadata.is_stats_set() + + @property + def statistics(self): + """Statistics for column chunk (:class:`Statistics`).""" + if not self.metadata.is_stats_set(): + return None + statistics = Statistics() + statistics.init(self.metadata.statistics(), self) + return statistics + + @property + def compression(self): + """ + Type of compression used for column (str). + + One of 'UNCOMPRESSED', 'SNAPPY', 'GZIP', 'LZO', 'BROTLI', 'LZ4', 'ZSTD', + or 'UNKNOWN'. + """ + return compression_name_from_enum(self.metadata.compression()) + + @property + def encodings(self): + """ + Encodings used for column (tuple of str). + + One of 'PLAIN', 'BIT_PACKED', 'RLE', 'BYTE_STREAM_SPLIT', 'DELTA_BINARY_PACKED', + 'DELTA_LENGTH_BYTE_ARRAY', 'DELTA_BYTE_ARRAY'. + """ + return tuple(map(encoding_name_from_enum, self.metadata.encodings())) + + @property + def has_dictionary_page(self): + """Whether there is dictionary data present in the column chunk (bool).""" + return bool(self.metadata.has_dictionary_page()) + + @property + def dictionary_page_offset(self): + """Offset of dictionary page relative to beginning of the file (int).""" + if self.has_dictionary_page: + return self.metadata.dictionary_page_offset() + else: + return None + + @property + def data_page_offset(self): + """Offset of data page relative to beginning of the file (int).""" + return self.metadata.data_page_offset() + + @property + def has_index_page(self): + """Not yet supported.""" + raise NotImplementedError('not supported in parquet-cpp') + + @property + def index_page_offset(self): + """Not yet supported.""" + raise NotImplementedError("parquet-cpp doesn't return valid values") + + @property + def total_compressed_size(self): + """Compressed size in bytes (int).""" + return self.metadata.total_compressed_size() + + @property + def total_uncompressed_size(self): + """Uncompressed size in bytes (int).""" + return self.metadata.total_uncompressed_size() + + @property + def has_offset_index(self): + """Whether the column chunk has an offset index""" + return self.metadata.GetOffsetIndexLocation().has_value() + + @property + def has_column_index(self): + """Whether the column chunk has a column index""" + return self.metadata.GetColumnIndexLocation().has_value() + + @property + def metadata(self): + """Additional metadata as key value pairs (dict[bytes, bytes]).""" + cdef: + unordered_map[c_string, c_string] metadata + const CKeyValueMetadata* underlying_metadata + underlying_metadata = self.metadata.key_value_metadata().get() + if underlying_metadata != NULL: + underlying_metadata.ToUnorderedMap(&metadata) + return metadata + else: + return None + + +cdef class SortingColumn: + """ + Sorting specification for a single column. + + Returned by :meth:`RowGroupMetaData.sorting_columns` and used in + :class:`ParquetWriter` to specify the sort order of the data. + + Parameters + ---------- + column_index : int + Index of column that data is sorted by. + descending : bool, default False + Whether column is sorted in descending order. + nulls_first : bool, default False + Whether null values appear before valid values. + + Notes + ----- + + Column indices are zero-based, refer only to leaf fields, and are in + depth-first order. This may make the column indices for nested schemas + different from what you expect. In most cases, it will be easier to + specify the sort order using column names instead of column indices + and converting using the ``from_ordering`` method. + + Examples + -------- + + In other APIs, sort order is specified by names, such as: + + >>> sort_order = [('id', 'ascending'), ('timestamp', 'descending')] + + For Parquet, the column index must be used instead: + + >>> import pyarrow.parquet as pq + >>> [pq.SortingColumn(0), pq.SortingColumn(1, descending=True)] + [SortingColumn(column_index=0, descending=False, nulls_first=False), SortingColumn(column_index=1, descending=True, nulls_first=False)] + + Convert the sort_order into the list of sorting columns with + ``from_ordering`` (note that the schema must be provided as well): + + >>> import pyarrow as pa + >>> schema = pa.schema([('id', pa.int64()), ('timestamp', pa.timestamp('ms'))]) + >>> sorting_columns = pq.SortingColumn.from_ordering(schema, sort_order) + >>> sorting_columns + (SortingColumn(column_index=0, descending=False, nulls_first=False), SortingColumn(column_index=1, descending=True, nulls_first=False)) + + Convert back to the sort order with ``to_ordering``: + + >>> pq.SortingColumn.to_ordering(schema, sorting_columns) + ((('id', 'ascending'), ('timestamp', 'descending')), 'at_end') + + See Also + -------- + RowGroupMetaData.sorting_columns + """ + cdef int column_index + cdef c_bool descending + cdef c_bool nulls_first + + def __init__(self, int column_index, c_bool descending=False, c_bool nulls_first=False): + self.column_index = column_index + self.descending = descending + self.nulls_first = nulls_first + + @classmethod + def from_ordering(cls, Schema schema, sort_keys, null_placement='at_end'): + """ + Create a tuple of SortingColumn objects from the same arguments as + :class:`pyarrow.compute.SortOptions`. + + Parameters + ---------- + schema : Schema + Schema of the input data. + sort_keys : Sequence of (name, order) tuples + Names of field/column keys (str) to sort the input on, + along with the order each field/column is sorted in. + Accepted values for `order` are "ascending", "descending". + null_placement : {'at_start', 'at_end'}, default 'at_end' + Where null values should appear in the sort order. + + Returns + ------- + sorting_columns : tuple of SortingColumn + """ + if null_placement == 'at_start': + nulls_first = True + elif null_placement == 'at_end': + nulls_first = False + else: + raise ValueError('null_placement must be "at_start" or "at_end"') + + col_map = _name_to_index_map(schema) + + sorting_columns = [] + + for sort_key in sort_keys: + if isinstance(sort_key, str): + name = sort_key + descending = False + elif (isinstance(sort_key, tuple) and len(sort_key) == 2 and + isinstance(sort_key[0], str) and + isinstance(sort_key[1], str)): + name, descending = sort_key + if descending == "descending": + descending = True + elif descending == "ascending": + descending = False + else: + raise ValueError("Invalid sort key direction: {0}" + .format(descending)) + else: + raise ValueError("Invalid sort key: {0}".format(sort_key)) + + try: + column_index = col_map[name] + except KeyError: + raise ValueError("Sort key name '{0}' not found in schema:\n{1}" + .format(name, schema)) + + sorting_columns.append( + cls(column_index, descending=descending, nulls_first=nulls_first) + ) + + return tuple(sorting_columns) + + @staticmethod + def to_ordering(Schema schema, sorting_columns): + """ + Convert a tuple of SortingColumn objects to the same format as + :class:`pyarrow.compute.SortOptions`. + + Parameters + ---------- + schema : Schema + Schema of the input data. + sorting_columns : tuple of SortingColumn + Columns to sort the input on. + + Returns + ------- + sort_keys : tuple of (name, order) tuples + null_placement : {'at_start', 'at_end'} + """ + col_map = {i: name for name, i in _name_to_index_map(schema).items()} + + sort_keys = [] + nulls_first = None + + for sorting_column in sorting_columns: + name = col_map[sorting_column.column_index] + if sorting_column.descending: + order = "descending" + else: + order = "ascending" + sort_keys.append((name, order)) + if nulls_first is None: + nulls_first = sorting_column.nulls_first + elif nulls_first != sorting_column.nulls_first: + raise ValueError("Sorting columns have inconsistent null placement") + + if nulls_first: + null_placement = "at_start" + else: + null_placement = "at_end" + + return tuple(sort_keys), null_placement + + def __repr__(self): + return """{}(column_index={}, descending={}, nulls_first={})""".format( + self.__class__.__name__, + self.column_index, self.descending, self.nulls_first) + + def __eq__(self, SortingColumn other): + return (self.column_index == other.column_index and + self.descending == other.descending and + self.nulls_first == other.nulls_first) + + def __hash__(self): + return hash((self.column_index, self.descending, self.nulls_first)) + + @property + def column_index(self): + """"Index of column data is sorted by (int).""" + return self.column_index + + @property + def descending(self): + """Whether column is sorted in descending order (bool).""" + return self.descending + + @property + def nulls_first(self): + """Whether null values appear before valid values (bool).""" + return self.nulls_first + + def to_dict(self): + """ + Get dictionary representation of the SortingColumn. + + Returns + ------- + dict + Dictionary with a key for each attribute of this class. + """ + d = dict( + column_index=self.column_index, + descending=self.descending, + nulls_first=self.nulls_first + ) + return d + + +cdef class RowGroupMetaData(_Weakrefable): + """Metadata for a single row group.""" + + def __cinit__(self, FileMetaData parent, int index): + if index < 0 or index >= parent.num_row_groups: + raise IndexError('{0} out of bounds'.format(index)) + self.up_metadata = parent._metadata.RowGroup(index) + self.metadata = self.up_metadata.get() + self.parent = parent + self.index = index + + def __reduce__(self): + return RowGroupMetaData, (self.parent, self.index) + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return NotImplemented + + def equals(self, RowGroupMetaData other): + """ + Return whether the two row group metadata objects are equal. + + Parameters + ---------- + other : RowGroupMetaData + Metadata to compare against. + + Returns + ------- + are_equal : bool + """ + return self.metadata.Equals(deref(other.metadata)) + + def column(self, int i): + """ + Get column metadata at given index. + + Parameters + ---------- + i : int + Index of column to get metadata for. + + Returns + ------- + ColumnChunkMetaData + Metadata for column within this chunk. + """ + if i < 0 or i >= self.num_columns: + raise IndexError('{0} out of bounds'.format(i)) + chunk = ColumnChunkMetaData() + chunk.init(self, i) + return chunk + + def __repr__(self): + return """{0} + num_columns: {1} + num_rows: {2} + total_byte_size: {3} + sorting_columns: {4}""".format(object.__repr__(self), + self.num_columns, + self.num_rows, + self.total_byte_size, + self.sorting_columns) + + def to_dict(self): + """ + Get dictionary representation of the row group metadata. + + Returns + ------- + dict + Dictionary with a key for each attribute of this class. + """ + columns = [] + d = dict( + num_columns=self.num_columns, + num_rows=self.num_rows, + total_byte_size=self.total_byte_size, + columns=columns, + sorting_columns=[col.to_dict() for col in self.sorting_columns] + ) + for i in range(self.num_columns): + columns.append(self.column(i).to_dict()) + return d + + @property + def num_columns(self): + """Number of columns in this row group (int).""" + return self.metadata.num_columns() + + @property + def num_rows(self): + """Number of rows in this row group (int).""" + return self.metadata.num_rows() + + @property + def total_byte_size(self): + """Total byte size of all the uncompressed column data in this row group (int).""" + return self.metadata.total_byte_size() + + @property + def sorting_columns(self): + """Columns the row group is sorted by (tuple of :class:`SortingColumn`)).""" + out = [] + cdef vector[CSortingColumn] sorting_columns = self.metadata.sorting_columns() + for sorting_col in sorting_columns: + out.append(SortingColumn( + sorting_col.column_idx, + sorting_col.descending, + sorting_col.nulls_first + )) + return tuple(out) + + +def _reconstruct_filemetadata(Buffer serialized): + cdef: + FileMetaData metadata = FileMetaData.__new__(FileMetaData) + CBuffer *buffer = serialized.buffer.get() + uint32_t metadata_len = buffer.size() + + metadata.init(CFileMetaData_Make(buffer.data(), &metadata_len)) + + return metadata + + +cdef class FileMetaData(_Weakrefable): + """Parquet metadata for a single file.""" + + def __cinit__(self): + pass + + def __reduce__(self): + cdef: + NativeFile sink = BufferOutputStream() + COutputStream* c_sink = sink.get_output_stream().get() + with nogil: + self._metadata.WriteTo(c_sink) + + cdef Buffer buffer = sink.getvalue() + return _reconstruct_filemetadata, (buffer,) + + def __hash__(self): + return hash((self.schema, + self.num_rows, + self.num_row_groups, + self.format_version, + self.serialized_size)) + + def __repr__(self): + return """{0} + created_by: {1} + num_columns: {2} + num_rows: {3} + num_row_groups: {4} + format_version: {5} + serialized_size: {6}""".format(object.__repr__(self), + self.created_by, self.num_columns, + self.num_rows, self.num_row_groups, + self.format_version, + self.serialized_size) + + def to_dict(self): + """ + Get dictionary representation of the file metadata. + + Returns + ------- + dict + Dictionary with a key for each attribute of this class. + """ + row_groups = [] + d = dict( + created_by=self.created_by, + num_columns=self.num_columns, + num_rows=self.num_rows, + num_row_groups=self.num_row_groups, + row_groups=row_groups, + format_version=self.format_version, + serialized_size=self.serialized_size + ) + for i in range(self.num_row_groups): + row_groups.append(self.row_group(i).to_dict()) + return d + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return NotImplemented + + def equals(self, FileMetaData other not None): + """ + Return whether the two file metadata objects are equal. + + Parameters + ---------- + other : FileMetaData + Metadata to compare against. + + Returns + ------- + are_equal : bool + """ + return self._metadata.Equals(deref(other._metadata)) + + @property + def schema(self): + """Schema of the file (:class:`ParquetSchema`).""" + if self._schema is None: + self._schema = ParquetSchema(self) + return self._schema + + @property + def serialized_size(self): + """Size of the original thrift encoded metadata footer (int).""" + return self._metadata.size() + + @property + def num_columns(self): + """Number of columns in file (int).""" + return self._metadata.num_columns() + + @property + def num_rows(self): + """Total number of rows in file (int).""" + return self._metadata.num_rows() + + @property + def num_row_groups(self): + """Number of row groups in file (int).""" + return self._metadata.num_row_groups() + + @property + def format_version(self): + """ + Parquet format version used in file (str, such as '1.0', '2.4'). + + If version is missing or unparsable, will default to assuming '2.6'. + """ + cdef ParquetVersion version = self._metadata.version() + if version == ParquetVersion_V1: + return '1.0' + elif version == ParquetVersion_V2_0: + return 'pseudo-2.0' + elif version == ParquetVersion_V2_4: + return '2.4' + elif version == ParquetVersion_V2_6: + return '2.6' + else: + warnings.warn('Unrecognized file version, assuming 2.6: {}' + .format(version)) + return '2.6' + + @property + def created_by(self): + """ + String describing source of the parquet file (str). + + This typically includes library name and version number. For example, Arrow 7.0's + writer returns 'parquet-cpp-arrow version 7.0.0'. + """ + return frombytes(self._metadata.created_by()) + + @property + def metadata(self): + """Additional metadata as key value pairs (dict[bytes, bytes]).""" + cdef: + unordered_map[c_string, c_string] metadata + const CKeyValueMetadata* underlying_metadata + underlying_metadata = self._metadata.key_value_metadata().get() + if underlying_metadata != NULL: + underlying_metadata.ToUnorderedMap(&metadata) + return metadata + else: + return None + + def row_group(self, int i): + """ + Get metadata for row group at index i. + + Parameters + ---------- + i : int + Row group index to get. + + Returns + ------- + row_group_metadata : RowGroupMetaData + """ + return RowGroupMetaData(self, i) + + def set_file_path(self, path): + """ + Set ColumnChunk file paths to the given value. + + This method modifies the ``file_path`` field of each ColumnChunk + in the FileMetaData to be a particular value. + + Parameters + ---------- + path : str + The file path to set on all ColumnChunks. + """ + cdef: + c_string c_path = tobytes(path) + self._metadata.set_file_path(c_path) + + def append_row_groups(self, FileMetaData other): + """ + Append row groups from other FileMetaData object. + + Parameters + ---------- + other : FileMetaData + Other metadata to append row groups from. + """ + cdef shared_ptr[CFileMetaData] c_metadata + + c_metadata = other.sp_metadata + self._metadata.AppendRowGroups(deref(c_metadata)) + + def write_metadata_file(self, where): + """ + Write the metadata to a metadata-only Parquet file. + + Parameters + ---------- + where : path or file-like object + Where to write the metadata. Should be a writable path on + the local filesystem, or a writable file-like object. + """ + cdef: + shared_ptr[COutputStream] sink + c_string c_where + + try: + where = _stringify_path(where) + except TypeError: + get_writer(where, &sink) + else: + c_where = tobytes(where) + with nogil: + sink = GetResultValue(FileOutputStream.Open(c_where)) + + with nogil: + check_status( + WriteMetaDataFile(deref(self._metadata), sink.get())) + + +cdef class ParquetSchema(_Weakrefable): + """A Parquet schema.""" + + def __cinit__(self, FileMetaData container): + self.parent = container + self.schema = container._metadata.schema() + + def __repr__(self): + return "{0}\n{1}".format( + object.__repr__(self), + frombytes(self.schema.ToString(), safe=True)) + + def __reduce__(self): + return ParquetSchema, (self.parent,) + + def __len__(self): + return self.schema.num_columns() + + def __getitem__(self, i): + return self.column(i) + + def __hash__(self): + return hash(self.schema.ToString()) + + @property + def names(self): + """Name of each field (list of str).""" + return [self[i].name for i in range(len(self))] + + def to_arrow_schema(self): + """ + Convert Parquet schema to effective Arrow schema. + + Returns + ------- + schema : Schema + """ + cdef shared_ptr[CSchema] sp_arrow_schema + + with nogil: + check_status(FromParquetSchema( + self.schema, default_arrow_reader_properties(), + self.parent._metadata.key_value_metadata(), + &sp_arrow_schema)) + + return pyarrow_wrap_schema(sp_arrow_schema) + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return NotImplemented + + def equals(self, ParquetSchema other): + """ + Return whether the two schemas are equal. + + Parameters + ---------- + other : ParquetSchema + Schema to compare against. + + Returns + ------- + are_equal : bool + """ + return self.schema.Equals(deref(other.schema)) + + def column(self, i): + """ + Return the schema for a single column. + + Parameters + ---------- + i : int + Index of column in schema. + + Returns + ------- + column_schema : ColumnSchema + """ + if i < 0 or i >= len(self): + raise IndexError('{0} out of bounds'.format(i)) + + return ColumnSchema(self, i) + + +cdef class ColumnSchema(_Weakrefable): + """Schema for a single column.""" + cdef: + int index + ParquetSchema parent + const ColumnDescriptor* descr + + def __cinit__(self, ParquetSchema schema, int index): + self.parent = schema + self.index = index # for pickling support + self.descr = schema.schema.Column(index) + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return NotImplemented + + def __reduce__(self): + return ColumnSchema, (self.parent, self.index) + + def equals(self, ColumnSchema other): + """ + Return whether the two column schemas are equal. + + Parameters + ---------- + other : ColumnSchema + Schema to compare against. + + Returns + ------- + are_equal : bool + """ + return self.descr.Equals(deref(other.descr)) + + def __repr__(self): + physical_type = self.physical_type + converted_type = self.converted_type + if converted_type == 'DECIMAL': + converted_type = 'DECIMAL({0}, {1})'.format(self.precision, + self.scale) + elif physical_type == 'FIXED_LEN_BYTE_ARRAY': + converted_type = ('FIXED_LEN_BYTE_ARRAY(length={0})' + .format(self.length)) + + return """ + name: {0} + path: {1} + max_definition_level: {2} + max_repetition_level: {3} + physical_type: {4} + logical_type: {5} + converted_type (legacy): {6}""".format(self.name, self.path, + self.max_definition_level, + self.max_repetition_level, + physical_type, + str(self.logical_type), + converted_type) + + @property + def name(self): + """Name of field (str).""" + return frombytes(self.descr.name()) + + @property + def path(self): + """Nested path to field, separated by periods (str).""" + return frombytes(self.descr.path().get().ToDotString()) + + @property + def max_definition_level(self): + """Maximum definition level (int).""" + return self.descr.max_definition_level() + + @property + def max_repetition_level(self): + """Maximum repetition level (int).""" + return self.descr.max_repetition_level() + + @property + def physical_type(self): + """Name of physical type (str).""" + return physical_type_name_from_enum(self.descr.physical_type()) + + @property + def logical_type(self): + """Logical type of column (:class:`ParquetLogicalType`).""" + return wrap_logical_type(self.descr.logical_type()) + + @property + def converted_type(self): + """Legacy converted type (str or None).""" + return converted_type_name_from_enum(self.descr.converted_type()) + + # FIXED_LEN_BYTE_ARRAY attribute + @property + def length(self): + """Array length if fixed length byte array type, None otherwise (int or None).""" + return self.descr.type_length() + + # Decimal attributes + @property + def precision(self): + """Precision if decimal type, None otherwise (int or None).""" + return self.descr.type_precision() + + @property + def scale(self): + """Scale if decimal type, None otherwise (int or None).""" + return self.descr.type_scale() + + +cdef physical_type_name_from_enum(ParquetType type_): + return { + ParquetType_BOOLEAN: 'BOOLEAN', + ParquetType_INT32: 'INT32', + ParquetType_INT64: 'INT64', + ParquetType_INT96: 'INT96', + ParquetType_FLOAT: 'FLOAT', + ParquetType_DOUBLE: 'DOUBLE', + ParquetType_BYTE_ARRAY: 'BYTE_ARRAY', + ParquetType_FIXED_LEN_BYTE_ARRAY: 'FIXED_LEN_BYTE_ARRAY', + }.get(type_, 'UNKNOWN') + + +cdef logical_type_name_from_enum(ParquetLogicalTypeId type_): + return { + ParquetLogicalType_UNDEFINED: 'UNDEFINED', + ParquetLogicalType_STRING: 'STRING', + ParquetLogicalType_MAP: 'MAP', + ParquetLogicalType_LIST: 'LIST', + ParquetLogicalType_ENUM: 'ENUM', + ParquetLogicalType_DECIMAL: 'DECIMAL', + ParquetLogicalType_DATE: 'DATE', + ParquetLogicalType_TIME: 'TIME', + ParquetLogicalType_TIMESTAMP: 'TIMESTAMP', + ParquetLogicalType_INT: 'INT', + ParquetLogicalType_FLOAT16: 'FLOAT16', + ParquetLogicalType_JSON: 'JSON', + ParquetLogicalType_BSON: 'BSON', + ParquetLogicalType_UUID: 'UUID', + ParquetLogicalType_NONE: 'NONE', + }.get(type_, 'UNKNOWN') + + +cdef converted_type_name_from_enum(ParquetConvertedType type_): + return { + ParquetConvertedType_NONE: 'NONE', + ParquetConvertedType_UTF8: 'UTF8', + ParquetConvertedType_MAP: 'MAP', + ParquetConvertedType_MAP_KEY_VALUE: 'MAP_KEY_VALUE', + ParquetConvertedType_LIST: 'LIST', + ParquetConvertedType_ENUM: 'ENUM', + ParquetConvertedType_DECIMAL: 'DECIMAL', + ParquetConvertedType_DATE: 'DATE', + ParquetConvertedType_TIME_MILLIS: 'TIME_MILLIS', + ParquetConvertedType_TIME_MICROS: 'TIME_MICROS', + ParquetConvertedType_TIMESTAMP_MILLIS: 'TIMESTAMP_MILLIS', + ParquetConvertedType_TIMESTAMP_MICROS: 'TIMESTAMP_MICROS', + ParquetConvertedType_UINT_8: 'UINT_8', + ParquetConvertedType_UINT_16: 'UINT_16', + ParquetConvertedType_UINT_32: 'UINT_32', + ParquetConvertedType_UINT_64: 'UINT_64', + ParquetConvertedType_INT_8: 'INT_8', + ParquetConvertedType_INT_16: 'INT_16', + ParquetConvertedType_INT_32: 'INT_32', + ParquetConvertedType_INT_64: 'INT_64', + ParquetConvertedType_JSON: 'JSON', + ParquetConvertedType_BSON: 'BSON', + ParquetConvertedType_INTERVAL: 'INTERVAL', + }.get(type_, 'UNKNOWN') + + +cdef encoding_name_from_enum(ParquetEncoding encoding_): + return { + ParquetEncoding_PLAIN: 'PLAIN', + ParquetEncoding_PLAIN_DICTIONARY: 'PLAIN_DICTIONARY', + ParquetEncoding_RLE: 'RLE', + ParquetEncoding_BIT_PACKED: 'BIT_PACKED', + ParquetEncoding_DELTA_BINARY_PACKED: 'DELTA_BINARY_PACKED', + ParquetEncoding_DELTA_LENGTH_BYTE_ARRAY: 'DELTA_LENGTH_BYTE_ARRAY', + ParquetEncoding_DELTA_BYTE_ARRAY: 'DELTA_BYTE_ARRAY', + ParquetEncoding_RLE_DICTIONARY: 'RLE_DICTIONARY', + ParquetEncoding_BYTE_STREAM_SPLIT: 'BYTE_STREAM_SPLIT', + }.get(encoding_, 'UNKNOWN') + + +cdef encoding_enum_from_name(str encoding_name): + enc = { + 'PLAIN': ParquetEncoding_PLAIN, + 'BIT_PACKED': ParquetEncoding_BIT_PACKED, + 'RLE': ParquetEncoding_RLE, + 'BYTE_STREAM_SPLIT': ParquetEncoding_BYTE_STREAM_SPLIT, + 'DELTA_BINARY_PACKED': ParquetEncoding_DELTA_BINARY_PACKED, + 'DELTA_LENGTH_BYTE_ARRAY': ParquetEncoding_DELTA_LENGTH_BYTE_ARRAY, + 'DELTA_BYTE_ARRAY': ParquetEncoding_DELTA_BYTE_ARRAY, + 'RLE_DICTIONARY': 'dict', + 'PLAIN_DICTIONARY': 'dict', + }.get(encoding_name, None) + if enc is None: + raise ValueError(f"Unsupported column encoding: {encoding_name!r}") + elif enc == 'dict': + raise ValueError(f"{encoding_name!r} is already used by default.") + else: + return enc + + +cdef compression_name_from_enum(ParquetCompression compression_): + return { + ParquetCompression_UNCOMPRESSED: 'UNCOMPRESSED', + ParquetCompression_SNAPPY: 'SNAPPY', + ParquetCompression_GZIP: 'GZIP', + ParquetCompression_LZO: 'LZO', + ParquetCompression_BROTLI: 'BROTLI', + ParquetCompression_LZ4: 'LZ4', + ParquetCompression_ZSTD: 'ZSTD', + }.get(compression_, 'UNKNOWN') + + +cdef int check_compression_name(name) except -1: + if name.upper() not in {'NONE', 'SNAPPY', 'GZIP', 'LZO', 'BROTLI', 'LZ4', + 'ZSTD'}: + raise ArrowException("Unsupported compression: " + name) + return 0 + + +cdef ParquetCompression compression_from_name(name): + name = name.upper() + if name == 'SNAPPY': + return ParquetCompression_SNAPPY + elif name == 'GZIP': + return ParquetCompression_GZIP + elif name == 'LZO': + return ParquetCompression_LZO + elif name == 'BROTLI': + return ParquetCompression_BROTLI + elif name == 'LZ4': + return ParquetCompression_LZ4 + elif name == 'ZSTD': + return ParquetCompression_ZSTD + else: + return ParquetCompression_UNCOMPRESSED + + +cdef class ParquetReader(_Weakrefable): + cdef: + object source + CMemoryPool* pool + UniquePtrNoGIL[FileReader] reader + FileMetaData _metadata + shared_ptr[CRandomAccessFile] rd_handle + + cdef public: + _column_idx_map + + def __cinit__(self, MemoryPool memory_pool=None): + self.pool = maybe_unbox_memory_pool(memory_pool) + self._metadata = None + + def open(self, object source not None, *, bint use_memory_map=False, + read_dictionary=None, FileMetaData metadata=None, + int buffer_size=0, bint pre_buffer=False, + coerce_int96_timestamp_unit=None, + FileDecryptionProperties decryption_properties=None, + thrift_string_size_limit=None, + thrift_container_size_limit=None, + page_checksum_verification=False): + """ + Open a parquet file for reading. + + Parameters + ---------- + source : str, pathlib.Path, pyarrow.NativeFile, or file-like object + use_memory_map : bool, default False + read_dictionary : iterable[int or str], optional + metadata : FileMetaData, optional + buffer_size : int, default 0 + pre_buffer : bool, default False + coerce_int96_timestamp_unit : str, optional + decryption_properties : FileDecryptionProperties, optional + thrift_string_size_limit : int, optional + thrift_container_size_limit : int, optional + page_checksum_verification : bool, default False + """ + cdef: + shared_ptr[CFileMetaData] c_metadata + CReaderProperties properties = default_reader_properties() + ArrowReaderProperties arrow_props = ( + default_arrow_reader_properties()) + FileReaderBuilder builder + + if pre_buffer and not is_threading_enabled(): + pre_buffer = False + + if metadata is not None: + c_metadata = metadata.sp_metadata + + if buffer_size > 0: + properties.enable_buffered_stream() + properties.set_buffer_size(buffer_size) + elif buffer_size == 0: + properties.disable_buffered_stream() + else: + raise ValueError('Buffer size must be larger than zero') + + if thrift_string_size_limit is not None: + if thrift_string_size_limit <= 0: + raise ValueError("thrift_string_size_limit " + "must be larger than zero") + properties.set_thrift_string_size_limit(thrift_string_size_limit) + if thrift_container_size_limit is not None: + if thrift_container_size_limit <= 0: + raise ValueError("thrift_container_size_limit " + "must be larger than zero") + properties.set_thrift_container_size_limit( + thrift_container_size_limit) + + if decryption_properties is not None: + properties.file_decryption_properties( + decryption_properties.unwrap()) + + arrow_props.set_pre_buffer(pre_buffer) + + properties.set_page_checksum_verification(page_checksum_verification) + + if coerce_int96_timestamp_unit is None: + # use the default defined in default_arrow_reader_properties() + pass + else: + arrow_props.set_coerce_int96_timestamp_unit( + string_to_timeunit(coerce_int96_timestamp_unit)) + + self.source = source + get_reader(source, use_memory_map, &self.rd_handle) + + with nogil: + check_status(builder.Open(self.rd_handle, properties, c_metadata)) + + # Set up metadata + with nogil: + c_metadata = builder.raw_reader().metadata() + self._metadata = result = FileMetaData() + result.init(c_metadata) + + if read_dictionary is not None: + self._set_read_dictionary(read_dictionary, &arrow_props) + + with nogil: + check_status(builder.memory_pool(self.pool) + .properties(arrow_props) + .Build(&self.reader)) + + cdef _set_read_dictionary(self, read_dictionary, + ArrowReaderProperties* props): + for column in read_dictionary: + if not isinstance(column, int): + column = self.column_name_idx(column) + props.set_read_dictionary(column, True) + + @property + def column_paths(self): + cdef: + FileMetaData container = self.metadata + const CFileMetaData* metadata = container._metadata + vector[c_string] path + int i = 0 + + paths = [] + for i in range(0, metadata.num_columns()): + path = (metadata.schema().Column(i) + .path().get().ToDotVector()) + paths.append([frombytes(x) for x in path]) + + return paths + + @property + def metadata(self): + return self._metadata + + @property + def schema_arrow(self): + cdef shared_ptr[CSchema] out + with nogil: + check_status(self.reader.get().GetSchema(&out)) + return pyarrow_wrap_schema(out) + + @property + def num_row_groups(self): + return self.reader.get().num_row_groups() + + def set_use_threads(self, bint use_threads): + """ + Parameters + ---------- + use_threads : bool + """ + if is_threading_enabled(): + self.reader.get().set_use_threads(use_threads) + else: + self.reader.get().set_use_threads(False) + + def set_batch_size(self, int64_t batch_size): + """ + Parameters + ---------- + batch_size : int64 + """ + self.reader.get().set_batch_size(batch_size) + + def iter_batches(self, int64_t batch_size, row_groups, column_indices=None, + bint use_threads=True): + """ + Parameters + ---------- + batch_size : int64 + row_groups : list[int] + column_indices : list[int], optional + use_threads : bool, default True + + Yields + ------ + next : RecordBatch + """ + cdef: + vector[int] c_row_groups + vector[int] c_column_indices + shared_ptr[CRecordBatch] record_batch + UniquePtrNoGIL[CRecordBatchReader] recordbatchreader + + self.set_batch_size(batch_size) + + if use_threads: + self.set_use_threads(use_threads) + + for row_group in row_groups: + c_row_groups.push_back(row_group) + + if column_indices is not None: + for index in column_indices: + c_column_indices.push_back(index) + with nogil: + check_status( + self.reader.get().GetRecordBatchReader( + c_row_groups, c_column_indices, &recordbatchreader + ) + ) + else: + with nogil: + check_status( + self.reader.get().GetRecordBatchReader( + c_row_groups, &recordbatchreader + ) + ) + + while True: + with nogil: + check_status( + recordbatchreader.get().ReadNext(&record_batch) + ) + if record_batch.get() == NULL: + break + + yield pyarrow_wrap_batch(record_batch) + + def read_row_group(self, int i, column_indices=None, + bint use_threads=True): + """ + Parameters + ---------- + i : int + column_indices : list[int], optional + use_threads : bool, default True + + Returns + ------- + table : pyarrow.Table + """ + return self.read_row_groups([i], column_indices, use_threads) + + def read_row_groups(self, row_groups not None, column_indices=None, + bint use_threads=True): + """ + Parameters + ---------- + row_groups : list[int] + column_indices : list[int], optional + use_threads : bool, default True + + Returns + ------- + table : pyarrow.Table + """ + cdef: + shared_ptr[CTable] ctable + vector[int] c_row_groups + vector[int] c_column_indices + + self.set_use_threads(use_threads) + + for row_group in row_groups: + c_row_groups.push_back(row_group) + + if column_indices is not None: + for index in column_indices: + c_column_indices.push_back(index) + + with nogil: + check_status(self.reader.get() + .ReadRowGroups(c_row_groups, c_column_indices, + &ctable)) + else: + # Read all columns + with nogil: + check_status(self.reader.get() + .ReadRowGroups(c_row_groups, &ctable)) + return pyarrow_wrap_table(ctable) + + def read_all(self, column_indices=None, bint use_threads=True): + """ + Parameters + ---------- + column_indices : list[int], optional + use_threads : bool, default True + + Returns + ------- + table : pyarrow.Table + """ + cdef: + shared_ptr[CTable] ctable + vector[int] c_column_indices + + self.set_use_threads(use_threads) + + if column_indices is not None: + for index in column_indices: + c_column_indices.push_back(index) + + with nogil: + check_status(self.reader.get() + .ReadTable(c_column_indices, &ctable)) + else: + # Read all columns + with nogil: + check_status(self.reader.get() + .ReadTable(&ctable)) + return pyarrow_wrap_table(ctable) + + def scan_contents(self, column_indices=None, batch_size=65536): + """ + Parameters + ---------- + column_indices : list[int], optional + batch_size : int32, default 65536 + + Returns + ------- + num_rows : int64 + """ + cdef: + vector[int] c_column_indices + int32_t c_batch_size + int64_t c_num_rows + + if column_indices is not None: + for index in column_indices: + c_column_indices.push_back(index) + + c_batch_size = batch_size + + with nogil: + check_status(self.reader.get() + .ScanContents(c_column_indices, c_batch_size, + &c_num_rows)) + + return c_num_rows + + def column_name_idx(self, column_name): + """ + Find the index of a column by its name. + + Parameters + ---------- + column_name : str + Name of the column; separation of nesting levels is done via ".". + + Returns + ------- + column_idx : int + Integer index of the column in the schema. + """ + cdef: + FileMetaData container = self.metadata + const CFileMetaData* metadata = container._metadata + int i = 0 + + if self._column_idx_map is None: + self._column_idx_map = {} + for i in range(0, metadata.num_columns()): + col_bytes = tobytes(metadata.schema().Column(i) + .path().get().ToDotString()) + self._column_idx_map[col_bytes] = i + + return self._column_idx_map[tobytes(column_name)] + + def read_column(self, int column_index): + """ + Read the column at the specified index. + + Parameters + ---------- + column_index : int + Index of the column. + + Returns + ------- + column : pyarrow.ChunkedArray + """ + cdef shared_ptr[CChunkedArray] out + with nogil: + check_status(self.reader.get() + .ReadColumn(column_index, &out)) + return pyarrow_wrap_chunked_array(out) + + def close(self): + if not self.closed: + with nogil: + check_status(self.rd_handle.get().Close()) + + @property + def closed(self): + if self.rd_handle == NULL: + return True + with nogil: + closed = self.rd_handle.get().closed() + return closed + + +cdef CSortingColumn _convert_sorting_column(SortingColumn sorting_column): + cdef CSortingColumn c_sorting_column + + c_sorting_column.column_idx = sorting_column.column_index + c_sorting_column.descending = sorting_column.descending + c_sorting_column.nulls_first = sorting_column.nulls_first + + return c_sorting_column + + +cdef vector[CSortingColumn] _convert_sorting_columns(sorting_columns) except *: + if not (isinstance(sorting_columns, Sequence) + and all(isinstance(col, SortingColumn) for col in sorting_columns)): + raise ValueError( + "'sorting_columns' must be a list of `SortingColumn`") + + cdef vector[CSortingColumn] c_sorting_columns = [_convert_sorting_column(col) + for col in sorting_columns] + + return c_sorting_columns + + +cdef shared_ptr[WriterProperties] _create_writer_properties( + use_dictionary=None, + compression=None, + version=None, + write_statistics=None, + data_page_size=None, + compression_level=None, + use_byte_stream_split=False, + column_encoding=None, + data_page_version=None, + FileEncryptionProperties encryption_properties=None, + write_batch_size=None, + dictionary_pagesize_limit=None, + write_page_index=False, + write_page_checksum=False, + sorting_columns=None, + store_decimal_as_integer=False) except *: + + """General writer properties""" + cdef: + shared_ptr[WriterProperties] properties + WriterProperties.Builder props + + # data_page_version + + if data_page_version is not None: + if data_page_version == "1.0": + props.data_page_version(ParquetDataPageVersion_V1) + elif data_page_version == "2.0": + props.data_page_version(ParquetDataPageVersion_V2) + else: + raise ValueError("Unsupported Parquet data page version: {0}" + .format(data_page_version)) + + # version + + if version is not None: + if version == "1.0": + props.version(ParquetVersion_V1) + elif version in ("2.0", "pseudo-2.0"): + warnings.warn( + "Parquet format '2.0' pseudo version is deprecated, use " + "'2.4' or '2.6' for fine-grained feature selection", + FutureWarning, stacklevel=2) + props.version(ParquetVersion_V2_0) + elif version == "2.4": + props.version(ParquetVersion_V2_4) + elif version == "2.6": + props.version(ParquetVersion_V2_6) + else: + raise ValueError("Unsupported Parquet format version: {0}" + .format(version)) + + # compression + + if isinstance(compression, basestring): + check_compression_name(compression) + props.compression(compression_from_name(compression)) + elif compression is not None: + for column, codec in compression.iteritems(): + check_compression_name(codec) + props.compression(tobytes(column), compression_from_name(codec)) + + if isinstance(compression_level, int): + props.compression_level(compression_level) + elif compression_level is not None: + for column, level in compression_level.iteritems(): + props.compression_level(tobytes(column), level) + + # use_dictionary + + if isinstance(use_dictionary, bool): + if use_dictionary: + props.enable_dictionary() + if column_encoding is not None: + raise ValueError( + "To use 'column_encoding' set 'use_dictionary' to False") + else: + props.disable_dictionary() + elif use_dictionary is not None: + # Deactivate dictionary encoding by default + props.disable_dictionary() + for column in use_dictionary: + props.enable_dictionary(tobytes(column)) + if (column_encoding is not None and + column_encoding.get(column) is not None): + raise ValueError( + "To use 'column_encoding' set 'use_dictionary' to False") + + # write_statistics + + if isinstance(write_statistics, bool): + if write_statistics: + props.enable_statistics() + else: + props.disable_statistics() + elif write_statistics is not None: + # Deactivate statistics by default and enable for specified columns + props.disable_statistics() + for column in write_statistics: + props.enable_statistics(tobytes(column)) + + # sorting_columns + + if sorting_columns is not None: + props.set_sorting_columns(_convert_sorting_columns(sorting_columns)) + + # use_byte_stream_split + + if isinstance(use_byte_stream_split, bool): + if use_byte_stream_split: + if column_encoding is not None: + raise ValueError( + "'use_byte_stream_split' cannot be passed" + "together with 'column_encoding'") + else: + props.encoding(ParquetEncoding_BYTE_STREAM_SPLIT) + elif use_byte_stream_split is not None: + for column in use_byte_stream_split: + if column_encoding is None: + column_encoding = {column: 'BYTE_STREAM_SPLIT'} + elif column_encoding.get(column, None) is None: + column_encoding[column] = 'BYTE_STREAM_SPLIT' + else: + raise ValueError( + "'use_byte_stream_split' cannot be passed" + "together with 'column_encoding'") + + # store_decimal_as_integer + + if isinstance(store_decimal_as_integer, bool): + if store_decimal_as_integer: + props.enable_store_decimal_as_integer() + else: + props.disable_store_decimal_as_integer() + else: + raise TypeError("'store_decimal_as_integer' must be a boolean") + + # column_encoding + # encoding map - encode individual columns + + if column_encoding is not None: + if isinstance(column_encoding, dict): + for column, _encoding in column_encoding.items(): + props.encoding(tobytes(column), + encoding_enum_from_name(_encoding)) + elif isinstance(column_encoding, str): + props.encoding(encoding_enum_from_name(column_encoding)) + else: + raise TypeError( + "'column_encoding' should be a dictionary or a string") + + if data_page_size is not None: + props.data_pagesize(data_page_size) + + if write_batch_size is not None: + props.write_batch_size(write_batch_size) + + if dictionary_pagesize_limit is not None: + props.dictionary_pagesize_limit(dictionary_pagesize_limit) + + # encryption + + if encryption_properties is not None: + props.encryption( + (encryption_properties).unwrap()) + + # For backwards compatibility reasons we cap the maximum row group size + # at 64Mi rows. This could be changed in the future, though it would be + # a breaking change. + # + # The user can always specify a smaller row group size (and the default + # is smaller) when calling write_table. If the call to write_table uses + # a size larger than this then it will be latched to this value. + props.max_row_group_length(_MAX_ROW_GROUP_SIZE) + + # checksum + + if write_page_checksum: + props.enable_page_checksum() + else: + props.disable_page_checksum() + + # page index + + if write_page_index: + props.enable_write_page_index() + else: + props.disable_write_page_index() + + properties = props.build() + + return properties + + +cdef shared_ptr[ArrowWriterProperties] _create_arrow_writer_properties( + use_deprecated_int96_timestamps=False, + coerce_timestamps=None, + allow_truncated_timestamps=False, + writer_engine_version=None, + use_compliant_nested_type=True, + store_schema=True) except *: + """Arrow writer properties""" + cdef: + shared_ptr[ArrowWriterProperties] arrow_properties + ArrowWriterProperties.Builder arrow_props + + # Store the original Arrow schema so things like dictionary types can + # be automatically reconstructed + if store_schema: + arrow_props.store_schema() + + # int96 support + + if use_deprecated_int96_timestamps: + arrow_props.enable_deprecated_int96_timestamps() + else: + arrow_props.disable_deprecated_int96_timestamps() + + # coerce_timestamps + + if coerce_timestamps == 'ms': + arrow_props.coerce_timestamps(TimeUnit_MILLI) + elif coerce_timestamps == 'us': + arrow_props.coerce_timestamps(TimeUnit_MICRO) + elif coerce_timestamps is not None: + raise ValueError('Invalid value for coerce_timestamps: {0}' + .format(coerce_timestamps)) + + # allow_truncated_timestamps + + if allow_truncated_timestamps: + arrow_props.allow_truncated_timestamps() + else: + arrow_props.disallow_truncated_timestamps() + + # use_compliant_nested_type + + if use_compliant_nested_type: + arrow_props.enable_compliant_nested_types() + else: + arrow_props.disable_compliant_nested_types() + + # writer_engine_version + + if writer_engine_version == "V1": + warnings.warn("V1 parquet writer engine is a no-op. Use V2.") + arrow_props.set_engine_version(ArrowWriterEngineVersion.V1) + elif writer_engine_version != "V2": + raise ValueError("Unsupported Writer Engine Version: {0}" + .format(writer_engine_version)) + + arrow_properties = arrow_props.build() + + return arrow_properties + +cdef _name_to_index_map(Schema arrow_schema): + cdef: + shared_ptr[CSchema] sp_arrow_schema + shared_ptr[SchemaDescriptor] sp_parquet_schema + shared_ptr[WriterProperties] props = _create_writer_properties() + shared_ptr[ArrowWriterProperties] arrow_props = _create_arrow_writer_properties( + use_deprecated_int96_timestamps=False, + coerce_timestamps=None, + allow_truncated_timestamps=False, + writer_engine_version="V2" + ) + + sp_arrow_schema = pyarrow_unwrap_schema(arrow_schema) + + with nogil: + check_status(ToParquetSchema( + sp_arrow_schema.get(), deref(props.get()), deref(arrow_props.get()), &sp_parquet_schema)) + + out = dict() + + cdef SchemaDescriptor* parquet_schema = sp_parquet_schema.get() + + for i in range(parquet_schema.num_columns()): + name = frombytes(parquet_schema.Column(i).path().get().ToDotString()) + out[name] = i + + return out + + +cdef class ParquetWriter(_Weakrefable): + cdef: + unique_ptr[FileWriter] writer + shared_ptr[COutputStream] sink + bint own_sink + + cdef readonly: + object use_dictionary + object use_deprecated_int96_timestamps + object use_byte_stream_split + object column_encoding + object coerce_timestamps + object allow_truncated_timestamps + object compression + object compression_level + object data_page_version + object use_compliant_nested_type + object version + object write_statistics + object writer_engine_version + int row_group_size + int64_t data_page_size + FileEncryptionProperties encryption_properties + int64_t write_batch_size + int64_t dictionary_pagesize_limit + object store_schema + object store_decimal_as_integer + + def __cinit__(self, where, Schema schema not None, use_dictionary=None, + compression=None, version=None, + write_statistics=None, + MemoryPool memory_pool=None, + use_deprecated_int96_timestamps=False, + coerce_timestamps=None, + data_page_size=None, + allow_truncated_timestamps=False, + compression_level=None, + use_byte_stream_split=False, + column_encoding=None, + writer_engine_version=None, + data_page_version=None, + use_compliant_nested_type=True, + encryption_properties=None, + write_batch_size=None, + dictionary_pagesize_limit=None, + store_schema=True, + write_page_index=False, + write_page_checksum=False, + sorting_columns=None, + store_decimal_as_integer=False): + cdef: + shared_ptr[WriterProperties] properties + shared_ptr[ArrowWriterProperties] arrow_properties + c_string c_where + CMemoryPool* pool + + try: + where = _stringify_path(where) + except TypeError: + get_writer(where, &self.sink) + self.own_sink = False + else: + c_where = tobytes(where) + with nogil: + self.sink = GetResultValue(FileOutputStream.Open(c_where)) + self.own_sink = True + + properties = _create_writer_properties( + use_dictionary=use_dictionary, + compression=compression, + version=version, + write_statistics=write_statistics, + data_page_size=data_page_size, + compression_level=compression_level, + use_byte_stream_split=use_byte_stream_split, + column_encoding=column_encoding, + data_page_version=data_page_version, + encryption_properties=encryption_properties, + write_batch_size=write_batch_size, + dictionary_pagesize_limit=dictionary_pagesize_limit, + write_page_index=write_page_index, + write_page_checksum=write_page_checksum, + sorting_columns=sorting_columns, + store_decimal_as_integer=store_decimal_as_integer, + ) + arrow_properties = _create_arrow_writer_properties( + use_deprecated_int96_timestamps=use_deprecated_int96_timestamps, + coerce_timestamps=coerce_timestamps, + allow_truncated_timestamps=allow_truncated_timestamps, + writer_engine_version=writer_engine_version, + use_compliant_nested_type=use_compliant_nested_type, + store_schema=store_schema, + ) + + pool = maybe_unbox_memory_pool(memory_pool) + with nogil: + self.writer = move(GetResultValue( + FileWriter.Open(deref(schema.schema), pool, + self.sink, properties, arrow_properties))) + + def close(self): + with nogil: + check_status(self.writer.get().Close()) + if self.own_sink: + check_status(self.sink.get().Close()) + + def write_table(self, Table table, row_group_size=None): + cdef: + CTable* ctable = table.table + int64_t c_row_group_size + + if row_group_size is None or row_group_size == -1: + c_row_group_size = min(ctable.num_rows(), _DEFAULT_ROW_GROUP_SIZE) + elif row_group_size == 0: + raise ValueError('Row group size cannot be 0') + else: + c_row_group_size = row_group_size + + with nogil: + check_status(self.writer.get() + .WriteTable(deref(ctable), c_row_group_size)) + + def add_key_value_metadata(self, key_value_metadata): + cdef: + shared_ptr[const CKeyValueMetadata] c_metadata + + c_metadata = pyarrow_unwrap_metadata(KeyValueMetadata(key_value_metadata)) + with nogil: + check_status(self.writer.get() + .AddKeyValueMetadata(c_metadata)) + + @property + def metadata(self): + cdef: + shared_ptr[CFileMetaData] metadata + FileMetaData result + with nogil: + metadata = self.writer.get().metadata() + if metadata: + result = FileMetaData() + result.init(metadata) + return result + raise RuntimeError( + 'file metadata is only available after writer close') diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_pyarrow_cpp_tests.pxd b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_pyarrow_cpp_tests.pxd new file mode 100644 index 0000000000000000000000000000000000000000..91c0220d7310870a7803ecceb2c32b8b32f8c11d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_pyarrow_cpp_tests.pxd @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# distutils: language = c++ +# cython: language_level = 3 + +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport CStatus + + +ctypedef CStatus cb_test_func() + +cdef extern from "arrow/python/python_test.h" namespace "arrow::py::testing" nogil: + + cdef cppclass CTestCase "arrow::py::testing::TestCase": + c_string name + cb_test_func func + + vector[CTestCase] GetCppTestCases() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_substrait.pyx b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_substrait.pyx new file mode 100644 index 0000000000000000000000000000000000000000..d9359c8e77d00e8067c11517d84c1f908b8745df --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_substrait.pyx @@ -0,0 +1,481 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 +from cython.operator cimport dereference as deref +from libcpp.vector cimport vector as std_vector + +from pyarrow import Buffer, py_buffer +from pyarrow._compute cimport Expression +from pyarrow.lib import frombytes, tobytes +from pyarrow.lib cimport * +from pyarrow.includes.libarrow cimport * +from pyarrow.includes.libarrow_substrait cimport * + +try: + import substrait as py_substrait +except ImportError: + py_substrait = None +else: + import substrait.proto # no-cython-lint + + +# TODO GH-37235: Fix exception handling +cdef CDeclaration _create_named_table_provider( + dict named_args, const std_vector[c_string]& names, const CSchema& schema +) noexcept: + cdef: + c_string c_name + shared_ptr[CTable] c_in_table + shared_ptr[CTableSourceNodeOptions] c_tablesourceopts + shared_ptr[CExecNodeOptions] c_input_node_opts + vector[CDeclaration.Input] no_c_inputs + + py_names = [] + for i in range(names.size()): + c_name = names[i] + py_names.append(frombytes(c_name)) + py_schema = pyarrow_wrap_schema(make_shared[CSchema](schema)) + + py_table = named_args["provider"](py_names, py_schema) + c_in_table = pyarrow_unwrap_table(py_table) + c_tablesourceopts = make_shared[CTableSourceNodeOptions](c_in_table) + c_input_node_opts = static_pointer_cast[CExecNodeOptions, CTableSourceNodeOptions]( + c_tablesourceopts) + return CDeclaration(tobytes("table_source"), + no_c_inputs, c_input_node_opts) + + +def run_query(plan, *, table_provider=None, use_threads=True): + """ + Execute a Substrait plan and read the results as a RecordBatchReader. + + Parameters + ---------- + plan : Union[Buffer, bytes] + The serialized Substrait plan to execute. + table_provider : object (optional) + A function to resolve any NamedTable relation to a table. + The function will receive two arguments which will be a list + of strings representing the table name and a pyarrow.Schema representing + the expected schema and should return a pyarrow.Table. + use_threads : bool, default True + If True then multiple threads will be used to run the query. If False then + all CPU intensive work will be done on the calling thread. + + Returns + ------- + RecordBatchReader + A reader containing the result of the executed query + + Examples + -------- + >>> import pyarrow as pa + >>> from pyarrow.lib import tobytes + >>> import pyarrow.substrait as substrait + >>> test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]}) + >>> test_table_2 = pa.Table.from_pydict({"x": [4, 5, 6]}) + >>> def table_provider(names, schema): + ... if not names: + ... raise Exception("No names provided") + ... elif names[0] == "t1": + ... return test_table_1 + ... elif names[1] == "t2": + ... return test_table_2 + ... else: + ... raise Exception("Unrecognized table name") + ... + >>> substrait_query = ''' + ... { + ... "relations": [ + ... {"rel": { + ... "read": { + ... "base_schema": { + ... "struct": { + ... "types": [ + ... {"i64": {}} + ... ] + ... }, + ... "names": [ + ... "x" + ... ] + ... }, + ... "namedTable": { + ... "names": ["t1"] + ... } + ... } + ... }} + ... ] + ... } + ... ''' + >>> buf = pa._substrait._parse_json_plan(tobytes(substrait_query)) + >>> reader = pa.substrait.run_query(buf, table_provider=table_provider) + >>> reader.read_all() + pyarrow.Table + x: int64 + ---- + x: [[1,2,3]] + """ + + cdef: + CResult[shared_ptr[CRecordBatchReader]] c_res_reader + shared_ptr[CRecordBatchReader] c_reader + RecordBatchReader reader + shared_ptr[CBuffer] c_buf_plan + CConversionOptions c_conversion_options + c_bool c_use_threads + + c_use_threads = use_threads + if isinstance(plan, (bytes, memoryview)): + c_buf_plan = pyarrow_unwrap_buffer(py_buffer(plan)) + elif isinstance(plan, Buffer): + c_buf_plan = pyarrow_unwrap_buffer(plan) + else: + raise TypeError( + f"Expected 'pyarrow.Buffer' or bytes, got '{type(plan)}'") + + if table_provider is not None: + named_table_args = { + "provider": table_provider + } + c_conversion_options.named_table_provider = BindFunction[CNamedTableProvider]( + &_create_named_table_provider, named_table_args) + + with nogil: + c_res_reader = ExecuteSerializedPlan( + deref(c_buf_plan), default_extension_id_registry(), + GetFunctionRegistry(), c_conversion_options, c_use_threads) + + c_reader = GetResultValue(c_res_reader) + + reader = RecordBatchReader.__new__(RecordBatchReader) + reader.reader = c_reader + return reader + + +def _parse_json_plan(plan): + """ + Parse a JSON plan into equivalent serialized Protobuf. + + Parameters + ---------- + plan : bytes + Substrait plan in JSON. + + Returns + ------- + Buffer + A buffer containing the serialized Protobuf plan. + """ + + cdef: + CResult[shared_ptr[CBuffer]] c_res_buffer + c_string c_str_plan + shared_ptr[CBuffer] c_buf_plan + + c_str_plan = plan + c_res_buffer = SerializeJsonPlan(c_str_plan) + with nogil: + c_buf_plan = GetResultValue(c_res_buffer) + return pyarrow_wrap_buffer(c_buf_plan) + + +class SubstraitSchema: + """A Schema encoded for Substrait usage. + + The SubstraitSchema contains a schema represented + both as a substrait ``NamedStruct`` and as an + ``ExtendedExpression``. + + The ``ExtendedExpression`` is available for cases where types + used by the schema require extensions to decode them. + In such case the schema will be the ``base_schema`` of the + ``ExtendedExpression`` and all extensions will be provided. + """ + + def __init__(self, schema, expression): + self.schema = schema + self.expression = expression + + def to_pysubstrait(self): + """Convert the schema to a substrait-python ExtendedExpression object.""" + if py_substrait is None: + raise ImportError("The 'substrait' package is required.") + return py_substrait.proto.ExtendedExpression.FromString(self.expression) + + +def serialize_schema(schema): + """ + Serialize a schema into a SubstraitSchema object. + + Parameters + ---------- + schema : Schema + The schema to serialize + + Returns + ------- + SubstraitSchema + The schema stored in a SubstraitSchema object. + """ + return SubstraitSchema( + schema=_serialize_namedstruct_schema(schema), + expression=serialize_expressions([], [], schema, allow_arrow_extensions=True) + ) + + +def _serialize_namedstruct_schema(schema): + cdef: + CResult[shared_ptr[CBuffer]] c_res_buffer + shared_ptr[CBuffer] c_buffer + CConversionOptions c_conversion_options + CExtensionSet c_extensions + + with nogil: + c_res_buffer = SerializeSchema(deref(( schema).sp_schema), &c_extensions, c_conversion_options) + c_buffer = GetResultValue(c_res_buffer) + + return memoryview(pyarrow_wrap_buffer(c_buffer)) + + +def deserialize_schema(buf): + """ + Deserialize a ``NamedStruct`` Substrait message + or a SubstraitSchema object into an Arrow Schema object + + Parameters + ---------- + buf : Buffer or bytes or SubstraitSchema + The message to deserialize + + Returns + ------- + Schema + The deserialized schema + """ + cdef: + shared_ptr[CBuffer] c_buffer + CResult[shared_ptr[CSchema]] c_res_schema + shared_ptr[CSchema] c_schema + CConversionOptions c_conversion_options + CExtensionSet c_extensions + + if isinstance(buf, SubstraitSchema): + return deserialize_expressions(buf.expression).schema + + if isinstance(buf, (bytes, memoryview)): + c_buffer = pyarrow_unwrap_buffer(py_buffer(buf)) + elif isinstance(buf, Buffer): + c_buffer = pyarrow_unwrap_buffer(buf) + else: + raise TypeError( + f"Expected 'pyarrow.Buffer' or bytes, got '{type(buf)}'") + + with nogil: + c_res_schema = DeserializeSchema( + deref(c_buffer), c_extensions, c_conversion_options) + c_schema = GetResultValue(c_res_schema) + + return pyarrow_wrap_schema(c_schema) + + +def serialize_expressions(exprs, names, schema, *, allow_arrow_extensions=False): + """ + Serialize a collection of expressions into Substrait + + Substrait expressions must be bound to a schema. For example, + the Substrait expression ``a:i32 + b:i32`` is different from the + Substrait expression ``a:i64 + b:i64``. Pyarrow expressions are + typically unbound. For example, both of the above expressions + would be represented as ``a + b`` in pyarrow. + + This means a schema must be provided when serializing an expression. + It also means that the serialization may fail if a matching function + call cannot be found for the expression. + + Parameters + ---------- + exprs : list of Expression + The expressions to serialize + names : list of str + Names for the expressions + schema : Schema + The schema the expressions will be bound to + allow_arrow_extensions : bool, default False + If False then only functions that are part of the core Substrait function + definitions will be allowed. Set this to True to allow pyarrow-specific functions + and user defined functions but the result may not be accepted by other + compute libraries. + + Returns + ------- + Buffer + An ExtendedExpression message containing the serialized expressions + """ + cdef: + CResult[shared_ptr[CBuffer]] c_res_buffer + shared_ptr[CBuffer] c_buffer + CNamedExpression c_named_expr + CBoundExpressions c_bound_exprs + CConversionOptions c_conversion_options + + if len(exprs) != len(names): + raise ValueError("exprs and names need to have the same length") + for expr, name in zip(exprs, names): + if not isinstance(expr, Expression): + raise TypeError(f"Expected Expression, got '{type(expr)}' in exprs") + if not isinstance(name, str): + raise TypeError(f"Expected str, got '{type(name)}' in names") + c_named_expr.expression = ( expr).unwrap() + c_named_expr.name = tobytes( name) + c_bound_exprs.named_expressions.push_back(c_named_expr) + + c_bound_exprs.schema = ( schema).sp_schema + + c_conversion_options.allow_arrow_extensions = allow_arrow_extensions + + with nogil: + c_res_buffer = SerializeExpressions(c_bound_exprs, c_conversion_options) + c_buffer = GetResultValue(c_res_buffer) + return memoryview(pyarrow_wrap_buffer(c_buffer)) + + +cdef class BoundExpressions(_Weakrefable): + """ + A collection of named expressions and the schema they are bound to + + This is equivalent to the Substrait ExtendedExpression message + """ + + cdef: + CBoundExpressions c_bound_exprs + + def __init__(self): + msg = 'BoundExpressions is an abstract class thus cannot be initialized.' + raise TypeError(msg) + + cdef void init(self, CBoundExpressions bound_expressions): + self.c_bound_exprs = bound_expressions + + @property + def schema(self): + """ + The common schema that all expressions are bound to + """ + return pyarrow_wrap_schema(self.c_bound_exprs.schema) + + @property + def expressions(self): + """ + A dict from expression name to expression + """ + expr_dict = {} + for named_expr in self.c_bound_exprs.named_expressions: + name = frombytes(named_expr.name) + expr = Expression.wrap(named_expr.expression) + expr_dict[name] = expr + return expr_dict + + @staticmethod + cdef wrap(const CBoundExpressions& bound_expressions): + cdef BoundExpressions self = BoundExpressions.__new__(BoundExpressions) + self.init(bound_expressions) + return self + + @classmethod + def from_substrait(cls, message): + """ + Convert a Substrait message into a BoundExpressions object + + Parameters + ---------- + message : Buffer or bytes or protobuf Message + The message to convert to a BoundExpressions object + + Returns + ------- + BoundExpressions + The converted expressions, their names, and the bound schema + """ + if isinstance(message, (bytes, memoryview)): + return deserialize_expressions(message) + elif isinstance(message, Buffer): + return deserialize_expressions(message) + else: + try: + return deserialize_expressions(message.SerializeToString()) + except AttributeError: + raise TypeError( + f"Expected 'pyarrow.Buffer' or bytes or protobuf Message, got '{type(message)}'") + + +def deserialize_expressions(buf): + """ + Deserialize an ExtendedExpression Substrait message into a BoundExpressions object + + Parameters + ---------- + buf : Buffer or bytes + The message to deserialize + + Returns + ------- + BoundExpressions + The deserialized expressions, their names, and the bound schema + """ + cdef: + shared_ptr[CBuffer] c_buffer + CResult[CBoundExpressions] c_res_bound_exprs + CBoundExpressions c_bound_exprs + + if isinstance(buf, (bytes, memoryview)): + c_buffer = pyarrow_unwrap_buffer(py_buffer(buf)) + elif isinstance(buf, Buffer): + c_buffer = pyarrow_unwrap_buffer(buf) + else: + raise TypeError( + f"Expected 'pyarrow.Buffer' or bytes, got '{type(buf)}'") + + with nogil: + c_res_bound_exprs = DeserializeExpressions(deref(c_buffer)) + c_bound_exprs = GetResultValue(c_res_bound_exprs) + + return BoundExpressions.wrap(c_bound_exprs) + + +def get_supported_functions(): + """ + Get a list of Substrait functions that the underlying + engine currently supports. + + Returns + ------- + list[str] + A list of function ids encoded as '{uri}#{name}' + """ + + cdef: + ExtensionIdRegistry* c_id_registry + std_vector[c_string] c_ids + + c_id_registry = default_extension_id_registry() + c_ids = c_id_registry.GetSupportedSubstraitFunctions() + + functions_list = [] + for c_id in c_ids: + functions_list.append(frombytes(c_id)) + return functions_list diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/cffi.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/cffi.py new file mode 100644 index 0000000000000000000000000000000000000000..1da1a916914049513b89c68bd60f08ba32b67edb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/cffi.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import + +import cffi + +c_source = """ + struct ArrowSchema { + // Array type description + const char* format; + const char* name; + const char* metadata; + int64_t flags; + int64_t n_children; + struct ArrowSchema** children; + struct ArrowSchema* dictionary; + + // Release callback + void (*release)(struct ArrowSchema*); + // Opaque producer-specific data + void* private_data; + }; + + struct ArrowArray { + // Array data description + int64_t length; + int64_t null_count; + int64_t offset; + int64_t n_buffers; + int64_t n_children; + const void** buffers; + struct ArrowArray** children; + struct ArrowArray* dictionary; + + // Release callback + void (*release)(struct ArrowArray*); + // Opaque producer-specific data + void* private_data; + }; + + struct ArrowArrayStream { + int (*get_schema)(struct ArrowArrayStream*, struct ArrowSchema* out); + int (*get_next)(struct ArrowArrayStream*, struct ArrowArray* out); + + const char* (*get_last_error)(struct ArrowArrayStream*); + + // Release callback + void (*release)(struct ArrowArrayStream*); + // Opaque producer-specific data + void* private_data; + }; + + typedef int32_t ArrowDeviceType; + + struct ArrowDeviceArray { + struct ArrowArray array; + int64_t device_id; + ArrowDeviceType device_type; + void* sync_event; + int64_t reserved[3]; + }; + """ + +# TODO use out-of-line mode for faster import and avoid C parsing +ffi = cffi.FFI() +ffi.cdef(c_source) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/compute.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/compute.py new file mode 100644 index 0000000000000000000000000000000000000000..426ecae31c039797d3429922e262a90a70af7fea --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/compute.py @@ -0,0 +1,744 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pyarrow._compute import ( # noqa + Function, + FunctionOptions, + FunctionRegistry, + HashAggregateFunction, + HashAggregateKernel, + Kernel, + ScalarAggregateFunction, + ScalarAggregateKernel, + ScalarFunction, + ScalarKernel, + VectorFunction, + VectorKernel, + # Option classes + ArraySortOptions, + AssumeTimezoneOptions, + CastOptions, + CountOptions, + CumulativeOptions, + CumulativeSumOptions, + DayOfWeekOptions, + DictionaryEncodeOptions, + RunEndEncodeOptions, + ElementWiseAggregateOptions, + ExtractRegexOptions, + FilterOptions, + IndexOptions, + JoinOptions, + ListSliceOptions, + ListFlattenOptions, + MakeStructOptions, + MapLookupOptions, + MatchSubstringOptions, + ModeOptions, + NullOptions, + PadOptions, + PairwiseOptions, + PartitionNthOptions, + QuantileOptions, + RandomOptions, + RankOptions, + ReplaceSliceOptions, + ReplaceSubstringOptions, + RoundBinaryOptions, + RoundOptions, + RoundTemporalOptions, + RoundToMultipleOptions, + ScalarAggregateOptions, + SelectKOptions, + SetLookupOptions, + SliceOptions, + SortOptions, + SplitOptions, + SplitPatternOptions, + StrftimeOptions, + StrptimeOptions, + StructFieldOptions, + TakeOptions, + TDigestOptions, + TrimOptions, + Utf8NormalizeOptions, + VarianceOptions, + WeekOptions, + # Functions + call_function, + function_registry, + get_function, + list_functions, + # Udf + call_tabular_function, + register_scalar_function, + register_tabular_function, + register_aggregate_function, + register_vector_function, + UdfContext, + # Expressions + Expression, +) + +from collections import namedtuple +import inspect +from textwrap import dedent +import warnings + +import pyarrow as pa +from pyarrow import _compute_docstrings +from pyarrow.vendored import docscrape + + +def _get_arg_names(func): + return func._doc.arg_names + + +_OptionsClassDoc = namedtuple('_OptionsClassDoc', ('params',)) + + +def _scrape_options_class_doc(options_class): + if not options_class.__doc__: + return None + doc = docscrape.NumpyDocString(options_class.__doc__) + return _OptionsClassDoc(doc['Parameters']) + + +def _decorate_compute_function(wrapper, exposed_name, func, options_class): + # Decorate the given compute function wrapper with useful metadata + # and documentation. + cpp_doc = func._doc + + wrapper.__arrow_compute_function__ = dict( + name=func.name, + arity=func.arity, + options_class=cpp_doc.options_class, + options_required=cpp_doc.options_required) + wrapper.__name__ = exposed_name + wrapper.__qualname__ = exposed_name + + doc_pieces = [] + + # 1. One-line summary + summary = cpp_doc.summary + if not summary: + arg_str = "arguments" if func.arity > 1 else "argument" + summary = ("Call compute function {!r} with the given {}" + .format(func.name, arg_str)) + + doc_pieces.append(f"{summary}.\n\n") + + # 2. Multi-line description + description = cpp_doc.description + if description: + doc_pieces.append(f"{description}\n\n") + + doc_addition = _compute_docstrings.function_doc_additions.get(func.name) + + # 3. Parameter description + doc_pieces.append(dedent("""\ + Parameters + ---------- + """)) + + # 3a. Compute function parameters + arg_names = _get_arg_names(func) + for arg_name in arg_names: + if func.kind in ('vector', 'scalar_aggregate'): + arg_type = 'Array-like' + else: + arg_type = 'Array-like or scalar-like' + doc_pieces.append(f"{arg_name} : {arg_type}\n") + doc_pieces.append(" Argument to compute function.\n") + + # 3b. Compute function option values + if options_class is not None: + options_class_doc = _scrape_options_class_doc(options_class) + if options_class_doc: + for p in options_class_doc.params: + doc_pieces.append(f"{p.name} : {p.type}\n") + for s in p.desc: + doc_pieces.append(f" {s}\n") + else: + warnings.warn(f"Options class {options_class.__name__} " + f"does not have a docstring", RuntimeWarning) + options_sig = inspect.signature(options_class) + for p in options_sig.parameters.values(): + doc_pieces.append(dedent("""\ + {0} : optional + Parameter for {1} constructor. Either `options` + or `{0}` can be passed, but not both at the same time. + """.format(p.name, options_class.__name__))) + doc_pieces.append(dedent(f"""\ + options : pyarrow.compute.{options_class.__name__}, optional + Alternative way of passing options. + """)) + + doc_pieces.append(dedent("""\ + memory_pool : pyarrow.MemoryPool, optional + If not passed, will allocate memory from the default memory pool. + """)) + + # 4. Custom addition (e.g. examples) + if doc_addition is not None: + doc_pieces.append("\n{}\n".format(dedent(doc_addition).strip("\n"))) + + wrapper.__doc__ = "".join(doc_pieces) + return wrapper + + +def _get_options_class(func): + class_name = func._doc.options_class + if not class_name: + return None + try: + return globals()[class_name] + except KeyError: + warnings.warn("Python binding for {} not exposed" + .format(class_name), RuntimeWarning) + return None + + +def _handle_options(name, options_class, options, args, kwargs): + if args or kwargs: + if options is not None: + raise TypeError( + "Function {!r} called with both an 'options' argument " + "and additional arguments" + .format(name)) + return options_class(*args, **kwargs) + + if options is not None: + if isinstance(options, dict): + return options_class(**options) + elif isinstance(options, options_class): + return options + raise TypeError( + "Function {!r} expected a {} parameter, got {}" + .format(name, options_class, type(options))) + + return None + + +def _make_generic_wrapper(func_name, func, options_class, arity): + if options_class is None: + def wrapper(*args, memory_pool=None): + if arity is not Ellipsis and len(args) != arity: + raise TypeError( + f"{func_name} takes {arity} positional argument(s), " + f"but {len(args)} were given" + ) + if args and isinstance(args[0], Expression): + return Expression._call(func_name, list(args)) + return func.call(args, None, memory_pool) + else: + def wrapper(*args, memory_pool=None, options=None, **kwargs): + if arity is not Ellipsis: + if len(args) < arity: + raise TypeError( + f"{func_name} takes {arity} positional argument(s), " + f"but {len(args)} were given" + ) + option_args = args[arity:] + args = args[:arity] + else: + option_args = () + options = _handle_options(func_name, options_class, options, + option_args, kwargs) + if args and isinstance(args[0], Expression): + return Expression._call(func_name, list(args), options) + return func.call(args, options, memory_pool) + return wrapper + + +def _make_signature(arg_names, var_arg_names, options_class): + from inspect import Parameter + params = [] + for name in arg_names: + params.append(Parameter(name, Parameter.POSITIONAL_ONLY)) + for name in var_arg_names: + params.append(Parameter(name, Parameter.VAR_POSITIONAL)) + if options_class is not None: + options_sig = inspect.signature(options_class) + for p in options_sig.parameters.values(): + assert p.kind in (Parameter.POSITIONAL_OR_KEYWORD, + Parameter.KEYWORD_ONLY) + if var_arg_names: + # Cannot have a positional argument after a *args + p = p.replace(kind=Parameter.KEYWORD_ONLY) + params.append(p) + params.append(Parameter("options", Parameter.KEYWORD_ONLY, + default=None)) + params.append(Parameter("memory_pool", Parameter.KEYWORD_ONLY, + default=None)) + return inspect.Signature(params) + + +def _wrap_function(name, func): + options_class = _get_options_class(func) + arg_names = _get_arg_names(func) + has_vararg = arg_names and arg_names[-1].startswith('*') + if has_vararg: + var_arg_names = [arg_names.pop().lstrip('*')] + else: + var_arg_names = [] + + wrapper = _make_generic_wrapper( + name, func, options_class, arity=func.arity) + wrapper.__signature__ = _make_signature(arg_names, var_arg_names, + options_class) + return _decorate_compute_function(wrapper, name, func, options_class) + + +def _make_global_functions(): + """ + Make global functions wrapping each compute function. + + Note that some of the automatically-generated wrappers may be overridden + by custom versions below. + """ + g = globals() + reg = function_registry() + + # Avoid clashes with Python keywords + rewrites = {'and': 'and_', + 'or': 'or_'} + + for cpp_name in reg.list_functions(): + name = rewrites.get(cpp_name, cpp_name) + func = reg.get_function(cpp_name) + if func.kind == "hash_aggregate": + # Hash aggregate functions are not callable, + # so let's not expose them at module level. + continue + if func.kind == "scalar_aggregate" and func.arity == 0: + # Nullary scalar aggregate functions are not callable + # directly so let's not expose them at module level. + continue + assert name not in g, name + g[cpp_name] = g[name] = _wrap_function(name, func) + + +_make_global_functions() + + +def cast(arr, target_type=None, safe=None, options=None, memory_pool=None): + """ + Cast array values to another data type. Can also be invoked as an array + instance method. + + Parameters + ---------- + arr : Array-like + target_type : DataType or str + Type to cast to + safe : bool, default True + Check for overflows or other unsafe conversions + options : CastOptions, default None + Additional checks pass by CastOptions + memory_pool : MemoryPool, optional + memory pool to use for allocations during function execution. + + Examples + -------- + >>> from datetime import datetime + >>> import pyarrow as pa + >>> arr = pa.array([datetime(2010, 1, 1), datetime(2015, 1, 1)]) + >>> arr.type + TimestampType(timestamp[us]) + + You can use ``pyarrow.DataType`` objects to specify the target type: + + >>> cast(arr, pa.timestamp('ms')) + + [ + 2010-01-01 00:00:00.000, + 2015-01-01 00:00:00.000 + ] + + >>> cast(arr, pa.timestamp('ms')).type + TimestampType(timestamp[ms]) + + Alternatively, it is also supported to use the string aliases for these + types: + + >>> arr.cast('timestamp[ms]') + + [ + 2010-01-01 00:00:00.000, + 2015-01-01 00:00:00.000 + ] + >>> arr.cast('timestamp[ms]').type + TimestampType(timestamp[ms]) + + Returns + ------- + casted : Array + The cast result as a new Array + """ + safe_vars_passed = (safe is not None) or (target_type is not None) + + if safe_vars_passed and (options is not None): + raise ValueError("Must either pass values for 'target_type' and 'safe'" + " or pass a value for 'options'") + + if options is None: + target_type = pa.types.lib.ensure_type(target_type) + if safe is False: + options = CastOptions.unsafe(target_type) + else: + options = CastOptions.safe(target_type) + return call_function("cast", [arr], options, memory_pool) + + +def index(data, value, start=None, end=None, *, memory_pool=None): + """ + Find the index of the first occurrence of a given value. + + Parameters + ---------- + data : Array-like + value : Scalar-like object + The value to search for. + start : int, optional + end : int, optional + memory_pool : MemoryPool, optional + If not passed, will allocate memory from the default memory pool. + + Returns + ------- + index : int + the index, or -1 if not found + + Examples + -------- + >>> import pyarrow as pa + >>> import pyarrow.compute as pc + >>> arr = pa.array(["Lorem", "ipsum", "dolor", "sit", "Lorem", "ipsum"]) + >>> pc.index(arr, "ipsum") + + >>> pc.index(arr, "ipsum", start=2) + + >>> pc.index(arr, "amet") + + """ + if start is not None: + if end is not None: + data = data.slice(start, end - start) + else: + data = data.slice(start) + elif end is not None: + data = data.slice(0, end) + + if not isinstance(value, pa.Scalar): + value = pa.scalar(value, type=data.type) + elif data.type != value.type: + value = pa.scalar(value.as_py(), type=data.type) + options = IndexOptions(value=value) + result = call_function('index', [data], options, memory_pool) + if start is not None and result.as_py() >= 0: + result = pa.scalar(result.as_py() + start, type=pa.int64()) + return result + + +def take(data, indices, *, boundscheck=True, memory_pool=None): + """ + Select values (or records) from array- or table-like data given integer + selection indices. + + The result will be of the same type(s) as the input, with elements taken + from the input array (or record batch / table fields) at the given + indices. If an index is null then the corresponding value in the output + will be null. + + Parameters + ---------- + data : Array, ChunkedArray, RecordBatch, or Table + indices : Array, ChunkedArray + Must be of integer type + boundscheck : boolean, default True + Whether to boundscheck the indices. If False and there is an out of + bounds index, will likely cause the process to crash. + memory_pool : MemoryPool, optional + If not passed, will allocate memory from the default memory pool. + + Returns + ------- + result : depends on inputs + Selected values for the given indices + + Examples + -------- + >>> import pyarrow as pa + >>> arr = pa.array(["a", "b", "c", None, "e", "f"]) + >>> indices = pa.array([0, None, 4, 3]) + >>> arr.take(indices) + + [ + "a", + null, + "e", + null + ] + """ + options = TakeOptions(boundscheck=boundscheck) + return call_function('take', [data, indices], options, memory_pool) + + +def fill_null(values, fill_value): + """Replace each null element in values with a corresponding + element from fill_value. + + If fill_value is scalar-like, then every null element in values + will be replaced with fill_value. If fill_value is array-like, + then the i-th element in values will be replaced with the i-th + element in fill_value. + + The fill_value's type must be the same as that of values, or it + must be able to be implicitly casted to the array's type. + + This is an alias for :func:`coalesce`. + + Parameters + ---------- + values : Array, ChunkedArray, or Scalar-like object + Each null element is replaced with the corresponding value + from fill_value. + fill_value : Array, ChunkedArray, or Scalar-like object + If not same type as values, will attempt to cast. + + Returns + ------- + result : depends on inputs + Values with all null elements replaced + + Examples + -------- + >>> import pyarrow as pa + >>> arr = pa.array([1, 2, None, 3], type=pa.int8()) + >>> fill_value = pa.scalar(5, type=pa.int8()) + >>> arr.fill_null(fill_value) + + [ + 1, + 2, + 5, + 3 + ] + >>> arr = pa.array([1, 2, None, 4, None]) + >>> arr.fill_null(pa.array([10, 20, 30, 40, 50])) + + [ + 1, + 2, + 30, + 4, + 50 + ] + """ + if not isinstance(fill_value, (pa.Array, pa.ChunkedArray, pa.Scalar)): + fill_value = pa.scalar(fill_value, type=values.type) + elif values.type != fill_value.type: + fill_value = pa.scalar(fill_value.as_py(), type=values.type) + + return call_function("coalesce", [values, fill_value]) + + +def top_k_unstable(values, k, sort_keys=None, *, memory_pool=None): + """ + Select the indices of the top-k ordered elements from array- or table-like + data. + + This is a specialization for :func:`select_k_unstable`. Output is not + guaranteed to be stable. + + Parameters + ---------- + values : Array, ChunkedArray, RecordBatch, or Table + Data to sort and get top indices from. + k : int + The number of `k` elements to keep. + sort_keys : List-like + Column key names to order by when input is table-like data. + memory_pool : MemoryPool, optional + If not passed, will allocate memory from the default memory pool. + + Returns + ------- + result : Array + Indices of the top-k ordered elements + + Examples + -------- + >>> import pyarrow as pa + >>> import pyarrow.compute as pc + >>> arr = pa.array(["a", "b", "c", None, "e", "f"]) + >>> pc.top_k_unstable(arr, k=3) + + [ + 5, + 4, + 2 + ] + """ + if sort_keys is None: + sort_keys = [] + if isinstance(values, (pa.Array, pa.ChunkedArray)): + sort_keys.append(("dummy", "descending")) + else: + sort_keys = map(lambda key_name: (key_name, "descending"), sort_keys) + options = SelectKOptions(k, sort_keys) + return call_function("select_k_unstable", [values], options, memory_pool) + + +def bottom_k_unstable(values, k, sort_keys=None, *, memory_pool=None): + """ + Select the indices of the bottom-k ordered elements from + array- or table-like data. + + This is a specialization for :func:`select_k_unstable`. Output is not + guaranteed to be stable. + + Parameters + ---------- + values : Array, ChunkedArray, RecordBatch, or Table + Data to sort and get bottom indices from. + k : int + The number of `k` elements to keep. + sort_keys : List-like + Column key names to order by when input is table-like data. + memory_pool : MemoryPool, optional + If not passed, will allocate memory from the default memory pool. + + Returns + ------- + result : Array of indices + Indices of the bottom-k ordered elements + + Examples + -------- + >>> import pyarrow as pa + >>> import pyarrow.compute as pc + >>> arr = pa.array(["a", "b", "c", None, "e", "f"]) + >>> pc.bottom_k_unstable(arr, k=3) + + [ + 0, + 1, + 2 + ] + """ + if sort_keys is None: + sort_keys = [] + if isinstance(values, (pa.Array, pa.ChunkedArray)): + sort_keys.append(("dummy", "ascending")) + else: + sort_keys = map(lambda key_name: (key_name, "ascending"), sort_keys) + options = SelectKOptions(k, sort_keys) + return call_function("select_k_unstable", [values], options, memory_pool) + + +def random(n, *, initializer='system', options=None, memory_pool=None): + """ + Generate numbers in the range [0, 1). + + Generated values are uniformly-distributed, double-precision + in range [0, 1). Algorithm and seed can be changed via RandomOptions. + + Parameters + ---------- + n : int + Number of values to generate, must be greater than or equal to 0 + initializer : int or str + How to initialize the underlying random generator. + If an integer is given, it is used as a seed. + If "system" is given, the random generator is initialized with + a system-specific source of (hopefully true) randomness. + Other values are invalid. + options : pyarrow.compute.RandomOptions, optional + Alternative way of passing options. + memory_pool : pyarrow.MemoryPool, optional + If not passed, will allocate memory from the default memory pool. + """ + options = RandomOptions(initializer=initializer) + return call_function("random", [], options, memory_pool, length=n) + + +def field(*name_or_index): + """Reference a column of the dataset. + + Stores only the field's name. Type and other information is known only when + the expression is bound to a dataset having an explicit scheme. + + Nested references are allowed by passing multiple names or a tuple of + names. For example ``('foo', 'bar')`` references the field named "bar" + inside the field named "foo". + + Parameters + ---------- + *name_or_index : string, multiple strings, tuple or int + The name or index of the (possibly nested) field the expression + references to. + + Returns + ------- + field_expr : Expression + Reference to the given field + + Examples + -------- + >>> import pyarrow.compute as pc + >>> pc.field("a") + + >>> pc.field(1) + + >>> pc.field(("a", "b")) + >> pc.field("a", "b") + tobytes(path) + + check_status(Initialize(options)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/cuda.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/cuda.py new file mode 100644 index 0000000000000000000000000000000000000000..18c530d4afe406366b6ff7c12cbc1c6813081e04 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/cuda.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# flake8: noqa + + +from pyarrow._cuda import (Context, IpcMemHandle, CudaBuffer, + HostBuffer, BufferReader, BufferWriter, + new_host_buffer, + serialize_record_batch, read_message, + read_record_batch) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/error.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/error.pxi new file mode 100644 index 0000000000000000000000000000000000000000..cbe25522e8d7ecbb8e0b7e5e024b9c22c56e6e9b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/error.pxi @@ -0,0 +1,274 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from cpython.exc cimport PyErr_CheckSignals, PyErr_SetInterrupt + +from pyarrow.includes.libarrow cimport CStatus +from pyarrow.includes.libarrow_python cimport IsPyError, RestorePyError +from pyarrow.includes.common cimport c_string + +from contextlib import contextmanager +import os +import signal +import threading + +from pyarrow.lib import is_threading_enabled +from pyarrow.util import _break_traceback_cycle_from_frame + + +class ArrowException(Exception): + pass + + +class ArrowInvalid(ValueError, ArrowException): + pass + + +class ArrowMemoryError(MemoryError, ArrowException): + pass + + +class ArrowKeyError(KeyError, ArrowException): + def __str__(self): + # Override KeyError.__str__, as it uses the repr() of the key + return ArrowException.__str__(self) + + +class ArrowTypeError(TypeError, ArrowException): + pass + + +class ArrowNotImplementedError(NotImplementedError, ArrowException): + pass + + +class ArrowCapacityError(ArrowException): + pass + + +class ArrowIndexError(IndexError, ArrowException): + pass + + +class ArrowSerializationError(ArrowException): + pass + + +class ArrowCancelled(ArrowException): + def __init__(self, message, signum=None): + super().__init__(message) + self.signum = signum + + +# Compatibility alias +ArrowIOError = IOError + + +# check_status() and convert_status() could be written directly in C++ +# if we didn't define Arrow-specific subclasses (ArrowInvalid etc.) +cdef int check_status(const CStatus& status) except -1 nogil: + if status.ok(): + return 0 + + with gil: + if IsPyError(status): + RestorePyError(status) + return -1 + + raise convert_status(status) + + +cdef object convert_status(const CStatus& status): + if IsPyError(status): + try: + RestorePyError(status) + except BaseException as e: + return e + + # We don't use Status::ToString() as it would redundantly include + # the C++ class name. + message = frombytes(status.message(), safe=True) + detail = status.detail() + if detail != nullptr: + message += ". Detail: " + frombytes(detail.get().ToString(), + safe=True) + + if status.IsInvalid(): + return ArrowInvalid(message) + elif status.IsIOError(): + # Note: OSError constructor is + # OSError(message) + # or + # OSError(errno, message, filename=None) + # or (on Windows) + # OSError(errno, message, filename, winerror) + errno = ErrnoFromStatus(status) + winerror = WinErrorFromStatus(status) + if winerror != 0: + return IOError(errno, message, None, winerror) + elif errno != 0: + return IOError(errno, message) + else: + return IOError(message) + elif status.IsOutOfMemory(): + return ArrowMemoryError(message) + elif status.IsKeyError(): + return ArrowKeyError(message) + elif status.IsNotImplemented(): + return ArrowNotImplementedError(message) + elif status.IsTypeError(): + return ArrowTypeError(message) + elif status.IsCapacityError(): + return ArrowCapacityError(message) + elif status.IsIndexError(): + return ArrowIndexError(message) + elif status.IsSerializationError(): + return ArrowSerializationError(message) + elif status.IsCancelled(): + signum = SignalFromStatus(status) + if signum > 0: + return ArrowCancelled(message, signum) + else: + return ArrowCancelled(message) + else: + message = frombytes(status.ToString(), safe=True) + return ArrowException(message) + + +# These are API functions for C++ PyArrow +cdef api int pyarrow_internal_check_status(const CStatus& status) \ + except -1 nogil: + return check_status(status) + +cdef api object pyarrow_internal_convert_status(const CStatus& status): + return convert_status(status) + + +cdef class StopToken: + cdef void init(self, CStopToken stop_token): + self.stop_token = move(stop_token) + + +cdef c_bool signal_handlers_enabled = True + + +def enable_signal_handlers(c_bool enable): + """ + Enable or disable interruption of long-running operations. + + By default, certain long running operations will detect user + interruptions, such as by pressing Ctrl-C. This detection relies + on setting a signal handler for the duration of the long-running + operation, and may therefore interfere with other frameworks or + libraries (such as an event loop). + + Parameters + ---------- + enable : bool + Whether to enable user interruption by setting a temporary + signal handler. + """ + global signal_handlers_enabled + signal_handlers_enabled = enable + + +# For internal use + +# Whether we need a workaround for https://bugs.python.org/issue42248 +have_signal_refcycle = (sys.version_info < (3, 8, 10) or + (3, 9) <= sys.version_info < (3, 9, 5) or + sys.version_info[:2] == (3, 10)) + +cdef class SignalStopHandler: + cdef: + StopToken _stop_token + vector[int] _signals + c_bool _enabled + + def __cinit__(self): + self._enabled = False + + self._init_signals() + if have_signal_refcycle: + _break_traceback_cycle_from_frame(sys._getframe(0)) + + self._stop_token = StopToken() + + if not self._signals.empty(): + maybe_source = SetSignalStopSource() + if not maybe_source.ok(): + # See ARROW-11841 / ARROW-17173: in complex interaction + # scenarios (such as R calling into Python), SetSignalStopSource() + # may have already activated a signal-receiving StopSource. + # Just warn instead of erroring out. + maybe_source.status().Warn() + else: + self._stop_token.init(deref(maybe_source).token()) + # signals don't work on Emscripten without threads. + # and possibly other single-thread environments. + self._enabled = is_threading_enabled() + + def _init_signals(self): + if (signal_handlers_enabled and + threading.current_thread() is threading.main_thread()): + self._signals = [ + sig for sig in (signal.SIGINT, signal.SIGTERM) + if signal.getsignal(sig) not in (signal.SIG_DFL, + signal.SIG_IGN, None)] + + def __enter__(self): + if self._enabled: + check_status(RegisterCancellingSignalHandler(self._signals)) + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + if self._enabled: + UnregisterCancellingSignalHandler() + if exc_value is None: + # Make sure we didn't lose a signal + try: + check_status(self._stop_token.stop_token.Poll()) + except ArrowCancelled as e: + exc_value = e + if isinstance(exc_value, ArrowCancelled): + if exc_value.signum: + # Re-emit the exact same signal. We restored the Python signal + # handler above, so it should receive it. + if os.name == 'nt': + SendSignal(exc_value.signum) + else: + SendSignalToThread(exc_value.signum, + threading.main_thread().ident) + else: + # Simulate Python receiving a SIGINT + # (see https://bugs.python.org/issue43356 for why we can't + # simulate the exact signal number) + PyErr_SetInterrupt() + # Maximize chances of the Python signal handler being executed now. + # Otherwise a potential KeyboardInterrupt might be missed by an + # immediately enclosing try/except block. + PyErr_CheckSignals() + # ArrowCancelled will be re-raised if PyErr_CheckSignals() + # returned successfully. + + def __dealloc__(self): + if self._enabled: + ResetSignalStopSource() + + @property + def stop_token(self): + return self._stop_token diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/fs.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/fs.py new file mode 100644 index 0000000000000000000000000000000000000000..abdd1a995751aa32aeba2a84176747e22bc64744 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/fs.py @@ -0,0 +1,431 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +FileSystem abstraction to interact with various local and remote filesystems. +""" + +from pyarrow.util import _is_path_like, _stringify_path + +from pyarrow._fs import ( # noqa + FileSelector, + FileType, + FileInfo, + FileSystem, + LocalFileSystem, + SubTreeFileSystem, + _MockFileSystem, + FileSystemHandler, + PyFileSystem, + _copy_files, + _copy_files_selector, +) + +# For backward compatibility. +FileStats = FileInfo + +_not_imported = [] +try: + from pyarrow._azurefs import AzureFileSystem # noqa +except ImportError: + _not_imported.append("AzureFileSystem") + +try: + from pyarrow._hdfs import HadoopFileSystem # noqa +except ImportError: + _not_imported.append("HadoopFileSystem") + +try: + from pyarrow._gcsfs import GcsFileSystem # noqa +except ImportError: + _not_imported.append("GcsFileSystem") + +try: + from pyarrow._s3fs import ( # noqa + AwsDefaultS3RetryStrategy, AwsStandardS3RetryStrategy, + S3FileSystem, S3LogLevel, S3RetryStrategy, ensure_s3_initialized, + finalize_s3, ensure_s3_finalized, initialize_s3, resolve_s3_region) +except ImportError: + _not_imported.append("S3FileSystem") +else: + # GH-38364: we don't initialize S3 eagerly as that could lead + # to crashes at shutdown even when S3 isn't used. + # Instead, S3 is initialized lazily using `ensure_s3_initialized` + # in assorted places. + import atexit + atexit.register(ensure_s3_finalized) + + +def __getattr__(name): + if name in _not_imported: + raise ImportError( + "The pyarrow installation is not built with support for " + "'{0}'".format(name) + ) + + raise AttributeError( + "module 'pyarrow.fs' has no attribute '{0}'".format(name) + ) + + +def _filesystem_from_str(uri): + # instantiate the file system from an uri, if the uri has a path + # component then it will be treated as a path prefix + filesystem, prefix = FileSystem.from_uri(uri) + prefix = filesystem.normalize_path(prefix) + if prefix: + # validate that the prefix is pointing to a directory + prefix_info = filesystem.get_file_info([prefix])[0] + if prefix_info.type != FileType.Directory: + raise ValueError( + "The path component of the filesystem URI must point to a " + "directory but it has a type: `{}`. The path component " + "is `{}` and the given filesystem URI is `{}`".format( + prefix_info.type.name, prefix_info.path, uri + ) + ) + filesystem = SubTreeFileSystem(prefix, filesystem) + return filesystem + + +def _ensure_filesystem(filesystem, *, use_mmap=False): + if isinstance(filesystem, FileSystem): + return filesystem + elif isinstance(filesystem, str): + if use_mmap: + raise ValueError( + "Specifying to use memory mapping not supported for " + "filesystem specified as an URI string" + ) + return _filesystem_from_str(filesystem) + + # handle fsspec-compatible filesystems + try: + import fsspec + except ImportError: + pass + else: + if isinstance(filesystem, fsspec.AbstractFileSystem): + if type(filesystem).__name__ == 'LocalFileSystem': + # In case its a simple LocalFileSystem, use native arrow one + return LocalFileSystem(use_mmap=use_mmap) + return PyFileSystem(FSSpecHandler(filesystem)) + + raise TypeError( + "Unrecognized filesystem: {}. `filesystem` argument must be a " + "FileSystem instance or a valid file system URI'".format( + type(filesystem)) + ) + + +def _resolve_filesystem_and_path(path, filesystem=None, *, memory_map=False): + """ + Return filesystem/path from path which could be an URI or a plain + filesystem path. + """ + if not _is_path_like(path): + if filesystem is not None: + raise ValueError( + "'filesystem' passed but the specified path is file-like, so" + " there is nothing to open with 'filesystem'." + ) + return filesystem, path + + if filesystem is not None: + filesystem = _ensure_filesystem(filesystem, use_mmap=memory_map) + if isinstance(filesystem, LocalFileSystem): + path = _stringify_path(path) + elif not isinstance(path, str): + raise TypeError( + "Expected string path; path-like objects are only allowed " + "with a local filesystem" + ) + path = filesystem.normalize_path(path) + return filesystem, path + + path = _stringify_path(path) + + # if filesystem is not given, try to automatically determine one + # first check if the file exists as a local (relative) file path + # if not then try to parse the path as an URI + filesystem = LocalFileSystem(use_mmap=memory_map) + + try: + file_info = filesystem.get_file_info(path) + except ValueError: # ValueError means path is likely an URI + file_info = None + exists_locally = False + else: + exists_locally = (file_info.type != FileType.NotFound) + + # if the file or directory doesn't exists locally, then assume that + # the path is an URI describing the file system as well + if not exists_locally: + try: + filesystem, path = FileSystem.from_uri(path) + except ValueError as e: + # neither an URI nor a locally existing path, so assume that + # local path was given and propagate a nicer file not found error + # instead of a more confusing scheme parsing error + if "empty scheme" not in str(e) \ + and "Cannot parse URI" not in str(e): + raise + else: + path = filesystem.normalize_path(path) + + return filesystem, path + + +def copy_files(source, destination, + source_filesystem=None, destination_filesystem=None, + *, chunk_size=1024*1024, use_threads=True): + """ + Copy files between FileSystems. + + This functions allows you to recursively copy directories of files from + one file system to another, such as from S3 to your local machine. + + Parameters + ---------- + source : string + Source file path or URI to a single file or directory. + If a directory, files will be copied recursively from this path. + destination : string + Destination file path or URI. If `source` is a file, `destination` + is also interpreted as the destination file (not directory). + Directories will be created as necessary. + source_filesystem : FileSystem, optional + Source filesystem, needs to be specified if `source` is not a URI, + otherwise inferred. + destination_filesystem : FileSystem, optional + Destination filesystem, needs to be specified if `destination` is not + a URI, otherwise inferred. + chunk_size : int, default 1MB + The maximum size of block to read before flushing to the + destination file. A larger chunk_size will use more memory while + copying but may help accommodate high latency FileSystems. + use_threads : bool, default True + Whether to use multiple threads to accelerate copying. + + Examples + -------- + Inspect an S3 bucket's files: + + >>> s3, path = fs.FileSystem.from_uri( + ... "s3://registry.opendata.aws/roda/ndjson/") + >>> selector = fs.FileSelector(path) + >>> s3.get_file_info(selector) + [>> fs.copy_files("s3://registry.opendata.aws/roda/ndjson/index.ndjson", + ... "file:///{}/index_copy.ndjson".format(local_path)) + + >>> fs.LocalFileSystem().get_file_info(str(local_path)+ + ... '/index_copy.ndjson') + + + Copy file using a FileSystem object: + + >>> fs.copy_files("registry.opendata.aws/roda/ndjson/index.ndjson", + ... "file:///{}/index_copy.ndjson".format(local_path), + ... source_filesystem=fs.S3FileSystem()) + """ + source_fs, source_path = _resolve_filesystem_and_path( + source, source_filesystem + ) + destination_fs, destination_path = _resolve_filesystem_and_path( + destination, destination_filesystem + ) + + file_info = source_fs.get_file_info(source_path) + if file_info.type == FileType.Directory: + source_sel = FileSelector(source_path, recursive=True) + _copy_files_selector(source_fs, source_sel, + destination_fs, destination_path, + chunk_size, use_threads) + else: + _copy_files(source_fs, source_path, + destination_fs, destination_path, + chunk_size, use_threads) + + +class FSSpecHandler(FileSystemHandler): + """ + Handler for fsspec-based Python filesystems. + + https://filesystem-spec.readthedocs.io/en/latest/index.html + + Parameters + ---------- + fs : FSSpec-compliant filesystem instance + + Examples + -------- + >>> PyFileSystem(FSSpecHandler(fsspec_fs)) # doctest: +SKIP + """ + + def __init__(self, fs): + self.fs = fs + + def __eq__(self, other): + if isinstance(other, FSSpecHandler): + return self.fs == other.fs + return NotImplemented + + def __ne__(self, other): + if isinstance(other, FSSpecHandler): + return self.fs != other.fs + return NotImplemented + + def get_type_name(self): + protocol = self.fs.protocol + if isinstance(protocol, list): + protocol = protocol[0] + return "fsspec+{0}".format(protocol) + + def normalize_path(self, path): + return path + + @staticmethod + def _create_file_info(path, info): + size = info["size"] + if info["type"] == "file": + ftype = FileType.File + elif info["type"] == "directory": + ftype = FileType.Directory + # some fsspec filesystems include a file size for directories + size = None + else: + ftype = FileType.Unknown + return FileInfo(path, ftype, size=size, mtime=info.get("mtime", None)) + + def get_file_info(self, paths): + infos = [] + for path in paths: + try: + info = self.fs.info(path) + except FileNotFoundError: + infos.append(FileInfo(path, FileType.NotFound)) + else: + infos.append(self._create_file_info(path, info)) + return infos + + def get_file_info_selector(self, selector): + if not self.fs.isdir(selector.base_dir): + if self.fs.exists(selector.base_dir): + raise NotADirectoryError(selector.base_dir) + else: + if selector.allow_not_found: + return [] + else: + raise FileNotFoundError(selector.base_dir) + + if selector.recursive: + maxdepth = None + else: + maxdepth = 1 + + infos = [] + selected_files = self.fs.find( + selector.base_dir, maxdepth=maxdepth, withdirs=True, detail=True + ) + for path, info in selected_files.items(): + _path = path.strip("/") + base_dir = selector.base_dir.strip("/") + # Need to exclude base directory from selected files if present + # (fsspec filesystems, see GH-37555) + if _path != base_dir: + infos.append(self._create_file_info(path, info)) + + return infos + + def create_dir(self, path, recursive): + # mkdir also raises FileNotFoundError when base directory is not found + try: + self.fs.mkdir(path, create_parents=recursive) + except FileExistsError: + pass + + def delete_dir(self, path): + self.fs.rm(path, recursive=True) + + def _delete_dir_contents(self, path, missing_dir_ok): + try: + subpaths = self.fs.listdir(path, detail=False) + except FileNotFoundError: + if missing_dir_ok: + return + raise + for subpath in subpaths: + if self.fs.isdir(subpath): + self.fs.rm(subpath, recursive=True) + elif self.fs.isfile(subpath): + self.fs.rm(subpath) + + def delete_dir_contents(self, path, missing_dir_ok): + if path.strip("/") == "": + raise ValueError( + "delete_dir_contents called on path '", path, "'") + self._delete_dir_contents(path, missing_dir_ok) + + def delete_root_dir_contents(self): + self._delete_dir_contents("/") + + def delete_file(self, path): + # fs.rm correctly raises IsADirectoryError when `path` is a directory + # instead of a file and `recursive` is not set to True + if not self.fs.exists(path): + raise FileNotFoundError(path) + self.fs.rm(path) + + def move(self, src, dest): + self.fs.mv(src, dest, recursive=True) + + def copy_file(self, src, dest): + # fs.copy correctly raises IsADirectoryError when `src` is a directory + # instead of a file + self.fs.copy(src, dest) + + # TODO can we read/pass metadata (e.g. Content-Type) in the methods below? + + def open_input_stream(self, path): + from pyarrow import PythonFile + + if not self.fs.isfile(path): + raise FileNotFoundError(path) + + return PythonFile(self.fs.open(path, mode="rb"), mode="r") + + def open_input_file(self, path): + from pyarrow import PythonFile + + if not self.fs.isfile(path): + raise FileNotFoundError(path) + + return PythonFile(self.fs.open(path, mode="rb"), mode="r") + + def open_output_stream(self, path, metadata): + from pyarrow import PythonFile + + return PythonFile(self.fs.open(path, mode="wb"), mode="w") + + def open_append_stream(self, path, metadata): + from pyarrow import PythonFile + + return PythonFile(self.fs.open(path, mode="ab"), mode="w") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/ipc.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/ipc.pxi new file mode 100644 index 0000000000000000000000000000000000000000..e15b0ea40ed2e7de9d5a7f1776d26ff40909b4c4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/ipc.pxi @@ -0,0 +1,1403 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from cpython.pycapsule cimport PyCapsule_CheckExact, PyCapsule_GetPointer, PyCapsule_New + +from collections import namedtuple +import warnings +from cython import sizeof + +cpdef enum MetadataVersion: + V1 = CMetadataVersion_V1 + V2 = CMetadataVersion_V2 + V3 = CMetadataVersion_V3 + V4 = CMetadataVersion_V4 + V5 = CMetadataVersion_V5 + + +cdef object _wrap_metadata_version(CMetadataVersion version): + return MetadataVersion( version) + + +cdef CMetadataVersion _unwrap_metadata_version( + MetadataVersion version) except *: + if version == MetadataVersion.V1: + return CMetadataVersion_V1 + elif version == MetadataVersion.V2: + return CMetadataVersion_V2 + elif version == MetadataVersion.V3: + return CMetadataVersion_V3 + elif version == MetadataVersion.V4: + return CMetadataVersion_V4 + elif version == MetadataVersion.V5: + return CMetadataVersion_V5 + raise ValueError("Not a metadata version: " + repr(version)) + + +_WriteStats = namedtuple( + 'WriteStats', + ('num_messages', 'num_record_batches', 'num_dictionary_batches', + 'num_dictionary_deltas', 'num_replaced_dictionaries')) + + +class WriteStats(_WriteStats): + """IPC write statistics + + Parameters + ---------- + num_messages : int + Number of messages. + num_record_batches : int + Number of record batches. + num_dictionary_batches : int + Number of dictionary batches. + num_dictionary_deltas : int + Delta of dictionaries. + num_replaced_dictionaries : int + Number of replaced dictionaries. + """ + __slots__ = () + + +@staticmethod +cdef _wrap_write_stats(CIpcWriteStats c): + return WriteStats(c.num_messages, c.num_record_batches, + c.num_dictionary_batches, c.num_dictionary_deltas, + c.num_replaced_dictionaries) + + +_ReadStats = namedtuple( + 'ReadStats', + ('num_messages', 'num_record_batches', 'num_dictionary_batches', + 'num_dictionary_deltas', 'num_replaced_dictionaries')) + + +class ReadStats(_ReadStats): + """IPC read statistics + + Parameters + ---------- + num_messages : int + Number of messages. + num_record_batches : int + Number of record batches. + num_dictionary_batches : int + Number of dictionary batches. + num_dictionary_deltas : int + Delta of dictionaries. + num_replaced_dictionaries : int + Number of replaced dictionaries. + """ + __slots__ = () + + +@staticmethod +cdef _wrap_read_stats(CIpcReadStats c): + return ReadStats(c.num_messages, c.num_record_batches, + c.num_dictionary_batches, c.num_dictionary_deltas, + c.num_replaced_dictionaries) + + +cdef class IpcReadOptions(_Weakrefable): + """ + Serialization options for reading IPC format. + + Parameters + ---------- + ensure_native_endian : bool, default True + Whether to convert incoming data to platform-native endianness. + use_threads : bool + Whether to use the global CPU thread pool to parallelize any + computational tasks like decompression + included_fields : list + If empty (the default), return all deserialized fields. + If non-empty, the values are the indices of fields to read on + the top-level schema + """ + __slots__ = () + + # cdef block is in lib.pxd + + def __init__(self, *, bint ensure_native_endian=True, + bint use_threads=True, list included_fields=None): + self.c_options = CIpcReadOptions.Defaults() + self.ensure_native_endian = ensure_native_endian + self.use_threads = use_threads + if included_fields is not None: + self.included_fields = included_fields + + @property + def ensure_native_endian(self): + return self.c_options.ensure_native_endian + + @ensure_native_endian.setter + def ensure_native_endian(self, bint value): + self.c_options.ensure_native_endian = value + + @property + def use_threads(self): + return self.c_options.use_threads + + @use_threads.setter + def use_threads(self, bint value): + self.c_options.use_threads = value + + @property + def included_fields(self): + return self.c_options.included_fields + + @included_fields.setter + def included_fields(self, list value not None): + self.c_options.included_fields = value + + +cdef class IpcWriteOptions(_Weakrefable): + """ + Serialization options for the IPC format. + + Parameters + ---------- + metadata_version : MetadataVersion, default MetadataVersion.V5 + The metadata version to write. V5 is the current and latest, + V4 is the pre-1.0 metadata version (with incompatible Union layout). + allow_64bit : bool, default False + If true, allow field lengths that don't fit in a signed 32-bit int. + use_legacy_format : bool, default False + Whether to use the pre-Arrow 0.15 IPC format. + compression : str, Codec, or None + compression codec to use for record batch buffers. + If None then batch buffers will be uncompressed. + Must be "lz4", "zstd" or None. + To specify a compression_level use `pyarrow.Codec` + use_threads : bool + Whether to use the global CPU thread pool to parallelize any + computational tasks like compression. + emit_dictionary_deltas : bool + Whether to emit dictionary deltas. Default is false for maximum + stream compatibility. + unify_dictionaries : bool + If true then calls to write_table will attempt to unify dictionaries + across all batches in the table. This can help avoid the need for + replacement dictionaries (which the file format does not support) + but requires computing the unified dictionary and then remapping + the indices arrays. + + This parameter is ignored when writing to the IPC stream format as + the IPC stream format can support replacement dictionaries. + """ + __slots__ = () + + # cdef block is in lib.pxd + + def __init__(self, *, metadata_version=MetadataVersion.V5, + bint allow_64bit=False, use_legacy_format=False, + compression=None, bint use_threads=True, + bint emit_dictionary_deltas=False, + bint unify_dictionaries=False): + self.c_options = CIpcWriteOptions.Defaults() + self.allow_64bit = allow_64bit + self.use_legacy_format = use_legacy_format + self.metadata_version = metadata_version + if compression is not None: + self.compression = compression + self.use_threads = use_threads + self.emit_dictionary_deltas = emit_dictionary_deltas + self.unify_dictionaries = unify_dictionaries + + @property + def allow_64bit(self): + return self.c_options.allow_64bit + + @allow_64bit.setter + def allow_64bit(self, bint value): + self.c_options.allow_64bit = value + + @property + def use_legacy_format(self): + return self.c_options.write_legacy_ipc_format + + @use_legacy_format.setter + def use_legacy_format(self, bint value): + self.c_options.write_legacy_ipc_format = value + + @property + def metadata_version(self): + return _wrap_metadata_version(self.c_options.metadata_version) + + @metadata_version.setter + def metadata_version(self, value): + self.c_options.metadata_version = _unwrap_metadata_version(value) + + @property + def compression(self): + if self.c_options.codec == nullptr: + return None + else: + return frombytes(self.c_options.codec.get().name()) + + @compression.setter + def compression(self, value): + if value is None: + self.c_options.codec.reset() + elif isinstance(value, str): + codec_type = _ensure_compression(value) + if codec_type != CCompressionType_ZSTD and codec_type != CCompressionType_LZ4_FRAME: + raise ValueError("Compression type must be lz4, zstd or None") + self.c_options.codec = shared_ptr[CCodec](GetResultValue( + CCodec.Create(codec_type)).release()) + elif isinstance(value, Codec): + if value.name != "lz4" and value.name != "zstd": + raise ValueError("Compression type must be lz4, zstd or None") + self.c_options.codec = (value).wrapped + else: + raise TypeError( + "Property `compression` must be None, str, or pyarrow.Codec") + + @property + def use_threads(self): + return self.c_options.use_threads + + @use_threads.setter + def use_threads(self, bint value): + self.c_options.use_threads = value + + @property + def emit_dictionary_deltas(self): + return self.c_options.emit_dictionary_deltas + + @emit_dictionary_deltas.setter + def emit_dictionary_deltas(self, bint value): + self.c_options.emit_dictionary_deltas = value + + @property + def unify_dictionaries(self): + return self.c_options.unify_dictionaries + + @unify_dictionaries.setter + def unify_dictionaries(self, bint value): + self.c_options.unify_dictionaries = value + + +cdef class Message(_Weakrefable): + """ + Container for an Arrow IPC message with metadata and optional body + """ + + def __cinit__(self): + pass + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, use " + "`pyarrow.ipc.read_message` function instead." + .format(self.__class__.__name__)) + + @property + def type(self): + return frombytes(FormatMessageType(self.message.get().type())) + + @property + def metadata(self): + return pyarrow_wrap_buffer(self.message.get().metadata()) + + @property + def metadata_version(self): + return _wrap_metadata_version(self.message.get().metadata_version()) + + @property + def body(self): + cdef shared_ptr[CBuffer] body = self.message.get().body() + if body.get() == NULL: + return None + else: + return pyarrow_wrap_buffer(body) + + def equals(self, Message other): + """ + Returns True if the message contents (metadata and body) are identical + + Parameters + ---------- + other : Message + + Returns + ------- + are_equal : bool + """ + cdef c_bool result + with nogil: + result = self.message.get().Equals(deref(other.message.get())) + return result + + def serialize_to(self, NativeFile sink, alignment=8, memory_pool=None): + """ + Write message to generic OutputStream + + Parameters + ---------- + sink : NativeFile + alignment : int, default 8 + Byte alignment for metadata and body + memory_pool : MemoryPool, default None + Uses default memory pool if not specified + """ + cdef: + int64_t output_length = 0 + COutputStream* out + CIpcWriteOptions options + + options.alignment = alignment + out = sink.get_output_stream().get() + with nogil: + check_status(self.message.get() + .SerializeTo(out, options, &output_length)) + + def serialize(self, alignment=8, memory_pool=None): + """ + Write message as encapsulated IPC message + + Parameters + ---------- + alignment : int, default 8 + Byte alignment for metadata and body + memory_pool : MemoryPool, default None + Uses default memory pool if not specified + + Returns + ------- + serialized : Buffer + """ + stream = BufferOutputStream(memory_pool) + self.serialize_to(stream, alignment=alignment, memory_pool=memory_pool) + return stream.getvalue() + + def __repr__(self): + if self.message == nullptr: + return """pyarrow.Message(uninitialized)""" + + metadata_len = self.metadata.size + body = self.body + body_len = 0 if body is None else body.size + + return """pyarrow.Message +type: {0} +metadata length: {1} +body length: {2}""".format(self.type, metadata_len, body_len) + + +cdef class MessageReader(_Weakrefable): + """ + Interface for reading Message objects from some source (like an + InputStream) + """ + cdef: + unique_ptr[CMessageReader] reader + + def __cinit__(self): + pass + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, use " + "`pyarrow.ipc.MessageReader.open_stream` function " + "instead.".format(self.__class__.__name__)) + + @staticmethod + def open_stream(source): + """ + Open stream from source, if you want to use memory map use + MemoryMappedFile as source. + + Parameters + ---------- + source : bytes/buffer-like, pyarrow.NativeFile, or file-like Python object + A readable source, like an InputStream + """ + cdef: + MessageReader result = MessageReader.__new__(MessageReader) + shared_ptr[CInputStream] in_stream + unique_ptr[CMessageReader] reader + + _get_input_stream(source, &in_stream) + with nogil: + reader = CMessageReader.Open(in_stream) + result.reader.reset(reader.release()) + + return result + + def __iter__(self): + return self + + def __next__(self): + return self.read_next_message() + + def read_next_message(self): + """ + Read next Message from the stream. + + Raises + ------ + StopIteration + At end of stream + """ + cdef Message result = Message.__new__(Message) + + with nogil: + result.message = move(GetResultValue(self.reader.get() + .ReadNextMessage())) + + if result.message.get() == NULL: + raise StopIteration + + return result + +# ---------------------------------------------------------------------- +# File and stream readers and writers + +cdef class _CRecordBatchWriter(_Weakrefable): + """The base RecordBatchWriter wrapper. + + Provides common implementations of convenience methods. Should not + be instantiated directly by user code. + """ + + # cdef block is in lib.pxd + + def write(self, table_or_batch): + """ + Write RecordBatch or Table to stream. + + Parameters + ---------- + table_or_batch : {RecordBatch, Table} + """ + if isinstance(table_or_batch, RecordBatch): + self.write_batch(table_or_batch) + elif isinstance(table_or_batch, Table): + self.write_table(table_or_batch) + else: + raise ValueError(type(table_or_batch)) + + def write_batch(self, RecordBatch batch, custom_metadata=None): + """ + Write RecordBatch to stream. + + Parameters + ---------- + batch : RecordBatch + custom_metadata : mapping or KeyValueMetadata + Keys and values must be string-like / coercible to bytes + """ + metadata = ensure_metadata(custom_metadata, allow_none=True) + c_meta = pyarrow_unwrap_metadata(metadata) + + with nogil: + check_status(self.writer.get() + .WriteRecordBatch(deref(batch.batch), c_meta)) + + def write_table(self, Table table, max_chunksize=None): + """ + Write Table to stream in (contiguous) RecordBatch objects. + + Parameters + ---------- + table : Table + max_chunksize : int, default None + Maximum number of rows for RecordBatch chunks. Individual chunks may + be smaller depending on the chunk layout of individual columns. + """ + cdef: + # max_chunksize must be > 0 to have any impact + int64_t c_max_chunksize = -1 + + if max_chunksize is not None: + c_max_chunksize = max_chunksize + + with nogil: + check_status(self.writer.get().WriteTable(table.table[0], + c_max_chunksize)) + + def close(self): + """ + Close stream and write end-of-stream 0 marker. + """ + with nogil: + check_status(self.writer.get().Close()) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + @property + def stats(self): + """ + Current IPC write statistics. + """ + if not self.writer: + raise ValueError("Operation on closed writer") + return _wrap_write_stats(self.writer.get().stats()) + + +cdef class _RecordBatchStreamWriter(_CRecordBatchWriter): + cdef: + CIpcWriteOptions options + bint closed + + def __cinit__(self): + pass + + def __dealloc__(self): + pass + + @property + def _use_legacy_format(self): + # For testing (see test_ipc.py) + return self.options.write_legacy_ipc_format + + @property + def _metadata_version(self): + # For testing (see test_ipc.py) + return _wrap_metadata_version(self.options.metadata_version) + + def _open(self, sink, Schema schema not None, + IpcWriteOptions options=IpcWriteOptions()): + cdef: + shared_ptr[COutputStream] c_sink + + self.options = options.c_options + get_writer(sink, &c_sink) + with nogil: + self.writer = GetResultValue( + MakeStreamWriter(c_sink, schema.sp_schema, + self.options)) + + +cdef _get_input_stream(object source, shared_ptr[CInputStream]* out): + try: + source = as_buffer(source) + except TypeError: + # Non-buffer-like + pass + + get_input_stream(source, True, out) + + +class _ReadPandasMixin: + + def read_pandas(self, **options): + """ + Read contents of stream to a pandas.DataFrame. + + Read all record batches as a pyarrow.Table then convert it to a + pandas.DataFrame using Table.to_pandas. + + Parameters + ---------- + **options + Arguments to forward to :meth:`Table.to_pandas`. + + Returns + ------- + df : pandas.DataFrame + """ + table = self.read_all() + return table.to_pandas(**options) + + +cdef class RecordBatchReader(_Weakrefable): + """Base class for reading stream of record batches. + + Record batch readers function as iterators of record batches that also + provide the schema (without the need to get any batches). + + Warnings + -------- + Do not call this class's constructor directly, use one of the + ``RecordBatchReader.from_*`` functions instead. + + Notes + ----- + To import and export using the Arrow C stream interface, use the + ``_import_from_c`` and ``_export_to_c`` methods. However, keep in mind this + interface is intended for expert users. + + Examples + -------- + >>> import pyarrow as pa + >>> schema = pa.schema([('x', pa.int64())]) + >>> def iter_record_batches(): + ... for i in range(2): + ... yield pa.RecordBatch.from_arrays([pa.array([1, 2, 3])], schema=schema) + >>> reader = pa.RecordBatchReader.from_batches(schema, iter_record_batches()) + >>> print(reader.schema) + x: int64 + >>> for batch in reader: + ... print(batch) + pyarrow.RecordBatch + x: int64 + ---- + x: [1,2,3] + pyarrow.RecordBatch + x: int64 + ---- + x: [1,2,3] + """ + + # cdef block is in lib.pxd + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, " + "use one of the RecordBatchReader.from_* functions instead." + .format(self.__class__.__name__)) + + def __iter__(self): + return self + + def __next__(self): + return self.read_next_batch() + + @property + def schema(self): + """ + Shared schema of the record batches in the stream. + + Returns + ------- + Schema + """ + cdef shared_ptr[CSchema] c_schema + + with nogil: + c_schema = self.reader.get().schema() + + return pyarrow_wrap_schema(c_schema) + + def read_next_batch(self): + """ + Read next RecordBatch from the stream. + + Raises + ------ + StopIteration: + At end of stream. + + Returns + ------- + RecordBatch + """ + cdef shared_ptr[CRecordBatch] batch + + with nogil: + check_status(self.reader.get().ReadNext(&batch)) + + if batch.get() == NULL: + raise StopIteration + + return pyarrow_wrap_batch(batch) + + def read_next_batch_with_custom_metadata(self): + """ + Read next RecordBatch from the stream along with its custom metadata. + + Raises + ------ + StopIteration: + At end of stream. + + Returns + ------- + batch : RecordBatch + custom_metadata : KeyValueMetadata + """ + cdef: + CRecordBatchWithMetadata batch_with_metadata + + with nogil: + batch_with_metadata = GetResultValue(self.reader.get().ReadNext()) + + if batch_with_metadata.batch.get() == NULL: + raise StopIteration + + return _wrap_record_batch_with_metadata(batch_with_metadata) + + def iter_batches_with_custom_metadata(self): + """ + Iterate over record batches from the stream along with their custom + metadata. + + Yields + ------ + RecordBatchWithMetadata + """ + while True: + try: + yield self.read_next_batch_with_custom_metadata() + except StopIteration: + return + + def read_all(self): + """ + Read all record batches as a pyarrow.Table. + + Returns + ------- + Table + """ + cdef shared_ptr[CTable] table + with nogil: + check_status(self.reader.get().ToTable().Value(&table)) + return pyarrow_wrap_table(table) + + read_pandas = _ReadPandasMixin.read_pandas + + def close(self): + """ + Release any resources associated with the reader. + """ + with nogil: + check_status(self.reader.get().Close()) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def cast(self, target_schema): + """ + Wrap this reader with one that casts each batch lazily as it is pulled. + Currently only a safe cast to target_schema is implemented. + + Parameters + ---------- + target_schema : Schema + Schema to cast to, the names and order of fields must match. + + Returns + ------- + RecordBatchReader + """ + cdef: + shared_ptr[CSchema] c_schema + shared_ptr[CRecordBatchReader] c_reader + RecordBatchReader out + + if self.schema.names != target_schema.names: + raise ValueError("Target schema's field names are not matching " + f"the table's field names: {self.schema.names}, " + f"{target_schema.names}") + + c_schema = pyarrow_unwrap_schema(target_schema) + c_reader = GetResultValue(CCastingRecordBatchReader.Make( + self.reader, c_schema)) + + out = RecordBatchReader.__new__(RecordBatchReader) + out.reader = c_reader + return out + + def _export_to_c(self, out_ptr): + """ + Export to a C ArrowArrayStream struct, given its pointer. + + Parameters + ---------- + out_ptr: int + The raw pointer to a C ArrowArrayStream struct. + + Be careful: if you don't pass the ArrowArrayStream struct to a + consumer, array memory will leak. This is a low-level function + intended for expert users. + """ + cdef: + void* c_ptr = _as_c_pointer(out_ptr) + with nogil: + check_status(ExportRecordBatchReader( + self.reader, c_ptr)) + + @staticmethod + def _import_from_c(in_ptr): + """ + Import RecordBatchReader from a C ArrowArrayStream struct, + given its pointer. + + Parameters + ---------- + in_ptr: int + The raw pointer to a C ArrowArrayStream struct. + + This is a low-level function intended for expert users. + """ + cdef: + void* c_ptr = _as_c_pointer(in_ptr) + shared_ptr[CRecordBatchReader] c_reader + RecordBatchReader self + + with nogil: + c_reader = GetResultValue(ImportRecordBatchReader( + c_ptr)) + + self = RecordBatchReader.__new__(RecordBatchReader) + self.reader = c_reader + return self + + def __arrow_c_stream__(self, requested_schema=None): + """ + Export to a C ArrowArrayStream PyCapsule. + + Parameters + ---------- + requested_schema : PyCapsule, default None + The schema to which the stream should be casted, passed as a + PyCapsule containing a C ArrowSchema representation of the + requested schema. + + Returns + ------- + PyCapsule + A capsule containing a C ArrowArrayStream struct. + """ + cdef: + ArrowArrayStream* c_stream + + if requested_schema is not None: + out_schema = Schema._import_from_c_capsule(requested_schema) + if self.schema != out_schema: + return self.cast(out_schema).__arrow_c_stream__() + + stream_capsule = alloc_c_stream(&c_stream) + + with nogil: + check_status(ExportRecordBatchReader(self.reader, c_stream)) + + return stream_capsule + + @staticmethod + def _import_from_c_capsule(stream): + """ + Import RecordBatchReader from a C ArrowArrayStream PyCapsule. + + Parameters + ---------- + stream: PyCapsule + A capsule containing a C ArrowArrayStream PyCapsule. + + Returns + ------- + RecordBatchReader + """ + cdef: + ArrowArrayStream* c_stream + shared_ptr[CRecordBatchReader] c_reader + RecordBatchReader self + + c_stream = PyCapsule_GetPointer( + stream, 'arrow_array_stream' + ) + + with nogil: + c_reader = GetResultValue(ImportRecordBatchReader(c_stream)) + + self = RecordBatchReader.__new__(RecordBatchReader) + self.reader = c_reader + return self + + @staticmethod + def from_stream(data, schema=None): + """ + Create RecordBatchReader from a Arrow-compatible stream object. + + This accepts objects implementing the Arrow PyCapsule Protocol for + streams, i.e. objects that have a ``__arrow_c_stream__`` method. + + Parameters + ---------- + data : Arrow-compatible stream object + Any object that implements the Arrow PyCapsule Protocol for + streams. + schema : Schema, default None + The schema to which the stream should be casted, if supported + by the stream object. + + Returns + ------- + RecordBatchReader + """ + + if not hasattr(data, "__arrow_c_stream__"): + raise TypeError( + "Expected an object implementing the Arrow PyCapsule Protocol for " + "streams (i.e. having a `__arrow_c_stream__` method), " + f"got {type(data)!r}." + ) + + if schema is not None: + if not hasattr(schema, "__arrow_c_schema__"): + raise TypeError( + "Expected an object implementing the Arrow PyCapsule Protocol for " + "schema (i.e. having a `__arrow_c_schema__` method), " + f"got {type(schema)!r}." + ) + requested = schema.__arrow_c_schema__() + else: + requested = None + + capsule = data.__arrow_c_stream__(requested) + return RecordBatchReader._import_from_c_capsule(capsule) + + @staticmethod + def from_batches(Schema schema not None, batches): + """ + Create RecordBatchReader from an iterable of batches. + + Parameters + ---------- + schema : Schema + The shared schema of the record batches + batches : Iterable[RecordBatch] + The batches that this reader will return. + + Returns + ------- + reader : RecordBatchReader + """ + cdef: + shared_ptr[CSchema] c_schema + shared_ptr[CRecordBatchReader] c_reader + RecordBatchReader self + + c_schema = pyarrow_unwrap_schema(schema) + c_reader = GetResultValue(CPyRecordBatchReader.Make( + c_schema, batches)) + + self = RecordBatchReader.__new__(RecordBatchReader) + self.reader = c_reader + return self + + +cdef class _RecordBatchStreamReader(RecordBatchReader): + cdef: + shared_ptr[CInputStream] in_stream + CIpcReadOptions options + CRecordBatchStreamReader* stream_reader + + def __cinit__(self): + pass + + def _open(self, source, IpcReadOptions options=IpcReadOptions(), + MemoryPool memory_pool=None): + self.options = options.c_options + self.options.memory_pool = maybe_unbox_memory_pool(memory_pool) + _get_input_stream(source, &self.in_stream) + with nogil: + self.reader = GetResultValue(CRecordBatchStreamReader.Open( + self.in_stream, self.options)) + self.stream_reader = self.reader.get() + + @property + def stats(self): + """ + Current IPC read statistics. + """ + if not self.reader: + raise ValueError("Operation on closed reader") + return _wrap_read_stats(self.stream_reader.stats()) + + +cdef class _RecordBatchFileWriter(_RecordBatchStreamWriter): + + def _open(self, sink, Schema schema not None, + IpcWriteOptions options=IpcWriteOptions()): + cdef: + shared_ptr[COutputStream] c_sink + + self.options = options.c_options + get_writer(sink, &c_sink) + with nogil: + self.writer = GetResultValue( + MakeFileWriter(c_sink, schema.sp_schema, self.options)) + +_RecordBatchWithMetadata = namedtuple( + 'RecordBatchWithMetadata', + ('batch', 'custom_metadata')) + + +class RecordBatchWithMetadata(_RecordBatchWithMetadata): + """RecordBatch with its custom metadata + + Parameters + ---------- + batch : RecordBatch + custom_metadata : KeyValueMetadata + """ + __slots__ = () + + +@staticmethod +cdef _wrap_record_batch_with_metadata(CRecordBatchWithMetadata c): + return RecordBatchWithMetadata(pyarrow_wrap_batch(c.batch), + pyarrow_wrap_metadata(c.custom_metadata)) + + +cdef class _RecordBatchFileReader(_Weakrefable): + cdef: + SharedPtrNoGIL[CRecordBatchFileReader] reader + shared_ptr[CRandomAccessFile] file + CIpcReadOptions options + + cdef readonly: + Schema schema + + def __cinit__(self): + pass + + def _open(self, source, footer_offset=None, + IpcReadOptions options=IpcReadOptions(), + MemoryPool memory_pool=None): + self.options = options.c_options + self.options.memory_pool = maybe_unbox_memory_pool(memory_pool) + try: + source = as_buffer(source) + except TypeError: + pass + + get_reader(source, False, &self.file) + + cdef int64_t offset = 0 + if footer_offset is not None: + offset = footer_offset + + with nogil: + if offset != 0: + self.reader = GetResultValue( + CRecordBatchFileReader.Open2(self.file.get(), offset, + self.options)) + + else: + self.reader = GetResultValue( + CRecordBatchFileReader.Open(self.file.get(), + self.options)) + + self.schema = pyarrow_wrap_schema(self.reader.get().schema()) + + @property + def num_record_batches(self): + """ + The number of record batches in the IPC file. + """ + return self.reader.get().num_record_batches() + + def get_batch(self, int i): + """ + Read the record batch with the given index. + + Parameters + ---------- + i : int + The index of the record batch in the IPC file. + + Returns + ------- + batch : RecordBatch + """ + cdef shared_ptr[CRecordBatch] batch + + if i < 0 or i >= self.num_record_batches: + raise ValueError('Batch number {0} out of range'.format(i)) + + with nogil: + batch = GetResultValue(self.reader.get().ReadRecordBatch(i)) + + return pyarrow_wrap_batch(batch) + + # TODO(wesm): ARROW-503: Function was renamed. Remove after a period of + # time has passed + get_record_batch = get_batch + + def get_batch_with_custom_metadata(self, int i): + """ + Read the record batch with the given index along with + its custom metadata + + Parameters + ---------- + i : int + The index of the record batch in the IPC file. + + Returns + ------- + batch : RecordBatch + custom_metadata : KeyValueMetadata + """ + cdef: + CRecordBatchWithMetadata batch_with_metadata + + if i < 0 or i >= self.num_record_batches: + raise ValueError('Batch number {0} out of range'.format(i)) + + with nogil: + batch_with_metadata = GetResultValue( + self.reader.get().ReadRecordBatchWithCustomMetadata(i)) + + return _wrap_record_batch_with_metadata(batch_with_metadata) + + def read_all(self): + """ + Read all record batches as a pyarrow.Table + """ + cdef: + vector[shared_ptr[CRecordBatch]] batches + shared_ptr[CTable] table + int i, nbatches + + nbatches = self.num_record_batches + + batches.resize(nbatches) + with nogil: + for i in range(nbatches): + batches[i] = GetResultValue(self.reader.get() + .ReadRecordBatch(i)) + table = GetResultValue( + CTable.FromRecordBatches(self.schema.sp_schema, move(batches))) + + return pyarrow_wrap_table(table) + + read_pandas = _ReadPandasMixin.read_pandas + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + pass + + @property + def stats(self): + """ + Current IPC read statistics. + """ + if not self.reader: + raise ValueError("Operation on closed reader") + return _wrap_read_stats(self.reader.get().stats()) + + +def get_tensor_size(Tensor tensor): + """ + Return total size of serialized Tensor including metadata and padding. + + Parameters + ---------- + tensor : Tensor + The tensor for which we want to known the size. + """ + cdef int64_t size + with nogil: + check_status(GetTensorSize(deref(tensor.tp), &size)) + return size + + +def get_record_batch_size(RecordBatch batch): + """ + Return total size of serialized RecordBatch including metadata and padding. + + Parameters + ---------- + batch : RecordBatch + The recordbatch for which we want to know the size. + """ + cdef int64_t size + with nogil: + check_status(GetRecordBatchSize(deref(batch.batch), &size)) + return size + + +def write_tensor(Tensor tensor, NativeFile dest): + """ + Write pyarrow.Tensor to pyarrow.NativeFile object its current position. + + Parameters + ---------- + tensor : pyarrow.Tensor + dest : pyarrow.NativeFile + + Returns + ------- + bytes_written : int + Total number of bytes written to the file + """ + cdef: + int32_t metadata_length + int64_t body_length + + handle = dest.get_output_stream() + + with nogil: + check_status( + WriteTensor(deref(tensor.tp), handle.get(), + &metadata_length, &body_length)) + + return metadata_length + body_length + + +cdef NativeFile as_native_file(source): + if not isinstance(source, NativeFile): + if hasattr(source, 'read'): + source = PythonFile(source) + else: + source = BufferReader(source) + + if not isinstance(source, NativeFile): + raise ValueError('Unable to read message from object with type: {0}' + .format(type(source))) + return source + + +def read_tensor(source): + """Read pyarrow.Tensor from pyarrow.NativeFile object from current + position. If the file source supports zero copy (e.g. a memory map), then + this operation does not allocate any memory. This function not assume that + the stream is aligned + + Parameters + ---------- + source : pyarrow.NativeFile + + Returns + ------- + tensor : Tensor + + """ + cdef: + shared_ptr[CTensor] sp_tensor + CInputStream* c_stream + NativeFile nf = as_native_file(source) + + c_stream = nf.get_input_stream().get() + with nogil: + sp_tensor = GetResultValue(ReadTensor(c_stream)) + return pyarrow_wrap_tensor(sp_tensor) + + +def read_message(source): + """ + Read length-prefixed message from file or buffer-like object + + Parameters + ---------- + source : pyarrow.NativeFile, file-like object, or buffer-like object + + Returns + ------- + message : Message + """ + cdef: + Message result = Message.__new__(Message) + CInputStream* c_stream + + cdef NativeFile nf = as_native_file(source) + c_stream = nf.get_input_stream().get() + + with nogil: + result.message = move( + GetResultValue(ReadMessage(c_stream, c_default_memory_pool()))) + + if result.message == nullptr: + raise EOFError("End of Arrow stream") + + return result + + +def read_schema(obj, DictionaryMemo dictionary_memo=None): + """ + Read Schema from message or buffer + + Parameters + ---------- + obj : buffer or Message + dictionary_memo : DictionaryMemo, optional + Needed to be able to reconstruct dictionary-encoded fields + with read_record_batch + + Returns + ------- + schema : Schema + """ + cdef: + shared_ptr[CSchema] result + shared_ptr[CRandomAccessFile] cpp_file + Message message + CDictionaryMemo temp_memo + CDictionaryMemo* arg_dict_memo + + if dictionary_memo is not None: + arg_dict_memo = dictionary_memo.memo + else: + arg_dict_memo = &temp_memo + + if isinstance(obj, Message): + message = obj + with nogil: + result = GetResultValue(ReadSchema( + deref(message.message.get()), arg_dict_memo)) + else: + get_reader(obj, False, &cpp_file) + with nogil: + result = GetResultValue(ReadSchema(cpp_file.get(), arg_dict_memo)) + + return pyarrow_wrap_schema(result) + + +def read_record_batch(obj, Schema schema, + DictionaryMemo dictionary_memo=None): + """ + Read RecordBatch from message, given a known schema. If reading data from a + complete IPC stream, use ipc.open_stream instead + + Parameters + ---------- + obj : Message or Buffer-like + schema : Schema + dictionary_memo : DictionaryMemo, optional + If message contains dictionaries, must pass a populated + DictionaryMemo + + Returns + ------- + batch : RecordBatch + """ + cdef: + shared_ptr[CRecordBatch] result + Message message + CDictionaryMemo temp_memo + CDictionaryMemo* arg_dict_memo + + if isinstance(obj, Message): + message = obj + else: + message = read_message(obj) + + if dictionary_memo is not None: + arg_dict_memo = dictionary_memo.memo + else: + arg_dict_memo = &temp_memo + + with nogil: + result = GetResultValue( + ReadRecordBatch(deref(message.message.get()), + schema.sp_schema, + arg_dict_memo, + CIpcReadOptions.Defaults())) + + return pyarrow_wrap_batch(result) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/ipc.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/ipc.py new file mode 100644 index 0000000000000000000000000000000000000000..523196e1e33894871319462cdd6c72bd85830cf0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/ipc.py @@ -0,0 +1,285 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Arrow file and stream reader/writer classes, and other messaging tools + +import os + +import pyarrow as pa + +from pyarrow.lib import (IpcReadOptions, IpcWriteOptions, ReadStats, WriteStats, # noqa + Message, MessageReader, + RecordBatchReader, _ReadPandasMixin, + MetadataVersion, + read_message, read_record_batch, read_schema, + read_tensor, write_tensor, + get_record_batch_size, get_tensor_size) +import pyarrow.lib as lib + + +class RecordBatchStreamReader(lib._RecordBatchStreamReader): + """ + Reader for the Arrow streaming binary format. + + Parameters + ---------- + source : bytes/buffer-like, pyarrow.NativeFile, or file-like Python object + Either an in-memory buffer, or a readable file object. + If you want to use memory map use MemoryMappedFile as source. + options : pyarrow.ipc.IpcReadOptions + Options for IPC deserialization. + If None, default values will be used. + memory_pool : MemoryPool, default None + If None, default memory pool is used. + """ + + def __init__(self, source, *, options=None, memory_pool=None): + options = _ensure_default_ipc_read_options(options) + self._open(source, options=options, memory_pool=memory_pool) + + +_ipc_writer_class_doc = """\ +Parameters +---------- +sink : str, pyarrow.NativeFile, or file-like Python object + Either a file path, or a writable file object. +schema : pyarrow.Schema + The Arrow schema for data to be written to the file. +use_legacy_format : bool, default None + Deprecated in favor of setting options. Cannot be provided with + options. + + If None, False will be used unless this default is overridden by + setting the environment variable ARROW_PRE_0_15_IPC_FORMAT=1 +options : pyarrow.ipc.IpcWriteOptions + Options for IPC serialization. + + If None, default values will be used: the legacy format will not + be used unless overridden by setting the environment variable + ARROW_PRE_0_15_IPC_FORMAT=1, and the V5 metadata version will be + used unless overridden by setting the environment variable + ARROW_PRE_1_0_METADATA_VERSION=1.""" + + +class RecordBatchStreamWriter(lib._RecordBatchStreamWriter): + __doc__ = """Writer for the Arrow streaming binary format + +{}""".format(_ipc_writer_class_doc) + + def __init__(self, sink, schema, *, use_legacy_format=None, options=None): + options = _get_legacy_format_default(use_legacy_format, options) + self._open(sink, schema, options=options) + + +class RecordBatchFileReader(lib._RecordBatchFileReader): + """ + Class for reading Arrow record batch data from the Arrow binary file format + + Parameters + ---------- + source : bytes/buffer-like, pyarrow.NativeFile, or file-like Python object + Either an in-memory buffer, or a readable file object. + If you want to use memory map use MemoryMappedFile as source. + footer_offset : int, default None + If the file is embedded in some larger file, this is the byte offset to + the very end of the file data + options : pyarrow.ipc.IpcReadOptions + Options for IPC serialization. + If None, default values will be used. + memory_pool : MemoryPool, default None + If None, default memory pool is used. + """ + + def __init__(self, source, footer_offset=None, *, options=None, + memory_pool=None): + options = _ensure_default_ipc_read_options(options) + self._open(source, footer_offset=footer_offset, + options=options, memory_pool=memory_pool) + + +class RecordBatchFileWriter(lib._RecordBatchFileWriter): + + __doc__ = """Writer to create the Arrow binary file format + +{}""".format(_ipc_writer_class_doc) + + def __init__(self, sink, schema, *, use_legacy_format=None, options=None): + options = _get_legacy_format_default(use_legacy_format, options) + self._open(sink, schema, options=options) + + +def _get_legacy_format_default(use_legacy_format, options): + if use_legacy_format is not None and options is not None: + raise ValueError( + "Can provide at most one of options and use_legacy_format") + elif options: + if not isinstance(options, IpcWriteOptions): + raise TypeError("expected IpcWriteOptions, got {}" + .format(type(options))) + return options + + metadata_version = MetadataVersion.V5 + if use_legacy_format is None: + use_legacy_format = \ + bool(int(os.environ.get('ARROW_PRE_0_15_IPC_FORMAT', '0'))) + if bool(int(os.environ.get('ARROW_PRE_1_0_METADATA_VERSION', '0'))): + metadata_version = MetadataVersion.V4 + return IpcWriteOptions(use_legacy_format=use_legacy_format, + metadata_version=metadata_version) + + +def _ensure_default_ipc_read_options(options): + if options and not isinstance(options, IpcReadOptions): + raise TypeError( + "expected IpcReadOptions, got {}".format(type(options)) + ) + return options or IpcReadOptions() + + +def new_stream(sink, schema, *, use_legacy_format=None, options=None): + return RecordBatchStreamWriter(sink, schema, + use_legacy_format=use_legacy_format, + options=options) + + +new_stream.__doc__ = """\ +Create an Arrow columnar IPC stream writer instance + +{} + +Returns +------- +writer : RecordBatchStreamWriter + A writer for the given sink +""".format(_ipc_writer_class_doc) + + +def open_stream(source, *, options=None, memory_pool=None): + """ + Create reader for Arrow streaming format. + + Parameters + ---------- + source : bytes/buffer-like, pyarrow.NativeFile, or file-like Python object + Either an in-memory buffer, or a readable file object. + options : pyarrow.ipc.IpcReadOptions + Options for IPC serialization. + If None, default values will be used. + memory_pool : MemoryPool, default None + If None, default memory pool is used. + + Returns + ------- + reader : RecordBatchStreamReader + A reader for the given source + """ + return RecordBatchStreamReader(source, options=options, + memory_pool=memory_pool) + + +def new_file(sink, schema, *, use_legacy_format=None, options=None): + return RecordBatchFileWriter(sink, schema, + use_legacy_format=use_legacy_format, + options=options) + + +new_file.__doc__ = """\ +Create an Arrow columnar IPC file writer instance + +{} + +Returns +------- +writer : RecordBatchFileWriter + A writer for the given sink +""".format(_ipc_writer_class_doc) + + +def open_file(source, footer_offset=None, *, options=None, memory_pool=None): + """ + Create reader for Arrow file format. + + Parameters + ---------- + source : bytes/buffer-like, pyarrow.NativeFile, or file-like Python object + Either an in-memory buffer, or a readable file object. + footer_offset : int, default None + If the file is embedded in some larger file, this is the byte offset to + the very end of the file data. + options : pyarrow.ipc.IpcReadOptions + Options for IPC serialization. + If None, default values will be used. + memory_pool : MemoryPool, default None + If None, default memory pool is used. + + Returns + ------- + reader : RecordBatchFileReader + A reader for the given source + """ + return RecordBatchFileReader( + source, footer_offset=footer_offset, + options=options, memory_pool=memory_pool) + + +def serialize_pandas(df, *, nthreads=None, preserve_index=None): + """ + Serialize a pandas DataFrame into a buffer protocol compatible object. + + Parameters + ---------- + df : pandas.DataFrame + nthreads : int, default None + Number of threads to use for conversion to Arrow, default all CPUs. + preserve_index : bool, default None + The default of None will store the index as a column, except for + RangeIndex which is stored as metadata only. If True, always + preserve the pandas index data as a column. If False, no index + information is saved and the result will have a default RangeIndex. + + Returns + ------- + buf : buffer + An object compatible with the buffer protocol. + """ + batch = pa.RecordBatch.from_pandas(df, nthreads=nthreads, + preserve_index=preserve_index) + sink = pa.BufferOutputStream() + with pa.RecordBatchStreamWriter(sink, batch.schema) as writer: + writer.write_batch(batch) + return sink.getvalue() + + +def deserialize_pandas(buf, *, use_threads=True): + """Deserialize a buffer protocol compatible object into a pandas DataFrame. + + Parameters + ---------- + buf : buffer + An object compatible with the buffer protocol. + use_threads : bool, default True + Whether to parallelize the conversion using multiple threads. + + Returns + ------- + df : pandas.DataFrame + The buffer deserialized as pandas DataFrame + """ + buffer_reader = pa.BufferReader(buf) + with pa.RecordBatchStreamReader(buffer_reader) as reader: + table = reader.read_all() + return table.to_pandas(use_threads=use_threads) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/jvm.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/jvm.py new file mode 100644 index 0000000000000000000000000000000000000000..161c5ff4d6d74512dfcd76ddac5a4c4781ad63c3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/jvm.py @@ -0,0 +1,335 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Functions to interact with Arrow memory allocated by Arrow Java. + +These functions convert the objects holding the metadata, the actual +data is not copied at all. + +This will only work with a JVM running in the same process such as provided +through jpype. Modules that talk to a remote JVM like py4j will not work as the +memory addresses reported by them are not reachable in the python process. +""" + +import pyarrow as pa + + +class _JvmBufferNanny: + """ + An object that keeps a org.apache.arrow.memory.ArrowBuf's underlying + memory alive. + """ + ref_manager = None + + def __init__(self, jvm_buf): + ref_manager = jvm_buf.getReferenceManager() + # Will raise a java.lang.IllegalArgumentException if the buffer + # is already freed. It seems that exception cannot easily be + # caught... + ref_manager.retain() + self.ref_manager = ref_manager + + def __del__(self): + if self.ref_manager is not None: + self.ref_manager.release() + + +def jvm_buffer(jvm_buf): + """ + Construct an Arrow buffer from org.apache.arrow.memory.ArrowBuf + + Parameters + ---------- + + jvm_buf: org.apache.arrow.memory.ArrowBuf + Arrow Buffer representation on the JVM. + + Returns + ------- + pyarrow.Buffer + Python Buffer that references the JVM memory. + """ + nanny = _JvmBufferNanny(jvm_buf) + address = jvm_buf.memoryAddress() + size = jvm_buf.capacity() + return pa.foreign_buffer(address, size, base=nanny) + + +def _from_jvm_int_type(jvm_type): + """ + Convert a JVM int type to its Python equivalent. + + Parameters + ---------- + jvm_type : org.apache.arrow.vector.types.pojo.ArrowType$Int + + Returns + ------- + typ : pyarrow.DataType + """ + + bit_width = jvm_type.getBitWidth() + if jvm_type.getIsSigned(): + if bit_width == 8: + return pa.int8() + elif bit_width == 16: + return pa.int16() + elif bit_width == 32: + return pa.int32() + elif bit_width == 64: + return pa.int64() + else: + if bit_width == 8: + return pa.uint8() + elif bit_width == 16: + return pa.uint16() + elif bit_width == 32: + return pa.uint32() + elif bit_width == 64: + return pa.uint64() + + +def _from_jvm_float_type(jvm_type): + """ + Convert a JVM float type to its Python equivalent. + + Parameters + ---------- + jvm_type: org.apache.arrow.vector.types.pojo.ArrowType$FloatingPoint + + Returns + ------- + typ: pyarrow.DataType + """ + precision = jvm_type.getPrecision().toString() + if precision == 'HALF': + return pa.float16() + elif precision == 'SINGLE': + return pa.float32() + elif precision == 'DOUBLE': + return pa.float64() + + +def _from_jvm_time_type(jvm_type): + """ + Convert a JVM time type to its Python equivalent. + + Parameters + ---------- + jvm_type: org.apache.arrow.vector.types.pojo.ArrowType$Time + + Returns + ------- + typ: pyarrow.DataType + """ + time_unit = jvm_type.getUnit().toString() + if time_unit == 'SECOND': + assert jvm_type.getBitWidth() == 32 + return pa.time32('s') + elif time_unit == 'MILLISECOND': + assert jvm_type.getBitWidth() == 32 + return pa.time32('ms') + elif time_unit == 'MICROSECOND': + assert jvm_type.getBitWidth() == 64 + return pa.time64('us') + elif time_unit == 'NANOSECOND': + assert jvm_type.getBitWidth() == 64 + return pa.time64('ns') + + +def _from_jvm_timestamp_type(jvm_type): + """ + Convert a JVM timestamp type to its Python equivalent. + + Parameters + ---------- + jvm_type: org.apache.arrow.vector.types.pojo.ArrowType$Timestamp + + Returns + ------- + typ: pyarrow.DataType + """ + time_unit = jvm_type.getUnit().toString() + timezone = jvm_type.getTimezone() + if timezone is not None: + timezone = str(timezone) + if time_unit == 'SECOND': + return pa.timestamp('s', tz=timezone) + elif time_unit == 'MILLISECOND': + return pa.timestamp('ms', tz=timezone) + elif time_unit == 'MICROSECOND': + return pa.timestamp('us', tz=timezone) + elif time_unit == 'NANOSECOND': + return pa.timestamp('ns', tz=timezone) + + +def _from_jvm_date_type(jvm_type): + """ + Convert a JVM date type to its Python equivalent + + Parameters + ---------- + jvm_type: org.apache.arrow.vector.types.pojo.ArrowType$Date + + Returns + ------- + typ: pyarrow.DataType + """ + day_unit = jvm_type.getUnit().toString() + if day_unit == 'DAY': + return pa.date32() + elif day_unit == 'MILLISECOND': + return pa.date64() + + +def field(jvm_field): + """ + Construct a Field from a org.apache.arrow.vector.types.pojo.Field + instance. + + Parameters + ---------- + jvm_field: org.apache.arrow.vector.types.pojo.Field + + Returns + ------- + pyarrow.Field + """ + name = str(jvm_field.getName()) + jvm_type = jvm_field.getType() + + typ = None + if not jvm_type.isComplex(): + type_str = jvm_type.getTypeID().toString() + if type_str == 'Null': + typ = pa.null() + elif type_str == 'Int': + typ = _from_jvm_int_type(jvm_type) + elif type_str == 'FloatingPoint': + typ = _from_jvm_float_type(jvm_type) + elif type_str == 'Utf8': + typ = pa.string() + elif type_str == 'Binary': + typ = pa.binary() + elif type_str == 'FixedSizeBinary': + typ = pa.binary(jvm_type.getByteWidth()) + elif type_str == 'Bool': + typ = pa.bool_() + elif type_str == 'Time': + typ = _from_jvm_time_type(jvm_type) + elif type_str == 'Timestamp': + typ = _from_jvm_timestamp_type(jvm_type) + elif type_str == 'Date': + typ = _from_jvm_date_type(jvm_type) + elif type_str == 'Decimal': + typ = pa.decimal128(jvm_type.getPrecision(), jvm_type.getScale()) + else: + raise NotImplementedError( + "Unsupported JVM type: {}".format(type_str)) + else: + # TODO: The following JVM types are not implemented: + # Struct, List, FixedSizeList, Union, Dictionary + raise NotImplementedError( + "JVM field conversion only implemented for primitive types.") + + nullable = jvm_field.isNullable() + jvm_metadata = jvm_field.getMetadata() + if jvm_metadata.isEmpty(): + metadata = None + else: + metadata = {str(entry.getKey()): str(entry.getValue()) + for entry in jvm_metadata.entrySet()} + return pa.field(name, typ, nullable, metadata) + + +def schema(jvm_schema): + """ + Construct a Schema from a org.apache.arrow.vector.types.pojo.Schema + instance. + + Parameters + ---------- + jvm_schema: org.apache.arrow.vector.types.pojo.Schema + + Returns + ------- + pyarrow.Schema + """ + fields = jvm_schema.getFields() + fields = [field(f) for f in fields] + jvm_metadata = jvm_schema.getCustomMetadata() + if jvm_metadata.isEmpty(): + metadata = None + else: + metadata = {str(entry.getKey()): str(entry.getValue()) + for entry in jvm_metadata.entrySet()} + return pa.schema(fields, metadata) + + +def array(jvm_array): + """ + Construct an (Python) Array from its JVM equivalent. + + Parameters + ---------- + jvm_array : org.apache.arrow.vector.ValueVector + + Returns + ------- + array : Array + """ + if jvm_array.getField().getType().isComplex(): + minor_type_str = jvm_array.getMinorType().toString() + raise NotImplementedError( + "Cannot convert JVM Arrow array of type {}," + " complex types not yet implemented.".format(minor_type_str)) + dtype = field(jvm_array.getField()).type + buffers = [jvm_buffer(buf) + for buf in list(jvm_array.getBuffers(False))] + + # If JVM has an empty Vector, buffer list will be empty so create manually + if len(buffers) == 0: + return pa.array([], type=dtype) + + length = jvm_array.getValueCount() + null_count = jvm_array.getNullCount() + return pa.Array.from_buffers(dtype, length, buffers, null_count) + + +def record_batch(jvm_vector_schema_root): + """ + Construct a (Python) RecordBatch from a JVM VectorSchemaRoot + + Parameters + ---------- + jvm_vector_schema_root : org.apache.arrow.vector.VectorSchemaRoot + + Returns + ------- + record_batch: pyarrow.RecordBatch + """ + pa_schema = schema(jvm_vector_schema_root.getSchema()) + + arrays = [] + for name in pa_schema.names: + arrays.append(array(jvm_vector_schema_root.getVector(name))) + + return pa.RecordBatch.from_arrays( + arrays, + pa_schema.names, + metadata=pa_schema.metadata + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/lib.pxd b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/lib.pxd new file mode 100644 index 0000000000000000000000000000000000000000..bc9811b92b007aa577f891cb4f6902a71371d9cf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/lib.pxd @@ -0,0 +1,770 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +from cpython cimport PyObject +from libcpp cimport nullptr, bool as c_bool +from libcpp.cast cimport dynamic_cast +from libcpp.memory cimport dynamic_pointer_cast +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * +from pyarrow.includes.libarrow_python cimport * + +# Will be available in Cython 3, not backported +# ref: https://github.com/cython/cython/issues/3293#issuecomment-1223058101 +cdef extern from "" namespace "std" nogil: + cdef cppclass nullopt_t: + nullopt_t() + + cdef nullopt_t nullopt + + cdef cppclass optional[T]: + ctypedef T value_type + optional() + optional(nullopt_t) + optional(optional&) except + + optional(T&) except + + c_bool has_value() + T& value() + T& value_or[U](U& default_value) + void swap(optional&) + void reset() + T& emplace(...) + T& operator*() + # T* operator->() # Not Supported + optional& operator=(optional&) + optional& operator=[U](U&) + c_bool operator bool() + c_bool operator!() + c_bool operator==[U](optional&, U&) + c_bool operator!=[U](optional&, U&) + c_bool operator<[U](optional&, U&) + c_bool operator>[U](optional&, U&) + c_bool operator<=[U](optional&, U&) + c_bool operator>=[U](optional&, U&) + + optional[T] make_optional[T](...) except + + +cdef extern from "Python.h": + int PySlice_Check(object) + + +cdef int check_status(const CStatus& status) except -1 nogil +cdef object convert_status(const CStatus& status) + + +cdef class _Weakrefable: + cdef object __weakref__ + + +cdef class IpcWriteOptions(_Weakrefable): + cdef: + CIpcWriteOptions c_options + + +cdef class IpcReadOptions(_Weakrefable): + cdef: + CIpcReadOptions c_options + + +cdef class Message(_Weakrefable): + cdef: + unique_ptr[CMessage] message + + +cdef class MemoryPool(_Weakrefable): + cdef: + CMemoryPool* pool + + cdef void init(self, CMemoryPool* pool) + + +cdef CMemoryPool* maybe_unbox_memory_pool(MemoryPool memory_pool) + + +cdef object box_memory_pool(CMemoryPool* pool) + + +cdef class DataType(_Weakrefable): + cdef: + shared_ptr[CDataType] sp_type + CDataType* type + bytes pep3118_format + + cdef void init(self, const shared_ptr[CDataType]& type) except * + cpdef Field field(self, i) + + +cdef class ListType(DataType): + cdef: + const CListType* list_type + + +cdef class LargeListType(DataType): + cdef: + const CLargeListType* list_type + + +cdef class ListViewType(DataType): + cdef: + const CListViewType* list_view_type + + +cdef class LargeListViewType(DataType): + cdef: + const CLargeListViewType* list_view_type + + +cdef class MapType(DataType): + cdef: + const CMapType* map_type + + +cdef class FixedSizeListType(DataType): + cdef: + const CFixedSizeListType* list_type + + +cdef class StructType(DataType): + cdef: + const CStructType* struct_type + + cdef Field field_by_name(self, name) + + +cdef class DictionaryMemo(_Weakrefable): + cdef: + # Even though the CDictionaryMemo instance is private, we allocate + # it on the heap so as to avoid C++ ABI issues with Python wheels. + shared_ptr[CDictionaryMemo] sp_memo + CDictionaryMemo* memo + + +cdef class DictionaryType(DataType): + cdef: + const CDictionaryType* dict_type + + +cdef class TimestampType(DataType): + cdef: + const CTimestampType* ts_type + + +cdef class Time32Type(DataType): + cdef: + const CTime32Type* time_type + + +cdef class Time64Type(DataType): + cdef: + const CTime64Type* time_type + + +cdef class DurationType(DataType): + cdef: + const CDurationType* duration_type + + +cdef class FixedSizeBinaryType(DataType): + cdef: + const CFixedSizeBinaryType* fixed_size_binary_type + + +cdef class Decimal32Type(FixedSizeBinaryType): + cdef: + const CDecimal32Type* decimal32_type + + +cdef class Decimal64Type(FixedSizeBinaryType): + cdef: + const CDecimal64Type* decimal64_type + + +cdef class Decimal128Type(FixedSizeBinaryType): + cdef: + const CDecimal128Type* decimal128_type + + +cdef class Decimal256Type(FixedSizeBinaryType): + cdef: + const CDecimal256Type* decimal256_type + + +cdef class RunEndEncodedType(DataType): + cdef: + const CRunEndEncodedType* run_end_encoded_type + + +cdef class BaseExtensionType(DataType): + cdef: + const CExtensionType* ext_type + + +cdef class ExtensionType(BaseExtensionType): + cdef: + const CPyExtensionType* cpy_ext_type + + +cdef class FixedShapeTensorType(BaseExtensionType): + cdef: + const CFixedShapeTensorType* tensor_ext_type + +cdef class Bool8Type(BaseExtensionType): + cdef: + const CBool8Type* bool8_ext_type + +cdef class OpaqueType(BaseExtensionType): + cdef: + const COpaqueType* opaque_ext_type + +cdef class UuidType(BaseExtensionType): + cdef: + const CUuidType* uuid_ext_type + +cdef class JsonType(BaseExtensionType): + cdef: + const CJsonType* json_ext_type + + +cdef class PyExtensionType(ExtensionType): + pass + + +cdef class _Metadata(_Weakrefable): + # required because KeyValueMetadata also extends collections.abc.Mapping + # and the first parent class must be an extension type + pass + + +cdef class KeyValueMetadata(_Metadata): + cdef: + shared_ptr[const CKeyValueMetadata] wrapped + const CKeyValueMetadata* metadata + + cdef void init(self, const shared_ptr[const CKeyValueMetadata]& wrapped) + + @staticmethod + cdef wrap(const shared_ptr[const CKeyValueMetadata]& sp) + cdef inline shared_ptr[const CKeyValueMetadata] unwrap(self) nogil + + +cdef class Field(_Weakrefable): + cdef: + shared_ptr[CField] sp_field + CField* field + + cdef readonly: + DataType type + + cdef void init(self, const shared_ptr[CField]& field) + + +cdef class Schema(_Weakrefable): + cdef: + shared_ptr[CSchema] sp_schema + CSchema* schema + + cdef void init(self, const vector[shared_ptr[CField]]& fields) + cdef void init_schema(self, const shared_ptr[CSchema]& schema) + + +cdef class Scalar(_Weakrefable): + cdef: + shared_ptr[CScalar] wrapped + + cdef void init(self, const shared_ptr[CScalar]& wrapped) + + @staticmethod + cdef wrap(const shared_ptr[CScalar]& wrapped) + + cdef inline shared_ptr[CScalar] unwrap(self) nogil + + +cdef class _PandasConvertible(_Weakrefable): + pass + + +cdef class Array(_PandasConvertible): + cdef: + shared_ptr[CArray] sp_array + CArray* ap + + cdef readonly: + DataType type + # To allow Table to propagate metadata to pandas.Series + object _name + + cdef void init(self, const shared_ptr[CArray]& sp_array) except * + cdef getitem(self, int64_t i) + cdef int64_t length(self) + cdef void _assert_cpu(self) except * + + +cdef class Tensor(_Weakrefable): + cdef: + shared_ptr[CTensor] sp_tensor + CTensor* tp + + cdef readonly: + DataType type + bytes _ssize_t_shape + bytes _ssize_t_strides + + cdef void init(self, const shared_ptr[CTensor]& sp_tensor) + + +cdef class SparseCSRMatrix(_Weakrefable): + cdef: + shared_ptr[CSparseCSRMatrix] sp_sparse_tensor + CSparseCSRMatrix* stp + + cdef readonly: + DataType type + + cdef void init(self, const shared_ptr[CSparseCSRMatrix]& sp_sparse_tensor) + + +cdef class SparseCSCMatrix(_Weakrefable): + cdef: + shared_ptr[CSparseCSCMatrix] sp_sparse_tensor + CSparseCSCMatrix* stp + + cdef readonly: + DataType type + + cdef void init(self, const shared_ptr[CSparseCSCMatrix]& sp_sparse_tensor) + + +cdef class SparseCOOTensor(_Weakrefable): + cdef: + shared_ptr[CSparseCOOTensor] sp_sparse_tensor + CSparseCOOTensor* stp + + cdef readonly: + DataType type + + cdef void init(self, const shared_ptr[CSparseCOOTensor]& sp_sparse_tensor) + + +cdef class SparseCSFTensor(_Weakrefable): + cdef: + shared_ptr[CSparseCSFTensor] sp_sparse_tensor + CSparseCSFTensor* stp + + cdef readonly: + DataType type + + cdef void init(self, const shared_ptr[CSparseCSFTensor]& sp_sparse_tensor) + + +cdef class NullArray(Array): + pass + + +cdef class BooleanArray(Array): + pass + + +cdef class NumericArray(Array): + pass + + +cdef class IntegerArray(NumericArray): + pass + + +cdef class FloatingPointArray(NumericArray): + pass + + +cdef class Int8Array(IntegerArray): + pass + + +cdef class UInt8Array(IntegerArray): + pass + + +cdef class Int16Array(IntegerArray): + pass + + +cdef class UInt16Array(IntegerArray): + pass + + +cdef class Int32Array(IntegerArray): + pass + + +cdef class UInt32Array(IntegerArray): + pass + + +cdef class Int64Array(IntegerArray): + pass + + +cdef class UInt64Array(IntegerArray): + pass + + +cdef class HalfFloatArray(FloatingPointArray): + pass + + +cdef class FloatArray(FloatingPointArray): + pass + + +cdef class DoubleArray(FloatingPointArray): + pass + + +cdef class FixedSizeBinaryArray(Array): + pass + + +cdef class Decimal32Array(FixedSizeBinaryArray): + pass + + +cdef class Decimal64Array(FixedSizeBinaryArray): + pass + + +cdef class Decimal128Array(FixedSizeBinaryArray): + pass + + +cdef class Decimal256Array(FixedSizeBinaryArray): + pass + + +cdef class StructArray(Array): + pass + + +cdef class BaseListArray(Array): + pass + + +cdef class ListArray(BaseListArray): + pass + + +cdef class LargeListArray(BaseListArray): + pass + + +cdef class ListViewArray(BaseListArray): + pass + + +cdef class LargeListViewArray(BaseListArray): + pass + + +cdef class MapArray(ListArray): + pass + + +cdef class FixedSizeListArray(BaseListArray): + pass + + +cdef class UnionArray(Array): + pass + + +cdef class StringArray(Array): + pass + + +cdef class BinaryArray(Array): + pass + + +cdef class StringViewArray(Array): + pass + + +cdef class BinaryViewArray(Array): + pass + + +cdef class DictionaryArray(Array): + cdef: + object _indices, _dictionary + + +cdef class ExtensionArray(Array): + pass + + +cdef class MonthDayNanoIntervalArray(Array): + pass + + +cdef wrap_array_output(PyObject* output) +cdef wrap_datum(const CDatum& datum) + + +cdef class ChunkedArray(_PandasConvertible): + cdef: + shared_ptr[CChunkedArray] sp_chunked_array + CChunkedArray* chunked_array + c_bool _is_cpu + c_bool _init_is_cpu + + cdef readonly: + # To allow Table to propagate metadata to pandas.Series + object _name + + cdef void init(self, const shared_ptr[CChunkedArray]& chunked_array) + cdef getitem(self, int64_t i) + + +cdef class _Tabular(_PandasConvertible): + cdef void _assert_cpu(self) except * + + +cdef class Table(_Tabular): + cdef: + shared_ptr[CTable] sp_table + CTable* table + c_bool _is_cpu + c_bool _init_is_cpu + + cdef void init(self, const shared_ptr[CTable]& table) + + +cdef class RecordBatch(_Tabular): + cdef: + shared_ptr[CRecordBatch] sp_batch + CRecordBatch* batch + Schema _schema + + cdef void init(self, const shared_ptr[CRecordBatch]& table) + + +cdef class Device(_Weakrefable): + cdef: + shared_ptr[CDevice] device + + cdef void init(self, const shared_ptr[CDevice]& device) + + @staticmethod + cdef wrap(const shared_ptr[CDevice]& device) + + cdef inline shared_ptr[CDevice] unwrap(self) nogil + + +cdef class MemoryManager(_Weakrefable): + cdef: + shared_ptr[CMemoryManager] memory_manager + + cdef void init(self, const shared_ptr[CMemoryManager]& memory_manager) + + @staticmethod + cdef wrap(const shared_ptr[CMemoryManager]& mm) + + cdef inline shared_ptr[CMemoryManager] unwrap(self) nogil + + +cdef class Buffer(_Weakrefable): + cdef: + shared_ptr[CBuffer] buffer + Py_ssize_t shape[1] + Py_ssize_t strides[1] + + cdef void init(self, const shared_ptr[CBuffer]& buffer) + cdef getitem(self, int64_t i) + + +cdef class ResizableBuffer(Buffer): + + cdef void init_rz(self, const shared_ptr[CResizableBuffer]& buffer) + + +cdef class NativeFile(_Weakrefable): + cdef: + shared_ptr[CInputStream] input_stream + shared_ptr[CRandomAccessFile] random_access + shared_ptr[COutputStream] output_stream + bint is_readable + bint is_writable + bint is_seekable + bint _is_appending + bint own_file + + # By implementing these "virtual" functions (all functions in Cython + # extension classes are technically virtual in the C++ sense) we can expose + # the arrow::io abstract file interfaces to other components throughout the + # suite of Arrow C++ libraries + cdef set_random_access_file(self, shared_ptr[CRandomAccessFile] handle) + cdef set_input_stream(self, shared_ptr[CInputStream] handle) + cdef set_output_stream(self, shared_ptr[COutputStream] handle) + + cdef shared_ptr[CRandomAccessFile] get_random_access_file(self) except * + cdef shared_ptr[CInputStream] get_input_stream(self) except * + cdef shared_ptr[COutputStream] get_output_stream(self) except * + + +cdef class BufferedInputStream(NativeFile): + pass + + +cdef class BufferedOutputStream(NativeFile): + pass + + +cdef class CompressedInputStream(NativeFile): + pass + + +cdef class CompressedOutputStream(NativeFile): + pass + + +cdef class _CRecordBatchWriter(_Weakrefable): + cdef: + SharedPtrNoGIL[CRecordBatchWriter] writer + + +cdef class RecordBatchReader(_Weakrefable): + cdef: + SharedPtrNoGIL[CRecordBatchReader] reader + + +cdef class CacheOptions(_Weakrefable): + cdef: + CCacheOptions wrapped + + cdef void init(self, CCacheOptions options) + + cdef inline CCacheOptions unwrap(self) + + @staticmethod + cdef wrap(const CCacheOptions options) + + +cdef class Codec(_Weakrefable): + cdef: + shared_ptr[CCodec] wrapped + + cdef inline CCodec* unwrap(self) nogil + + +# This class is only used internally for now +cdef class StopToken: + cdef: + CStopToken stop_token + + cdef void init(self, CStopToken stop_token) + + +cdef get_input_stream(object source, c_bool use_memory_map, + shared_ptr[CInputStream]* reader) +cdef get_reader(object source, c_bool use_memory_map, + shared_ptr[CRandomAccessFile]* reader) +cdef get_writer(object source, shared_ptr[COutputStream]* writer) +cdef NativeFile get_native_file(object source, c_bool use_memory_map) + +cdef shared_ptr[CInputStream] native_transcoding_input_stream( + shared_ptr[CInputStream] stream, src_encoding, + dest_encoding) except * + +cdef shared_ptr[function[StreamWrapFunc]] make_streamwrap_func( + src_encoding, dest_encoding) except * + +# Default is allow_none=False +cpdef DataType ensure_type(object type, bint allow_none=*) + +cdef timeunit_to_string(TimeUnit unit) +cdef TimeUnit string_to_timeunit(unit) except * + +# Exceptions may be raised when converting dict values, so need to +# check exception state on return +cdef shared_ptr[const CKeyValueMetadata] pyarrow_unwrap_metadata( + object meta) except * +cdef object pyarrow_wrap_metadata( + const shared_ptr[const CKeyValueMetadata]& meta) + +# +# Public Cython API for 3rd party code +# +# If you add functions to this list, please also update +# `cpp/src/arrow/python/pyarrow.{h, cc}` +# + +# Wrapping C++ -> Python + +cdef public object pyarrow_wrap_buffer(const shared_ptr[CBuffer]& buf) +cdef public object pyarrow_wrap_resizable_buffer( + const shared_ptr[CResizableBuffer]& buf) + +cdef public object pyarrow_wrap_data_type(const shared_ptr[CDataType]& type) +cdef public object pyarrow_wrap_field(const shared_ptr[CField]& field) +cdef public object pyarrow_wrap_schema(const shared_ptr[CSchema]& type) + +cdef public object pyarrow_wrap_scalar(const shared_ptr[CScalar]& sp_scalar) + +cdef public object pyarrow_wrap_array(const shared_ptr[CArray]& sp_array) +cdef public object pyarrow_wrap_chunked_array( + const shared_ptr[CChunkedArray]& sp_array) + +cdef public object pyarrow_wrap_sparse_coo_tensor( + const shared_ptr[CSparseCOOTensor]& sp_sparse_tensor) +cdef public object pyarrow_wrap_sparse_csc_matrix( + const shared_ptr[CSparseCSCMatrix]& sp_sparse_tensor) +cdef public object pyarrow_wrap_sparse_csf_tensor( + const shared_ptr[CSparseCSFTensor]& sp_sparse_tensor) +cdef public object pyarrow_wrap_sparse_csr_matrix( + const shared_ptr[CSparseCSRMatrix]& sp_sparse_tensor) +cdef public object pyarrow_wrap_tensor(const shared_ptr[CTensor]& sp_tensor) + +cdef public object pyarrow_wrap_batch(const shared_ptr[CRecordBatch]& cbatch) +cdef public object pyarrow_wrap_table(const shared_ptr[CTable]& ctable) + +# Unwrapping Python -> C++ + +cdef public shared_ptr[CBuffer] pyarrow_unwrap_buffer(object buffer) + +cdef public shared_ptr[CDataType] pyarrow_unwrap_data_type(object data_type) +cdef public shared_ptr[CField] pyarrow_unwrap_field(object field) +cdef public shared_ptr[CSchema] pyarrow_unwrap_schema(object schema) + +cdef public shared_ptr[CScalar] pyarrow_unwrap_scalar(object scalar) + +cdef public shared_ptr[CArray] pyarrow_unwrap_array(object array) +cdef public shared_ptr[CChunkedArray] pyarrow_unwrap_chunked_array( + object array) + +cdef public shared_ptr[CSparseCOOTensor] pyarrow_unwrap_sparse_coo_tensor( + object sparse_tensor) +cdef public shared_ptr[CSparseCSCMatrix] pyarrow_unwrap_sparse_csc_matrix( + object sparse_tensor) +cdef public shared_ptr[CSparseCSFTensor] pyarrow_unwrap_sparse_csf_tensor( + object sparse_tensor) +cdef public shared_ptr[CSparseCSRMatrix] pyarrow_unwrap_sparse_csr_matrix( + object sparse_tensor) +cdef public shared_ptr[CTensor] pyarrow_unwrap_tensor(object tensor) + +cdef public shared_ptr[CRecordBatch] pyarrow_unwrap_batch(object batch) +cdef public shared_ptr[CTable] pyarrow_unwrap_table(object table) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/lib_api.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/lib_api.h new file mode 100644 index 0000000000000000000000000000000000000000..b5ddac2a39eb2499e66f95e8a17e58a93b68ae13 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/lib_api.h @@ -0,0 +1,201 @@ +/* Generated by Cython 3.0.12 */ + +#ifndef __PYX_HAVE_API__pyarrow__lib +#define __PYX_HAVE_API__pyarrow__lib +#ifdef __MINGW64__ +#define MS_WIN64 +#endif +#include "Python.h" +#include "lib.h" + +static PyObject *(*__pyx_api_f_7pyarrow_3lib_box_memory_pool)( arrow::MemoryPool *) = 0; +#define box_memory_pool __pyx_api_f_7pyarrow_3lib_box_memory_pool +static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_buffer)(std::shared_ptr< arrow::Buffer> const &) = 0; +#define pyarrow_wrap_buffer __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_buffer +static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_resizable_buffer)(std::shared_ptr< arrow::ResizableBuffer> const &) = 0; +#define pyarrow_wrap_resizable_buffer __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_resizable_buffer +static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_data_type)(std::shared_ptr< arrow::DataType> const &) = 0; +#define pyarrow_wrap_data_type __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_data_type +static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_field)(std::shared_ptr< arrow::Field> const &) = 0; +#define pyarrow_wrap_field __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_field +static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_schema)(std::shared_ptr< arrow::Schema> const &) = 0; +#define pyarrow_wrap_schema __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_schema +static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_scalar)(std::shared_ptr< arrow::Scalar> const &) = 0; +#define pyarrow_wrap_scalar __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_scalar +static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_array)(std::shared_ptr< arrow::Array> const &) = 0; +#define pyarrow_wrap_array __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_array +static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_chunked_array)(std::shared_ptr< arrow::ChunkedArray> const &) = 0; +#define pyarrow_wrap_chunked_array __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_chunked_array +static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_coo_tensor)(std::shared_ptr< arrow::SparseCOOTensor> const &) = 0; +#define pyarrow_wrap_sparse_coo_tensor __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_coo_tensor +static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_csc_matrix)(std::shared_ptr< arrow::SparseCSCMatrix> const &) = 0; +#define pyarrow_wrap_sparse_csc_matrix __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_csc_matrix +static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_csf_tensor)(std::shared_ptr< arrow::SparseCSFTensor> const &) = 0; +#define pyarrow_wrap_sparse_csf_tensor __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_csf_tensor +static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_csr_matrix)(std::shared_ptr< arrow::SparseCSRMatrix> const &) = 0; +#define pyarrow_wrap_sparse_csr_matrix __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_csr_matrix +static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_tensor)(std::shared_ptr< arrow::Tensor> const &) = 0; +#define pyarrow_wrap_tensor __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_tensor +static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_batch)(std::shared_ptr< arrow::RecordBatch> const &) = 0; +#define pyarrow_wrap_batch __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_batch +static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_table)(std::shared_ptr< arrow::Table> const &) = 0; +#define pyarrow_wrap_table __pyx_api_f_7pyarrow_3lib_pyarrow_wrap_table +static std::shared_ptr< arrow::Buffer> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_buffer)(PyObject *) = 0; +#define pyarrow_unwrap_buffer __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_buffer +static std::shared_ptr< arrow::DataType> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_data_type)(PyObject *) = 0; +#define pyarrow_unwrap_data_type __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_data_type +static std::shared_ptr< arrow::Field> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_field)(PyObject *) = 0; +#define pyarrow_unwrap_field __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_field +static std::shared_ptr< arrow::Schema> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_schema)(PyObject *) = 0; +#define pyarrow_unwrap_schema __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_schema +static std::shared_ptr< arrow::Scalar> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_scalar)(PyObject *) = 0; +#define pyarrow_unwrap_scalar __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_scalar +static std::shared_ptr< arrow::Array> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_array)(PyObject *) = 0; +#define pyarrow_unwrap_array __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_array +static std::shared_ptr< arrow::ChunkedArray> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_chunked_array)(PyObject *) = 0; +#define pyarrow_unwrap_chunked_array __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_chunked_array +static std::shared_ptr< arrow::SparseCOOTensor> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_coo_tensor)(PyObject *) = 0; +#define pyarrow_unwrap_sparse_coo_tensor __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_coo_tensor +static std::shared_ptr< arrow::SparseCSCMatrix> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_csc_matrix)(PyObject *) = 0; +#define pyarrow_unwrap_sparse_csc_matrix __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_csc_matrix +static std::shared_ptr< arrow::SparseCSFTensor> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_csf_tensor)(PyObject *) = 0; +#define pyarrow_unwrap_sparse_csf_tensor __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_csf_tensor +static std::shared_ptr< arrow::SparseCSRMatrix> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_csr_matrix)(PyObject *) = 0; +#define pyarrow_unwrap_sparse_csr_matrix __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_csr_matrix +static std::shared_ptr< arrow::Tensor> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_tensor)(PyObject *) = 0; +#define pyarrow_unwrap_tensor __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_tensor +static std::shared_ptr< arrow::RecordBatch> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_batch)(PyObject *) = 0; +#define pyarrow_unwrap_batch __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_batch +static std::shared_ptr< arrow::Table> (*__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_table)(PyObject *) = 0; +#define pyarrow_unwrap_table __pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_table +static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_internal_check_status)(arrow::Status const &) = 0; +#define pyarrow_internal_check_status __pyx_api_f_7pyarrow_3lib_pyarrow_internal_check_status +static PyObject *(*__pyx_api_f_7pyarrow_3lib_pyarrow_internal_convert_status)(arrow::Status const &) = 0; +#define pyarrow_internal_convert_status __pyx_api_f_7pyarrow_3lib_pyarrow_internal_convert_status +static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_buffer)(PyObject *) = 0; +#define pyarrow_is_buffer __pyx_api_f_7pyarrow_3lib_pyarrow_is_buffer +static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_data_type)(PyObject *) = 0; +#define pyarrow_is_data_type __pyx_api_f_7pyarrow_3lib_pyarrow_is_data_type +static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_metadata)(PyObject *) = 0; +#define pyarrow_is_metadata __pyx_api_f_7pyarrow_3lib_pyarrow_is_metadata +static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_field)(PyObject *) = 0; +#define pyarrow_is_field __pyx_api_f_7pyarrow_3lib_pyarrow_is_field +static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_schema)(PyObject *) = 0; +#define pyarrow_is_schema __pyx_api_f_7pyarrow_3lib_pyarrow_is_schema +static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_array)(PyObject *) = 0; +#define pyarrow_is_array __pyx_api_f_7pyarrow_3lib_pyarrow_is_array +static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_chunked_array)(PyObject *) = 0; +#define pyarrow_is_chunked_array __pyx_api_f_7pyarrow_3lib_pyarrow_is_chunked_array +static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_scalar)(PyObject *) = 0; +#define pyarrow_is_scalar __pyx_api_f_7pyarrow_3lib_pyarrow_is_scalar +static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_tensor)(PyObject *) = 0; +#define pyarrow_is_tensor __pyx_api_f_7pyarrow_3lib_pyarrow_is_tensor +static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_coo_tensor)(PyObject *) = 0; +#define pyarrow_is_sparse_coo_tensor __pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_coo_tensor +static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_csr_matrix)(PyObject *) = 0; +#define pyarrow_is_sparse_csr_matrix __pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_csr_matrix +static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_csc_matrix)(PyObject *) = 0; +#define pyarrow_is_sparse_csc_matrix __pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_csc_matrix +static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_csf_tensor)(PyObject *) = 0; +#define pyarrow_is_sparse_csf_tensor __pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_csf_tensor +static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_table)(PyObject *) = 0; +#define pyarrow_is_table __pyx_api_f_7pyarrow_3lib_pyarrow_is_table +static int (*__pyx_api_f_7pyarrow_3lib_pyarrow_is_batch)(PyObject *) = 0; +#define pyarrow_is_batch __pyx_api_f_7pyarrow_3lib_pyarrow_is_batch +#ifndef __PYX_HAVE_RT_ImportFunction_3_0_12 +#define __PYX_HAVE_RT_ImportFunction_3_0_12 +static int __Pyx_ImportFunction_3_0_12(PyObject *module, const char *funcname, void (**f)(void), const char *sig) { + PyObject *d = 0; + PyObject *cobj = 0; + union { + void (*fp)(void); + void *p; + } tmp; + d = PyObject_GetAttrString(module, (char *)"__pyx_capi__"); + if (!d) + goto bad; + cobj = PyDict_GetItemString(d, funcname); + if (!cobj) { + PyErr_Format(PyExc_ImportError, + "%.200s does not export expected C function %.200s", + PyModule_GetName(module), funcname); + goto bad; + } + if (!PyCapsule_IsValid(cobj, sig)) { + PyErr_Format(PyExc_TypeError, + "C function %.200s.%.200s has wrong signature (expected %.500s, got %.500s)", + PyModule_GetName(module), funcname, sig, PyCapsule_GetName(cobj)); + goto bad; + } + tmp.p = PyCapsule_GetPointer(cobj, sig); + *f = tmp.fp; + if (!(*f)) + goto bad; + Py_DECREF(d); + return 0; +bad: + Py_XDECREF(d); + return -1; +} +#endif + + +static int import_pyarrow__lib(void) { + PyObject *module = 0; + module = PyImport_ImportModule("pyarrow.lib"); + if (!module) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "box_memory_pool", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_box_memory_pool, "PyObject *( arrow::MemoryPool *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_wrap_buffer", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_buffer, "PyObject *(std::shared_ptr< arrow::Buffer> const &)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_wrap_resizable_buffer", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_resizable_buffer, "PyObject *(std::shared_ptr< arrow::ResizableBuffer> const &)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_wrap_data_type", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_data_type, "PyObject *(std::shared_ptr< arrow::DataType> const &)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_wrap_field", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_field, "PyObject *(std::shared_ptr< arrow::Field> const &)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_wrap_schema", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_schema, "PyObject *(std::shared_ptr< arrow::Schema> const &)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_wrap_scalar", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_scalar, "PyObject *(std::shared_ptr< arrow::Scalar> const &)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_wrap_array", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_array, "PyObject *(std::shared_ptr< arrow::Array> const &)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_wrap_chunked_array", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_chunked_array, "PyObject *(std::shared_ptr< arrow::ChunkedArray> const &)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_wrap_sparse_coo_tensor", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_coo_tensor, "PyObject *(std::shared_ptr< arrow::SparseCOOTensor> const &)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_wrap_sparse_csc_matrix", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_csc_matrix, "PyObject *(std::shared_ptr< arrow::SparseCSCMatrix> const &)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_wrap_sparse_csf_tensor", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_csf_tensor, "PyObject *(std::shared_ptr< arrow::SparseCSFTensor> const &)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_wrap_sparse_csr_matrix", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_sparse_csr_matrix, "PyObject *(std::shared_ptr< arrow::SparseCSRMatrix> const &)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_wrap_tensor", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_tensor, "PyObject *(std::shared_ptr< arrow::Tensor> const &)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_wrap_batch", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_batch, "PyObject *(std::shared_ptr< arrow::RecordBatch> const &)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_wrap_table", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_wrap_table, "PyObject *(std::shared_ptr< arrow::Table> const &)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_unwrap_buffer", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_buffer, "std::shared_ptr< arrow::Buffer> (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_unwrap_data_type", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_data_type, "std::shared_ptr< arrow::DataType> (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_unwrap_field", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_field, "std::shared_ptr< arrow::Field> (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_unwrap_schema", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_schema, "std::shared_ptr< arrow::Schema> (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_unwrap_scalar", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_scalar, "std::shared_ptr< arrow::Scalar> (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_unwrap_array", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_array, "std::shared_ptr< arrow::Array> (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_unwrap_chunked_array", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_chunked_array, "std::shared_ptr< arrow::ChunkedArray> (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_unwrap_sparse_coo_tensor", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_coo_tensor, "std::shared_ptr< arrow::SparseCOOTensor> (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_unwrap_sparse_csc_matrix", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_csc_matrix, "std::shared_ptr< arrow::SparseCSCMatrix> (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_unwrap_sparse_csf_tensor", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_csf_tensor, "std::shared_ptr< arrow::SparseCSFTensor> (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_unwrap_sparse_csr_matrix", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_sparse_csr_matrix, "std::shared_ptr< arrow::SparseCSRMatrix> (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_unwrap_tensor", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_tensor, "std::shared_ptr< arrow::Tensor> (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_unwrap_batch", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_batch, "std::shared_ptr< arrow::RecordBatch> (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_unwrap_table", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_unwrap_table, "std::shared_ptr< arrow::Table> (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_internal_check_status", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_internal_check_status, "int (arrow::Status const &)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_internal_convert_status", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_internal_convert_status, "PyObject *(arrow::Status const &)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_is_buffer", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_buffer, "int (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_is_data_type", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_data_type, "int (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_is_metadata", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_metadata, "int (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_is_field", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_field, "int (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_is_schema", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_schema, "int (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_is_array", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_array, "int (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_is_chunked_array", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_chunked_array, "int (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_is_scalar", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_scalar, "int (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_is_tensor", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_tensor, "int (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_is_sparse_coo_tensor", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_coo_tensor, "int (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_is_sparse_csr_matrix", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_csr_matrix, "int (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_is_sparse_csc_matrix", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_csc_matrix, "int (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_is_sparse_csf_tensor", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_sparse_csf_tensor, "int (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_is_table", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_table, "int (PyObject *)") < 0) goto bad; + if (__Pyx_ImportFunction_3_0_12(module, "pyarrow_is_batch", (void (**)(void))&__pyx_api_f_7pyarrow_3lib_pyarrow_is_batch, "int (PyObject *)") < 0) goto bad; + Py_DECREF(module); module = 0; + return 0; + bad: + Py_XDECREF(module); + return -1; +} + +#endif /* !__PYX_HAVE_API__pyarrow__lib */ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/libarrow_python_parquet_encryption.so b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/libarrow_python_parquet_encryption.so new file mode 100644 index 0000000000000000000000000000000000000000..cd9621415c785a581bf777beafe5ed2745331921 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/libarrow_python_parquet_encryption.so differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/public-api.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/public-api.pxi new file mode 100644 index 0000000000000000000000000000000000000000..d1fa1192debc3cd24b23a34e226761ec4aab02cc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/public-api.pxi @@ -0,0 +1,443 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from libcpp.memory cimport shared_ptr +from pyarrow.includes.libarrow cimport (CArray, CDataType, CField, + CRecordBatch, CSchema, + CTable, CTensor, + CSparseCOOTensor, CSparseCSRMatrix, + CSparseCSCMatrix, CSparseCSFTensor) + +# You cannot assign something to a dereferenced pointer in Cython thus these +# methods don't use Status to indicate a successful operation. + + +cdef api bint pyarrow_is_buffer(object buffer): + return isinstance(buffer, Buffer) + + +cdef api shared_ptr[CBuffer] pyarrow_unwrap_buffer(object buffer): + cdef Buffer buf + if pyarrow_is_buffer(buffer): + buf = (buffer) + return buf.buffer + + return shared_ptr[CBuffer]() + + +cdef api object pyarrow_wrap_buffer(const shared_ptr[CBuffer]& buf): + cdef Buffer result = Buffer.__new__(Buffer) + result.init(buf) + return result + + +cdef api object pyarrow_wrap_resizable_buffer( + const shared_ptr[CResizableBuffer]& buf): + cdef ResizableBuffer result = ResizableBuffer.__new__(ResizableBuffer) + result.init_rz(buf) + return result + + +cdef api bint pyarrow_is_data_type(object type_): + return isinstance(type_, DataType) + + +cdef api shared_ptr[CDataType] pyarrow_unwrap_data_type( + object data_type): + cdef DataType type_ + if pyarrow_is_data_type(data_type): + type_ = (data_type) + return type_.sp_type + + return shared_ptr[CDataType]() + + +# Workaround for Cython parsing bug +# https://github.com/cython/cython/issues/2143 +ctypedef const CPyExtensionType* _CPyExtensionTypePtr + + +cdef api object pyarrow_wrap_data_type( + const shared_ptr[CDataType]& type): + cdef: + const CExtensionType* ext_type + const CPyExtensionType* cpy_ext_type + DataType out + + if type.get() == NULL: + return None + + if type.get().id() == _Type_DICTIONARY: + out = DictionaryType.__new__(DictionaryType) + elif type.get().id() == _Type_LIST: + out = ListType.__new__(ListType) + elif type.get().id() == _Type_LARGE_LIST: + out = LargeListType.__new__(LargeListType) + elif type.get().id() == _Type_LIST_VIEW: + out = ListViewType.__new__(ListViewType) + elif type.get().id() == _Type_LARGE_LIST_VIEW: + out = LargeListViewType.__new__(LargeListViewType) + elif type.get().id() == _Type_MAP: + out = MapType.__new__(MapType) + elif type.get().id() == _Type_FIXED_SIZE_LIST: + out = FixedSizeListType.__new__(FixedSizeListType) + elif type.get().id() == _Type_STRUCT: + out = StructType.__new__(StructType) + elif type.get().id() == _Type_SPARSE_UNION: + out = SparseUnionType.__new__(SparseUnionType) + elif type.get().id() == _Type_DENSE_UNION: + out = DenseUnionType.__new__(DenseUnionType) + elif type.get().id() == _Type_TIME32: + out = Time32Type.__new__(Time32Type) + elif type.get().id() == _Type_TIME64: + out = Time64Type.__new__(Time64Type) + elif type.get().id() == _Type_TIMESTAMP: + out = TimestampType.__new__(TimestampType) + elif type.get().id() == _Type_DURATION: + out = DurationType.__new__(DurationType) + elif type.get().id() == _Type_FIXED_SIZE_BINARY: + out = FixedSizeBinaryType.__new__(FixedSizeBinaryType) + elif type.get().id() == _Type_DECIMAL32: + out = Decimal32Type.__new__(Decimal32Type) + elif type.get().id() == _Type_DECIMAL64: + out = Decimal64Type.__new__(Decimal64Type) + elif type.get().id() == _Type_DECIMAL128: + out = Decimal128Type.__new__(Decimal128Type) + elif type.get().id() == _Type_DECIMAL256: + out = Decimal256Type.__new__(Decimal256Type) + elif type.get().id() == _Type_RUN_END_ENCODED: + out = RunEndEncodedType.__new__(RunEndEncodedType) + elif type.get().id() == _Type_EXTENSION: + ext_type = type.get() + cpy_ext_type = dynamic_cast[_CPyExtensionTypePtr](ext_type) + extension_name = ext_type.extension_name() + if cpy_ext_type != nullptr: + return cpy_ext_type.GetInstance() + elif extension_name == b"arrow.bool8": + out = Bool8Type.__new__(Bool8Type) + elif extension_name == b"arrow.fixed_shape_tensor": + out = FixedShapeTensorType.__new__(FixedShapeTensorType) + elif extension_name == b"arrow.opaque": + out = OpaqueType.__new__(OpaqueType) + elif extension_name == b"arrow.uuid": + out = UuidType.__new__(UuidType) + elif extension_name == b"arrow.json": + out = JsonType.__new__(JsonType) + else: + out = BaseExtensionType.__new__(BaseExtensionType) + else: + out = DataType.__new__(DataType) + + out.init(type) + return out + + +cdef object pyarrow_wrap_metadata( + const shared_ptr[const CKeyValueMetadata]& meta): + if meta.get() == nullptr: + return None + else: + return KeyValueMetadata.wrap(meta) + + +cdef api bint pyarrow_is_metadata(object metadata): + return isinstance(metadata, KeyValueMetadata) + + +cdef shared_ptr[const CKeyValueMetadata] pyarrow_unwrap_metadata(object meta): + cdef shared_ptr[const CKeyValueMetadata] c_meta + if pyarrow_is_metadata(meta): + c_meta = (meta).unwrap() + return c_meta + + +cdef api bint pyarrow_is_field(object field): + return isinstance(field, Field) + + +cdef api shared_ptr[CField] pyarrow_unwrap_field(object field): + cdef Field field_ + if pyarrow_is_field(field): + field_ = (field) + return field_.sp_field + + return shared_ptr[CField]() + + +cdef api object pyarrow_wrap_field(const shared_ptr[CField]& field): + if field.get() == NULL: + return None + cdef Field out = Field.__new__(Field) + out.init(field) + return out + + +cdef api bint pyarrow_is_schema(object schema): + return isinstance(schema, Schema) + + +cdef api shared_ptr[CSchema] pyarrow_unwrap_schema(object schema): + cdef Schema sch + if pyarrow_is_schema(schema): + sch = (schema) + return sch.sp_schema + + return shared_ptr[CSchema]() + + +cdef api object pyarrow_wrap_schema(const shared_ptr[CSchema]& schema): + cdef Schema out = Schema.__new__(Schema) + out.init_schema(schema) + return out + + +cdef api bint pyarrow_is_array(object array): + return isinstance(array, Array) + + +cdef api shared_ptr[CArray] pyarrow_unwrap_array(object array): + cdef Array arr + if pyarrow_is_array(array): + arr = (array) + return arr.sp_array + + return shared_ptr[CArray]() + + +cdef api object pyarrow_wrap_array(const shared_ptr[CArray]& sp_array): + if sp_array.get() == NULL: + raise ValueError('Array was NULL') + + klass = get_array_class_from_type(sp_array.get().type()) + + cdef Array arr = klass.__new__(klass) + arr.init(sp_array) + return arr + + +cdef api bint pyarrow_is_chunked_array(object array): + return isinstance(array, ChunkedArray) + + +cdef api shared_ptr[CChunkedArray] pyarrow_unwrap_chunked_array(object array): + cdef ChunkedArray arr + if pyarrow_is_chunked_array(array): + arr = (array) + return arr.sp_chunked_array + + return shared_ptr[CChunkedArray]() + + +cdef api object pyarrow_wrap_chunked_array( + const shared_ptr[CChunkedArray]& sp_array): + if sp_array.get() == NULL: + raise ValueError('ChunkedArray was NULL') + + cdef CDataType* data_type = sp_array.get().type().get() + + if data_type == NULL: + raise ValueError('ChunkedArray data type was NULL') + + cdef ChunkedArray arr = ChunkedArray.__new__(ChunkedArray) + arr.init(sp_array) + return arr + + +cdef api bint pyarrow_is_scalar(object value): + return isinstance(value, Scalar) + + +cdef api shared_ptr[CScalar] pyarrow_unwrap_scalar(object scalar): + if pyarrow_is_scalar(scalar): + return ( scalar).unwrap() + return shared_ptr[CScalar]() + + +cdef api object pyarrow_wrap_scalar(const shared_ptr[CScalar]& sp_scalar): + if sp_scalar.get() == NULL: + raise ValueError('Scalar was NULL') + + cdef CDataType* data_type = sp_scalar.get().type.get() + + if data_type == NULL: + raise ValueError('Scalar data type was NULL') + + if data_type.id() == _Type_NA: + return _NULL + + if data_type.id() not in _scalar_classes: + raise ValueError('Scalar type not supported') + + klass = get_scalar_class_from_type(sp_scalar.get().type) + + cdef Scalar scalar = klass.__new__(klass) + scalar.init(sp_scalar) + return scalar + + +cdef api bint pyarrow_is_tensor(object tensor): + return isinstance(tensor, Tensor) + + +cdef api shared_ptr[CTensor] pyarrow_unwrap_tensor(object tensor): + cdef Tensor ten + if pyarrow_is_tensor(tensor): + ten = (tensor) + return ten.sp_tensor + + return shared_ptr[CTensor]() + + +cdef api object pyarrow_wrap_tensor( + const shared_ptr[CTensor]& sp_tensor): + if sp_tensor.get() == NULL: + raise ValueError('Tensor was NULL') + + cdef Tensor tensor = Tensor.__new__(Tensor) + tensor.init(sp_tensor) + return tensor + + +cdef api bint pyarrow_is_sparse_coo_tensor(object sparse_tensor): + return isinstance(sparse_tensor, SparseCOOTensor) + +cdef api shared_ptr[CSparseCOOTensor] pyarrow_unwrap_sparse_coo_tensor( + object sparse_tensor): + cdef SparseCOOTensor sten + if pyarrow_is_sparse_coo_tensor(sparse_tensor): + sten = (sparse_tensor) + return sten.sp_sparse_tensor + + return shared_ptr[CSparseCOOTensor]() + +cdef api object pyarrow_wrap_sparse_coo_tensor( + const shared_ptr[CSparseCOOTensor]& sp_sparse_tensor): + if sp_sparse_tensor.get() == NULL: + raise ValueError('SparseCOOTensor was NULL') + + cdef SparseCOOTensor sparse_tensor = SparseCOOTensor.__new__( + SparseCOOTensor) + sparse_tensor.init(sp_sparse_tensor) + return sparse_tensor + + +cdef api bint pyarrow_is_sparse_csr_matrix(object sparse_tensor): + return isinstance(sparse_tensor, SparseCSRMatrix) + +cdef api shared_ptr[CSparseCSRMatrix] pyarrow_unwrap_sparse_csr_matrix( + object sparse_tensor): + cdef SparseCSRMatrix sten + if pyarrow_is_sparse_csr_matrix(sparse_tensor): + sten = (sparse_tensor) + return sten.sp_sparse_tensor + + return shared_ptr[CSparseCSRMatrix]() + +cdef api object pyarrow_wrap_sparse_csr_matrix( + const shared_ptr[CSparseCSRMatrix]& sp_sparse_tensor): + if sp_sparse_tensor.get() == NULL: + raise ValueError('SparseCSRMatrix was NULL') + + cdef SparseCSRMatrix sparse_tensor = SparseCSRMatrix.__new__( + SparseCSRMatrix) + sparse_tensor.init(sp_sparse_tensor) + return sparse_tensor + + +cdef api bint pyarrow_is_sparse_csc_matrix(object sparse_tensor): + return isinstance(sparse_tensor, SparseCSCMatrix) + +cdef api shared_ptr[CSparseCSCMatrix] pyarrow_unwrap_sparse_csc_matrix( + object sparse_tensor): + cdef SparseCSCMatrix sten + if pyarrow_is_sparse_csc_matrix(sparse_tensor): + sten = (sparse_tensor) + return sten.sp_sparse_tensor + + return shared_ptr[CSparseCSCMatrix]() + +cdef api object pyarrow_wrap_sparse_csc_matrix( + const shared_ptr[CSparseCSCMatrix]& sp_sparse_tensor): + if sp_sparse_tensor.get() == NULL: + raise ValueError('SparseCSCMatrix was NULL') + + cdef SparseCSCMatrix sparse_tensor = SparseCSCMatrix.__new__( + SparseCSCMatrix) + sparse_tensor.init(sp_sparse_tensor) + return sparse_tensor + + +cdef api bint pyarrow_is_sparse_csf_tensor(object sparse_tensor): + return isinstance(sparse_tensor, SparseCSFTensor) + +cdef api shared_ptr[CSparseCSFTensor] pyarrow_unwrap_sparse_csf_tensor( + object sparse_tensor): + cdef SparseCSFTensor sten + if pyarrow_is_sparse_csf_tensor(sparse_tensor): + sten = (sparse_tensor) + return sten.sp_sparse_tensor + + return shared_ptr[CSparseCSFTensor]() + +cdef api object pyarrow_wrap_sparse_csf_tensor( + const shared_ptr[CSparseCSFTensor]& sp_sparse_tensor): + if sp_sparse_tensor.get() == NULL: + raise ValueError('SparseCSFTensor was NULL') + + cdef SparseCSFTensor sparse_tensor = SparseCSFTensor.__new__( + SparseCSFTensor) + sparse_tensor.init(sp_sparse_tensor) + return sparse_tensor + + +cdef api bint pyarrow_is_table(object table): + return isinstance(table, Table) + + +cdef api shared_ptr[CTable] pyarrow_unwrap_table(object table): + cdef Table tab + if pyarrow_is_table(table): + tab =
(table) + return tab.sp_table + + return shared_ptr[CTable]() + + +cdef api object pyarrow_wrap_table(const shared_ptr[CTable]& ctable): + cdef Table table = Table.__new__(Table) + table.init(ctable) + return table + + +cdef api bint pyarrow_is_batch(object batch): + return isinstance(batch, RecordBatch) + + +cdef api shared_ptr[CRecordBatch] pyarrow_unwrap_batch(object batch): + cdef RecordBatch bat + if pyarrow_is_batch(batch): + bat = (batch) + return bat.sp_batch + + return shared_ptr[CRecordBatch]() + + +cdef api object pyarrow_wrap_batch( + const shared_ptr[CRecordBatch]& cbatch): + cdef RecordBatch batch = RecordBatch.__new__(RecordBatch) + batch.init(cbatch) + return batch diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/table.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/table.pxi new file mode 100644 index 0000000000000000000000000000000000000000..af241e4be07d96c95e78ac32c93e3c2f5c182498 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/table.pxi @@ -0,0 +1,6562 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from cpython.pycapsule cimport PyCapsule_CheckExact, PyCapsule_GetPointer, PyCapsule_New + +import warnings +from cython import sizeof + +cdef class ChunkedArray(_PandasConvertible): + """ + An array-like composed from a (possibly empty) collection of pyarrow.Arrays + + Warnings + -------- + Do not call this class's constructor directly. + + Examples + -------- + To construct a ChunkedArray object use :func:`pyarrow.chunked_array`: + + >>> import pyarrow as pa + >>> pa.chunked_array([], type=pa.int8()) + + [ + ... + ] + + >>> pa.chunked_array([[2, 2, 4], [4, 5, 100]]) + + [ + [ + 2, + 2, + 4 + ], + [ + 4, + 5, + 100 + ] + ] + >>> isinstance(pa.chunked_array([[2, 2, 4], [4, 5, 100]]), pa.ChunkedArray) + True + """ + + def __cinit__(self): + self.chunked_array = NULL + self._init_is_cpu = False + + def __init__(self): + raise TypeError("Do not call ChunkedArray's constructor directly, use " + "`chunked_array` function instead.") + + cdef void init(self, const shared_ptr[CChunkedArray]& chunked_array): + self.sp_chunked_array = chunked_array + self.chunked_array = chunked_array.get() + + def __reduce__(self): + self._assert_cpu() + return chunked_array, (self.chunks, self.type) + + @property + def data(self): + import warnings + warnings.warn("Calling .data on ChunkedArray is provided for " + "compatibility after Column was removed, simply drop " + "this attribute", FutureWarning) + return self + + @property + def type(self): + """ + Return data type of a ChunkedArray. + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.chunked_array([[2, 2, 4], [4, 5, 100]]) + >>> n_legs.type + DataType(int64) + """ + return pyarrow_wrap_data_type(self.sp_chunked_array.get().type()) + + def length(self): + """ + Return length of a ChunkedArray. + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.chunked_array([[2, 2, 4], [4, 5, 100]]) + >>> n_legs.length() + 6 + """ + return self.chunked_array.length() + + def __len__(self): + return self.length() + + def __repr__(self): + type_format = object.__repr__(self) + return '{0}\n{1}'.format(type_format, str(self)) + + def to_string(self, *, int indent=0, int window=5, int container_window=2, + c_bool skip_new_lines=False): + """ + Render a "pretty-printed" string representation of the ChunkedArray + + Parameters + ---------- + indent : int + How much to indent right the content of the array, + by default ``0``. + window : int + How many items to preview within each chunk at the begin and end + of the chunk when the chunk is bigger than the window. + The other elements will be ellipsed. + container_window : int + How many chunks to preview at the begin and end + of the array when the array is bigger than the window. + The other elements will be ellipsed. + This setting also applies to list columns. + skip_new_lines : bool + If the array should be rendered as a single line of text + or if each element should be on its own line. + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.chunked_array([[2, 2, 4], [4, 5, 100]]) + >>> n_legs.to_string(skip_new_lines=True) + '[[2,2,4],[4,5,100]]' + """ + cdef: + c_string result + PrettyPrintOptions options + + with nogil: + options = PrettyPrintOptions(indent, window) + options.skip_new_lines = skip_new_lines + options.container_window = container_window + check_status( + PrettyPrint( + deref(self.chunked_array), + options, + &result + ) + ) + + return frombytes(result, safe=True) + + def format(self, **kwargs): + """ + DEPRECATED, use pyarrow.ChunkedArray.to_string + + Parameters + ---------- + **kwargs : dict + + Returns + ------- + str + """ + import warnings + warnings.warn('ChunkedArray.format is deprecated, ' + 'use ChunkedArray.to_string') + return self.to_string(**kwargs) + + def __str__(self): + return self.to_string() + + def validate(self, *, full=False): + """ + Perform validation checks. An exception is raised if validation fails. + + By default only cheap validation checks are run. Pass `full=True` + for thorough validation checks (potentially O(n)). + + Parameters + ---------- + full : bool, default False + If True, run expensive checks, otherwise cheap checks only. + + Raises + ------ + ArrowInvalid + """ + if full: + self._assert_cpu() + with nogil: + check_status(self.sp_chunked_array.get().ValidateFull()) + else: + with nogil: + check_status(self.sp_chunked_array.get().Validate()) + + @property + def null_count(self): + """ + Number of null entries + + Returns + ------- + int + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.chunked_array([[2, 2, 4], [4, None, 100]]) + >>> n_legs.null_count + 1 + """ + self._assert_cpu() + return self.chunked_array.null_count() + + @property + def nbytes(self): + """ + Total number of bytes consumed by the elements of the chunked array. + + In other words, the sum of bytes from all buffer ranges referenced. + + Unlike `get_total_buffer_size` this method will account for array + offsets. + + If buffers are shared between arrays then the shared + portion will only be counted multiple times. + + The dictionary of dictionary arrays will always be counted in their + entirety even if the array only references a portion of the dictionary. + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.chunked_array([[2, 2, 4], [4, None, 100]]) + >>> n_legs.nbytes + 49 + """ + self._assert_cpu() + cdef: + CResult[int64_t] c_res_buffer + + with nogil: + c_res_buffer = ReferencedBufferSize(deref(self.chunked_array)) + size = GetResultValue(c_res_buffer) + return size + + def get_total_buffer_size(self): + """ + The sum of bytes in each buffer referenced by the chunked array. + + An array may only reference a portion of a buffer. + This method will overestimate in this case and return the + byte size of the entire buffer. + + If a buffer is referenced multiple times then it will + only be counted once. + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.chunked_array([[2, 2, 4], [4, None, 100]]) + >>> n_legs.get_total_buffer_size() + 49 + """ + self._assert_cpu() + cdef: + int64_t total_buffer_size + + total_buffer_size = TotalBufferSize(deref(self.chunked_array)) + return total_buffer_size + + def __sizeof__(self): + return super(ChunkedArray, self).__sizeof__() + self.nbytes + + def __iter__(self): + for chunk in self.iterchunks(): + for item in chunk: + yield item + + def __getitem__(self, key): + """ + Slice or return value at given index + + Parameters + ---------- + key : integer or slice + Slices with step not equal to 1 (or None) will produce a copy + rather than a zero-copy view + + Returns + ------- + value : Scalar (index) or ChunkedArray (slice) + """ + self._assert_cpu() + if isinstance(key, slice): + return _normalize_slice(self, key) + + return self.getitem(_normalize_index(key, self.chunked_array.length())) + + cdef getitem(self, int64_t i): + self._assert_cpu() + return Scalar.wrap(GetResultValue(self.chunked_array.GetScalar(i))) + + def is_null(self, *, nan_is_null=False): + """ + Return boolean array indicating the null values. + + Parameters + ---------- + nan_is_null : bool (optional, default False) + Whether floating-point NaN values should also be considered null. + + Returns + ------- + array : boolean Array or ChunkedArray + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.chunked_array([[2, 2, 4], [4, None, 100]]) + >>> n_legs.is_null() + + [ + [ + false, + false, + false, + false, + true, + false + ] + ] + """ + self._assert_cpu() + options = _pc().NullOptions(nan_is_null=nan_is_null) + return _pc().call_function('is_null', [self], options) + + def is_nan(self): + """ + Return boolean array indicating the NaN values. + + Examples + -------- + >>> import pyarrow as pa + >>> import numpy as np + >>> arr = pa.chunked_array([[2, np.nan, 4], [4, None, 100]]) + >>> arr.is_nan() + + [ + [ + false, + true, + false, + false, + null, + false + ] + ] + """ + self._assert_cpu() + return _pc().is_nan(self) + + def is_valid(self): + """ + Return boolean array indicating the non-null values. + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.chunked_array([[2, 2, 4], [4, None, 100]]) + >>> n_legs.is_valid() + + [ + [ + true, + true, + true + ], + [ + true, + false, + true + ] + ] + """ + self._assert_cpu() + return _pc().is_valid(self) + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return NotImplemented + + def fill_null(self, fill_value): + """ + Replace each null element in values with fill_value. + + See :func:`pyarrow.compute.fill_null` for full usage. + + Parameters + ---------- + fill_value : any + The replacement value for null entries. + + Returns + ------- + result : Array or ChunkedArray + A new array with nulls replaced by the given value. + + Examples + -------- + >>> import pyarrow as pa + >>> fill_value = pa.scalar(5, type=pa.int8()) + >>> n_legs = pa.chunked_array([[2, 2, 4], [4, None, 100]]) + >>> n_legs.fill_null(fill_value) + + [ + [ + 2, + 2, + 4, + 4, + 5, + 100 + ] + ] + """ + self._assert_cpu() + return _pc().fill_null(self, fill_value) + + def equals(self, ChunkedArray other): + """ + Return whether the contents of two chunked arrays are equal. + + Parameters + ---------- + other : pyarrow.ChunkedArray + Chunked array to compare against. + + Returns + ------- + are_equal : bool + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.chunked_array([[2, 2, 4], [4, 5, 100]]) + >>> animals = pa.chunked_array(( + ... ["Flamingo", "Parrot", "Dog"], + ... ["Horse", "Brittle stars", "Centipede"] + ... )) + >>> n_legs.equals(n_legs) + True + >>> n_legs.equals(animals) + False + """ + self._assert_cpu() + if other is None: + return False + + cdef: + CChunkedArray* this_arr = self.chunked_array + CChunkedArray* other_arr = other.chunked_array + c_bool result + + with nogil: + result = this_arr.Equals(deref(other_arr)) + + return result + + def _to_pandas(self, options, types_mapper=None, **kwargs): + self._assert_cpu() + return _array_like_to_pandas(self, options, types_mapper=types_mapper) + + def to_numpy(self, zero_copy_only=False): + """ + Return a NumPy copy of this array (experimental). + + Parameters + ---------- + zero_copy_only : bool, default False + Introduced for signature consistence with pyarrow.Array.to_numpy. + This must be False here since NumPy arrays' buffer must be contiguous. + + Returns + ------- + array : numpy.ndarray + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.chunked_array([[2, 2, 4], [4, 5, 100]]) + >>> n_legs.to_numpy() + array([ 2, 2, 4, 4, 5, 100]) + """ + self._assert_cpu() + if np is None: + raise ImportError( + "Cannot return a numpy.ndarray if NumPy is not present") + if zero_copy_only: + raise ValueError( + "zero_copy_only must be False for pyarrow.ChunkedArray.to_numpy" + ) + cdef: + PyObject* out + PandasOptions c_options + object values + + c_options.to_numpy = True + + with nogil: + check_status( + ConvertChunkedArrayToPandas( + c_options, + self.sp_chunked_array, + self, + &out + ) + ) + + # wrap_array_output uses pandas to convert to Categorical, here + # always convert to numpy array + values = PyObject_to_object(out) + + if isinstance(values, dict): + values = np.take(values['dictionary'], values['indices']) + + return values + + def __array__(self, dtype=None, copy=None): + self._assert_cpu() + if copy is False: + raise ValueError( + "Unable to avoid a copy while creating a numpy array as requested " + "(converting a pyarrow.ChunkedArray always results in a copy).\n" + "If using `np.array(obj, copy=False)` replace it with " + "`np.asarray(obj)` to allow a copy when needed" + ) + # 'copy' can further be ignored because to_numpy() already returns a copy + values = self.to_numpy() + if dtype is None: + return values + return values.astype(dtype, copy=False) + + def cast(self, object target_type=None, safe=None, options=None): + """ + Cast array values to another data type + + See :func:`pyarrow.compute.cast` for usage. + + Parameters + ---------- + target_type : DataType, None + Type to cast array to. + safe : boolean, default True + Whether to check for conversion errors such as overflow. + options : CastOptions, default None + Additional checks pass by CastOptions + + Returns + ------- + cast : Array or ChunkedArray + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.chunked_array([[2, 2, 4], [4, 5, 100]]) + >>> n_legs.type + DataType(int64) + + Change the data type of an array: + + >>> n_legs_seconds = n_legs.cast(pa.duration('s')) + >>> n_legs_seconds.type + DurationType(duration[s]) + """ + self._assert_cpu() + return _pc().cast(self, target_type, safe=safe, options=options) + + def dictionary_encode(self, null_encoding='mask'): + """ + Compute dictionary-encoded representation of array. + + See :func:`pyarrow.compute.dictionary_encode` for full usage. + + Parameters + ---------- + null_encoding : str, default "mask" + How to handle null entries. + + Returns + ------- + encoded : ChunkedArray + A dictionary-encoded version of this array. + + Examples + -------- + >>> import pyarrow as pa + >>> animals = pa.chunked_array(( + ... ["Flamingo", "Parrot", "Dog"], + ... ["Horse", "Brittle stars", "Centipede"] + ... )) + >>> animals.dictionary_encode() + + [ + ... + -- dictionary: + [ + "Flamingo", + "Parrot", + "Dog", + "Horse", + "Brittle stars", + "Centipede" + ] + -- indices: + [ + 0, + 1, + 2 + ], + ... + -- dictionary: + [ + "Flamingo", + "Parrot", + "Dog", + "Horse", + "Brittle stars", + "Centipede" + ] + -- indices: + [ + 3, + 4, + 5 + ] + ] + """ + self._assert_cpu() + options = _pc().DictionaryEncodeOptions(null_encoding) + return _pc().call_function('dictionary_encode', [self], options) + + def flatten(self, MemoryPool memory_pool=None): + """ + Flatten this ChunkedArray. If it has a struct type, the column is + flattened into one array per struct field. + + Parameters + ---------- + memory_pool : MemoryPool, default None + For memory allocations, if required, otherwise use default pool + + Returns + ------- + result : list of ChunkedArray + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.chunked_array([[2, 2, 4], [4, 5, 100]]) + >>> c_arr = pa.chunked_array(n_legs.value_counts()) + >>> c_arr + + [ + -- is_valid: all not null + -- child 0 type: int64 + [ + 2, + 4, + 5, + 100 + ] + -- child 1 type: int64 + [ + 2, + 2, + 1, + 1 + ] + ] + >>> c_arr.flatten() + [ + [ + [ + 2, + 4, + 5, + 100 + ] + ], + [ + [ + 2, + 2, + 1, + 1 + ] + ]] + >>> c_arr.type + StructType(struct) + >>> n_legs.type + DataType(int64) + """ + self._assert_cpu() + cdef: + vector[shared_ptr[CChunkedArray]] flattened + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + + with nogil: + flattened = GetResultValue(self.chunked_array.Flatten(pool)) + + return [pyarrow_wrap_chunked_array(col) for col in flattened] + + def combine_chunks(self, MemoryPool memory_pool=None): + """ + Flatten this ChunkedArray into a single non-chunked array. + + Parameters + ---------- + memory_pool : MemoryPool, default None + For memory allocations, if required, otherwise use default pool + + Returns + ------- + result : Array + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.chunked_array([[2, 2, 4], [4, 5, 100]]) + >>> n_legs + + [ + [ + 2, + 2, + 4 + ], + [ + 4, + 5, + 100 + ] + ] + >>> n_legs.combine_chunks() + + [ + 2, + 2, + 4, + 4, + 5, + 100 + ] + """ + self._assert_cpu() + if self.num_chunks == 0: + return array([], type=self.type) + else: + return concat_arrays(self.chunks) + + def unique(self): + """ + Compute distinct elements in array + + Returns + ------- + pyarrow.Array + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.chunked_array([[2, 2, 4], [4, 5, 100]]) + >>> n_legs + + [ + [ + 2, + 2, + 4 + ], + [ + 4, + 5, + 100 + ] + ] + >>> n_legs.unique() + + [ + 2, + 4, + 5, + 100 + ] + """ + self._assert_cpu() + return _pc().call_function('unique', [self]) + + def value_counts(self): + """ + Compute counts of unique elements in array. + + Returns + ------- + An array of structs + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.chunked_array([[2, 2, 4], [4, 5, 100]]) + >>> n_legs + + [ + [ + 2, + 2, + 4 + ], + [ + 4, + 5, + 100 + ] + ] + >>> n_legs.value_counts() + + -- is_valid: all not null + -- child 0 type: int64 + [ + 2, + 4, + 5, + 100 + ] + -- child 1 type: int64 + [ + 2, + 2, + 1, + 1 + ] + """ + self._assert_cpu() + return _pc().call_function('value_counts', [self]) + + def slice(self, offset=0, length=None): + """ + Compute zero-copy slice of this ChunkedArray + + Parameters + ---------- + offset : int, default 0 + Offset from start of array to slice + length : int, default None + Length of slice (default is until end of batch starting from + offset) + + Returns + ------- + sliced : ChunkedArray + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.chunked_array([[2, 2, 4], [4, 5, 100]]) + >>> n_legs + + [ + [ + 2, + 2, + 4 + ], + [ + 4, + 5, + 100 + ] + ] + >>> n_legs.slice(2,2) + + [ + [ + 4 + ], + [ + 4 + ] + ] + """ + cdef shared_ptr[CChunkedArray] result + + if offset < 0: + raise IndexError('Offset must be non-negative') + + offset = min(len(self), offset) + if length is None: + result = self.chunked_array.Slice(offset) + else: + result = self.chunked_array.Slice(offset, length) + + return pyarrow_wrap_chunked_array(result) + + def filter(self, mask, object null_selection_behavior="drop"): + """ + Select values from the chunked array. + + See :func:`pyarrow.compute.filter` for full usage. + + Parameters + ---------- + mask : Array or array-like + The boolean mask to filter the chunked array with. + null_selection_behavior : str, default "drop" + How nulls in the mask should be handled. + + Returns + ------- + filtered : Array or ChunkedArray + An array of the same type, with only the elements selected by + the boolean mask. + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.chunked_array([[2, 2, 4], [4, 5, 100]]) + >>> n_legs + + [ + [ + 2, + 2, + 4 + ], + [ + 4, + 5, + 100 + ] + ] + >>> mask = pa.array([True, False, None, True, False, True]) + >>> n_legs.filter(mask) + + [ + [ + 2 + ], + [ + 4, + 100 + ] + ] + >>> n_legs.filter(mask, null_selection_behavior="emit_null") + + [ + [ + 2, + null + ], + [ + 4, + 100 + ] + ] + """ + self._assert_cpu() + return _pc().filter(self, mask, null_selection_behavior) + + def index(self, value, start=None, end=None, *, memory_pool=None): + """ + Find the first index of a value. + + See :func:`pyarrow.compute.index` for full usage. + + Parameters + ---------- + value : Scalar or object + The value to look for in the array. + start : int, optional + The start index where to look for `value`. + end : int, optional + The end index where to look for `value`. + memory_pool : MemoryPool, optional + A memory pool for potential memory allocations. + + Returns + ------- + index : Int64Scalar + The index of the value in the array (-1 if not found). + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.chunked_array([[2, 2, 4], [4, 5, 100]]) + >>> n_legs + + [ + [ + 2, + 2, + 4 + ], + [ + 4, + 5, + 100 + ] + ] + >>> n_legs.index(4) + + >>> n_legs.index(4, start=3) + + """ + self._assert_cpu() + return _pc().index(self, value, start, end, memory_pool=memory_pool) + + def take(self, object indices): + """ + Select values from the chunked array. + + See :func:`pyarrow.compute.take` for full usage. + + Parameters + ---------- + indices : Array or array-like + The indices in the array whose values will be returned. + + Returns + ------- + taken : Array or ChunkedArray + An array with the same datatype, containing the taken values. + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.chunked_array([[2, 2, 4], [4, 5, 100]]) + >>> n_legs + + [ + [ + 2, + 2, + 4 + ], + [ + 4, + 5, + 100 + ] + ] + >>> n_legs.take([1,4,5]) + + [ + [ + 2, + 5, + 100 + ] + ] + """ + self._assert_cpu() + return _pc().take(self, indices) + + def drop_null(self): + """ + Remove missing values from a chunked array. + See :func:`pyarrow.compute.drop_null` for full description. + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.chunked_array([[2, 2, None], [4, 5, 100]]) + >>> n_legs + + [ + [ + 2, + 2, + null + ], + [ + 4, + 5, + 100 + ] + ] + >>> n_legs.drop_null() + + [ + [ + 2, + 2 + ], + [ + 4, + 5, + 100 + ] + ] + """ + self._assert_cpu() + return _pc().drop_null(self) + + def sort(self, order="ascending", **kwargs): + """ + Sort the ChunkedArray + + Parameters + ---------- + order : str, default "ascending" + Which order to sort values in. + Accepted values are "ascending", "descending". + **kwargs : dict, optional + Additional sorting options. + As allowed by :class:`SortOptions` + + Returns + ------- + result : ChunkedArray + """ + self._assert_cpu() + indices = _pc().sort_indices( + self, + options=_pc().SortOptions(sort_keys=[("", order)], **kwargs) + ) + return self.take(indices) + + def unify_dictionaries(self, MemoryPool memory_pool=None): + """ + Unify dictionaries across all chunks. + + This method returns an equivalent chunked array, but where all + chunks share the same dictionary values. Dictionary indices are + transposed accordingly. + + If there are no dictionaries in the chunked array, it is returned + unchanged. + + Parameters + ---------- + memory_pool : MemoryPool, default None + For memory allocations, if required, otherwise use default pool + + Returns + ------- + result : ChunkedArray + + Examples + -------- + >>> import pyarrow as pa + >>> arr_1 = pa.array(["Flamingo", "Parrot", "Dog"]).dictionary_encode() + >>> arr_2 = pa.array(["Horse", "Brittle stars", "Centipede"]).dictionary_encode() + >>> c_arr = pa.chunked_array([arr_1, arr_2]) + >>> c_arr + + [ + ... + -- dictionary: + [ + "Flamingo", + "Parrot", + "Dog" + ] + -- indices: + [ + 0, + 1, + 2 + ], + ... + -- dictionary: + [ + "Horse", + "Brittle stars", + "Centipede" + ] + -- indices: + [ + 0, + 1, + 2 + ] + ] + >>> c_arr.unify_dictionaries() + + [ + ... + -- dictionary: + [ + "Flamingo", + "Parrot", + "Dog", + "Horse", + "Brittle stars", + "Centipede" + ] + -- indices: + [ + 0, + 1, + 2 + ], + ... + -- dictionary: + [ + "Flamingo", + "Parrot", + "Dog", + "Horse", + "Brittle stars", + "Centipede" + ] + -- indices: + [ + 3, + 4, + 5 + ] + ] + """ + self._assert_cpu() + cdef: + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + shared_ptr[CChunkedArray] c_result + + with nogil: + c_result = GetResultValue(CDictionaryUnifier.UnifyChunkedArray( + self.sp_chunked_array, pool)) + + return pyarrow_wrap_chunked_array(c_result) + + @property + def num_chunks(self): + """ + Number of underlying chunks. + + Returns + ------- + int + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.chunked_array([[2, 2, None], [4, 5, 100]]) + >>> n_legs.num_chunks + 2 + """ + return self.chunked_array.num_chunks() + + def chunk(self, i): + """ + Select a chunk by its index. + + Parameters + ---------- + i : int + + Returns + ------- + pyarrow.Array + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.chunked_array([[2, 2, None], [4, 5, 100]]) + >>> n_legs.chunk(1) + + [ + 4, + 5, + 100 + ] + """ + if i >= self.num_chunks or i < 0: + raise IndexError('Chunk index out of range.') + + return pyarrow_wrap_array(self.chunked_array.chunk(i)) + + @property + def chunks(self): + """ + Convert to a list of single-chunked arrays. + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.chunked_array([[2, 2, None], [4, 5, 100]]) + >>> n_legs + + [ + [ + 2, + 2, + null + ], + [ + 4, + 5, + 100 + ] + ] + >>> n_legs.chunks + [ + [ + 2, + 2, + null + ], + [ + 4, + 5, + 100 + ]] + """ + return list(self.iterchunks()) + + def iterchunks(self): + """ + Convert to an iterator of ChunkArrays. + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.chunked_array([[2, 2, 4], [4, None, 100]]) + >>> for i in n_legs.iterchunks(): + ... print(i.null_count) + ... + 0 + 1 + + """ + for i in range(self.num_chunks): + yield self.chunk(i) + + def to_pylist(self): + """ + Convert to a list of native Python objects. + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.chunked_array([[2, 2, 4], [4, None, 100]]) + >>> n_legs.to_pylist() + [2, 2, 4, 4, None, 100] + """ + self._assert_cpu() + result = [] + for i in range(self.num_chunks): + result += self.chunk(i).to_pylist() + return result + + def __arrow_c_stream__(self, requested_schema=None): + """ + Export to a C ArrowArrayStream PyCapsule. + + Parameters + ---------- + requested_schema : PyCapsule, default None + The schema to which the stream should be casted, passed as a + PyCapsule containing a C ArrowSchema representation of the + requested schema. + + Returns + ------- + PyCapsule + A capsule containing a C ArrowArrayStream struct. + """ + self._assert_cpu() + cdef: + ChunkedArray chunked + ArrowArrayStream* c_stream = NULL + + if requested_schema is not None: + target_type = DataType._import_from_c_capsule(requested_schema) + + if target_type != self.type: + try: + chunked = self.cast(target_type, safe=True) + except ArrowInvalid as e: + raise ValueError( + f"Could not cast {self.type} to requested type {target_type}: {e}" + ) + else: + chunked = self + else: + chunked = self + + stream_capsule = alloc_c_stream(&c_stream) + + with nogil: + check_status(ExportChunkedArray(chunked.sp_chunked_array, c_stream)) + + return stream_capsule + + @staticmethod + def _import_from_c_capsule(stream): + """ + Import ChunkedArray from a C ArrowArrayStream PyCapsule. + + Parameters + ---------- + stream: PyCapsule + A capsule containing a C ArrowArrayStream PyCapsule. + + Returns + ------- + ChunkedArray + """ + cdef: + ArrowArrayStream* c_stream + shared_ptr[CChunkedArray] c_chunked_array + ChunkedArray self + + c_stream = PyCapsule_GetPointer( + stream, 'arrow_array_stream' + ) + + with nogil: + c_chunked_array = GetResultValue(ImportChunkedArray(c_stream)) + + self = ChunkedArray.__new__(ChunkedArray) + self.init(c_chunked_array) + return self + + @property + def is_cpu(self): + """ + Whether all chunks in the ChunkedArray are CPU-accessible. + """ + if not self._init_is_cpu: + self._is_cpu = self.chunked_array.is_cpu() + self._init_is_cpu = True + return self._is_cpu + + def _assert_cpu(self): + if not self.is_cpu: + raise NotImplementedError("Implemented only for data on CPU device") + + +def chunked_array(arrays, type=None): + """ + Construct chunked array from list of array-like objects + + Parameters + ---------- + arrays : Array, list of Array, or array-like + Must all be the same data type. Can be empty only if type also passed. + Any Arrow-compatible array that implements the Arrow PyCapsule Protocol + (has an ``__arrow_c_array__`` or ``__arrow_c_stream__`` method) can be + passed as well. + type : DataType or string coercible to DataType + + Returns + ------- + ChunkedArray + + Examples + -------- + >>> import pyarrow as pa + >>> pa.chunked_array([], type=pa.int8()) + + [ + ... + ] + + >>> pa.chunked_array([[2, 2, 4], [4, 5, 100]]) + + [ + [ + 2, + 2, + 4 + ], + [ + 4, + 5, + 100 + ] + ] + """ + cdef: + Array arr + vector[shared_ptr[CArray]] c_arrays + shared_ptr[CChunkedArray] c_result + shared_ptr[CDataType] c_type + + type = ensure_type(type, allow_none=True) + + if isinstance(arrays, Array): + arrays = [arrays] + elif hasattr(arrays, "__arrow_c_stream__"): + if type is not None: + requested_type = type.__arrow_c_schema__() + else: + requested_type = None + capsule = arrays.__arrow_c_stream__(requested_type) + result = ChunkedArray._import_from_c_capsule(capsule) + if type is not None and result.type != type: + # __arrow_c_stream__ coerces schema with best effort, so we might + # need to cast it if the producer wasn't able to cast to exact schema. + result = result.cast(type) + return result + elif hasattr(arrays, "__arrow_c_array__"): + arr = array(arrays, type=type) + arrays = [arr] + + for x in arrays: + arr = x if isinstance(x, Array) else array(x, type=type) + + if type is None: + # it allows more flexible chunked array construction from to coerce + # subsequent arrays to the firstly inferred array type + # it also spares the inference overhead after the first chunk + type = arr.type + + c_arrays.push_back(arr.sp_array) + + c_type = pyarrow_unwrap_data_type(type) + with nogil: + c_result = GetResultValue(CChunkedArray.Make(c_arrays, c_type)) + return pyarrow_wrap_chunked_array(c_result) + + +cdef _schema_from_arrays(arrays, names, metadata, shared_ptr[CSchema]* schema): + cdef: + Py_ssize_t K = len(arrays) + c_string c_name + shared_ptr[CDataType] c_type + shared_ptr[const CKeyValueMetadata] c_meta + vector[shared_ptr[CField]] c_fields + + if metadata is not None: + c_meta = KeyValueMetadata(metadata).unwrap() + + if K == 0: + if names is None or len(names) == 0: + schema.reset(new CSchema(c_fields, c_meta)) + return arrays + else: + raise ValueError('Length of names ({}) does not match ' + 'length of arrays ({})'.format(len(names), K)) + + c_fields.resize(K) + + if names is None: + raise ValueError('Must pass names or schema when constructing ' + 'Table or RecordBatch.') + + if len(names) != K: + raise ValueError('Length of names ({}) does not match ' + 'length of arrays ({})'.format(len(names), K)) + + converted_arrays = [] + for i in range(K): + val = arrays[i] + if not isinstance(val, (Array, ChunkedArray)): + val = array(val) + + c_type = ( val.type).sp_type + + if names[i] is None: + c_name = b'None' + else: + c_name = tobytes(names[i]) + c_fields[i].reset(new CField(c_name, c_type, True)) + converted_arrays.append(val) + + schema.reset(new CSchema(c_fields, c_meta)) + return converted_arrays + + +cdef _sanitize_arrays(arrays, names, schema, metadata, + shared_ptr[CSchema]* c_schema): + cdef Schema cy_schema + if schema is None: + converted_arrays = _schema_from_arrays(arrays, names, metadata, + c_schema) + else: + if names is not None: + raise ValueError('Cannot pass both schema and names') + if metadata is not None: + raise ValueError('Cannot pass both schema and metadata') + cy_schema = schema + + if len(schema) != len(arrays): + raise ValueError('Schema and number of arrays unequal') + + c_schema[0] = cy_schema.sp_schema + converted_arrays = [] + for i, item in enumerate(arrays): + item = asarray(item, type=schema[i].type) + converted_arrays.append(item) + return converted_arrays + +cdef class _Tabular(_PandasConvertible): + """Internal: An interface for common operations on tabular objects.""" + + def __init__(self): + raise TypeError(f"Do not call {self.__class__.__name__}'s constructor directly, use " + f"one of the `{self.__class__.__name__}.from_*` functions instead.") + + def __array__(self, dtype=None, copy=None): + self._assert_cpu() + if copy is False: + raise ValueError( + "Unable to avoid a copy while creating a numpy array as requested " + f"(converting a pyarrow.{self.__class__.__name__} always results " + "in a copy).\n" + "If using `np.array(obj, copy=False)` replace it with " + "`np.asarray(obj)` to allow a copy when needed" + ) + # 'copy' can further be ignored because stacking will result in a copy + column_arrays = [ + np.asarray(self.column(i), dtype=dtype) for i in range(self.num_columns) + ] + if column_arrays: + arr = np.stack(column_arrays, axis=1) + else: + arr = np.empty((self.num_rows, 0), dtype=dtype) + return arr + + def __dataframe__(self, nan_as_null: bool = False, allow_copy: bool = True): + """ + Return the dataframe interchange object implementing the interchange protocol. + + Parameters + ---------- + nan_as_null : bool, default False + Whether to tell the DataFrame to overwrite null values in the data + with ``NaN`` (or ``NaT``). + allow_copy : bool, default True + Whether to allow memory copying when exporting. If set to False + it would cause non-zero-copy exports to fail. + + Returns + ------- + DataFrame interchange object + The object which consuming library can use to ingress the dataframe. + + Notes + ----- + Details on the interchange protocol: + https://data-apis.org/dataframe-protocol/latest/index.html + `nan_as_null` currently has no effect; once support for nullable extension + dtypes is added, this value should be propagated to columns. + """ + + from pyarrow.interchange.dataframe import _PyArrowDataFrame + + return _PyArrowDataFrame(self, nan_as_null, allow_copy) + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return NotImplemented + + def __getitem__(self, key): + """ + Slice or return column at given index or column name + + Parameters + ---------- + key : integer, str, or slice + Slices with step not equal to 1 (or None) will produce a copy + rather than a zero-copy view + + Returns + ------- + Array (from RecordBatch) or ChunkedArray (from Table) for column input. + RecordBatch or Table for slice input. + """ + if isinstance(key, slice): + return _normalize_slice(self, key) + + return self.column(key) + + def __len__(self): + return self.num_rows + + def __repr__(self): + if not self._is_initialized(): + raise ValueError("This object's internal pointer is NULL, do not " + "use any methods or attributes on this object") + return self.to_string(preview_cols=10) + + def _column(self, int i): + raise NotImplementedError + + def _ensure_integer_index(self, i): + """ + Ensure integer index (convert string column name to integer if needed). + """ + if isinstance(i, (bytes, str)): + field_indices = self.schema.get_all_field_indices(i) + + if len(field_indices) == 0: + raise KeyError("Field \"{}\" does not exist in schema" + .format(i)) + elif len(field_indices) > 1: + raise KeyError("Field \"{}\" exists {} times in schema" + .format(i, len(field_indices))) + else: + return field_indices[0] + elif isinstance(i, int): + return i + else: + raise TypeError("Index must either be string or integer") + + def _is_initialized(self): + raise NotImplementedError + + def column(self, i): + """ + Select single column from Table or RecordBatch. + + Parameters + ---------- + i : int or string + The index or name of the column to retrieve. + + Returns + ------- + column : Array (for RecordBatch) or ChunkedArray (for Table) + + Examples + -------- + Table (works similarly for RecordBatch) + + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + >>> table = pa.Table.from_pandas(df) + + Select a column by numeric index: + + >>> table.column(0) + + [ + [ + 2, + 4, + 5, + 100 + ] + ] + + Select a column by its name: + + >>> table.column("animals") + + [ + [ + "Flamingo", + "Horse", + "Brittle stars", + "Centipede" + ] + ] + """ + return self._column(self._ensure_integer_index(i)) + + @property + def column_names(self): + """ + Names of the Table or RecordBatch columns. + + Returns + ------- + list of str + + Examples + -------- + Table (works similarly for RecordBatch) + + >>> import pyarrow as pa + >>> table = pa.Table.from_arrays([[2, 4, 5, 100], + ... ["Flamingo", "Horse", "Brittle stars", "Centipede"]], + ... names=['n_legs', 'animals']) + >>> table.column_names + ['n_legs', 'animals'] + """ + return [self.field(i).name for i in range(self.num_columns)] + + @property + def columns(self): + """ + List of all columns in numerical order. + + Returns + ------- + columns : list of Array (for RecordBatch) or list of ChunkedArray (for Table) + + Examples + -------- + Table (works similarly for RecordBatch) + + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'n_legs': [None, 4, 5, None], + ... 'animals': ["Flamingo", "Horse", None, "Centipede"]}) + >>> table = pa.Table.from_pandas(df) + >>> table.columns + [ + [ + [ + null, + 4, + 5, + null + ] + ], + [ + [ + "Flamingo", + "Horse", + null, + "Centipede" + ] + ]] + """ + return [self._column(i) for i in range(self.num_columns)] + + def drop_null(self): + """ + Remove rows that contain missing values from a Table or RecordBatch. + + See :func:`pyarrow.compute.drop_null` for full usage. + + Returns + ------- + Table or RecordBatch + A tabular object with the same schema, with rows containing + no missing values. + + Examples + -------- + Table (works similarly for RecordBatch) + + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'year': [None, 2022, 2019, 2021], + ... 'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", None, "Centipede"]}) + >>> table = pa.Table.from_pandas(df) + >>> table.drop_null() + pyarrow.Table + year: double + n_legs: int64 + animals: string + ---- + year: [[2022,2021]] + n_legs: [[4,100]] + animals: [["Horse","Centipede"]] + """ + self._assert_cpu() + return _pc().drop_null(self) + + def field(self, i): + """ + Select a schema field by its column name or numeric index. + + Parameters + ---------- + i : int or string + The index or name of the field to retrieve. + + Returns + ------- + Field + + Examples + -------- + Table (works similarly for RecordBatch) + + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + >>> table = pa.Table.from_pandas(df) + >>> table.field(0) + pyarrow.Field + >>> table.field(1) + pyarrow.Field + """ + return self.schema.field(i) + + @classmethod + def from_pydict(cls, mapping, schema=None, metadata=None): + """ + Construct a Table or RecordBatch from Arrow arrays or columns. + + Parameters + ---------- + mapping : dict or Mapping + A mapping of strings to Arrays or Python lists. + schema : Schema, default None + If not passed, will be inferred from the Mapping values. + metadata : dict or Mapping, default None + Optional metadata for the schema (if inferred). + + Returns + ------- + Table or RecordBatch + + Examples + -------- + Table (works similarly for RecordBatch) + + >>> import pyarrow as pa + >>> n_legs = pa.array([2, 4, 5, 100]) + >>> animals = pa.array(["Flamingo", "Horse", "Brittle stars", "Centipede"]) + >>> pydict = {'n_legs': n_legs, 'animals': animals} + + Construct a Table from a dictionary of arrays: + + >>> pa.Table.from_pydict(pydict) + pyarrow.Table + n_legs: int64 + animals: string + ---- + n_legs: [[2,4,5,100]] + animals: [["Flamingo","Horse","Brittle stars","Centipede"]] + >>> pa.Table.from_pydict(pydict).schema + n_legs: int64 + animals: string + + Construct a Table from a dictionary of arrays with metadata: + + >>> my_metadata={"n_legs": "Number of legs per animal"} + >>> pa.Table.from_pydict(pydict, metadata=my_metadata).schema + n_legs: int64 + animals: string + -- schema metadata -- + n_legs: 'Number of legs per animal' + + Construct a Table from a dictionary of arrays with pyarrow schema: + + >>> my_schema = pa.schema([ + ... pa.field('n_legs', pa.int64()), + ... pa.field('animals', pa.string())], + ... metadata={"n_legs": "Number of legs per animal"}) + >>> pa.Table.from_pydict(pydict, schema=my_schema).schema + n_legs: int64 + animals: string + -- schema metadata -- + n_legs: 'Number of legs per animal' + """ + + return _from_pydict(cls=cls, + mapping=mapping, + schema=schema, + metadata=metadata) + + @classmethod + def from_pylist(cls, mapping, schema=None, metadata=None): + """ + Construct a Table or RecordBatch from list of rows / dictionaries. + + Parameters + ---------- + mapping : list of dicts of rows + A mapping of strings to row values. + schema : Schema, default None + If not passed, will be inferred from the first row of the + mapping values. + metadata : dict or Mapping, default None + Optional metadata for the schema (if inferred). + + Returns + ------- + Table or RecordBatch + + Examples + -------- + Table (works similarly for RecordBatch) + + >>> import pyarrow as pa + >>> pylist = [{'n_legs': 2, 'animals': 'Flamingo'}, + ... {'n_legs': 4, 'animals': 'Dog'}] + + Construct a Table from a list of rows: + + >>> pa.Table.from_pylist(pylist) + pyarrow.Table + n_legs: int64 + animals: string + ---- + n_legs: [[2,4]] + animals: [["Flamingo","Dog"]] + + Construct a Table from a list of rows with metadata: + + >>> my_metadata={"n_legs": "Number of legs per animal"} + >>> pa.Table.from_pylist(pylist, metadata=my_metadata).schema + n_legs: int64 + animals: string + -- schema metadata -- + n_legs: 'Number of legs per animal' + + Construct a Table from a list of rows with pyarrow schema: + + >>> my_schema = pa.schema([ + ... pa.field('n_legs', pa.int64()), + ... pa.field('animals', pa.string())], + ... metadata={"n_legs": "Number of legs per animal"}) + >>> pa.Table.from_pylist(pylist, schema=my_schema).schema + n_legs: int64 + animals: string + -- schema metadata -- + n_legs: 'Number of legs per animal' + """ + + return _from_pylist(cls=cls, + mapping=mapping, + schema=schema, + metadata=metadata) + + def itercolumns(self): + """ + Iterator over all columns in their numerical order. + + Yields + ------ + Array (for RecordBatch) or ChunkedArray (for Table) + + Examples + -------- + Table (works similarly for RecordBatch) + + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'n_legs': [None, 4, 5, None], + ... 'animals': ["Flamingo", "Horse", None, "Centipede"]}) + >>> table = pa.Table.from_pandas(df) + >>> for i in table.itercolumns(): + ... print(i.null_count) + ... + 2 + 1 + """ + for i in range(self.num_columns): + yield self._column(i) + + @property + def num_columns(self): + raise NotImplementedError + + @property + def num_rows(self): + raise NotImplementedError + + @property + def shape(self): + """ + Dimensions of the table or record batch: (#rows, #columns). + + Returns + ------- + (int, int) + Number of rows and number of columns. + + Examples + -------- + >>> import pyarrow as pa + >>> table = pa.table({'n_legs': [None, 4, 5, None], + ... 'animals': ["Flamingo", "Horse", None, "Centipede"]}) + >>> table.shape + (4, 2) + """ + return (self.num_rows, self.num_columns) + + @property + def schema(self): + raise NotImplementedError + + def sort_by(self, sorting, **kwargs): + """ + Sort the Table or RecordBatch by one or multiple columns. + + Parameters + ---------- + sorting : str or list[tuple(name, order)] + Name of the column to use to sort (ascending), or + a list of multiple sorting conditions where + each entry is a tuple with column name + and sorting order ("ascending" or "descending") + **kwargs : dict, optional + Additional sorting options. + As allowed by :class:`SortOptions` + + Returns + ------- + Table or RecordBatch + A new tabular object sorted according to the sort keys. + + Examples + -------- + Table (works similarly for RecordBatch) + + >>> import pandas as pd + >>> import pyarrow as pa + >>> df = pd.DataFrame({'year': [2020, 2022, 2021, 2022, 2019, 2021], + ... 'n_legs': [2, 2, 4, 4, 5, 100], + ... 'animal': ["Flamingo", "Parrot", "Dog", "Horse", + ... "Brittle stars", "Centipede"]}) + >>> table = pa.Table.from_pandas(df) + >>> table.sort_by('animal') + pyarrow.Table + year: int64 + n_legs: int64 + animal: string + ---- + year: [[2019,2021,2021,2020,2022,2022]] + n_legs: [[5,100,4,2,4,2]] + animal: [["Brittle stars","Centipede","Dog","Flamingo","Horse","Parrot"]] + """ + self._assert_cpu() + if isinstance(sorting, str): + sorting = [(sorting, "ascending")] + + indices = _pc().sort_indices( + self, + options=_pc().SortOptions(sort_keys=sorting, **kwargs) + ) + return self.take(indices) + + def take(self, object indices): + """ + Select rows from a Table or RecordBatch. + + See :func:`pyarrow.compute.take` for full usage. + + Parameters + ---------- + indices : Array or array-like + The indices in the tabular object whose rows will be returned. + + Returns + ------- + Table or RecordBatch + A tabular object with the same schema, containing the taken rows. + + Examples + -------- + Table (works similarly for RecordBatch) + + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'year': [2020, 2022, 2019, 2021], + ... 'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + >>> table = pa.Table.from_pandas(df) + >>> table.take([1,3]) + pyarrow.Table + year: int64 + n_legs: int64 + animals: string + ---- + year: [[2022,2021]] + n_legs: [[4,100]] + animals: [["Horse","Centipede"]] + """ + self._assert_cpu() + return _pc().take(self, indices) + + def filter(self, mask, object null_selection_behavior="drop"): + """ + Select rows from the table or record batch based on a boolean mask. + + The Table can be filtered based on a mask, which will be passed to + :func:`pyarrow.compute.filter` to perform the filtering, or it can + be filtered through a boolean :class:`.Expression` + + Parameters + ---------- + mask : Array or array-like or .Expression + The boolean mask or the :class:`.Expression` to filter the table with. + null_selection_behavior : str, default "drop" + How nulls in the mask should be handled, does nothing if + an :class:`.Expression` is used. + + Returns + ------- + filtered : Table or RecordBatch + A tabular object of the same schema, with only the rows selected + by applied filtering + + Examples + -------- + Using a Table (works similarly for RecordBatch): + + >>> import pyarrow as pa + >>> table = pa.table({'year': [2020, 2022, 2019, 2021], + ... 'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + + Define an expression and select rows: + + >>> import pyarrow.compute as pc + >>> expr = pc.field("year") <= 2020 + >>> table.filter(expr) + pyarrow.Table + year: int64 + n_legs: int64 + animals: string + ---- + year: [[2020,2019]] + n_legs: [[2,5]] + animals: [["Flamingo","Brittle stars"]] + + Define a mask and select rows: + + >>> mask=[True, True, False, None] + >>> table.filter(mask) + pyarrow.Table + year: int64 + n_legs: int64 + animals: string + ---- + year: [[2020,2022]] + n_legs: [[2,4]] + animals: [["Flamingo","Horse"]] + >>> table.filter(mask, null_selection_behavior='emit_null') + pyarrow.Table + year: int64 + n_legs: int64 + animals: string + ---- + year: [[2020,2022,null]] + n_legs: [[2,4,null]] + animals: [["Flamingo","Horse",null]] + """ + self._assert_cpu() + if isinstance(mask, _pc().Expression): + return _pac()._filter_table(self, mask) + else: + return _pc().filter(self, mask, null_selection_behavior) + + def to_pydict(self): + """ + Convert the Table or RecordBatch to a dict or OrderedDict. + + Returns + ------- + dict + + Examples + -------- + Table (works similarly for RecordBatch) + + >>> import pyarrow as pa + >>> n_legs = pa.array([2, 2, 4, 4, 5, 100]) + >>> animals = pa.array(["Flamingo", "Parrot", "Dog", "Horse", "Brittle stars", "Centipede"]) + >>> table = pa.Table.from_arrays([n_legs, animals], names=["n_legs", "animals"]) + >>> table.to_pydict() + {'n_legs': [2, 2, 4, 4, 5, 100], 'animals': ['Flamingo', 'Parrot', ..., 'Centipede']} + """ + entries = [] + for i in range(self.num_columns): + name = self.field(i).name + column = self[i].to_pylist() + entries.append((name, column)) + return ordered_dict(entries) + + def to_pylist(self): + """ + Convert the Table or RecordBatch to a list of rows / dictionaries. + + Returns + ------- + list + + Examples + -------- + Table (works similarly for RecordBatch) + + >>> import pyarrow as pa + >>> data = [[2, 4, 5, 100], + ... ["Flamingo", "Horse", "Brittle stars", "Centipede"]] + >>> table = pa.table(data, names=["n_legs", "animals"]) + >>> table.to_pylist() + [{'n_legs': 2, 'animals': 'Flamingo'}, {'n_legs': 4, 'animals': 'Horse'}, ... + """ + pydict = self.to_pydict() + names = self.schema.names + pylist = [{column: pydict[column][row] for column in names} + for row in range(self.num_rows)] + return pylist + + def to_string(self, *, show_metadata=False, preview_cols=0): + """ + Return human-readable string representation of Table or RecordBatch. + + Parameters + ---------- + show_metadata : bool, default False + Display Field-level and Schema-level KeyValueMetadata. + preview_cols : int, default 0 + Display values of the columns for the first N columns. + + Returns + ------- + str + """ + # Use less verbose schema output. + schema_as_string = self.schema.to_string( + show_field_metadata=show_metadata, + show_schema_metadata=show_metadata + ) + title = 'pyarrow.{}\n{}'.format(type(self).__name__, schema_as_string) + pieces = [title] + if preview_cols: + pieces.append('----') + for i in range(min(self.num_columns, preview_cols)): + pieces.append('{}: {}'.format( + self.field(i).name, + self.column(i).to_string(indent=0, skip_new_lines=True) + )) + if preview_cols < self.num_columns: + pieces.append('...') + return '\n'.join(pieces) + + def remove_column(self, int i): + # implemented in RecordBatch/Table subclasses + raise NotImplementedError + + def drop_columns(self, columns): + """ + Drop one or more columns and return a new Table or RecordBatch. + + Parameters + ---------- + columns : str or list[str] + Field name(s) referencing existing column(s). + + Raises + ------ + KeyError + If any of the passed column names do not exist. + + Returns + ------- + Table or RecordBatch + A tabular object without the column(s). + + Examples + -------- + Table (works similarly for RecordBatch) + + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + >>> table = pa.Table.from_pandas(df) + + Drop one column: + + >>> table.drop_columns("animals") + pyarrow.Table + n_legs: int64 + ---- + n_legs: [[2,4,5,100]] + + Drop one or more columns: + + >>> table.drop_columns(["n_legs", "animals"]) + pyarrow.Table + ... + ---- + """ + if isinstance(columns, str): + columns = [columns] + + indices = [] + for col in columns: + idx = self.schema.get_field_index(col) + if idx == -1: + raise KeyError("Column {!r} not found".format(col)) + indices.append(idx) + + indices.sort() + indices.reverse() + + res = self + for idx in indices: + res = res.remove_column(idx) + + return res + + def add_column(self, int i, field_, column): + # implemented in RecordBatch/Table subclasses + raise NotImplementedError + + def append_column(self, field_, column): + """ + Append column at end of columns. + + Parameters + ---------- + field_ : str or Field + If a string is passed then the type is deduced from the column + data. + column : Array or value coercible to array + Column data. + + Returns + ------- + Table or RecordBatch + New table or record batch with the passed column added. + + Examples + -------- + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + >>> table = pa.Table.from_pandas(df) + + Append column at the end: + + >>> year = [2021, 2022, 2019, 2021] + >>> table.append_column('year', [year]) + pyarrow.Table + n_legs: int64 + animals: string + year: int64 + ---- + n_legs: [[2,4,5,100]] + animals: [["Flamingo","Horse","Brittle stars","Centipede"]] + year: [[2021,2022,2019,2021]] + """ + return self.add_column(self.num_columns, field_, column) + + cdef void _assert_cpu(self) except *: + return + + +cdef class RecordBatch(_Tabular): + """ + Batch of rows of columns of equal length + + Warnings + -------- + Do not call this class's constructor directly, use one of the + ``RecordBatch.from_*`` functions instead. + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.array([2, 2, 4, 4, 5, 100]) + >>> animals = pa.array(["Flamingo", "Parrot", "Dog", "Horse", "Brittle stars", "Centipede"]) + >>> names = ["n_legs", "animals"] + + Constructing a RecordBatch from arrays: + + >>> pa.RecordBatch.from_arrays([n_legs, animals], names=names) + pyarrow.RecordBatch + n_legs: int64 + animals: string + ---- + n_legs: [2,2,4,4,5,100] + animals: ["Flamingo","Parrot","Dog","Horse","Brittle stars","Centipede"] + >>> pa.RecordBatch.from_arrays([n_legs, animals], names=names).to_pandas() + n_legs animals + 0 2 Flamingo + 1 2 Parrot + 2 4 Dog + 3 4 Horse + 4 5 Brittle stars + 5 100 Centipede + + Constructing a RecordBatch from pandas DataFrame: + + >>> import pandas as pd + >>> df = pd.DataFrame({'year': [2020, 2022, 2021, 2022], + ... 'month': [3, 5, 7, 9], + ... 'day': [1, 5, 9, 13], + ... 'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + >>> pa.RecordBatch.from_pandas(df) + pyarrow.RecordBatch + year: int64 + month: int64 + day: int64 + n_legs: int64 + animals: string + ---- + year: [2020,2022,2021,2022] + month: [3,5,7,9] + day: [1,5,9,13] + n_legs: [2,4,5,100] + animals: ["Flamingo","Horse","Brittle stars","Centipede"] + >>> pa.RecordBatch.from_pandas(df).to_pandas() + year month day n_legs animals + 0 2020 3 1 2 Flamingo + 1 2022 5 5 4 Horse + 2 2021 7 9 5 Brittle stars + 3 2022 9 13 100 Centipede + + Constructing a RecordBatch from pylist: + + >>> pylist = [{'n_legs': 2, 'animals': 'Flamingo'}, + ... {'n_legs': 4, 'animals': 'Dog'}] + >>> pa.RecordBatch.from_pylist(pylist).to_pandas() + n_legs animals + 0 2 Flamingo + 1 4 Dog + + You can also construct a RecordBatch using :func:`pyarrow.record_batch`: + + >>> pa.record_batch([n_legs, animals], names=names).to_pandas() + n_legs animals + 0 2 Flamingo + 1 2 Parrot + 2 4 Dog + 3 4 Horse + 4 5 Brittle stars + 5 100 Centipede + + >>> pa.record_batch(df) + pyarrow.RecordBatch + year: int64 + month: int64 + day: int64 + n_legs: int64 + animals: string + ---- + year: [2020,2022,2021,2022] + month: [3,5,7,9] + day: [1,5,9,13] + n_legs: [2,4,5,100] + animals: ["Flamingo","Horse","Brittle stars","Centipede"] + """ + + def __cinit__(self): + self.batch = NULL + self._schema = None + + cdef void init(self, const shared_ptr[CRecordBatch]& batch): + self.sp_batch = batch + self.batch = batch.get() + + def _is_initialized(self): + return self.batch != NULL + + def __reduce__(self): + self._assert_cpu() + return _reconstruct_record_batch, (self.columns, self.schema) + + def validate(self, *, full=False): + """ + Perform validation checks. An exception is raised if validation fails. + + By default only cheap validation checks are run. Pass `full=True` + for thorough validation checks (potentially O(n)). + + Parameters + ---------- + full : bool, default False + If True, run expensive checks, otherwise cheap checks only. + + Raises + ------ + ArrowInvalid + """ + if full: + self._assert_cpu() + with nogil: + check_status(self.batch.ValidateFull()) + else: + with nogil: + check_status(self.batch.Validate()) + + def replace_schema_metadata(self, metadata=None): + """ + Create shallow copy of record batch by replacing schema + key-value metadata with the indicated new metadata (which may be None, + which deletes any existing metadata + + Parameters + ---------- + metadata : dict, default None + + Returns + ------- + shallow_copy : RecordBatch + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.array([2, 2, 4, 4, 5, 100]) + + Constructing a RecordBatch with schema and metadata: + + >>> my_schema = pa.schema([ + ... pa.field('n_legs', pa.int64())], + ... metadata={"n_legs": "Number of legs per animal"}) + >>> batch = pa.RecordBatch.from_arrays([n_legs], schema=my_schema) + >>> batch.schema + n_legs: int64 + -- schema metadata -- + n_legs: 'Number of legs per animal' + + Shallow copy of a RecordBatch with deleted schema metadata: + + >>> batch.replace_schema_metadata().schema + n_legs: int64 + """ + cdef: + shared_ptr[const CKeyValueMetadata] c_meta + shared_ptr[CRecordBatch] c_batch + + metadata = ensure_metadata(metadata, allow_none=True) + c_meta = pyarrow_unwrap_metadata(metadata) + with nogil: + c_batch = self.batch.ReplaceSchemaMetadata(c_meta) + + return pyarrow_wrap_batch(c_batch) + + @property + def num_columns(self): + """ + Number of columns + + Returns + ------- + int + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.array([2, 2, 4, 4, 5, 100]) + >>> animals = pa.array(["Flamingo", "Parrot", "Dog", "Horse", "Brittle stars", "Centipede"]) + >>> batch = pa.RecordBatch.from_arrays([n_legs, animals], + ... names=["n_legs", "animals"]) + >>> batch.num_columns + 2 + """ + return self.batch.num_columns() + + @property + def num_rows(self): + """ + Number of rows + + Due to the definition of a RecordBatch, all columns have the same + number of rows. + + Returns + ------- + int + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.array([2, 2, 4, 4, 5, 100]) + >>> animals = pa.array(["Flamingo", "Parrot", "Dog", "Horse", "Brittle stars", "Centipede"]) + >>> batch = pa.RecordBatch.from_arrays([n_legs, animals], + ... names=["n_legs", "animals"]) + >>> batch.num_rows + 6 + """ + return self.batch.num_rows() + + @property + def schema(self): + """ + Schema of the RecordBatch and its columns + + Returns + ------- + pyarrow.Schema + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.array([2, 2, 4, 4, 5, 100]) + >>> animals = pa.array(["Flamingo", "Parrot", "Dog", "Horse", "Brittle stars", "Centipede"]) + >>> batch = pa.RecordBatch.from_arrays([n_legs, animals], + ... names=["n_legs", "animals"]) + >>> batch.schema + n_legs: int64 + animals: string + """ + if self._schema is None: + self._schema = pyarrow_wrap_schema(self.batch.schema()) + + return self._schema + + def _column(self, int i): + """ + Select single column from record batch by its numeric index. + + Parameters + ---------- + i : int + The index of the column to retrieve. + + Returns + ------- + column : pyarrow.Array + """ + cdef int index = _normalize_index(i, self.num_columns) + cdef Array result = pyarrow_wrap_array(self.batch.column(index)) + result._name = self.schema[index].name + return result + + @property + def nbytes(self): + """ + Total number of bytes consumed by the elements of the record batch. + + In other words, the sum of bytes from all buffer ranges referenced. + + Unlike `get_total_buffer_size` this method will account for array + offsets. + + If buffers are shared between arrays then the shared + portion will only be counted multiple times. + + The dictionary of dictionary arrays will always be counted in their + entirety even if the array only references a portion of the dictionary. + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.array([2, 2, 4, 4, 5, 100]) + >>> animals = pa.array(["Flamingo", "Parrot", "Dog", "Horse", "Brittle stars", "Centipede"]) + >>> batch = pa.RecordBatch.from_arrays([n_legs, animals], + ... names=["n_legs", "animals"]) + >>> batch.nbytes + 116 + """ + self._assert_cpu() + cdef: + CResult[int64_t] c_res_buffer + + with nogil: + c_res_buffer = ReferencedBufferSize(deref(self.batch)) + size = GetResultValue(c_res_buffer) + return size + + def get_total_buffer_size(self): + """ + The sum of bytes in each buffer referenced by the record batch + + An array may only reference a portion of a buffer. + This method will overestimate in this case and return the + byte size of the entire buffer. + + If a buffer is referenced multiple times then it will + only be counted once. + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.array([2, 2, 4, 4, 5, 100]) + >>> animals = pa.array(["Flamingo", "Parrot", "Dog", "Horse", "Brittle stars", "Centipede"]) + >>> batch = pa.RecordBatch.from_arrays([n_legs, animals], + ... names=["n_legs", "animals"]) + >>> batch.get_total_buffer_size() + 120 + """ + self._assert_cpu() + cdef: + int64_t total_buffer_size + + total_buffer_size = TotalBufferSize(deref(self.batch)) + return total_buffer_size + + def __sizeof__(self): + return super(RecordBatch, self).__sizeof__() + self.nbytes + + def add_column(self, int i, field_, column): + """ + Add column to RecordBatch at position i. + + A new record batch is returned with the column added, the original record batch + object is left unchanged. + + Parameters + ---------- + i : int + Index to place the column at. + field_ : str or Field + If a string is passed then the type is deduced from the column + data. + column : Array or value coercible to array + Column data. + + Returns + ------- + RecordBatch + New record batch with the passed column added. + + Examples + -------- + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + >>> batch = pa.RecordBatch.from_pandas(df) + + Add column: + + >>> year = [2021, 2022, 2019, 2021] + >>> batch.add_column(0,"year", year) + pyarrow.RecordBatch + year: int64 + n_legs: int64 + animals: string + ---- + year: [2021,2022,2019,2021] + n_legs: [2,4,5,100] + animals: ["Flamingo","Horse","Brittle stars","Centipede"] + + Original record batch is left unchanged: + + >>> batch + pyarrow.RecordBatch + n_legs: int64 + animals: string + ---- + n_legs: [2,4,5,100] + animals: ["Flamingo","Horse","Brittle stars","Centipede"] + """ + cdef: + shared_ptr[CRecordBatch] c_batch + Field c_field + Array c_arr + CDeviceAllocationType device_type = self.sp_batch.get().device_type() + + if isinstance(column, Array): + c_arr = column + else: + c_arr = array(column) + + if device_type != c_arr.sp_array.get().device_type(): + raise TypeError("The column must be allocated on the same " + "device as the RecordBatch. Got column on " + f"device {c_arr.device_type!r}, but expected " + f"{self.device_type!r}.") + + if isinstance(field_, Field): + c_field = field_ + else: + c_field = field(field_, c_arr.type) + + with nogil: + c_batch = GetResultValue(self.batch.AddColumn( + i, c_field.sp_field, c_arr.sp_array)) + + return pyarrow_wrap_batch(c_batch) + + def remove_column(self, int i): + """ + Create new RecordBatch with the indicated column removed. + + Parameters + ---------- + i : int + Index of column to remove. + + Returns + ------- + Table + New record batch without the column. + + Examples + -------- + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + >>> batch = pa.RecordBatch.from_pandas(df) + >>> batch.remove_column(1) + pyarrow.RecordBatch + n_legs: int64 + ---- + n_legs: [2,4,5,100] + """ + cdef shared_ptr[CRecordBatch] c_batch + + with nogil: + c_batch = GetResultValue(self.batch.RemoveColumn(i)) + + return pyarrow_wrap_batch(c_batch) + + def set_column(self, int i, field_, column): + """ + Replace column in RecordBatch at position. + + Parameters + ---------- + i : int + Index to place the column at. + field_ : str or Field + If a string is passed then the type is deduced from the column + data. + column : Array or value coercible to array + Column data. + + Returns + ------- + RecordBatch + New record batch with the passed column set. + + Examples + -------- + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + >>> batch = pa.RecordBatch.from_pandas(df) + + Replace a column: + + >>> year = [2021, 2022, 2019, 2021] + >>> batch.set_column(1,'year', year) + pyarrow.RecordBatch + n_legs: int64 + year: int64 + ---- + n_legs: [2,4,5,100] + year: [2021,2022,2019,2021] + """ + cdef: + shared_ptr[CRecordBatch] c_batch + Field c_field + Array c_arr + CDeviceAllocationType device_type = self.sp_batch.get().device_type() + + if isinstance(column, Array): + c_arr = column + else: + c_arr = array(column) + + if device_type != c_arr.sp_array.get().device_type(): + raise TypeError("The column must be allocated on the same " + "device as the RecordBatch. Got column on " + f"device {c_arr.device_type!r}, but expected " + f"{self.device_type!r}.") + + if isinstance(field_, Field): + c_field = field_ + else: + c_field = field(field_, c_arr.type) + + with nogil: + c_batch = GetResultValue(self.batch.SetColumn( + i, c_field.sp_field, c_arr.sp_array)) + + return pyarrow_wrap_batch(c_batch) + + def rename_columns(self, names): + """ + Create new record batch with columns renamed to provided names. + + Parameters + ---------- + names : list[str] or dict[str, str] + List of new column names or mapping of old column names to new column names. + + If a mapping of old to new column names is passed, then all columns which are + found to match a provided old column name will be renamed to the new column name. + If any column names are not found in the mapping, a KeyError will be raised. + + Raises + ------ + KeyError + If any of the column names passed in the names mapping do not exist. + + Returns + ------- + RecordBatch + + Examples + -------- + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + >>> batch = pa.RecordBatch.from_pandas(df) + >>> new_names = ["n", "name"] + >>> batch.rename_columns(new_names) + pyarrow.RecordBatch + n: int64 + name: string + ---- + n: [2,4,5,100] + name: ["Flamingo","Horse","Brittle stars","Centipede"] + >>> new_names = {"n_legs": "n", "animals": "name"} + >>> batch.rename_columns(new_names) + pyarrow.RecordBatch + n: int64 + name: string + ---- + n: [2,4,5,100] + name: ["Flamingo","Horse","Brittle stars","Centipede"] + """ + cdef: + shared_ptr[CRecordBatch] c_batch + vector[c_string] c_names + + if isinstance(names, (list, tuple)): + for name in names: + c_names.push_back(tobytes(name)) + elif isinstance(names, dict): + idx_to_new_name = {} + for name, new_name in names.items(): + indices = self.schema.get_all_field_indices(name) + + if not indices: + raise KeyError("Column {!r} not found".format(name)) + + for index in indices: + idx_to_new_name[index] = new_name + + for i in range(self.num_columns): + new_name = idx_to_new_name.get(i, self.column_names[i]) + c_names.push_back(tobytes(new_name)) + else: + raise TypeError(f"names must be a list or dict not {type(names)!r}") + + with nogil: + c_batch = GetResultValue(self.batch.RenameColumns(move(c_names))) + + return pyarrow_wrap_batch(c_batch) + + def serialize(self, memory_pool=None): + """ + Write RecordBatch to Buffer as encapsulated IPC message, which does not + include a Schema. + + To reconstruct a RecordBatch from the encapsulated IPC message Buffer + returned by this function, a Schema must be passed separately. See + Examples. + + Parameters + ---------- + memory_pool : MemoryPool, default None + Uses default memory pool if not specified + + Returns + ------- + serialized : Buffer + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.array([2, 2, 4, 4, 5, 100]) + >>> animals = pa.array(["Flamingo", "Parrot", "Dog", "Horse", "Brittle stars", "Centipede"]) + >>> batch = pa.RecordBatch.from_arrays([n_legs, animals], + ... names=["n_legs", "animals"]) + >>> buf = batch.serialize() + >>> buf + + + Reconstruct RecordBatch from IPC message Buffer and original Schema + + >>> pa.ipc.read_record_batch(buf, batch.schema) + pyarrow.RecordBatch + n_legs: int64 + animals: string + ---- + n_legs: [2,2,4,4,5,100] + animals: ["Flamingo","Parrot","Dog","Horse","Brittle stars","Centipede"] + """ + self._assert_cpu() + cdef shared_ptr[CBuffer] buffer + cdef CIpcWriteOptions options = CIpcWriteOptions.Defaults() + options.memory_pool = maybe_unbox_memory_pool(memory_pool) + + with nogil: + buffer = GetResultValue( + SerializeRecordBatch(deref(self.batch), options)) + return pyarrow_wrap_buffer(buffer) + + def slice(self, offset=0, length=None): + """ + Compute zero-copy slice of this RecordBatch + + Parameters + ---------- + offset : int, default 0 + Offset from start of record batch to slice + length : int, default None + Length of slice (default is until end of batch starting from + offset) + + Returns + ------- + sliced : RecordBatch + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.array([2, 2, 4, 4, 5, 100]) + >>> animals = pa.array(["Flamingo", "Parrot", "Dog", "Horse", "Brittle stars", "Centipede"]) + >>> batch = pa.RecordBatch.from_arrays([n_legs, animals], + ... names=["n_legs", "animals"]) + >>> batch.to_pandas() + n_legs animals + 0 2 Flamingo + 1 2 Parrot + 2 4 Dog + 3 4 Horse + 4 5 Brittle stars + 5 100 Centipede + >>> batch.slice(offset=3).to_pandas() + n_legs animals + 0 4 Horse + 1 5 Brittle stars + 2 100 Centipede + >>> batch.slice(length=2).to_pandas() + n_legs animals + 0 2 Flamingo + 1 2 Parrot + >>> batch.slice(offset=3, length=1).to_pandas() + n_legs animals + 0 4 Horse + """ + cdef shared_ptr[CRecordBatch] result + + if offset < 0: + raise IndexError('Offset must be non-negative') + + offset = min(len(self), offset) + if length is None: + result = self.batch.Slice(offset) + else: + result = self.batch.Slice(offset, length) + + return pyarrow_wrap_batch(result) + + def equals(self, object other, bint check_metadata=False): + """ + Check if contents of two record batches are equal. + + Parameters + ---------- + other : pyarrow.RecordBatch + RecordBatch to compare against. + check_metadata : bool, default False + Whether schema metadata equality should be checked as well. + + Returns + ------- + are_equal : bool + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.array([2, 2, 4, 4, 5, 100]) + >>> animals = pa.array(["Flamingo", "Parrot", "Dog", "Horse", "Brittle stars", "Centipede"]) + >>> batch = pa.RecordBatch.from_arrays([n_legs, animals], + ... names=["n_legs", "animals"]) + >>> batch_0 = pa.record_batch([]) + >>> batch_1 = pa.RecordBatch.from_arrays([n_legs, animals], + ... names=["n_legs", "animals"], + ... metadata={"n_legs": "Number of legs per animal"}) + >>> batch.equals(batch) + True + >>> batch.equals(batch_0) + False + >>> batch.equals(batch_1) + True + >>> batch.equals(batch_1, check_metadata=True) + False + """ + self._assert_cpu() + cdef: + CRecordBatch* this_batch = self.batch + shared_ptr[CRecordBatch] other_batch = pyarrow_unwrap_batch(other) + c_bool result + + if not other_batch: + return False + + with nogil: + result = this_batch.Equals(deref(other_batch), check_metadata) + + return result + + def select(self, object columns): + """ + Select columns of the RecordBatch. + + Returns a new RecordBatch with the specified columns, and metadata + preserved. + + Parameters + ---------- + columns : list-like + The column names or integer indices to select. + + Returns + ------- + RecordBatch + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.array([2, 2, 4, 4, 5, 100]) + >>> animals = pa.array(["Flamingo", "Parrot", "Dog", "Horse", "Brittle stars", "Centipede"]) + >>> batch = pa.record_batch([n_legs, animals], + ... names=["n_legs", "animals"]) + + Select columns my indices: + + >>> batch.select([1]) + pyarrow.RecordBatch + animals: string + ---- + animals: ["Flamingo","Parrot","Dog","Horse","Brittle stars","Centipede"] + + Select columns by names: + + >>> batch.select(["n_legs"]) + pyarrow.RecordBatch + n_legs: int64 + ---- + n_legs: [2,2,4,4,5,100] + """ + cdef: + shared_ptr[CRecordBatch] c_batch + vector[int] c_indices + + for idx in columns: + idx = self._ensure_integer_index(idx) + idx = _normalize_index(idx, self.num_columns) + c_indices.push_back( idx) + + with nogil: + c_batch = GetResultValue(self.batch.SelectColumns(move(c_indices))) + + return pyarrow_wrap_batch(c_batch) + + def cast(self, Schema target_schema, safe=None, options=None): + """ + Cast record batch values to another schema. + + Parameters + ---------- + target_schema : Schema + Schema to cast to, the names and order of fields must match. + safe : bool, default True + Check for overflows or other unsafe conversions. + options : CastOptions, default None + Additional checks pass by CastOptions + + Returns + ------- + RecordBatch + + Examples + -------- + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + >>> batch = pa.RecordBatch.from_pandas(df) + >>> batch.schema + n_legs: int64 + animals: string + -- schema metadata -- + pandas: '{"index_columns": [{"kind": "range", "name": null, "start": 0, ... + + Define new schema and cast batch values: + + >>> my_schema = pa.schema([ + ... pa.field('n_legs', pa.duration('s')), + ... pa.field('animals', pa.string())] + ... ) + >>> batch.cast(target_schema=my_schema) + pyarrow.RecordBatch + n_legs: duration[s] + animals: string + ---- + n_legs: [2,4,5,100] + animals: ["Flamingo","Horse","Brittle stars","Centipede"] + """ + cdef: + Array column, casted + Field field + list newcols = [] + + if self.schema.names != target_schema.names: + raise ValueError("Target schema's field names are not matching " + "the record batch's field names: {!r}, {!r}" + .format(self.schema.names, target_schema.names)) + + for column, field in zip(self.itercolumns(), target_schema): + if not field.nullable and column.null_count > 0: + raise ValueError("Casting field {!r} with null values to non-nullable" + .format(field.name)) + casted = column.cast(field.type, safe=safe, options=options) + newcols.append(casted) + + return RecordBatch.from_arrays(newcols, schema=target_schema) + + def _to_pandas(self, options, **kwargs): + self._assert_cpu() + return Table.from_batches([self])._to_pandas(options, **kwargs) + + @classmethod + def from_pandas(cls, df, Schema schema=None, preserve_index=None, + nthreads=None, columns=None): + """ + Convert pandas.DataFrame to an Arrow RecordBatch + + Parameters + ---------- + df : pandas.DataFrame + schema : pyarrow.Schema, optional + The expected schema of the RecordBatch. This can be used to + indicate the type of columns if we cannot infer it automatically. + If passed, the output will have exactly this schema. Columns + specified in the schema that are not found in the DataFrame columns + or its index will raise an error. Additional columns or index + levels in the DataFrame which are not specified in the schema will + be ignored. + preserve_index : bool, optional + Whether to store the index as an additional column in the resulting + ``RecordBatch``. The default of None will store the index as a + column, except for RangeIndex which is stored as metadata only. Use + ``preserve_index=True`` to force it to be stored as a column. + nthreads : int, default None + If greater than 1, convert columns to Arrow in parallel using + indicated number of threads. By default, this follows + :func:`pyarrow.cpu_count` (may use up to system CPU count threads). + columns : list, optional + List of column to be converted. If None, use all columns. + + Returns + ------- + pyarrow.RecordBatch + + + Examples + -------- + >>> import pandas as pd + >>> df = pd.DataFrame({'year': [2020, 2022, 2021, 2022], + ... 'month': [3, 5, 7, 9], + ... 'day': [1, 5, 9, 13], + ... 'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + + Convert pandas DataFrame to RecordBatch: + + >>> import pyarrow as pa + >>> pa.RecordBatch.from_pandas(df) + pyarrow.RecordBatch + year: int64 + month: int64 + day: int64 + n_legs: int64 + animals: string + ---- + year: [2020,2022,2021,2022] + month: [3,5,7,9] + day: [1,5,9,13] + n_legs: [2,4,5,100] + animals: ["Flamingo","Horse","Brittle stars","Centipede"] + + Convert pandas DataFrame to RecordBatch using schema: + + >>> my_schema = pa.schema([ + ... pa.field('n_legs', pa.int64()), + ... pa.field('animals', pa.string())], + ... metadata={"n_legs": "Number of legs per animal"}) + >>> pa.RecordBatch.from_pandas(df, schema=my_schema) + pyarrow.RecordBatch + n_legs: int64 + animals: string + ---- + n_legs: [2,4,5,100] + animals: ["Flamingo","Horse","Brittle stars","Centipede"] + + Convert pandas DataFrame to RecordBatch specifying columns: + + >>> pa.RecordBatch.from_pandas(df, columns=["n_legs"]) + pyarrow.RecordBatch + n_legs: int64 + ---- + n_legs: [2,4,5,100] + """ + from pyarrow.pandas_compat import dataframe_to_arrays + arrays, schema, n_rows = dataframe_to_arrays( + df, schema, preserve_index, nthreads=nthreads, columns=columns + ) + + # If df is empty but row index is not, create empty RecordBatch with rows >0 + cdef vector[shared_ptr[CArray]] c_arrays + if n_rows: + return pyarrow_wrap_batch(CRecordBatch.Make(( schema).sp_schema, + n_rows, c_arrays)) + else: + return cls.from_arrays(arrays, schema=schema) + + @staticmethod + def from_arrays(list arrays, names=None, schema=None, metadata=None): + """ + Construct a RecordBatch from multiple pyarrow.Arrays + + Parameters + ---------- + arrays : list of pyarrow.Array + One for each field in RecordBatch + names : list of str, optional + Names for the batch fields. If not passed, schema must be passed + schema : Schema, default None + Schema for the created batch. If not passed, names must be passed + metadata : dict or Mapping, default None + Optional metadata for the schema (if inferred). + + Returns + ------- + pyarrow.RecordBatch + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.array([2, 2, 4, 4, 5, 100]) + >>> animals = pa.array(["Flamingo", "Parrot", "Dog", "Horse", "Brittle stars", "Centipede"]) + >>> names = ["n_legs", "animals"] + + Construct a RecordBatch from pyarrow Arrays using names: + + >>> pa.RecordBatch.from_arrays([n_legs, animals], names=names) + pyarrow.RecordBatch + n_legs: int64 + animals: string + ---- + n_legs: [2,2,4,4,5,100] + animals: ["Flamingo","Parrot","Dog","Horse","Brittle stars","Centipede"] + >>> pa.RecordBatch.from_arrays([n_legs, animals], names=names).to_pandas() + n_legs animals + 0 2 Flamingo + 1 2 Parrot + 2 4 Dog + 3 4 Horse + 4 5 Brittle stars + 5 100 Centipede + + Construct a RecordBatch from pyarrow Arrays using schema: + + >>> my_schema = pa.schema([ + ... pa.field('n_legs', pa.int64()), + ... pa.field('animals', pa.string())], + ... metadata={"n_legs": "Number of legs per animal"}) + >>> pa.RecordBatch.from_arrays([n_legs, animals], schema=my_schema).to_pandas() + n_legs animals + 0 2 Flamingo + 1 2 Parrot + 2 4 Dog + 3 4 Horse + 4 5 Brittle stars + 5 100 Centipede + >>> pa.RecordBatch.from_arrays([n_legs, animals], schema=my_schema).schema + n_legs: int64 + animals: string + -- schema metadata -- + n_legs: 'Number of legs per animal' + """ + cdef: + Array arr + shared_ptr[CSchema] c_schema + vector[shared_ptr[CArray]] c_arrays + int64_t num_rows + + if len(arrays) > 0: + num_rows = len(arrays[0]) + else: + num_rows = 0 + + if isinstance(names, Schema): + import warnings + warnings.warn("Schema passed to names= option, please " + "pass schema= explicitly. " + "Will raise exception in future", FutureWarning) + schema = names + names = None + + converted_arrays = _sanitize_arrays(arrays, names, schema, metadata, + &c_schema) + + c_arrays.reserve(len(arrays)) + for arr in converted_arrays: + if len(arr) != num_rows: + raise ValueError('Arrays were not all the same length: ' + '{0} vs {1}'.format(len(arr), num_rows)) + c_arrays.push_back(arr.sp_array) + + result = pyarrow_wrap_batch(CRecordBatch.Make(c_schema, num_rows, + c_arrays)) + result.validate() + return result + + @staticmethod + def from_struct_array(StructArray struct_array): + """ + Construct a RecordBatch from a StructArray. + + Each field in the StructArray will become a column in the resulting + ``RecordBatch``. + + Parameters + ---------- + struct_array : StructArray + Array to construct the record batch from. + + Returns + ------- + pyarrow.RecordBatch + + Examples + -------- + >>> import pyarrow as pa + >>> struct = pa.array([{'n_legs': 2, 'animals': 'Parrot'}, + ... {'year': 2022, 'n_legs': 4}]) + >>> pa.RecordBatch.from_struct_array(struct).to_pandas() + animals n_legs year + 0 Parrot 2 NaN + 1 None 4 2022.0 + """ + cdef: + shared_ptr[CRecordBatch] c_record_batch + if struct_array.sp_array.get().device_type() != CDeviceAllocationType_kCPU: + raise NotImplementedError("Implemented only for data on CPU device") + with nogil: + c_record_batch = GetResultValue( + CRecordBatch.FromStructArray(struct_array.sp_array)) + return pyarrow_wrap_batch(c_record_batch) + + def to_struct_array(self): + """ + Convert to a struct array. + """ + self._assert_cpu() + cdef: + shared_ptr[CRecordBatch] c_record_batch + shared_ptr[CArray] c_array + + c_record_batch = pyarrow_unwrap_batch(self) + with nogil: + c_array = GetResultValue( + deref(c_record_batch).ToStructArray()) + return pyarrow_wrap_array(c_array) + + def to_tensor(self, c_bool null_to_nan=False, c_bool row_major=True, MemoryPool memory_pool=None): + """ + Convert to a :class:`~pyarrow.Tensor`. + + RecordBatches that can be converted have fields of type signed or unsigned + integer or float, including all bit-widths. + + ``null_to_nan`` is ``False`` by default and this method will raise an error in case + any nulls are present. RecordBatches with nulls can be converted with ``null_to_nan`` + set to ``True``. In this case null values are converted to ``NaN`` and integer type + arrays are promoted to the appropriate float type. + + Parameters + ---------- + null_to_nan : bool, default False + Whether to write null values in the result as ``NaN``. + row_major : bool, default True + Whether resulting Tensor is row-major or column-major + memory_pool : MemoryPool, default None + For memory allocations, if required, otherwise use default pool + + Examples + -------- + >>> import pyarrow as pa + >>> batch = pa.record_batch( + ... [ + ... pa.array([1, 2, 3, 4, None], type=pa.int32()), + ... pa.array([10, 20, 30, 40, None], type=pa.float32()), + ... ], names = ["a", "b"] + ... ) + + >>> batch + pyarrow.RecordBatch + a: int32 + b: float + ---- + a: [1,2,3,4,null] + b: [10,20,30,40,null] + + Convert a RecordBatch to row-major Tensor with null values + written as ``NaN``s + + >>> batch.to_tensor(null_to_nan=True) + + type: double + shape: (5, 2) + strides: (16, 8) + >>> batch.to_tensor(null_to_nan=True).to_numpy() + array([[ 1., 10.], + [ 2., 20.], + [ 3., 30.], + [ 4., 40.], + [nan, nan]]) + + Convert a RecordBatch to column-major Tensor + + >>> batch.to_tensor(null_to_nan=True, row_major=False) + + type: double + shape: (5, 2) + strides: (8, 40) + >>> batch.to_tensor(null_to_nan=True, row_major=False).to_numpy() + array([[ 1., 10.], + [ 2., 20.], + [ 3., 30.], + [ 4., 40.], + [nan, nan]]) + """ + self._assert_cpu() + cdef: + shared_ptr[CRecordBatch] c_record_batch + shared_ptr[CTensor] c_tensor + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + + c_record_batch = pyarrow_unwrap_batch(self) + with nogil: + c_tensor = GetResultValue( + deref(c_record_batch).ToTensor(null_to_nan, + row_major, pool)) + return pyarrow_wrap_tensor(c_tensor) + + def copy_to(self, destination): + """ + Copy the entire RecordBatch to destination device. + + This copies each column of the record batch to create + a new record batch where all underlying buffers for the columns have + been copied to the destination MemoryManager. + + Parameters + ---------- + destination : pyarrow.MemoryManager or pyarrow.Device + The destination device to copy the array to. + + Returns + ------- + RecordBatch + """ + cdef: + shared_ptr[CRecordBatch] c_batch + shared_ptr[CMemoryManager] c_memory_manager + + if isinstance(destination, Device): + c_memory_manager = (destination).unwrap().get().default_memory_manager() + elif isinstance(destination, MemoryManager): + c_memory_manager = (destination).unwrap() + else: + raise TypeError( + "Argument 'destination' has incorrect type (expected a " + f"pyarrow Device or MemoryManager, got {type(destination)})" + ) + + with nogil: + c_batch = GetResultValue(self.batch.CopyTo(c_memory_manager)) + return pyarrow_wrap_batch(c_batch) + + def _export_to_c(self, out_ptr, out_schema_ptr=0): + """ + Export to a C ArrowArray struct, given its pointer. + + If a C ArrowSchema struct pointer is also given, the record batch + schema is exported to it at the same time. + + Parameters + ---------- + out_ptr: int + The raw pointer to a C ArrowArray struct. + out_schema_ptr: int (optional) + The raw pointer to a C ArrowSchema struct. + + Be careful: if you don't pass the ArrowArray struct to a consumer, + array memory will leak. This is a low-level function intended for + expert users. + """ + cdef: + void* c_ptr = _as_c_pointer(out_ptr) + void* c_schema_ptr = _as_c_pointer(out_schema_ptr, + allow_null=True) + with nogil: + check_status(ExportRecordBatch(deref(self.sp_batch), + c_ptr, + c_schema_ptr)) + + @staticmethod + def _import_from_c(in_ptr, schema): + """ + Import RecordBatch from a C ArrowArray struct, given its pointer + and the imported schema. + + Parameters + ---------- + in_ptr: int + The raw pointer to a C ArrowArray struct. + type: Schema or int + Either a Schema object, or the raw pointer to a C ArrowSchema + struct. + + This is a low-level function intended for expert users. + """ + cdef: + void* c_ptr = _as_c_pointer(in_ptr) + void* c_schema_ptr + shared_ptr[CRecordBatch] c_batch + + c_schema = pyarrow_unwrap_schema(schema) + if c_schema == nullptr: + # Not a Schema object, perhaps a raw ArrowSchema pointer + c_schema_ptr = _as_c_pointer(schema, allow_null=True) + with nogil: + c_batch = GetResultValue(ImportRecordBatch( + c_ptr, c_schema_ptr)) + else: + with nogil: + c_batch = GetResultValue(ImportRecordBatch( + c_ptr, c_schema)) + return pyarrow_wrap_batch(c_batch) + + def __arrow_c_array__(self, requested_schema=None): + """ + Get a pair of PyCapsules containing a C ArrowArray representation of the object. + + Parameters + ---------- + requested_schema : PyCapsule | None + A PyCapsule containing a C ArrowSchema representation of a requested + schema. PyArrow will attempt to cast the batch to this schema. + If None, the batch will be returned as-is, with a schema matching the + one returned by :meth:`__arrow_c_schema__()`. + + Returns + ------- + Tuple[PyCapsule, PyCapsule] + A pair of PyCapsules containing a C ArrowSchema and ArrowArray, + respectively. + """ + self._assert_cpu() + cdef: + ArrowArray* c_array + ArrowSchema* c_schema + + if requested_schema is not None: + target_schema = Schema._import_from_c_capsule(requested_schema) + + if target_schema != self.schema: + try: + casted_batch = self.cast(target_schema, safe=True) + inner_batch = pyarrow_unwrap_batch(casted_batch) + except ArrowInvalid as e: + raise ValueError( + f"Could not cast {self.schema} to requested schema {target_schema}: {e}" + ) + else: + inner_batch = self.sp_batch + else: + inner_batch = self.sp_batch + + schema_capsule = alloc_c_schema(&c_schema) + array_capsule = alloc_c_array(&c_array) + + with nogil: + check_status(ExportRecordBatch(deref(inner_batch), c_array, c_schema)) + + return schema_capsule, array_capsule + + def __arrow_c_stream__(self, requested_schema=None): + """ + Export the batch as an Arrow C stream PyCapsule. + + Parameters + ---------- + requested_schema : PyCapsule, default None + The schema to which the stream should be casted, passed as a + PyCapsule containing a C ArrowSchema representation of the + requested schema. + Currently, this is not supported and will raise a + NotImplementedError if the schema doesn't match the current schema. + + Returns + ------- + PyCapsule + """ + self._assert_cpu() + return Table.from_batches([self]).__arrow_c_stream__(requested_schema) + + @staticmethod + def _import_from_c_capsule(schema_capsule, array_capsule): + """ + Import RecordBatch from a pair of PyCapsules containing a C ArrowSchema + and ArrowArray, respectively. + + Parameters + ---------- + schema_capsule : PyCapsule + A PyCapsule containing a C ArrowSchema representation of the schema. + array_capsule : PyCapsule + A PyCapsule containing a C ArrowArray representation of the array. + + Returns + ------- + pyarrow.RecordBatch + """ + cdef: + ArrowSchema* c_schema + ArrowArray* c_array + shared_ptr[CRecordBatch] c_batch + + c_schema = PyCapsule_GetPointer(schema_capsule, 'arrow_schema') + c_array = PyCapsule_GetPointer(array_capsule, 'arrow_array') + + with nogil: + c_batch = GetResultValue(ImportRecordBatch(c_array, c_schema)) + + return pyarrow_wrap_batch(c_batch) + + def _export_to_c_device(self, out_ptr, out_schema_ptr=0): + """ + Export to a C ArrowDeviceArray struct, given its pointer. + + If a C ArrowSchema struct pointer is also given, the record batch + schema is exported to it at the same time. + + Parameters + ---------- + out_ptr: int + The raw pointer to a C ArrowDeviceArray struct. + out_schema_ptr: int (optional) + The raw pointer to a C ArrowSchema struct. + + Be careful: if you don't pass the ArrowDeviceArray struct to a consumer, + array memory will leak. This is a low-level function intended for + expert users. + """ + cdef: + void* c_ptr = _as_c_pointer(out_ptr) + void* c_schema_ptr = _as_c_pointer(out_schema_ptr, + allow_null=True) + with nogil: + check_status(ExportDeviceRecordBatch( + deref(self.sp_batch), NULL, + c_ptr, c_schema_ptr) + ) + + @staticmethod + def _import_from_c_device(in_ptr, schema): + """ + Import RecordBatch from a C ArrowDeviceArray struct, given its pointer + and the imported schema. + + Parameters + ---------- + in_ptr: int + The raw pointer to a C ArrowDeviceArray struct. + type: Schema or int + Either a Schema object, or the raw pointer to a C ArrowSchema + struct. + + This is a low-level function intended for expert users. + """ + cdef: + ArrowDeviceArray* c_device_array = _as_c_pointer(in_ptr) + void* c_schema_ptr + shared_ptr[CRecordBatch] c_batch + + if c_device_array.device_type == ARROW_DEVICE_CUDA: + _ensure_cuda_loaded() + + c_schema = pyarrow_unwrap_schema(schema) + if c_schema == nullptr: + # Not a Schema object, perhaps a raw ArrowSchema pointer + c_schema_ptr = _as_c_pointer(schema, allow_null=True) + with nogil: + c_batch = GetResultValue(ImportDeviceRecordBatch( + c_device_array, c_schema_ptr)) + else: + with nogil: + c_batch = GetResultValue(ImportDeviceRecordBatch( + c_device_array, c_schema)) + return pyarrow_wrap_batch(c_batch) + + def __arrow_c_device_array__(self, requested_schema=None, **kwargs): + """ + Get a pair of PyCapsules containing a C ArrowDeviceArray representation + of the object. + + Parameters + ---------- + requested_schema : PyCapsule | None + A PyCapsule containing a C ArrowSchema representation of a requested + schema. PyArrow will attempt to cast the batch to this data type. + If None, the batch will be returned as-is, with a type matching the + one returned by :meth:`__arrow_c_schema__()`. + kwargs + Currently no additional keyword arguments are supported, but + this method will accept any keyword with a value of ``None`` + for compatibility with future keywords. + + Returns + ------- + Tuple[PyCapsule, PyCapsule] + A pair of PyCapsules containing a C ArrowSchema and ArrowDeviceArray, + respectively. + """ + cdef: + ArrowDeviceArray* c_array + ArrowSchema* c_schema + shared_ptr[CRecordBatch] inner_batch + + non_default_kwargs = [ + name for name, value in kwargs.items() if value is not None + ] + if non_default_kwargs: + raise NotImplementedError( + f"Received unsupported keyword argument(s): {non_default_kwargs}" + ) + + if requested_schema is not None: + target_schema = Schema._import_from_c_capsule(requested_schema) + + if target_schema != self.schema: + if not self.is_cpu: + raise NotImplementedError( + "Casting to a requested schema is only supported for CPU data" + ) + try: + casted_batch = self.cast(target_schema, safe=True) + inner_batch = pyarrow_unwrap_batch(casted_batch) + except ArrowInvalid as e: + raise ValueError( + f"Could not cast {self.schema} to requested schema {target_schema}: {e}" + ) + else: + inner_batch = self.sp_batch + else: + inner_batch = self.sp_batch + + schema_capsule = alloc_c_schema(&c_schema) + array_capsule = alloc_c_device_array(&c_array) + + with nogil: + check_status(ExportDeviceRecordBatch( + deref(inner_batch), NULL, c_array, c_schema)) + + return schema_capsule, array_capsule + + @staticmethod + def _import_from_c_device_capsule(schema_capsule, array_capsule): + """ + Import RecordBatch from a pair of PyCapsules containing a + C ArrowSchema and ArrowDeviceArray, respectively. + + Parameters + ---------- + schema_capsule : PyCapsule + A PyCapsule containing a C ArrowSchema representation of the schema. + array_capsule : PyCapsule + A PyCapsule containing a C ArrowDeviceArray representation of the array. + + Returns + ------- + pyarrow.RecordBatch + """ + cdef: + ArrowSchema* c_schema + ArrowDeviceArray* c_array + shared_ptr[CRecordBatch] batch + + c_schema = PyCapsule_GetPointer(schema_capsule, 'arrow_schema') + c_array = PyCapsule_GetPointer( + array_capsule, 'arrow_device_array' + ) + + with nogil: + batch = GetResultValue(ImportDeviceRecordBatch(c_array, c_schema)) + + return pyarrow_wrap_batch(batch) + + @property + def device_type(self): + """ + The device type where the arrays in the RecordBatch reside. + + Returns + ------- + DeviceAllocationType + """ + return _wrap_device_allocation_type(self.sp_batch.get().device_type()) + + @property + def is_cpu(self): + """ + Whether the RecordBatch's arrays are CPU-accessible. + """ + return self.device_type == DeviceAllocationType.CPU + + cdef void _assert_cpu(self) except *: + if self.sp_batch.get().device_type() != CDeviceAllocationType_kCPU: + raise NotImplementedError("Implemented only for data on CPU device") + + +def _reconstruct_record_batch(columns, schema): + """ + Internal: reconstruct RecordBatch from pickled components. + """ + return RecordBatch.from_arrays(columns, schema=schema) + + +def table_to_blocks(options, Table table, categories, extension_columns): + cdef: + PyObject* result_obj + shared_ptr[CTable] c_table + CMemoryPool* pool + PandasOptions c_options = _convert_pandas_options(options) + + if categories is not None: + c_options.categorical_columns = {tobytes(cat) for cat in categories} + if extension_columns is not None: + c_options.extension_columns = {tobytes(col) + for col in extension_columns} + + if pandas_api.is_v1(): + # ARROW-3789: Coerce date/timestamp types to datetime64[ns] + c_options.coerce_temporal_nanoseconds = True + + if c_options.self_destruct: + # Move the shared_ptr, table is now unsafe to use further + c_table = move(table.sp_table) + table.table = NULL + else: + c_table = table.sp_table + + with nogil: + check_status( + libarrow_python.ConvertTableToPandas(c_options, move(c_table), + &result_obj) + ) + + return PyObject_to_object(result_obj) + + +cdef class Table(_Tabular): + """ + A collection of top-level named, equal length Arrow arrays. + + Warnings + -------- + Do not call this class's constructor directly, use one of the ``from_*`` + methods instead. + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.array([2, 4, 5, 100]) + >>> animals = pa.array(["Flamingo", "Horse", "Brittle stars", "Centipede"]) + >>> names = ["n_legs", "animals"] + + Construct a Table from arrays: + + >>> pa.Table.from_arrays([n_legs, animals], names=names) + pyarrow.Table + n_legs: int64 + animals: string + ---- + n_legs: [[2,4,5,100]] + animals: [["Flamingo","Horse","Brittle stars","Centipede"]] + + Construct a Table from a RecordBatch: + + >>> batch = pa.record_batch([n_legs, animals], names=names) + >>> pa.Table.from_batches([batch]) + pyarrow.Table + n_legs: int64 + animals: string + ---- + n_legs: [[2,4,5,100]] + animals: [["Flamingo","Horse","Brittle stars","Centipede"]] + + Construct a Table from pandas DataFrame: + + >>> import pandas as pd + >>> df = pd.DataFrame({'year': [2020, 2022, 2019, 2021], + ... 'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + >>> pa.Table.from_pandas(df) + pyarrow.Table + year: int64 + n_legs: int64 + animals: string + ---- + year: [[2020,2022,2019,2021]] + n_legs: [[2,4,5,100]] + animals: [["Flamingo","Horse","Brittle stars","Centipede"]] + + Construct a Table from a dictionary of arrays: + + >>> pydict = {'n_legs': n_legs, 'animals': animals} + >>> pa.Table.from_pydict(pydict) + pyarrow.Table + n_legs: int64 + animals: string + ---- + n_legs: [[2,4,5,100]] + animals: [["Flamingo","Horse","Brittle stars","Centipede"]] + >>> pa.Table.from_pydict(pydict).schema + n_legs: int64 + animals: string + + Construct a Table from a dictionary of arrays with metadata: + + >>> my_metadata={"n_legs": "Number of legs per animal"} + >>> pa.Table.from_pydict(pydict, metadata=my_metadata).schema + n_legs: int64 + animals: string + -- schema metadata -- + n_legs: 'Number of legs per animal' + + Construct a Table from a list of rows: + + >>> pylist = [{'n_legs': 2, 'animals': 'Flamingo'}, {'year': 2021, 'animals': 'Centipede'}] + >>> pa.Table.from_pylist(pylist) + pyarrow.Table + n_legs: int64 + animals: string + ---- + n_legs: [[2,null]] + animals: [["Flamingo","Centipede"]] + + Construct a Table from a list of rows with pyarrow schema: + + >>> my_schema = pa.schema([ + ... pa.field('year', pa.int64()), + ... pa.field('n_legs', pa.int64()), + ... pa.field('animals', pa.string())], + ... metadata={"year": "Year of entry"}) + >>> pa.Table.from_pylist(pylist, schema=my_schema).schema + year: int64 + n_legs: int64 + animals: string + -- schema metadata -- + year: 'Year of entry' + + Construct a Table with :func:`pyarrow.table`: + + >>> pa.table([n_legs, animals], names=names) + pyarrow.Table + n_legs: int64 + animals: string + ---- + n_legs: [[2,4,5,100]] + animals: [["Flamingo","Horse","Brittle stars","Centipede"]] + """ + + def __cinit__(self): + self.table = NULL + self._init_is_cpu = False + + cdef void init(self, const shared_ptr[CTable]& table): + self.sp_table = table + self.table = table.get() + + def _is_initialized(self): + return self.table != NULL + + def validate(self, *, full=False): + """ + Perform validation checks. An exception is raised if validation fails. + + By default only cheap validation checks are run. Pass `full=True` + for thorough validation checks (potentially O(n)). + + Parameters + ---------- + full : bool, default False + If True, run expensive checks, otherwise cheap checks only. + + Raises + ------ + ArrowInvalid + """ + if full: + self._assert_cpu() + with nogil: + check_status(self.table.ValidateFull()) + else: + with nogil: + check_status(self.table.Validate()) + + def __reduce__(self): + # Reduce the columns as ChunkedArrays to avoid serializing schema + # data twice + self._assert_cpu() + columns = [col for col in self.columns] + return _reconstruct_table, (columns, self.schema) + + def slice(self, offset=0, length=None): + """ + Compute zero-copy slice of this Table. + + Parameters + ---------- + offset : int, default 0 + Offset from start of table to slice. + length : int, default None + Length of slice (default is until end of table starting from + offset). + + Returns + ------- + Table + + Examples + -------- + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'year': [2020, 2022, 2019, 2021], + ... 'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + >>> table = pa.Table.from_pandas(df) + >>> table.slice(length=3) + pyarrow.Table + year: int64 + n_legs: int64 + animals: string + ---- + year: [[2020,2022,2019]] + n_legs: [[2,4,5]] + animals: [["Flamingo","Horse","Brittle stars"]] + >>> table.slice(offset=2) + pyarrow.Table + year: int64 + n_legs: int64 + animals: string + ---- + year: [[2019,2021]] + n_legs: [[5,100]] + animals: [["Brittle stars","Centipede"]] + >>> table.slice(offset=2, length=1) + pyarrow.Table + year: int64 + n_legs: int64 + animals: string + ---- + year: [[2019]] + n_legs: [[5]] + animals: [["Brittle stars"]] + """ + cdef shared_ptr[CTable] result + + if offset < 0: + raise IndexError('Offset must be non-negative') + + offset = min(len(self), offset) + if length is None: + result = self.table.Slice(offset) + else: + result = self.table.Slice(offset, length) + + return pyarrow_wrap_table(result) + + def select(self, object columns): + """ + Select columns of the Table. + + Returns a new Table with the specified columns, and metadata + preserved. + + Parameters + ---------- + columns : list-like + The column names or integer indices to select. + + Returns + ------- + Table + + Examples + -------- + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'year': [2020, 2022, 2019, 2021], + ... 'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + >>> table = pa.Table.from_pandas(df) + >>> table.select([0,1]) + pyarrow.Table + year: int64 + n_legs: int64 + ---- + year: [[2020,2022,2019,2021]] + n_legs: [[2,4,5,100]] + >>> table.select(["year"]) + pyarrow.Table + year: int64 + ---- + year: [[2020,2022,2019,2021]] + """ + cdef: + shared_ptr[CTable] c_table + vector[int] c_indices + + for idx in columns: + idx = self._ensure_integer_index(idx) + idx = _normalize_index(idx, self.num_columns) + c_indices.push_back( idx) + + with nogil: + c_table = GetResultValue(self.table.SelectColumns(move(c_indices))) + + return pyarrow_wrap_table(c_table) + + def replace_schema_metadata(self, metadata=None): + """ + Create shallow copy of table by replacing schema + key-value metadata with the indicated new metadata (which may be None), + which deletes any existing metadata. + + Parameters + ---------- + metadata : dict, default None + + Returns + ------- + Table + + Examples + -------- + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'year': [2020, 2022, 2019, 2021], + ... 'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + >>> table = pa.Table.from_pandas(df) + + Constructing a Table with pyarrow schema and metadata: + + >>> my_schema = pa.schema([ + ... pa.field('n_legs', pa.int64()), + ... pa.field('animals', pa.string())], + ... metadata={"n_legs": "Number of legs per animal"}) + >>> table= pa.table(df, my_schema) + >>> table.schema + n_legs: int64 + animals: string + -- schema metadata -- + n_legs: 'Number of legs per animal' + pandas: ... + + Create a shallow copy of a Table with deleted schema metadata: + + >>> table.replace_schema_metadata().schema + n_legs: int64 + animals: string + + Create a shallow copy of a Table with new schema metadata: + + >>> metadata={"animals": "Which animal"} + >>> table.replace_schema_metadata(metadata = metadata).schema + n_legs: int64 + animals: string + -- schema metadata -- + animals: 'Which animal' + """ + cdef: + shared_ptr[const CKeyValueMetadata] c_meta + shared_ptr[CTable] c_table + + metadata = ensure_metadata(metadata, allow_none=True) + c_meta = pyarrow_unwrap_metadata(metadata) + with nogil: + c_table = self.table.ReplaceSchemaMetadata(c_meta) + + return pyarrow_wrap_table(c_table) + + def flatten(self, MemoryPool memory_pool=None): + """ + Flatten this Table. + + Each column with a struct type is flattened + into one column per struct field. Other columns are left unchanged. + + Parameters + ---------- + memory_pool : MemoryPool, default None + For memory allocations, if required, otherwise use default pool + + Returns + ------- + Table + + Examples + -------- + >>> import pyarrow as pa + >>> struct = pa.array([{'n_legs': 2, 'animals': 'Parrot'}, + ... {'year': 2022, 'n_legs': 4}]) + >>> month = pa.array([4, 6]) + >>> table = pa.Table.from_arrays([struct,month], + ... names = ["a", "month"]) + >>> table + pyarrow.Table + a: struct + child 0, animals: string + child 1, n_legs: int64 + child 2, year: int64 + month: int64 + ---- + a: [ + -- is_valid: all not null + -- child 0 type: string + ["Parrot",null] + -- child 1 type: int64 + [2,4] + -- child 2 type: int64 + [null,2022]] + month: [[4,6]] + + Flatten the columns with struct field: + + >>> table.flatten() + pyarrow.Table + a.animals: string + a.n_legs: int64 + a.year: int64 + month: int64 + ---- + a.animals: [["Parrot",null]] + a.n_legs: [[2,4]] + a.year: [[null,2022]] + month: [[4,6]] + """ + self._assert_cpu() + cdef: + shared_ptr[CTable] flattened + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + + with nogil: + flattened = GetResultValue(self.table.Flatten(pool)) + + return pyarrow_wrap_table(flattened) + + def combine_chunks(self, MemoryPool memory_pool=None): + """ + Make a new table by combining the chunks this table has. + + All the underlying chunks in the ChunkedArray of each column are + concatenated into zero or one chunk. + + Parameters + ---------- + memory_pool : MemoryPool, default None + For memory allocations, if required, otherwise use default pool. + + Returns + ------- + Table + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.chunked_array([[2, 2, 4], [4, 5, 100]]) + >>> animals = pa.chunked_array([["Flamingo", "Parrot", "Dog"], ["Horse", "Brittle stars", "Centipede"]]) + >>> names = ["n_legs", "animals"] + >>> table = pa.table([n_legs, animals], names=names) + >>> table + pyarrow.Table + n_legs: int64 + animals: string + ---- + n_legs: [[2,2,4],[4,5,100]] + animals: [["Flamingo","Parrot","Dog"],["Horse","Brittle stars","Centipede"]] + >>> table.combine_chunks() + pyarrow.Table + n_legs: int64 + animals: string + ---- + n_legs: [[2,2,4,4,5,100]] + animals: [["Flamingo","Parrot","Dog","Horse","Brittle stars","Centipede"]] + """ + self._assert_cpu() + cdef: + shared_ptr[CTable] combined + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + + with nogil: + combined = GetResultValue(self.table.CombineChunks(pool)) + + return pyarrow_wrap_table(combined) + + def unify_dictionaries(self, MemoryPool memory_pool=None): + """ + Unify dictionaries across all chunks. + + This method returns an equivalent table, but where all chunks of + each column share the same dictionary values. Dictionary indices + are transposed accordingly. + + Columns without dictionaries are returned unchanged. + + Parameters + ---------- + memory_pool : MemoryPool, default None + For memory allocations, if required, otherwise use default pool + + Returns + ------- + Table + + Examples + -------- + >>> import pyarrow as pa + >>> arr_1 = pa.array(["Flamingo", "Parrot", "Dog"]).dictionary_encode() + >>> arr_2 = pa.array(["Horse", "Brittle stars", "Centipede"]).dictionary_encode() + >>> c_arr = pa.chunked_array([arr_1, arr_2]) + >>> table = pa.table([c_arr], names=["animals"]) + >>> table + pyarrow.Table + animals: dictionary + ---- + animals: [ -- dictionary: + ["Flamingo","Parrot","Dog"] -- indices: + [0,1,2], -- dictionary: + ["Horse","Brittle stars","Centipede"] -- indices: + [0,1,2]] + + Unify dictionaries across both chunks: + + >>> table.unify_dictionaries() + pyarrow.Table + animals: dictionary + ---- + animals: [ -- dictionary: + ["Flamingo","Parrot","Dog","Horse","Brittle stars","Centipede"] -- indices: + [0,1,2], -- dictionary: + ["Flamingo","Parrot","Dog","Horse","Brittle stars","Centipede"] -- indices: + [3,4,5]] + """ + self._assert_cpu() + cdef: + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + shared_ptr[CTable] c_result + + with nogil: + c_result = GetResultValue(CDictionaryUnifier.UnifyTable( + deref(self.table), pool)) + + return pyarrow_wrap_table(c_result) + + def equals(self, Table other, bint check_metadata=False): + """ + Check if contents of two tables are equal. + + Parameters + ---------- + other : pyarrow.Table + Table to compare against. + check_metadata : bool, default False + Whether schema metadata equality should be checked as well. + + Returns + ------- + bool + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.array([2, 2, 4, 4, 5, 100]) + >>> animals = pa.array(["Flamingo", "Parrot", "Dog", "Horse", "Brittle stars", "Centipede"]) + >>> names=["n_legs", "animals"] + >>> table = pa.Table.from_arrays([n_legs, animals], names=names) + >>> table_0 = pa.Table.from_arrays([]) + >>> table_1 = pa.Table.from_arrays([n_legs, animals], + ... names=names, + ... metadata={"n_legs": "Number of legs per animal"}) + >>> table.equals(table) + True + >>> table.equals(table_0) + False + >>> table.equals(table_1) + True + >>> table.equals(table_1, check_metadata=True) + False + """ + self._assert_cpu() + if other is None: + return False + + cdef: + CTable* this_table = self.table + CTable* other_table = other.table + c_bool result + + with nogil: + result = this_table.Equals(deref(other_table), check_metadata) + + return result + + def cast(self, Schema target_schema, safe=None, options=None): + """ + Cast table values to another schema. + + Parameters + ---------- + target_schema : Schema + Schema to cast to, the names and order of fields must match. + safe : bool, default True + Check for overflows or other unsafe conversions. + options : CastOptions, default None + Additional checks pass by CastOptions + + Returns + ------- + Table + + Examples + -------- + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + >>> table = pa.Table.from_pandas(df) + >>> table.schema + n_legs: int64 + animals: string + -- schema metadata -- + pandas: '{"index_columns": [{"kind": "range", "name": null, "start": 0, ... + + Define new schema and cast table values: + + >>> my_schema = pa.schema([ + ... pa.field('n_legs', pa.duration('s')), + ... pa.field('animals', pa.string())] + ... ) + >>> table.cast(target_schema=my_schema) + pyarrow.Table + n_legs: duration[s] + animals: string + ---- + n_legs: [[2,4,5,100]] + animals: [["Flamingo","Horse","Brittle stars","Centipede"]] + """ + self._assert_cpu() + cdef: + ChunkedArray column, casted + Field field + list newcols = [] + + if self.schema.names != target_schema.names: + raise ValueError("Target schema's field names are not matching " + "the table's field names: {!r}, {!r}" + .format(self.schema.names, target_schema.names)) + + for column, field in zip(self.itercolumns(), target_schema): + if not field.nullable and column.null_count > 0: + raise ValueError("Casting field {!r} with null values to non-nullable" + .format(field.name)) + casted = column.cast(field.type, safe=safe, options=options) + newcols.append(casted) + + return Table.from_arrays(newcols, schema=target_schema) + + @classmethod + def from_pandas(cls, df, Schema schema=None, preserve_index=None, + nthreads=None, columns=None, bint safe=True): + """ + Convert pandas.DataFrame to an Arrow Table. + + The column types in the resulting Arrow Table are inferred from the + dtypes of the pandas.Series in the DataFrame. In the case of non-object + Series, the NumPy dtype is translated to its Arrow equivalent. In the + case of `object`, we need to guess the datatype by looking at the + Python objects in this Series. + + Be aware that Series of the `object` dtype don't carry enough + information to always lead to a meaningful Arrow type. In the case that + we cannot infer a type, e.g. because the DataFrame is of length 0 or + the Series only contains None/nan objects, the type is set to + null. This behavior can be avoided by constructing an explicit schema + and passing it to this function. + + Parameters + ---------- + df : pandas.DataFrame + schema : pyarrow.Schema, optional + The expected schema of the Arrow Table. This can be used to + indicate the type of columns if we cannot infer it automatically. + If passed, the output will have exactly this schema. Columns + specified in the schema that are not found in the DataFrame columns + or its index will raise an error. Additional columns or index + levels in the DataFrame which are not specified in the schema will + be ignored. + preserve_index : bool, optional + Whether to store the index as an additional column in the resulting + ``Table``. The default of None will store the index as a column, + except for RangeIndex which is stored as metadata only. Use + ``preserve_index=True`` to force it to be stored as a column. + nthreads : int, default None + If greater than 1, convert columns to Arrow in parallel using + indicated number of threads. By default, this follows + :func:`pyarrow.cpu_count` (may use up to system CPU count threads). + columns : list, optional + List of column to be converted. If None, use all columns. + safe : bool, default True + Check for overflows or other unsafe conversions. + + Returns + ------- + Table + + Examples + -------- + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + >>> pa.Table.from_pandas(df) + pyarrow.Table + n_legs: int64 + animals: string + ---- + n_legs: [[2,4,5,100]] + animals: [["Flamingo","Horse","Brittle stars","Centipede"]] + """ + from pyarrow.pandas_compat import dataframe_to_arrays + arrays, schema, n_rows = dataframe_to_arrays( + df, + schema=schema, + preserve_index=preserve_index, + nthreads=nthreads, + columns=columns, + safe=safe + ) + + # If df is empty but row index is not, create empty Table with rows >0 + cdef vector[shared_ptr[CChunkedArray]] c_arrays + if n_rows: + return pyarrow_wrap_table( + CTable.MakeWithRows(( schema).sp_schema, c_arrays, n_rows)) + else: + return cls.from_arrays(arrays, schema=schema) + + @staticmethod + def from_arrays(arrays, names=None, schema=None, metadata=None): + """ + Construct a Table from Arrow arrays. + + Parameters + ---------- + arrays : list of pyarrow.Array or pyarrow.ChunkedArray + Equal-length arrays that should form the table. + names : list of str, optional + Names for the table columns. If not passed, schema must be passed. + schema : Schema, default None + Schema for the created table. If not passed, names must be passed. + metadata : dict or Mapping, default None + Optional metadata for the schema (if inferred). + + Returns + ------- + Table + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.array([2, 4, 5, 100]) + >>> animals = pa.array(["Flamingo", "Horse", "Brittle stars", "Centipede"]) + >>> names = ["n_legs", "animals"] + + Construct a Table from arrays: + + >>> pa.Table.from_arrays([n_legs, animals], names=names) + pyarrow.Table + n_legs: int64 + animals: string + ---- + n_legs: [[2,4,5,100]] + animals: [["Flamingo","Horse","Brittle stars","Centipede"]] + + Construct a Table from arrays with metadata: + + >>> my_metadata={"n_legs": "Number of legs per animal"} + >>> pa.Table.from_arrays([n_legs, animals], + ... names=names, + ... metadata=my_metadata) + pyarrow.Table + n_legs: int64 + animals: string + ---- + n_legs: [[2,4,5,100]] + animals: [["Flamingo","Horse","Brittle stars","Centipede"]] + >>> pa.Table.from_arrays([n_legs, animals], + ... names=names, + ... metadata=my_metadata).schema + n_legs: int64 + animals: string + -- schema metadata -- + n_legs: 'Number of legs per animal' + + Construct a Table from arrays with pyarrow schema: + + >>> my_schema = pa.schema([ + ... pa.field('n_legs', pa.int64()), + ... pa.field('animals', pa.string())], + ... metadata={"animals": "Name of the animal species"}) + >>> pa.Table.from_arrays([n_legs, animals], + ... schema=my_schema) + pyarrow.Table + n_legs: int64 + animals: string + ---- + n_legs: [[2,4,5,100]] + animals: [["Flamingo","Horse","Brittle stars","Centipede"]] + >>> pa.Table.from_arrays([n_legs, animals], + ... schema=my_schema).schema + n_legs: int64 + animals: string + -- schema metadata -- + animals: 'Name of the animal species' + """ + cdef: + vector[shared_ptr[CChunkedArray]] columns + shared_ptr[CSchema] c_schema + int i, K = len(arrays) + + converted_arrays = _sanitize_arrays(arrays, names, schema, metadata, + &c_schema) + + columns.reserve(K) + for item in converted_arrays: + if isinstance(item, Array): + columns.push_back( + make_shared[CChunkedArray]( + ( item).sp_array + ) + ) + elif isinstance(item, ChunkedArray): + columns.push_back(( item).sp_chunked_array) + else: + raise TypeError(type(item)) + + result = pyarrow_wrap_table(CTable.Make(c_schema, columns)) + result.validate() + return result + + @staticmethod + def from_struct_array(struct_array): + """ + Construct a Table from a StructArray. + + Each field in the StructArray will become a column in the resulting + ``Table``. + + Parameters + ---------- + struct_array : StructArray or ChunkedArray + Array to construct the table from. + + Returns + ------- + pyarrow.Table + + Examples + -------- + >>> import pyarrow as pa + >>> struct = pa.array([{'n_legs': 2, 'animals': 'Parrot'}, + ... {'year': 2022, 'n_legs': 4}]) + >>> pa.Table.from_struct_array(struct).to_pandas() + animals n_legs year + 0 Parrot 2 NaN + 1 None 4 2022.0 + """ + if isinstance(struct_array, Array): + return Table.from_batches([RecordBatch.from_struct_array(struct_array)]) + else: + return Table.from_batches([ + RecordBatch.from_struct_array(chunk) + for chunk in struct_array.chunks + ]) + + def to_struct_array(self, max_chunksize=None): + """ + Convert to a chunked array of struct type. + + Parameters + ---------- + max_chunksize : int, default None + Maximum number of rows for ChunkedArray chunks. Individual chunks + may be smaller depending on the chunk layout of individual columns. + + Returns + ------- + ChunkedArray + """ + self._assert_cpu() + return chunked_array([ + batch.to_struct_array() + for batch in self.to_batches(max_chunksize=max_chunksize) + ]) + + @staticmethod + def from_batches(batches, Schema schema=None): + """ + Construct a Table from a sequence or iterator of Arrow RecordBatches. + + Parameters + ---------- + batches : sequence or iterator of RecordBatch + Sequence of RecordBatch to be converted, all schemas must be equal. + schema : Schema, default None + If not passed, will be inferred from the first RecordBatch. + + Returns + ------- + Table + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.array([2, 4, 5, 100]) + >>> animals = pa.array(["Flamingo", "Horse", "Brittle stars", "Centipede"]) + >>> names = ["n_legs", "animals"] + >>> batch = pa.record_batch([n_legs, animals], names=names) + >>> batch.to_pandas() + n_legs animals + 0 2 Flamingo + 1 4 Horse + 2 5 Brittle stars + 3 100 Centipede + + Construct a Table from a RecordBatch: + + >>> pa.Table.from_batches([batch]) + pyarrow.Table + n_legs: int64 + animals: string + ---- + n_legs: [[2,4,5,100]] + animals: [["Flamingo","Horse","Brittle stars","Centipede"]] + + Construct a Table from a sequence of RecordBatches: + + >>> pa.Table.from_batches([batch, batch]) + pyarrow.Table + n_legs: int64 + animals: string + ---- + n_legs: [[2,4,5,100],[2,4,5,100]] + animals: [["Flamingo","Horse","Brittle stars","Centipede"],["Flamingo","Horse","Brittle stars","Centipede"]] + """ + cdef: + vector[shared_ptr[CRecordBatch]] c_batches + shared_ptr[CTable] c_table + shared_ptr[CSchema] c_schema + RecordBatch batch + + for batch in batches: + c_batches.push_back(batch.sp_batch) + + if schema is None: + if c_batches.size() == 0: + raise ValueError('Must pass schema, or at least ' + 'one RecordBatch') + c_schema = c_batches[0].get().schema() + else: + c_schema = schema.sp_schema + + with nogil: + c_table = GetResultValue( + CTable.FromRecordBatches(c_schema, move(c_batches))) + + return pyarrow_wrap_table(c_table) + + def to_batches(self, max_chunksize=None): + """ + Convert Table to a list of RecordBatch objects. + + Note that this method is zero-copy, it merely exposes the same data + under a different API. + + Parameters + ---------- + max_chunksize : int, default None + Maximum number of rows for each RecordBatch chunk. Individual chunks + may be smaller depending on the chunk layout of individual columns. + + Returns + ------- + list[RecordBatch] + + Examples + -------- + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + >>> table = pa.Table.from_pandas(df) + + Convert a Table to a RecordBatch: + + >>> table.to_batches()[0].to_pandas() + n_legs animals + 0 2 Flamingo + 1 4 Horse + 2 5 Brittle stars + 3 100 Centipede + + Convert a Table to a list of RecordBatches: + + >>> table.to_batches(max_chunksize=2)[0].to_pandas() + n_legs animals + 0 2 Flamingo + 1 4 Horse + >>> table.to_batches(max_chunksize=2)[1].to_pandas() + n_legs animals + 0 5 Brittle stars + 1 100 Centipede + """ + cdef: + unique_ptr[TableBatchReader] reader + int64_t c_max_chunksize + list result = [] + shared_ptr[CRecordBatch] batch + + reader.reset(new TableBatchReader(deref(self.table))) + + if max_chunksize is not None: + if not max_chunksize > 0: + raise ValueError("'max_chunksize' should be strictly positive") + c_max_chunksize = max_chunksize + reader.get().set_chunksize(c_max_chunksize) + + while True: + with nogil: + check_status(reader.get().ReadNext(&batch)) + + if batch.get() == NULL: + break + + result.append(pyarrow_wrap_batch(batch)) + + return result + + def to_reader(self, max_chunksize=None): + """ + Convert the Table to a RecordBatchReader. + + Note that this method is zero-copy, it merely exposes the same data + under a different API. + + Parameters + ---------- + max_chunksize : int, default None + Maximum number of rows for each RecordBatch chunk. Individual chunks + may be smaller depending on the chunk layout of individual columns. + + Returns + ------- + RecordBatchReader + + Examples + -------- + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + >>> table = pa.Table.from_pandas(df) + + Convert a Table to a RecordBatchReader: + + >>> table.to_reader() + + + >>> reader = table.to_reader() + >>> reader.schema + n_legs: int64 + animals: string + -- schema metadata -- + pandas: '{"index_columns": [{"kind": "range", "name": null, "start": 0, ... + >>> reader.read_all() + pyarrow.Table + n_legs: int64 + animals: string + ---- + n_legs: [[2,4,5,100]] + animals: [["Flamingo","Horse","Brittle stars","Centipede"]] + """ + cdef: + shared_ptr[CRecordBatchReader] c_reader + RecordBatchReader reader + shared_ptr[TableBatchReader] t_reader + t_reader = make_shared[TableBatchReader](self.sp_table) + + if max_chunksize is not None: + t_reader.get().set_chunksize(max_chunksize) + + c_reader = dynamic_pointer_cast[CRecordBatchReader, TableBatchReader]( + t_reader) + reader = RecordBatchReader.__new__(RecordBatchReader) + reader.reader = c_reader + return reader + + def _to_pandas(self, options, categories=None, ignore_metadata=False, + types_mapper=None): + self._assert_cpu() + from pyarrow.pandas_compat import table_to_dataframe + df = table_to_dataframe( + options, self, categories, + ignore_metadata=ignore_metadata, + types_mapper=types_mapper) + return df + + @property + def schema(self): + """ + Schema of the table and its columns. + + Returns + ------- + Schema + + Examples + -------- + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + >>> table = pa.Table.from_pandas(df) + >>> table.schema + n_legs: int64 + animals: string + -- schema metadata -- + pandas: '{"index_columns": [{"kind": "range", "name": null, "start": 0, "' ... + """ + return pyarrow_wrap_schema(self.table.schema()) + + def _column(self, int i): + """ + Select a column by its numeric index. + + Parameters + ---------- + i : int + The index of the column to retrieve. + + Returns + ------- + ChunkedArray + """ + cdef int index = _normalize_index(i, self.num_columns) + cdef ChunkedArray result = pyarrow_wrap_chunked_array( + self.table.column(index)) + result._name = self.schema[index].name + return result + + @property + def num_columns(self): + """ + Number of columns in this table. + + Returns + ------- + int + + Examples + -------- + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'n_legs': [None, 4, 5, None], + ... 'animals': ["Flamingo", "Horse", None, "Centipede"]}) + >>> table = pa.Table.from_pandas(df) + >>> table.num_columns + 2 + """ + return self.table.num_columns() + + @property + def num_rows(self): + """ + Number of rows in this table. + + Due to the definition of a table, all columns have the same number of + rows. + + Returns + ------- + int + + Examples + -------- + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'n_legs': [None, 4, 5, None], + ... 'animals': ["Flamingo", "Horse", None, "Centipede"]}) + >>> table = pa.Table.from_pandas(df) + >>> table.num_rows + 4 + """ + return self.table.num_rows() + + @property + def nbytes(self): + """ + Total number of bytes consumed by the elements of the table. + + In other words, the sum of bytes from all buffer ranges referenced. + + Unlike `get_total_buffer_size` this method will account for array + offsets. + + If buffers are shared between arrays then the shared + portion will only be counted multiple times. + + The dictionary of dictionary arrays will always be counted in their + entirety even if the array only references a portion of the dictionary. + + Examples + -------- + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'n_legs': [None, 4, 5, None], + ... 'animals': ["Flamingo", "Horse", None, "Centipede"]}) + >>> table = pa.Table.from_pandas(df) + >>> table.nbytes + 72 + """ + self._assert_cpu() + cdef: + CResult[int64_t] c_res_buffer + + with nogil: + c_res_buffer = ReferencedBufferSize(deref(self.table)) + size = GetResultValue(c_res_buffer) + return size + + def get_total_buffer_size(self): + """ + The sum of bytes in each buffer referenced by the table. + + An array may only reference a portion of a buffer. + This method will overestimate in this case and return the + byte size of the entire buffer. + + If a buffer is referenced multiple times then it will + only be counted once. + + Examples + -------- + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'n_legs': [None, 4, 5, None], + ... 'animals': ["Flamingo", "Horse", None, "Centipede"]}) + >>> table = pa.Table.from_pandas(df) + >>> table.get_total_buffer_size() + 76 + """ + self._assert_cpu() + cdef: + int64_t total_buffer_size + + total_buffer_size = TotalBufferSize(deref(self.table)) + return total_buffer_size + + def __sizeof__(self): + return super(Table, self).__sizeof__() + self.nbytes + + def add_column(self, int i, field_, column): + """ + Add column to Table at position. + + A new table is returned with the column added, the original table + object is left unchanged. + + Parameters + ---------- + i : int + Index to place the column at. + field_ : str or Field + If a string is passed then the type is deduced from the column + data. + column : Array, list of Array, or values coercible to arrays + Column data. + + Returns + ------- + Table + New table with the passed column added. + + Examples + -------- + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + >>> table = pa.Table.from_pandas(df) + + Add column: + + >>> year = [2021, 2022, 2019, 2021] + >>> table.add_column(0,"year", [year]) + pyarrow.Table + year: int64 + n_legs: int64 + animals: string + ---- + year: [[2021,2022,2019,2021]] + n_legs: [[2,4,5,100]] + animals: [["Flamingo","Horse","Brittle stars","Centipede"]] + + Original table is left unchanged: + + >>> table + pyarrow.Table + n_legs: int64 + animals: string + ---- + n_legs: [[2,4,5,100]] + animals: [["Flamingo","Horse","Brittle stars","Centipede"]] + """ + cdef: + shared_ptr[CTable] c_table + Field c_field + ChunkedArray c_arr + + if isinstance(column, ChunkedArray): + c_arr = column + else: + c_arr = chunked_array(column) + + if isinstance(field_, Field): + c_field = field_ + else: + c_field = field(field_, c_arr.type) + + with nogil: + c_table = GetResultValue(self.table.AddColumn( + i, c_field.sp_field, c_arr.sp_chunked_array)) + + return pyarrow_wrap_table(c_table) + + def remove_column(self, int i): + """ + Create new Table with the indicated column removed. + + Parameters + ---------- + i : int + Index of column to remove. + + Returns + ------- + Table + New table without the column. + + Examples + -------- + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + >>> table = pa.Table.from_pandas(df) + >>> table.remove_column(1) + pyarrow.Table + n_legs: int64 + ---- + n_legs: [[2,4,5,100]] + """ + cdef shared_ptr[CTable] c_table + + with nogil: + c_table = GetResultValue(self.table.RemoveColumn(i)) + + return pyarrow_wrap_table(c_table) + + def set_column(self, int i, field_, column): + """ + Replace column in Table at position. + + Parameters + ---------- + i : int + Index to place the column at. + field_ : str or Field + If a string is passed then the type is deduced from the column + data. + column : Array, list of Array, or values coercible to arrays + Column data. + + Returns + ------- + Table + New table with the passed column set. + + Examples + -------- + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + >>> table = pa.Table.from_pandas(df) + + Replace a column: + + >>> year = [2021, 2022, 2019, 2021] + >>> table.set_column(1,'year', [year]) + pyarrow.Table + n_legs: int64 + year: int64 + ---- + n_legs: [[2,4,5,100]] + year: [[2021,2022,2019,2021]] + """ + cdef: + shared_ptr[CTable] c_table + Field c_field + ChunkedArray c_arr + + if isinstance(column, ChunkedArray): + c_arr = column + else: + c_arr = chunked_array(column) + + if isinstance(field_, Field): + c_field = field_ + else: + c_field = field(field_, c_arr.type) + + with nogil: + c_table = GetResultValue(self.table.SetColumn( + i, c_field.sp_field, c_arr.sp_chunked_array)) + + return pyarrow_wrap_table(c_table) + + def rename_columns(self, names): + """ + Create new table with columns renamed to provided names. + + Parameters + ---------- + names : list[str] or dict[str, str] + List of new column names or mapping of old column names to new column names. + + If a mapping of old to new column names is passed, then all columns which are + found to match a provided old column name will be renamed to the new column name. + If any column names are not found in the mapping, a KeyError will be raised. + + Raises + ------ + KeyError + If any of the column names passed in the names mapping do not exist. + + Returns + ------- + Table + + Examples + -------- + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + >>> table = pa.Table.from_pandas(df) + >>> new_names = ["n", "name"] + >>> table.rename_columns(new_names) + pyarrow.Table + n: int64 + name: string + ---- + n: [[2,4,5,100]] + name: [["Flamingo","Horse","Brittle stars","Centipede"]] + >>> new_names = {"n_legs": "n", "animals": "name"} + >>> table.rename_columns(new_names) + pyarrow.Table + n: int64 + name: string + ---- + n: [[2,4,5,100]] + name: [["Flamingo","Horse","Brittle stars","Centipede"]] + """ + cdef: + shared_ptr[CTable] c_table + vector[c_string] c_names + + if isinstance(names, (list, tuple)): + for name in names: + c_names.push_back(tobytes(name)) + elif isinstance(names, dict): + idx_to_new_name = {} + for name, new_name in names.items(): + indices = self.schema.get_all_field_indices(name) + + if not indices: + raise KeyError("Column {!r} not found".format(name)) + + for index in indices: + idx_to_new_name[index] = new_name + + for i in range(self.num_columns): + c_names.push_back(tobytes(idx_to_new_name.get(i, self.schema[i].name))) + else: + raise TypeError(f"names must be a list or dict not {type(names)!r}") + + with nogil: + c_table = GetResultValue(self.table.RenameColumns(move(c_names))) + + return pyarrow_wrap_table(c_table) + + def drop(self, columns): + """ + Drop one or more columns and return a new table. + + Alias of Table.drop_columns, but kept for backwards compatibility. + + Parameters + ---------- + columns : str or list[str] + Field name(s) referencing existing column(s). + + Returns + ------- + Table + New table without the column(s). + """ + return self.drop_columns(columns) + + def group_by(self, keys, use_threads=True): + """ + Declare a grouping over the columns of the table. + + Resulting grouping can then be used to perform aggregations + with a subsequent ``aggregate()`` method. + + Parameters + ---------- + keys : str or list[str] + Name of the columns that should be used as the grouping key. + use_threads : bool, default True + Whether to use multithreading or not. When set to True (the + default), no stable ordering of the output is guaranteed. + + Returns + ------- + TableGroupBy + + See Also + -------- + TableGroupBy.aggregate + + Examples + -------- + >>> import pandas as pd + >>> import pyarrow as pa + >>> df = pd.DataFrame({'year': [2020, 2022, 2021, 2022, 2019, 2021], + ... 'n_legs': [2, 2, 4, 4, 5, 100], + ... 'animal': ["Flamingo", "Parrot", "Dog", "Horse", + ... "Brittle stars", "Centipede"]}) + >>> table = pa.Table.from_pandas(df) + >>> table.group_by('year').aggregate([('n_legs', 'sum')]) + pyarrow.Table + year: int64 + n_legs_sum: int64 + ---- + year: [[2020,2022,2021,2019]] + n_legs_sum: [[2,6,104,5]] + """ + self._assert_cpu() + return TableGroupBy(self, keys, use_threads=use_threads) + + def join(self, right_table, keys, right_keys=None, join_type="left outer", + left_suffix=None, right_suffix=None, coalesce_keys=True, + use_threads=True): + """ + Perform a join between this table and another one. + + Result of the join will be a new Table, where further + operations can be applied. + + Parameters + ---------- + right_table : Table + The table to join to the current one, acting as the right table + in the join operation. + keys : str or list[str] + The columns from current table that should be used as keys + of the join operation left side. + right_keys : str or list[str], default None + The columns from the right_table that should be used as keys + on the join operation right side. + When ``None`` use the same key names as the left table. + join_type : str, default "left outer" + The kind of join that should be performed, one of + ("left semi", "right semi", "left anti", "right anti", + "inner", "left outer", "right outer", "full outer") + left_suffix : str, default None + Which suffix to add to left column names. This prevents confusion + when the columns in left and right tables have colliding names. + right_suffix : str, default None + Which suffix to add to the right column names. This prevents confusion + when the columns in left and right tables have colliding names. + coalesce_keys : bool, default True + If the duplicated keys should be omitted from one of the sides + in the join result. + use_threads : bool, default True + Whether to use multithreading or not. + + Returns + ------- + Table + + Examples + -------- + >>> import pandas as pd + >>> import pyarrow as pa + >>> df1 = pd.DataFrame({'id': [1, 2, 3], + ... 'year': [2020, 2022, 2019]}) + >>> df2 = pd.DataFrame({'id': [3, 4], + ... 'n_legs': [5, 100], + ... 'animal': ["Brittle stars", "Centipede"]}) + >>> t1 = pa.Table.from_pandas(df1) + >>> t2 = pa.Table.from_pandas(df2) + + Left outer join: + + >>> t1.join(t2, 'id').combine_chunks().sort_by('year') + pyarrow.Table + id: int64 + year: int64 + n_legs: int64 + animal: string + ---- + id: [[3,1,2]] + year: [[2019,2020,2022]] + n_legs: [[5,null,null]] + animal: [["Brittle stars",null,null]] + + Full outer join: + + >>> t1.join(t2, 'id', join_type="full outer").combine_chunks().sort_by('year') + pyarrow.Table + id: int64 + year: int64 + n_legs: int64 + animal: string + ---- + id: [[3,1,2,4]] + year: [[2019,2020,2022,null]] + n_legs: [[5,null,null,100]] + animal: [["Brittle stars",null,null,"Centipede"]] + + Right outer join: + + >>> t1.join(t2, 'id', join_type="right outer").combine_chunks().sort_by('year') + pyarrow.Table + year: int64 + id: int64 + n_legs: int64 + animal: string + ---- + year: [[2019,null]] + id: [[3,4]] + n_legs: [[5,100]] + animal: [["Brittle stars","Centipede"]] + + Right anti join + + >>> t1.join(t2, 'id', join_type="right anti") + pyarrow.Table + id: int64 + n_legs: int64 + animal: string + ---- + id: [[4]] + n_legs: [[100]] + animal: [["Centipede"]] + """ + self._assert_cpu() + if right_keys is None: + right_keys = keys + return _pac()._perform_join( + join_type, self, keys, right_table, right_keys, + left_suffix=left_suffix, right_suffix=right_suffix, + use_threads=use_threads, coalesce_keys=coalesce_keys, + output_type=Table + ) + + def join_asof(self, right_table, on, by, tolerance, right_on=None, right_by=None): + """ + Perform an asof join between this table and another one. + + This is similar to a left-join except that we match on nearest key rather + than equal keys. Both tables must be sorted by the key. This type of join + is most useful for time series data that are not perfectly aligned. + + Optionally match on equivalent keys with "by" before searching with "on". + + Result of the join will be a new Table, where further + operations can be applied. + + Parameters + ---------- + right_table : Table + The table to join to the current one, acting as the right table + in the join operation. + on : str + The column from current table that should be used as the "on" key + of the join operation left side. + + An inexact match is used on the "on" key, i.e. a row is considered a + match if and only if left_on - tolerance <= right_on <= left_on. + + The input dataset must be sorted by the "on" key. Must be a single + field of a common type. + + Currently, the "on" key must be an integer, date, or timestamp type. + by : str or list[str] + The columns from current table that should be used as the keys + of the join operation left side. The join operation is then done + only for the matches in these columns. + tolerance : int + The tolerance for inexact "on" key matching. A right row is considered + a match with the left row ``right.on - left.on <= tolerance``. The + ``tolerance`` may be: + + - negative, in which case a past-as-of-join occurs; + - or positive, in which case a future-as-of-join occurs; + - or zero, in which case an exact-as-of-join occurs. + + The tolerance is interpreted in the same units as the "on" key. + right_on : str or list[str], default None + The columns from the right_table that should be used as the on key + on the join operation right side. + When ``None`` use the same key name as the left table. + right_by : str or list[str], default None + The columns from the right_table that should be used as keys + on the join operation right side. + When ``None`` use the same key names as the left table. + + Returns + ------- + Table + + Example + -------- + >>> import pyarrow as pa + >>> t1 = pa.table({'id': [1, 3, 2, 3, 3], + ... 'year': [2020, 2021, 2022, 2022, 2023]}) + >>> t2 = pa.table({'id': [3, 4], + ... 'year': [2020, 2021], + ... 'n_legs': [5, 100], + ... 'animal': ["Brittle stars", "Centipede"]}) + + >>> t1.join_asof(t2, on='year', by='id', tolerance=-2) + pyarrow.Table + id: int64 + year: int64 + n_legs: int64 + animal: string + ---- + id: [[1,3,2,3,3]] + year: [[2020,2021,2022,2022,2023]] + n_legs: [[null,5,null,5,null]] + animal: [[null,"Brittle stars",null,"Brittle stars",null]] + """ + self._assert_cpu() + if right_on is None: + right_on = on + if right_by is None: + right_by = by + return _pac()._perform_join_asof(self, on, by, + right_table, right_on, right_by, + tolerance, output_type=Table) + + def __arrow_c_stream__(self, requested_schema=None): + """ + Export the table as an Arrow C stream PyCapsule. + + Parameters + ---------- + requested_schema : PyCapsule, default None + The schema to which the stream should be casted, passed as a + PyCapsule containing a C ArrowSchema representation of the + requested schema. + Currently, this is not supported and will raise a + NotImplementedError if the schema doesn't match the current schema. + + Returns + ------- + PyCapsule + """ + self._assert_cpu() + return self.to_reader().__arrow_c_stream__(requested_schema) + + @property + def is_cpu(self): + """ + Whether all ChunkedArrays are CPU-accessible. + """ + if not self._init_is_cpu: + self._is_cpu = all(c.is_cpu for c in self.itercolumns()) + self._init_is_cpu = True + return self._is_cpu + + cdef void _assert_cpu(self) except *: + if not self.is_cpu: + raise NotImplementedError("Implemented only for data on CPU device") + + +def _reconstruct_table(arrays, schema): + """ + Internal: reconstruct pa.Table from pickled components. + """ + return Table.from_arrays(arrays, schema=schema) + + +def record_batch(data, names=None, schema=None, metadata=None): + """ + Create a pyarrow.RecordBatch from another Python data structure or sequence + of arrays. + + Parameters + ---------- + data : dict, list, pandas.DataFrame, Arrow-compatible table + A mapping of strings to Arrays or Python lists, a list of Arrays, + a pandas DataFame, or any tabular object implementing the + Arrow PyCapsule Protocol (has an ``__arrow_c_array__`` or + ``__arrow_c_device_array__`` method). + names : list, default None + Column names if list of arrays passed as data. Mutually exclusive with + 'schema' argument. + schema : Schema, default None + The expected schema of the RecordBatch. If not passed, will be inferred + from the data. Mutually exclusive with 'names' argument. + metadata : dict or Mapping, default None + Optional metadata for the schema (if schema not passed). + + Returns + ------- + RecordBatch + + See Also + -------- + RecordBatch.from_arrays, RecordBatch.from_pandas, table + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.array([2, 2, 4, 4, 5, 100]) + >>> animals = pa.array(["Flamingo", "Parrot", "Dog", "Horse", "Brittle stars", "Centipede"]) + >>> names = ["n_legs", "animals"] + + Construct a RecordBatch from a python dictionary: + + >>> pa.record_batch({"n_legs": n_legs, "animals": animals}) + pyarrow.RecordBatch + n_legs: int64 + animals: string + ---- + n_legs: [2,2,4,4,5,100] + animals: ["Flamingo","Parrot","Dog","Horse","Brittle stars","Centipede"] + >>> pa.record_batch({"n_legs": n_legs, "animals": animals}).to_pandas() + n_legs animals + 0 2 Flamingo + 1 2 Parrot + 2 4 Dog + 3 4 Horse + 4 5 Brittle stars + 5 100 Centipede + + Creating a RecordBatch from a list of arrays with names: + + >>> pa.record_batch([n_legs, animals], names=names) + pyarrow.RecordBatch + n_legs: int64 + animals: string + ---- + n_legs: [2,2,4,4,5,100] + animals: ["Flamingo","Parrot","Dog","Horse","Brittle stars","Centipede"] + + Creating a RecordBatch from a list of arrays with names and metadata: + + >>> my_metadata={"n_legs": "How many legs does an animal have?"} + >>> pa.record_batch([n_legs, animals], + ... names=names, + ... metadata = my_metadata) + pyarrow.RecordBatch + n_legs: int64 + animals: string + ---- + n_legs: [2,2,4,4,5,100] + animals: ["Flamingo","Parrot","Dog","Horse","Brittle stars","Centipede"] + >>> pa.record_batch([n_legs, animals], + ... names=names, + ... metadata = my_metadata).schema + n_legs: int64 + animals: string + -- schema metadata -- + n_legs: 'How many legs does an animal have?' + + Creating a RecordBatch from a pandas DataFrame: + + >>> import pandas as pd + >>> df = pd.DataFrame({'year': [2020, 2022, 2021, 2022], + ... 'month': [3, 5, 7, 9], + ... 'day': [1, 5, 9, 13], + ... 'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + >>> pa.record_batch(df) + pyarrow.RecordBatch + year: int64 + month: int64 + day: int64 + n_legs: int64 + animals: string + ---- + year: [2020,2022,2021,2022] + month: [3,5,7,9] + day: [1,5,9,13] + n_legs: [2,4,5,100] + animals: ["Flamingo","Horse","Brittle stars","Centipede"] + + >>> pa.record_batch(df).to_pandas() + year month day n_legs animals + 0 2020 3 1 2 Flamingo + 1 2022 5 5 4 Horse + 2 2021 7 9 5 Brittle stars + 3 2022 9 13 100 Centipede + + Creating a RecordBatch from a pandas DataFrame with schema: + + >>> my_schema = pa.schema([ + ... pa.field('n_legs', pa.int64()), + ... pa.field('animals', pa.string())], + ... metadata={"n_legs": "Number of legs per animal"}) + >>> pa.record_batch(df, my_schema).schema + n_legs: int64 + animals: string + -- schema metadata -- + n_legs: 'Number of legs per animal' + pandas: ... + >>> pa.record_batch(df, my_schema).to_pandas() + n_legs animals + 0 2 Flamingo + 1 4 Horse + 2 5 Brittle stars + 3 100 Centipede + """ + # accept schema as first argument for backwards compatibility / usability + if isinstance(names, Schema) and schema is None: + schema = names + names = None + + if isinstance(data, (list, tuple)): + return RecordBatch.from_arrays(data, names=names, schema=schema, + metadata=metadata) + elif isinstance(data, dict): + if names is not None: + raise ValueError( + "The 'names' argument is not valid when passing a dictionary") + return RecordBatch.from_pydict(data, schema=schema, metadata=metadata) + elif hasattr(data, "__arrow_c_device_array__"): + if schema is not None: + requested_schema = schema.__arrow_c_schema__() + else: + requested_schema = None + schema_capsule, array_capsule = data.__arrow_c_device_array__(requested_schema) + batch = RecordBatch._import_from_c_device_capsule(schema_capsule, array_capsule) + if schema is not None and batch.schema != schema: + # __arrow_c_device_array__ coerces schema with best effort, so we might + # need to cast it if the producer wasn't able to cast to exact schema. + batch = batch.cast(schema) + return batch + elif hasattr(data, "__arrow_c_array__"): + if schema is not None: + requested_schema = schema.__arrow_c_schema__() + else: + requested_schema = None + schema_capsule, array_capsule = data.__arrow_c_array__(requested_schema) + batch = RecordBatch._import_from_c_capsule(schema_capsule, array_capsule) + if schema is not None and batch.schema != schema: + # __arrow_c_array__ coerces schema with best effort, so we might + # need to cast it if the producer wasn't able to cast to exact schema. + batch = batch.cast(schema) + return batch + + elif _pandas_api.is_data_frame(data): + return RecordBatch.from_pandas(data, schema=schema) + + else: + raise TypeError("Expected pandas DataFrame or list of arrays") + + +def table(data, names=None, schema=None, metadata=None, nthreads=None): + """ + Create a pyarrow.Table from a Python data structure or sequence of arrays. + + Parameters + ---------- + data : dict, list, pandas.DataFrame, Arrow-compatible table + A mapping of strings to Arrays or Python lists, a list of arrays or + chunked arrays, a pandas DataFame, or any tabular object implementing + the Arrow PyCapsule Protocol (has an ``__arrow_c_array__``, + ``__arrow_c_device_array__`` or ``__arrow_c_stream__`` method). + names : list, default None + Column names if list of arrays passed as data. Mutually exclusive with + 'schema' argument. + schema : Schema, default None + The expected schema of the Arrow Table. If not passed, will be inferred + from the data. Mutually exclusive with 'names' argument. + If passed, the output will have exactly this schema (raising an error + when columns are not found in the data and ignoring additional data not + specified in the schema, when data is a dict or DataFrame). + metadata : dict or Mapping, default None + Optional metadata for the schema (if schema not passed). + nthreads : int, default None + For pandas.DataFrame inputs: if greater than 1, convert columns to + Arrow in parallel using indicated number of threads. By default, + this follows :func:`pyarrow.cpu_count` (may use up to system CPU count + threads). + + Returns + ------- + Table + + See Also + -------- + Table.from_arrays, Table.from_pandas, Table.from_pydict + + Examples + -------- + >>> import pyarrow as pa + >>> n_legs = pa.array([2, 4, 5, 100]) + >>> animals = pa.array(["Flamingo", "Horse", "Brittle stars", "Centipede"]) + >>> names = ["n_legs", "animals"] + + Construct a Table from a python dictionary: + + >>> pa.table({"n_legs": n_legs, "animals": animals}) + pyarrow.Table + n_legs: int64 + animals: string + ---- + n_legs: [[2,4,5,100]] + animals: [["Flamingo","Horse","Brittle stars","Centipede"]] + + Construct a Table from arrays: + + >>> pa.table([n_legs, animals], names=names) + pyarrow.Table + n_legs: int64 + animals: string + ---- + n_legs: [[2,4,5,100]] + animals: [["Flamingo","Horse","Brittle stars","Centipede"]] + + Construct a Table from arrays with metadata: + + >>> my_metadata={"n_legs": "Number of legs per animal"} + >>> pa.table([n_legs, animals], names=names, metadata = my_metadata).schema + n_legs: int64 + animals: string + -- schema metadata -- + n_legs: 'Number of legs per animal' + + Construct a Table from pandas DataFrame: + + >>> import pandas as pd + >>> df = pd.DataFrame({'year': [2020, 2022, 2019, 2021], + ... 'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + >>> pa.table(df) + pyarrow.Table + year: int64 + n_legs: int64 + animals: string + ---- + year: [[2020,2022,2019,2021]] + n_legs: [[2,4,5,100]] + animals: [["Flamingo","Horse","Brittle stars","Centipede"]] + + Construct a Table from pandas DataFrame with pyarrow schema: + + >>> my_schema = pa.schema([ + ... pa.field('n_legs', pa.int64()), + ... pa.field('animals', pa.string())], + ... metadata={"n_legs": "Number of legs per animal"}) + >>> pa.table(df, my_schema).schema + n_legs: int64 + animals: string + -- schema metadata -- + n_legs: 'Number of legs per animal' + pandas: '{"index_columns": [], "column_indexes": [{"name": null, ... + + Construct a Table from chunked arrays: + + >>> n_legs = pa.chunked_array([[2, 2, 4], [4, 5, 100]]) + >>> animals = pa.chunked_array([["Flamingo", "Parrot", "Dog"], ["Horse", "Brittle stars", "Centipede"]]) + >>> table = pa.table([n_legs, animals], names=names) + >>> table + pyarrow.Table + n_legs: int64 + animals: string + ---- + n_legs: [[2,2,4],[4,5,100]] + animals: [["Flamingo","Parrot","Dog"],["Horse","Brittle stars","Centipede"]] + """ + # accept schema as first argument for backwards compatibility / usability + if isinstance(names, Schema) and schema is None: + schema = names + names = None + + if isinstance(data, (list, tuple)): + return Table.from_arrays(data, names=names, schema=schema, + metadata=metadata) + elif isinstance(data, dict): + if names is not None: + raise ValueError( + "The 'names' argument is not valid when passing a dictionary") + return Table.from_pydict(data, schema=schema, metadata=metadata) + elif _pandas_api.is_data_frame(data): + if names is not None or metadata is not None: + raise ValueError( + "The 'names' and 'metadata' arguments are not valid when " + "passing a pandas DataFrame") + return Table.from_pandas(data, schema=schema, nthreads=nthreads) + elif hasattr(data, "__arrow_c_stream__"): + if names is not None or metadata is not None: + raise ValueError( + "The 'names' and 'metadata' arguments are not valid when " + "using Arrow PyCapsule Interface") + if schema is not None: + requested = schema.__arrow_c_schema__() + else: + requested = None + capsule = data.__arrow_c_stream__(requested) + reader = RecordBatchReader._import_from_c_capsule(capsule) + table = reader.read_all() + if schema is not None and table.schema != schema: + # __arrow_c_array__ coerces schema with best effort, so we might + # need to cast it if the producer wasn't able to cast to exact schema. + table = table.cast(schema) + return table + elif hasattr(data, "__arrow_c_array__") or hasattr(data, "__arrow_c_device_array__"): + if names is not None or metadata is not None: + raise ValueError( + "The 'names' and 'metadata' arguments are not valid when " + "using Arrow PyCapsule Interface") + batch = record_batch(data, schema) + return Table.from_batches([batch]) + else: + raise TypeError( + "Expected pandas DataFrame, python dictionary or list of arrays") + + +def concat_tables(tables, MemoryPool memory_pool=None, str promote_options="none", **kwargs): + """ + Concatenate pyarrow.Table objects. + + If promote_options="none", a zero-copy concatenation will be performed. The schemas + of all the Tables must be the same (except the metadata), otherwise an + exception will be raised. The result Table will share the metadata with the + first table. + + If promote_options="default", any null type arrays will be casted to the type of other + arrays in the column of the same name. If a table is missing a particular + field, null values of the appropriate type will be generated to take the + place of the missing field. The new schema will share the metadata with the + first table. Each field in the new schema will share the metadata with the + first table which has the field defined. Note that type promotions may + involve additional allocations on the given ``memory_pool``. + + If promote_options="permissive", the behavior of default plus types will be promoted + to the common denominator that fits all the fields. + + Parameters + ---------- + tables : iterable of pyarrow.Table objects + Pyarrow tables to concatenate into a single Table. + memory_pool : MemoryPool, default None + For memory allocations, if required, otherwise use default pool. + promote_options : str, default none + Accepts strings "none", "default" and "permissive". + **kwargs : dict, optional + + Examples + -------- + >>> import pyarrow as pa + >>> t1 = pa.table([ + ... pa.array([2, 4, 5, 100]), + ... pa.array(["Flamingo", "Horse", "Brittle stars", "Centipede"]) + ... ], names=['n_legs', 'animals']) + >>> t2 = pa.table([ + ... pa.array([2, 4]), + ... pa.array(["Parrot", "Dog"]) + ... ], names=['n_legs', 'animals']) + >>> pa.concat_tables([t1,t2]) + pyarrow.Table + n_legs: int64 + animals: string + ---- + n_legs: [[2,4,5,100],[2,4]] + animals: [["Flamingo","Horse","Brittle stars","Centipede"],["Parrot","Dog"]] + + """ + cdef: + vector[shared_ptr[CTable]] c_tables + shared_ptr[CTable] c_result_table + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + Table table + CConcatenateTablesOptions options = ( + CConcatenateTablesOptions.Defaults()) + + if "promote" in kwargs: + warnings.warn( + "promote has been superseded by promote_options='default'.", + FutureWarning, stacklevel=2) + if kwargs['promote'] is True: + promote_options = "default" + + for table in tables: + c_tables.push_back(table.sp_table) + + if promote_options == "permissive": + options.field_merge_options = CField.CMergeOptions.Permissive() + elif promote_options in {"default", "none"}: + options.field_merge_options = CField.CMergeOptions.Defaults() + else: + raise ValueError(f"Invalid promote options: {promote_options}") + + with nogil: + options.unify_schemas = promote_options != "none" + c_result_table = GetResultValue( + ConcatenateTables(c_tables, options, pool)) + + return pyarrow_wrap_table(c_result_table) + + +def concat_batches(recordbatches, MemoryPool memory_pool=None): + """ + Concatenate pyarrow.RecordBatch objects. + + All recordbatches must share the same Schema, + the operation implies a copy of the data to merge + the arrays of the different RecordBatches. + + Parameters + ---------- + recordbatches : iterable of pyarrow.RecordBatch objects + Pyarrow record batches to concatenate into a single RecordBatch. + memory_pool : MemoryPool, default None + For memory allocations, if required, otherwise use default pool. + + Examples + -------- + >>> import pyarrow as pa + >>> t1 = pa.record_batch([ + ... pa.array([2, 4, 5, 100]), + ... pa.array(["Flamingo", "Horse", "Brittle stars", "Centipede"]) + ... ], names=['n_legs', 'animals']) + >>> t2 = pa.record_batch([ + ... pa.array([2, 4]), + ... pa.array(["Parrot", "Dog"]) + ... ], names=['n_legs', 'animals']) + >>> pa.concat_batches([t1,t2]) + pyarrow.RecordBatch + n_legs: int64 + animals: string + ---- + n_legs: [2,4,5,100,2,4] + animals: ["Flamingo","Horse","Brittle stars","Centipede","Parrot","Dog"] + + """ + cdef: + vector[shared_ptr[CRecordBatch]] c_recordbatches + shared_ptr[CRecordBatch] c_result_recordbatch + RecordBatch recordbatch + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + + for recordbatch in recordbatches: + c_recordbatches.push_back(recordbatch.sp_batch) + + with nogil: + c_result_recordbatch = GetResultValue( + ConcatenateRecordBatches(c_recordbatches, pool)) + + return pyarrow_wrap_batch(c_result_recordbatch) + + +def _from_pydict(cls, mapping, schema, metadata): + """ + Construct a Table/RecordBatch from Arrow arrays or columns. + + Parameters + ---------- + cls : Class Table/RecordBatch + mapping : dict or Mapping + A mapping of strings to Arrays or Python lists. + schema : Schema, default None + If not passed, will be inferred from the Mapping values. + metadata : dict or Mapping, default None + Optional metadata for the schema (if inferred). + + Returns + ------- + Table/RecordBatch + """ + + arrays = [] + if schema is None: + names = [] + for k, v in mapping.items(): + names.append(k) + arrays.append(asarray(v)) + return cls.from_arrays(arrays, names, metadata=metadata) + elif isinstance(schema, Schema): + for field in schema: + try: + v = mapping[field.name] + except KeyError: + try: + v = mapping[tobytes(field.name)] + except KeyError: + present = mapping.keys() + missing = [n for n in schema.names if n not in present] + raise KeyError( + "The passed mapping doesn't contain the " + "following field(s) of the schema: {}". + format(', '.join(missing)) + ) + arrays.append(asarray(v, type=field.type)) + # Will raise if metadata is not None + return cls.from_arrays(arrays, schema=schema, metadata=metadata) + else: + raise TypeError('Schema must be an instance of pyarrow.Schema') + + +def _from_pylist(cls, mapping, schema, metadata): + """ + Construct a Table/RecordBatch from list of rows / dictionaries. + + Parameters + ---------- + cls : Class Table/RecordBatch + mapping : list of dicts of rows + A mapping of strings to row values. + schema : Schema, default None + If not passed, will be inferred from the first row of the + mapping values. + metadata : dict or Mapping, default None + Optional metadata for the schema (if inferred). + + Returns + ------- + Table/RecordBatch + """ + + arrays = [] + if schema is None: + names = [] + if mapping: + names = list(mapping[0].keys()) + for n in names: + v = [row[n] if n in row else None for row in mapping] + arrays.append(v) + return cls.from_arrays(arrays, names, metadata=metadata) + else: + if isinstance(schema, Schema): + for n in schema.names: + v = [row[n] if n in row else None for row in mapping] + arrays.append(v) + # Will raise if metadata is not None + return cls.from_arrays(arrays, schema=schema, metadata=metadata) + else: + raise TypeError('Schema must be an instance of pyarrow.Schema') + + +class TableGroupBy: + """ + A grouping of columns in a table on which to perform aggregations. + + Parameters + ---------- + table : pyarrow.Table + Input table to execute the aggregation on. + keys : str or list[str] + Name of the grouped columns. + use_threads : bool, default True + Whether to use multithreading or not. When set to True (the default), + no stable ordering of the output is guaranteed. + + Examples + -------- + >>> import pyarrow as pa + >>> t = pa.table([ + ... pa.array(["a", "a", "b", "b", "c"]), + ... pa.array([1, 2, 3, 4, 5]), + ... ], names=["keys", "values"]) + + Grouping of columns: + + >>> pa.TableGroupBy(t,"keys") + + + Perform aggregations: + + >>> pa.TableGroupBy(t,"keys").aggregate([("values", "sum")]) + pyarrow.Table + keys: string + values_sum: int64 + ---- + keys: [["a","b","c"]] + values_sum: [[3,7,5]] + """ + + def __init__(self, table, keys, use_threads=True): + if isinstance(keys, str): + keys = [keys] + + self._table = table + self.keys = keys + self._use_threads = use_threads + + def aggregate(self, aggregations): + """ + Perform an aggregation over the grouped columns of the table. + + Parameters + ---------- + aggregations : list[tuple(str, str)] or \ +list[tuple(str, str, FunctionOptions)] + List of tuples, where each tuple is one aggregation specification + and consists of: aggregation column name followed + by function name and optionally aggregation function option. + Pass empty list to get a single row for each group. + The column name can be a string, an empty list or a list of + column names, for unary, nullary and n-ary aggregation functions + respectively. + + For the list of function names and respective aggregation + function options see :ref:`py-grouped-aggrs`. + + Returns + ------- + Table + Results of the aggregation functions. + + Examples + -------- + >>> import pyarrow as pa + >>> t = pa.table([ + ... pa.array(["a", "a", "b", "b", "c"]), + ... pa.array([1, 2, 3, 4, 5]), + ... ], names=["keys", "values"]) + + Sum the column "values" over the grouped column "keys": + + >>> t.group_by("keys").aggregate([("values", "sum")]) + pyarrow.Table + keys: string + values_sum: int64 + ---- + keys: [["a","b","c"]] + values_sum: [[3,7,5]] + + Count the rows over the grouped column "keys": + + >>> t.group_by("keys").aggregate([([], "count_all")]) + pyarrow.Table + keys: string + count_all: int64 + ---- + keys: [["a","b","c"]] + count_all: [[2,2,1]] + + Do multiple aggregations: + + >>> t.group_by("keys").aggregate([ + ... ("values", "sum"), + ... ("keys", "count") + ... ]) + pyarrow.Table + keys: string + values_sum: int64 + keys_count: int64 + ---- + keys: [["a","b","c"]] + values_sum: [[3,7,5]] + keys_count: [[2,2,1]] + + Count the number of non-null values for column "values" + over the grouped column "keys": + + >>> import pyarrow.compute as pc + >>> t.group_by(["keys"]).aggregate([ + ... ("values", "count", pc.CountOptions(mode="only_valid")) + ... ]) + pyarrow.Table + keys: string + values_count: int64 + ---- + keys: [["a","b","c"]] + values_count: [[2,2,1]] + + Get a single row for each group in column "keys": + + >>> t.group_by("keys").aggregate([]) + pyarrow.Table + keys: string + ---- + keys: [["a","b","c"]] + """ + group_by_aggrs = [] + for aggr in aggregations: + # Set opt to None if not specified + if len(aggr) == 2: + target, func = aggr + opt = None + else: + target, func, opt = aggr + # Ensure target is a list + if not isinstance(target, (list, tuple)): + target = [target] + # Ensure aggregate function is hash_ if needed + if len(self.keys) > 0 and not func.startswith("hash_"): + func = "hash_" + func + if len(self.keys) == 0 and func.startswith("hash_"): + func = func[5:] + # Determine output field name + func_nohash = func if not func.startswith("hash_") else func[5:] + if len(target) == 0: + aggr_name = func_nohash + else: + aggr_name = "_".join(target) + "_" + func_nohash + group_by_aggrs.append((target, func, opt, aggr_name)) + + return _pac()._group_by( + self._table, group_by_aggrs, self.keys, use_threads=self._use_threads + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/util.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/util.py new file mode 100644 index 0000000000000000000000000000000000000000..5878d1f902627f9f4399ae4c71e8ea114efa6a0e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/util.py @@ -0,0 +1,276 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Miscellaneous utility code + +import os +import contextlib +import functools +import gc +import socket +import sys +import textwrap +import types +import warnings + + +_DEPR_MSG = ( + "pyarrow.{} is deprecated as of {}, please use pyarrow.{} instead." +) + + +def doc(*docstrings, **params): + """ + A decorator that takes docstring templates, concatenates them, and finally + performs string substitution on them. + This decorator will add a variable "_docstring_components" to the wrapped + callable to keep track of the original docstring template for potential future use. + If the docstring is a template, it will be saved as a string. + Otherwise, it will be saved as a callable and the docstring will be obtained via + the __doc__ attribute. + This decorator cannot be used on Cython classes due to a CPython constraint, + which enforces the __doc__ attribute to be read-only. + See https://github.com/python/cpython/issues/91309 + + Parameters + ---------- + *docstrings : None, str, or callable + The string / docstring / docstring template to be prepended in order + before the default docstring under the callable. + **params + The key/value pairs used to format the docstring template. + """ + + def decorator(decorated): + docstring_components = [] + + # collect docstrings and docstring templates + for docstring in docstrings: + if docstring is None: + continue + if hasattr(docstring, "_docstring_components"): + docstring_components.extend( + docstring._docstring_components + ) + elif isinstance(docstring, str) or docstring.__doc__: + docstring_components.append(docstring) + + # append the callable's docstring last + if decorated.__doc__: + docstring_components.append(textwrap.dedent(decorated.__doc__)) + + params_applied = [ + component.format(**params) + if isinstance(component, str) and len(params) > 0 + else component + for component in docstring_components + ] + + decorated.__doc__ = "".join( + [ + component + if isinstance(component, str) + else textwrap.dedent(component.__doc__ or "") + for component in params_applied + ] + ) + + decorated._docstring_components = ( + docstring_components + ) + return decorated + + return decorator + + +def _deprecate_api(old_name, new_name, api, next_version, type=FutureWarning): + msg = _DEPR_MSG.format(old_name, next_version, new_name) + + def wrapper(*args, **kwargs): + warnings.warn(msg, type) + return api(*args, **kwargs) + return wrapper + + +def _deprecate_class(old_name, new_class, next_version, + instancecheck=True): + """ + Raise warning if a deprecated class is used in an isinstance check. + """ + class _DeprecatedMeta(type): + def __instancecheck__(self, other): + warnings.warn( + _DEPR_MSG.format(old_name, next_version, new_class.__name__), + FutureWarning, + stacklevel=2 + ) + return isinstance(other, new_class) + + return _DeprecatedMeta(old_name, (new_class,), {}) + + +def _is_iterable(obj): + try: + iter(obj) + return True + except TypeError: + return False + + +def _is_path_like(path): + return isinstance(path, str) or hasattr(path, '__fspath__') + + +def _stringify_path(path): + """ + Convert *path* to a string or unicode path if possible. + """ + if isinstance(path, str): + return os.path.expanduser(path) + + # checking whether path implements the filesystem protocol + try: + return os.path.expanduser(path.__fspath__()) + except AttributeError: + pass + + raise TypeError("not a path-like object") + + +def product(seq): + """ + Return a product of sequence items. + """ + return functools.reduce(lambda a, b: a*b, seq, 1) + + +def get_contiguous_span(shape, strides, itemsize): + """ + Return a contiguous span of N-D array data. + + Parameters + ---------- + shape : tuple + strides : tuple + itemsize : int + Specify array shape data + + Returns + ------- + start, end : int + The span end points. + """ + if not strides: + start = 0 + end = itemsize * product(shape) + else: + start = 0 + end = itemsize + for i, dim in enumerate(shape): + if dim == 0: + start = end = 0 + break + stride = strides[i] + if stride > 0: + end += stride * (dim - 1) + elif stride < 0: + start += stride * (dim - 1) + if end - start != itemsize * product(shape): + raise ValueError('array data is non-contiguous') + return start, end + + +def find_free_port(): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + with contextlib.closing(sock) as sock: + sock.bind(('', 0)) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return sock.getsockname()[1] + + +def guid(): + from uuid import uuid4 + return uuid4().hex + + +def _break_traceback_cycle_from_frame(frame): + # Clear local variables in all inner frames, so as to break the + # reference cycle. + this_frame = sys._getframe(0) + refs = gc.get_referrers(frame) + while refs: + for frame in refs: + if frame is not this_frame and isinstance(frame, types.FrameType): + break + else: + # No frame found in referrers (finished?) + break + refs = None + # Clear the frame locals, to try and break the cycle (it is + # somewhere along the chain of execution frames). + frame.clear() + # To visit the inner frame, we need to find it among the + # referrers of this frame (while `frame.f_back` would let + # us visit the outer frame). + refs = gc.get_referrers(frame) + refs = frame = this_frame = None + + +def _download_urllib(url, out_path): + from urllib.request import urlopen + with urlopen(url) as response: + with open(out_path, 'wb') as f: + f.write(response.read()) + + +def _download_requests(url, out_path): + import requests + with requests.get(url) as response: + with open(out_path, 'wb') as f: + f.write(response.content) + + +def download_tzdata_on_windows(): + r""" + Download and extract latest IANA timezone database into the + location expected by Arrow which is %USERPROFILE%\Downloads\tzdata. + """ + if sys.platform != 'win32': + raise TypeError(f"Timezone database is already provided by {sys.platform}") + + import tarfile + + tzdata_url = "https://data.iana.org/time-zones/tzdata-latest.tar.gz" + tzdata_path = os.path.expandvars(r"%USERPROFILE%\Downloads\tzdata") + tzdata_compressed_path = os.path.join(tzdata_path, "tzdata.tar.gz") + windows_zones_url = "https://raw.githubusercontent.com/unicode-org/cldr/master/common/supplemental/windowsZones.xml" # noqa + windows_zones_path = os.path.join(tzdata_path, "windowsZones.xml") + os.makedirs(tzdata_path, exist_ok=True) + + # Try to download the files with requests and then fall back to urllib. This + # works around possible issues in certain older environment (GH-45295) + try: + _download_requests(tzdata_url, tzdata_compressed_path) + _download_requests(windows_zones_url, windows_zones_path) + except ImportError: + _download_urllib(tzdata_url, tzdata_compressed_path) + _download_urllib(windows_zones_url, windows_zones_path) + + assert os.path.exists(tzdata_compressed_path) + assert os.path.exists(windows_zones_path) + + tarfile.open(tzdata_compressed_path).extractall(tzdata_path) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/rich-13.9.4.dist-info/INSTALLER b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/rich-13.9.4.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/rich-13.9.4.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/rich-13.9.4.dist-info/LICENSE b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/rich-13.9.4.dist-info/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..4415505566f261c802b671426be529a31f914137 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/rich-13.9.4.dist-info/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2020 Will McGugan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/rich-13.9.4.dist-info/METADATA b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/rich-13.9.4.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..7f3ac7f8140264345c1883a603c3e27f38efcb69 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/rich-13.9.4.dist-info/METADATA @@ -0,0 +1,473 @@ +Metadata-Version: 2.1 +Name: rich +Version: 13.9.4 +Summary: Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal +Home-page: https://github.com/Textualize/rich +License: MIT +Author: Will McGugan +Author-email: willmcgugan@gmail.com +Requires-Python: >=3.8.0 +Classifier: Development Status :: 5 - Production/Stable +Classifier: Environment :: Console +Classifier: Framework :: IPython +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: MIT License +Classifier: Operating System :: MacOS +Classifier: Operating System :: Microsoft :: Windows +Classifier: Operating System :: POSIX :: Linux +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Programming Language :: Python :: 3.13 +Classifier: Typing :: Typed +Provides-Extra: jupyter +Requires-Dist: ipywidgets (>=7.5.1,<9) ; extra == "jupyter" +Requires-Dist: markdown-it-py (>=2.2.0) +Requires-Dist: pygments (>=2.13.0,<3.0.0) +Requires-Dist: typing-extensions (>=4.0.0,<5.0) ; python_version < "3.11" +Project-URL: Documentation, https://rich.readthedocs.io/en/latest/ +Description-Content-Type: text/markdown + +[![Supported Python Versions](https://img.shields.io/pypi/pyversions/rich/13.2.0)](https://pypi.org/project/rich/) [![PyPI version](https://badge.fury.io/py/rich.svg)](https://badge.fury.io/py/rich) + +[![Downloads](https://pepy.tech/badge/rich/month)](https://pepy.tech/project/rich) +[![codecov](https://img.shields.io/codecov/c/github/Textualize/rich?label=codecov&logo=codecov)](https://codecov.io/gh/Textualize/rich) +[![Rich blog](https://img.shields.io/badge/blog-rich%20news-yellowgreen)](https://www.willmcgugan.com/tag/rich/) +[![Twitter Follow](https://img.shields.io/twitter/follow/willmcgugan.svg?style=social)](https://twitter.com/willmcgugan) + +![Logo](https://github.com/textualize/rich/raw/master/imgs/logo.svg) + +[English readme](https://github.com/textualize/rich/blob/master/README.md) + • [简体中文 readme](https://github.com/textualize/rich/blob/master/README.cn.md) + • [正體中文 readme](https://github.com/textualize/rich/blob/master/README.zh-tw.md) + • [Lengua española readme](https://github.com/textualize/rich/blob/master/README.es.md) + • [Deutsche readme](https://github.com/textualize/rich/blob/master/README.de.md) + • [Läs på svenska](https://github.com/textualize/rich/blob/master/README.sv.md) + • [日本語 readme](https://github.com/textualize/rich/blob/master/README.ja.md) + • [한국어 readme](https://github.com/textualize/rich/blob/master/README.kr.md) + • [Français readme](https://github.com/textualize/rich/blob/master/README.fr.md) + • [Schwizerdütsch readme](https://github.com/textualize/rich/blob/master/README.de-ch.md) + • [हिन्दी readme](https://github.com/textualize/rich/blob/master/README.hi.md) + • [Português brasileiro readme](https://github.com/textualize/rich/blob/master/README.pt-br.md) + • [Italian readme](https://github.com/textualize/rich/blob/master/README.it.md) + • [Русский readme](https://github.com/textualize/rich/blob/master/README.ru.md) + • [Indonesian readme](https://github.com/textualize/rich/blob/master/README.id.md) + • [فارسی readme](https://github.com/textualize/rich/blob/master/README.fa.md) + • [Türkçe readme](https://github.com/textualize/rich/blob/master/README.tr.md) + • [Polskie readme](https://github.com/textualize/rich/blob/master/README.pl.md) + + +Rich is a Python library for _rich_ text and beautiful formatting in the terminal. + +The [Rich API](https://rich.readthedocs.io/en/latest/) makes it easy to add color and style to terminal output. Rich can also render pretty tables, progress bars, markdown, syntax highlighted source code, tracebacks, and more — out of the box. + +![Features](https://github.com/textualize/rich/raw/master/imgs/features.png) + +For a video introduction to Rich see [calmcode.io](https://calmcode.io/rich/introduction.html) by [@fishnets88](https://twitter.com/fishnets88). + +See what [people are saying about Rich](https://www.willmcgugan.com/blog/pages/post/rich-tweets/). + +## Compatibility + +Rich works with Linux, macOS and Windows. True color / emoji works with new Windows Terminal, classic terminal is limited to 16 colors. Rich requires Python 3.8 or later. + +Rich works with [Jupyter notebooks](https://jupyter.org/) with no additional configuration required. + +## Installing + +Install with `pip` or your favorite PyPI package manager. + +```sh +python -m pip install rich +``` + +Run the following to test Rich output on your terminal: + +```sh +python -m rich +``` + +## Rich Print + +To effortlessly add rich output to your application, you can import the [rich print](https://rich.readthedocs.io/en/latest/introduction.html#quick-start) method, which has the same signature as the builtin Python function. Try this: + +```python +from rich import print + +print("Hello, [bold magenta]World[/bold magenta]!", ":vampire:", locals()) +``` + +![Hello World](https://github.com/textualize/rich/raw/master/imgs/print.png) + +## Rich REPL + +Rich can be installed in the Python REPL, so that any data structures will be pretty printed and highlighted. + +```python +>>> from rich import pretty +>>> pretty.install() +``` + +![REPL](https://github.com/textualize/rich/raw/master/imgs/repl.png) + +## Using the Console + +For more control over rich terminal content, import and construct a [Console](https://rich.readthedocs.io/en/latest/reference/console.html#rich.console.Console) object. + +```python +from rich.console import Console + +console = Console() +``` + +The Console object has a `print` method which has an intentionally similar interface to the builtin `print` function. Here's an example of use: + +```python +console.print("Hello", "World!") +``` + +As you might expect, this will print `"Hello World!"` to the terminal. Note that unlike the builtin `print` function, Rich will word-wrap your text to fit within the terminal width. + +There are a few ways of adding color and style to your output. You can set a style for the entire output by adding a `style` keyword argument. Here's an example: + +```python +console.print("Hello", "World!", style="bold red") +``` + +The output will be something like the following: + +![Hello World](https://github.com/textualize/rich/raw/master/imgs/hello_world.png) + +That's fine for styling a line of text at a time. For more finely grained styling, Rich renders a special markup which is similar in syntax to [bbcode](https://en.wikipedia.org/wiki/BBCode). Here's an example: + +```python +console.print("Where there is a [bold cyan]Will[/bold cyan] there [u]is[/u] a [i]way[/i].") +``` + +![Console Markup](https://github.com/textualize/rich/raw/master/imgs/where_there_is_a_will.png) + +You can use a Console object to generate sophisticated output with minimal effort. See the [Console API](https://rich.readthedocs.io/en/latest/console.html) docs for details. + +## Rich Inspect + +Rich has an [inspect](https://rich.readthedocs.io/en/latest/reference/init.html?highlight=inspect#rich.inspect) function which can produce a report on any Python object, such as class, instance, or builtin. + +```python +>>> my_list = ["foo", "bar"] +>>> from rich import inspect +>>> inspect(my_list, methods=True) +``` + +![Log](https://github.com/textualize/rich/raw/master/imgs/inspect.png) + +See the [inspect docs](https://rich.readthedocs.io/en/latest/reference/init.html#rich.inspect) for details. + +# Rich Library + +Rich contains a number of builtin _renderables_ you can use to create elegant output in your CLI and help you debug your code. + +Click the following headings for details: + +
+Log + +The Console object has a `log()` method which has a similar interface to `print()`, but also renders a column for the current time and the file and line which made the call. By default Rich will do syntax highlighting for Python structures and for repr strings. If you log a collection (i.e. a dict or a list) Rich will pretty print it so that it fits in the available space. Here's an example of some of these features. + +```python +from rich.console import Console +console = Console() + +test_data = [ + {"jsonrpc": "2.0", "method": "sum", "params": [None, 1, 2, 4, False, True], "id": "1",}, + {"jsonrpc": "2.0", "method": "notify_hello", "params": [7]}, + {"jsonrpc": "2.0", "method": "subtract", "params": [42, 23], "id": "2"}, +] + +def test_log(): + enabled = False + context = { + "foo": "bar", + } + movies = ["Deadpool", "Rise of the Skywalker"] + console.log("Hello from", console, "!") + console.log(test_data, log_locals=True) + + +test_log() +``` + +The above produces the following output: + +![Log](https://github.com/textualize/rich/raw/master/imgs/log.png) + +Note the `log_locals` argument, which outputs a table containing the local variables where the log method was called. + +The log method could be used for logging to the terminal for long running applications such as servers, but is also a very nice debugging aid. + +
+
+Logging Handler + +You can also use the builtin [Handler class](https://rich.readthedocs.io/en/latest/logging.html) to format and colorize output from Python's logging module. Here's an example of the output: + +![Logging](https://github.com/textualize/rich/raw/master/imgs/logging.png) + +
+ +
+Emoji + +To insert an emoji in to console output place the name between two colons. Here's an example: + +```python +>>> console.print(":smiley: :vampire: :pile_of_poo: :thumbs_up: :raccoon:") +😃 🧛 💩 👍 🦝 +``` + +Please use this feature wisely. + +
+ +
+Tables + +Rich can render flexible [tables](https://rich.readthedocs.io/en/latest/tables.html) with unicode box characters. There is a large variety of formatting options for borders, styles, cell alignment etc. + +![table movie](https://github.com/textualize/rich/raw/master/imgs/table_movie.gif) + +The animation above was generated with [table_movie.py](https://github.com/textualize/rich/blob/master/examples/table_movie.py) in the examples directory. + +Here's a simpler table example: + +```python +from rich.console import Console +from rich.table import Table + +console = Console() + +table = Table(show_header=True, header_style="bold magenta") +table.add_column("Date", style="dim", width=12) +table.add_column("Title") +table.add_column("Production Budget", justify="right") +table.add_column("Box Office", justify="right") +table.add_row( + "Dec 20, 2019", "Star Wars: The Rise of Skywalker", "$275,000,000", "$375,126,118" +) +table.add_row( + "May 25, 2018", + "[red]Solo[/red]: A Star Wars Story", + "$275,000,000", + "$393,151,347", +) +table.add_row( + "Dec 15, 2017", + "Star Wars Ep. VIII: The Last Jedi", + "$262,000,000", + "[bold]$1,332,539,889[/bold]", +) + +console.print(table) +``` + +This produces the following output: + +![table](https://github.com/textualize/rich/raw/master/imgs/table.png) + +Note that console markup is rendered in the same way as `print()` and `log()`. In fact, anything that is renderable by Rich may be included in the headers / rows (even other tables). + +The `Table` class is smart enough to resize columns to fit the available width of the terminal, wrapping text as required. Here's the same example, with the terminal made smaller than the table above: + +![table2](https://github.com/textualize/rich/raw/master/imgs/table2.png) + +
+ +
+Progress Bars + +Rich can render multiple flicker-free [progress](https://rich.readthedocs.io/en/latest/progress.html) bars to track long-running tasks. + +For basic usage, wrap any sequence in the `track` function and iterate over the result. Here's an example: + +```python +from rich.progress import track + +for step in track(range(100)): + do_step(step) +``` + +It's not much harder to add multiple progress bars. Here's an example taken from the docs: + +![progress](https://github.com/textualize/rich/raw/master/imgs/progress.gif) + +The columns may be configured to show any details you want. Built-in columns include percentage complete, file size, file speed, and time remaining. Here's another example showing a download in progress: + +![progress](https://github.com/textualize/rich/raw/master/imgs/downloader.gif) + +To try this out yourself, see [examples/downloader.py](https://github.com/textualize/rich/blob/master/examples/downloader.py) which can download multiple URLs simultaneously while displaying progress. + +
+ +
+Status + +For situations where it is hard to calculate progress, you can use the [status](https://rich.readthedocs.io/en/latest/reference/console.html#rich.console.Console.status) method which will display a 'spinner' animation and message. The animation won't prevent you from using the console as normal. Here's an example: + +```python +from time import sleep +from rich.console import Console + +console = Console() +tasks = [f"task {n}" for n in range(1, 11)] + +with console.status("[bold green]Working on tasks...") as status: + while tasks: + task = tasks.pop(0) + sleep(1) + console.log(f"{task} complete") +``` + +This generates the following output in the terminal. + +![status](https://github.com/textualize/rich/raw/master/imgs/status.gif) + +The spinner animations were borrowed from [cli-spinners](https://www.npmjs.com/package/cli-spinners). You can select a spinner by specifying the `spinner` parameter. Run the following command to see the available values: + +``` +python -m rich.spinner +``` + +The above command generates the following output in the terminal: + +![spinners](https://github.com/textualize/rich/raw/master/imgs/spinners.gif) + +
+ +
+Tree + +Rich can render a [tree](https://rich.readthedocs.io/en/latest/tree.html) with guide lines. A tree is ideal for displaying a file structure, or any other hierarchical data. + +The labels of the tree can be simple text or anything else Rich can render. Run the following for a demonstration: + +``` +python -m rich.tree +``` + +This generates the following output: + +![markdown](https://github.com/textualize/rich/raw/master/imgs/tree.png) + +See the [tree.py](https://github.com/textualize/rich/blob/master/examples/tree.py) example for a script that displays a tree view of any directory, similar to the linux `tree` command. + +
+ +
+Columns + +Rich can render content in neat [columns](https://rich.readthedocs.io/en/latest/columns.html) with equal or optimal width. Here's a very basic clone of the (MacOS / Linux) `ls` command which displays a directory listing in columns: + +```python +import os +import sys + +from rich import print +from rich.columns import Columns + +directory = os.listdir(sys.argv[1]) +print(Columns(directory)) +``` + +The following screenshot is the output from the [columns example](https://github.com/textualize/rich/blob/master/examples/columns.py) which displays data pulled from an API in columns: + +![columns](https://github.com/textualize/rich/raw/master/imgs/columns.png) + +
+ +
+Markdown + +Rich can render [markdown](https://rich.readthedocs.io/en/latest/markdown.html) and does a reasonable job of translating the formatting to the terminal. + +To render markdown import the `Markdown` class and construct it with a string containing markdown code. Then print it to the console. Here's an example: + +```python +from rich.console import Console +from rich.markdown import Markdown + +console = Console() +with open("README.md") as readme: + markdown = Markdown(readme.read()) +console.print(markdown) +``` + +This will produce output something like the following: + +![markdown](https://github.com/textualize/rich/raw/master/imgs/markdown.png) + +
+ +
+Syntax Highlighting + +Rich uses the [pygments](https://pygments.org/) library to implement [syntax highlighting](https://rich.readthedocs.io/en/latest/syntax.html). Usage is similar to rendering markdown; construct a `Syntax` object and print it to the console. Here's an example: + +```python +from rich.console import Console +from rich.syntax import Syntax + +my_code = ''' +def iter_first_last(values: Iterable[T]) -> Iterable[Tuple[bool, bool, T]]: + """Iterate and generate a tuple with a flag for first and last value.""" + iter_values = iter(values) + try: + previous_value = next(iter_values) + except StopIteration: + return + first = True + for value in iter_values: + yield first, False, previous_value + first = False + previous_value = value + yield first, True, previous_value +''' +syntax = Syntax(my_code, "python", theme="monokai", line_numbers=True) +console = Console() +console.print(syntax) +``` + +This will produce the following output: + +![syntax](https://github.com/textualize/rich/raw/master/imgs/syntax.png) + +
+ +
+Tracebacks + +Rich can render [beautiful tracebacks](https://rich.readthedocs.io/en/latest/traceback.html) which are easier to read and show more code than standard Python tracebacks. You can set Rich as the default traceback handler so all uncaught exceptions will be rendered by Rich. + +Here's what it looks like on OSX (similar on Linux): + +![traceback](https://github.com/textualize/rich/raw/master/imgs/traceback.png) + +
+ +All Rich renderables make use of the [Console Protocol](https://rich.readthedocs.io/en/latest/protocol.html), which you can also use to implement your own Rich content. + +# Rich CLI + + +See also [Rich CLI](https://github.com/textualize/rich-cli) for a command line application powered by Rich. Syntax highlight code, render markdown, display CSVs in tables, and more, directly from the command prompt. + + +![Rich CLI](https://raw.githubusercontent.com/Textualize/rich-cli/main/imgs/rich-cli-splash.jpg) + +# Textual + +See also Rich's sister project, [Textual](https://github.com/Textualize/textual), which you can use to build sophisticated User Interfaces in the terminal. + +![Textual screenshot](https://raw.githubusercontent.com/Textualize/textual/main/imgs/textual.png) + diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/rich-13.9.4.dist-info/RECORD b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/rich-13.9.4.dist-info/RECORD new file mode 100644 index 0000000000000000000000000000000000000000..220057726aa2498318c371cde183965f87d76297 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/rich-13.9.4.dist-info/RECORD @@ -0,0 +1,162 @@ +rich-13.9.4.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +rich-13.9.4.dist-info/LICENSE,sha256=3u18F6QxgVgZCj6iOcyHmlpQJxzruYrnAl9I--WNyhU,1056 +rich-13.9.4.dist-info/METADATA,sha256=dg29ATErmwW3hqOEbIsmWW2Y4ieh38w98r9l8MfIrGI,18274 +rich-13.9.4.dist-info/RECORD,, +rich-13.9.4.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88 +rich/__init__.py,sha256=lh2WcoIOJp5M5_lbAsSUMGv8oiJeumROazHH_AYMS8I,6066 +rich/__main__.py,sha256=Wvh53rmOMyWeUeyqUHpn1PXsHlBc4TVcQnqrw46nf9Y,8333 +rich/__pycache__/__init__.cpython-312.pyc,, +rich/__pycache__/__main__.cpython-312.pyc,, +rich/__pycache__/_cell_widths.cpython-312.pyc,, +rich/__pycache__/_emoji_codes.cpython-312.pyc,, +rich/__pycache__/_emoji_replace.cpython-312.pyc,, +rich/__pycache__/_export_format.cpython-312.pyc,, +rich/__pycache__/_extension.cpython-312.pyc,, +rich/__pycache__/_fileno.cpython-312.pyc,, +rich/__pycache__/_inspect.cpython-312.pyc,, +rich/__pycache__/_log_render.cpython-312.pyc,, +rich/__pycache__/_loop.cpython-312.pyc,, +rich/__pycache__/_null_file.cpython-312.pyc,, +rich/__pycache__/_palettes.cpython-312.pyc,, +rich/__pycache__/_pick.cpython-312.pyc,, +rich/__pycache__/_ratio.cpython-312.pyc,, +rich/__pycache__/_spinners.cpython-312.pyc,, +rich/__pycache__/_stack.cpython-312.pyc,, +rich/__pycache__/_timer.cpython-312.pyc,, +rich/__pycache__/_win32_console.cpython-312.pyc,, +rich/__pycache__/_windows.cpython-312.pyc,, +rich/__pycache__/_windows_renderer.cpython-312.pyc,, +rich/__pycache__/_wrap.cpython-312.pyc,, +rich/__pycache__/abc.cpython-312.pyc,, +rich/__pycache__/align.cpython-312.pyc,, +rich/__pycache__/ansi.cpython-312.pyc,, +rich/__pycache__/bar.cpython-312.pyc,, +rich/__pycache__/box.cpython-312.pyc,, +rich/__pycache__/cells.cpython-312.pyc,, +rich/__pycache__/color.cpython-312.pyc,, +rich/__pycache__/color_triplet.cpython-312.pyc,, +rich/__pycache__/columns.cpython-312.pyc,, +rich/__pycache__/console.cpython-312.pyc,, +rich/__pycache__/constrain.cpython-312.pyc,, +rich/__pycache__/containers.cpython-312.pyc,, +rich/__pycache__/control.cpython-312.pyc,, +rich/__pycache__/default_styles.cpython-312.pyc,, +rich/__pycache__/diagnose.cpython-312.pyc,, +rich/__pycache__/emoji.cpython-312.pyc,, +rich/__pycache__/errors.cpython-312.pyc,, +rich/__pycache__/file_proxy.cpython-312.pyc,, +rich/__pycache__/filesize.cpython-312.pyc,, +rich/__pycache__/highlighter.cpython-312.pyc,, +rich/__pycache__/json.cpython-312.pyc,, +rich/__pycache__/jupyter.cpython-312.pyc,, +rich/__pycache__/layout.cpython-312.pyc,, +rich/__pycache__/live.cpython-312.pyc,, +rich/__pycache__/live_render.cpython-312.pyc,, +rich/__pycache__/logging.cpython-312.pyc,, +rich/__pycache__/markdown.cpython-312.pyc,, +rich/__pycache__/markup.cpython-312.pyc,, +rich/__pycache__/measure.cpython-312.pyc,, +rich/__pycache__/padding.cpython-312.pyc,, +rich/__pycache__/pager.cpython-312.pyc,, +rich/__pycache__/palette.cpython-312.pyc,, +rich/__pycache__/panel.cpython-312.pyc,, +rich/__pycache__/pretty.cpython-312.pyc,, +rich/__pycache__/progress.cpython-312.pyc,, +rich/__pycache__/progress_bar.cpython-312.pyc,, +rich/__pycache__/prompt.cpython-312.pyc,, +rich/__pycache__/protocol.cpython-312.pyc,, +rich/__pycache__/region.cpython-312.pyc,, +rich/__pycache__/repr.cpython-312.pyc,, +rich/__pycache__/rule.cpython-312.pyc,, +rich/__pycache__/scope.cpython-312.pyc,, +rich/__pycache__/screen.cpython-312.pyc,, +rich/__pycache__/segment.cpython-312.pyc,, +rich/__pycache__/spinner.cpython-312.pyc,, +rich/__pycache__/status.cpython-312.pyc,, +rich/__pycache__/style.cpython-312.pyc,, +rich/__pycache__/styled.cpython-312.pyc,, +rich/__pycache__/syntax.cpython-312.pyc,, +rich/__pycache__/table.cpython-312.pyc,, +rich/__pycache__/terminal_theme.cpython-312.pyc,, +rich/__pycache__/text.cpython-312.pyc,, +rich/__pycache__/theme.cpython-312.pyc,, +rich/__pycache__/themes.cpython-312.pyc,, +rich/__pycache__/traceback.cpython-312.pyc,, +rich/__pycache__/tree.cpython-312.pyc,, +rich/_cell_widths.py,sha256=fbmeyetEdHjzE_Vx2l1uK7tnPOhMs2X1lJfO3vsKDpA,10209 +rich/_emoji_codes.py,sha256=hu1VL9nbVdppJrVoijVshRlcRRe_v3dju3Mmd2sKZdY,140235 +rich/_emoji_replace.py,sha256=n-kcetsEUx2ZUmhQrfeMNc-teeGhpuSQ5F8VPBsyvDo,1064 +rich/_export_format.py,sha256=RI08pSrm5tBSzPMvnbTqbD9WIalaOoN5d4M1RTmLq1Y,2128 +rich/_extension.py,sha256=G66PkbH_QdTJh6jD-J228O76CmAnr2hLQv72CgPPuzE,241 +rich/_fileno.py,sha256=HWZxP5C2ajMbHryvAQZseflVfQoGzsKOHzKGsLD8ynQ,799 +rich/_inspect.py,sha256=QM05lEFnFoTaFqpnbx-zBEI6k8oIKrD3cvjEOQNhKig,9655 +rich/_log_render.py,sha256=xBKCxqiO4FZk8eG56f8crFdrmJxFrJsQE3V3F-fFekc,3213 +rich/_loop.py,sha256=hV_6CLdoPm0va22Wpw4zKqM0RYsz3TZxXj0PoS-9eDQ,1236 +rich/_null_file.py,sha256=ADGKp1yt-k70FMKV6tnqCqecB-rSJzp-WQsD7LPL-kg,1394 +rich/_palettes.py,sha256=cdev1JQKZ0JvlguV9ipHgznTdnvlIzUFDBb0It2PzjI,7063 +rich/_pick.py,sha256=evDt8QN4lF5CiwrUIXlOJCntitBCOsI3ZLPEIAVRLJU,423 +rich/_ratio.py,sha256=d2k38QnkJKhkHAqqSseqMQ-ZuvgbwnocRKhMQq84EdI,5459 +rich/_spinners.py,sha256=U2r1_g_1zSjsjiUdAESc2iAMc3i4ri_S8PYP6kQ5z1I,19919 +rich/_stack.py,sha256=-C8OK7rxn3sIUdVwxZBBpeHhIzX0eI-VM3MemYfaXm0,351 +rich/_timer.py,sha256=zelxbT6oPFZnNrwWPpc1ktUeAT-Vc4fuFcRZLQGLtMI,417 +rich/_win32_console.py,sha256=o2QN_IRx10biGP3Ap1neaqX8FBGlUKSmWM6Kw4OSg-U,22719 +rich/_windows.py,sha256=is3WpbHMj8WaTHYB11hc6lP2t4hlvt4TViTlHSmjsi0,1901 +rich/_windows_renderer.py,sha256=d799xOnxLbCCCzGu9-U7YLmIQkxtxQIBFQQ6iu4veSc,2759 +rich/_wrap.py,sha256=FlSsom5EX0LVkA3KWy34yHnCfLtqX-ZIepXKh-70rpc,3404 +rich/abc.py,sha256=dALMOGfKVNeAbvqq66IpTQxQUerxD7AE4FKwqd0eQKk,878 +rich/align.py,sha256=gxlfgvi4ah8ERmg8RpGFtWY1Z4WBuWm-6qSIUSFx4bQ,10421 +rich/ansi.py,sha256=Avs1LHbSdcyOvDOdpELZUoULcBiYewY76eNBp6uFBhs,6921 +rich/bar.py,sha256=ldbVHOzKJOnflVNuv1xS7g6dLX2E3wMnXkdPbpzJTcs,3263 +rich/box.py,sha256=46rA0eBKLBcqNhCXmEKS4pN1dz36F0Vzi52hyVT-tyc,10783 +rich/cells.py,sha256=KrQkj5-LghCCpJLSNQIyAZjndc4bnEqOEmi5YuZ9UCY,5130 +rich/color.py,sha256=3HSULVDj7qQkXUdFWv78JOiSZzfy5y1nkcYhna296V0,18211 +rich/color_triplet.py,sha256=3lhQkdJbvWPoLDO-AnYImAWmJvV5dlgYNCVZ97ORaN4,1054 +rich/columns.py,sha256=HUX0KcMm9dsKNi11fTbiM_h2iDtl8ySCaVcxlalEzq8,7131 +rich/console.py,sha256=zgSwvRDPiDXh6wQ_kbnNSxff-s7uuljVmaTeoYPyh6E,100084 +rich/constrain.py,sha256=1VIPuC8AgtKWrcncQrjBdYqA3JVWysu6jZo1rrh7c7Q,1288 +rich/containers.py,sha256=c_56TxcedGYqDepHBMTuZdUIijitAQgnox-Qde0Z1qo,5502 +rich/control.py,sha256=Ix-rO8ZhSB2q1Biazr4l72ZyAw27H9or7ElipWVVo0M,6606 +rich/default_styles.py,sha256=gY-aX6rUxxlxdOOt5CqxnltpFDQqqqdHuXwAy2OD1o8,8123 +rich/diagnose.py,sha256=ZopD2EpWVtmmKptgbXT-sOMkAJ7DGrMSUXUiaU2GZ78,924 +rich/emoji.py,sha256=1jTRHFwvQxY1ciul22MdEZcWc7brfjKT8FG6ZjXj5dM,2465 +rich/errors.py,sha256=5pP3Kc5d4QJ_c0KFsxrfyhjiPVe7J1zOqSFbFAzcV-Y,642 +rich/file_proxy.py,sha256=Tl9THMDZ-Pk5Wm8sI1gGg_U5DhusmxD-FZ0fUbcU0W0,1683 +rich/filesize.py,sha256=_iz9lIpRgvW7MNSeCZnLg-HwzbP4GETg543WqD8SFs0,2484 +rich/highlighter.py,sha256=G_sn-8DKjM1sEjLG_oc4ovkWmiUpWvj8bXi0yed2LnY,9586 +rich/json.py,sha256=omC2WHTgURxEosna1ftoSJCne2EX7MDuQtCdswS3qsk,5019 +rich/jupyter.py,sha256=G9pOJmR4ESIFYSd4MKGqmHqCtstx0oRWpyeTgv54-Xc,3228 +rich/layout.py,sha256=WR8PCSroYnteIT3zawxQ3k3ad1sQO5wGG1SZOoeBuBM,13944 +rich/live.py,sha256=DhzAPEnjTxQuq9_0Y2xh2MUwQcP_aGPkenLfKETslwM,14270 +rich/live_render.py,sha256=QaiB8dtGikCdssoXpkEmmiH55fxT-9bzLkBO9pbBvrU,3654 +rich/logging.py,sha256=aqZpsmIEE45-wbnZqWnEaNSdQ89cbGcaL26-ZV0poj0,12446 +rich/markdown.py,sha256=eDi7dMN7RQD5u21tuqCOSpNWGZdKmyGtKmaZNt257rA,25969 +rich/markup.py,sha256=btpr271BLhiCR1jNglRnv2BpIzVcNefYwSMeW9teDbc,8427 +rich/measure.py,sha256=HmrIJX8sWRTHbgh8MxEay_83VkqNW_70s8aKP5ZcYI8,5305 +rich/padding.py,sha256=h8XnIivLrNtlxI3vQPKHXh4hAwjOJqZx0slM0z3g1_M,4896 +rich/pager.py,sha256=SO_ETBFKbg3n_AgOzXm41Sv36YxXAyI3_R-KOY2_uSc,828 +rich/palette.py,sha256=Ar6ZUrYHiFt6-Rr2k-k9F8V7hxgJYHNdqjk2vVXsLgc,3288 +rich/panel.py,sha256=fFRHcviXvWhk3V3zx5Zwmsb_RL9KJ3esD-sU0NYEVyw,11235 +rich/pretty.py,sha256=eQs437AksYaCB2qO_d-z6e0DF_t5F1KfXfa1Hi-Ya0E,36355 +rich/progress.py,sha256=tLmBGHrAfxIQxfB2kq1IpNXTVFNuvl9bXd_QkLQUN8Q,60333 +rich/progress_bar.py,sha256=mZTPpJUwcfcdgQCTTz3kyY-fc79ddLwtx6Ghhxfo064,8162 +rich/prompt.py,sha256=k0CUIW-3I55jGk8U3O1WiEhdF6yXa2EiWeRqRhuJXWA,12435 +rich/protocol.py,sha256=Wt-2HZd67OYiopUkCTOz7lM38vyo5r3HEQZ9TOPDl5Q,1367 +rich/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +rich/region.py,sha256=rNT9xZrVZTYIXZC0NYn41CJQwYNbR-KecPOxTgQvB8Y,166 +rich/repr.py,sha256=HIsurPLZK9Gray75l3_vQx7S27AzTpAj4ChXSfe1Fes,4419 +rich/rule.py,sha256=umO21Wjw0FcYAeTB3UumNLCsDWhejzxnjlf2VwiXiDI,4590 +rich/scope.py,sha256=lf6Qet_e4JOY34lwhYSAG-NBXYKBcYu6t_igv_JoGog,2831 +rich/screen.py,sha256=rL_j2wX-4SeuIOI2oOlc418QP9EAvD59GInUmEAE6jQ,1579 +rich/segment.py,sha256=7gOdwSPrzu0a2gRmxBDtu3u2S8iG5s9l7wlB58dKMy0,24707 +rich/spinner.py,sha256=PT5qgXPG3ZpqRj7n3EZQ6NW56mx3ldZqZCU7gEMyZk4,4364 +rich/status.py,sha256=kkPph3YeAZBo-X-4wPp8gTqZyU466NLwZBA4PZTTewo,4424 +rich/style.py,sha256=aSoUNbVgfP1PAnduAqgbbl4AMQy668qs2S1FEwr3Oqs,27067 +rich/styled.py,sha256=wljVsVTXbABMMZvkzkO43ZEk_-irzEtvUiQ-sNnikQ8,1234 +rich/syntax.py,sha256=NY1DRIqXBkFExudqxm5K3BJXFCttN63AF_3IZAvtLMg,35655 +rich/table.py,sha256=RX26U8oHV0s1U-gl6WqylfesmOT2qt7VVtMtC18-Pk0,40067 +rich/terminal_theme.py,sha256=1j5-ufJfnvlAo5Qsi_ACZiXDmwMXzqgmFByObT9-yJY,3370 +rich/text.py,sha256=v-vCOG8gS_D5QDhOhU19478-yEJGAXKVi8iYCCk7O_M,47540 +rich/theme.py,sha256=oNyhXhGagtDlbDye3tVu3esWOWk0vNkuxFw-_unlaK0,3771 +rich/themes.py,sha256=0xgTLozfabebYtcJtDdC5QkX5IVUEaviqDUJJh4YVFk,102 +rich/traceback.py,sha256=hCLOig4Uwtc7f0FqseEkFZ8YUwzvGOli8BOG517mipg,31725 +rich/tree.py,sha256=QoOwg424FkdwGfR8K0tZ6Q7qtzWNAUP_m4sFaYuG6nw,9391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/rich-13.9.4.dist-info/WHEEL b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/rich-13.9.4.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..d73ccaae8e0eea45949b0957a5af034099b36aa4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/rich-13.9.4.dist-info/WHEEL @@ -0,0 +1,4 @@ +Wheel-Version: 1.0 +Generator: poetry-core 1.9.0 +Root-Is-Purelib: true +Tag: py3-none-any diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/_imp.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/_imp.py new file mode 100644 index 0000000000000000000000000000000000000000..f1d9f29218987d4f830f2d57aca9e3f74d00a095 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/_imp.py @@ -0,0 +1,87 @@ +""" +Re-implementation of find_module and get_frozen_object +from the deprecated imp module. +""" + +import importlib.machinery +import importlib.util +import os +import tokenize +from importlib.util import module_from_spec + +PY_SOURCE = 1 +PY_COMPILED = 2 +C_EXTENSION = 3 +C_BUILTIN = 6 +PY_FROZEN = 7 + + +def find_spec(module, paths): + finder = ( + importlib.machinery.PathFinder().find_spec + if isinstance(paths, list) + else importlib.util.find_spec + ) + return finder(module, paths) + + +def find_module(module, paths=None): + """Just like 'imp.find_module()', but with package support""" + spec = find_spec(module, paths) + if spec is None: + raise ImportError(f"Can't find {module}") + if not spec.has_location and hasattr(spec, 'submodule_search_locations'): + spec = importlib.util.spec_from_loader('__init__.py', spec.loader) + + kind = -1 + file = None + static = isinstance(spec.loader, type) + if ( + spec.origin == 'frozen' + or static + and issubclass(spec.loader, importlib.machinery.FrozenImporter) + ): + kind = PY_FROZEN + path = None # imp compabilty + suffix = mode = '' # imp compatibility + elif ( + spec.origin == 'built-in' + or static + and issubclass(spec.loader, importlib.machinery.BuiltinImporter) + ): + kind = C_BUILTIN + path = None # imp compabilty + suffix = mode = '' # imp compatibility + elif spec.has_location: + path = spec.origin + suffix = os.path.splitext(path)[1] + mode = 'r' if suffix in importlib.machinery.SOURCE_SUFFIXES else 'rb' + + if suffix in importlib.machinery.SOURCE_SUFFIXES: + kind = PY_SOURCE + file = tokenize.open(path) + elif suffix in importlib.machinery.BYTECODE_SUFFIXES: + kind = PY_COMPILED + file = open(path, 'rb') + elif suffix in importlib.machinery.EXTENSION_SUFFIXES: + kind = C_EXTENSION + + else: + path = None + suffix = mode = '' + + return file, path, (suffix, mode, kind) + + +def get_frozen_object(module, paths=None): + spec = find_spec(module, paths) + if not spec: + raise ImportError(f"Can't find {module}") + return spec.loader.get_code(module) + + +def get_module(module, paths, info): + spec = find_spec(module, paths) + if not spec: + raise ImportError(f"Can't find {module}") + return module_from_spec(spec) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/_importlib.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/_importlib.py new file mode 100644 index 0000000000000000000000000000000000000000..ce0fd52653b56c9c2cb2b2c7bfb35e3ec3c61408 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/_importlib.py @@ -0,0 +1,9 @@ +import sys + +if sys.version_info < (3, 10): + import importlib_metadata as metadata # pragma: no cover +else: + import importlib.metadata as metadata # noqa: F401 + + +import importlib.resources as resources # noqa: F401 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/_itertools.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/_itertools.py new file mode 100644 index 0000000000000000000000000000000000000000..d6ca841353ce39ac4361013f5c8160d69028d0d8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/_itertools.py @@ -0,0 +1,23 @@ +from more_itertools import consume # noqa: F401 + + +# copied from jaraco.itertools 6.1 +def ensure_unique(iterable, key=lambda x: x): + """ + Wrap an iterable to raise a ValueError if non-unique values are encountered. + + >>> list(ensure_unique('abc')) + ['a', 'b', 'c'] + >>> consume(ensure_unique('abca')) + Traceback (most recent call last): + ... + ValueError: Duplicate element 'a' encountered. + """ + seen = set() + seen_add = seen.add + for element in iterable: + k = key(element) + if k in seen: + raise ValueError(f"Duplicate element {element!r} encountered.") + seen_add(k) + yield element diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/_path.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/_path.py new file mode 100644 index 0000000000000000000000000000000000000000..0d99b0f539ff5f819b167013c48726180cd83d49 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/_path.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import contextlib +import os +import sys +from typing import TYPE_CHECKING, TypeVar, Union + +from more_itertools import unique_everseen + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + +StrPath: TypeAlias = Union[str, os.PathLike[str]] # Same as _typeshed.StrPath +StrPathT = TypeVar("StrPathT", bound=Union[str, os.PathLike[str]]) + + +def ensure_directory(path): + """Ensure that the parent directory of `path` exists""" + dirname = os.path.dirname(path) + os.makedirs(dirname, exist_ok=True) + + +def same_path(p1: StrPath, p2: StrPath) -> bool: + """Differs from os.path.samefile because it does not require paths to exist. + Purely string based (no comparison between i-nodes). + >>> same_path("a/b", "./a/b") + True + >>> same_path("a/b", "a/./b") + True + >>> same_path("a/b", "././a/b") + True + >>> same_path("a/b", "./a/b/c/..") + True + >>> same_path("a/b", "../a/b/c") + False + >>> same_path("a", "a/b") + False + """ + return normpath(p1) == normpath(p2) + + +def normpath(filename: StrPath) -> str: + """Normalize a file/dir name for comparison purposes.""" + # See pkg_resources.normalize_path for notes about cygwin + file = os.path.abspath(filename) if sys.platform == 'cygwin' else filename + return os.path.normcase(os.path.realpath(os.path.normpath(file))) + + +@contextlib.contextmanager +def paths_on_pythonpath(paths): + """ + Add the indicated paths to the head of the PYTHONPATH environment + variable so that subprocesses will also see the packages at + these paths. + + Do this in a context that restores the value on exit. + + >>> getfixture('monkeypatch').setenv('PYTHONPATH', 'anything') + >>> with paths_on_pythonpath(['foo', 'bar']): + ... assert 'foo' in os.environ['PYTHONPATH'] + ... assert 'anything' in os.environ['PYTHONPATH'] + >>> os.environ['PYTHONPATH'] + 'anything' + + >>> getfixture('monkeypatch').delenv('PYTHONPATH') + >>> with paths_on_pythonpath(['foo', 'bar']): + ... assert 'foo' in os.environ['PYTHONPATH'] + >>> os.environ.get('PYTHONPATH') + """ + nothing = object() + orig_pythonpath = os.environ.get('PYTHONPATH', nothing) + current_pythonpath = os.environ.get('PYTHONPATH', '') + try: + prefix = os.pathsep.join(unique_everseen(paths)) + to_join = filter(None, [prefix, current_pythonpath]) + new_path = os.pathsep.join(to_join) + if new_path: + os.environ['PYTHONPATH'] = new_path + yield + finally: + if orig_pythonpath is nothing: + os.environ.pop('PYTHONPATH', None) + else: + os.environ['PYTHONPATH'] = orig_pythonpath diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/build_meta.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/build_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..00fa5e1f7048c8781cfc4838f705e122fd547825 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/build_meta.py @@ -0,0 +1,560 @@ +"""A PEP 517 interface to setuptools + +Previously, when a user or a command line tool (let's call it a "frontend") +needed to make a request of setuptools to take a certain action, for +example, generating a list of installation requirements, the frontend +would call "setup.py egg_info" or "setup.py bdist_wheel" on the command line. + +PEP 517 defines a different method of interfacing with setuptools. Rather +than calling "setup.py" directly, the frontend should: + + 1. Set the current directory to the directory with a setup.py file + 2. Import this module into a safe python interpreter (one in which + setuptools can potentially set global variables or crash hard). + 3. Call one of the functions defined in PEP 517. + +What each function does is defined in PEP 517. However, here is a "casual" +definition of the functions (this definition should not be relied on for +bug reports or API stability): + + - `build_wheel`: build a wheel in the folder and return the basename + - `get_requires_for_build_wheel`: get the `setup_requires` to build + - `prepare_metadata_for_build_wheel`: get the `install_requires` + - `build_sdist`: build an sdist in the folder and return the basename + - `get_requires_for_build_sdist`: get the `setup_requires` to build + +Again, this is not a formal definition! Just a "taste" of the module. +""" + +from __future__ import annotations + +import contextlib +import io +import os +import shlex +import shutil +import sys +import tempfile +import tokenize +import warnings +from collections.abc import Iterable, Iterator, Mapping +from pathlib import Path +from typing import TYPE_CHECKING, Union + +import setuptools + +from . import errors +from ._path import StrPath, same_path +from ._reqs import parse_strings +from .warnings import SetuptoolsDeprecationWarning + +import distutils +from distutils.util import strtobool + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + +__all__ = [ + 'get_requires_for_build_sdist', + 'get_requires_for_build_wheel', + 'prepare_metadata_for_build_wheel', + 'build_wheel', + 'build_sdist', + 'get_requires_for_build_editable', + 'prepare_metadata_for_build_editable', + 'build_editable', + '__legacy__', + 'SetupRequirementsError', +] + +SETUPTOOLS_ENABLE_FEATURES = os.getenv("SETUPTOOLS_ENABLE_FEATURES", "").lower() +LEGACY_EDITABLE = "legacy-editable" in SETUPTOOLS_ENABLE_FEATURES.replace("_", "-") + + +class SetupRequirementsError(BaseException): + def __init__(self, specifiers) -> None: + self.specifiers = specifiers + + +class Distribution(setuptools.dist.Distribution): + def fetch_build_eggs(self, specifiers): + specifier_list = list(parse_strings(specifiers)) + + raise SetupRequirementsError(specifier_list) + + @classmethod + @contextlib.contextmanager + def patch(cls): + """ + Replace + distutils.dist.Distribution with this class + for the duration of this context. + """ + orig = distutils.core.Distribution + distutils.core.Distribution = cls # type: ignore[misc] # monkeypatching + try: + yield + finally: + distutils.core.Distribution = orig # type: ignore[misc] # monkeypatching + + +@contextlib.contextmanager +def no_install_setup_requires(): + """Temporarily disable installing setup_requires + + Under PEP 517, the backend reports build dependencies to the frontend, + and the frontend is responsible for ensuring they're installed. + So setuptools (acting as a backend) should not try to install them. + """ + orig = setuptools._install_setup_requires + setuptools._install_setup_requires = lambda attrs: None + try: + yield + finally: + setuptools._install_setup_requires = orig + + +def _get_immediate_subdirectories(a_dir): + return [ + name for name in os.listdir(a_dir) if os.path.isdir(os.path.join(a_dir, name)) + ] + + +def _file_with_extension(directory: StrPath, extension: str | tuple[str, ...]): + matching = (f for f in os.listdir(directory) if f.endswith(extension)) + try: + (file,) = matching + except ValueError: + raise ValueError( + 'No distribution was found. Ensure that `setup.py` ' + 'is not empty and that it calls `setup()`.' + ) from None + return file + + +def _open_setup_script(setup_script): + if not os.path.exists(setup_script): + # Supply a default setup.py + return io.StringIO("from setuptools import setup; setup()") + + return tokenize.open(setup_script) + + +@contextlib.contextmanager +def suppress_known_deprecation(): + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'setup.py install is deprecated') + yield + + +_ConfigSettings: TypeAlias = Union[Mapping[str, Union[str, list[str], None]], None] +""" +Currently the user can run:: + + pip install -e . --config-settings key=value + python -m build -C--key=value -C key=value + +- pip will pass both key and value as strings and overwriting repeated keys + (pypa/pip#11059). +- build will accumulate values associated with repeated keys in a list. + It will also accept keys with no associated value. + This means that an option passed by build can be ``str | list[str] | None``. +- PEP 517 specifies that ``config_settings`` is an optional dict. +""" + + +class _ConfigSettingsTranslator: + """Translate ``config_settings`` into distutils-style command arguments. + Only a limited number of options is currently supported. + """ + + # See pypa/setuptools#1928 pypa/setuptools#2491 + + def _get_config(self, key: str, config_settings: _ConfigSettings) -> list[str]: + """ + Get the value of a specific key in ``config_settings`` as a list of strings. + + >>> fn = _ConfigSettingsTranslator()._get_config + >>> fn("--global-option", None) + [] + >>> fn("--global-option", {}) + [] + >>> fn("--global-option", {'--global-option': 'foo'}) + ['foo'] + >>> fn("--global-option", {'--global-option': ['foo']}) + ['foo'] + >>> fn("--global-option", {'--global-option': 'foo'}) + ['foo'] + >>> fn("--global-option", {'--global-option': 'foo bar'}) + ['foo', 'bar'] + """ + cfg = config_settings or {} + opts = cfg.get(key) or [] + return shlex.split(opts) if isinstance(opts, str) else opts + + def _global_args(self, config_settings: _ConfigSettings) -> Iterator[str]: + """ + Let the user specify ``verbose`` or ``quiet`` + escape hatch via + ``--global-option``. + Note: ``-v``, ``-vv``, ``-vvv`` have similar effects in setuptools, + so we just have to cover the basic scenario ``-v``. + + >>> fn = _ConfigSettingsTranslator()._global_args + >>> list(fn(None)) + [] + >>> list(fn({"verbose": "False"})) + ['-q'] + >>> list(fn({"verbose": "1"})) + ['-v'] + >>> list(fn({"--verbose": None})) + ['-v'] + >>> list(fn({"verbose": "true", "--global-option": "-q --no-user-cfg"})) + ['-v', '-q', '--no-user-cfg'] + >>> list(fn({"--quiet": None})) + ['-q'] + """ + cfg = config_settings or {} + falsey = {"false", "no", "0", "off"} + if "verbose" in cfg or "--verbose" in cfg: + level = str(cfg.get("verbose") or cfg.get("--verbose") or "1") + yield ("-q" if level.lower() in falsey else "-v") + if "quiet" in cfg or "--quiet" in cfg: + level = str(cfg.get("quiet") or cfg.get("--quiet") or "1") + yield ("-v" if level.lower() in falsey else "-q") + + yield from self._get_config("--global-option", config_settings) + + def __dist_info_args(self, config_settings: _ConfigSettings) -> Iterator[str]: + """ + The ``dist_info`` command accepts ``tag-date`` and ``tag-build``. + + .. warning:: + We cannot use this yet as it requires the ``sdist`` and ``bdist_wheel`` + commands run in ``build_sdist`` and ``build_wheel`` to reuse the egg-info + directory created in ``prepare_metadata_for_build_wheel``. + + >>> fn = _ConfigSettingsTranslator()._ConfigSettingsTranslator__dist_info_args + >>> list(fn(None)) + [] + >>> list(fn({"tag-date": "False"})) + ['--no-date'] + >>> list(fn({"tag-date": None})) + ['--no-date'] + >>> list(fn({"tag-date": "true", "tag-build": ".a"})) + ['--tag-date', '--tag-build', '.a'] + """ + cfg = config_settings or {} + if "tag-date" in cfg: + val = strtobool(str(cfg["tag-date"] or "false")) + yield ("--tag-date" if val else "--no-date") + if "tag-build" in cfg: + yield from ["--tag-build", str(cfg["tag-build"])] + + def _editable_args(self, config_settings: _ConfigSettings) -> Iterator[str]: + """ + The ``editable_wheel`` command accepts ``editable-mode=strict``. + + >>> fn = _ConfigSettingsTranslator()._editable_args + >>> list(fn(None)) + [] + >>> list(fn({"editable-mode": "strict"})) + ['--mode', 'strict'] + """ + cfg = config_settings or {} + mode = cfg.get("editable-mode") or cfg.get("editable_mode") + if not mode: + return + yield from ["--mode", str(mode)] + + def _arbitrary_args(self, config_settings: _ConfigSettings) -> Iterator[str]: + """ + Users may expect to pass arbitrary lists of arguments to a command + via "--global-option" (example provided in PEP 517 of a "escape hatch"). + + >>> fn = _ConfigSettingsTranslator()._arbitrary_args + >>> list(fn(None)) + [] + >>> list(fn({})) + [] + >>> list(fn({'--build-option': 'foo'})) + ['foo'] + >>> list(fn({'--build-option': ['foo']})) + ['foo'] + >>> list(fn({'--build-option': 'foo'})) + ['foo'] + >>> list(fn({'--build-option': 'foo bar'})) + ['foo', 'bar'] + >>> list(fn({'--global-option': 'foo'})) + [] + """ + yield from self._get_config("--build-option", config_settings) + + +class _BuildMetaBackend(_ConfigSettingsTranslator): + def _get_build_requires( + self, config_settings: _ConfigSettings, requirements: list[str] + ): + sys.argv = [ + *sys.argv[:1], + *self._global_args(config_settings), + "egg_info", + ] + try: + with Distribution.patch(): + self.run_setup() + except SetupRequirementsError as e: + requirements += e.specifiers + + return requirements + + def run_setup(self, setup_script: str = 'setup.py'): + # Note that we can reuse our build directory between calls + # Correctness comes first, then optimization later + __file__ = os.path.abspath(setup_script) + __name__ = '__main__' + + with _open_setup_script(__file__) as f: + code = f.read().replace(r'\r\n', r'\n') + + try: + exec(code, locals()) + except SystemExit as e: + if e.code: + raise + # We ignore exit code indicating success + SetuptoolsDeprecationWarning.emit( + "Running `setup.py` directly as CLI tool is deprecated.", + "Please avoid using `sys.exit(0)` or similar statements " + "that don't fit in the paradigm of a configuration file.", + see_url="https://blog.ganssle.io/articles/2021/10/" + "setup-py-deprecated.html", + ) + + def get_requires_for_build_wheel(self, config_settings: _ConfigSettings = None): + return self._get_build_requires(config_settings, requirements=[]) + + def get_requires_for_build_sdist(self, config_settings: _ConfigSettings = None): + return self._get_build_requires(config_settings, requirements=[]) + + def _bubble_up_info_directory( + self, metadata_directory: StrPath, suffix: str + ) -> str: + """ + PEP 517 requires that the .dist-info directory be placed in the + metadata_directory. To comply, we MUST copy the directory to the root. + + Returns the basename of the info directory, e.g. `proj-0.0.0.dist-info`. + """ + info_dir = self._find_info_directory(metadata_directory, suffix) + if not same_path(info_dir.parent, metadata_directory): + shutil.move(str(info_dir), metadata_directory) + # PEP 517 allow other files and dirs to exist in metadata_directory + return info_dir.name + + def _find_info_directory(self, metadata_directory: StrPath, suffix: str) -> Path: + for parent, dirs, _ in os.walk(metadata_directory): + candidates = [f for f in dirs if f.endswith(suffix)] + + if len(candidates) != 0 or len(dirs) != 1: + assert len(candidates) == 1, f"Multiple {suffix} directories found" + return Path(parent, candidates[0]) + + msg = f"No {suffix} directory found in {metadata_directory}" + raise errors.InternalError(msg) + + def prepare_metadata_for_build_wheel( + self, metadata_directory: StrPath, config_settings: _ConfigSettings = None + ): + sys.argv = [ + *sys.argv[:1], + *self._global_args(config_settings), + "dist_info", + "--output-dir", + str(metadata_directory), + "--keep-egg-info", + ] + with no_install_setup_requires(): + self.run_setup() + + self._bubble_up_info_directory(metadata_directory, ".egg-info") + return self._bubble_up_info_directory(metadata_directory, ".dist-info") + + def _build_with_temp_dir( + self, + setup_command: Iterable[str], + result_extension: str | tuple[str, ...], + result_directory: StrPath, + config_settings: _ConfigSettings, + arbitrary_args: Iterable[str] = (), + ): + result_directory = os.path.abspath(result_directory) + + # Build in a temporary directory, then copy to the target. + os.makedirs(result_directory, exist_ok=True) + + with tempfile.TemporaryDirectory( + prefix=".tmp-", dir=result_directory + ) as tmp_dist_dir: + sys.argv = [ + *sys.argv[:1], + *self._global_args(config_settings), + *setup_command, + "--dist-dir", + tmp_dist_dir, + *arbitrary_args, + ] + with no_install_setup_requires(): + self.run_setup() + + result_basename = _file_with_extension(tmp_dist_dir, result_extension) + result_path = os.path.join(result_directory, result_basename) + if os.path.exists(result_path): + # os.rename will fail overwriting on non-Unix. + os.remove(result_path) + os.rename(os.path.join(tmp_dist_dir, result_basename), result_path) + + return result_basename + + def build_wheel( + self, + wheel_directory: StrPath, + config_settings: _ConfigSettings = None, + metadata_directory: StrPath | None = None, + ): + def _build(cmd: list[str]): + with suppress_known_deprecation(): + return self._build_with_temp_dir( + cmd, + '.whl', + wheel_directory, + config_settings, + self._arbitrary_args(config_settings), + ) + + if metadata_directory is None: + return _build(['bdist_wheel']) + + try: + return _build(['bdist_wheel', '--dist-info-dir', str(metadata_directory)]) + except SystemExit as ex: # pragma: nocover + # pypa/setuptools#4683 + if "--dist-info-dir not recognized" not in str(ex): + raise + _IncompatibleBdistWheel.emit() + return _build(['bdist_wheel']) + + def build_sdist( + self, sdist_directory: StrPath, config_settings: _ConfigSettings = None + ): + return self._build_with_temp_dir( + ['sdist', '--formats', 'gztar'], '.tar.gz', sdist_directory, config_settings + ) + + def _get_dist_info_dir(self, metadata_directory: StrPath | None) -> str | None: + if not metadata_directory: + return None + dist_info_candidates = list(Path(metadata_directory).glob("*.dist-info")) + assert len(dist_info_candidates) <= 1 + return str(dist_info_candidates[0]) if dist_info_candidates else None + + if not LEGACY_EDITABLE: + # PEP660 hooks: + # build_editable + # get_requires_for_build_editable + # prepare_metadata_for_build_editable + def build_editable( + self, + wheel_directory: StrPath, + config_settings: _ConfigSettings = None, + metadata_directory: StrPath | None = None, + ): + # XXX can or should we hide our editable_wheel command normally? + info_dir = self._get_dist_info_dir(metadata_directory) + opts = ["--dist-info-dir", info_dir] if info_dir else [] + cmd = ["editable_wheel", *opts, *self._editable_args(config_settings)] + with suppress_known_deprecation(): + return self._build_with_temp_dir( + cmd, ".whl", wheel_directory, config_settings + ) + + def get_requires_for_build_editable( + self, config_settings: _ConfigSettings = None + ): + return self.get_requires_for_build_wheel(config_settings) + + def prepare_metadata_for_build_editable( + self, metadata_directory: StrPath, config_settings: _ConfigSettings = None + ): + return self.prepare_metadata_for_build_wheel( + metadata_directory, config_settings + ) + + +class _BuildMetaLegacyBackend(_BuildMetaBackend): + """Compatibility backend for setuptools + + This is a version of setuptools.build_meta that endeavors + to maintain backwards + compatibility with pre-PEP 517 modes of invocation. It + exists as a temporary + bridge between the old packaging mechanism and the new + packaging mechanism, + and will eventually be removed. + """ + + def run_setup(self, setup_script: str = 'setup.py'): + # In order to maintain compatibility with scripts assuming that + # the setup.py script is in a directory on the PYTHONPATH, inject + # '' into sys.path. (pypa/setuptools#1642) + sys_path = list(sys.path) # Save the original path + + script_dir = os.path.dirname(os.path.abspath(setup_script)) + if script_dir not in sys.path: + sys.path.insert(0, script_dir) + + # Some setup.py scripts (e.g. in pygame and numpy) use sys.argv[0] to + # get the directory of the source code. They expect it to refer to the + # setup.py script. + sys_argv_0 = sys.argv[0] + sys.argv[0] = setup_script + + try: + super().run_setup(setup_script=setup_script) + finally: + # While PEP 517 frontends should be calling each hook in a fresh + # subprocess according to the standard (and thus it should not be + # strictly necessary to restore the old sys.path), we'll restore + # the original path so that the path manipulation does not persist + # within the hook after run_setup is called. + sys.path[:] = sys_path + sys.argv[0] = sys_argv_0 + + +class _IncompatibleBdistWheel(SetuptoolsDeprecationWarning): + _SUMMARY = "wheel.bdist_wheel is deprecated, please import it from setuptools" + _DETAILS = """ + Ensure that any custom bdist_wheel implementation is a subclass of + setuptools.command.bdist_wheel.bdist_wheel. + """ + _DUE_DATE = (2025, 10, 15) + # Initially introduced in 2024/10/15, but maybe too disruptive to be enforced? + _SEE_URL = "https://github.com/pypa/wheel/pull/631" + + +# The primary backend +_BACKEND = _BuildMetaBackend() + +get_requires_for_build_wheel = _BACKEND.get_requires_for_build_wheel +get_requires_for_build_sdist = _BACKEND.get_requires_for_build_sdist +prepare_metadata_for_build_wheel = _BACKEND.prepare_metadata_for_build_wheel +build_wheel = _BACKEND.build_wheel +build_sdist = _BACKEND.build_sdist + +if not LEGACY_EDITABLE: + get_requires_for_build_editable = _BACKEND.get_requires_for_build_editable + prepare_metadata_for_build_editable = _BACKEND.prepare_metadata_for_build_editable + build_editable = _BACKEND.build_editable + + +# The legacy backend +__legacy__ = _BuildMetaLegacyBackend() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/cli-32.exe b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/cli-32.exe new file mode 100644 index 0000000000000000000000000000000000000000..65c3cd99cc7433f271a5b9387abdd1ddb949d1a6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/cli-32.exe differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/cli-arm64.exe b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/cli-arm64.exe new file mode 100644 index 0000000000000000000000000000000000000000..da96455a07a0bad4cde5dc5626544325f82c722b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/cli-arm64.exe differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/cli.exe b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/cli.exe new file mode 100644 index 0000000000000000000000000000000000000000..65c3cd99cc7433f271a5b9387abdd1ddb949d1a6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/cli.exe differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/depends.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/depends.py new file mode 100644 index 0000000000000000000000000000000000000000..e5223b79561c36d9b6c45ead78288098e1cb0f1d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/depends.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +import contextlib +import dis +import marshal +import sys +from types import CodeType +from typing import Any, Literal, TypeVar + +from packaging.version import Version + +from . import _imp +from ._imp import PY_COMPILED, PY_FROZEN, PY_SOURCE, find_module + +_T = TypeVar("_T") + +__all__ = ['Require', 'find_module'] + + +class Require: + """A prerequisite to building or installing a distribution""" + + def __init__( + self, + name, + requested_version, + module, + homepage: str = '', + attribute=None, + format=None, + ) -> None: + if format is None and requested_version is not None: + format = Version + + if format is not None: + requested_version = format(requested_version) + if attribute is None: + attribute = '__version__' + + self.__dict__.update(locals()) + del self.self + + def full_name(self): + """Return full package/distribution name, w/version""" + if self.requested_version is not None: + return f'{self.name}-{self.requested_version}' + return self.name + + def version_ok(self, version): + """Is 'version' sufficiently up-to-date?""" + return ( + self.attribute is None + or self.format is None + or str(version) != "unknown" + and self.format(version) >= self.requested_version + ) + + def get_version( + self, paths=None, default: _T | Literal["unknown"] = "unknown" + ) -> _T | Literal["unknown"] | None | Any: + """Get version number of installed module, 'None', or 'default' + + Search 'paths' for module. If not found, return 'None'. If found, + return the extracted version attribute, or 'default' if no version + attribute was specified, or the value cannot be determined without + importing the module. The version is formatted according to the + requirement's version format (if any), unless it is 'None' or the + supplied 'default'. + """ + + if self.attribute is None: + try: + f, _p, _i = find_module(self.module, paths) + except ImportError: + return None + if f: + f.close() + return default + + v = get_module_constant(self.module, self.attribute, default, paths) + + if v is not None and v is not default and self.format is not None: + return self.format(v) + + return v + + def is_present(self, paths=None): + """Return true if dependency is present on 'paths'""" + return self.get_version(paths) is not None + + def is_current(self, paths=None): + """Return true if dependency is present and up-to-date on 'paths'""" + version = self.get_version(paths) + if version is None: + return False + return self.version_ok(str(version)) + + +def maybe_close(f): + @contextlib.contextmanager + def empty(): + yield + return + + if not f: + return empty() + + return contextlib.closing(f) + + +# Some objects are not available on some platforms. +# XXX it'd be better to test assertions about bytecode instead. +if not sys.platform.startswith('java') and sys.platform != 'cli': + + def get_module_constant( + module, symbol, default: _T | int = -1, paths=None + ) -> _T | int | None | Any: + """Find 'module' by searching 'paths', and extract 'symbol' + + Return 'None' if 'module' does not exist on 'paths', or it does not define + 'symbol'. If the module defines 'symbol' as a constant, return the + constant. Otherwise, return 'default'.""" + + try: + f, path, (_suffix, _mode, kind) = info = find_module(module, paths) + except ImportError: + # Module doesn't exist + return None + + with maybe_close(f): + if kind == PY_COMPILED: + f.read(8) # skip magic & date + code = marshal.load(f) + elif kind == PY_FROZEN: + code = _imp.get_frozen_object(module, paths) + elif kind == PY_SOURCE: + code = compile(f.read(), path, 'exec') + else: + # Not something we can parse; we'll have to import it. :( + imported = _imp.get_module(module, paths, info) + return getattr(imported, symbol, None) + + return extract_constant(code, symbol, default) + + def extract_constant( + code: CodeType, symbol: str, default: _T | int = -1 + ) -> _T | int | None | Any: + """Extract the constant value of 'symbol' from 'code' + + If the name 'symbol' is bound to a constant value by the Python code + object 'code', return that value. If 'symbol' is bound to an expression, + return 'default'. Otherwise, return 'None'. + + Return value is based on the first assignment to 'symbol'. 'symbol' must + be a global, or at least a non-"fast" local in the code block. That is, + only 'STORE_NAME' and 'STORE_GLOBAL' opcodes are checked, and 'symbol' + must be present in 'code.co_names'. + """ + if symbol not in code.co_names: + # name's not there, can't possibly be an assignment + return None + + name_idx = list(code.co_names).index(symbol) + + STORE_NAME = dis.opmap['STORE_NAME'] + STORE_GLOBAL = dis.opmap['STORE_GLOBAL'] + LOAD_CONST = dis.opmap['LOAD_CONST'] + + const = default + + for byte_code in dis.Bytecode(code): + op = byte_code.opcode + arg = byte_code.arg + + if op == LOAD_CONST: + assert arg is not None + const = code.co_consts[arg] + elif arg == name_idx and (op == STORE_NAME or op == STORE_GLOBAL): + return const + else: + const = default + + return None + + __all__ += ['get_module_constant', 'extract_constant'] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/gui-64.exe b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/gui-64.exe new file mode 100644 index 0000000000000000000000000000000000000000..031cb77c17ba8d8a983448268851d612e05e80d1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/gui-64.exe differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/gui.exe b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/gui.exe new file mode 100644 index 0000000000000000000000000000000000000000..1eb430c6d614a5daea4139badc09c222a4b0e72a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/gui.exe differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/sandbox.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/sandbox.py new file mode 100644 index 0000000000000000000000000000000000000000..2d84242d667c7df6a20aa56eabce6ac34ddc4a7e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/sandbox.py @@ -0,0 +1,536 @@ +from __future__ import annotations + +import builtins +import contextlib +import functools +import itertools +import operator +import os +import pickle +import re +import sys +import tempfile +import textwrap +from types import TracebackType +from typing import TYPE_CHECKING, Any, ClassVar + +import pkg_resources +from pkg_resources import working_set + +from distutils.errors import DistutilsError + +if TYPE_CHECKING: + import os as _os +elif sys.platform.startswith('java'): + import org.python.modules.posix.PosixModule as _os # pyright: ignore[reportMissingImports] +else: + _os = sys.modules[os.name] +_open = open + + +if TYPE_CHECKING: + from typing_extensions import Self + +__all__ = [ + "AbstractSandbox", + "DirectorySandbox", + "SandboxViolation", + "run_setup", +] + + +def _execfile(filename, globals, locals=None): + """ + Python 3 implementation of execfile. + """ + mode = 'rb' + with open(filename, mode) as stream: + script = stream.read() + if locals is None: + locals = globals + code = compile(script, filename, 'exec') + exec(code, globals, locals) + + +@contextlib.contextmanager +def save_argv(repl=None): + saved = sys.argv[:] + if repl is not None: + sys.argv[:] = repl + try: + yield saved + finally: + sys.argv[:] = saved + + +@contextlib.contextmanager +def save_path(): + saved = sys.path[:] + try: + yield saved + finally: + sys.path[:] = saved + + +@contextlib.contextmanager +def override_temp(replacement): + """ + Monkey-patch tempfile.tempdir with replacement, ensuring it exists + """ + os.makedirs(replacement, exist_ok=True) + + saved = tempfile.tempdir + + tempfile.tempdir = replacement + + try: + yield + finally: + tempfile.tempdir = saved + + +@contextlib.contextmanager +def pushd(target): + saved = os.getcwd() + os.chdir(target) + try: + yield saved + finally: + os.chdir(saved) + + +class UnpickleableException(Exception): + """ + An exception representing another Exception that could not be pickled. + """ + + @staticmethod + def dump(type, exc): + """ + Always return a dumped (pickled) type and exc. If exc can't be pickled, + wrap it in UnpickleableException first. + """ + try: + return pickle.dumps(type), pickle.dumps(exc) + except Exception: + # get UnpickleableException inside the sandbox + from setuptools.sandbox import UnpickleableException as cls + + return cls.dump(cls, cls(repr(exc))) + + +class ExceptionSaver: + """ + A Context Manager that will save an exception, serialize, and restore it + later. + """ + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> bool: + if not exc: + return False + + # dump the exception + self._saved = UnpickleableException.dump(type, exc) + self._tb = tb + + # suppress the exception + return True + + def resume(self): + "restore and re-raise any exception" + + if '_saved' not in vars(self): + return + + _type, exc = map(pickle.loads, self._saved) + raise exc.with_traceback(self._tb) + + +@contextlib.contextmanager +def save_modules(): + """ + Context in which imported modules are saved. + + Translates exceptions internal to the context into the equivalent exception + outside the context. + """ + saved = sys.modules.copy() + with ExceptionSaver() as saved_exc: + yield saved + + sys.modules.update(saved) + # remove any modules imported since + del_modules = ( + mod_name + for mod_name in sys.modules + if mod_name not in saved + # exclude any encodings modules. See #285 + and not mod_name.startswith('encodings.') + ) + _clear_modules(del_modules) + + saved_exc.resume() + + +def _clear_modules(module_names): + for mod_name in list(module_names): + del sys.modules[mod_name] + + +@contextlib.contextmanager +def save_pkg_resources_state(): + saved = pkg_resources.__getstate__() + try: + yield saved + finally: + pkg_resources.__setstate__(saved) + + +@contextlib.contextmanager +def setup_context(setup_dir): + temp_dir = os.path.join(setup_dir, 'temp') + with save_pkg_resources_state(): + with save_modules(): + with save_path(): + hide_setuptools() + with save_argv(): + with override_temp(temp_dir): + with pushd(setup_dir): + # ensure setuptools commands are available + __import__('setuptools') + yield + + +_MODULES_TO_HIDE = { + 'setuptools', + 'distutils', + 'pkg_resources', + 'Cython', + '_distutils_hack', +} + + +def _needs_hiding(mod_name): + """ + >>> _needs_hiding('setuptools') + True + >>> _needs_hiding('pkg_resources') + True + >>> _needs_hiding('setuptools_plugin') + False + >>> _needs_hiding('setuptools.__init__') + True + >>> _needs_hiding('distutils') + True + >>> _needs_hiding('os') + False + >>> _needs_hiding('Cython') + True + """ + base_module = mod_name.split('.', 1)[0] + return base_module in _MODULES_TO_HIDE + + +def hide_setuptools(): + """ + Remove references to setuptools' modules from sys.modules to allow the + invocation to import the most appropriate setuptools. This technique is + necessary to avoid issues such as #315 where setuptools upgrading itself + would fail to find a function declared in the metadata. + """ + _distutils_hack = sys.modules.get('_distutils_hack', None) + if _distutils_hack is not None: + _distutils_hack._remove_shim() + + modules = filter(_needs_hiding, sys.modules) + _clear_modules(modules) + + +def run_setup(setup_script, args): + """Run a distutils setup script, sandboxed in its directory""" + setup_dir = os.path.abspath(os.path.dirname(setup_script)) + with setup_context(setup_dir): + try: + sys.argv[:] = [setup_script] + list(args) + sys.path.insert(0, setup_dir) + # reset to include setup dir, w/clean callback list + working_set.__init__() + working_set.callbacks.append(lambda dist: dist.activate()) + + with DirectorySandbox(setup_dir): + ns = dict(__file__=setup_script, __name__='__main__') + _execfile(setup_script, ns) + except SystemExit as v: + if v.args and v.args[0]: + raise + # Normal exit, just return + + +class AbstractSandbox: + """Wrap 'os' module and 'open()' builtin for virtualizing setup scripts""" + + _active = False + + def __init__(self) -> None: + self._attrs = [ + name + for name in dir(_os) + if not name.startswith('_') and hasattr(self, name) + ] + + def _copy(self, source): + for name in self._attrs: + setattr(os, name, getattr(source, name)) + + def __enter__(self) -> None: + self._copy(self) + builtins.open = self._open + self._active = True + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ): + self._active = False + builtins.open = _open + self._copy(_os) + + def run(self, func): + """Run 'func' under os sandboxing""" + with self: + return func() + + def _mk_dual_path_wrapper(name: str): # type: ignore[misc] # https://github.com/pypa/setuptools/pull/4099 + original = getattr(_os, name) + + def wrap(self, src, dst, *args, **kw): + if self._active: + src, dst = self._remap_pair(name, src, dst, *args, **kw) + return original(src, dst, *args, **kw) + + return wrap + + for __name in ["rename", "link", "symlink"]: + if hasattr(_os, __name): + locals()[__name] = _mk_dual_path_wrapper(__name) + + def _mk_single_path_wrapper(name: str, original=None): # type: ignore[misc] # https://github.com/pypa/setuptools/pull/4099 + original = original or getattr(_os, name) + + def wrap(self, path, *args, **kw): + if self._active: + path = self._remap_input(name, path, *args, **kw) + return original(path, *args, **kw) + + return wrap + + _open = _mk_single_path_wrapper('open', _open) + for __name in [ + "stat", + "listdir", + "chdir", + "open", + "chmod", + "chown", + "mkdir", + "remove", + "unlink", + "rmdir", + "utime", + "lchown", + "chroot", + "lstat", + "startfile", + "mkfifo", + "mknod", + "pathconf", + "access", + ]: + if hasattr(_os, __name): + locals()[__name] = _mk_single_path_wrapper(__name) + + def _mk_single_with_return(name: str): # type: ignore[misc] # https://github.com/pypa/setuptools/pull/4099 + original = getattr(_os, name) + + def wrap(self, path, *args, **kw): + if self._active: + path = self._remap_input(name, path, *args, **kw) + return self._remap_output(name, original(path, *args, **kw)) + return original(path, *args, **kw) + + return wrap + + for __name in ['readlink', 'tempnam']: + if hasattr(_os, __name): + locals()[__name] = _mk_single_with_return(__name) + + def _mk_query(name: str): # type: ignore[misc] # https://github.com/pypa/setuptools/pull/4099 + original = getattr(_os, name) + + def wrap(self, *args, **kw): + retval = original(*args, **kw) + if self._active: + return self._remap_output(name, retval) + return retval + + return wrap + + for __name in ['getcwd', 'tmpnam']: + if hasattr(_os, __name): + locals()[__name] = _mk_query(__name) + + def _validate_path(self, path): + """Called to remap or validate any path, whether input or output""" + return path + + def _remap_input(self, operation, path, *args, **kw): + """Called for path inputs""" + return self._validate_path(path) + + def _remap_output(self, operation, path): + """Called for path outputs""" + return self._validate_path(path) + + def _remap_pair(self, operation, src, dst, *args, **kw): + """Called for path pairs like rename, link, and symlink operations""" + return ( + self._remap_input(operation + '-from', src, *args, **kw), + self._remap_input(operation + '-to', dst, *args, **kw), + ) + + if TYPE_CHECKING: + # This is a catch-all for all the dynamically created attributes. + # This isn't public API anyway + def __getattribute__(self, name: str) -> Any: ... + + +if hasattr(os, 'devnull'): + _EXCEPTIONS = [os.devnull] +else: + _EXCEPTIONS = [] + + +class DirectorySandbox(AbstractSandbox): + """Restrict operations to a single subdirectory - pseudo-chroot""" + + write_ops: ClassVar[dict[str, None]] = dict.fromkeys([ + "open", + "chmod", + "chown", + "mkdir", + "remove", + "unlink", + "rmdir", + "utime", + "lchown", + "chroot", + "mkfifo", + "mknod", + "tempnam", + ]) + + _exception_patterns: list[str | re.Pattern] = [] + "exempt writing to paths that match the pattern" + + def __init__(self, sandbox, exceptions=_EXCEPTIONS) -> None: + self._sandbox = os.path.normcase(os.path.realpath(sandbox)) + self._prefix = os.path.join(self._sandbox, '') + self._exceptions = [ + os.path.normcase(os.path.realpath(path)) for path in exceptions + ] + AbstractSandbox.__init__(self) + + def _violation(self, operation, *args, **kw): + from setuptools.sandbox import SandboxViolation + + raise SandboxViolation(operation, args, kw) + + def _open(self, path, mode='r', *args, **kw): + if mode not in ('r', 'rt', 'rb', 'rU', 'U') and not self._ok(path): + self._violation("open", path, mode, *args, **kw) + return _open(path, mode, *args, **kw) + + def tmpnam(self) -> None: + self._violation("tmpnam") + + def _ok(self, path): + active = self._active + try: + self._active = False + realpath = os.path.normcase(os.path.realpath(path)) + return ( + self._exempted(realpath) + or realpath == self._sandbox + or realpath.startswith(self._prefix) + ) + finally: + self._active = active + + def _exempted(self, filepath): + start_matches = ( + filepath.startswith(exception) for exception in self._exceptions + ) + pattern_matches = ( + re.match(pattern, filepath) for pattern in self._exception_patterns + ) + candidates = itertools.chain(start_matches, pattern_matches) + return any(candidates) + + def _remap_input(self, operation, path, *args, **kw): + """Called for path inputs""" + if operation in self.write_ops and not self._ok(path): + self._violation(operation, os.path.realpath(path), *args, **kw) + return path + + def _remap_pair(self, operation, src, dst, *args, **kw): + """Called for path pairs like rename, link, and symlink operations""" + if not self._ok(src) or not self._ok(dst): + self._violation(operation, src, dst, *args, **kw) + return (src, dst) + + def open(self, file, flags, mode: int = 0o777, *args, **kw) -> int: + """Called for low-level os.open()""" + if flags & WRITE_FLAGS and not self._ok(file): + self._violation("os.open", file, flags, mode, *args, **kw) + return _os.open(file, flags, mode, *args, **kw) + + +WRITE_FLAGS = functools.reduce( + operator.or_, + [ + getattr(_os, a, 0) + for a in "O_WRONLY O_RDWR O_APPEND O_CREAT O_TRUNC O_TEMPORARY".split() + ], +) + + +class SandboxViolation(DistutilsError): + """A setup script attempted to modify the filesystem outside the sandbox""" + + tmpl = textwrap.dedent( + """ + SandboxViolation: {cmd}{args!r} {kwargs} + + The package setup script has attempted to modify files on your system + that are not within the EasyInstall build area, and has been aborted. + + This package cannot be safely installed by EasyInstall, and may not + support alternate installation locations even if you run its setup + script by hand. Please inform the package's author and the EasyInstall + maintainers to find out if a fix or workaround is available. + """ + ).lstrip() + + def __str__(self) -> str: + cmd, args, kwargs = self.args + return self.tmpl.format(**locals()) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/script (dev).tmpl b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/script (dev).tmpl new file mode 100644 index 0000000000000000000000000000000000000000..39a24b04888e79df51e2237577b303a2f901be63 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/script (dev).tmpl @@ -0,0 +1,6 @@ +# EASY-INSTALL-DEV-SCRIPT: %(spec)r,%(script_name)r +__requires__ = %(spec)r +__import__('pkg_resources').require(%(spec)r) +__file__ = %(dev_path)r +with open(__file__) as f: + exec(compile(f.read(), __file__, 'exec')) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/version.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/version.py new file mode 100644 index 0000000000000000000000000000000000000000..ec253c414474677d3a5977511cfe901bfb786740 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/version.py @@ -0,0 +1,6 @@ +from ._importlib import metadata + +try: + __version__ = metadata.version('setuptools') or '0.dev0+unknown' +except Exception: + __version__ = '0.dev0+unknown' diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/warnings.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/warnings.py new file mode 100644 index 0000000000000000000000000000000000000000..96467787c237846bfbacf2d44eb833be0a88b633 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/warnings.py @@ -0,0 +1,110 @@ +"""Provide basic warnings used by setuptools modules. + +Using custom classes (other than ``UserWarning``) allow users to set +``PYTHONWARNINGS`` filters to run tests and prepare for upcoming changes in +setuptools. +""" + +from __future__ import annotations + +import os +import warnings +from datetime import date +from inspect import cleandoc +from textwrap import indent +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing_extensions import TypeAlias + +_DueDate: TypeAlias = tuple[int, int, int] # time tuple +_INDENT = 8 * " " +_TEMPLATE = f"""{80 * '*'}\n{{details}}\n{80 * '*'}""" + + +class SetuptoolsWarning(UserWarning): + """Base class in ``setuptools`` warning hierarchy.""" + + @classmethod + def emit( + cls, + summary: str | None = None, + details: str | None = None, + due_date: _DueDate | None = None, + see_docs: str | None = None, + see_url: str | None = None, + stacklevel: int = 2, + **kwargs, + ) -> None: + """Private: reserved for ``setuptools`` internal use only""" + # Default values: + summary_ = summary or getattr(cls, "_SUMMARY", None) or "" + details_ = details or getattr(cls, "_DETAILS", None) or "" + due_date = due_date or getattr(cls, "_DUE_DATE", None) + docs_ref = see_docs or getattr(cls, "_SEE_DOCS", None) + docs_url = docs_ref and f"https://setuptools.pypa.io/en/latest/{docs_ref}" + see_url = see_url or getattr(cls, "_SEE_URL", None) + due = date(*due_date) if due_date else None + + text = cls._format(summary_, details_, due, see_url or docs_url, kwargs) + if due and due < date.today() and _should_enforce(): + raise cls(text) + warnings.warn(text, cls, stacklevel=stacklevel + 1) + + @classmethod + def _format( + cls, + summary: str, + details: str, + due_date: date | None = None, + see_url: str | None = None, + format_args: dict | None = None, + ) -> str: + """Private: reserved for ``setuptools`` internal use only""" + today = date.today() + summary = cleandoc(summary).format_map(format_args or {}) + possible_parts = [ + cleandoc(details).format_map(format_args or {}), + ( + f"\nBy {due_date:%Y-%b-%d}, you need to update your project and remove " + "deprecated calls\nor your builds will no longer be supported." + if due_date and due_date > today + else None + ), + ( + "\nThis deprecation is overdue, please update your project and remove " + "deprecated\ncalls to avoid build errors in the future." + if due_date and due_date < today + else None + ), + (f"\nSee {see_url} for details." if see_url else None), + ] + parts = [x for x in possible_parts if x] + if parts: + body = indent(_TEMPLATE.format(details="\n".join(parts)), _INDENT) + return "\n".join([summary, "!!\n", body, "\n!!"]) + return summary + + +class InformationOnly(SetuptoolsWarning): + """Currently there is no clear way of displaying messages to the users + that use the setuptools backend directly via ``pip``. + The only thing that might work is a warning, although it is not the + most appropriate tool for the job... + + See pypa/packaging-problems#558. + """ + + +class SetuptoolsDeprecationWarning(SetuptoolsWarning): + """ + Base class for warning deprecations in ``setuptools`` + + This class is not derived from ``DeprecationWarning``, and as such is + visible by default. + """ + + +def _should_enforce(): + enforce = os.getenv("SETUPTOOLS_ENFORCE_DEPRECATION", "false").lower() + return enforce in ("true", "on", "ok", "1") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/windows_support.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/windows_support.py new file mode 100644 index 0000000000000000000000000000000000000000..7a2b53a291409c66851961a559eb4d69be0f4acc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/setuptools/windows_support.py @@ -0,0 +1,30 @@ +import platform + + +def windows_only(func): + if platform.system() != 'Windows': + return lambda *args, **kwargs: None + return func + + +@windows_only +def hide_file(path: str) -> None: + """ + Set the hidden attribute on a file or directory. + + From https://stackoverflow.com/questions/19622133/ + + `path` must be text. + """ + import ctypes + import ctypes.wintypes + + SetFileAttributes = ctypes.windll.kernel32.SetFileAttributesW + SetFileAttributes.argtypes = ctypes.wintypes.LPWSTR, ctypes.wintypes.DWORD + SetFileAttributes.restype = ctypes.wintypes.BOOL + + FILE_ATTRIBUTE_HIDDEN = 0x02 + + ret = SetFileAttributes(path, FILE_ATTRIBUTE_HIDDEN) + if not ret: + raise ctypes.WinError() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/cache_size.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/cache_size.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ee8e3880c144579f96307e1fb64f9099c4c8714 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/cache_size.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/callback.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/callback.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a7db179e36a97e2691df3ebecee8dce00d563f0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/callback.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/code_context.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/code_context.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f0a1f6914f7cbb520b81529b556ee34f433d8ff Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/code_context.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/config.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34d55e8bc3fc09cea73cff299af80fe8bfb98410 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/config.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/convert_frame.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/convert_frame.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..983e73a9ecc9238014125dc8db8c35354fcc4443 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/convert_frame.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/current_scope_id.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/current_scope_id.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d59425f1fed0b6de193dfb91988bb5ee69de459 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/current_scope_id.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/debug_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/debug_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3215d0b5617b0d87711f1d96f86915883bb4fae Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/debug_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/external_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/external_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b96bbeb8f05c234e3ab26704b7bc7c29dad74b8b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/external_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/functional_export.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/functional_export.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6dfa5c3700065b335b5b2c0687daf6b505bd42ab Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/functional_export.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/graph_bytecode_inputs.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/graph_bytecode_inputs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a105fcd5f368e20f446b7c5449f384f104608714 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/graph_bytecode_inputs.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/graph_deduplication.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/graph_deduplication.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3cb2d27548ae842ec9d5e31a709988263c190af Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/graph_deduplication.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/graph_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/graph_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38dbb426d0816fd0cf36541346e5d037bc991ef1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/graph_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/logging.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/logging.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb6fc54f835115f396d40fe8f287fc7db488fdbe Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/logging.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/package.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/package.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2724fa814d37c606740239b8b2678898662db76d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/package.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/test_case.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/test_case.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad234016da92f07d5342ff887f866051a1c22a5a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/test_case.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/test_dont_skip_tracing_functions.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/test_dont_skip_tracing_functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..292160273dbbd1ae9611e5e9f97ddc124c20bfd3 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/test_dont_skip_tracing_functions.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/test_minifier_common.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/test_minifier_common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bebdbdd355c15ede75e4f5a754697fc28a5fbb2e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/test_minifier_common.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/testing.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/testing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bb0f4f6b2756fd78c0d605cfee98090fd539908 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/__pycache__/testing.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..122b933e9b206263bff94b0d4a44071e74176eff Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/cudagraphs.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/cudagraphs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5f22e2552b3c9befcbac870628c5f17bdd67cda Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/cudagraphs.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/debugging.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/debugging.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..793b7ff1031cd197f812f492dede656fda77b366 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/debugging.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/distributed.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/distributed.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35f2ba40a7161cfd8be75029d14fa8b99af092c2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/distributed.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/onnxrt.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/onnxrt.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3e277ac2c995c4e5c25bd18d17fcba0920732c9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/onnxrt.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/torchxla.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/torchxla.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b59f0b3cc65ebf6e2d4bc02b2e6b84701474647 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/__pycache__/torchxla.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/common.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/common.py new file mode 100644 index 0000000000000000000000000000000000000000..0d2b6ecff0c17d70fd978058c1b5a5915aa41158 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/common.py @@ -0,0 +1,183 @@ +""" +This module provides common utilities and base classes for TorchDynamo backends. + +Key components: +- AotAutograd: Base class for implementing AOT (Ahead-of-Time) autograd backends +- Backend utilities for handling: + - Fake tensor conversion + - Device/dtype detection from inputs + - Memory efficient fusion + - Graph flattening + - Common compiler configurations + +The utilities here are used by various backend implementations to handle +common operations and provide consistent behavior across different backends. +AOT autograd functionality is particularly important as it enables ahead-of-time +optimization of both forward and backward passes. +""" + +import contextlib +import functools +import logging +from collections.abc import Callable, Iterable +from typing import Any +from typing_extensions import ParamSpec, TypeVar +from unittest.mock import patch + +import torch +from torch._dynamo import disable +from torch._dynamo.exc import TensorifyScalarRestartAnalysis +from torch._dynamo.utils import counters, defake, flatten_graph_inputs +from torch._functorch.aot_autograd import ( + aot_module_simplified, + SerializableAOTDispatchCompiler, +) +from torch.utils._python_dispatch import _disable_current_modes + + +log = logging.getLogger(__name__) + +P = ParamSpec("P") +R = TypeVar("R") + + +class AotAutograd: + def __init__(self, **kwargs: Any) -> None: + self.__name__ = "compiler_fn" + self.kwargs = kwargs + + def __call__( + self, gm: torch.fx.GraphModule, example_inputs: Iterable[Any], **kwargs: Any + ) -> Callable[..., Any]: + if kwargs: + log.warning("aot_autograd-based backend ignoring extra kwargs %s", kwargs) + + if any(isinstance(x, (list, tuple, dict)) for x in example_inputs): + return flatten_graph_inputs( + gm, + example_inputs, + self, + ) + + # Hack to get around circular import problems with aot_eager_decomp_partition + if callable(self.kwargs.get("decompositions")): + self.kwargs["decompositions"] = self.kwargs["decompositions"]() + + # NB: dont delete counter increment + counters["aot_autograd"]["total"] += 1 + use_fallback = False + + if use_fallback: + log.debug("Unable to use AOT Autograd because graph has mutation") + counters["aot_autograd"]["not_ok"] += 1 + return gm + + def wrap_bw_compiler(bw_compiler_fn: Callable[P, R]) -> Callable[..., R]: + def _wrapped_bw_compiler(*args: P.args, **kwargs: P.kwargs) -> R: + # Note [Wrapping bw_compiler in disable] + # The two disables here: + # - stop TorchDynamo from trying to compile the bw_compiler function itself + # - stop TorchDynamo from trying to compile our the generated backwards pass bw_compiler produces + + return disable( + disable( + bw_compiler_fn, reason="do not trace backward compiler function" + )(*args, **kwargs), # type: ignore[misc] + reason="do not trace generated backwards pass", + ) + + _wrapped_bw_compiler._is_wrapped_bw_compiler = ( # pyrefly: ignore [missing-attribute] + True + ) + return _wrapped_bw_compiler + + bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"] + + if isinstance(bw_compiler, SerializableAOTDispatchCompiler): + bw_compiler.compiler_fn = wrap_bw_compiler(bw_compiler.compiler_fn) + elif getattr(bw_compiler, "_is_wrapped_bw_compiler", False): + bw_compiler.compiler_fn = bw_compiler + else: + bw_compiler = wrap_bw_compiler(bw_compiler) + + self.kwargs["bw_compiler"] = bw_compiler + self.kwargs["inference_compiler"] = ( + self.kwargs.get("inference_compiler") or self.kwargs["fw_compiler"] + ) + + from functorch.compile import nop + from torch._inductor.debug import enable_aot_logging + + # debug asserts slow down compile time noticeably, + # So only default them on when the aot_eager backend is used. + if self.kwargs.get("fw_compiler", None) is nop: + patch_config: contextlib.AbstractContextManager[Any] = patch( + "functorch.compile.config.debug_assert", True + ) + else: + patch_config = contextlib.nullcontext() + + try: + # NB: NOT cloned! + with enable_aot_logging(), patch_config: + cg = aot_module_simplified(gm, example_inputs, **self.kwargs) + counters["aot_autograd"]["ok"] += 1 + return disable(cg, reason="do not trace AOT-compiled graph") + except TensorifyScalarRestartAnalysis: + raise + except Exception: + counters["aot_autograd"]["not_ok"] += 1 + raise + + +def aot_autograd(**kwargs: Any) -> AotAutograd: + return AotAutograd(**kwargs) + + +def mem_efficient_fusion_kwargs(use_decomps: bool) -> dict[str, Any]: + from functorch.compile import ( + default_decompositions, + min_cut_rematerialization_partition, + ts_compile, + ) + + kwargs = { + # these are taken from memory_efficient_fusion() + "fw_compiler": ts_compile, + "bw_compiler": ts_compile, + "partition_fn": min_cut_rematerialization_partition, + } + + if use_decomps: + kwargs["decompositions"] = default_decompositions + + return kwargs + + +def fake_tensor_unsupported(fn: Callable[[Any, list[Any], Any], R]) -> Any: + """ + Decorator for backends that need real inputs. We swap out fake + tensors for zero tensors. + """ + + @functools.wraps(fn) + def wrapper(model: Any, inputs: Any, **kwargs: Any) -> Any: + with _disable_current_modes(): + inputs = list(map(defake, inputs)) + return fn(model, inputs, **kwargs) # type: ignore[call-arg] + + return wrapper + + +def device_from_inputs(example_inputs: Iterable[Any]) -> torch.device: + for x in example_inputs: + if hasattr(x, "device"): + return x.device + return torch.device("cpu") # Default fallback + + +def dtype_from_inputs(example_inputs: Iterable[Any]) -> torch.dtype: + for x in example_inputs: + if hasattr(x, "dtype"): + return x.dtype + return torch.float32 # Default fallback diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/debugging.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/debugging.py new file mode 100644 index 0000000000000000000000000000000000000000..0e62e08cf1fc93a3acb11249e561ee06eb44e655 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/debugging.py @@ -0,0 +1,558 @@ +""" +This module provides debugging backends for TorchDynamo to help diagnose and troubleshoot +compilation and execution issues. It includes: + +Key Debugging Backends: +- eager: Simple pass-through backend that runs models in eager mode +- eager_noexcept: Similar to eager but with additional exception handling +- eager_debug: Adds schema validation checks for custom operators +- aot_eager: Uses AOT Autograd with nop compiler for debugging +- aot_eager_decomp_partition: Uses TorchInductor decompositions for debugging +- torchscript: Compiles using TorchScript for debugging JIT-related issues + +Testing and Development Tools: +- Backends for inducing specific errors (compile/runtime/accuracy) +- ExplainOutput class for detailed graph compilation analysis +- Utilities for cross-referencing and mode management +- Tools for graph detail inspection and break reason analysis + +These backends are primarily used for: +1. Debugging graph breaks and compilation failures +2. Testing error handling and recovery mechanisms +3. Analyzing performance bottlenecks +4. Validating operator schemas and decompositions +""" + +import dataclasses +import functools +import logging +from collections.abc import Callable, Iterable +from importlib import import_module +from typing import Any, Optional, TYPE_CHECKING, Union + +import torch +from functorch.compile import min_cut_rematerialization_partition +from torch import _guards +from torch._dynamo.output_graph import GraphCompileReason +from torch._functorch import config as functorch_config +from torch._functorch.compilers import ts_compile + +from .common import aot_autograd +from .registry import CompiledFn, CompilerFn, register_debug_backend as register_backend + + +if TYPE_CHECKING: + from torch.fx.node import Target + + +log = logging.getLogger(__name__) + + +@register_backend +def eager( + gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any +) -> Callable[..., Any]: + if kwargs: + log.warning("eager backend ignoring extra kwargs %s", kwargs) + return gm.forward + + +def make_eager_backend_with_torch_function_mode( + mode: torch.overrides.TorchFunctionMode, +) -> Callable[..., Any]: + return make_eager_backend_with_torch_function_modes([mode]) + + +def make_eager_backend_with_torch_function_modes( + modes: Iterable[torch.overrides.TorchFunctionMode], +) -> Callable[..., Any]: + """Used to trace HOPs (cond and while) for eager execution, the metadata + TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks + in the HOP, so we need to externally run this mode and not trace it.""" + from contextlib import ExitStack + + def fn( + gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any + ) -> Callable[..., Any]: + def wrapper(*args: Any, **kwargs: Any) -> Any: + with ExitStack() as stack: + for mode in modes: + stack.enter_context(mode) + return gm.forward(*args, **kwargs) + + return wrapper + + return fn + + +@register_backend +def eager_noexcept( + gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any +) -> Callable[..., Any]: + if kwargs: + log.warning("eager_noexcept backend ignoring extra kwargs %s", kwargs) + + # This backend is intended to check that dynamo-generated GraphModules + # do not cause errors. + def inner(*args: Any) -> Any: + try: + return gm(*args) + except Exception as e: + raise torch._dynamo.exc.TorchDynamoException( + "Unexpected exception when running generated GraphModule" + ) from e + + return inner + + +@register_backend +def pre_dispatch_eager( + gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any +) -> torch.fx.GraphModule: + if kwargs: + log.warning("pre_dispatch_eager backend ignoring extra kwargs %s", kwargs) + + from torch.fx.experimental.proxy_tensor import make_fx + + def runnable_gm(*args: Any) -> Any: + return torch.fx.Interpreter(gm).run(*args) + + pre_dispatch_gm = make_fx(runnable_gm, pre_dispatch=True)(*fake_tensor_inputs) + pre_dispatch_gm.print_readable() + + return pre_dispatch_gm + + +@register_backend +def eager_debug( + gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any +) -> Callable[..., Any]: + if kwargs: + log.warning("eager_debug backend ignoring extra kwargs %s", kwargs) + + from torch._subclasses.schema_check_mode import SchemaCheckMode + + # We could add more debugging bits here. + # Right now, this backend can be used to check for and error on + # custom dispatcher ops that have incorrect schemas. + def inner(*args: Any) -> Any: + with SchemaCheckMode(): + return torch.fx.Interpreter(gm).run(*args) + + return inner + + +@register_backend(name="ts") # type: ignore[misc] +def torchscript( + gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor] +) -> torch.jit.ScriptModule: + return torch.jit.script(gm) + + +# used boxed call to discard inputs when they are no longer needed +def boxed_nop( + fx_g: torch.fx.GraphModule, example_inputs: list[torch.Tensor] +) -> Callable[..., Any]: + from torch.fx.graph import _BoxedCodeGen + + # Set the graph to use boxed codegen + fx_g.graph.set_codegen(_BoxedCodeGen()) + fx_g.recompile() + + # Wrap the forward method in a function so we can set _boxed_call attribute + forward_fn = fx_g.forward + + def run(args: Any) -> Any: + return forward_fn(args) + + run._boxed_call = True # type: ignore[attr-defined] + return run + + +def boxed_nop_with_mode( + fx_g: torch.fx.GraphModule, + example_inputs: list[torch.Tensor], + *, + mode: torch.overrides.TorchFunctionMode, +) -> Callable[..., Any]: + from torch.fx.graph import _BoxedCodeGen + + # Set the graph to use boxed codegen + fx_g.graph.set_codegen(_BoxedCodeGen()) + fx_g.recompile() + + # Create a wrapper that runs with the mode + forward_fn = fx_g.forward + + def run(args: Any) -> Any: + with mode: + return forward_fn(args) + + run._boxed_call = True # type: ignore[attr-defined] + return run + + +def fake_crossref_boxed_nop( + fx_g: torch.fx.GraphModule, + example_inputs: list[torch.Tensor], + ignore_op_fn: Optional[Callable[[torch._ops.OpOverload], bool]] = None, +) -> Callable[..., Any]: + from torch.fx.graph import _BoxedCodeGen + + # Set the graph to use boxed codegen + fx_g.graph.set_codegen(_BoxedCodeGen()) + fx_g.recompile() + + # Create a wrapper that runs with the mode + forward_fn = fx_g.forward + + def run(args: Any) -> Any: + with torch._subclasses.CrossRefFakeMode(ignore_op_fn): + return forward_fn(args) + + run._boxed_call = True # type: ignore[attr-defined] + return run + + +def ignore_builtins(op: torch._ops.OpOverload) -> bool: + return op.namespace in ("aten", "prims", "prim") + + +def get_nop_func() -> Callable[ + [torch.fx.GraphModule, list[torch.Tensor]], Callable[..., Any] +]: + if not torch._functorch.config.fake_tensor_crossref: + return boxed_nop + elif torch._functorch.config.fake_tensor_crossref == "all": + return fake_crossref_boxed_nop + else: + assert torch._functorch.config.fake_tensor_crossref == "custom_ops" + return functools.partial(fake_crossref_boxed_nop, ignore_op_fn=ignore_builtins) + + +# Useful for debugging purpose +# aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging. +def aot_eager( + gm: torch.fx.GraphModule, + fake_tensor_inputs: list[torch.Tensor], + fw_compiler: Optional[Callable[..., Any]] = None, + bw_compiler: Optional[Callable[..., Any]] = None, + **kwargs: Any, +) -> Callable[..., Any]: + return aot_autograd( + fw_compiler=fw_compiler or boxed_nop, + bw_compiler=bw_compiler or boxed_nop, + partition_fn=min_cut_rematerialization_partition, + keep_inference_input_mutations=True, + )(gm, fake_tensor_inputs, **kwargs) + + +register_backend(name="aot_eager", compiler_fn=aot_eager) + +aot_eager_default_partitioner = aot_autograd( + fw_compiler=boxed_nop, keep_inference_input_mutations=True +) +register_backend( + name="aot_eager_default_partitioner", compiler_fn=aot_eager_default_partitioner +) + + +# Uses TorchInductor AOT Autograd decomps and partitioner to isolate aot vs +# inductor problems. +# aot_eager_decomp_partition just replaces the inductor compiler with nop to help +# isolate inductor vs aot_eager errors +def aot_eager_decomp_partition( + gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any +) -> Callable[..., Any]: + if kwargs: + log.warning( + "aot_eager_decomp_partition backend ignoring extra kwargs %s", kwargs + ) + + from torch._inductor.compiler_bisector import CompilerBisector + + config_patches = {"unlift_effect_tokens": True} + if bisect_changes := CompilerBisector.get_config_change( + "aot_eager_decomp_partition" + ): + config_patches.update(bisect_changes) # type: ignore[arg-type] + + with functorch_config.patch(config_patches): + return aot_autograd( + # these are taken from memory_efficient_fusion() + fw_compiler=get_nop_func(), + bw_compiler=get_nop_func(), + # NB: lambda here is to delay import of inductor + decompositions=lambda: import_module( + "torch._inductor.compile_fx" + ).select_decomp_table(), + partition_fn=functools.partial( + min_cut_rematerialization_partition, compiler="inductor" + ), + )(gm, fake_tensor_inputs) + + +register_backend( + name="aot_eager_decomp_partition", compiler_fn=aot_eager_decomp_partition +) + + +# aot_eager_decomp_partition_with_mode is similar as aot_eager_decomp_partition, +# except that it takes a TorchDispatchMode mode and run the fw/bw in the mode +def aot_eager_decomp_partition_with_mode( + gm: torch.fx.GraphModule, + fake_tensor_inputs: list[torch.Tensor], + mode: Any, + **kwarg: Any, +) -> Callable[..., Any]: + return aot_autograd( + # these are taken from memory_efficient_fusion() + fw_compiler=functools.partial(boxed_nop_with_mode, mode=mode), + bw_compiler=functools.partial(boxed_nop_with_mode, mode=mode), + # NB: lambda here is to delay import of inductor + decompositions=lambda: import_module( + "torch._inductor.compile_fx" + ).select_decomp_table(), + partition_fn=functools.partial( + min_cut_rematerialization_partition, compiler="inductor" + ), + )(gm, fake_tensor_inputs) + + +register_backend( + name="aot_eager_decomp_partition_with_mode", + compiler_fn=aot_eager_decomp_partition_with_mode, # type: ignore[arg-type] +) + + +def aot_eager_decomp_partition_crossref( + gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any +) -> Callable[..., Any]: + # if the config is set, respect it, otherwise only test custom_ops. + # custom_op bad metas always manifest as an error whereas aten will only sometimes. + # by default, use the less noisy option + config_val = ( + "custom_ops" + if not functorch_config.fake_tensor_crossref + else functorch_config.fake_tensor_crossref + ) + with functorch_config.patch(fake_tensor_crossref=config_val): + return aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs) + + +register_backend( + name="aot_eager_decomp_partition_crossref", + compiler_fn=aot_eager_decomp_partition_crossref, +) + + +# AOT Autograd with torchscript backend. Default partitioner. +# aot_ts uses torchscript backend. We can use this with both nnc and nvfuser +# by using the relevant fuser with torch.jit.fuser(...) +aot_ts = aot_autograd(fw_compiler=ts_compile) +register_backend(name="aot_ts", compiler_fn=aot_ts) + +# These buggy backends are used for inducing bugs so that we can test +# our repro extraction / minifier scripts + + +class ReluCompileError(Exception): + pass + + +class TestingOnlyCompileError(Exception): + pass + + +@register_backend +def relu_compile_error_TESTING_ONLY( + gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] +) -> torch.fx.GraphModule: + for node in gm.graph.nodes: + if node.target is torch.relu: + raise ReluCompileError + return gm + + +@register_backend +def relu_runtime_error_TESTING_ONLY( + gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] +) -> torch.fx.GraphModule: + for node in gm.graph.nodes: + if node.target is torch.relu: + node.target = torch._assert + node.args = (False, "ReluRuntimeError") + gm.recompile() + return gm + + +@register_backend +def relu_accuracy_error_TESTING_ONLY( + gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] +) -> torch.fx.GraphModule: + for node in gm.graph.nodes: + if node.target is torch.relu: + node.target = torch.add + node.args = (node.args[0], 1) + gm.recompile() + + return gm + + +@register_backend +def non_leaf_compile_error_TESTING_ONLY( + gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] +) -> torch.fx.GraphModule: + # Require at least one non-trivial thing in the graph, + # see https://github.com/pytorch/pytorch/issues/102898 + for node in gm.graph.nodes: + if node.op == "call_function": + break + else: + return gm + for t in example_inputs: + if not t.is_leaf: + raise TestingOnlyCompileError + return gm + + +@dataclasses.dataclass +class ExplainOutput: + """ + This is the output of :func:`torch._dynamo.explain()` + There is no reason to create this class directly. + """ + + graphs: list[torch.fx.GraphModule] + graph_count: int + graph_break_count: int + break_reasons: list[GraphCompileReason] + op_count: int + ops_per_graph: Optional[list[list["Target"]]] = None + out_guards: Optional[list[_guards.Guard]] = None + compile_times: Optional[str] = None + + def __str__(self) -> str: + output = f"Graph Count: {self.graph_count}\n" + output += f"Graph Break Count: {self.graph_break_count}\n" + output += f"Op Count: {self.op_count}\n" + + output += "Break Reasons:\n" + for idx, break_reason in enumerate(self.break_reasons): + output += f" Break Reason {idx + 1}:\n" + output += f" Reason: {break_reason.reason}\n" + output += " User Stack:\n" + for frame_summary in break_reason.user_stack: + output += f" {frame_summary}\n" + + if self.ops_per_graph is not None: + output += "Ops per Graph:\n" + for idx, ops in enumerate(self.ops_per_graph): + output += f" Ops {idx + 1}:\n" + for op in ops: + output += f" {op}\n" + + if self.out_guards is not None: + output += "Out Guards:\n" + for i, guard in enumerate(self.out_guards): + output += f" Guard {i + 1}:\n" + output += f" {str(guard)}" + + if self.compile_times is not None: + output += f"Compile Times: {self.compile_times}\n" + return output + + +def _explain_graph_detail( + gm: torch.fx.GraphModule, + graphs: list[torch.fx.GraphModule], + op_count: int, + ops_per_graph: list[list["Target"]], + break_reasons: list[GraphCompileReason], +) -> tuple[ + torch.fx.GraphModule, + list[torch.fx.GraphModule], + int, + list[list["Target"]], + list[GraphCompileReason], +]: + """ + This function is a utility which processes a torch.fx.GraphModule and + accumulates information about its ops, graph breaks, and other details. It + is intended to be used by the ExplainWithBackend class and + `torch._dynamo.explain()` to provide details from Dynamo's graph capture. + + Parameters: + gm (torch.fx.GraphModule): The GraphModule to be processed. + graphs (list): A list that accumulates all the GraphModules processed. + op_count (int): The total count of operations in all GraphModules processed so far. + ops_per_graph (list): A list that accumulates the operations of each GraphModule. + break_reasons (list): A list that accumulates the reasons for breaks in each GraphModule. + + Returns: + tuple: A tuple containing the processed GraphModule, the updated lists of graphs, + operations per graph, and break reasons, and the updated operation count. + """ + graphs.append(gm) + ops = [node.target for node in gm.graph.nodes if node.op == "call_function"] + op_count += len(ops) + ops_per_graph.append(ops) + if gm.compile_subgraph_reason.graph_break: # type: ignore[union-attr] + break_reasons.append(gm.compile_subgraph_reason) # type: ignore[arg-type] + + return gm, graphs, op_count, ops_per_graph, break_reasons + + +class ExplainWithBackend: + """ + This class is intended to be used as a backend for `torch.compile`. It is + composable with other backends. When used in this way, it accumulates + information about graph breaks, ops, and other info and provides a string + representation summarizing this information. + + Attributes: + backend (str): The name of the backend to use for optimization. + graphs (list): A list of the graphs captured by TorchDynamo. + op_count (int): The total number of operations in all optimized graphs. + break_reasons (list): A list of graph break reasons with stack traces. + + Example Usage: + def fn(x): + x = torch.sigmoid(x) + return x + + torch._dynamo.reset() + eb = ExplainWithBackend("inductor") + optimized_fn = torch.compile(fn, backend=eb) + result = optimized_fn(torch.randn(5)) + print(eb.output()) + """ + + def __init__(self, backend: Union[CompilerFn, str]) -> None: + from .registry import lookup_backend + + self.backend = lookup_backend(backend) + self.graphs: list[torch.fx.GraphModule] = [] + self.op_count = 0 + self.break_reasons: list[GraphCompileReason] = [] + + def __call__( + self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] + ) -> CompiledFn: + ops_per_graph: list[list[Target]] = [] + gm, self.graphs, self.op_count, _, self.break_reasons = _explain_graph_detail( + gm, self.graphs, self.op_count, ops_per_graph, self.break_reasons + ) + return self.backend(gm, example_inputs) + + def output(self) -> ExplainOutput: + graph_count = len(self.graphs) + output = ExplainOutput( + self.graphs, + graph_count, + graph_count - 1, + self.break_reasons, + self.op_count, + ) + + return output diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/inductor.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/inductor.py new file mode 100644 index 0000000000000000000000000000000000000000..ae62dd56678b8349d27fe909f12482b884ca596c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/inductor.py @@ -0,0 +1,31 @@ +""" +This module provides the TorchInductor backend integration for TorchDynamo. + +TorchInductor is a compiler backend that generates optimized code for both CPU and GPU. +This module lazily imports and registers the TorchInductor compiler to avoid loading it +into memory when it is not being used. This helps reduce memory overhead when using +other backends. + +The inductor backend can be used with torch.compile(): + model = torch.compile(model, backend="inductor") +""" + +from typing import Any + +from torch._dynamo import register_backend +from torch._dynamo.utils import dynamo_timed + + +@register_backend +def inductor(*args: Any, **kwargs: Any) -> Any: + with dynamo_timed("inductor_import", log_pt2_compile_event=True): + # do import here to avoid loading inductor into memory when it is not used + # The AsyncCompile subproc pool can be slow to start, so warm it up as early + # as possible. + from torch._inductor.async_compile import maybe_warm_pool + + maybe_warm_pool() + + from torch._inductor.compile_fx import compile_fx + + return compile_fx(*args, **kwargs) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/tensorrt.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/tensorrt.py new file mode 100644 index 0000000000000000000000000000000000000000..493e21a9dfc5fe929fdeefdf6153834d470ab561 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/tensorrt.py @@ -0,0 +1,12 @@ +# import torch # type: ignore[import] +# from .common import device_from_inputs, fake_tensor_unsupported # type: ignore[import] +# from .registry import register_backend # type: ignore[import] + +""" +Placeholder for TensorRT backend for dynamo via torch-tensorrt +""" + +# @register_backend +# def tensorrt(gm, example_inputs): +# import torch_tensorrt # type: ignore[import] +# pass diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/tvm.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/tvm.py new file mode 100644 index 0000000000000000000000000000000000000000..02dde50de0fe02d793226b64d852967d99d31de6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/tvm.py @@ -0,0 +1,197 @@ +""" +This module provides TVM backend integration for TorchDynamo. + +Apache TVM is a deep learning compiler framework that can optimize and execute +models on various hardware backends. This module enables: + +- Compilation of PyTorch models to TVM's computation graphs +- Multiple scheduling options: + - Default scheduler + - Auto-scheduler for automatic optimization + - Meta-schedule for evolutionary search-based tuning +- Hardware-specific optimizations: + - CUDA GPU support + - CPU support with LLVM targeting and architecture-specific tuning + - Automatic detection of CPU capabilities (AVX2, AVX512) +- Tensor conversion utilities between PyTorch and TVM formats +- Configurable optimization levels and tuning trials + +The backend can be used with torch.compile(): + model = torch.compile(model, backend="tvm") +""" + +import functools +import importlib +import logging +import os +import sys +import tempfile +from collections.abc import Callable +from pathlib import Path +from types import MappingProxyType +from typing import Any, Optional + +import torch +from torch import fx + +from .common import device_from_inputs, fake_tensor_unsupported +from .registry import register_backend + + +log = logging.getLogger(__name__) + + +@register_backend +@fake_tensor_unsupported # type: ignore[arg-type] +def tvm( + gm: fx.GraphModule, + example_inputs: list[torch.Tensor], + *, + options: Optional[MappingProxyType[str, Any]] = None, +) -> Callable[..., Any]: + if options is None: + options = MappingProxyType({"scheduler": None, "trials": 20000, "opt_level": 3}) + assert options is not None + import tvm # type: ignore[import] + from tvm import relay # type: ignore[import] + from tvm.contrib import graph_executor # type: ignore[import] + + jit_mod = torch.jit.trace(gm, example_inputs) + device = device_from_inputs(example_inputs) + shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)] + example_outputs = gm(*example_inputs) + if len(example_outputs) == 0: + log.warning("Explicitly fall back to eager due to zero output") + return gm.forward + mod, params = relay.frontend.from_pytorch(jit_mod, shape_list) + if device.type == "cuda": + dev = tvm.cuda(device.index) + target = tvm.target.cuda() + else: + dev = tvm.cpu(0) + target = tvm.target.Target(llvm_target()) + + scheduler = options.get("scheduler", None) + if scheduler is None: + scheduler = os.environ.get("TVM_SCHEDULER", None) + + trials = options.get("trials", 20000) + opt_level = options.get("opt_level", 3) + + if scheduler == "auto_scheduler": + # pyrefly: ignore [import-error] + from tvm import auto_scheduler + + with ( + tempfile.NamedTemporaryFile() as log_file, + auto_scheduler.ApplyHistoryBest(log_file), + tvm.transform.PassContext( + opt_level=opt_level, config={"relay.backend.use_auto_scheduler": True} + ), + ): + lib = relay.build(mod, target=target, params=params) + elif scheduler == "meta_schedule": + # pyrefly: ignore [import-error] + from tvm import meta_schedule as ms + + with tempfile.TemporaryDirectory() as work_dir: + if device.type != "cuda": + # meta_schedule needs num-cores to be specified + # here we use the maximum core count + target = tvm.target.Target( + f"{llvm_target()} --num-cores {ms.utils.cpu_count(logical=False)}" + ) + # TODO(shingjan): This could be replaced by tvm.contrib.torch.optimize_torch + # once USE_PT_TVMDSOOP is updated and turned on by default in TVM. + assert trials > 0 + database = ms.relay_integration.tune_relay( + mod=mod, + target=target, + work_dir=work_dir, + max_trials_global=trials, + num_trials_per_iter=64, + params=params, + strategy="evolutionary", + opt_level=opt_level, + ) + lib = ms.relay_integration.compile_relay( + database=database, + mod=mod, + target=target, + params=params, + opt_level=opt_level, + ) + elif scheduler == "default" or not scheduler: + # no autotuning + with tvm.transform.PassContext(opt_level=opt_level): + lib = relay.build(mod, target=target, params=params) + else: + raise NotImplementedError( + "This tuning option is invalid/not implemented for torchdynamo's TVM-related backend. " + "There are three available options: default, auto_scheduler and meta_schedule." + ) + m = graph_executor.GraphModule(lib["default"](dev)) + + def to_torch_tensor(nd_tensor: tvm.nd.array) -> torch.Tensor: + """A helper function to transfer a NDArray to torch.tensor.""" + if nd_tensor.dtype == "bool": + # DLPack does not support boolean so it can't be handled by + # torch.utils.dlpack.from_pack. Workaround by going through + # numpy, although this brings additional data copy overhead. + return torch.from_numpy(nd_tensor.numpy()) + return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack()) + + def to_tvm_tensor(torch_tensor: torch.Tensor) -> tvm.nd.array: + """A helper function to transfer a torch.tensor to NDArray.""" + if torch_tensor.dtype == torch.bool: + # same reason as above, fallback to numpy conversion which + # could introduce data copy overhead + return tvm.nd.array(torch_tensor.cpu().numpy()) + return tvm.nd.from_dlpack(torch_tensor) + + def exec_tvm(*i_args: torch.Tensor) -> list[torch.Tensor]: + args = [a.contiguous() for a in i_args] + shape_info, _ = m.get_input_info() + active_inputs = {name for name, _ in shape_info.items()} + for idx, arg in enumerate(args, 0): + if arg.dim() != 0: + if arg.requires_grad: + arg = arg.detach() + inp_name = f"inp_{idx}" + if inp_name not in active_inputs: + log.warning( + "input %s skipped as not found in tvm's runtime library", + inp_name, + ) + continue + m.set_input( + inp_name, + to_tvm_tensor(arg), + ) + m.run() + return [to_torch_tensor(m.get_output(i)) for i in range(m.get_num_outputs())] + + return exec_tvm + + +tvm_meta_schedule = functools.partial(tvm, scheduler="meta_schedule") +tvm_auto_scheduler = functools.partial(tvm, scheduler="auto_scheduler") + + +def has_tvm() -> bool: + try: + importlib.import_module("tvm") + return True + except ImportError: + return False + + +@functools.cache +def llvm_target() -> str: + if sys.platform == "linux": + cpuinfo = Path("/proc/cpuinfo").read_text() + if "avx512" in cpuinfo: + return "llvm -mcpu=skylake-avx512" + elif "avx2" in cpuinfo: + return "llvm -mcpu=core-avx2" + return "llvm" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..59f6f76317e6daf4d6dbcfc93d363442b5e4335f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__init__.py @@ -0,0 +1,431 @@ +""" +Python polyfills for common builtins. +""" + +# NOTE: 1. Please do not import any submodule in the directory here to avoid circular imports. +# 2. While adding a new polyfill module, also add it to POLYFILLED_MODULE_NAMES in loader.py. +# Add it in the TYPE_CHECKING block below as well. + +# mypy: allow-untyped-defs + +import types +from collections import OrderedDict +from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence +from itertools import repeat as _repeat +from operator import eq, ne +from typing import Any, TYPE_CHECKING + +import torch + +from ..utils import dict_keys + + +if TYPE_CHECKING: + # Load by torch._dynamo.polyfills.loader + # See also the POLYFILLED_MODULE_NAMES in torch/_dynamo/polyfills/loader.py + # Put the submodules here to avoid circular imports + from . import ( + _collections as _collections, + builtins as builtins, + functools as functools, + itertools as itertools, + operator as operator, + os as os, + pytree as pytree, + struct as struct, + sys as sys, + ) + +from torch.overrides import BaseTorchFunctionMode + + +# These classes handle support for TorchFunctionModes across +# graph breaks +# Today the TorchFunctionMode enter (for the classes we support) +# simply pushes the mode onto the stack. Since after this occurs +# the stack is mutated, and we replay these mutations, we don't need +# any cleanup logic to be run once the graph break occurs, we simply replay +# these mutations to ensure at the graph break the torch function mode stack is correct +# and reconstruct the torch function mode stack normally +# when we compile the resume function on the other side of the break. +# However, to ensure we exit properly +# in the resume function, we need to re-enter the contexts as we do other contexts. +# These contexts do nothing on enter, but provide the correct exit logic to ensure +# the stack state is correct. +class NoEnterTorchFunctionMode(BaseTorchFunctionMode): + def __enter__(self): + pass + + +def index(iterator, item, start=0, end=None): + from itertools import islice + + for i, elem in islice(enumerate(iterator), start, end): + if item == elem: + return i + # This will not run in dynamo + raise ValueError(f"{item} is not in {type(iterator)}") + + +def repeat(item, count): + for _ in range(count): + yield item + + +def radians(x): + import math + + return math.pi / 180.0 * x + + +def impl_CONTAINS_OP_fallback(a, b): + # performs fallback "a in b" + if hasattr(b, "__iter__"): + # use __iter__ if __contains__ is not available + for x in b: + if x == a: + return True + return False + raise TypeError(f"argument of type {type(b)} is not iterable") + + +def accumulate_grad(x, new_grad): + # polyfills according to the Gradient Layout Contract + if new_grad is None: + return + new_grad_strided = torch.empty_like(x) + new_grad_strided.copy_(new_grad) + if x.grad is None: + x.grad = new_grad_strided + elif torch.is_grad_enabled(): + x.grad = x.grad + new_grad_strided + else: + x.grad.add_(new_grad_strided) + + +# This mirrors +# https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/listobject.c#L3352-L3413 +def list_cmp(op: Callable[[Any, Any], bool], left: Sequence[Any], right: Sequence[Any]): + """emulate `(1,2,3) > (1,2)` etc""" + + # Optimization: For equality, short-circuit if lengths differ + # This avoids iterating through elements and triggering guards on SymInts + left_len = len(left) + right_len = len(right) + + if op is eq and left_len != right_len: + return False + if op is ne and left_len != right_len: + return True + + # Apply `op` to the first pair that differ + for a, b in zip(left, right): + if a != b: + return op(a, b) + + # No more pairs to compare, so compare sizes. + return op(left_len, right_len) + + +def dict___eq__(d, other): + if (len(d) != len(other)) or (d.keys() != other.keys()): + return False + + if all(isinstance(a, OrderedDict) for a in (d, other)): + return list(d.items()) == list(other.items()) + + for k, v in d.items(): + if v != other[k]: + return False + + return True + + +def set_symmetric_difference(set1, set2): + symmetric_difference_set = set() + for x in set1: + if x not in set2: + symmetric_difference_set.add(x) + for x in set2: + if x not in set1: + symmetric_difference_set.add(x) + return symmetric_difference_set + + +def set_symmetric_difference_update(set1, set2): + result = set1.symmetric_difference(set2) + set1.clear() + set1.update(result) + + +def set_isdisjoint(set1, set2): + if not isinstance(set2, Iterable): + raise TypeError(f"'{type(set2)}' object is not iterable") + + for x in set1: + for y in set2: + if not isinstance(y, Hashable): + raise TypeError(f"unhashable type: '{type(y)}'") + if x == y: + return False + return True + + +def set_intersection(set1, *others): + if len(others) == 0: + return set1.copy() + + if not all(isinstance(s, Iterable) for s in others): + raise TypeError(f"set.difference expected an iterable, got {type(others)}") + + for s in others: + if any(not isinstance(x, Hashable) for x in s): + raise TypeError("unhashable type") + + # return a new set with elements common in all sets + intersection_set = set() + for x in set1: + for set2 in others: + if not any(x == y for y in set2): + break + else: + intersection_set.add(x) + return intersection_set + + +def set_intersection_update(set1, *others): + result = set1.intersection(*others) + set1.clear() + set1.update(result) + + +def set_union(set1, *others): + # frozenset also uses this function + if len(others) == 0: + return set1.copy() + + if not all(isinstance(s, Iterable) for s in others): + raise TypeError(f"set.union expected an iterable, got {type(others)}") + + for s in others: + if any(not isinstance(x, Hashable) for x in s): + raise TypeError("unhashable type") + + union_set = set(set1.copy()) + for set2 in others: + set_update(union_set, set2) + + # frozenset also uses this function + return type(set1)(union_set) + + +def set_update(set1, *others): + if len(others) == 0: + return set1 + + for set2 in others: + for x in set2: + if x not in set1: + set1.add(x) + + +def set_difference(set1, *others): + if len(others) == 0: + return set1.copy() + + if not all(isinstance(s, Iterable) for s in others): + raise TypeError(f"set.difference expected an iterable, got {type(others)}") + + for s in others: + if any(not isinstance(x, Hashable) for x in s): + raise TypeError("unhashable type") + + difference_set = set() + for x in set1: + for set2 in others: + if x in set2: + break + else: + difference_set.add(x) + return difference_set + + +def set_difference_update(set1, *others): + result = set1.difference(*others) + set1.clear() + set1.update(result) + + +def assert_dict_equal(self_, d1, d2, msg=None): + self_.assertTrue(d1 == d2, msg) + + +def assert_multi_line_equal(self_, first, second, msg=None): + return self_.assertTrue(first == second, msg) + + +# The original impl. uses difflib +def assert_sequence_equal(self_, seq1, seq2, msg=None, seq_type=None): + return self_.assertTrue(seq1 == seq2, msg) + + +def getattr_and_trace(*args, **kwargs): + wrapper_obj = args[0] + attr_name = args[1] + fn = getattr(wrapper_obj, attr_name) + return fn(*args[2:], **kwargs) + + +def mapping_get(obj, key, value=None, /): + try: + return obj.__getitem__(key) + except KeyError: + return value + + +def instantiate_user_defined_class_object(cls, /, *args, **kwargs): + obj = cls.__new__(cls, *args, **kwargs) + + # Only call __init__ if the object is an instance of the class + # Reference: https://github.com/python/cpython/blob/3.12/Objects/typeobject.c#L1670-L1673 + if isinstance(obj, cls): + obj.__init__(*args, **kwargs) + return obj + + +def mutable_mapping_update(self, data=(), /, **kwargs): + if isinstance(data, Mapping): + # Merge standard mapping with PyMapping_Items + for key, value in data.items(): + self[key] = value + # FIXME: Enabling the `elif`-branch below needs too many `VariableClass.call_obj_hasattr` changes. + # >>> class Foo: + # ... def __init__(self): + # ... self.keys = lambda: ['a', 'b', 'c'] # not required to be a method + # ... + # ... def __getitem__(self, key): + # ... return 0 + # ... + # >>> dict(Foo()) + # {'a': 0, 'b': 0, 'c': 0} + # + # > This is a rare case, so we comment it out for now. + # + # elif hasattr(data, "keys"): + # # Merge mapping-like object with PyMapping_Keys + PyObject_GetItem + # for key in data.keys(): + # self[key] = data[key] + else: + if not isinstance(data, Iterable): + raise TypeError(f"{type(data).__name__!r} object is not iterable") + # Likely a sequence of pairs + for key, value in data: + self[key] = value + + if kwargs: + for key, value in kwargs.items(): + self[key] = value + + +# Used with something like dict(obj) +def construct_dict(cls, data=(), /, **kwargs): + self = cls.__new__(cls) + mutable_mapping_update(self, data, **kwargs) + return self + + +def foreach_map_fn(*args): + op = args[0] + new_args: list[Any] = [] + at_least_one_list = False + for arg in args[1:]: + if not isinstance(arg, (list, tuple)): + new_args.append(_repeat(arg)) + else: + at_least_one_list = True + new_args.append(arg) + + # Just apply op once to args if there are no lists + if not at_least_one_list: + return op(*args[1:]) + + out = [] + for unpacked in zip(*new_args): + out.append(op(*unpacked)) + + return out + + +def foreach_lerp_inplace(self, end, weight): + # decompose foreach lerp into constituent ops, prevents a graph break due to + # converting a value to a scalar when arg[2] is a single tensor + result = torch._foreach_sub(end, self) + result = torch._foreach_mul(result, weight) + return torch._foreach_add_(self, result) + + +def foreach_pow_scalar(scalar, exps): + return torch._foreach_pow([scalar for _ in exps], exps) + + +def addcmul_inplace(self, tensor1, tensor2, value): + return self.add_(tensor1 * tensor2 * value) + + +def predicate(obj: Any) -> bool: + # This will cause the rest of dynamo to handle the if statement correctly, so we don't have to rewrite it here. + # We can't just use bool() here since we can't trace into that in general. + if obj: + return True + return False + + +def cmp_eq(a, b): + # Note that the commented `is` check should ideally be removed. This is a + # CPython optimization that skips the __eq__ checks it the obj id's are + # same. But, these lines adds many `is` nodes in the Fx graph for + # SymNodeVariable. For now, we can just skip this check. This is STILL + # correct because one of the __eq__ checks will pass later, just could be + # slow in some corner cases. + # if a is b: + # return True + result = a.__eq__(b) + if result is NotImplemented: + result = b.__eq__(a) + return result is not NotImplemented and result + + +def cmp_ne(a, b): + # Check if __ne__ is overridden + if isinstance(type(a).__ne__, types.FunctionType): + return a.__ne__(b) + return not cmp_eq(a, b) + + +def cmp_lt(a, b): + result = a.__lt__(b) + if result is NotImplemented: + raise TypeError(f"{type(a)} does not support the < operator") + return result + + +def cmp_le(a, b): + # Check if __le__ is overridden + if isinstance(type(a).__le__, types.FunctionType): + return a.__le__(b) + return cmp_eq(a, b) or cmp_lt(a, b) + + +def cmp_gt(a, b): + # Check if __gt__ is overridden + if isinstance(type(a).__gt__, types.FunctionType): + return a.__gt__(b) + # a > b is equivalent to b < a + return cmp_lt(b, a) + + +def cmp_ge(a, b): + # Check if __ge__ is overridden + if isinstance(type(a).__ge__, types.FunctionType): + return a.__ge__(b) + return cmp_eq(a, b) or cmp_gt(a, b) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20468b21de65c91fe61320401d25c6790d0b7201 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/_collections.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/_collections.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f5bf339dab1aeb8d2083eaafbdf400105d0d30a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/_collections.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/builtins.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/builtins.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0bc3fcdb52b212be37261e010684e318a6ba32b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/builtins.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/functools.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/functools.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ce6cba4647cb125093ed5d3ee179aed0df7a617 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/functools.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/fx.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/fx.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f158a32ca8fa20518cc26c5c2297d4098c3d485 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/fx.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/heapq.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/heapq.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8fe9c23721fba42d37d8c0c6788da64ea6e9605 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/heapq.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/itertools.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/itertools.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8fd4585644cc0d46411255d136ef077c6afc7a7c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/itertools.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/loader.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/loader.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cf054666fe5cdc57cee8a8eaf2163803b37f949 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/loader.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/operator.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/operator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9ab2eb084279381a05b4adb6003401734456fbd Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/operator.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/os.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/os.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55012b715d3ef552783945977f4398e8da80a8f1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/os.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/pytree.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/pytree.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe9ec47edae87832631801cece2ae001e62299a5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/pytree.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/struct.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/struct.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c18af41ae62645ee0caaaeb8e06a8a9ec995d3af Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/struct.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/sys.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/sys.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e67cd46156cee6400d434a77dd1943d601824ee7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/sys.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/tensor.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/tensor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0683cac44a2b61d2650b87a6af85641eca96b78 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/__pycache__/tensor.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/_collections.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/_collections.py new file mode 100644 index 0000000000000000000000000000000000000000..9773635ae30587b06bb9f6b82c003392767b3873 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/_collections.py @@ -0,0 +1,33 @@ +""" +Python polyfills for builtins +""" + +from collections.abc import Iterable, MutableMapping +from typing import TypeVar + +from ..decorators import substitute_in_graph + + +__all__ = [] + + +T = TypeVar("T") + + +try: + import _collections # type: ignore[import-not-found] + + @substitute_in_graph(_collections._count_elements) + def _count_elements( + mapping: MutableMapping[T, int], + iterable: Iterable[T], + ) -> None: + "Tally elements from the iterable." + mapping_get = mapping.get + for elem in iterable: + mapping[elem] = mapping_get(elem, 0) + 1 + + __all__.append("_count_elements") + +except ImportError: + pass diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/builtins.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/builtins.py new file mode 100644 index 0000000000000000000000000000000000000000..45feac9ca5dce561251c85794593c276dabaa4ef --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/builtins.py @@ -0,0 +1,123 @@ +""" +Python polyfills for builtins +""" + +from __future__ import annotations + +import builtins +import functools +import operator +from collections.abc import Callable +from typing import TYPE_CHECKING, TypeVar + +from ..decorators import substitute_in_graph + + +if TYPE_CHECKING: + from collections.abc import Iterable + + +__all__ = [ + "all", + "any", + "enumerate", + "sum", +] + + +_T = TypeVar("_T") + + +@substitute_in_graph(builtins.all, can_constant_fold_through=True) +def all(iterable: Iterable[object], /) -> bool: + for elem in iterable: + if not elem: + return False + return True + + +@substitute_in_graph(builtins.any, can_constant_fold_through=True) +def any(iterable: Iterable[object], /) -> bool: + for elem in iterable: + if elem: + return True + return False + + +@substitute_in_graph(builtins.enumerate, is_embedded_type=True) # type: ignore[arg-type] +def enumerate(iterable: Iterable[_T], start: int = 0) -> Iterable[tuple[int, _T]]: + if not isinstance(start, int): + raise TypeError( + f"{type(start).__name__!r} object cannot be interpreted as an integer" + ) + + for x in iterable: + yield start, x + start += 1 + + +@substitute_in_graph(builtins.sum, can_constant_fold_through=True) # type: ignore[arg-type] +def sum(iterable: Iterable[_T], /, start: _T = 0) -> _T: # type: ignore[assignment] + return functools.reduce(operator.add, iterable, start) + + +class _CallableIterator: + def __init__(self, fn, sentinel): # type: ignore[no-untyped-def] + self.fn = fn + self.sentinel = sentinel + + def __iter__(self): # type: ignore[no-untyped-def] + return self + + def __next__(self): # type: ignore[no-untyped-def] + # The iterator created in this case will call object with no arguments + # for each call to its __next__() method; + r = self.fn() + + # If the value returned is equal to sentinel, StopIteration will be raised + if r == self.sentinel: + raise StopIteration + + # otherwise the value will be returned. + return r + + +class _SENTINEL_MISSING: + pass + + +# TODO(guilhermeleobas): use substitute_in_graph for iter() +def iter_(fn_or_iterable, sentinel=_SENTINEL_MISSING, /): # type: ignore[no-untyped-def] + # Without a second argument, object must be a collection object which supports + # the iterable (__iter__) or the sequence protocol (__getitem__ with an integer + # starting at 0) + if sentinel is _SENTINEL_MISSING: + iterable = fn_or_iterable + if hasattr(iterable, "__iter__"): + iterator = iterable.__iter__() + if hasattr(iterator, "__next__"): + return iterator + else: + raise TypeError(f"'{type(iterator)}' object is not iterable") + if hasattr(iterable, "__getitem__"): + # Needs to be a new function to avoid iter becoming a generator + def sequence_protocol(iterable): # type: ignore[no-untyped-def] + i = 0 + while True: + try: + yield iterable.__getitem__(i) + i += 1 + except IndexError: + break + + return sequence_protocol(iterable) + raise TypeError(f"'{type(iterable)}' object is not iterable") + else: + # If the second argument, sentinel, is given, then object must be a + # callable object. + fn = fn_or_iterable + + if not isinstance(fn, Callable): # type: ignore[arg-type] + raise TypeError("iter(v, w): v must be a callable") + + return _CallableIterator(fn, sentinel) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/functools.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/functools.py new file mode 100644 index 0000000000000000000000000000000000000000..f70ca59bcea3eeab647583843bd1073e05e14639 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/functools.py @@ -0,0 +1,47 @@ +""" +Python polyfills for functools +""" + +import functools +from collections.abc import Callable, Iterable +from typing import TypeVar + +from ..decorators import substitute_in_graph + + +__all__ = ["reduce"] + + +_T = TypeVar("_T") +_U = TypeVar("_U") + + +class _INITIAL_MISSING: + pass + + +# Reference: https://docs.python.org/3/library/functools.html#functools.reduce +@substitute_in_graph(functools.reduce) +def reduce( + function: Callable[[_U, _T], _U], + iterable: Iterable[_T], + initial: _U = _INITIAL_MISSING, # type: ignore[assignment] + /, +) -> _U: + it = iter(iterable) + + value: _U + if initial is _INITIAL_MISSING: + try: + value = next(it) # type: ignore[assignment] + except StopIteration: + raise TypeError( + "reduce() of empty iterable with no initial value", + ) from None + else: + value = initial + + for element in it: + value = function(value, element) + + return value diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/fx.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/fx.py new file mode 100644 index 0000000000000000000000000000000000000000..5a5ed97e0899d94fc4478de5acfa7879f5560ab2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/fx.py @@ -0,0 +1,41 @@ +from collections.abc import Callable +from typing import Any + +from torch._C import _fx_map_aggregate, _fx_map_arg +from torch.fx.immutable_collections import immutable_dict, immutable_list +from torch.fx.node import Node + +from ..decorators import substitute_in_graph + + +@substitute_in_graph(_fx_map_arg, can_constant_fold_through=True) +def map_arg(a: Any, fn: Callable[[Node], Any]) -> Any: + return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x) + + +@substitute_in_graph(_fx_map_aggregate, can_constant_fold_through=True) +def map_aggregate(a: Any, fn: Callable[[Any], Any]) -> Any: + result: Any + if isinstance(a, tuple): + it = (map_aggregate(elem, fn) for elem in a) + # Support NamedTuple (if it has `_fields`) by repacking into original type. + result = type(a)(*it) if hasattr(a, "_fields") else tuple(it) + elif isinstance(a, list): + result = immutable_list([map_aggregate(elem, fn) for elem in a]) + elif isinstance(a, dict): + result = immutable_dict([(k, map_aggregate(v, fn)) for k, v in a.items()]) + elif isinstance(a, slice): + result = slice( + map_aggregate(a.start, fn), + map_aggregate(a.stop, fn), + map_aggregate(a.step, fn), + ) + else: + result = fn(a) + return result + + +__all__ = [ + "map_arg", + "map_aggregate", +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/heapq.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/heapq.py new file mode 100644 index 0000000000000000000000000000000000000000..feddb5723614f581fdd232a162feaf00a3ca2fae --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/heapq.py @@ -0,0 +1,119 @@ +""" +Python polyfills for heapq +""" + +from __future__ import annotations + +import heapq +import importlib +import sys +from typing import TYPE_CHECKING, TypeVar + +from ..decorators import substitute_in_graph + + +if TYPE_CHECKING: + from types import ModuleType + + +_T = TypeVar("_T") + + +# Partially copied from CPython test/support/import_helper.py +# https://github.com/python/cpython/blob/bb8791c0b75b5970d109e5557bfcca8a578a02af/Lib/test/support/import_helper.py +def _save_and_remove_modules(names: set[str]) -> dict[str, ModuleType]: + orig_modules = {} + prefixes = tuple(name + "." for name in names) + for modname in list(sys.modules): + if modname in names or modname.startswith(prefixes): + orig_modules[modname] = sys.modules.pop(modname) + return orig_modules + + +def import_fresh_module(name: str, blocked: list[str]) -> ModuleType: + # Keep track of modules saved for later restoration as well + # as those which just need a blocking entry removed + names = {name, *blocked} + orig_modules = _save_and_remove_modules(names) + for modname in blocked: + sys.modules[modname] = None # type: ignore[assignment] + + try: + return importlib.import_module(name) + finally: + _save_and_remove_modules(names) + sys.modules.update(orig_modules) + + +# Import the pure Python heapq module, blocking the C extension +py_heapq = import_fresh_module("heapq", blocked=["_heapq"]) + + +__all__ = [ + "_heapify_max", + "_heappop_max", + "_heapreplace_max", + "heapify", + "heappop", + "heappush", + "heappushpop", + "heapreplace", + "merge", + "nlargest", + "nsmallest", +] + + +@substitute_in_graph(heapq._heapify_max) +def _heapify_max(heap: list[_T], /) -> None: + return py_heapq._heapify_max(heap) + + +@substitute_in_graph(heapq._heappop_max) # type: ignore[attr-defined] +def _heappop_max(heap: list[_T]) -> _T: + return py_heapq._heappop_max(heap) + + +@substitute_in_graph(heapq._heapreplace_max) # type: ignore[attr-defined] +def _heapreplace_max(heap: list[_T], item: _T) -> _T: + return py_heapq._heapreplace_max(heap, item) + + +@substitute_in_graph(heapq.heapify) +def heapify(heap: list[_T], /) -> None: + return py_heapq.heapify(heap) + + +@substitute_in_graph(heapq.heappop) +def heappop(heap: list[_T], /) -> _T: + return py_heapq.heappop(heap) + + +@substitute_in_graph(heapq.heappush) +def heappush(heap: list[_T], item: _T) -> None: + return py_heapq.heappush(heap, item) + + +@substitute_in_graph(heapq.heappushpop) +def heappushpop(heap: list[_T], item: _T) -> _T: + return py_heapq.heappushpop(heap, item) + + +@substitute_in_graph(heapq.heapreplace) +def heapreplace(heap: list[_T], item: _T) -> _T: + return py_heapq.heapreplace(heap, item) + + +@substitute_in_graph(heapq.merge) # type: ignore[arg-type] +def merge(*iterables, key=None, reverse=False): # type: ignore[no-untyped-def] + return py_heapq.merge(*iterables, key=key, reverse=reverse) + + +@substitute_in_graph(heapq.nlargest) # type: ignore[arg-type] +def nlargest(n, iterable, key=None): # type: ignore[no-untyped-def] + return py_heapq.nlargest(n, iterable, key=key) + + +@substitute_in_graph(heapq.nsmallest) # type: ignore[arg-type] +def nsmallest(n, iterable, key=None): # type: ignore[no-untyped-def] + return py_heapq.nsmallest(n, iterable, key=key) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/itertools.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/itertools.py new file mode 100644 index 0000000000000000000000000000000000000000..8fbf9dfa1706751df86abcb55c2186c2ab47dd6e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/itertools.py @@ -0,0 +1,276 @@ +""" +Python polyfills for itertools +""" + +from __future__ import annotations + +import itertools +import operator +from collections.abc import Callable +from typing import Optional, overload, TYPE_CHECKING, TypeAlias, TypeVar + +from ..decorators import substitute_in_graph + + +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator + + +__all__ = [ + "accumulate", + "chain", + "chain_from_iterable", + "compress", + "cycle", + "dropwhile", + "filterfalse", + "islice", + "tee", + "zip_longest", + "pairwise", +] + + +_T = TypeVar("_T") +_U = TypeVar("_U") +_Predicate: TypeAlias = Callable[[_T], object] +_T1 = TypeVar("_T1") +_T2 = TypeVar("_T2") + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.chain +@substitute_in_graph(itertools.chain, is_embedded_type=True) # type: ignore[arg-type] +def chain(*iterables: Iterable[_T]) -> Iterator[_T]: + for iterable in iterables: + yield from iterable + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.accumulate +@substitute_in_graph(itertools.accumulate, is_embedded_type=True) # type: ignore[arg-type] +def accumulate( + iterable: Iterable[_T], + func: Optional[Callable[[_T, _T], _T]] = None, + *, + initial: Optional[_T] = None, +) -> Iterator[_T]: + # call iter outside of the generator to match cypthon behavior + iterator = iter(iterable) + if func is None: + func = operator.add + + def _accumulate(iterator: Iterator[_T]) -> Iterator[_T]: + total = initial + if total is None: + try: + total = next(iterator) + except StopIteration: + return + + yield total + for element in iterator: + total = func(total, element) + yield total + + return _accumulate(iterator) + + +@substitute_in_graph(itertools.chain.from_iterable) # type: ignore[arg-type] +def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]: + # previous version of this code was: + # return itertools.chain(*iterable) + # If iterable is an infinite generator, this will lead to infinite recursion + for it in iterable: + yield from it + + +chain.from_iterable = chain_from_iterable # type: ignore[attr-defined] + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.compress +@substitute_in_graph(itertools.compress, is_embedded_type=True) # type: ignore[arg-type] +def compress(data: Iterable[_T], selectors: Iterable[_U], /) -> Iterator[_T]: + return (datum for datum, selector in zip(data, selectors) if selector) + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.cycle +@substitute_in_graph(itertools.cycle, is_embedded_type=True) # type: ignore[arg-type] +def cycle(iterable: Iterable[_T]) -> Iterator[_T]: + iterator = iter(iterable) + + def _cycle(iterator: Iterator[_T]) -> Iterator[_T]: + saved = [] + for element in iterable: + yield element + saved.append(element) + + while saved: + for element in saved: + yield element + + return _cycle(iterator) + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.dropwhile +@substitute_in_graph(itertools.dropwhile, is_embedded_type=True) # type: ignore[arg-type] +def dropwhile(predicate: _Predicate[_T], iterable: Iterable[_T], /) -> Iterator[_T]: + # dropwhile(lambda x: x < 5, [1, 4, 6, 3, 8]) -> 6 3 8 + + iterator = iter(iterable) + for x in iterator: + if not predicate(x): + yield x + break + + yield from iterator + + +@substitute_in_graph(itertools.filterfalse, is_embedded_type=True) # type: ignore[arg-type] +def filterfalse(function: _Predicate[_T], iterable: Iterable[_T], /) -> Iterator[_T]: + it = iter(iterable) + if function is None: + return filter(operator.not_, it) + else: + return filter(lambda x: not function(x), it) + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.islice +@substitute_in_graph(itertools.islice, is_embedded_type=True) # type: ignore[arg-type] +def islice(iterable: Iterable[_T], /, *args: int | None) -> Iterator[_T]: + s = slice(*args) + start = 0 if s.start is None else s.start + stop = s.stop + step = 1 if s.step is None else s.step + if start < 0 or (stop is not None and stop < 0) or step <= 0: + raise ValueError( + "Indices for islice() must be None or an integer: 0 <= x <= sys.maxsize.", + ) + + if stop is None: + # TODO: use indices = itertools.count() and merge implementation with the else branch + # when we support infinite iterators + next_i = start + for i, element in enumerate(iterable): + if i == next_i: + yield element + next_i += step + else: + indices = range(max(start, stop)) + next_i = start + for i, element in zip(indices, iterable): + if i == next_i: + yield element + next_i += step + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.pairwise +@substitute_in_graph(itertools.pairwise, is_embedded_type=True) # type: ignore[arg-type] +def pairwise(iterable: Iterable[_T], /) -> Iterator[tuple[_T, _T]]: + a = None + first = True + for b in iterable: + if first: + first = False + else: + yield a, b # type: ignore[misc] + a = b + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.tee +@substitute_in_graph(itertools.tee) +def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]: + iterator = iter(iterable) + shared_link = [None, None] + + def _tee(link) -> Iterator[_T]: # type: ignore[no-untyped-def] + try: + while True: + if link[1] is None: + link[0] = next(iterator) + link[1] = [None, None] + value, link = link + yield value + except StopIteration: + return + + return tuple(_tee(shared_link) for _ in range(n)) + + +@overload +# pyrefly: ignore [inconsistent-overload] +def zip_longest( + iter1: Iterable[_T1], + /, + *, + fillvalue: _U = ..., +) -> Iterator[tuple[_T1]]: ... + + +@overload +# pyrefly: ignore [inconsistent-overload] +def zip_longest( + iter1: Iterable[_T1], + iter2: Iterable[_T2], + /, +) -> Iterator[tuple[_T1 | None, _T2 | None]]: ... + + +@overload +# pyrefly: ignore [inconsistent-overload] +def zip_longest( + iter1: Iterable[_T1], + iter2: Iterable[_T2], + /, + *, + fillvalue: _U = ..., +) -> Iterator[tuple[_T1 | _U, _T2 | _U]]: ... + + +@overload +# pyrefly: ignore [inconsistent-overload] +def zip_longest( + iter1: Iterable[_T], + iter2: Iterable[_T], + iter3: Iterable[_T], + /, + *iterables: Iterable[_T], +) -> Iterator[tuple[_T | None, ...]]: ... + + +@overload +# pyrefly: ignore [inconsistent-overload] +def zip_longest( + iter1: Iterable[_T], + iter2: Iterable[_T], + iter3: Iterable[_T], + /, + *iterables: Iterable[_T], + fillvalue: _U = ..., +) -> Iterator[tuple[_T | _U, ...]]: ... + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.zip_longest +@substitute_in_graph(itertools.zip_longest, is_embedded_type=True) # type: ignore[arg-type,misc] +def zip_longest( + *iterables: Iterable[_T], + fillvalue: _U = None, # type: ignore[assignment] +) -> Iterator[tuple[_T | _U, ...]]: + # zip_longest('ABCD', 'xy', fillvalue='-') -> Ax By C- D- + + iterators = list(map(iter, iterables)) + num_active = len(iterators) + if not num_active: + return + + while True: + values = [] + for i, iterator in enumerate(iterators): + try: + value = next(iterator) + except StopIteration: + num_active -= 1 + if not num_active: + return + iterators[i] = itertools.repeat(fillvalue) # type: ignore[arg-type] + value = fillvalue # type: ignore[assignment] + values.append(value) + yield tuple(values) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/loader.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..31479e9d86ce6163c1c54ccdea73cc224ac82904 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/loader.py @@ -0,0 +1,45 @@ +# Used to load and initialize polyfill handlers when importing torch._dynamo +# Please add a new import when adding a new polyfill module. + +import importlib +from typing import TYPE_CHECKING + +import torch.utils._pytree as python_pytree + +from .. import polyfills, trace_rules + + +if TYPE_CHECKING: + from types import ModuleType + + +# See also the TYPE_CHECKING block in torch/_dynamo/polyfills/__init__.py +POLYFILLED_MODULE_NAMES: tuple[str, ...] = ( + "_collections", + "builtins", + "functools", + "itertools", + "operator", + "os", + "struct", + "sys", + "fx", + "tensor", +) +if python_pytree._cxx_pytree_dynamo_traceable: + POLYFILLED_MODULE_NAMES += ("pytree",) + +POLYFILLED_MODULES: tuple["ModuleType", ...] = tuple( + importlib.import_module(f".{submodule}", package=polyfills.__name__) + for submodule in POLYFILLED_MODULE_NAMES +) + + +# Unregister the builtin functions from _builtin_function_ids to let them to be +# dispatched with the appropriate VariableTracker type. Otherwise, they will be +# dispatched with BuiltinVariable if present in _builtin_function_ids. +for polyfill_module in POLYFILLED_MODULES: + for polyfill_name in polyfill_module.__all__: + polyfill_handler = getattr(polyfill_module, polyfill_name) + original_fn = polyfill_handler.__torch_dynamo_original__ + trace_rules._builtin_function_ids.remove(id(original_fn)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/operator.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/operator.py new file mode 100644 index 0000000000000000000000000000000000000000..cae61df2c04307f294f1bf56fa68323acabc0e48 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/operator.py @@ -0,0 +1,119 @@ +""" +Python polyfills for operator +""" + +from __future__ import annotations + +import operator +from typing import Any, overload, TYPE_CHECKING, TypeVar +from typing_extensions import TypeVarTuple, Unpack + +from ..decorators import substitute_in_graph + + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + + +# Most unary and binary operators are handled by BuiltinVariable (e.g., `pos`, `add`) +__all__ = ["attrgetter", "itemgetter", "methodcaller", "countOf"] + + +_T = TypeVar("_T") +_T1 = TypeVar("_T1") +_T2 = TypeVar("_T2") +_Ts = TypeVarTuple("_Ts") +_U = TypeVar("_U") +_U1 = TypeVar("_U1") +_U2 = TypeVar("_U2") +_Us = TypeVarTuple("_Us") + + +@overload +# pyrefly: ignore [inconsistent-overload] +def attrgetter(attr: str, /) -> Callable[[Any], _U]: ... + + +@overload +# pyrefly: ignore [inconsistent-overload] +def attrgetter( + attr1: str, attr2: str, /, *attrs: str +) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]: ... + + +# Reference: https://docs.python.org/3/library/operator.html#operator.attrgetter +@substitute_in_graph(operator.attrgetter, is_embedded_type=True) # type: ignore[arg-type,misc] +def attrgetter(*attrs: str) -> Callable[[Any], Any | tuple[Any, ...]]: + if len(attrs) == 0: + raise TypeError("attrgetter expected 1 argument, got 0") + + if any(not isinstance(attr, str) for attr in attrs): + raise TypeError("attribute name must be a string") + + def resolve_attr(obj: Any, attr: str) -> Any: + for name in attr.split("."): + obj = getattr(obj, name) + return obj + + if len(attrs) == 1: + attr = attrs[0] + + def getter(obj: Any) -> Any: + return resolve_attr(obj, attr) + + else: + + def getter(obj: Any) -> tuple[Any, ...]: # type: ignore[misc] + return tuple(resolve_attr(obj, attr) for attr in attrs) + + return getter + + +@overload +# pyrefly: ignore [inconsistent-overload] +def itemgetter(item: _T, /) -> Callable[[Any], _U]: ... + + +@overload +# pyrefly: ignore [inconsistent-overload] +def itemgetter( + item1: _T1, item2: _T2, /, *items: Unpack[_Ts] +) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]: ... + + +# Reference: https://docs.python.org/3/library/operator.html#operator.itemgetter +@substitute_in_graph(operator.itemgetter, is_embedded_type=True) # type: ignore[arg-type,misc] +def itemgetter(*items: Any) -> Callable[[Any], Any | tuple[Any, ...]]: + if len(items) == 0: + raise TypeError("itemgetter expected 1 argument, got 0") + + if len(items) == 1: + item = items[0] + + def getter(obj: Any) -> Any: + return obj[item] + + else: + + def getter(obj: Any) -> tuple[Any, ...]: # type: ignore[misc] + return tuple(obj[item] for item in items) + + return getter + + +# Reference: https://docs.python.org/3/library/operator.html#operator.methodcaller +@substitute_in_graph(operator.methodcaller, is_embedded_type=True) # type: ignore[arg-type] +def methodcaller(name: str, /, *args: Any, **kwargs: Any) -> Callable[[Any], Any]: + if not isinstance(name, str): + raise TypeError("method name must be a string") + + def caller(obj: Any) -> Any: + return getattr(obj, name)(*args, **kwargs) + + return caller + + +# Reference: https://docs.python.org/3/library/operator.html#operator.countOf +@substitute_in_graph(operator.countOf, can_constant_fold_through=True) # type: ignore[arg-type,misc] +def countOf(a: Iterable[_T], b: _T, /) -> int: + return sum(it is b or it == b for it in a) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/os.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/os.py new file mode 100644 index 0000000000000000000000000000000000000000..2f55d436ad8978bc0ddb46bdeeb356c518590547 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/os.py @@ -0,0 +1,37 @@ +""" +Python polyfills for os +""" + +from __future__ import annotations + +import os +from typing import AnyStr + +from ..decorators import substitute_in_graph + + +__all__ = ["fspath"] + + +# Copied from os.py in the standard library +@substitute_in_graph(os.fspath, can_constant_fold_through=True) +def fspath(path: AnyStr | os.PathLike[AnyStr]) -> AnyStr: + if isinstance(path, (str, bytes)): + # pyrefly: ignore [bad-return] + return path + + path_type = type(path) + try: + path_repr = path_type.__fspath__(path) # type: ignore[arg-type] + except AttributeError: + if hasattr(path_type, "__fspath__"): + raise + raise TypeError( + f"expected str, bytes or os.PathLike object, not {path_type.__name__}", + ) from None + if isinstance(path_repr, (str, bytes)): + return path_repr # type: ignore[return-value] + raise TypeError( + f"expected {path_type.__name__}.__fspath__() to return str or bytes, " + f"not {type(path_repr).__name__}", + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/pytree.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/pytree.py new file mode 100644 index 0000000000000000000000000000000000000000..f5f9c1830333641b785b96780bb9b6b0475282e4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/pytree.py @@ -0,0 +1,758 @@ +""" +Python polyfills for torch.utils.pytree +""" + +from __future__ import annotations + +from collections import deque +from dataclasses import dataclass, field +from typing import Any, TYPE_CHECKING, TypeVar + +import optree +import optree._C +import optree.utils +from optree import ( + is_namedtuple, + is_namedtuple_class, + is_namedtuple_instance, + is_structseq, + is_structseq_class, + is_structseq_instance, + namedtuple_fields, + structseq_fields, +) + +import torch.utils._cxx_pytree as cxx_pytree # noqa: F401 +import torch.utils._pytree as python_pytree +from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES + +from ..decorators import substitute_in_graph + + +if TYPE_CHECKING: + import builtins + from collections.abc import Callable, Iterable, Mapping + from typing_extensions import Self, TypeIs + + from torch.utils._cxx_pytree import PyTree + + +__all__ = [ + "is_namedtuple", + "is_namedtuple_class", + "is_namedtuple_instance", + "is_structseq", + "is_structseq_class", + "is_structseq_instance", + "namedtuple_fields", + "structseq_fields", + "treespec_leaf", + "treespec_tuple", + "treespec_dict", + "tree_is_leaf", + "tree_iter", + "tree_leaves", + "tree_flatten", + "tree_flatten_with_path", + "tree_structure", + "tree_unflatten", +] + + +_T = TypeVar("_T") +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") + + +@substitute_in_graph( + optree._C.is_dict_insertion_ordered, + can_constant_fold_through=True, +) +def _(*args: Any, **kwargs: Any) -> bool: + # In namespace 'torch', the dictionary is always traversed in insertion order. + # This function returns True. + raise ValueError( + "Should not be called directly " + "because the original function will be called in the constant fold path." + ) + + +__name = "" +for __name, __func in ( + ("is_namedtuple", is_namedtuple), + ("is_namedtuple_class", is_namedtuple_class), + ("is_namedtuple_instance", is_namedtuple_instance), + ("is_structseq", is_structseq), + ("is_structseq_class", is_structseq_class), + ("is_structseq_instance", is_structseq_instance), + ("namedtuple_fields", namedtuple_fields), + ("structseq_fields", structseq_fields), +): + globals()[__name] = substitute_in_graph( + __func, # type: ignore[arg-type] + can_constant_fold_through=True, + )(__func.__python_implementation__) # type: ignore[attr-defined] + del __func +del __name + + +@substitute_in_graph(optree.tree_is_leaf, can_constant_fold_through=True) # type: ignore[arg-type] +def tree_is_leaf( + tree: PyTree, + /, + is_leaf: Callable[[PyTree], bool] | None = None, + *, + none_is_leaf: bool = False, + namespace: str = "", +) -> bool: + if (tree is None and none_is_leaf) or (is_leaf is not None and is_leaf(tree)): + return True + if optree.register_pytree_node.get(type(tree), namespace=namespace) is None: + return True + return False + + +@substitute_in_graph(optree.tree_iter, can_constant_fold_through=False) # type: ignore[arg-type] +def tree_iter( + tree: PyTree, + /, + is_leaf: Callable[[PyTree], bool] | None = None, + *, + none_is_leaf: bool = False, + namespace: str = "", +) -> Iterable[Any]: + stack = [tree] + while stack: + node = stack.pop() + if tree_is_leaf( + node, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + ): + yield node + continue + + children, *_ = optree.tree_flatten_one_level( + node, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + stack.extend(reversed(children)) + + +@substitute_in_graph(optree.tree_leaves, can_constant_fold_through=True) # type: ignore[arg-type] +def tree_leaves( + tree: PyTree, + /, + is_leaf: Callable[[PyTree], bool] | None = None, + *, + none_is_leaf: bool = False, + namespace: str = "", +) -> list[Any]: + return list( + tree_iter( + tree, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + ) + + +class _Asterisk(str): + __slots__ = () + + def __new__(cls) -> Self: + return super().__new__(cls, "*") + + def __repr__(self) -> str: + return "*" # no quotes + + +_asterisk = _Asterisk() +del _Asterisk + + +@dataclass(frozen=True) +class PyTreeSpec: + """Analog for :class:`optree.PyTreeSpec` in Python.""" + + _children: tuple[PyTreeSpec, ...] + _type: builtins.type | None + _metadata: Any + _entries: tuple[Any, ...] + _unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None + none_is_leaf: bool + namespace: str + + num_nodes: int = field(init=False) + num_leaves: int = field(init=False) + num_children: int = field(init=False) + + def __post_init__(self, /) -> None: + if self._type is None: + assert len(self._children) == 0 + assert self._metadata is None + assert self._entries == () + assert self._unflatten_func is None + num_nodes = 1 + num_leaves = 1 + num_children = 0 + else: + assert callable(self._unflatten_func) + num_nodes = 1 + num_leaves = 0 + for child in self._children: + num_nodes += child.num_nodes + num_leaves += child.num_leaves + num_children = len(self._children) + + object.__setattr__(self, "num_nodes", num_nodes) + object.__setattr__(self, "num_leaves", num_leaves) + object.__setattr__(self, "num_children", num_children) + + def __repr__(self, /) -> str: + def helper(treespec: PyTreeSpec) -> str: + if treespec.is_leaf(): + assert treespec.type is None + return _asterisk + + assert treespec.type is not None + assert callable(treespec._unflatten_func) + children_representations = [ + helper(subspec) for subspec in treespec._children + ] + if ( + treespec.type in BUILTIN_TYPES + or (treespec.type is type(None) and not self.none_is_leaf) + or optree.is_namedtuple_class(treespec.type) + or optree.is_structseq_class(treespec.type) + ): + # pyrefly: ignore [bad-return] + return treespec._unflatten_func( + treespec._metadata, + children_representations, + ) + return ( + f"CustomTreeNode({treespec.type.__name__}[{treespec._metadata!r}], " + f"[{', '.join(children_representations)}])" + ) + + inner = [ + str(helper(self)), + *(["NoneIsLeaf"] if self.none_is_leaf else []), + f"namespace={self.namespace!r}", + ] + return f"PyTreeSpec({', '.join(inner)})" + + def __len__(self, /) -> int: + return self.num_leaves + + @property + def type(self, /) -> builtins.type | None: + return self._type + + def is_leaf(self, /) -> bool: + return self.num_nodes == 1 and self.num_leaves == 1 + + def paths(self, /) -> list[tuple[Any, ...]]: + def helper(treespec: PyTreeSpec, path_prefix: list[Any]) -> None: + if treespec.is_leaf(): + paths.append(path_prefix) + return + + for entry, subspec in zip( + treespec._entries, + treespec._children, + strict=True, + ): + helper(subspec, path_prefix + [entry]) + + paths: list[list[Any]] = [] + helper(self, []) + return [tuple(path) for path in paths] + + def accessors(self, /) -> list[optree.PyTreeAccessor]: + def helper( + treespec: PyTreeSpec, + entry_path_prefix: list[optree.PyTreeEntry], + ) -> None: + if treespec.is_leaf(): + entry_paths.append(entry_path_prefix) + return + + node_type = treespec.type + assert node_type is not None + handler = optree.register_pytree_node.get( + node_type, namespace=treespec.namespace + ) + assert handler is not None + kind: optree.PyTreeKind = handler.kind + path_entry_type: type[optree.PyTreeEntry] = handler.path_entry_type + + for entry, subspec in zip( + treespec._entries, + treespec._children, + strict=True, + ): + helper( + subspec, + entry_path_prefix + [path_entry_type(entry, node_type, kind)], + ) + + entry_paths: list[list[optree.PyTreeEntry]] = [] + helper(self, []) + return [optree.PyTreeAccessor(path) for path in entry_paths] + + def children(self, /) -> list[PyTreeSpec]: + return list(self._children) + + def child(self, index: int, /) -> PyTreeSpec: + return self._children[index] + + def entries(self, /) -> list[Any]: + return list(self._entries) + + def entry(self, index: int, /) -> Any: + return self._entries[index] + + def flatten_up_to(self, tree: PyTree, /) -> list[PyTree]: + def helper( + treespec: PyTreeSpec, + node: PyTree, + subtrees: list[PyTree], + ) -> None: + if treespec.is_leaf(): + subtrees.append(node) + return + + node_type = type(node) + if treespec.type not in BUILTIN_TYPES: + # Always require custom node types to match exactly + if node_type != treespec.type: + raise ValueError( + f"Type mismatch; " + f"expected {treespec.type!r}, but got {node_type!r}.", + ) + + children, metadata, *_ = optree.tree_flatten_one_level( + node, + none_is_leaf=self.none_is_leaf, + namespace=self.namespace, + ) + if len(children) != treespec.num_children: + raise ValueError( + f"Node arity mismatch; " + f"expected {treespec.num_children}, but got {len(children)}.", + ) + if metadata != treespec._metadata: + raise ValueError( + f"Node context mismatch for custom node type {treespec.type!r}.", + ) + else: + # For builtin dictionary types, we allow some flexibility + # Otherwise, we require exact matches + both_standard_dict = ( + treespec.type in STANDARD_DICT_TYPES + and node_type in STANDARD_DICT_TYPES + ) + if not both_standard_dict and node_type != treespec.type: + raise ValueError( + f"Node type mismatch; " + f"expected {treespec.type!r}, but got {node_type!r}.", + ) + if len(node) != treespec.num_children: + raise ValueError( + f"Node arity mismatch; " + f"expected {treespec.num_children}, but got {len(node)}.", + ) + + if both_standard_dict: + # dictionary types are compatible with each other + expected_keys = treespec.entries() + got_key_set = set(node) + expected_key_set = set(expected_keys) + if got_key_set != expected_key_set: + missing_keys = expected_key_set.difference(got_key_set) + extra_keys = got_key_set.difference(expected_key_set) + message = "" + if missing_keys: + message += f"; missing key(s): {missing_keys}" + if extra_keys: + message += f"; extra key(s): {extra_keys}" + raise ValueError(f"Node keys mismatch{message}.") + children = [node[key] for key in expected_keys] + else: + # node_type is treespec.type + children, metadata, *_ = optree.tree_flatten_one_level( + node, + none_is_leaf=self.none_is_leaf, + namespace=self.namespace, + ) + if ( + node_type is not deque # ignore mismatch of `maxlen` for deque + ) and metadata != treespec._metadata: + raise ValueError( + f"Node metadata mismatch for node type {treespec.type!r}; " + f"expected {treespec._metadata!r}, but got {metadata!r}.", # namedtuple type mismatch + ) + + for subtree, subspec in zip(children, treespec._children, strict=True): + helper(subspec, subtree, subtrees) + + subtrees: list[PyTree] = [] + helper(self, tree, subtrees) + return subtrees + + def unflatten(self, leaves: Iterable[Any], /) -> PyTree: + if not isinstance(leaves, (list, tuple)): + leaves = list(leaves) + if len(leaves) != self.num_leaves: + raise ValueError( + f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} " + f"but the spec refers to a pytree that holds {self.num_leaves} " + f"items ({self}).", + ) + if self.is_leaf(): + return leaves[0] + + # Recursively unflatten the children + start = 0 + end = 0 + subtrees = [] + for subspec in self._children: + end += subspec.num_leaves + subtrees.append(subspec.unflatten(leaves[start:end])) + start = end + + assert callable(self._unflatten_func) + return self._unflatten_func(self._metadata, subtrees) + + +def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec | python_pytree.TreeSpec]: + return isinstance(obj, (PyTreeSpec, python_pytree.TreeSpec)) + + +@substitute_in_graph( # type: ignore[arg-type] + optree.treespec_leaf, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, +) +def treespec_leaf( + *, + none_is_leaf: bool = False, + namespace: str = "", # unused +) -> PyTreeSpec: + return PyTreeSpec( + (), + None, + None, + (), + None, + none_is_leaf=none_is_leaf, + namespace="", + ) + + +@substitute_in_graph( # type: ignore[arg-type] + optree.treespec_tuple, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, +) +def treespec_tuple( + iterable: Iterable[PyTreeSpec] = (), + /, + *, + none_is_leaf: bool = False, + namespace: str = "", +) -> PyTreeSpec: + children = tuple(iterable) + if any(not _is_pytreespec_instance(child) for child in children): + raise ValueError(f"Expected a tuple of PyTreeSpecs, got: {children!r}.") + if any(child.none_is_leaf != none_is_leaf for child in children): + raise ValueError( + "All children PyTreeSpecs must have the same `none_is_leaf` value " + f"as the parent; expected {none_is_leaf}, got: {children!r}.", + ) + if any(child.namespace not in (namespace, "") for child in children): + raise ValueError( + "All children PyTreeSpecs must have the same `namespace` value " + f"as the parent; expected {namespace!r}, got: {children!r}.", + ) + handler = optree.register_pytree_node.get(tuple, namespace=namespace) + assert handler is not None + return PyTreeSpec( + tuple(children), + tuple, + None, + tuple(range(len(children))), + handler.unflatten_func, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + + +@substitute_in_graph( # type: ignore[arg-type] + optree.treespec_dict, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, +) +def treespec_dict( + mapping: Mapping[Any, PyTreeSpec] | Iterable[tuple[Any, PyTreeSpec]] = (), + /, + *, + none_is_leaf: bool = False, + namespace: str = "", + **kwargs: PyTreeSpec, +) -> PyTreeSpec: + dct = dict(mapping, **kwargs) + if any(not _is_pytreespec_instance(child) for child in dct.values()): + raise ValueError(f"Expected a dictionary of TreeSpecs, got: {dct!r}.") + if any(child.none_is_leaf != none_is_leaf for child in dct.values()): + raise ValueError( + "All children PyTreeSpecs must have the same `none_is_leaf` value " + f"as the parent; expected {none_is_leaf}, got: {dct!r}.", + ) + if any(child.namespace not in (namespace, "") for child in dct.values()): + raise ValueError( + "All children PyTreeSpecs must have the same `namespace` value " + f"as the parent; expected {namespace!r}, got: {dct!r}.", + ) + + ( + children, + metadata, + entries, + unflatten_func, + ) = optree.tree_flatten_one_level( # type: ignore[assignment,var-annotated] + dct, # type: ignore[arg-type] + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + return PyTreeSpec( + tuple(children), # type: ignore[arg-type] + dict, + metadata, + entries, + unflatten_func, # type: ignore[arg-type] + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + + +@substitute_in_graph( # type: ignore[arg-type] + optree.tree_flatten, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, +) +def tree_flatten( + tree: PyTree, + /, + is_leaf: Callable[[PyTree], bool] | None = None, + *, + none_is_leaf: bool = False, + namespace: str = "", +) -> tuple[list[Any], PyTreeSpec]: + def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec: + if tree_is_leaf( + node, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + ): + leaves.append(node) + return PyTreeSpec( + (), + None, + None, + (), + None, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + + ( + children, + metadata, + entries, + unflatten_func, + ) = optree.tree_flatten_one_level( + node, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + + # Recursively flatten the children + subspecs = tuple(helper(child, leaves) for child in children) + return PyTreeSpec( + subspecs, + type(node), + metadata, + entries, + unflatten_func, # type: ignore[arg-type] + none_is_leaf=none_is_leaf, + namespace=namespace, + ) # type: ignore[arg-type] + + leaves: list[Any] = [] + treespec = helper(tree, leaves) + return leaves, treespec + + +@substitute_in_graph( # type: ignore[arg-type] + optree._C.flatten, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, +) +def _C_flatten( + tree: PyTree, + /, + leaf_predicate: Callable[[PyTree], bool] | None = None, + none_is_leaf: bool = False, + namespace: str = "", +) -> tuple[list[Any], PyTreeSpec]: + return tree_flatten( # type: ignore[return-value] + tree, + is_leaf=leaf_predicate, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + + +@substitute_in_graph( # type: ignore[arg-type] + optree.tree_flatten_with_path, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, +) +def tree_flatten_with_path( + tree: PyTree, + /, + is_leaf: Callable[[PyTree], bool] | None = None, + *, + none_is_leaf: bool = False, + namespace: str = "", +) -> tuple[list[tuple[Any, ...]], list[Any], PyTreeSpec]: + leaves, treespec = tree_flatten( + tree, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + return treespec.paths(), leaves, treespec # type: ignore[return-value] + + +@substitute_in_graph( # type: ignore[arg-type] + optree._C.flatten_with_path, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, +) +def _C_flatten_with_path( + tree: PyTree, + /, + leaf_predicate: Callable[[PyTree], bool] | None = None, + none_is_leaf: bool = False, + namespace: str = "", +) -> tuple[list[tuple[Any, ...]], list[Any], PyTreeSpec]: + return tree_flatten_with_path( # type: ignore[return-value] + tree, + is_leaf=leaf_predicate, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + + +@substitute_in_graph( # type: ignore[arg-type] + optree.tree_structure, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, +) +def tree_structure( + tree: PyTree, + /, + is_leaf: Callable[[PyTree], bool] | None = None, + *, + none_is_leaf: bool = False, + namespace: str = "", +) -> PyTreeSpec: + return tree_flatten( # type: ignore[return-value] + tree, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + )[1] + + +@substitute_in_graph( # type: ignore[arg-type] + optree.tree_unflatten, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, +) +def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree: + if not _is_pytreespec_instance(treespec): + raise TypeError( + f"Expected `treespec` to be an instance of " + f"PyTreeSpec but got item of type {type(treespec)}." + ) + return treespec.unflatten(leaves) + + +_none_registration = optree.register_pytree_node.get(type(None)) +assert _none_registration is not None + + +@substitute_in_graph( # type: ignore[arg-type] + _none_registration.unflatten_func, + can_constant_fold_through=True, + skip_signature_check=True, +) +def none_unflatten(_: None, children: Iterable[_T], /) -> None: + if len(list(children)) != 0: + raise ValueError("Expected no children.") + return None + + +with optree.dict_insertion_ordered(False, namespace="torch"): + _dict_registration = optree.register_pytree_node.get(dict) + assert _dict_registration is not None + + +@substitute_in_graph( # type: ignore[arg-type] + _dict_registration.flatten_func, + can_constant_fold_through=True, + skip_signature_check=True, +) +def dict_flatten( + dct: dict[_KT, _VT], / +) -> tuple[list[_VT], tuple[list[_KT], list[_KT]], tuple[_KT, ...]]: + sorted_keys = optree.utils.total_order_sorted(dct) + values = [dct[key] for key in sorted_keys] + original_keys = list(dct) + return values, (original_keys, sorted_keys), tuple(sorted_keys) + + +@substitute_in_graph( # type: ignore[arg-type] + _dict_registration.unflatten_func, + can_constant_fold_through=True, + skip_signature_check=True, +) +def dict_unflatten( + metadata: tuple[list[_KT], list[_KT]], + values: Iterable[_VT], + /, +) -> dict[_KT, _VT]: + original_keys, sorted_keys = metadata + d = dict.fromkeys(original_keys) + d.update(zip(sorted_keys, values, strict=True)) + return d # type: ignore[return-value] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/struct.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/struct.py new file mode 100644 index 0000000000000000000000000000000000000000..f4522a12f7323e51da6f4454814e87daf82cea98 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/struct.py @@ -0,0 +1,27 @@ +""" +Python polyfills for struct +""" + +from __future__ import annotations + +import struct +from typing import Any +from typing_extensions import Buffer + +from ..decorators import substitute_in_graph + + +__all__ = [ + "pack", + "unpack", +] + + +@substitute_in_graph(struct.pack, can_constant_fold_through=True) # type: ignore[arg-type] +def pack(fmt: bytes | str, /, *v: Any) -> bytes: + return struct.pack(fmt, *v) + + +@substitute_in_graph(struct.unpack, can_constant_fold_through=True) # type: ignore[arg-type] +def unpack(format: bytes | str, buffer: Buffer, /) -> tuple[Any, ...]: + return struct.unpack(format, buffer) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/sys.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/sys.py new file mode 100644 index 0000000000000000000000000000000000000000..ab666c385806f9cd56e489038a0884be861c0bf3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/sys.py @@ -0,0 +1,34 @@ +""" +Python polyfills for sys +""" + +from __future__ import annotations + +import sys + +from ..decorators import substitute_in_graph + + +__all__ = [ + "intern", + "getrecursionlimit", +] + + +@substitute_in_graph(sys.intern, can_constant_fold_through=True) +def intern(string: str, /) -> str: + return string + + +@substitute_in_graph(sys.getrecursionlimit, can_constant_fold_through=True) +def getrecursionlimit() -> int: + return sys.getrecursionlimit() + + +if hasattr(sys, "get_int_max_str_digits"): + + @substitute_in_graph(sys.get_int_max_str_digits, can_constant_fold_through=True) + def get_int_max_str_digits() -> int: + return sys.get_int_max_str_digits() + + __all__ += ["get_int_max_str_digits"] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/tensor.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..dffa98f60f3b578810a2386255964d03858afa37 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/polyfills/tensor.py @@ -0,0 +1,40 @@ +from typing import Any + +import torch + +from ..decorators import substitute_in_graph + + +@substitute_in_graph( # type: ignore[arg-type] + torch.Tensor._make_subclass +) +def make_subclass( + cls: type[Any], data: torch.Tensor, requires_grad: bool = False, **kwargs: Any +) -> Any: + with torch._C.DisableTorchFunctionSubclass(): + # This is a rough approximation of `THPVariable_make_subclass`. It should + # suffice for most of Dynamo tracing purposes. + # https://github.com/pytorch/pytorch/blob/ccfde4dadfa3c342076a1ee387017f84dd4ad2f7/torch/csrc/autograd/python_variable.cpp#L597-L650 + assert len(kwargs) == 0, ( + "_make_subclass only supports requires_grad as keyword arg" + ) + data = data.detach() + + # Avoid unnecessary `requires_grad` mutation, which isn't supported in Dynamo. + if data.requires_grad != requires_grad: + data.requires_grad = requires_grad + + # Dynamo can't yet handle upcasting to base tensor type via `as_subclass`. + if cls is torch.Tensor: + return torch.Tensor(data) + + # Calling `as_subclass` because + # 1. Dynamo knows how to handle it + # 2. the C impls match at this point -- both `THPVariable_make_subclass` and + # `THPVariable_as_subclass` calls `THPVariable_NewWithVar`. + return data.as_subclass(cls) + + +__all__ = [ + "make_subclass", +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/repro/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/repro/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/repro/after_aot.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/repro/after_aot.py new file mode 100644 index 0000000000000000000000000000000000000000..25ef68a111080a42e97e7fe738203e5a42e1f9df --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/repro/after_aot.py @@ -0,0 +1,1281 @@ +""" +Utilities for reproducing and debugging issues in PyTorch's Dynamo AOT compilation. + +This module provides tools and infrastructure for: +1. Generating minimal reproducible test cases ("repros") from failing compilations +2. Analyzing accuracy issues between eager and compiled execution +3. Minifying large models/inputs to isolate problematic patterns +4. Debugging compiler errors and accuracy divergences + +The main components include: +- Repro generation: Creates standalone Python files that reproduce compiler issues +- Minification: Reduces large graphs to minimal failing examples +- Accuracy analysis: Compares compiled vs eager execution, with fp64 reference +- Debug tools: Dumps graph state, tracks intermediates, analyzes divergences + +This is primarily used by PyTorch developers and researchers to debug issues in +the Dynamo AOT compilation pipeline, particularly for the Inductor backend. +""" + +from __future__ import annotations + +import argparse +import copy +import functools +import io +import logging +import os +import shutil +import subprocess +import sys +import textwrap +import uuid +from importlib import import_module +from tempfile import TemporaryFile +from typing import Any, IO, Optional, TYPE_CHECKING, Union +from typing_extensions import Unpack + +import sympy + + +try: + from triton.runtime.autotuner import Autotuner, Heuristics + from triton.runtime.jit import JITFunction +except ImportError: + + class Autotuner: # type: ignore[no-redef] + pass + + class JITFunction: # type: ignore[no-redef] + pass + + class Heuristics: # type: ignore[no-redef] + pass + + +import torch +import torch.fx as fx +import torch.nn as nn +from torch._dynamo.debug_utils import ( + _cuda_system_info_comment, + AccuracyError, + backend_accuracy_fails, + BuckTargetWriter, + cast_to_fp64, + extra_deps, + extra_imports, + generate_config_string, + generate_env_vars_string, + helper_for_dump_minify, + InputReader, + InputWriter, + MAX_CONSTANT_NUMEL_INLINE, + minifier_dir, + NNModuleToString, + NopInputReader, + same_two_models, +) +from torch._dynamo.utils import clone_inputs, counters, same +from torch._environment import is_fbcode +from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table +from torch._inductor.cpp_builder import normalize_path_separator +from torch._library.fake_class_registry import FakeScriptObject +from torch._ops import OpOverload +from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.experimental.symbolic_shapes import ( + fx_placeholder_targets, + has_free_symbols, +) +from torch.hub import tqdm + +from .. import config + + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + from torch._inductor.compile_fx import _CompileFxCallable, _CompileFxKwargs + from torch._inductor.output_code import OutputCode + from torch._inductor.utils import InputType + + +log = logging.getLogger(__name__) + + +inductor_config = import_module("torch._inductor.config") +use_buck = is_fbcode() + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# MAIN ENTRY POINT +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def wrap_compiler_debug( + unconfigured_compiler_fn: _CompileFxCallable, + compiler_name: str, +) -> _CompileFxCallable: + """ + Minifier for Fx Graph modules after Aot Autograd has finished. We wrap both + forward and backward call separately with the backend compiler_fn - like + inductor or nvfuser. Intercepting after Aot Autograd presents neat + abstraction, where all the params are lifted as graph inputs, making it easy + to save the graph as a string. + """ + + @functools.wraps(unconfigured_compiler_fn) + def debug_wrapper( + gm: torch.fx.GraphModule, + example_inputs: Sequence[InputType], + **kwargs: Unpack[_CompileFxKwargs], + ) -> OutputCode: + from torch._subclasses import FakeTensorMode + + compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs) + + from torch._functorch.aot_autograd import get_aot_graph_name + + graph_name = get_aot_graph_name() + + # TODO: why do we need to deepcopy the original graph? + orig_graph = copy.deepcopy(gm.graph) + assert config.repro_after in ("dynamo", "aot", None) + + try: + # Call the compiler_fn - which is either aot_autograd or inductor + # with fake inputs + inner_compiled_fn = compiler_fn(gm, example_inputs) + except Exception: + # TODO: Failures here are troublesome because no real inputs, + # need a different serialization strategy + if config.repro_after == "aot": + if config.repro_level == 1: + dump_compiler_graph_state( + fx.GraphModule(gm, orig_graph), + example_inputs, + compiler_name, + ) + elif config.repro_level == 2: + dump_to_minify( + fx.GraphModule(gm, orig_graph), + example_inputs, + compiler_name, + ) + log.error("CompilerError") + raise + + # We may run regular PyTorch compute that may trigger Dynamo, do NOT + # recursively attempt to accuracy minify in that case! + def deferred_for_real_inputs( + real_inputs: Sequence[InputType], **_kwargs: object + ) -> Any: + # This is a bit obscure: if we recursively try to accuracy minify + # the SAME function, this would trigger. But most of the time + # we should never hit this branch + assert not _kwargs + if config.repro_after != "aot": + assert not isinstance(inner_compiled_fn, str) + return inner_compiled_fn(real_inputs) + with config.patch(repro_after=None): + return inner_debug_fn(real_inputs) + + def inner_debug_fn(real_inputs: Sequence[InputType]) -> Any: + """ + Aot Autograd fw_compiler and bw_compiler can have fake tensors. So, + example_inputs can be fake tensors. We can call compiler_fn (which is + inductor or nvfuser) with fake tensors but the actually compiled_fn + should be called with real tensors. Therefore, the actual invocation + is deferred. + """ + # Copy the tensor attrs like shape, stride etc by converting to Fake Tensor + # because inductor clears the tensor list in its codegen. And example_inputs + # are available only for the first invocation. + fake_mode = FakeTensorMode() + copy_tensor_attrs = [ + fake_mode.from_tensor(x) if isinstance(x, torch.Tensor) else x + for x in real_inputs + ] + if config.repro_level == 3: + # Always dump the original module in case we have segfaults + dump_to_minify( + fx.GraphModule(gm, orig_graph), real_inputs, compiler_name + ) + + if config.repro_level == 4: + if compiler_name != "inductor": + raise NotImplementedError( + "Accuracy minification is supported for inductor only" + ) + failed = not same_two_models( + gm, + inner_compiled_fn, # type: ignore[arg-type] + real_inputs, + only_fwd=True, + ignore_non_fp=config.repro_ignore_non_fp, + ) + + if failed: + log.warning( + "Accuracy failed for the AOT Autograd graph %s", graph_name + ) + dump_compiler_graph_state( + fx.GraphModule(gm, orig_graph), + real_inputs, + f"{compiler_name}_accuracy", + ) + dump_to_minify( + fx.GraphModule(gm, orig_graph), + real_inputs, + f"{compiler_name}_accuracy", + ) + raise AccuracyError("Bad accuracy detected") + else: + # Call the compiled function with real inputs + return inner_compiled_fn(real_inputs) # type: ignore[operator] + else: + try: + # Call the compiled function with real inputs + out = inner_compiled_fn(real_inputs) # type: ignore[operator] + # sync cuda kernels to ensure IMA detection + for arg in example_inputs: + if isinstance(arg, torch.Tensor) and arg.is_cuda: + torch.cuda.synchronize() + break + return out + except Exception: + if config.repro_level == 1: + dump_compiler_graph_state( + fx.GraphModule(gm, orig_graph), + copy_tensor_attrs, + compiler_name, + ) + elif config.repro_level == 2: + dump_to_minify( + fx.GraphModule(gm, orig_graph), + copy_tensor_attrs, + compiler_name, + ) + raise + + if config.repro_after == "aot": + compiled_fn = deferred_for_real_inputs + compiled_fn._boxed_call = True # type: ignore[attr-defined] + return compiled_fn # type: ignore[return-value] + else: + return inner_compiled_fn + + return debug_wrapper + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# DUMP REPROS +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def maybe_fbcode_instructions() -> str: + if is_fbcode(): + extra_deps_formatted = "\n".join([f' "{dep}",' for dep in extra_deps]) + if len(extra_deps_formatted) > 0: + extra_deps_formatted = "\n" + extra_deps_formatted + return f"""\ +\"\"\" +To run this script in fbcode: +- Create a directory (//scripts/{{your_unixname}}/repro) +- Put this file in scripts/{{your_unixname}}/repro/fx_graph_runnable.py +- Add a TARGETS file that looks like the following +- `buck2 run //scripts/{{your_unixname}}/repro:repro` + +NOTE: you may need additional deps to actually be able to run the script. +``` +# Contents of TARGETS file +load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary") + +python_binary( + name = "repro", + main_src = "fx_graph_runnable.py", + deps = [ + "//caffe2:torch",{extra_deps_formatted} + ], +) +``` +\"\"\" +""" + else: + return "" + + +def generate_compiler_repro_string( + gm: torch.fx.GraphModule, + args: Sequence[Any], + *, + stable_output: bool = False, + save_dir: Optional[str] = None, + stable_hash: bool = False, + has_distributed_ops: bool = False, +) -> str: + if save_dir is not None: + save_dir = normalize_path_separator(save_dir) + # Add distributed imports if needed + distributed_imports = "" + if has_distributed_ops: + distributed_imports = textwrap.dedent( + """ +import torch.distributed as dist +from torch.testing._internal.distributed.fake_pg import FakeStore + """ + ).strip() + + triton_imports = "" + + if len(kernel_side_table.id_to_kernel) > 0: + triton_imports = textwrap.dedent( + """ +import triton +import triton.language as tl + """ + ).strip() + + model_str = textwrap.dedent( + f""" +{generate_env_vars_string(stable_output=stable_output)} +import torch +from torch import tensor, device +import torch.fx as fx +from torch._dynamo.testing import rand_strided +from math import inf +import torch._inductor.inductor_prims +{distributed_imports} +{triton_imports} + +{generate_config_string(stable_output=stable_output)} + +isolate_fails_code_str = None + +{extra_imports} + +{maybe_fbcode_instructions()} + """ + ) + model_str += textwrap.dedent( + """ +if "__compile_source__" in globals(): + import inspect as __after_aot_inspect + import linecache as __after_aot_linecache + __after_aot_filename = __after_aot_inspect.currentframe().f_code.co_filename + __after_aot_linecache.cache[__after_aot_filename] = ( + len(__compile_source__), + None, + __compile_source__.splitlines(True), + __after_aot_filename, + ) +""" + ) + if not stable_output: + model_str += f"# torch version: {torch.version.__version__}\n" + if hasattr(torch.version, "cuda"): + model_str += f"# torch cuda version: {torch.version.cuda}\n" + if hasattr(torch.version, "git_version"): + model_str += f"# torch git version: {torch.version.git_version}\n\n\n" + model_str += _cuda_system_info_comment() + + kernel_side_table_prefix = ( + "torch._higher_order_ops.triton_kernel_wrap.kernel_side_table" + ) + # Track which grid entry corresponds to the best config + for id in kernel_side_table.id_to_kernel: + kernel = kernel_side_table.get_kernel(id) + + try: + if isinstance(kernel, Autotuner): + # pyrefly: ignore [missing-attribute] + if isinstance(kernel.fn, Heuristics): + model_str += "ERROR: Repro will not work as intended, " + model_str += "triton.runtime.autotuner.Heuristics is not currently supported\n" + break + + config_strs = [] + # pyrefly: ignore [missing-attribute] + for kernel_config in kernel.configs: + # pyrefly: ignore [bad-argument-type] + config_strs.append(f"""triton.Config( + {str(kernel_config.kwargs)}, + num_warps={kernel_config.num_warps}, + num_stages={kernel_config.num_stages}, + )""") + + config_str = ",".join(config_strs) + model_str += textwrap.dedent(f""" + @triton.autotune( + configs=[ + {config_str} + ], + key=[] + ) + """).strip() + + model_str += "\n@triton.jit\n" + # pyrefly: ignore [missing-attribute] + src_code = kernel.src if isinstance(kernel, JITFunction) else kernel.fn.src + fn_name = ( + # pyrefly: ignore [missing-attribute] + kernel._fn_name + if isinstance(kernel, JITFunction) + # pyrefly: ignore # missing-attribute + else kernel.fn._fn_name + ) + fn_name = fn_name.split(".")[-1] + + model_str += src_code + model_str += "\n" + model_str += f"{kernel_side_table_prefix}.add_kernel({fn_name})\n" + except AttributeError as e: + model_str += "ERROR: Repro will not work as intended, " + model_str += f"User defined triton kernel exception: {e}\n" + + # pyrefly: ignore [unbound-name] + if len(kernel_side_table.constant_args) > 0: + # pyrefly: ignore [unbound-name] + model_str += f"{kernel_side_table_prefix}.constant_args={kernel_side_table.constant_args}\n" + + model_str += NNModuleToString.convert(gm) + + writer = InputWriter(save_dir, stable_hash=stable_hash) + used_syms = {} + + # Extract from graph placeholders and their corresponding arguments + placeholder_targets = fx_placeholder_targets(gm) + for placeholder, arg in zip(placeholder_targets, args): + # pyrefly: ignore [unbound-name] + if isinstance(arg, (int, torch.SymInt)): + writer.symint(placeholder, arg) + # pyrefly: ignore [unbound-name] + elif isinstance(arg, torch.Tensor): + # TODO: improve these names with FQN + writer.tensor(placeholder, arg) + elif arg is None: + writer.const(placeholder) + else: + writer.unsupported(placeholder, arg) + + # Extract symbolic variables from the same arguments + # pyrefly: ignore [unbound-name] + if ( + # pyrefly: ignore [unbound-name] + isinstance(arg, torch.SymInt) + # By checking sympy.Symbol, we are excluding any symbolic expressions. + # TODO: we may need to solve expressions to extract symbol definitions. + and isinstance(arg.node.expr, sympy.Symbol) + and arg.node.hint is not None + ): + used_syms[str(arg.node)] = arg.node.hint + # pyrefly: ignore [unbound-name] + elif isinstance(arg, torch.Tensor): + # Extract symbolic variables from tensor shapes and strides + for dim in arg.shape: + # pyrefly: ignore [unbound-name] + if ( + # pyrefly: ignore [unbound-name] + isinstance(dim, torch.SymInt) + and isinstance(dim.node.expr, sympy.Symbol) + and dim.node.hint is not None + ): + used_syms[str(dim.node)] = dim.node.hint + for stride in arg.stride(): + # pyrefly: ignore [unbound-name] + if ( + # pyrefly: ignore [unbound-name] + isinstance(stride, torch.SymInt) + and isinstance(stride.node.expr, sympy.Symbol) + and stride.node.hint is not None + ): + used_syms[str(stride.node)] = stride.node.hint + # Add symbolic variable definitions to the top of the generated code + if used_syms: + hint_lines = "\n".join( + f"{name} = {hint}" for name, hint in sorted(used_syms.items()) + ) + model_str = f"{hint_lines}\n\n{model_str}" + + load_args_lines = writer.lines() + load_args_code = "\n".join(load_args_lines) + model_str += load_args_code + "\n" + + model_str += "mod = Repro()\n" + return model_str + + +def save_graph_repro( + fd: IO[Any], + gm: torch.fx.GraphModule, + args: Sequence[Any], + compiler_name: str, + *, + stable_output: bool = False, + save_dir: Optional[str] = None, + command: str = "run", + accuracy: Optional[Union[str, bool]] = None, + tracing_mode: Optional[str] = None, + check_str: Optional[str] = None, + stable_hash: bool = False, +) -> None: + if any( + isinstance(arg, torch.fx.experimental._backward_state.BackwardState) + for arg in args + ): + fd.write( + "Repro is not generated due to existence of BackwardState in graph input" + ) + return + + if save_dir is not None: + save_dir = normalize_path_separator(save_dir) + + # Check if the graph contains distributed operations + has_distributed_ops = any( + node.op == "call_function" + and isinstance(node.target, OpOverload) + and node.target.namespace in {"_c10d_functional", "c10d_functional"} + for node in gm.graph.nodes + ) + + fd.write( + generate_compiler_repro_string( + gm, + args, + stable_output=stable_output, + save_dir=save_dir, + stable_hash=stable_hash, + has_distributed_ops=has_distributed_ops, + ) + ) + if accuracy is None: + accuracy = "_accuracy" in compiler_name + if tracing_mode is None: + tracing_mode = "real" + if any( + has_free_symbols(a) for a in args if not isinstance(a, FakeScriptObject) + ): + tracing_mode = "symbolic" + fd.write("if __name__ == '__main__':\n") + fd.write(" from torch._dynamo.repro.after_aot import run_repro\n") + + # Add distributed initialization before run_repro if needed + if has_distributed_ops: + fd.write( + " # Initialize FakeProcessGroup for distributed operations\n" + " store = FakeStore()\n" + " dist.init_process_group(\n" + ' backend="fake",\n' + " rank=0,\n" + " world_size=2,\n" + " store=store\n" + " )\n" + ) + + fd.write( + f" with torch.no_grad():\n" + f" run_repro(mod, load_args, accuracy={accuracy!r}, command={command!r}, " + f"save_dir={save_dir!r}, tracing_mode={tracing_mode!r}, check_str={check_str!r})\n" + f" # To run it separately, do \n" + f" # mod, args = run_repro(mod, load_args, accuracy={accuracy!r}, command='get_args', " + f"save_dir={save_dir!r}, tracing_mode={tracing_mode!r}, check_str={check_str!r})\n" + f" # mod(*args)" + ) + + # Add distributed cleanup after run_repro + if has_distributed_ops: + fd.write("\n dist.destroy_process_group()\n") + + +def dump_compiler_graph_state( + gm: torch.fx.GraphModule, + args: Sequence[Any], + compiler_name: str, + *, + accuracy: Optional[Union[str, bool]] = None, +) -> None: + subdir = os.path.join(minifier_dir(), "checkpoints") + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + file_name = os.path.join(subdir, f"{len(gm.graph.nodes)}.py") + log.warning( + "Writing checkpoint with %s nodes to %s", len(gm.graph.nodes), file_name + ) + with open(file_name, "w") as fd: + save_graph_repro( + fd, gm, args, compiler_name, save_dir=subdir, accuracy=accuracy + ) + curdir = os.getcwd() + repro_path = os.path.join(curdir, "repro.py") + try: + shutil.copyfile(file_name, repro_path) + log.warning("Copying repro file for convenience to %s", repro_path) + if use_buck: + BuckTargetWriter(file_name).write() + except OSError: + log.warning("No write permissions for %s", repro_path) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# DUMP MINIFIER +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def dump_to_minify( + gm: torch.fx.GraphModule, args: Sequence[Any], compiler_name: str +) -> None: + out = io.StringIO() + # TODO: factor this out + subdir = os.path.join(minifier_dir(), "checkpoints") + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + save_graph_repro(out, gm, args, compiler_name, save_dir=subdir, command="minify") + return helper_for_dump_minify(out.getvalue()) + + +def isolate_fails( + fx_g: torch.fx.GraphModule, + args: Sequence[Any], + compiler_name: str, + env: Optional[dict[str, Any]] = None, + save_dir: Optional[str] = None, + accuracy: Optional[Union[bool, str]] = None, + tracing_mode: Optional[str] = None, + check_str: Optional[str] = None, +) -> bool: + if env is None: + env = {} + subdir = os.path.join(os.getcwd(), "isolate") + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + file_name = os.path.join(subdir, f"{str(uuid.uuid4())[:5]}.py") + with open(file_name, "w") as fd: + save_graph_repro( + fd, + fx_g, + args, + compiler_name, + save_dir=save_dir, + command="minifier-query", + accuracy=accuracy, + tracing_mode=tracing_mode, + check_str=check_str, + ) + # with open(file_name, "r") as fd: + # print(fd.read()) + new_env = os.environ.copy() + new_env = {**new_env, **env} + if use_buck: + cmd = BuckTargetWriter(file_name).write(print_msg=False) + else: + cmd = [sys.executable, file_name] + with ( + TemporaryFile() as stdout, + TemporaryFile() as stderr, + subprocess.Popen( + cmd, + cwd=subdir, + stdout=stdout, + stderr=stderr, + env=new_env, + ) as p, + ): + p.wait() + + stdout.seek(0) + stderr.seek(0) + print( + textwrap.indent(stdout.read().decode("utf-8"), prefix=">> "), + file=sys.stdout, + ) + print( + textwrap.indent(stderr.read().decode("utf-8"), prefix=">> "), + file=sys.stderr, + ) + # print(f"Isolated test failed - {file_name}") + return p.returncode != 0 + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# MINIFIER TOOLS +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def inductor_fails( + fx_g: torch.fx.GraphModule, args: Sequence[Any], check_str: Optional[str] = None +) -> bool: + has_cuda = False + for arg in args: + if isinstance(arg, torch.Tensor) and arg.is_cuda: + has_cuda = True + break + + def sync() -> None: + if has_cuda: + # Ensures that segfaults are surfaced + torch.cuda.synchronize() + + from torch._inductor.compile_fx import compile_fx_inner + + try: + result = fx_g(*args) + assert isinstance(result, (tuple, list)) + assert not any(isinstance(x, (tuple, list)) for x in result) + except Exception: + return False + + sync() + + try: + compile_mod = compile_fx_inner(fx_g, args) + assert not isinstance(compile_mod, str) + compile_mod(args) + sync() + except Exception as e: + if check_str is not None and check_str not in repr(e): + return False + print(repr(e)) + return True + return False + + +def inductor_accuracy_fails( + fx_g: torch.fx.GraphModule, + args: Sequence[Any], + check_str: Optional[str] = None, + *, + require_fp64: bool = False, + ignore_non_fp: bool = False, +) -> bool: + from torch._inductor.compile_fx import compile_fx_inner + + return backend_aot_accuracy_fails( + fx_g, + args, # type: ignore[arg-type] + compile_fx_inner, # type: ignore[arg-type] + require_fp64=require_fp64, + ignore_non_fp=ignore_non_fp, + ) + + +backend_aot_accuracy_fails = functools.partial(backend_accuracy_fails, only_fwd=True) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# REPRO MAIN +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def repro_common( + options: Any, mod: nn.Module, load_args: Any +) -> tuple[torch.fx.GraphModule, Sequence[Any]]: + # Invariant for graphs we generate with the repro script + assert not any(mod.named_parameters()) + for n, b in mod.named_buffers(): + if b.numel() > MAX_CONSTANT_NUMEL_INLINE: + log.warning( + "Constant %s was not serialized, generated random data instead. " + "If you think this is affecting you, please comment on " + "https://github.com/pytorch/pytorch/issues/100468", + n, + ) + + if not hasattr(load_args, "_version"): + log.warning( + "load_args does not have a _version attribute, please file a bug to PyTorch " + "and describe how you generate this repro script" + ) + else: + if load_args._version > 0: + log.warning( + "load_args is version %s, but this version of PyTorch only supports " + "version 0. We will try to run it anyway but there may be an incompatibility; " + "if so, try upgrading your version of PyTorch.", + load_args._version, + ) + + nop_reader = NopInputReader() + load_args(nop_reader) + + with tqdm(desc="Loading inputs", total=nop_reader.total) as pbar: + input_reader = InputReader(save_dir=options.save_dir, pbar=pbar) + load_args(input_reader) + args = input_reader.args + + # Turn mod into a GraphModule the slow way + # TODO: speed this up + mod = make_fx(mod, tracing_mode=options.tracing_mode)(*args) + + # pyrefly: ignore [bad-assignment] + torch._inductor.config.generate_intermediate_hooks = True + + return mod, args + + +ACCURACY_FAILS: dict[str, Callable[[torch.fx.GraphModule, Any], bool]] = { + "": inductor_fails, + # This might look inverted but it's not. strict_accuracy means "we will + # minify any time we see anything that diverges", whereas accuracy is more + # conservative, and will only minify if there is a meaningful fp64 + # divergence + "accuracy": functools.partial( + inductor_accuracy_fails, require_fp64=True, ignore_non_fp=True + ), + "strict_accuracy": inductor_accuracy_fails, +} + + +def repro_minifier_query(options: Any, mod: nn.Module, load_args: Any) -> None: + mod, args = repro_common(options, mod, load_args) + fail_fn = functools.partial( + ACCURACY_FAILS[options.accuracy], + check_str=options.check_str, # type: ignore[call-arg] + ) + if fail_fn(mod, args): + sys.exit(1) + else: + sys.exit(0) + + +def repro_minify(options: Any, mod: nn.Module, load_args: Any) -> None: + from functorch.compile import minifier + + mod, args = repro_common(options, mod, load_args) + compiler_name = "inductor_accuracy" if options.accuracy != "" else "inductor" + + favored_device = 1 if torch.cuda.device_count() >= 2 else 0 + env_variables = {"CUDA_VISIBLE_DEVICES": str(favored_device)} + + module_fails: Any + if options.isolate: + module_fails = functools.partial( + isolate_fails, + env=env_variables, + compiler_name=compiler_name, + save_dir=options.save_dir, + accuracy=options.accuracy, + tracing_mode=options.tracing_mode, + ) + else: + module_fails = ACCURACY_FAILS[options.accuracy] + + minifier( + mod, + args, + module_fails=functools.partial(module_fails, check_str=options.check_str), + dump_state=functools.partial( + dump_compiler_graph_state, compiler_name=compiler_name + ), + save_dir=options.save_dir, + offload_to_disk=options.offload_to_disk, + skip_offload=options.skip_saving_eager_intermediates, + skip_sanity=options.skip_sanity, + max_granularity=options.max_granularity, + ) + + +def repro_analyze(options: Any, mod: nn.Module, load_args: Any) -> None: + from torch._inductor.compile_fx import compile_fx_inner + from torch._inductor.hooks import intermediate_hook + + mod, args = repro_common(options, mod, load_args) + + # TODO: The logic for cloning inputs/models here is intentionally + # modeled off of run_fwd_maybe_bwd, but arguably it is better not to + # clone inputs (as you are doubling your effective GPU memory usage). + # It is certainly faster though! It probably makes sense to let the + # user specify the offload strategy. + + with tqdm(desc="Compiling"): + compiled = compile_fx_inner(mod, args) + total = counters["inductor"]["intermediate_hooks"] + + known_names = set() + + def save_hook(name: str, val: Any) -> None: + known_names.add(name) + if not options.skip_saving_inductor_intermediates: + writer.write_tensor(os.path.join("inductor", name), val) + pbar.update(1) # type: ignore[has-type] + + writer = torch.utils._content_store.ContentStoreWriter( + options.save_dir, stable_hash=options.stable_hash + ) + reader = torch.utils._content_store.ContentStoreReader(options.save_dir) + + new_args = clone_inputs(args) + with ( + intermediate_hook(save_hook), + tqdm(desc="Saving inductor intermediates", total=total) as pbar, + ): + assert not isinstance(compiled, str) + compiled(new_args) # type: ignore[arg-type] + assert not new_args + + def compare_tuples(tuple1: tuple[Any], tuple2: tuple[Any]) -> Optional[str]: + diff_indices = [i for i in range(len(tuple1)) if tuple1[i] != tuple2[i]] + diff_values = [(tuple1[i], tuple2[i]) for i in diff_indices] + + if not diff_values: + return None + else: + return " and ".join(f"{a} != {b}" for a, b in diff_values) + + def check_hook(name: str, val: Any) -> None: + meta = writer.compute_tensor_metadata(val) + meta2 = reader.read_tensor_metadata(os.path.join("inductor", name)) + reason = compare_tuples(meta, meta2) + if reason is not None: + pbar.write(f"NONDETERMINISTIC INDUCTOR at {name} ({reason})") + pbar.update(1) + + if not options.skip_check_deterministic: + new_args = clone_inputs(args) + with ( + intermediate_hook(check_hook), + tqdm(desc="Checking inductor determinism", total=total) as pbar, + ): + compiled(new_args) # type: ignore[arg-type] + assert not new_args + + class WriterInterp(fx.Interpreter): + def __init__(self, mod: torch.nn.Module, subdir: str) -> None: + super().__init__(mod) + self.subdir = subdir + + def run_node(self, n: torch.fx.Node) -> Any: + r = super().run_node(n) + name = n.name + if name in known_names: + pbar.update(1) + writer.write_tensor(os.path.join(self.subdir, name), r) + return r + + # NB: the module cast doesn't actually do anything, since there are no + # parameters/buffers on the module + if not options.skip_saving_float64_intermediates: + new_mod, new_args = cast_to_fp64(copy.deepcopy(mod), clone_inputs(args)) # type: ignore[arg-type] + with tqdm(desc="Saving float64 intermediates", total=total) as pbar: + WriterInterp(new_mod, "float64").boxed_run(new_args) + assert not new_args + + class ExactReaderInterp(fx.Interpreter): + def run_node(self, n: torch.fx.Node) -> Any: + r = super().run_node(n) + name = n.name + if name in known_names: + meta = writer.compute_tensor_metadata(r) + meta2 = reader.read_tensor_metadata(os.path.join("float64", name)) + reason = compare_tuples(meta, meta2) + if reason is not None: + pbar.write(f"NONDETERMINISTIC FLOAT64 at {name} ({reason})") + pbar.update(1) + return r + + # TODO: check eager determinism + + if not options.skip_check_deterministic: + new_mod, new_args = cast_to_fp64(copy.deepcopy(mod), clone_inputs(args)) # type: ignore[arg-type] + with tqdm(desc="Checking float64 determinism", total=total) as pbar: + ExactReaderInterp(new_mod).boxed_run(new_args) + assert not new_args + + # Now that we've saved everything, interp through the eager graph + # and do comparisons + class ReaderInterp(fx.Interpreter): + def run_node(self, n: torch.fx.Node) -> Any: + r = super().run_node(n) + name = n.name + if name in known_names: + inductor = reader.read_tensor(os.path.join("inductor", name)) + float64 = reader.read_tensor(os.path.join("float64", name)) + logged = False + + def log_error(msg: str, *args: Any) -> None: + nonlocal logged + logged = True + pbar.write(f"DIVERGED at {name}: {msg % args}") + + if not same( + r, + inductor, + float64, + tol=torch._dynamo.config.repro_tolerance, + equal_nan=True, + log_error=log_error, + ): + assert logged + pbar.update(1) + return r + + with tqdm(desc="Checking divergence", total=total) as pbar: + ReaderInterp(mod).boxed_run(args) + assert not args + + +def repro_get_args( + options: Any, mod: nn.Module, load_args: Any +) -> tuple[torch.fx.GraphModule, list[Any]]: + mod, args = repro_common(options, mod, load_args) + return mod, args # type: ignore[return-value] + + +def repro_run(options: Any, mod: nn.Module, load_args: Any) -> None: + from torch._inductor.compile_fx import compile_fx_inner + + mod, args = repro_common(options, mod, load_args) + + from torch.cuda import synchronize + + compiled = compile_fx_inner(mod, args) + assert not isinstance(compiled, str) + + if options.accuracy != "": + # We don't really respect --accuracy vs --strict-accuracy here, it + # seems counterintuitive + if not same_two_models( + mod, + compiled, # type: ignore[arg-type] + args, + only_fwd=True, + ignore_non_fp=config.repro_ignore_non_fp, + ): + raise AccuracyError("Bad accuracy detected") + else: + need_sync = False + + for arg in args: + if isinstance(arg, torch.Tensor) and arg.is_cuda: + need_sync = True + break + + compiled(list(args)) + + if need_sync: + synchronize() # ensure segfaults are surfaced + + +# TODO: lazily load the inputs or something, rather than cloning them +def run_repro( + mod: nn.Module, + load_args: Any, + *, + command: str = "run", + accuracy: Union[bool, str] = "", + save_dir: Optional[str] = None, + tracing_mode: Optional[str] = None, + patch_code: Optional[str] = None, + check_str: Optional[str] = None, + **kwargs: Any, +) -> Any: + for k in kwargs: + log.warning( + "Unrecognized kwarg %s; perhaps this repro was made on a newer version of PyTorch", + k, + ) + + if accuracy is True: + accuracy = "accuracy" + elif accuracy is False: + accuracy = "" + + if patch_code is not None: + log.warning( + "patch_code no longer works on this version of PyTorch, silently ignoring" + ) + + parser = argparse.ArgumentParser( + description=f"""\ +An after_aot repro script, typically triggering a bug in PyTorch Inductor. +When run with no arguments, this script defaults to running '{command}'. +Extra flags may be available; to find out more, try '{command} --help'. +There are also alternate subcommands available, see below. + +default settings on this script: + {accuracy=} + {tracing_mode=} + {save_dir=} + {check_str=} +""", + formatter_class=argparse.RawTextHelpFormatter, + ) + + def common_flags(parser: argparse.ArgumentParser) -> None: + accuracy_group = parser.add_mutually_exclusive_group() + accuracy_group.add_argument( + "--no-accuracy", + dest="accuracy", + action="store_const", + const="", + default=accuracy, + help="do not test accuracy, just run the module and see if it errors", + ) + accuracy_group.add_argument( + "--accuracy", + action="store_const", + const="accuracy", + default=accuracy, + help="""\ +test if the RMSE between the compiled module and the fp64 reference is greater +than eager and the fp64 reference. This is usually more reliable than the +standard allclose test, as we expect numeric differences from compiling, often +improving accuracy over eager. RMSE test allows for compiled module to +diverge greatly from eager, as long as this divergence moves it closer to the +'true' mathematical value of the network. Caveats: (1) double precision can +still suffer from rounding error, so it is not a perfect reference (see for +example 'Herbie: Automatically Improving Floating Point Accuracy') for +approaches that detect the necessary working precision and compute it in +arbitrary precision floating point; unfortunately, this is not practical for +tensor computation; (2) if there are not enough samples in the output being +compared, we may get unlucky and have an unlucky greater RMSE than eager; this +could be overcome by applying a more rigorous statistical test at some +p-value, which we leave for future work. +""", + ) + accuracy_group.add_argument( + "--strict-accuracy", + dest="accuracy", + action="store_const", + const="strict_accuracy", + default=accuracy, + help="""\ +by default, when doing accuracy minification we will reject reductions which +change the divergence from a floating point divergence to a integral/boolean +divergence. This is because some operations like ReLU involve temporarily +sharp boundaries that smooth out again afterwards; without requiring +divergence on floating point, the minifier will often fixate on divergent +boolean tensor even though this is not the true source of the divergence. +However, rejecting these reductions makes it more difficult for the minifier +to make process. Using this option will let the minifier progress for ALL +divergences--you just might not end up with a useful repro in the end.""", + ) + + parser.add_argument( + "--save-dir", + type=str, + default=save_dir, + metavar="DIR", + help="directory where saved inputs live", + ) + parser.add_argument( + "--no-save-dir", + dest="save_dir", + action="store_const", + const=None, + help="don't use any directory for saved inputs", + ) + parser.add_argument( + "--tracing-mode", + type=str, + metavar="{real,fake,symbolic}", + default=tracing_mode, + help="how to trace the repro module into a GraphModule with metadata", + ) + + subparsers = parser.add_subparsers( + dest="command", metavar="{run,minify,analyze}", required=True + ) + + parser_run = subparsers.add_parser( + "run", + help="just run the repro", + ) + common_flags(parser_run) + + parser_minify = subparsers.add_parser( + "minify", help="run the minifier on the repro" + ) + common_flags(parser_minify) + parser_get_args = subparsers.add_parser("get_args", help="get the args") + common_flags(parser_get_args) + parser_minify_isolate = parser_minify.add_mutually_exclusive_group() + parser_minify_isolate.add_argument( + "--isolate", + action="store_true", + default=True, + help="run in separate processes to avoid interference (default)", + ) + parser_minify_isolate.add_argument( + "--no-isolate", + dest="isolate", + action="store_false", + help="speed up by running all compilation in same process", + ) + parser_minify.add_argument( + "--skip-saving-eager-intermediates", + action="store_true", + help="skip saving eager intermediates on --minify", + ) + # TODO: make this an option for --analyze too + parser_minify.add_argument( + "--offload-to-disk", + action="store_true", + help="during minification, offload delta debugging intermediates to disk. Use if you're OOMing", + ) + parser_minify.add_argument( + "--skip-sanity", + action="store_true", + help="skip sanity check at beginning of minification on original graph", + ) + parser_minify.add_argument( + "--max-granularity", + type=int, + default=None, + help="start at this granularity and work down; must be power of 2", + ) + parser_minify.add_argument( + "--check-str", + type=str, + default=check_str, + help="require minified program to fail with error containing this string", + ) + + parser_analyze = subparsers.add_parser( + "analyze", help="run the accuracy analyzer on the repro" + ) + common_flags(parser_analyze) + parser_analyze.add_argument( + "--skip-saving-inductor-intermediates", + action="store_true", + help="skip saving inductor intermediates on --analyze", + ) + parser_analyze.add_argument( + "--skip-saving-float64-intermediates", + action="store_true", + help="skip saving float64 intermediates", + ) + parser_analyze.add_argument( + "--skip-check-deterministic", + action="store_true", + help="skip checking that the network is deterministic", + ) + parser_analyze.add_argument( + "--stable-hash", + action="store_true", + help="use SHA-1 checksum instead of fast (but possibly unsound) hash", + ) + + # Run the repro in the context of minification, inverting exit code meaning + parser_minifier_query = subparsers.add_parser( + "minifier-query", + ) + common_flags(parser_minifier_query) + parser_minifier_query.add_argument( + "--check-str", + type=str, + default=check_str, + help="require minified program to fail with error containing this string", + ) + + args = None + if len(sys.argv) <= 1: + args = [command, *sys.argv[1:]] + + options = parser.parse_args(args) + COMMAND_FNS = { + "minify": repro_minify, + "analyze": repro_analyze, + "minifier-query": repro_minifier_query, + "run": repro_run, + "get_args": repro_get_args, + } + return COMMAND_FNS[options.command](options, mod, load_args) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py new file mode 100644 index 0000000000000000000000000000000000000000..a17518fc6c74d7c64477964f3fc7d1176fc67019 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py @@ -0,0 +1,637 @@ +""" +Utilities for reproducing and debugging issues in Dynamo after graph capture. + +This file provides tools and infrastructure for debugging problems that occur +after Dynamo has captured the graph but before/during backend compilation. +Key components include: + +- Minification tools to reduce large graphs to minimal failing examples +- Accuracy testing to validate compiled graph outputs match eager mode +- Repro generation to create standalone reproduction scripts +- Debug backends for capturing and analyzing failures +- Utilities for saving/loading graph states and inputs + +The tools here focus specifically on the post-graph-capture stage, making them +useful for debugging backend compilation issues, AOTAutograd problems, and +accuracy discrepancies between compiled and eager execution. +""" + +import argparse +import copy +import functools +import logging +import os +import shutil +import sys +import textwrap +from collections.abc import Callable, Sequence +from importlib import import_module +from typing import Any, Optional, Union + +import torch +import torch.fx as fx +from torch._dynamo.debug_utils import ( + AccuracyError, + backend_accuracy_fails, + BUCK_CMD_PREFIX, + BuckTargetWriter, + extra_imports, + generate_config_string, + generate_env_vars_string, + helper_for_dump_minify, + InputReader, + InputWriter, + minifier_dir, + NNModuleToString, + NopInputReader, + run_fwd_maybe_bwd, + same_two_models, +) +from torch.fx.experimental.symbolic_shapes import fx_placeholder_targets +from torch.hub import tqdm + +from .. import config +from ..backends.registry import CompilerFn, lookup_backend, register_debug_backend +from ..debug_utils import clone_inputs_retaining_gradness + + +log = logging.getLogger(__name__) + + +inductor_config = import_module("torch._inductor.config") +use_buck = inductor_config.is_fbcode() + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# MAIN ENTRY POINT +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def _accuracy_fails( + gm: torch.fx.GraphModule, + example_inputs: Sequence[Any], + compiler_fn: Callable[[torch.fx.GraphModule, list[Any]], torch.fx.GraphModule], +) -> bool: + return backend_accuracy_fails( + gm, + example_inputs, + compiler_fn, + only_fwd=config.repro_forward_only, + ignore_non_fp=config.repro_ignore_non_fp, + ) + + +class WrapBackendDebug: + def __init__( + self, unconfigured_compiler_fn: CompilerFn, compiler_name: Optional[str] + ) -> None: + functools.wraps(unconfigured_compiler_fn)(self) + self._torchdynamo_orig_backend = unconfigured_compiler_fn + self._compiler_name = compiler_name + if hasattr(unconfigured_compiler_fn, "__name__"): + self.__name__ = unconfigured_compiler_fn.__name__ + if hasattr(unconfigured_compiler_fn, "compiler_name"): + self.__name__ = unconfigured_compiler_fn.compiler_name # type: ignore[attr-defined] + if hasattr(unconfigured_compiler_fn, "get_compiler_config"): + self.get_compiler_config = unconfigured_compiler_fn.get_compiler_config # type: ignore[attr-defined] + + def __call__( + self, gm: torch.fx.GraphModule, example_inputs: list[Any], **kwargs: Any + ) -> torch.fx.GraphModule: + compiler_fn = functools.partial(self._torchdynamo_orig_backend, **kwargs) + assert config.repro_after in ("dynamo", "aot", None) + + if config.repro_after == "dynamo": + + def add_paths(exc: Exception) -> None: + exc.minifier_path = os.path.join(minifier_dir(), "minifier_launcher.py") # type: ignore[attr-defined] + if use_buck: + exc.buck_command = " ".join( # type: ignore[attr-defined] + BUCK_CMD_PREFIX + + [BuckTargetWriter(exc.minifier_path).cmd_line_path] # type: ignore[attr-defined] + ) + + if config.repro_level == 3: + dump_to_minify_after_dynamo(gm, example_inputs, self._compiler_name) + + # Check for either accuracy (level 4) or other type of failures. + if config.repro_level == 4: + # Check Accuracy + compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs) + if _accuracy_fails(gm, example_inputs, compiler_fn): # type: ignore[arg-type] + log.warning( + "Accuracy failed for the TorchDynamo produced graph. Creating script to minify the error." + ) + dump_to_minify_after_dynamo( + fx.GraphModule(gm, copy.deepcopy(gm.graph)), + example_inputs, + self._compiler_name, + ) + exc = AccuracyError("Bad accuracy detected.") + add_paths(exc) + raise exc + else: + try: + compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs) + run_fwd_maybe_bwd(compiled_gm, example_inputs) # type: ignore[arg-type] + except Exception as exc: + log.warning( + "Compiled Fx GraphModule failed. Creating script to minify the error." + ) + if config.repro_level == 1: + dump_state_fn = functools.partial( + dump_backend_state, compiler_name=self._compiler_name + ) + dump_state_fn( + fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs + ) + elif config.repro_level == 2: + dump_to_minify_after_dynamo( + fx.GraphModule(gm, copy.deepcopy(gm.graph)), + example_inputs, + self._compiler_name, + ) + add_paths(exc) + raise + else: + compiled_gm = compiler_fn(gm, example_inputs) + + return compiled_gm # type: ignore[return-value] + + +def wrap_backend_debug( + unconfigured_compiler_fn: CompilerFn, compiler_name: Optional[str] +) -> WrapBackendDebug: + """ + A minifier decorator that wraps the TorchDynamo produced Fx graph modules. + As opposed to wrap_compiler_debug, this wrapper intercepts at the + TorchDynamo produced Fx Graph Module. This makes it backend-agnostic to some + level, e.g., it is useful for minifying issues related to Aot Autograd + tracing. If an error is found, we minify and save the minified repro in + repro.tar.gz. + """ + return WrapBackendDebug(unconfigured_compiler_fn, compiler_name) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# REPRO DUMPERS +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def generate_dynamo_fx_repro_string( + gm: torch.fx.GraphModule, + args: Sequence[Any], + compiler_name: Optional[str], + check_accuracy: bool = False, + *, + stable_output: bool = False, + save_dir: Optional[str] = None, + command: str = "run", +) -> str: + """ + Generate a repro string for backend-agnostic minified version. + """ + + model_str = NNModuleToString.convert(gm) + + # TODO: Figure out why torch.compile'd hash isn't work on this codepath + writer = InputWriter(save_dir, stable_hash=True) + for placeholder, arg in zip(fx_placeholder_targets(gm), args): + if isinstance(arg, (int, torch.SymInt)): + writer.symint(placeholder, arg) + elif isinstance(arg, torch.Tensor): + # TODO: improve these names with FQN + writer.tensor(placeholder, arg) + else: + raise TypeError(f"arg is neither SymInt/int nor torch.Tensor, {arg}") + load_args = "\n".join(writer.lines()) + + return textwrap.dedent( + f""" +{generate_env_vars_string(stable_output=stable_output)} +from math import inf +import torch +from torch import tensor, device +import torch.fx as fx +import torch._dynamo +from torch._dynamo.testing import rand_strided +from torch._dynamo.debug_utils import run_fwd_maybe_bwd + +{generate_config_string(stable_output=stable_output)} + +{extra_imports} + +{model_str} +mod = Repro() + +{load_args} + +if __name__ == '__main__': + from torch._dynamo.repro.after_dynamo import run_repro + run_repro(mod, load_args, accuracy={check_accuracy!r}, command={command!r}, + save_dir={save_dir!r}, autocast={torch.is_autocast_enabled()!r}, backend={compiler_name!r}) +""" + ) + + +def dump_backend_repro_as_file( + gm: torch.fx.GraphModule, + args: Sequence[Any], + compiler_name: Optional[str], + check_accuracy: bool = False, +) -> None: + """ + Saves the repro to a repro.py file + """ + curdir = os.getcwd() + subdir = os.path.join(os.getcwd(), "checkpoints") + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + file_name = os.path.join(subdir, f"minified_{len(gm.graph.nodes)}_nodes.py") + log.warning( + "Writing checkpoint with %s nodes to %s", len(gm.graph.nodes), file_name + ) + + with open(file_name, "w") as fd: + fd.write( + generate_dynamo_fx_repro_string( + gm, args, compiler_name, check_accuracy, save_dir=subdir + ) + ) + latest_repro = os.path.join(curdir, "repro.py") + log.warning("Copying %s to %s for convenience", file_name, latest_repro) + + if use_buck: + BuckTargetWriter(latest_repro).write() + + shutil.copyfile(file_name, latest_repro) + + +def dump_backend_state( + gm: torch.fx.GraphModule, + args: Sequence[Any], + compiler_name: Optional[str], + check_accuracy: bool = False, +) -> None: + """ + Dumps the dynamo graph to repro the issue. + 1) It tries to convert Fx GraphModule to a string. If we can, it writes to a + repro.py file. + 2) If we can't convert Fx GraphModule to a string, we use to_folder to save + the module and save a tar file. + """ + assert NNModuleToString.can_convert_to_string(gm) + return dump_backend_repro_as_file(gm, args, compiler_name, check_accuracy) + # return dump_backend_repro_as_tarfile(gm, args, compiler_name) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# MINIFIER DUMPER +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def dump_to_minify_after_dynamo( + gm: torch.fx.GraphModule, args: Sequence[Any], compiler_name: Optional[str] +) -> None: + # TODO: factor this out + subdir = os.path.join(minifier_dir(), "checkpoints") + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + helper_for_dump_minify( + generate_dynamo_fx_repro_string( + gm, + args, + compiler_name, + check_accuracy=config.repro_level == 4, + save_dir=subdir, + command="minify", + ) + ) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# MINIFIER BACKENDS +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +@register_debug_backend # type: ignore[arg-type] +def dynamo_minifier_backend( + gm: fx.GraphModule, example_inputs: Sequence[Any], compiler_name: Optional[str] +) -> fx.GraphModule: + from functorch.compile import minifier + + compiler_fn = lookup_backend(compiler_name) # type: ignore[arg-type] + + # TODO: It's inconsistent to pass SymInt inputs but REAL tensors. + # We should pass ints and look at the GraphModule placeholders + # to resolve them to SymInt (if necessary) + example_inputs = [ + i.node.hint if isinstance(i, torch.SymInt) else i for i in example_inputs + ] + + try: + compiled_gm = compiler_fn(gm, example_inputs) + run_fwd_maybe_bwd(compiled_gm, example_inputs) # type: ignore[arg-type] + raise ValueError("No issue was detected") + except Exception as exc: + orig_failure = str(exc) + log.warning( + "Compiled Fx GraphModule failed. Creating script to minify the error." + ) + dump_state_fn = functools.partial( + dump_backend_state, compiler_name=compiler_name + ) + dump_state_fn(fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs) + fails_fn = functools.partial( + backend_fails, + compiler_fn=compiler_fn, + orig_failure=orig_failure, + ) + minifier( + gm, + example_inputs, + module_fails=fails_fn, + dump_state=dump_state_fn, + ) + return gm + + +@register_debug_backend # type: ignore[arg-type] +def dynamo_accuracy_minifier_backend( + gm: fx.GraphModule, example_inputs: Sequence[Any], compiler_name: Optional[str] +) -> fx.GraphModule: + from functorch.compile import minifier + + compiler_fn = lookup_backend(compiler_name) # type: ignore[arg-type] + + # Set the eval mode to remove randomness. + gm.eval() + + # Check Accuracy + if _accuracy_fails(gm, example_inputs, compiler_fn): # type: ignore[arg-type] + log.warning("Accuracy failed for the TorchDynamo produced graph") + dump_state_fn = functools.partial( + dump_backend_state, compiler_name=compiler_name, check_accuracy=True + ) + fails_fn = functools.partial( + _accuracy_fails, + compiler_fn=compiler_fn, # type: ignore[arg-type] + ) + dump_state_fn(fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs) + minifier( + gm, + example_inputs, + module_fails=fails_fn, + dump_state=dump_state_fn, + ) + else: + log.error("Input graph does not fail accuracy testing") + return gm + + +def backend_fails( + gm: fx.GraphModule, + example_inputs: Sequence[Any], + compiler_fn: CompilerFn, + orig_failure: Sequence[Any], +) -> bool: + """ + Minifier uses this function to identify if the minified graph module fails + with the same error. + + One caveat is that minifier can potentially go into a wrong direction when + the resulting graph module fails for a different reason. To avoid this, we + save the string for the original exception and check similarity between new + and old exception. They can be somewhat different in some cases, when the + exception string depends on the failing node information. So, we have a + loose similarity metric to guide the minifier path. + """ + from difflib import SequenceMatcher + + try: + # Run the original gm to check eager validity + run_fwd_maybe_bwd(gm, clone_inputs_retaining_gradness(example_inputs)) + compiled_gm = compiler_fn(gm, example_inputs) # type: ignore[arg-type] + run_fwd_maybe_bwd(compiled_gm, clone_inputs_retaining_gradness(example_inputs)) # type: ignore[arg-type] + except Exception as e: + new_failure = str(e) + if SequenceMatcher(None, orig_failure, new_failure).ratio() > 0.5: + return True + return False + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# REPRO MAIN +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def run_load_args(options: Any, mod: torch.nn.Module, load_args: Any) -> list[Any]: + if not hasattr(load_args, "_version"): + log.warning( + "load_args does not have a _version attribute, please file a bug to PyTorch " + "and describe how you generate this repro script" + ) + else: + if load_args._version > 0: + log.warning( + "load_args is version %s, but this version of PyTorch only supports " + "version 0. We will try to run it anyway but there may be an incompatibility; " + "if so, try upgrading your version of PyTorch.", + load_args._version, + ) + + nop_reader = NopInputReader() + load_args(nop_reader) + + with tqdm(desc="Loading inputs", total=nop_reader.total) as pbar: + input_reader = InputReader(save_dir=options.save_dir, pbar=pbar) + load_args(input_reader) + args = input_reader.args + + return args + + +def repro_minify(options: Any, mod: torch.nn.Module, load_args: Any) -> None: + args = run_load_args(options, mod, load_args) + + # Setup debug minifier compiler + if not options.accuracy: + compiler_fn = lookup_backend("dynamo_minifier_backend") + else: + compiler_fn = lookup_backend("dynamo_accuracy_minifier_backend") + + if options.backend is None: + raise RuntimeError( + "Compiler name is None - this likely means that a custom compiler " + "was called by torchdynamo. Please remove this error, import your " + "custom compiler function, and replace the backend=None " + "line in run_repro to backend=" + ) + + dynamo_minifier_backend = functools.partial( + compiler_fn, + compiler_name=options.backend, # type: ignore[call-arg] + ) + opt_mod = torch._dynamo.optimize(dynamo_minifier_backend)(mod) + + with torch.amp.autocast("cuda", enabled=options.autocast): + opt_mod(*args) + + +def repro_run(options: Any, mod: torch.nn.Module, load_args: Any) -> None: + opt_mod = torch._dynamo.optimize(options.backend)(mod) + + if options.accuracy != "": + mod.eval() + opt_mod.eval() # type: ignore[union-attr] + + with torch.amp.autocast("cuda", enabled=options.autocast): + # TODO: disable clone + args = run_load_args(options, mod, load_args) + assert same_two_models(mod, mod, args), "Eager itself failed" # type: ignore[arg-type] + if not same_two_models( + mod, # type: ignore[arg-type] + opt_mod, # type: ignore[arg-type] + args, + only_fwd=config.repro_forward_only, + ignore_non_fp=config.repro_ignore_non_fp, + ): + raise AccuracyError("Dynamo failed") + else: + with torch.amp.autocast("cuda", enabled=options.autocast): + args = run_load_args(options, mod, load_args) + run_fwd_maybe_bwd(mod, args, only_fwd=options.only_fwd, disable_clone=True) # type: ignore[arg-type] + del args + + args = run_load_args(options, mod, load_args) + run_fwd_maybe_bwd( + opt_mod, # type: ignore[arg-type] + args, + only_fwd=options.only_fwd, + disable_clone=True, # type: ignore[arg-type] + ) + + +def run_repro( + mod: torch.nn.Module, + load_args: Any, + *, + command: str = "run", + accuracy: Union[bool, str] = "", + save_dir: Optional[str] = None, + autocast: bool = False, + backend: str = "inductor", + **kwargs: Any, +) -> None: + for k in kwargs: + log.warning( + "Unrecognized kwarg %s; perhaps this repro was made on a newer version of PyTorch", + k, + ) + + if accuracy is True: + accuracy = "accuracy" + elif accuracy is False: + accuracy = "" + + parser = argparse.ArgumentParser( + description=f"""\ +An after_dynamo repro script, typically triggering a bug in Dynamo or +AOTAutograd. When run with no arguments, this script defaults to running +'{command}'. Extra flags may be available; to find out more, try '{command} +--help'. There are also alternate subcommands available, see below. + +default settings on this script: + {accuracy=} + {save_dir=} +""", + formatter_class=argparse.RawTextHelpFormatter, + ) + + def common_flags(parser: argparse.ArgumentParser) -> None: + accuracy_group = parser.add_mutually_exclusive_group() + accuracy_group.add_argument( + "--no-accuracy", + dest="accuracy", + action="store_const", + const="", + default=accuracy, + help="do not test accuracy, just run the module and see if it errors", + ) + accuracy_group.add_argument( + "--accuracy", + action="store_const", + const="accuracy", + default=accuracy, + help="test accuracy", + ) + parser.add_argument( + "--save-dir", + type=str, + default=save_dir, + metavar="DIR", + help="directory where saved inputs live", + ) + parser.add_argument( + "--no-save-dir", + dest="save_dir", + action="store_const", + const=None, + help="don't use any directory for saved inputs", + ) + parser.add_argument( + "--no-isolate", + dest="isolate", + action="store_false", + default=False, + help="no isolate (doesn't do anything for after_dynamo)", + ) + parser.add_argument( + "--autocast", + default=autocast, + action="store_true", + help="use torch.cuda.amp.autocast", + ) + parser.add_argument( + "--no-autocast", + dest="autocast", + action="store_false", + help="don't use torch.cuda.amp.autocast", + ) + parser.add_argument( + "--backend", + type=str, + default=backend, + metavar="BACKEND", + help="torch.compile backend to use", + ) + + subparsers = parser.add_subparsers( + dest="command", metavar="{run,minify}", required=True + ) + + parser_run = subparsers.add_parser( + "run", + help="just run the repro", + ) + common_flags(parser_run) + parser_run.add_argument( + "--only-fwd", + action="store_true", + help="don't run backwards compilation for testing", + ) + + parser_minify = subparsers.add_parser( + "minify", help="run the minifier on the repro" + ) + common_flags(parser_minify) + + args = None + if len(sys.argv) <= 1: + args = [command, *sys.argv[1:]] + + options = parser.parse_args(args) + COMMAND_FNS = { + "minify": repro_minify, + "run": repro_run, + } + COMMAND_FNS[options.command](options, mod, load_args) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/repro/aoti.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/repro/aoti.py new file mode 100644 index 0000000000000000000000000000000000000000..d1f556787695c92b070166c364a3fbf85e262631 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/repro/aoti.py @@ -0,0 +1,661 @@ +""" +Utilities for debugging and reproducing issues in Ahead of Time with Inductor (AOTI) compilation. + +This file provides tools and utilities for: +- Generating minimal reproducible test cases (minification) +- Handling exported programs and graph modules +- Creating debug repros for AOTI compilation issues +- Supporting both accuracy testing and error reproduction +- Managing configuration and environment for repro cases + +The main components include: +- Minification tools to reduce test cases while preserving errors +- Repro generation utilities for exported programs +- Error handling specific to AOTI compilation +- Command-line interface for running and managing repros +""" + +import argparse +import functools +import io +import logging +import os +import re +import shutil +import sys +import textwrap +from collections.abc import Sequence +from importlib import import_module +from typing import Any, IO, Optional, Union + +import torch +from torch._dynamo.debug_utils import ( + _cuda_system_info_comment, + BuckTargetWriter, + extra_imports, + generate_config_string, + generate_env_vars_string, + helper_for_dump_minify, + InputReader, + minifier_dir, + NNModuleToString, + NopInputReader, +) +from torch.export import ExportedProgram +from torch.hub import tqdm + + +log = logging.getLogger(__name__) + + +inductor_config = import_module("torch._inductor.config") +use_buck = inductor_config.is_fbcode() + + +class AOTIMinifierError(Exception): + def __init__(self, original_exception: Union[str, Exception]) -> None: + additional_message = "This error is caused by a bug in the AOTI minifier, please report a bug to PyTorch" + full_message = f"{additional_message}: {str(original_exception)}" + super().__init__(full_message) + self.original_exception = original_exception + + +def dump_to_minify( + exported_program: ExportedProgram, + compiler_name: str, + command: str = "minify", + options: Optional[dict[str, Any]] = None, +) -> None: + """ + If command is "minify": + Dump exported_program to `debug_dir/minifier/minifier_launcher.py`, with minify command. + If command is "run": + Dump exported_program to `cwd/repro.py`, with run command. + """ + assert command in ["minify", "run"] + + subdir = os.path.join(minifier_dir(), "checkpoints") + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + + if command == "minify": + out = io.StringIO() + save_graph_repro_ep( + out, + compiler_name, + exported_program=exported_program, + save_dir=subdir, + command="minify", + config_patches=options, + ) + return helper_for_dump_minify(out.getvalue()) + else: + curdir = os.getcwd() + file_name = os.path.join(curdir, "repro.py") + try: + with open(file_name, "w") as fd: + save_graph_repro_ep( + fd, + compiler_name, + exported_program=exported_program, + config_patches=options, + save_dir=subdir, + command="run", + module_in_comment=True, + ) + log.warning("Writing repro file to %s", file_name) + if use_buck: + BuckTargetWriter(file_name).write() + except OSError: + log.warning("No write permissions for %s", file_name) + + +def get_module_string(gm: torch.fx.GraphModule) -> str: + def _convert_to_comment(s_: str) -> str: + s = s_.split("\n") + if len(s) == 1: + return "# " + s_ + first = s.pop(0) + for i in range(len(s)): + line = s[i] + if line.strip() != "": + s[i] = "# " + line + else: + s[i] = "" + s = "\n".join(s) + s = first + "\n" + s + return s + + module_string = NNModuleToString.convert(gm) + return _convert_to_comment(module_string) + + +def save_graph_repro_ep( + fd: IO[Any], + compiler_name: str, + *, + exported_program: Optional[ExportedProgram] = None, + gm: Optional[torch.nn.Module] = None, + args: Optional[tuple[Any]] = None, + config_patches: Optional[dict[str, str]] = None, + stable_output: bool = False, + save_dir: Optional[str] = None, + command: str = "run", + accuracy: Optional[Union[str, bool]] = None, + check_str: Optional[str] = None, + module_in_comment: bool = False, + strict: bool = False, +) -> None: + # Save graph for reproducing the error. + # Either exported_program or gm will be saved, depending on which one is defined. + # Only one of exported_program and gm should be defined. + + if exported_program is None and gm is None: + raise AOTIMinifierError("One of exported_program and gm must be defined") + if exported_program is not None and gm is not None: + raise AOTIMinifierError("Only one of exported_program and gm can be defined") + if gm is not None and args is None: + raise AOTIMinifierError("If gm is defined, args should also be defined") + + if exported_program is None: + assert gm is not None + assert args is not None + exported_program = torch.export.export(gm, args, strict=strict) + elif gm is None: + gm = exported_program.module(check_guards=False) + + # save a graph preview using gm + module_string = get_module_string(gm) # type: ignore[arg-type] + fd.write(module_string) + + # save a graph repro using exported_program + fd.write( + generate_compiler_repro_exported_program( + exported_program, + options=config_patches, + stable_output=stable_output, + save_dir=save_dir, + ) + ) + if accuracy is None: + accuracy = "_accuracy" in compiler_name + fd.write("if __name__ == '__main__':\n") + fd.write(" from torch._dynamo.repro.aoti import run_repro\n") + fd.write( + f" with torch.no_grad():\n" + f" run_repro(exported_program, config_patches=config_patches, accuracy={accuracy!r}, command={command!r}, " + f"save_dir={save_dir!r}, check_str={check_str!r})\n" + ) + + +def dump_compiler_graph_state( + gm: torch.fx.GraphModule, + args: Sequence[Any], + compiler_name: str, + *, + config_patches: Optional[dict[str, str]] = None, + accuracy: Optional[Union[str, bool]] = None, + strict: bool = False, +) -> None: + subdir = os.path.join(minifier_dir(), "checkpoints") + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + file_name = os.path.join(subdir, f"{len(gm.graph.nodes)}.py") + log.warning( + "Writing checkpoint with %s nodes to %s", len(gm.graph.nodes), file_name + ) + with open(file_name, "w") as fd: + save_graph_repro_ep( + fd, + compiler_name, + gm=gm, + args=tuple(args), + config_patches=config_patches, + save_dir=subdir, + accuracy=accuracy, + module_in_comment=True, + strict=strict, + ) + curdir = os.getcwd() + repro_path = os.path.join(curdir, "repro.py") + try: + shutil.copyfile(file_name, repro_path) + log.warning("Copying repro file for convenience to %s", repro_path) + if use_buck: + BuckTargetWriter(file_name).write() + except OSError: + log.warning("No write permissions for %s", repro_path) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# DUMP REPROS +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def generate_compiler_repro_exported_program( + exported_program: ExportedProgram, + *, + options: Optional[dict[str, str]] = None, + stable_output: bool = False, + save_dir: Optional[str] = None, +) -> str: + model_str = textwrap.dedent( + f""" +{generate_env_vars_string(stable_output=stable_output)} +import torch +import torch._inductor.inductor_prims + +{generate_config_string(stable_output=stable_output)} + +isolate_fails_code_str = None + +{extra_imports} + + """ + ) + if not stable_output: + model_str += f"# torch version: {torch.version.__version__}\n" + if hasattr(torch.version, "cuda"): + model_str += f"# torch cuda version: {torch.version.cuda}\n" + if hasattr(torch.version, "git_version"): + model_str += f"# torch git version: {torch.version.git_version}\n\n\n" + model_str += _cuda_system_info_comment() + if save_dir: + ep_path = os.path.join(save_dir, "exported_program.pt2") + else: + ep_path = "exported_program.pt2" + torch.export.save(exported_program, ep_path) + + model_str += f"exported_program = torch.export.load('{ep_path}')\n" + model_str += "# print(exported_program.graph)\n" + model_str += f"config_patches={options}\n" + return model_str + + +def repro_load_args(load_args: Any, save_dir: Optional[str]) -> tuple[Any]: + if not hasattr(load_args, "_version"): + log.warning( + "load_args does not have a _version attribute, please file a bug to PyTorch " + "and describe how you generate this repro script" + ) + else: + if load_args._version > 0: + log.warning( + "load_args is version %s, but this version of PyTorch only supports " + "version 0. We will try to run it anyway but there may be an incompatibility; " + "if so, try upgrading your version of PyTorch.", + load_args._version, + ) + + nop_reader = NopInputReader() + load_args(nop_reader) + + with tqdm(desc="Loading inputs", total=nop_reader.total) as pbar: + input_reader = InputReader(save_dir=save_dir, pbar=pbar) + load_args(input_reader) + args = input_reader.args + + return tuple(args) + + +def repro_common( + options: Any, exported_program: ExportedProgram +) -> tuple[torch.fx.GraphModule, Any, Any]: + # pyrefly: ignore [bad-assignment] + torch._inductor.config.generate_intermediate_hooks = True + mod = exported_program.module(check_guards=False) + args, kwargs = exported_program.example_inputs + return mod, args, kwargs # type: ignore[return-value] + + +def repro_get_args( + options: Any, + exported_program: ExportedProgram, + config_patches: Optional[dict[str, Any]], +) -> tuple[torch.fx.GraphModule, Any, Any]: + mod, args, kwargs = repro_common(options, exported_program) + return mod, args, kwargs + + +def repro_run( + options: Any, + exported_program: ExportedProgram, + config_patches: Optional[dict[str, Any]], +) -> None: + from torch._inductor import _aoti_compile_and_package_inner + + gm, args, kwargs = repro_common(options, exported_program) + + from torch.cuda import synchronize + + _aoti_compile_and_package_inner( + gm, + args, + kwargs, + load_and_run=True, + check_accuracy=options.accuracy, + inductor_configs=config_patches, + ) + + need_sync = False + + for arg in args: + if isinstance(arg, torch.Tensor) and arg.is_cuda: + need_sync = True + break + + if need_sync: + synchronize() # ensure segfaults are surfaced + + +def export_for_aoti_minifier( + gm: torch.nn.Module, + tuple_inputs: tuple[Any], + strict: bool = False, + skip_export_error: bool = True, +) -> Optional[torch.nn.Module]: + # Some graphs cannot be used for AOTI/export (illegal graphs), these should be + # considered as graphs that don't fail in the minifier, so the minifier keeps searching. + # In these case, we return None. Otherwise, we return the exported graph module. + # This won't affect the minifier result because the minifier is only responsible for catching + # errors in AOTI, not export. + # + # Please add to this list of illegal graphs if you change the implementation here. + # - graph output is not allowed by export + # + # If skip_export_error=True, then the errors in export will not be raised, and the minifier + # will keep exploring and ignore this graph. + from torch._dynamo.exc import UserError, UserErrorType + + try: + ep = torch.export.export(gm, tuple_inputs, strict=strict) + gm = ep.module(check_guards=False) + return gm + except Exception as e: + if skip_export_error: + return None + if isinstance(e, UserError) and e.error_type == UserErrorType.INVALID_OUTPUT: + # graph output is not allowed by export when strict=True + return None + if isinstance(e, RuntimeError): + # graph output is not allowed by export when strict=False + pattern = r"Found .* in output, which is not a known type\." + if re.search(pattern, str(e)) is not None: + return None + raise AOTIMinifierError(e) from e + # we should never reach here + return None + + +def repro_minify( + options: Any, + exported_program: ExportedProgram, + config_patches: Optional[dict[str, Any]], +) -> None: + from functorch.compile import minifier + from torch._inductor import _aoti_compile_and_package_inner + from torch._inductor.compile_fx import _aoti_flatten_inputs + + mod, args, kwargs = repro_common(options, exported_program) + + # update serialized_in_spec and serialized_out_spec + flat_example_inputs, inductor_configs = _aoti_flatten_inputs( + mod, args, kwargs, options=config_patches + ) + compiler_name = "aot_inductor" + assert options.minifier_export_mode in ["dynamo", "python"] + strict = options.minifier_export_mode == "dynamo" + skip_export_error = options.skip_export_error + + from torch.cuda import synchronize + + need_sync = False + + for arg in args: + if isinstance(arg, torch.Tensor) and arg.is_cuda: + need_sync = True + break + + def module_fails( + gm: torch.fx.GraphModule, + flat_example_inputs: list[Any], + check_str: Optional[str] = None, + ) -> bool: + # Need to export first so the in_spec and out_spec are populated + tuple_inputs = tuple(flat_example_inputs) + # pyrefly: ignore [bad-assignment] + gm = export_for_aoti_minifier( + gm, tuple_inputs, strict=strict, skip_export_error=skip_export_error + ) + + # Some graphs cannot be used for AOTI/export (illegal graphs), these should be + # considered as graphs that don't fail in the minifier, so the minifier keeps searching. + if gm is None: + return False + + assert isinstance(gm, torch.fx.GraphModule) + + try: + _aoti_compile_and_package_inner( + gm, + tuple_inputs, + load_and_run=True, + check_accuracy=options.accuracy, + inductor_configs=inductor_configs, + ) + if need_sync: + synchronize() # ensure segfaults are surfaced + return False + except Exception as e: + if check_str is not None and check_str not in repr(e): + return False + return True + + minifier( + mod, + flat_example_inputs, + module_fails=functools.partial(module_fails, check_str=options.check_str), + dump_state=functools.partial( + dump_compiler_graph_state, + compiler_name=compiler_name, + config_patches=config_patches, + accuracy=options.accuracy, + strict=strict, + ), + save_dir=options.save_dir, + offload_to_disk=options.offload_to_disk, + skip_offload=options.skip_saving_eager_intermediates, + skip_sanity=options.skip_sanity, + max_granularity=options.max_granularity, + ) + + +def run_repro( + exported_program: ExportedProgram, + *, + config_patches: Optional[dict[str, str]] = None, + command: str = "run", + accuracy: Union[bool, str] = "", + save_dir: Optional[str] = None, + tracing_mode: Optional[str] = None, + check_str: Optional[str] = None, + minifier_export_mode: str = "python", + skip_export_error: bool = True, + **more_kwargs: Any, +) -> Any: + for k in more_kwargs: + log.warning( + "Unrecognized kwarg %s; perhaps this repro was made on a newer version of PyTorch", + k, + ) + + if accuracy is True: + accuracy = "accuracy" + elif accuracy is False: + accuracy = "" + + parser = argparse.ArgumentParser( + description=f"""\ +An AOTI repro script, typically triggering a bug in PyTorch AOTInductor. +When run with no arguments, this script defaults to running '{command}'. +Extra flags may be available; to find out more, try '{command} --help'. +There are also alternate subcommands available, see below. + +default settings on this script: + {accuracy=} + {tracing_mode=} + {save_dir=} + {check_str=} +""", + formatter_class=argparse.RawTextHelpFormatter, + ) + + def common_flags(parser: argparse.ArgumentParser) -> None: + accuracy_group = parser.add_mutually_exclusive_group() + accuracy_group.add_argument( + "--no-accuracy", + dest="accuracy", + action="store_const", + const="", + default=accuracy, + help="do not test accuracy, just run the module and see if it errors", + ) + accuracy_group.add_argument( + "--accuracy", + action="store_const", + const="accuracy", + default=accuracy, + help="""\ +test if the RMSE between the compiled module and the fp64 reference is greater +than eager and the fp64 reference. This is usually more reliable than the +standard allclose test, as we expect numeric differences from compiling, often +improving accuracy over eager. RMSE test allows for compiled module to +diverge greatly from eager, as long as this divergence moves it closer to the +'true' mathematical value of the network. Caveats: (1) double precision can +still suffer from rounding error, so it is not a perfect reference (see for +example 'Herbie: Automatically Improving Floating Point Accuracy') for +approaches that detect the necessary working precision and compute it in +arbitrary precision floating point; unfortunately, this is not practical for +tensor computation; (2) if there are not enough samples in the output being +compared, we may get unlucky and have an unlucky greater RMSE than eager; this +could be overcome by applying a more rigorous statistical test at some +p-value, which we leave for future work. +""", + ) + accuracy_group.add_argument( + "--strict-accuracy", + dest="accuracy", + action="store_const", + const="strict_accuracy", + default=accuracy, + help="""\ +by default, when doing accuracy minification we will reject reductions which +change the divergence from a floating point divergence to a integral/boolean +divergence. This is because some operations like ReLU involve temporarily +sharp boundaries that smooth out again afterwards; without requiring +divergence on floating point, the minifier will often fixate on divergent +boolean tensor even though this is not the true source of the divergence. +However, rejecting these reductions makes it more difficult for the minifier +to make process. Using this option will let the minifier progress for ALL +divergences--you just might not end up with a useful repro in the end.""", + ) + + parser.add_argument( + "--save-dir", + type=str, + default=save_dir, + metavar="DIR", + help="directory where saved inputs live", + ) + parser.add_argument( + "--no-save-dir", + dest="save_dir", + action="store_const", + const=None, + help="don't use any directory for saved inputs", + ) + + subparsers = parser.add_subparsers( + dest="command", metavar="{run,minify}", required=True + ) + + parser_run = subparsers.add_parser( + "run", + help="just run the repro", + ) + common_flags(parser_run) + + parser_minify = subparsers.add_parser( + "minify", help="run the minifier on the repro" + ) + common_flags(parser_minify) + parser_get_args = subparsers.add_parser("get_args", help="get the args") + common_flags(parser_get_args) + parser_minify.add_argument( + "--skip-saving-eager-intermediates", + action="store_true", + help="skip saving eager intermediates on --minify", + ) + parser_minify.add_argument( + "--offload-to-disk", + action="store_true", + help="during minification, offload delta debugging intermediates to disk. Use if you're OOMing", + ) + parser_minify.add_argument( + "--skip-sanity", + action="store_true", + help="skip sanity check at beginning of minification on original graph", + ) + parser_minify.add_argument( + "--max-granularity", + type=int, + default=None, + help="start at this granularity and work down; must be power of 2", + ) + parser_minify.add_argument( + "--check-str", + type=str, + default=check_str, + help="require minified program to fail with error containing this string", + ) + parser_minify.add_argument( + "--minifier-export-mode", + type=str, + default=minifier_export_mode, + help=( + "The export mode used in minifier, either dynamo or python." + "`dynamo` corresponds to strict=True, and `python` corresponds to strict=False." + ), + ) + parser_minify.add_argument( + "--skip-export-error", + type=bool, + default=skip_export_error, + help="Skip intermediate graphs that cannot be exported.", + ) + + # Run the repro in the context of minification, inverting exit code meaning + parser_minifier_query = subparsers.add_parser( + "minifier-query", + ) + common_flags(parser_minifier_query) + parser_minifier_query.add_argument( + "--check-str", + type=str, + default=check_str, + help="require minified program to fail with error containing this string", + ) + + args = None + if len(sys.argv) <= 1: + args = [command, *sys.argv[1:]] + + options = parser.parse_args(args) + COMMAND_FNS = { + "minify": repro_minify, + "run": repro_run, + "get_args": repro_get_args, + } + return COMMAND_FNS[options.command]( + options, exported_program, config_patches=config_patches + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2ac31eeee5362e1d1becbdeb6199ec70cea5c0e2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__init__.py @@ -0,0 +1,230 @@ +""" +This package implements variable tracking and symbolic execution capabilities for Dynamo, +which are essential for converting Python code into FX graphs. It provides a comprehensive +set of variable types that handle different Python constructs during tracing. + +Each variable type (like BuiltinVariable, TensorVariable, NNModuleVariable, etc.) is responsible +for tracking and symbolically executing operations on specific Python objects. This enables +Dynamo to: +- Track the flow of values through Python code +- Maintain correct semantics during graph conversion +- Handle complex Python features like context managers, iterators, and custom objects +- Support both eager and symbolic execution modes + +The VariableTracker base class provides the foundation for all variable types, with each +subclass implementing specific behavior for different Python constructs. This modular design +allows Dynamo to accurately trace and optimize Python code while preserving its semantics. +""" + +from .base import VariableTracker +from .builtin import BuiltinVariable +from .constant import ConstantVariable, EnumVariable +from .ctx_manager import ( + CatchWarningsCtxManagerVariable, + ContextWrappingVariable, + CUDADeviceVariable, + DeterministicAlgorithmsVariable, + DisabledSavedTensorsHooksVariable, + DualLevelContextManager, + DynamoConfigPatchVariable, + ErrorOnGraphBreakVariable, + FSDPParamGroupUseTrainingStateVariable, + FxTracebackAnnotateVariable, + GradIncrementNestingCtxManagerVariable, + GradInplaceRequiresGradCtxManagerVariable, + GradModeVariable, + InferenceModeVariable, + JvpIncrementNestingCtxManagerVariable, + SDPAKernelVariable, + SetFwdGradEnabledContextManager, + TemporarilyPopInterpreterStackCtxManagerVariable, + VmapIncrementNestingCtxManagerVariable, + WithEnterFunctionVariable, + WithExitFunctionVariable, +) +from .dicts import ( + ConstDictVariable, + DefaultDictVariable, + DictKeySetVariable, + FrozensetVariable, + MappingProxyVariable, + NNModuleHooksDictVariable, + SetVariable, +) +from .distributed import BackwardHookVariable, DistributedVariable, PlacementVariable +from .functions import ( + BuiltinMethodVariable, + CollectionsNamedTupleFunction, + CreateTMADescriptorExperimentalVariable, + CreateTMADescriptorStableVariable, + FunctionDecoratedByContextlibContextManagerVariable, + FunctoolsPartialVariable, + FunctoolsWrapsVariable, + LocalGeneratorFunctionVariable, + LocalGeneratorObjectVariable, + NestedUserFunctionVariable, + PolyfilledFunctionVariable, + PyTreeGetNodeTypeFunctionVariable, + PyTreeTreeIsLeafFunctionVariable, + SkipFunctionVariable, + TMADescriptorExperimentalVariable, + TMADescriptorStableVariable, + UserFunctionVariable, + UserMethodVariable, + WrapperUserFunctionVariable, + WrapperUserMethodVariable, +) +from .higher_order_ops import ( + FunctionalCallVariable, + FunctorchHigherOrderVariable, + ReparametrizeModuleCallVariable, + TorchHigherOrderOperatorVariable, +) +from .iter import ( + CountIteratorVariable, + FilterVariable, + IteratorVariable, + ItertoolsVariable, + MapVariable, + ObjectIteratorVariable, + RepeatIteratorVariable, + ZipVariable, +) +from .lazy import LazyVariableTracker +from .lists import ( + BaseListVariable, + ListIteratorVariable, + ListVariable, + NamedTupleVariable, + RangeVariable, + SliceVariable, + TupleIteratorVariable, + TupleVariable, +) +from .misc import ( + AutogradFunctionContextVariable, + AutogradFunctionVariable, + CellVariable, + DeletedVariable, + ExceptionVariable, + GetAttrVariable, + LambdaVariable, + MethodWrapperVariable, + NewGlobalVariable, + NumpyVariable, + PythonModuleVariable, + RandomClassVariable, + RandomVariable, + StringFormatVariable, + SuperVariable, + TorchVersionVariable, + TypingVariable, + UnknownVariable, + WeakRefVariable, +) +from .nn_module import ( + FSDPManagedNNModuleVariable, + NNModuleVariable, + UnspecializedBuiltinNNModuleVariable, + UnspecializedNNModuleVariable, +) +from .optimizer import OptimizerVariable +from .sdpa import SDPAParamsVariable +from .streams import EventVariable, StreamContextVariable, StreamVariable +from .tensor import ( + DataPtrVariable, + FakeItemVariable, + NumpyNdarrayVariable, + SymNodeVariable, + TensorVariable, + UnspecializedPythonVariable, + UntypedStorageVariable, +) +from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable +from .user_defined import ( + FrozenDataClassVariable, + MutableMappingVariable, + RemovableHandleVariable, + UserDefinedClassVariable, + UserDefinedDictVariable, + UserDefinedExceptionClassVariable, + UserDefinedExceptionObjectVariable, + UserDefinedListVariable, + UserDefinedObjectVariable, + UserDefinedSetVariable, + UserDefinedTupleVariable, +) + + +__all__ = [ + "AutogradFunctionContextVariable", + "AutogradFunctionVariable", + "BackwardHookVariable", + "BaseListVariable", + "BuiltinVariable", + "CatchWarningsCtxManagerVariable", + "ConstantVariable", + "ConstDictVariable", + "ContextWrappingVariable", + "CountIteratorVariable", + "CreateTMADescriptorExperimentalVariable", + "CreateTMADescriptorStableVariable", + "CUDADeviceVariable", + "DataPtrVariable", + "DefaultDictVariable", + "DeletedVariable", + "DeterministicAlgorithmsVariable", + "DictKeySetVariable", + "DynamoConfigPatchVariable", + "EnumVariable", + "FakeItemVariable", + "GetAttrVariable", + "GradModeVariable", + "IteratorVariable", + "ItertoolsVariable", + "LambdaVariable", + "LazyVariableTracker", + "ListIteratorVariable", + "ListVariable", + "NamedTupleVariable", + "NestedUserFunctionVariable", + "CellVariable", + "NewGlobalVariable", + "NNModuleVariable", + "NumpyNdarrayVariable", + "NumpyVariable", + "OptimizerVariable", + "PlacementVariable", + "PolyfilledFunctionVariable", + "PythonModuleVariable", + "RangeVariable", + "RemovableHandleVariable", + "RepeatIteratorVariable", + "SDPAParamsVariable", + "ErrorOnGraphBreakVariable", + "SkipFunctionVariable", + "SliceVariable", + "StringFormatVariable", + "SuperVariable", + "TemporarilyPopInterpreterStackCtxManagerVariable", + "TensorVariable", + "TMADescriptorExperimentalVariable", + "TMADescriptorStableVariable", + "TorchCtxManagerClassVariable", + "TorchInGraphFunctionVariable", + "TorchVersionVariable", + "TupleVariable", + "UnknownVariable", + "UnspecializedNNModuleVariable", + "UnspecializedPythonVariable", + "UntypedStorageVariable", + "UserDefinedClassVariable", + "UserDefinedTupleVariable", + "UserDefinedObjectVariable", + "UserFunctionVariable", + "UserMethodVariable", + "VariableTracker", + "WithEnterFunctionVariable", + "WithExitFunctionVariable", + "MappingProxyVariable", +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bb83052e85bc5a019211912da8807b0479a7cc7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/base.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/base.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ee3e3b5d8cf48734e6250837d2a3af5ceea28f6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/base.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/constant.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/constant.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a7907ee85675c4b7b4f06e3a64309e64d6710ce Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/constant.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/ctx_manager.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/ctx_manager.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3aa1b140abf2bb3c1f8dfc8f30c8e3560b6ea68f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/ctx_manager.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/dicts.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/dicts.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48ad7ead2b81a88e820541757f053ea264f28d49 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/dicts.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/distributed.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/distributed.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f6a95eacee61071506a8b65e0064927da74a4e3 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/distributed.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/iter.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/iter.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e028cb713218c710086945934b1ecd91c852aef0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/iter.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/lazy.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/lazy.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da27bcb16e8769011c7832a09297d6350bc6b4da Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/lazy.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/lists.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/lists.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1dc6b2504301bd0d1c39d8252f3805138fd2257 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/lists.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/misc.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/misc.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc359544c8c6b79e86fc2e3e8fe3fc3cc3bca088 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/misc.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/nn_module.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/nn_module.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c30606d485d089ccc665b35e0f5d56728dc717d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/nn_module.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/optimizer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/optimizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bae185b2add89904f31b5a701e8c970dd564349 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/optimizer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/script_object.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/script_object.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9aa510773a9056bd08aab7e98e56592d6a2db4f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/script_object.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/sdpa.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/sdpa.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3362315d2f0e86531d599392b15c323d457f020d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/sdpa.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/streams.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/streams.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4480bf385c74bb88347a8feb27fb62f34309f38 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/streams.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/tensor.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/tensor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0b54f107a432ffa81dc2c5b898768af04c4bbe2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/tensor.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/torch_function.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/torch_function.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63aeda064c8b06fa8df7e73e0a0e1633c932f4e4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/__pycache__/torch_function.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/base.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/base.py new file mode 100644 index 0000000000000000000000000000000000000000..af63c4c9d75999a677d6b1c327ea58b165b2520b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/base.py @@ -0,0 +1,825 @@ +""" +Core variable tracking functionality for Dynamo. This module defines the fundamental +classes and systems used to track and manage variables during Dynamo's operation. + +The module provides: +1. VariableTracker - The base class for tracking variables during compilation +2. MutationType system - Classes for tracking and managing mutations to variables +3. Source type management - Utilities for tracking variable origins and scope +4. Variable state management - Tools for managing variable state and transformations + +These components form the foundation of Dynamo's variable handling system, +enabling accurate tracking and transformation of Python code into optimized +computations. +""" + +import collections +import logging +from collections.abc import Callable, ItemsView, KeysView, Sequence, ValuesView +from enum import Enum +from typing import Any, NoReturn, Optional, TYPE_CHECKING + +from torch._guards import Guard +from torch.fx.proxy import Node + +from .. import graph_break_hints, variables +from ..current_scope_id import current_scope_id +from ..exc import raise_observed_exception, unimplemented +from ..guards import GuardBuilder, install_guard +from ..source import AttrSource, Source +from ..utils import cmp_name_to_op_mapping, istype + + +if TYPE_CHECKING: + from ..codegen import PyCodegen + from ..symbolic_convert import InstructionTranslator + from .constant import ConstantVariable + from .functions import UserFunctionVariable + + +log = logging.getLogger(__name__) + + +class SourceType(Enum): + """ + This Enum divides VariableTracker into 2 cases, depending on the variable + it represents: + - already existed that Dynamo began tracking while introspection (Existing) + - is a new variable that is created during Dynamo introspection (New) + + In general, we have these invariants: + 1. for `VariableTracker` associated with `Existing`, its `source` field must not be None. + 2. for `VariableTracker` associated with `New`, most of the time its + `source` field is None, except for cases like side effect codegen for + `AttributeMutationNew`, during which we generate a + `LocalSource('tmp...')` for such variable, to facilitate codegen. + """ + + Existing = 0 + New = 1 + + +class MutationType: + """ + Base class for Variable.mutation_type. It encodes information about + 1. The type of mutation Dynamo allows on the variable. + 2. Whether the value represented by this variable already existed before + Dynamo tracing. + """ + + def __init__(self, typ: SourceType) -> None: + # In HigherOrderOperator tracing, we need to distinguish + # between MutationTypes inside the HigherOrderOperator and + # ones outside it. For example, it is not safe to mutate + # `a` in the following example because it was constructed + # in a different scope. + # + # def f(x): + # a = 1 + # def g(x): + # nonlocal a + # a = 2 + # return x + # return wrap(g, x) + a + # + # We use self.scope to distinguish this. + # scope == 0: The object was an existing variable + # scope == 1: The object was created while Dynamo + # was introspecting a function + # (and no HigherOrderOps were involved) + # scope >= 2: The object was created through + # Dynamo introspection of a HigherOrderOp. + # The exact number corresponds to the level + # of nested HigherOrderOps. + if typ is SourceType.Existing: + self.scope = 0 + elif typ is SourceType.New: + self.scope = current_scope_id() + else: + unimplemented( + gb_type="Unsupported SourceType", + context=f"MutationType.__init__ {self} {typ}", + explanation=f"Dynamo does not support the type `{typ}`", + hints=[ + "This branch is not supposed to be reachable.", + *graph_break_hints.DYNAMO_BUG, + ], + ) + + +class ValueMutationNew(MutationType): + """ + This case of VariableTracker.mutation_type marker indicates + 1. Dynamo allows mutation on the value itself (rather than its attributes). + 2. The value is created by the bytecode Dynamo is tracing through. + + For instance, Dynamo could model a newly created list with this marker, + indicating that while we need to model mutations to this list, we don't have + to emit bytecode for these mutations if the list doesn't escape into the + Python world. + """ + + def __init__(self) -> None: + super().__init__(SourceType.New) + + def __hash__(self) -> int: + return id(self) + + def __eq__(self, other: object) -> bool: + return self is other + + +class ValueMutationExisting(MutationType): + """ + This case of VariableTracker.mutation_type marker indicates + 1. Dynamo allows mutation on the value itself (rather than its attributes). + 2. The value exists before Dynamo tracing started. + + For instance, Dynamo could model a pre-existing list with this marker, + indicating that if we encounter mutations to this list, we need to buffer + and re-apply those mutations after the graph runs, since the list might be + used afterwards in Python. + """ + + # A flag to indicate whether mutation happened on the associated + # `VariableTracker`. This enables SideEffects to accurately and quickly + # filter out which pre-existing values it needs to generate mutation for. + is_modified: bool + + def __init__(self, is_modified: bool = False) -> None: + super().__init__(SourceType.Existing) + self.is_modified = is_modified + + +class AttributeMutation(MutationType): + """ + This case of VariableTracker.mutation_type marker indicates that Dynamo + allows mutation on the value's attributes. + """ + + +class AttributeMutationExisting(AttributeMutation): + """ + This case of VariableTracker.mutation_type marker indicates + 1. Dynamo allows mutation on the value's attributes. + 2. The value exists before Dynamo tracing started. + + For instance, Dynamo could model a pre-existing object with this marker, + indicating that if we encounter mutations to this object, we need to buffer + then re-apply those mutations after the graph runs, since the object might + be used afterwards in Python. + """ + + def __init__(self) -> None: + super().__init__(SourceType.Existing) + + +class AttributeMutationNew(AttributeMutation): + """ + This case of VariableTracker.mutation_type marker indicates + 1. Dynamo allows mutation on the value's attributes. + 2. The value is created by the bytecode Dynamo is tracing through. + + For instance, Dynamo could model a newly created object with this marker, + indicating that while we need to model mutations to this object, we don't + have to emit bytecode for these mutations if the object doesn't escape into + the Python world. + """ + + def __init__(self, cls_source: Optional[Source] = None) -> None: + super().__init__(SourceType.New) + self.cls_source = cls_source + + +def _is_top_level_scope(scope_id: int) -> bool: + return scope_id == 1 + + +def is_side_effect_safe(m: MutationType) -> bool: + scope_id = current_scope_id() + + # In the top-level scope (if no HigherOrderOperators are involved), + # we are allowed to modify variables created in this scope as well + # as existing variables. + if _is_top_level_scope(scope_id): + return True + # Otherwise, only allow local mutation of variables created in the current scope + return m.scope == scope_id + + +# This helps users of `as_python_constant` to catch unimplemented error with +# more information; it inherits `NotImplementedError` for backward +# compatibility reasons. +class AsPythonConstantNotImplementedError(NotImplementedError): + vt: "VariableTracker" + + def __init__(self, vt: "VariableTracker") -> None: + super().__init__(f"{vt} is not a constant") + self.vt = vt + + +class VariableTrackerMeta(type): + all_subclasses: list[type] = [] + + def __new__( + mcs: type, name: str, bases: tuple[type, ...], attrs: dict[str, Any] + ) -> type: + # Determine which metaclass to use based on the class attributes + # Classes with _no_implicit_realize = True should NOT implicitly realize + # (they need standard isinstance behavior to avoid infinite recursion) + # Check if any base class has _no_implicit_realize set, or if it's in attrs + no_implicit_realize = attrs.get("_no_implicit_realize", False) or any( + getattr(base, "_no_implicit_realize", False) for base in bases + ) + if no_implicit_realize or name == "VariableTracker": + # Use base VariableTrackerMeta (no custom __instancecheck__) + return super().__new__(VariableTrackerMeta, name, bases, attrs) + else: + # Use ImplicitRealizingVariableTrackerMeta for all other subclasses + return super().__new__( + ImplicitRealizingVariableTrackerMeta, name, bases, attrs + ) + + def __init__( + cls: type, name: str, bases: tuple[type, ...], attrs: dict[str, Any] + ) -> None: + super().__init__(name, bases, attrs) # type: ignore[misc] + VariableTrackerMeta.all_subclasses.append(cls) + + +class ImplicitRealizingVariableTrackerMeta(VariableTrackerMeta): + def __instancecheck__(self, instance: object) -> bool: + """Make isinstance work with LazyVariableTracker""" + if instancecheck(LazyVariableTracker, instance): + return instance.lazy_isinstance(self) # pyrefly: ignore[missing-attribute] + return instancecheck(self, instance) + + +class VariableTracker(metaclass=VariableTrackerMeta): + """ + Base class for tracked locals and stack values + + VariableTracker instances are immutable and should be copied in + order to change them. + + Prefer the factory function VariableTracker.build() over VariableTracker.__init__(). + """ + + # fields to leave unmodified in apply() + _nonvar_fields = { + "value", + "guards", + "source", + "mutation_type", + "parents_tracker", + "user_code_variable_name", + } + + def clone(self, **kwargs: Any) -> "VariableTracker": + """Shallow copy with some (optional) changes""" + args = dict(self.__dict__) + args.update(kwargs) + return self.__class__(**args) + + @classmethod + def visit( + cls, + fn: Callable[["VariableTracker"], None], + value: Any, + cache: Optional[dict[int, Any]] = None, + ) -> None: + """ + Walk value and call fn on all the VariableTracker instances + """ + if cache is None: + cache = {} + + idx = id(value) + if idx in cache: + return + # save `value` to keep it alive and ensure id() isn't reused + cache[idx] = value + + if isinstance(value, VariableTracker): + value = value.unwrap() + fn(value) + value = value.unwrap() # calling fn() might have realized it + nonvars = value._nonvar_fields + for key, subvalue in value.__dict__.items(): + if key not in nonvars: + cls.visit(fn, subvalue, cache) + elif istype(value, (list, tuple)): + for subvalue in value: + cls.visit(fn, subvalue, cache) + elif istype(value, (dict, collections.OrderedDict)): + for subvalue in value.values(): + cls.visit(fn, subvalue, cache) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + def debug_repr(self) -> str: + # Intended to be overridden to provide more info + try: + return repr(self.as_python_constant()) + except NotImplementedError: + return repr(self) + + def python_type(self) -> type: + """ + Abstract method to be implemented by subclasses of VariableTracker. + + This method should return the type represented by the instance of the subclass. + The purpose is to provide a standardized way to retrieve the Python type information + of the variable being tracked. + + Returns: + type: The Python type (such as int, str, list, etc.) of the variable tracked by + the subclass. If the type cannot be determined or is not relevant, + leaving it undefined or invoking super() is always sound. + + Note: + This is an abstract method and may be overridden in subclasses. + + Example: + class SetVariable(VariableTracker): + def python_type(self): + return set + + Raises: + NotImplementedError: If the method is not implemented in a subclass. + """ + try: + return type(self.as_python_constant()) + except NotImplementedError: + raise NotImplementedError(f"{self} has no type") from None + + def python_type_name(self) -> str: + try: + return self.python_type().__name__ + except NotImplementedError: + return "" + + def as_python_constant(self) -> Any: + """For constants""" + raise AsPythonConstantNotImplementedError(self) + + def guard_as_python_constant(self) -> Any: + """Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants""" + try: + return self.as_python_constant() + except NotImplementedError: + unimplemented( + gb_type="Not a Python constant", + context=f"guard_as_python_constant {self}", + explanation=f"Failed to convert {self} into a Python constant.", + hints=[], + ) + + def is_python_constant(self) -> bool: + try: + self.as_python_constant() + return True + except NotImplementedError: + return False + + def is_constant_match(self, *values: Any) -> bool: + """ + Check if this variable is a python constant matching one of the given values. + + Examples: + var.is_constant_match(None) # True if var is constant None + var.is_constant_match(True, False) # True if var is constant True or False + var.is_constant_match(NotImplemented) # True if var is constant NotImplemented + """ + return False + + def is_constant_none(self) -> bool: + """Check if this variable is a constant None value.""" + return False + + def make_guard(self, fn: Callable[..., Any]) -> Guard: + if self.source: + return self.source.make_guard(fn) + raise NotImplementedError + + # TODO[@lucaskabela] - change this type to `InstructionTranslatorBase` + # and cascade that (large blast radius) + def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any: + """getattr(self, name) returning a python constant""" + raise NotImplementedError + + def is_symnode_like(self) -> bool: + """Return True for values that can participate in SymNode operations""" + return False + + def is_tensor(self) -> bool: + """Return True for TensorVariable instances""" + return False + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": + """getattr(self, name) returning a new variable""" + value = self.const_getattr(tx, name) + if not variables.ConstantVariable.is_literal(value): + raise NotImplementedError + source = self.source and AttrSource(self.source, name) + if source and not self.is_python_constant(): + # The second condition is to avoid guards on const getattr objects + # like __code__.co_argcount + install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH)) + return variables.ConstantVariable.create(value, source=source) + + def is_proxy(self) -> bool: + try: + self.as_proxy() + return True + except NotImplementedError: + return False + + def as_proxy(self) -> Any: + raise NotImplementedError(str(self)) + + def maybe_fx_node(self) -> Optional[Node]: + try: + proxy = self.as_proxy() + import torch.fx + + if isinstance(proxy, torch.fx.Proxy): + return proxy.node + return None + except NotImplementedError: + return None + + def reconstruct(self, codegen: "PyCodegen") -> None: + raise NotImplementedError + + def unpack_var_sequence(self, tx: Any) -> list["VariableTracker"]: + raise NotImplementedError + + def force_unpack_var_sequence(self, tx: Any) -> list["VariableTracker"]: + # like unpack_var_sequence, but should only be used when it is + # safe to eagerly (vs. lazily) unpack this variable. + # e.g. map(f, x) is normally evaluated lazily but sometimes + # we want to force eager unpacking, e.g. when converting to a list. + # NOTE: this method is allowed to mutate the VariableTracker, so + # it should only be called once. + return self.unpack_var_sequence(tx) + + def has_unpack_var_sequence(self, tx: Any) -> bool: + try: + self.unpack_var_sequence(tx) + return True + except NotImplementedError: + return False + + # NB: don't call force_unpack_var_sequence, especially if it mutates! + def has_force_unpack_var_sequence(self, tx: Any) -> bool: + return self.has_unpack_var_sequence(tx) + + # Forces unpacking the var sequence while also applying a function to each element. + # Only use when it is safe to eagerly unpack this variable (like force_unpack_var_sequence). + # INVARIANT: variable must satisfy has_force_unpack_var_sequence() == True! + def force_apply_to_var_sequence( + self, tx: Any, fn: Callable[["VariableTracker"], Any] + ) -> None: + assert self.has_force_unpack_var_sequence(tx) + for v in self.unpack_var_sequence(tx): + fn(v) + + def call_obj_hasattr(self, tx: Any, name: str) -> "ConstantVariable": + unimplemented( + gb_type="Unsupported hasattr call", + context=f"call_obj_hasattr {self} {name}", + explanation=f"Dynamo does not know how to trace the function `{self.debug_repr()}`", + hints=[ + f"Avoid calling `hasattr({self.__class__.__name__}, {name})` in your code.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + def call_function( + self, + tx: Any, + args: Sequence["VariableTracker"], + kwargs: dict[str, "VariableTracker"], + ) -> "VariableTracker": + unimplemented( + gb_type="Unsupported function call", + context=f"call_function {self} {args} {kwargs}", + explanation=f"Dynamo does not know how to trace the function `{self.debug_repr()}`", + hints=[ + f"Avoid calling `{self.debug_repr()}` in your code.", + "Please report an issue to PyTorch.", + ], + ) + + def call_method( + self, + tx: Any, + name: str, + args: list["VariableTracker"], + kwargs: dict[str, "VariableTracker"], + ) -> "VariableTracker": + if name == "__len__" and self.has_unpack_var_sequence(tx): + assert not (args or kwargs) + return variables.ConstantVariable.create(len(self.unpack_var_sequence(tx))) + elif ( + name == "__getattr__" + and len(args) == 1 + and args[0].is_python_constant() + and not kwargs + ): + return self.var_getattr(tx, args[0].as_python_constant()) + elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs: + other = args[0] + if not isinstance(self, type(other)) and not ( + isinstance(self, variables.GetAttrVariable) + or isinstance(other, variables.GetAttrVariable) + ): + # NB: GetAttrVariable is a special case because sometimes an + # object can map to GetAttrVariable but other time as + # SkipFunctionVariable if it is an input to the compiled + # function, e.g. tensor.data_ptr + return variables.ConstantVariable.create(NotImplemented) + # NB : Checking for mutation is necessary because we compare + # constant values + if ( + not self.is_python_constant() + or not other.is_python_constant() + or tx.output.side_effects.has_pending_mutation(self) + or tx.output.side_effects.has_pending_mutation(other) + ): + unimplemented( + gb_type="Builtin `operator.*` comparison with constant `self` failed", + context=f"call_method {self} {name} {args} {kwargs}", + explanation=f"Failed to compare {self} with {other}, " + + f"because {other} is not a Python constant or its mutation check fails.", + hints=[], + ) + + try: + return variables.ConstantVariable.create( + cmp_name_to_op_mapping[name]( + self.as_python_constant(), other.as_python_constant() + ) + ) + except Exception as e: + raise_observed_exception( + type(e), + tx, + args=[list(map(variables.ConstantVariable.create, e.args))], + ) + hints = [ + f"Avoid calling `{self.python_type_name()}.{name}` in your code.", + "Please report an issue to PyTorch.", + ] + # additional hint for method calls on improperly constructed iterators + if isinstance(self, variables.UserDefinedObjectVariable) and name in ( + "__iter__", + "__next__", + ): + if isinstance(self.value, (KeysView, ItemsView, ValuesView)): + hints.append( + "Consider moving the creation of dict view object (e.g. `dict.keys()`, `dict.items()`,) " + "to the compiled region, instead of passing it as an input to the compiled region." + ) + hints.append( + "Dynamo does not fully support tracing builtin iterators (e.g. `map`, `zip`, `enumerate`) " + "passed in from uncompiled to compiled regions (e.g. `torch.compile(fn)(enumerate(...))`). " + "This can happen unintentionally if a previous graph break happens with a builtin iterator " + "in the local scope." + ) + hints.append( + "List/dict comprehensions in Python <= 3.11 result in implicit function calls, which Dynamo " + "cannot trace as a top level frame. Possible workarounds are (1) use a loop instead of a comprehension, " + "(2) fix any graph breaks in the function above the comprehension, (3) wrap the comprehension in a " + "function, or (4) use Python 3.12+." + ) + unimplemented( + gb_type="Unsupported method call", + context=f"call_method {self} {name} {args} {kwargs}", + explanation=f"Dynamo does not know how to trace method `{name}` of class `{self.python_type_name()}`", + hints=hints, + ) + + def call_tree_map( + self, + tx: Any, + tree_map_fn: "UserFunctionVariable", + map_fn: "VariableTracker", + rest: Sequence["VariableTracker"], + tree_map_kwargs: dict[str, "VariableTracker"], + ) -> "VariableTracker": + """Performance optimization to implement optree.tree_map faster than tracing it""" + is_leaf_var = tree_map_kwargs.get("is_leaf") + if is_leaf_var is not None and not is_leaf_var.is_constant_none(): + pred_result = is_leaf_var.call_function(tx, [self], {}) + try: + leaf_decision = pred_result.as_python_constant() + except NotImplementedError: + return self._tree_map_fallback( + tx, + tree_map_fn, + map_fn, + rest, + tree_map_kwargs, + ) + if leaf_decision: + return map_fn.call_function(tx, [self, *rest], {}) + + return self.call_tree_map_branch( + tx, + tree_map_fn, + map_fn, + rest, + tree_map_kwargs, + ) + + def call_tree_map_branch( + self, + tx: Any, + tree_map_fn: "UserFunctionVariable", + map_fn: "VariableTracker", + rest: Sequence["VariableTracker"], + tree_map_kwargs: dict[str, "VariableTracker"], + ) -> "VariableTracker": + """Emulate optree.tree_map without is_leaf/none_is_leaf checks (handled above)""" + return self._tree_map_fallback( + tx, + tree_map_fn, + map_fn, + rest, + tree_map_kwargs, + ) + + def _tree_map_fallback( + self, + tx: Any, + tree_map_fn: "UserFunctionVariable", + map_fn: "VariableTracker", + rest: Sequence["VariableTracker"], + tree_map_kwargs: dict[str, "VariableTracker"], + ) -> "VariableTracker": + tree_map_fn_copy = tree_map_fn.clone() + tree_map_fn_copy._maybe_call_tree_map_fastpath = lambda *args, **kwargs: None # type: ignore[missing-attribute] + log.debug( + "tree_map fastpath fallback triggered for %s (rest=%s, kwargs=%s)", + self, + rest, + tree_map_kwargs, + ) + return tree_map_fn_copy.call_function( + tx, + [map_fn, self, *rest], + tree_map_kwargs, + ) + + def set_name_hint(self, name: str) -> None: + pass + + def realize(self) -> "VariableTracker": + """Used by LazyVariableTracker to build the real VariableTracker""" + return self + + def unwrap(self) -> "VariableTracker": + """Used by LazyVariableTracker to return the real VariableTracker if it already exists""" + return self + + def is_realized(self) -> bool: + """Used by LazyVariableTracker to indicate an unrealized node""" + return True + + def next_variable(self, tx: Any) -> "VariableTracker": + unimplemented( + gb_type="Unsupported next() call", + context=f"next({self})", + explanation=f"Dynamo does not know how to trace calling `next()` on variable `{self}`.", + hints=[*graph_break_hints.USER_ERROR], + ) + + def is_strict_mode(self, tx: Any) -> bool: + return bool(tx.strict_checks_fn and tx.strict_checks_fn(self)) + + def is_mutable(self) -> bool: + """Whether Dynamo allows mutation on this variable.""" + return not self.is_immutable() + + def is_immutable(self) -> bool: + """Whether Dynamo bans mutation on this variable.""" + return self.mutation_type is None + + @staticmethod + def build( + tx: Any, + value: Any, + source: Optional[Source] = None, + ) -> Any: + """Create a new VariableTracker from a value and optional Source""" + if source is None: + return builder.SourcelessBuilder.create(tx, value) + else: + return variables.LazyVariableTracker.create(value, source) + + def is_python_hashable(self): + """ + Unlike the variable tracker's own __hash__, this method checks whether + the underlying Python object referenced by this variable tracker is hashable. + """ + try: + type_self = self.python_type() + except NotImplementedError: + type_self = type(self) + + unimplemented( + gb_type="Dynamo cannot determine whether the underlying object is hashable", + context=f"is_python_hashable {self}", + explanation=f"Dynamo does not know whether the underlying python object for {self} is hashable", + hints=[ + ( + f"Consider using a different type of object as the dictionary key instead of {type_self}." + ), + *graph_break_hints.SUPPORTABLE, + ], + ) + + def get_python_hash(self): + """ + Unlike the variable tracker’s own __hash__, this method is used by + ConstDictVariableTracker to compute the hash of the underlying key object. + """ + unimplemented( + gb_type="Dynamo cannot determine the hash of an object", + context=f"get_python_hash {self}", + explanation=f"Dynamo does not know the hash of the underlying python object for {self}", + hints=[ + ( + f"Consider using a different type of object as the dictionary key instead of {self.python_type()}." + ), + *graph_break_hints.SUPPORTABLE, + ], + ) + + def is_python_equal(self, other): + """ + NB - Deliberately not overriding the __eq__ method because that can + disable the __hash__ for the vt itself. + """ + unimplemented( + gb_type="Dynamo cannot determine the equality comparison of an object", + context=f"is_python_equal {self}", + explanation=f"Dynamo does not know the equality comparison of the underlying python object for {self}", + hints=[ + ( + f"Consider using a different type of object as the dictionary key instead of {self.python_type()}." + ), + *graph_break_hints.SUPPORTABLE, + ], + ) + + def __init__( + self, + *, + source: Optional[Source] = None, + mutation_type: Optional[MutationType] = None, + ) -> None: + super().__init__() + self.source = source + self.mutation_type = mutation_type + + # NOTE sometimes mutation_type is set afterwards for implementation + # convenience, we don't validate those cases at the moment. + if mutation_type is not None: + if isinstance(mutation_type, (ValueMutationNew, AttributeMutationNew)): + # If this fails, it's either + # 1. one mistakenly passed in a source + # 2. `mutation_type` is incorrect + assert source is None + else: + assert isinstance( + mutation_type, (ValueMutationExisting, AttributeMutationExisting) + ) + # If this fails, it's either + # 1. one forgot to pass in a source + # 2. `mutation_type` is incorrect + assert source is not None + + +def raise_type_error_exc(tx: Any, msg_str: str) -> NoReturn: + msg = variables.ConstantVariable.create(msg_str) + raise_observed_exception(TypeError, tx, args=[msg]) + + +def typestr(*objs: object) -> str: + if len(objs) == 1: + (obj,) = objs + if isinstance(obj, VariableTracker): + return str(obj) + else: + return type(obj).__name__ + else: + return " ".join(map(typestr, objs)) + + +instancecheck = type.__instancecheck__ +from . import builder +from .lazy import LazyVariableTracker diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..62c9bb896ef9bc4b92455f3ea71ecabdcb148be4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py @@ -0,0 +1,3957 @@ +# mypy: ignore-errors + +""" +This module contains classes and utilities for building variable trackers in Dynamo. +Variable trackers are used to convert Python values into symbolic representations +that can be traced and transformed during graph capture. + +The key classes are: + +- VariableBuilder: Handles source-tracked objects that need guards and proper + reconstruction in the output graph. Used for inputs, module attributes, etc. + +- SourcelessBuilder: Handles ephemeral objects created during tracing that don't + need source tracking or guards. Used for temporary lists, intermediate values, etc. + +Variable trackers enable Dynamo to track the flow of values through the program, +maintain guards for dynamic properties, and reconstruct values in the output graph. +The builders in this module handle converting Python values into appropriate +VariableTracker instances based on their type and usage context. +""" + +import abc +import collections +import contextlib +import copy +import dataclasses +import enum +import functools +import inspect +import itertools +import logging +import math +import operator +import random +import re +import sys +import traceback +import types +import weakref +from collections.abc import Callable, MutableMapping +from typing import Any, NamedTuple, Optional, TYPE_CHECKING, Union + +import sympy + +import torch +from torch import SymInt +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.graph_bytecode_inputs import ( + get_external_object_by_index, + register_user_object, +) +from torch._dynamo.utils import ( + get_metrics_context, + is_int_specialization_case, + is_torch_sym, + set_feature_use, +) +from torch._guards import TracingContext +from torch._higher_order_ops.flat_apply import flat_apply +from torch._higher_order_ops.torchbind import call_torchbind +from torch._library.opaque_object import is_opaque_type, is_opaque_value_type +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode +from torch._subclasses.meta_utils import is_sparse_any, safe_grad +from torch._utils_internal import justknobs_check +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental._dynamism import normalize_source_name +from torch.fx.experimental.sym_node import _DynamicScalar, DynamicInt +from torch.fx.experimental.symbolic_shapes import ( + _constrain_range_for_size, + _nested_int_aware_sort, + DimDynamic, + RelaxedUnspecConstraint, + StatefulSymbolicContext, + SubclassSymbolicContext, + SymbolicContext, + SymIntSymbolicContext, + TrackedFake, +) +from torch.fx.immutable_collections import immutable_dict, immutable_list +from torch.nn.utils._expanded_weights import ExpandedWeight +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + is_traceable_wrapper_subclass_type, +) +from torch.utils._sympy.value_ranges import ValueRanges +from torch.utils.weak import TensorWeakRef + +from .. import config, graph_break_hints, mutation_guard, replay_record, trace_rules +from ..device_interface import get_registered_device_interfaces +from ..exc import InternalTorchDynamoError, raise_observed_exception, unimplemented +from ..guards import GuardBuilder, install_guard, make_dupe_guard +from ..pgo import ( + auto_dynamic, + auto_unset, + FrameStateSizeEntry, + InferStride, + process_automatic_dynamic, +) +from ..side_effects import SideEffects +from ..source import ( + AttrProxySource, + AttrSource, + CallMethodItemSource, + ChainedSource, + ConstDictKeySource, + ConvertIntSource, + DictGetItemSource, + DictSubclassGetItemSource, + DynamicScalarSource, + FloatTensorSource, + GetItemSource, + GradSource, + is_constant_source, + is_from_closure_source, + is_from_global_source, + is_from_nonlocal_source, + is_from_optimizer_source, + is_from_unspecialized_nn_module_source, + ListGetItemSource, + LocalSource, + NonSerializableSetGetItemSource, + NumpyTensorSource, + OptimizerSource, + RandomValueSource, + Source, + SubclassAttrListSource, + TupleIteratorGetItemSource, + UnspecializedBuiltinNNModuleSource, + UnspecializedNNModuleSource, +) +from ..utils import ( + _extract_tensor_dict, + build_checkpoint_variable, + build_invoke_subgraph_variable, + clone_input, + common_constant_types, + dict_keys, + get_fake_value, + get_items_from_dict, + get_locals_to_steal, + get_static_address_type, + is_frozen_dataclass, + is_function, + is_function_or_wrapper, + is_invoke_subgraph, + is_lru_cache_wrapped_function, + is_namedtuple, + is_parameter_freezing, + is_typing, + is_utils_checkpoint, + is_wrapper_or_member_descriptor, + istype, + namedtuple_fields, + odict_values, + proxy_args_kwargs, + range_iterator, + set_example_value, + tensor_always_has_static_shape, + tuple_iterator, + tuple_iterator_getitem, + tuple_iterator_len, + unwrap_with_attr_name_if_wrapper, + wrap_fake_exception, +) +from .base import ( + AttributeMutationNew, + typestr, + ValueMutationExisting, + ValueMutationNew, + VariableTracker, + VariableTrackerMeta, +) +from .builtin import BuiltinVariable +from .constant import ConstantVariable, EnumVariable +from .ctx_manager import ( + AutocastModeVariable, + DynamoConfigPatchVariable, + ErrorOnGraphBreakVariable, + NullContextVariable, + PreserveVersionContextVariable, +) +from .dicts import ( + ConstDictVariable, + DefaultDictVariable, + DictKeySetVariable, + FrozensetVariable, + MappingProxyVariable, + SetVariable, +) +from .distributed import ( + DeviceMeshVariable, + PlacementClassVariable, + PlacementVariable, + ProcessGroupVariable, + WorldMetaClassVariable, +) +from .functions import ( + BuiltinMethodVariable, + CollectionsNamedTupleFunction, + CollectiveFunctionRewriteVariable, + CreateTMADescriptorExperimentalVariable, + CreateTMADescriptorStableVariable, + FunctoolsPartialVariable, + FunctoolsWrapsVariable, + SysFunctionVariable, + TracebackVariable, + TritonKernelVariable, + UserFunctionVariable, + UserMethodVariable, + WrapperUserFunctionVariable, +) +from .higher_order_ops import ( + LocalMapWrappedHigherOrderVariable, + TorchHigherOrderOperatorVariable, +) +from .iter import ItertoolsVariable +from .lazy import LazyVariableTracker +from .lists import ( + BaseListVariable, + ListIteratorVariable, + ListVariable, + NamedTupleVariable, + RangeVariable, + SizeVariable, + SliceVariable, + TupleIteratorVariable, + TupleVariable, +) +from .misc import ( + AutogradEngineVariable, + AutogradFunctionContextVariable, + AutogradFunctionVariable, + ComptimeVariable, + ConstantLikeVariable, + DebuggingVariable, + DelayGraphBreakVariable, + GetAttrVariable, + GetSetDescriptorVariable, + LambdaVariable, + LoggingLoggerVariable, + MethodWrapperVariable, + NumpyDTypeVariable, + NumpyVariable, + PythonModuleVariable, + RandomClassVariable, + RandomVariable, + SavedTensorBox, + TorchVersionVariable, + TypingVariable, + WeakRefVariable, +) +from .nn_module import ( + FSDPManagedNNModuleVariable, + UnspecializedBuiltinNNModuleVariable, + UnspecializedNNModuleVariable, +) +from .optimizer import OptimizerVariable +from .script_object import OpaqueObjectClassVariable, TorchScriptObjectVariable +from .sdpa import SDPAParamsVariable +from .streams import EventVariable, StreamContextVariable, StreamVariable +from .tensor import ( + NumpyNdarrayVariable, + supported_const_comparison_op_values, + SymNodeVariable, + TensorSubclassVariable, + TensorVariable, + UnspecializedPythonVariable, +) +from .torch import ( + DispatchKeySetVariable, + FuncTorchInterpreterVariable, + TorchCtxManagerClassVariable, + TorchInGraphFunctionVariable, +) +from .torch_function import ( + TensorWithTFOverrideVariable, + torch_function_mode_stack_state_mgr, + TorchFunctionModeVariable, +) +from .user_defined import ( + FrozenDataClassVariable, + IntWrapperVariable, + KeyedJaggedTensorVariable, + MutableMappingVariable, + SourcelessGraphModuleVariable, + UserDefinedClassVariable, + UserDefinedDictVariable, + UserDefinedExceptionClassVariable, + UserDefinedListVariable, + UserDefinedObjectVariable, + UserDefinedSetVariable, + UserDefinedTupleVariable, +) + + +try: + import numpy as np +except ModuleNotFoundError: + np = None + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + +log = logging.getLogger(__name__) +static_inputs_log = torch._logging.getArtifactLogger( + __name__, "cudagraph_static_inputs" +) + + +DimList = list + + +def safe_has_grad(t): + with torch._logging.hide_warnings(torch._logging._internal.safe_grad_filter): + return hasattr(t, "grad") + + +class _missing: + pass + + +@dataclasses.dataclass +class GraphArg: + source: Source + # TODO: storing a SymInt here but not a FakeTensor is a pretty strange + # thing to do. Probably should have example (which stores an int) and + # fake_example + _example: Union[TensorWeakRef, torch.SymInt] + # When True, this indicates that this GraphArg is a Python quantity (e.g., + # a float or int) which we pass to the FX graph as a Tensor. This + # controls how we codegen calls into the Dynamo graph: we will call + # torch.as_tensor on the quantity before passing it in. + # + # Note that we typically do not pass dynamic integers as tensors, because + # they will most frequently just be used for size computation. But this + # is a policy decision that we can change our mind on; in particular, when + # an int comes from a random number generator (e.g., random.randint), we + # DO pass it as a tensor. + # + # It's also worth noting that our current tracing rules for + # pass_arg_as_tensor as subtly broken: we just pun the variable as a + # 0d scalar Tensor and pray that the semantics are the same. Which they + # often are, but not necessarily. ezyang(May 2024) plans to fix this + # soon. + pass_arg_as_tensor: bool + fake_tensor: Optional[torch._subclasses.fake_tensor.FakeTensor] + # UnspecializedPythonVariable often masquerades as a tensor. + # We MUST NOT generate shape guard code + # that actually tries to access tensor properties on these values. + # is_tensor lets us tell if this graph arg actually is a tensor + # or not. + is_tensor: bool = True + # Sometimes, the Tensor we pass to example is freshly allocated (smh). + # Then we cannot only keep a weak reference to it. This lets you + # stash a strong reference too. + example_strong_ref: Optional[torch.Tensor] = None + + def __setattr__(self, name, value): + # Use object.__setattr__ to bypass Dynamo's STORE_ATTR interception. + # This is needed because when PYTORCH_TEST_WITH_DYNAMO=1, even internal + # GraphArg creation can be traced, and with replay_side_effects=False, + # normal STORE_ATTR bytecode only records mutations without applying them. + object.__setattr__(self, name, value) + + @property + def example(self): + if isinstance(self._example, TensorWeakRef): + r = self._example() + assert r is not None + return r + else: + return self._example + + def __post_init__(self): + if isinstance(self._example, torch.Tensor): + self._example = TensorWeakRef(self._example) + assert is_fake(self.fake_tensor) + + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.source) + + def erase(self): + self._example = None + self.example_strong_ref = None + + def __eq__(self, other): + return self.source.name == other.source.name + + +class BackwardStateGraphArg(GraphArg): + def __init__(self) -> None: + super().__init__( + source=None, + _example=BackwardState(), + pass_arg_as_tensor=False, + fake_tensor=None, + is_tensor=False, + ) + + def reconstruct(self, codegen: "PyCodegen"): + assert codegen.tx.output.backward_state_var + codegen.add_push_null( + lambda: codegen.load_import_from(BackwardState.__module__, "BackwardState") + ) + codegen.call_function(0, False) + codegen.dup_top() + codegen.store(codegen.tx.output.backward_state_var) + + +# All class-based iterators in itertools +# NOTE: use id() because some objects are not hashable, it will raise error during lookup +ITERTOOLS_TYPE_IDS: frozenset[int] = frozenset( + id(member) + for name, member in vars(itertools).items() + if not name.startswith("_") and inspect.isclass(member) +) +# Will be updated later in substitute_in_graph in torch/_dynamo/polyfills/itertools.py +ITERTOOLS_POLYFILLED_TYPE_IDS: set[int] = set() + +# Capture fn pointer at import time +# This is to guard against trying to mark the iterated tensors +# as static in case user overrides fn ptr +og_module_named_buffers_fn_ptr = torch.nn.Module.named_buffers +og_module_named_parameters_fn_ptr = torch.nn.Module.named_parameters + + +class VariableBuilder: + """Wrap a python value in a VariableTracker() instance""" + + def __init__( + self, + tx, + source: Source, + ) -> None: + assert source is not None, ( + "Consider SourcelessBuilder for ephemeral objects, usually objects created locally." + ) + assert TracingContext.try_get() is not None, "Expected active TracingContext" + super().__init__() + self.tx = tx + self.source = source + self.name = source.name + + def __call__(self, value): + if value in self.tx.output.side_effects: + side_effect_result = self.tx.output.side_effects[value] + dup_guard = make_dupe_guard(self.source, side_effect_result.source) + if dup_guard: + self.install_guards(dup_guard) + + if isinstance(value, torch.nn.Module) and isinstance( + side_effect_result, UnspecializedNNModuleVariable + ): + # This means that two nn module instances with different sources + # have the same id. NN modules are somewhat special objects, + # because we have to track their nn_module_stack for ease of + # use. But if we don't do anything, we will just return the + # older variable tracker with the older nn_module_stack. So, + # lets return the old variable tracker but update its + # nn_module_stack + side_effect_result.set_nn_module_stack_source(self.source) + return side_effect_result + + cached_vt = self.tx.output.variable_tracker_cache.lookup(value, self.source) + if cached_vt: + return cached_vt + + vt = self._wrap(value) + + if vt.source is None: + vt.source = self.source + + def _is_deduplicable_sym_variable(value, vt): + # Constants like 0, 1, 2, etc. can be unspecialized as SymNodeVariables sometimes, but we + # should NOT track them. If we use a single SymNodeVariable instance to track them + # across multiple uses, then guards created for one usage will incorrectly apply to + # all other usages of that constant, leading to unnecessary recompilations. + return ( + is_torch_sym(value) or isinstance(value, _DynamicScalar) + ) and isinstance(vt, SymNodeVariable) + + if ( + ( + self._can_lift_attrs_to_inputs(vt) + or _is_deduplicable_sym_variable(value, vt) + ) + and value not in self.tx.output.side_effects + and not is_wrapper_or_member_descriptor(value) + ): + vt = self.tx.output.side_effects.track_object_existing(value, vt) + + self.tx.output.variable_tracker_cache.add(value, self.source, vt) + return vt + + def _can_lift_attrs_to_inputs(self, vt): + return type(vt) in { + TensorVariable, + TensorWithTFOverrideVariable, + UserDefinedObjectVariable, + NumpyNdarrayVariable, + } + + def get_source(self): + return self.source + + def install_guards(self, *guards): + source = self.get_source() + try: + tmp = [source.make_guard(guard) for guard in guards] + except NotImplementedError: + return None + install_guard(*tmp, skip=1) + return {} + + @classmethod + def _type_dispatch(cls): + return cls._type_dispatch_impl(config.trace_numpy) + + @classmethod + @functools.cache + def _type_dispatch_impl(cls, trace_numpy): + # NB: Careful not to close over self to avoid ref cycle from lru_cache + entries = [ + ( + ( + torch.Tensor, + torch.nn.Parameter, + torch._subclasses.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ), + cls.wrap_tensor, + ), + ( + (tuple, list, odict_values, collections.deque, torch.Size), + cls.wrap_listlike, + ), + (tuple_iterator, cls.wrap_tuple_iterator), + (range_iterator, cls.wrap_range_iterator), + ((slice, range), cls.wrap_slice_range), + (tuple(common_constant_types), cls.wrap_literal), + (re.Pattern, cls.wrap_regex_pattern), + (weakref.ReferenceType, cls.wrap_weakref), + (torch.utils.hooks.RemovableHandle, cls.wrap_removable_handle), + (torch.jit.ScriptFunction, cls.wrap_jit_function), + (types.MappingProxyType, cls.wrap_mapping_proxy), + ] + + if trace_numpy and np: + entries.append((np.ndarray, cls.wrap_numpy_ndarray)) + + result = {} + for ts, fn in entries: + for t in ts if isinstance(ts, tuple) else (ts,): + assert t not in result + result[t] = fn + + return result + + def wrap_regex_pattern(self, value: re.Pattern): + # TODO(jansel): something like a REPR_MATCH might be more robust here + self.install_guards(GuardBuilder.ID_MATCH) + return ConstantLikeVariable(value) + + def wrap_weakref(self, value: weakref.ReferenceType): + self.install_guards(GuardBuilder.TYPE_MATCH) + return WeakRefVariable.build(self.tx, value, source=self.source) + + def wrap_removable_handle(self, value): + # This means that the removable handle was created in some other frame. + # Our current infra requires the hook to be registered and removed in + # the same frame. So graph break. + # Related test - PYTORCH_TEST_WITH_DYNAMO=1 python test/test_autograd.py -k TestAutograd.test_hooks + unimplemented( + gb_type="Attempted to represent unregistered RemovableHandle", + context="", + explanation="Dynamo attempted to build a representation of a torch.utils.hooks.RemovableHandle, " + "which is not supported. This happens because the RemovableHandle was created in another frame.", + hints=[], + ) + + def wrap_jit_function(self, value): + self.install_guards(GuardBuilder.TYPE_MATCH) + return WrapperUserFunctionVariable( + value, "_torchdynamo_inline", source=self.source + ) + + def wrap_mapping_proxy(self, value): + self.install_guards(GuardBuilder.TYPE_MATCH) + # This might be suboptimal compared to dict guards. But mappingproxy is + # not very common, so its ok to guard on all keys. + self.install_guards(GuardBuilder.MAPPING_KEYS_CHECK) + all_const = all(ConstantVariable.is_literal(k) for k in value) + + if not all_const: + unimplemented( + gb_type="non-const keys in mappingproxy", + context=f"non-const keys: {[k for k in value.keys() if not ConstantVariable.is_literal(k)]}", # noqa: SIM118 + explanation="Dynamo expects mappingproxy keys to be constants.", + hints=[ + "Ensure your mappingproxy keys are constants (e.g. int, float, strings)", + ], + ) + + def build_key_value(k, v): + key = ConstantVariable.create(k) + source_key = k + + source_value = GetItemSource(self.get_source(), source_key) + res_value = LazyVariableTracker.create(v, source_value) + + return key, res_value + + items = dict(build_key_value(k, v) for k, v in value.items()) + + # Create a dict_vt to be used in the mapping proxy variable + dict_vt = ConstDictVariable(items, source=None) + result = MappingProxyVariable(dict_vt, source=self.source) + return self.tx.output.side_effects.track_mutable(value, result) + + @classmethod + @functools.cache + def _id_dispatch( + cls, + ) -> dict[int, Callable[["VariableBuilder", Any], VariableTracker]]: + from ..comptime import comptime + + entries = [ + (comptime, lambda self, value: ComptimeVariable()), + ( + dataclasses.fields, + lambda self, value: LambdaVariable( + _dataclasses_fields_lambda, + source=self.source, + **self.install_guards(GuardBuilder.CLOSURE_MATCH), + ), + ), + (torch.__version__, lambda self, value: TorchVersionVariable()), + ] + + result = {} + for ts, fn in entries: + for t in ts if isinstance(ts, (tuple, list)) else (ts,): + assert t not in result + result[id(t)] = fn + + return result + + def _wrap(self, value): + # import here to avoid circular dependencies + from torch.utils._triton import ( + has_triton, + has_triton_experimental_host_tma, + has_triton_tensor_descriptor_host_tma, + ) + + from ..decorators import ( + DynamoConfigPatchProxy, + ErrorOnGraphBreakDecoratorContextManager, + ) + + if has_triton(): + from triton.runtime.autotuner import Autotuner + from triton.runtime.jit import JITFunction + else: + + class JITFunction: + pass + + class Autotuner: + pass + + # default implementations, in case we don't have triton (or the wrong triton version) + def create_1d_tma_descriptor(): + pass + + def create_2d_tma_descriptor(): + pass + + class TensorDescriptor: + @staticmethod + def from_tensor(): + pass + + if has_triton_experimental_host_tma(): + from triton.tools.experimental_descriptor import ( # noqa: F811 + create_1d_tma_descriptor, + create_2d_tma_descriptor, + ) + if has_triton_tensor_descriptor_host_tma(): + from triton.tools.tensor_descriptor import TensorDescriptor # noqa: F811 + + # Handle exact type() match + type_dispatch = self._type_dispatch().get(type(value)) + if type_dispatch is not None: + return type_dispatch(self, value) + + # Handle exact id() match + id_dispatch = self._id_dispatch().get(id(value)) + if id_dispatch is not None: + return id_dispatch(self, value) + + # Everything else (NB: order matters!) + if ( + isinstance(value, torch.Tensor) + and type(value) + not in ( + # These torch-native subclasses have overly restrictive + # `__torch_function__` which prevents Dynamo from reading their + # tensor attributes like `is_nested` or calling methods like + # `_is_view`. + torch.nn.parameter.UninitializedBuffer, + torch.nn.parameter.UninitializedParameter, + ExpandedWeight, + ) + and type(value) not in config.nontraceable_tensor_subclasses + ): + if ( + type(value).__torch_dispatch__ is torch.Tensor.__torch_dispatch__ + or is_traceable_wrapper_subclass(value) + ): + return self.wrap_tensor(value) + + if is_namedtuple(value): + self.install_guards(GuardBuilder.SEQUENCE_LENGTH) + output = [ + LazyVariableTracker.create( + getattr(value, name), + source=AttrSource(self.source, name), + ) + for name in namedtuple_fields(type(value)) + ] + result = NamedTupleVariable( + output, tuple_cls=type(value), source=self.source + ) + return self.tx.output.side_effects.track_object_existing(value, result) + elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)): + self.install_guards(GuardBuilder.TYPE_MATCH) + all_const = all(ConstantVariable.is_literal(k) for k in value) + + # For all_const, we don't have to guard on anything yet. We guard on + # keys lazily by adding a dict_getitem entry for each accessed key. + # For cases where we need to guard on all keys, we lazily put guards + # during the dict call_method (check dicts.py) + if not all_const: + # Guard on the key order + # This is not ideal, i.e., there is no need to guard on the key + # order. But we guard on the key order because of the complexity + # + # 1) For non-constant objects, we can't save the key in the + # guard context because it can be memory heavy. We can add + # weakrefs but this complicates the accesses. + # + # 2) For non-constant objects, we also have to guard on the keys + # (like TENSOR_MATCH on tensor). We might also have guards on + # the attributes of the keys (like tensor.grad). To make this + # work in tree structure is complicated. + # + # So, instead we guard on the key order. While guarding on key + # order, we just save the indices and use it to access keys and + # values. Indices are cheap to save. + self.tx.output.guard_on_key_order.add(self.source) + + # We need all the keys to be hashable. We do this within the + # _HashableTracker class in dicts.py + def build_key_value(i, k, v): + base = self.get_source() + if all_const: + key = ConstantVariable.create(k) + source_key = k + else: + source_key = ConstDictKeySource(base, i) + key = LazyVariableTracker.create(k, source_key) + source_value = DictGetItemSource(base, source_key) + res_value = LazyVariableTracker.create(v, source_value) + + return key, res_value + + # Ensure that we call dict.keys and not value.keys (which can call + # overridden keys method). In the C++ guards, we relied on + # PyDict_Next to traverse the dictionary, which uses the internal + # data structure and does not call the overridden keys method. + result = dict( + build_key_value(i, k, v) + for i, (k, v) in enumerate(get_items_from_dict(value)) + ) + + if istype(value, collections.defaultdict): + factory_source = AttrSource(self.source, "default_factory") + result = DefaultDictVariable( + result, + type(value), + default_factory=VariableBuilder(self.tx, factory_source)( + value.default_factory + ), + source=self.source, + ) + else: + result = ConstDictVariable( + result, user_cls=type(value), source=self.source + ) + + return self.tx.output.side_effects.track_mutable(value, result) + elif isinstance(value, torch.nn.Module): + return self.wrap_module(value) + elif ConstantVariable.is_literal(value): # non-atomic literals + return self.wrap_literal(value) + elif isinstance(value, torch.overrides.TorchFunctionMode): + var = TorchFunctionModeVariable(value, source=self.source) + self.tx.output.side_effects.track_object_existing(value, var) + return var + elif istype(value, set): + if any(isinstance(x, torch.Tensor) for x in value): + unimplemented( + gb_type="Attempted to wrap a set with tensors", + context="Python set containing torch.Tensor elements", + explanation=( + "Dynamo cannot trace sets of tensors. To get a stable ordering, " + "Dynamo needs to convert the set into a list and the order might not be " + "stable if the set contains tensors." + ), + hints=[ + "Use a dictionary where the keys are tensors.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + self.install_guards(GuardBuilder.TYPE_MATCH) + self.install_guards(GuardBuilder.SEQUENCE_LENGTH) + + # The list gives a ordering for the set items. The ordering is based + # on the Python hash and it is not related to object ordering inside + # the set object. The order being incorrect at runtime will lead to + # a recompilation. + L = list(value) + items = [ + LazyVariableTracker.create( + v, source=NonSerializableSetGetItemSource(self.source, i) + ) + for i, v in enumerate(L) + ] + result = SetVariable(items, source=self.source) + return self.tx.output.side_effects.track_object_existing(value, result) + elif istype(value, frozenset) and all( + ( + # For DBR quantization, we could get a frozenset of torch funcs. + (type(x) is types.BuiltinMethodType and x.__module__ == "torch") + or + # Another commonly used frozenset of types. + x in torch.utils._pytree.BUILTIN_TYPES + ) + for x in value + ): + # For the limited cases of frozenset here, we know the items won't + # change across runs, so we can safely create sourceless VTs for + # them and only guard on the frozenset id. + # TODO support source for sets and remove the special logics here. + items = [SourcelessBuilder.create(self.tx, v) for v in value] + self.install_guards(GuardBuilder.ID_MATCH) + return FrozensetVariable(items, source=self.source) + elif isinstance( + value, (enum.Enum, torch.DispatchKey, torch._C._functorch.TransformType) + ): + self.install_guards(GuardBuilder.ID_MATCH) + return EnumVariable(value=value, source=self.source) + elif DebuggingVariable.is_reorderable_logging_function(value): + # Put this above builtin_callable so that print() can be handled + # along with other builtin debugging functions + self.install_guards(GuardBuilder.BUILTIN_MATCH) + return DebuggingVariable(value, source=self.source) + elif isinstance(value, logging.Logger): + self.install_guards(GuardBuilder.TYPE_MATCH) + return LoggingLoggerVariable(value, source=self.source) + elif is_utils_checkpoint(value): + return build_checkpoint_variable(source=self.source) + elif is_invoke_subgraph(value): + return build_invoke_subgraph_variable(source=self.source) + elif LocalMapWrappedHigherOrderVariable.should_wrap_in_hop(value): + return LocalMapWrappedHigherOrderVariable.build(source=self.source) + elif isinstance(value, functools.partial): + func_src = AttrSource(self.get_source(), "func") + func_obj = VariableBuilder(self.tx, func_src)(value.func) + + args = [] + args_source = AttrSource(self.get_source(), "args") + for i, arg in enumerate(value.args): + args.append( + VariableBuilder(self.tx, GetItemSource(args_source, i))(arg) + ) + + keywords = {} + keywords_source = AttrSource(self.get_source(), "keywords") + for k, v in value.keywords.items(): + if not ConstantVariable.is_literal(k): + unimplemented( + gb_type="functools.partial() with non-literal keyword", + context=f"non-literal keyword: {k}", + explanation="functools.partial() expects literal/string keywords", + hints=[*graph_break_hints.USER_ERROR], + ) + keywords[k] = VariableBuilder( + self.tx, DictGetItemSource(keywords_source, k) + )(v) + + install_guard( + self.get_source().make_guard(GuardBuilder.TYPE_MATCH), + keywords_source.make_guard(GuardBuilder.DICT_KEYS_MATCH), + args_source.make_guard(GuardBuilder.SEQUENCE_LENGTH), + ) + return FunctoolsPartialVariable(func_obj, args, keywords) + elif is_typing(value): + # typing.List, typing.Mapping, etc. + self.install_guards(GuardBuilder.ID_MATCH) + return TypingVariable( + value, + source=self.source, + ) + elif np is not None and isinstance(value, np.generic): + # numpy array scalars: convert to 0D arrays + return self.wrap_numpy_ndarray(np.asarray(value)) + elif trace_rules.is_numpy(value): + assert np + if istype(value, types.MethodType): + # Dont guard on cython functions as they dont change ids + if inspect.isfunction(value.__func__): + install_guard( + AttrSource(self.source, "__func__").make_guard( + GuardBuilder.CLOSURE_MATCH + ) + ) + elif inspect.isclass(value): + self.install_guards(GuardBuilder.CLASS_MATCH) + elif inspect.isfunction(value): + self.install_guards(GuardBuilder.CLOSURE_MATCH) + elif callable(value): + self.install_guards(GuardBuilder.ID_MATCH) + else: + self.install_guards(GuardBuilder.TYPE_MATCH) + return NumpyVariable(value, source=self.source) + elif trace_rules.is_numpy_dtype(value): + self.install_guards(GuardBuilder.ID_MATCH) + return NumpyDTypeVariable(value, source=self.source) + elif trace_rules.is_numpy_type_info(value): + if isinstance(value, np.iinfo): + self.install_guards(GuardBuilder.TYPE_MATCH) + dt_source = AttrSource(self.source, "dtype") + install_guard(dt_source.make_guard(GuardBuilder.ID_MATCH)) + else: + self.install_guards(GuardBuilder.ID_MATCH) + return ConstantLikeVariable(value, source=self.source) + # NB: These can't be put in type_dispatch, they have to run later + elif CollectiveFunctionRewriteVariable.can_rewrite(value): + self.install_guards(GuardBuilder.CLOSURE_MATCH) + return CollectiveFunctionRewriteVariable.create( + self.tx, + value, + source=self.source, + ) + elif istype(value, torch.autograd.function.FunctionMeta): + self.install_guards(GuardBuilder.CLASS_MATCH) + return AutogradFunctionVariable( + value, + source=self.source, + ) + elif isinstance(value, torch.autograd.function.FunctionCtx): + actual_saved_tensors = None + try: + actual_saved_tensors = value.saved_tensors + except RuntimeError: + pass + + saved_tensors = [] + guards = [self.source.make_guard(GuardBuilder.TYPE_MATCH)] + if isinstance(actual_saved_tensors, tuple): + saved_tensors_source = AttrSource(self.source, "saved_tensors") + guards.append( + saved_tensors_source.make_guard(GuardBuilder.SEQUENCE_LENGTH) + ) + for i, v in enumerate(actual_saved_tensors): + saved_tensors.append( + VariableBuilder( + self.tx, GetItemSource(saved_tensors_source, i) + )(v) + ) + install_guard(*guards) + + return self.tx.output.side_effects.track_object_existing( + value, + AutogradFunctionContextVariable( + value, + source=self.source, + saved_tensors=SavedTensorBox(saved_tensors), + ), + ) + elif ( + isinstance(value, types.MethodType) + and istype( + getattr(value, "__self__", None), torch.autograd.function.FunctionMeta + ) + and getattr(value, "__name__", "") == "apply" + and value == getattr(value.__self__, "apply", None) + ): + # handle aliased autograd function `apply` calls + install_guard( + AttrSource(self.get_source(), "__func__").make_guard( + GuardBuilder.CLOSURE_MATCH + ) + ) + return GetAttrVariable( + AutogradFunctionVariable( + value.__self__, source=AttrSource(self.source, member="__self__") + ), + "apply", + ) + elif isinstance(value, torch._C._ImperativeEngine): + self.install_guards(GuardBuilder.ID_MATCH) + return AutogradEngineVariable(value, source=self.source) + elif ( + value + is torch._dynamo.external_utils.FakeCompiledAutogradEngine._exec_final_callbacks_stub + ): + self.install_guards(GuardBuilder.CLOSURE_MATCH) + return LambdaVariable( + lambda: UserFunctionVariable( + torch._dynamo.external_utils.FakeCompiledAutogradEngine.exec_final_callbacks, + ).call_function( + self.tx, + (self.tx.output.side_effects.get_ca_final_callbacks_var(),), + {}, + ) + ) + elif isinstance(value, DynamoConfigPatchProxy): + return DynamoConfigPatchVariable(value.changes) + elif isinstance(value, ErrorOnGraphBreakDecoratorContextManager): + return ErrorOnGraphBreakVariable(value.error_on_graph_break) + elif callable(value) and trace_rules.lookup_callable(value) is not None: + if trace_rules.is_callable_allowed(value): + self.tx.output.has_user_defined_allowed_in_graph = True + return trace_rules.lookup_callable(value).create_with_source( + value, source=self.source + ) + elif np and isinstance(value, np.number): + return self.wrap_unspecialized_primitive(value) + elif isinstance(value, HigherOrderOperator): + if value is torch._higher_order_ops.invoke_subgraph: + unimplemented( + gb_type="Attempted to wrap torch._higher_order_ops.invoke_subgraph", + context="", + explanation="Directly using invoke_subgraph is not supported. Use nested_compile_region", + hints=[], + ) + self.install_guards(GuardBuilder.TYPE_MATCH) + return TorchHigherOrderOperatorVariable.make(value, source=self.source) + elif isinstance(value, torch.cuda.StreamContext): + self.install_guards(GuardBuilder.ID_MATCH) + stream_source = AttrSource(self.source, "stream") + stream_var = VariableBuilder(self.tx, stream_source)(value.stream) + return StreamContextVariable.create(self.tx, stream_var) + elif isinstance(value, torch.Stream): + # This refers to the device-agnostic torch.Stream + self.install_guards(GuardBuilder.TYPE_MATCH) + index = register_user_object(value, self.source) + stream_proxy = self.tx.output.create_proxy( + "call_function", get_external_object_by_index, (index,), {} + ) + set_example_value(stream_proxy.node, value) + var = StreamVariable( + stream_proxy, value, source=self.source, user_object_index=index + ) + return self.tx.output.side_effects.track_object_existing(value, var) + elif isinstance(value, (torch._C._SDPAParams)): + self.install_guards(GuardBuilder.TYPE_MATCH) + return SDPAParamsVariable.create(self.tx, value, self.source) + elif isinstance(value, torch._functorch.pyfunctorch.FuncTorchInterpreter): + self.install_guards(GuardBuilder.ID_MATCH) + return FuncTorchInterpreterVariable(value) + elif isinstance(value, torch.Event): + self.install_guards(GuardBuilder.TYPE_MATCH) + index = register_user_object(value, self.source) + event_proxy = self.tx.output.create_proxy( + "call_function", + get_external_object_by_index, + (index,), + {}, + ) + set_example_value(event_proxy.node, value) + return EventVariable( + event_proxy, + value, + index, + source=self.source, + ) + elif ( + istype(value, contextlib.nullcontext) + and inspect.getattr_static(value, "enter_result", None) is None + ): + self.install_guards(GuardBuilder.TYPE_MATCH) + return NullContextVariable(source=self.source) + elif KeyedJaggedTensorVariable.is_matching_object(value): + self.install_guards(GuardBuilder.TYPE_MATCH) + result = KeyedJaggedTensorVariable(value, source=self.source) + # TODO: this doing it manually is bad + return self.tx.output.side_effects.track_object_existing(value, result) + elif isinstance(value, torch.optim.Optimizer): + self.install_guards(GuardBuilder.ID_MATCH) + self.source = OptimizerSource(self.source) + return OptimizerVariable(value, source=self.source) + elif isinstance(value, torch.DispatchKeySet): + self.install_guards(GuardBuilder.DISPATCH_KEY_SET_MATCH) + return DispatchKeySetVariable(value) + elif WorldMetaClassVariable.is_group_member_type(value): + return WorldMetaClassVariable(value, source=self.source) + elif ProcessGroupVariable.is_process_group(value): + self.install_guards(GuardBuilder.ID_MATCH) + return ProcessGroupVariable(value, source=self.source) + elif DeviceMeshVariable.is_device_mesh(value): + # TODO: see if we need to add custom guard instead of a simple ID_MATCH + self.install_guards(GuardBuilder.EQUALS_MATCH) + return DeviceMeshVariable(value, source=self.source) + elif PlacementClassVariable.is_placement_type(value): + # TODO: see if we need to add custom guard instead of a simple ID_MATCH + self.install_guards(GuardBuilder.ID_MATCH) + return PlacementClassVariable(value, source=self.source) + elif PlacementVariable.is_placement(value): + # TODO: see if we need to add custom guard instead of a simple ID_MATCH + self.install_guards(GuardBuilder.EQUALS_MATCH) + return PlacementVariable( + value, + source=self.source, + ) + elif ( + id(value) in ITERTOOLS_TYPE_IDS + and id(value) not in ITERTOOLS_POLYFILLED_TYPE_IDS + ): + self.install_guards(GuardBuilder.CLASS_MATCH) + return ItertoolsVariable(value, source=self.source) + elif isinstance(value, _DynamicScalar): + is_int = isinstance(value, DynamicInt) + source = DynamicScalarSource(self.source, is_int) + if id(value) in self.tx.output.root_tracer.dynamic_scalar_nodes: + # If we've already seen this dynamic scalar, reuse the existing + # SymInt/SymFloat node. + node = self.tx.output.root_tracer.dynamic_scalar_nodes[id(value)] + else: + sym = self.tx.output.shape_env.create_unspecified_symbol( + value.real, + source=source, + dynamic_dim=DimDynamic.DYNAMIC, + ) + node = self.tx.output.shape_env.create_symintnode( + sym, + hint=value.real, + source=source, + ) + + # Bind to graph input + sym_node_proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(node), + node, + source=source, + ) + sym_node_proxy.node.meta["grapharg"] = GraphArg( + source, + node, + False, + None, + is_tensor=False, + example_strong_ref=node, + ) + sym_expr = node.node.expr + assert isinstance(sym_expr, sympy.Symbol), ( + f"{sym_expr} is not a basic Symbol." + ) + self.tx.output.tracked_fakes.append(TrackedFake(node, source, None)) + return SymNodeVariable.create(self.tx, sym_node_proxy, node) + elif is_torch_sym(value): + # Note: this doesn't handle nested symints. + # For SymBool input, we reuse the infra for SymInt by simulating SymBool with a SymInt in dynamo. + + # Concretely, + # 1. We create a SymInt in dynamo's shape_env, whose source is constructed as ConvertIntSource(self.source). + # so that guards on the SymInts can be effectively applied on the original SymBool in user program. + # 2. We create a SymBool based on the SymInt in dynamo's ShapeEnv. Because the original user program + # depends on the value being a SymBool. This allows dynamo to interpret the user's program correctly. + source = ( + self.source + if isinstance(value, torch.SymInt) + else ConvertIntSource(self.source) + ) + if value.node.has_hint(): + new_symint = ( + self.tx.output.shape_env.create_unspecified_symint_and_symbol( + int(value.node.hint), + source, + dynamic_dim=DimDynamic.DYNAMIC, + ) + ) + else: + if isinstance(value, torch.SymBool): + # We need to create an unbacked symint to replace the unbacked symbool. + new_symint = self.tx.output.shape_env.create_unbacked_symint() + else: + # TODO (yidi): we need to figure out a way to propagate the guards + # we accumulated when tracing the subggraph to outer shape_env. For normal symints, + # this is automatically done by evaluating the guards once but this + # will cause data-dependent error when we evaluate the outer unbacked symints. + # The test case that triggers this graph break is test_cond_unbacked_symint_closure + unimplemented( + gb_type="Attempted to wrap unbacked SymInt", + context="", + explanation="Unbacked SymInt input is not supported yet.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + sym_node_proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(new_symint), + new_symint, + source=source, + ) + + sym_node_proxy.node.meta["grapharg"] = GraphArg( + source, + new_symint, + False, + None, + is_tensor=False, + example_strong_ref=new_symint, + ) + # We bind the new_symint to graph input. + sym_expr = new_symint.node.expr + assert isinstance(sym_expr, sympy.Symbol), ( + f"{sym_expr} is not a basic Symbol." + ) + self.tx.output.tracked_fakes.append(TrackedFake(new_symint, source, None)) + + tracing_symint = ( + new_symint if isinstance(value, torch.SymInt) else new_symint == 1 + ) # cast it back to symbool for tracing + return SymNodeVariable(sym_node_proxy, tracing_symint) + + elif isinstance(value, (JITFunction, Autotuner)): + self.install_guards(GuardBuilder.ID_MATCH) + return TritonKernelVariable( + value, + None, # No kernel idx provided + None, # No grid provided + source=self.source, + ) + elif value is create_1d_tma_descriptor: + return CreateTMADescriptorExperimentalVariable(rank=1) + elif value is create_2d_tma_descriptor: + return CreateTMADescriptorExperimentalVariable(rank=2) + elif value is TensorDescriptor.from_tensor: + return CreateTMADescriptorStableVariable() + elif isinstance(value, torch.amp.autocast_mode.autocast): + self.install_guards(GuardBuilder.ID_MATCH) + return AutocastModeVariable( + target_values=[ + value.device, + value.fast_dtype, + value._enabled, + value._cache_enabled, + ], + source=self.source, + ) + elif TorchCtxManagerClassVariable.is_matching_cls(value): + if inspect.isclass(value): + self.install_guards(GuardBuilder.CLASS_MATCH) + elif inspect.isfunction(value): + self.install_guards(GuardBuilder.CLOSURE_MATCH) + return TorchCtxManagerClassVariable(value, source=self.source) + elif inspect.getattr_static(value, "__script_if_tracing_wrapper", False): + self.install_guards(GuardBuilder.TYPE_MATCH) + return WrapperUserFunctionVariable( + value, "__original_fn", source=self.source + ) + elif is_lru_cache_wrapped_function(value): + self.install_guards(GuardBuilder.TYPE_MATCH) + return WrapperUserFunctionVariable(value, "__wrapped__", source=self.source) + elif value is traceback.clear_frames: + return TracebackVariable(source=self.source) + elif value is sys.exc_info or ( + sys.version_info >= (3, 11) and value is sys.exception + ): + return SysFunctionVariable(value, source=self.source) + elif is_function_or_wrapper(value) and inspect.getattr_static( + value, "_torchdynamo_inline", False + ): + self.install_guards(GuardBuilder.TYPE_MATCH) + return WrapperUserFunctionVariable( + value, "_torchdynamo_inline", source=self.source + ) + elif value is functools.wraps: + self.install_guards(GuardBuilder.ID_MATCH) + return FunctoolsWrapsVariable(value, source=self.source) + elif value is collections.namedtuple: + self.install_guards(GuardBuilder.ID_MATCH) + return CollectionsNamedTupleFunction(value, source=self.source) + elif isinstance( + value, types.BuiltinMethodType + ) and BuiltinMethodVariable.is_supported_builtin_method(value): + self.install_guards(GuardBuilder.ID_MATCH) + return BuiltinMethodVariable(value, source=self.source) + elif is_function(value) and value in (float.fromhex, float.hex): + self.install_guards(GuardBuilder.ID_MATCH) + return GetAttrVariable( + BuiltinVariable(float, source=self.source), + value.__name__, + ) + elif is_function_or_wrapper(value): + value, attr_name = unwrap_with_attr_name_if_wrapper(value) + # For these wrappers, Dynamo points to the wrapped function, + # so source needs to be updated as well. + if attr_name is not None: + self.source = AttrSource(self.source, attr_name) + return trace_rules.lookup(value).create_with_source( + value, source=self.source + ) + elif value is random.Random: + self.install_guards(GuardBuilder.ID_MATCH) + return RandomClassVariable(source=self.source) + elif istype(value, random.Random) and RandomVariable.is_supported_random_obj( + value + ): + self.install_guards(GuardBuilder.TYPE_MATCH) + result = RandomVariable(value, source=self.source) + self.tx.output.side_effects.track_mutable(value, result) + return result + # Don't use istype, since some python modules are not subclasses of types.ModuleType directly. + # E.g, type(torch.ops) -> , + # type(torch.backends.cudnn) -> + elif isinstance(value, (types.ModuleType, replay_record.DummyModule)): + self.install_guards(GuardBuilder.MODULE_MATCH) + result = PythonModuleVariable( + value, + source=self.source, + ) + self.tx.output.side_effects.track_object_existing(value, result) + return result + elif isinstance(value, types.MethodType) and isinstance( + value.__self__, (torch.nn.Module, torch.utils._pytree.TreeSpec) + ): + # don't let MethodTypes fall through to UserDefinedObject, + # which doesn't support 'CALL_FUNCTION' + + # TODO(whc): Why do we limit this to methods on NNModules? + # I don't have a good reason for this, but it preserves the existing behavior + # for MBartForConditionalGeneration, which generates many graph breaks and OOMs otherwise. + # I suspect we probably want to relax this check and dig deeper there. + + # In order to construct a MethodVariable in Dynamo, we start with an actual method obj from python, + # but need to separately wrap its underlying `__func__` and its `self` argument. We wrap `self` here + # and then `__func__` gets wrapped inside UserMethodVariable. + self_obj = VariableBuilder( + self.tx, source=AttrSource(self.source, "__self__") + )(value.__self__) + assert self_obj and isinstance(self_obj, VariableTracker), ( + "Failed to produce a valid self obj" + ) + return UserMethodVariable( + value.__func__, + self_obj, + source=self.source, + ) + elif isinstance(value, types.GetSetDescriptorType): + # GetSet descriptors are C functions attached to an attribute lookup + # using PyGetSetDef. Python, on attribute lookup, can decide to + # create a new object on the fly, and therefore the `id` of the + # descriptors is not guaranteed to be same for different attribute + # accesses. Since these are unlikely to change during the program + # execution, we can skip guarding on them. + return GetSetDescriptorVariable(value) + elif isinstance(value, types.MethodWrapperType): + # Method-wrappers are written in C, and they are not guaranteed to + # return the same object on attribute lookup. Therefore, we cannot + # insert a ID_MATCH guard here. method-wrappers are very + # unlikely to change, so its ok to skip the guard here. + return MethodWrapperVariable(value) + elif issubclass(type(value), type) and issubclass(value, BaseException): + # match user defined exceptions + self.install_guards(GuardBuilder.ID_MATCH) + return UserDefinedExceptionClassVariable(value) + elif issubclass(type(value), type): + if value in ( + torch.utils.hooks.BackwardHook, + torch.nn.Parameter, + torch.nn.Buffer, + ): + # TODO(jansel): combine this case with the one above + return trace_rules.lookup(value).create_with_source( + value, source=self.source + ) + if value is torch.autograd._unsafe_preserve_version_counter: + self.install_guards(GuardBuilder.CLASS_MATCH) + return PreserveVersionContextVariable.constructor(self.tx) + if ( + # `value` must be a strict subclass of `torch.Tensor` + issubclass(value, torch.Tensor) + and value is not torch.Tensor + # `TensorSubclassVariable` is not for subclass that overrides + # `torch_dispatch`. + and value.__torch_dispatch__ is torch.Tensor.__torch_dispatch__ + # `TensorSubclassVariable` would lead to construction of + # `TensorWithTFOverrideVariable`, but we don't want that for + # traceable wrapper subclasses (we wrap those subclass instances + # into `TensorVariable`). + and not is_traceable_wrapper_subclass_type(value) + ): + return TensorSubclassVariable(value, source=self.source) + + if not is_from_closure_source(self.source): + # For closure source, the variable comes from LOAD_SUPER_ATTR, + # which calls self.__class__. This is internal Cpython + # implementation, and it is rare for the user to modify + # self.__class__ manually. + # For other cases, this is a userdefined class, so install an + # ID_MATCH even if its a global variable. + self.install_guards(GuardBuilder.CLASS_MATCH) + + if is_opaque_type(value): + return OpaqueObjectClassVariable( + value, + source=self.source, + ) + + return UserDefinedClassVariable( + value, + source=self.source, + ) + elif TorchScriptObjectVariable.is_matching_cls(type(value)): + from ..source import ( + FlattenScriptObjectSource, + ScriptObjectQualifiedNameSource, + ) + + if torch._library.fake_class_registry.tracing_with_real(value): + proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(value), + value, + source=self.source, + ) + + # setting is_unspecialized=False to not insert a as_tensor call in reconstruct by default + # setting example to be real value because these example values will be used + # as example_inputs for user compiler. + proxy.node.meta["grapharg"] = GraphArg( + self.source, value, False, None, False, value + ) + return TorchScriptObjectVariable.create( + proxy, + value, + source=self.source, + ) + + if is_opaque_type(type(value)): + # Check if this is a value-type opaque object (registered as both opaque type and constant) + if is_opaque_value_type(type(value)): + # Value-type: guard on equality (will use __eq__) + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return TorchScriptObjectVariable.create( + value, + value, + source=self.source, + ) + else: + # Reference-type: guard only on type/identity + self.install_guards(GuardBuilder.TYPE_MATCH) + + elif not hasattr(value, "__obj_flatten__"): + # This exists to allow a smoother transition. + # The implications are: + # The script objects won't be tracked as proxies. + # Methods on these objects won't show up in the graph. + # The original script object might be mutated. + return self.wrap_user_defined(value) + else: + # Install the guards on the fully qualified name of the script object + LazyVariableTracker.realize_all( + VariableBuilder( + self.tx, ScriptObjectQualifiedNameSource(self.source) + )( + value._type().qualified_name() # type: ignore[attr-defined] + ) + ) + # Install the guards on the content of the script object by setting the source + # to be FlattenScriptObjectSource, which calls __obj_flatten__() to get the contents. + LazyVariableTracker.realize_all( + VariableBuilder(self.tx, FlattenScriptObjectSource(self.source))( + value.__obj_flatten__() + ) + ) + + fake_script_obj = torch._library.fake_class_registry.maybe_to_fake_obj( + self.tx.output.fake_mode, value + ) + + proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(value), + fake_script_obj, + source=self.source, + ) + + # setting is_unspecialized=False to not insert a as_tensor call in reconstruct by default + # setting example to be real value because these example values will be used + # as example_inputs for user compiler. + proxy.node.meta["grapharg"] = GraphArg( + self.source, value, False, None, False, fake_script_obj + ) + return TorchScriptObjectVariable.create( + proxy, + fake_script_obj, + source=self.source, + ) + elif ( + isinstance(value, (dict, collections.OrderedDict)) + and type(value).__new__ is dict.__new__ + ): + # Construct a dict_vt that will reside inside the UserDefinedDictVariable + self.install_guards(GuardBuilder.TYPE_MATCH) + self.install_guards(GuardBuilder.SEQUENCE_LENGTH) + + # Guard on the key order + self.tx.output.guard_on_key_order.add(self.source) + + # We need all the keys to be hashable. We do this within the + # _HashableTracker class in dicts.py + def build_key_value(i, k, v): + base = self.get_source() + source_key = ConstDictKeySource(base, i) + key = LazyVariableTracker.create(k, source_key) + + source_value = DictSubclassGetItemSource(base, source_key) + res_value = LazyVariableTracker.create(v, source_value) + + return key, res_value + + # Ensure that we call dict.keys and not value.keys (which can call + # overridden keys method). In the C++ guards, we relied on + # PyDict_Next to traverse the dictionary, which uses the internal + # data structure and does not call the overridden keys method. + result = dict( + build_key_value(i, k, v) + for i, (k, v) in enumerate(get_items_from_dict(value)) + ) + + dict_vt = ConstDictVariable( + result, + user_cls=( + collections.OrderedDict + if isinstance(value, collections.OrderedDict) + else dict + ), + mutation_type=ValueMutationExisting(), + source=self.source, + ) + # Force this to reconstruct on mutation to keep the reconstruction + # bytecode simple + dict_vt.should_reconstruct_all = True + + result = UserDefinedDictVariable(value, dict_vt=dict_vt, source=self.source) + return self.tx.output.side_effects.track_object_existing(value, result) + elif isinstance(value, tuple): + self.install_guards(GuardBuilder.TYPE_MATCH) + self.install_guards(GuardBuilder.SEQUENCE_LENGTH) + + # NB - Be careful in not triggering user code. Guards also work on + # the underlying tuple data structure. + output = [ + LazyVariableTracker.create( + tuple.__getitem__(value, i), + source=GetItemSource(self.get_source(), i), + ) + for i in range(tuple.__len__(value)) + ] + + tuple_vt = TupleVariable( + output, source=self.source, mutation_type=ValueMutationExisting() + ) + result = UserDefinedTupleVariable( + value, tuple_vt=tuple_vt, source=self.source + ) + return self.tx.output.side_effects.track_object_existing(value, result) + elif isinstance(value, list): + self.install_guards(GuardBuilder.TYPE_MATCH) + self.install_guards(GuardBuilder.SEQUENCE_LENGTH) + + # NB - Be careful in not triggering user code. Guards also work on + # the underlying list data structure. + output = [ + LazyVariableTracker.create( + list.__getitem__(value, i), + source=ListGetItemSource(self.get_source(), i), + ) + for i in range(list.__len__(value)) + ] + list_vt = ListVariable( + output, source=self.source, mutation_type=ValueMutationExisting() + ) + result = UserDefinedListVariable(value, list_vt=list_vt, source=self.source) + return self.tx.output.side_effects.track_object_existing(value, result) + elif isinstance(value, (set, frozenset)): + self.install_guards(GuardBuilder.TYPE_MATCH) + self.install_guards(GuardBuilder.SEQUENCE_LENGTH) + + L = list(dict.fromkeys(value)) + output = [ + LazyVariableTracker.create( + list.__getitem__(L, i), + source=NonSerializableSetGetItemSource(self.get_source(), i), + ) + for i in range(list.__len__(L)) + ] + set_vt_cls = SetVariable if isinstance(value, set) else FrozensetVariable + set_vt = set_vt_cls( + output, source=self.source, mutation_type=ValueMutationExisting() + ) + result = UserDefinedSetVariable(value, set_vt=set_vt, source=self.source) + return self.tx.output.side_effects.track_object_existing(value, result) + elif issubclass(type(value), MutableMapping): + self.install_guards(GuardBuilder.TYPE_MATCH) + result = MutableMappingVariable(value, source=self.source) + return self.tx.output.side_effects.track_object_existing(value, result) + elif is_frozen_dataclass(value): + self.install_guards(GuardBuilder.TYPE_MATCH) + result = FrozenDataClassVariable.create(self.tx, value, source=self.source) + return self.tx.output.side_effects.track_object_existing(value, result) + elif isinstance(value, dict_keys): + if all(ConstantVariable.is_literal(k) for k in value): + # If the dict_keys object is passed from outside the compile region, it must either be passed along with + # the corresponding dict object or treated as a set (when only the keys are passed into the compiled region). + # - If it is passed along with the dict, the dict object itself is already guarded. + # - If only the dict_keys object is passed, we add EQUALS_MATCH and SEQUENCE_LENGTH guards + # to ensure it remains unchanged across multiple runs. + items = [SourcelessBuilder.create(self.tx, v) for v in value] + install_guard( + self.get_source().make_guard(GuardBuilder.SEQUENCE_LENGTH), + self.get_source().make_guard(GuardBuilder.EQUALS_MATCH), + ) + return DictKeySetVariable(items, source=self.source) + else: + unimplemented( + gb_type="non-const keys in dict_keys", + context=f"non-const keys: {[k for k in value if not ConstantVariable.is_literal(k)]}", + explanation="Dynamo expects dict_keys keys to be constants.", + hints=[ + "Ensure your dict_keys keys are constants (e.g. int, float, strings)", + ], + ) + elif IntWrapperVariable.is_matching_object(value): + from torch.export.dynamic_shapes import _DimHintType + + if value.dynamism is None or value.dynamism.type == _DimHintType.STATIC: + return self.wrap_symint(value.val) + elif value.dynamism.type == _DimHintType.DYNAMIC: + log.debug( + "%s marked %s via IntWrapper", + self.source.name, + DimDynamic.DYNAMIC, + ) + return self.wrap_symint( + value.val, + dynamism=DimDynamic.DYNAMIC, + context=SymIntSymbolicContext( + constraint=RelaxedUnspecConstraint(warn_only=False) + ), + ) + elif value.dynamism.type == _DimHintType.AUTO: + log.debug( + "%s marked %s via IntWrapper", + self.source.name, + DimDynamic.DYNAMIC, + ) + return self.wrap_symint(value.val, dynamism=DimDynamic.DYNAMIC) + else: + raise RuntimeError(f"Undefined dynamism {value.dynamism}") + else: + return self.wrap_user_defined(value) + + def wrap_user_defined(self, value: Any): + self.install_guards(GuardBuilder.TYPE_MATCH) + result = UserDefinedObjectVariable(value, source=self.source) + if not SideEffects.cls_supports_mutation_side_effects(type(value)): + # don't allow STORE_ATTR mutation with custom __setattr__ + return result + return self.tx.output.side_effects.track_object_existing(value, result) + + def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]): + for item in value: + if item is value: + unimplemented( + gb_type="list elements are pointing to the list itself", + context="", + explanation="Dynamo does not support lists whose items reference to itself", + hints=["Avoid using self referential list"], + ) + + if config.specialize_int and type(value) is torch.Size: + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return ConstantVariable.create(value=value) + + # One can index a tensor with a list/tuple. Therefore, we need to + # have a stricter match. + self.install_guards(GuardBuilder.SEQUENCE_LENGTH) + + # Tuples are immutable objects, so we should mark its items static. This + # avoids wrapping of tuple items as symints. This helps for nn module + # attributes like conv2d strides, dilations. + if ( + istype(value, tuple) + and all(ConstantVariable.is_literal(item) for item in value) + and self.source.guard_source.is_unspecialized_nn_module() + ): + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return TupleVariable([ConstantVariable.create(item) for item in value]) + + output = [ + LazyVariableTracker.create( + item, + source=GetItemSource(self.get_source(), i), + ) + for i, item in enumerate(value) + ] + + maybe_gm = self.tx.output.local_scope.get("self") + if isinstance( + self.source, LocalSource + ) and self.source.local_name in get_locals_to_steal(maybe_gm): + # The input tensor list to dynamo from compiled autograd may contain activations + # which are freed as they are used in inductor. Dynamo's default behavior is to + # lift all tensors to the graph inputs, but this will cause dynamo to hold an + # extra reference to the activation tensors and increase peak memory usage. + # To allow freeing ASAP, we keep the list as graph argument to the dynamo output + # graph, and unpack it locally. + # e.g. instead of `def forward(self, L_inputs_0_, L_inputs_1_, ...):`, we have + # `def forward(self, L_inputs_):` + source = self.source + assert isinstance(value, list) + tensor_list_proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(value), + value, + source=source, + ) + tensor_list_proxy.node.meta["steal_arg"] = True + + list_variable = wrap_fx_proxy_cls( + target_cls=TensorVariable, + tx=self.tx, + proxy=tensor_list_proxy, + example_value=value, + subclass_type=None, + source=source, + ) + + # Apply relevant logic from `VariableTracker.build(value[i])` + # (except for the `create_graph_input` stuff). + guards = [] + for i, tensor_variable in enumerate(list_variable.items): + source_i = GetItemSource(base=source, index=i, index_is_slice=False) + # access unpacked tensor from this list instead of from a lifted arg + self.tx.output.input_source_to_var[source_i] = tensor_variable + tensor_variable.proxy.node.meta["tensor_dict"] = _extract_tensor_dict( + value[i] + ) + guard = functools.partial( + GuardBuilder.TENSOR_MATCH, value=TensorWeakRef(value[i]) + ) + guards.append(source_i.make_guard(guard)) + + install_guard(*guards, skip=1) + + grapharg = GraphArg( + source, + value, + pass_arg_as_tensor=False, + fake_tensor=None, + is_tensor=False, + ) + tensor_list_proxy.node.meta["grapharg"] = grapharg + + # The following is very important for maintaining the "python object + # <==> variable tracker" 1-to-1 mapping, which is mainly handled via + # `side_effects`. Note that constructing `tensor_variable` above + # already adds it to graph arg, but we never registered it with + # `side_effects`. The preemptive `realize` calls here basically + # does that registration (at the end of `self.__call__`). + # + # A slightly cleaner alternative is to register the + # `tensor_variable`s above with `side_effects` directly, and just + # return the `list_variable`, but that breaks some tensor-subclass + # related tests like `test_inputs_aliasing_bytecode_stack_restore`, + # because `tensor_variable` is constructed via + # `handle_traced_output`, which doesn't really expect/handle tensor + # subclass. + # + # Eventually, we expect to fix remove all of these by having Dynamo + # auto-boxing inputs to the compiled graph, see + # https://github.com/pytorch/pytorch/issues/153701. + for vt in output: + vt.realize() + + result = BaseListVariable.cls_for_instance(value)(output, source=self.source) + if istype(value, (list, collections.deque)): + return self.tx.output.side_effects.track_mutable(value, result) + return result + + def wrap_tuple_iterator(self, value: tuple_iterator): + self.install_guards(GuardBuilder.TUPLE_ITERATOR_LEN) + output = [ + VariableBuilder(self.tx, TupleIteratorGetItemSource(self.get_source(), i))( + tuple_iterator_getitem(value, i) + ) + for i in range(tuple_iterator_len(value)) + ] + result = TupleIteratorVariable(output, source=self.source) + return self.tx.output.side_effects.track_mutable(value, result) + + def wrap_range_iterator(self, value: range_iterator): + self.install_guards(GuardBuilder.RANGE_ITERATOR_MATCH) + # Get all the values from the range iterator; no need to install guards + # on items since `RANGE_ITERATOR_MATCH` guarantees the same items. + items = [ConstantVariable.create(v) for v in copy.deepcopy(value)] + result = ListIteratorVariable(items, source=self.source) + return self.tx.output.side_effects.track_mutable(value, result) + + def wrap_slice_range(self, value: Union[slice, range]): + items = [ + VariableBuilder(self.tx, AttrSource(self.get_source(), k))( + getattr(value, k) + ) + for k in ("start", "stop", "step") + ] + self.install_guards(GuardBuilder.TYPE_MATCH) + if isinstance(value, slice): + return SliceVariable(items, self.tx, source=self.source) + else: + return RangeVariable(items, source=self.source) + + def mark_static_input(self, value: torch.Tensor, guard: bool): + from ..decorators import mark_static_address + + static_inputs_log.debug( + "Marking static input %s, id: %s)", self.source.name, id(value) + ) + mark_static_address(value, guard=guard) + + # Check if we've seen this tensor before and update graph metadata if needed + # As long as this runs before AOT this is sound + if value in self.tx.output.side_effects: + var = self.tx.output.side_effects[value] + var.proxy.node.meta["tensor_dict"]["_dynamo_static_input_type"] = ( + value._dynamo_static_input_type + ) + + def wrap_module(self, value: torch.nn.Module): + from ..eval_frame import OptimizedModule + + if len(value.__dict__) == 0: + unimplemented( + gb_type="Uninitialized nn.Module", + context=typestr(value), + explanation=f"Attempted to trace an uninitialized nn.Module of type {typestr(value)}.", + hints=[ + *graph_break_hints.USER_ERROR, + "Ensure your nn.Module instance has called `super().__init__()`.", + ], + ) + if istype(value, OptimizedModule): + # Check if the optimized module was disabled + if inspect.getattr_static(value.forward, "_torchdynamo_disable", False): + # This bytecode is mostly of kind LOAD_ATTR or LOAD_METHOD. If + # we graph break here, Dynamo does not know how to create + # continuation functions for such bytecodes. So, we delay the + # graph break to CALL_FUNCTION. + msg = inspect.getattr_static( + value.forward, "_torchdynamo_disable_msg", None + ) + return DelayGraphBreakVariable( + source=self.source, + msg=f"Optimized `nn.Module` is wrapped with `torch.compiler.disable` (reason: {msg})", + ) + + self.install_guards(GuardBuilder.TYPE_MATCH) + self.source = AttrSource(self.source, "_orig_mod") + return self.wrap_module(value._orig_mod) + + if ( + isinstance(value, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM)) + and not config.allow_rnn + ): + unimplemented( + gb_type="Attempted to wrap RNN, GRU, or LSTM", + context=str(value), + explanation="Dynamo does not support RNN, GRU, or LSTM.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + if getattr(value, "_is_fsdp_managed_module", False): + # See note [Dynamo treats FSDP wrapped modules as UnspecializedNNModule] + # in fully_sharded_data_parallel.py for more information + + # we can't do this assert inside FSDP constructor, + # since we don't know yet whether dynamo will be used + if not getattr(value, "_fsdp_use_orig_params", False): + unimplemented( + gb_type="FSDP with use_orig_params=False", + context="", + explanation="Dynamo only supports FSDP with use_orig_params=True", + hints=[], + ) + + # Note on FSDP guarding + # Eager FSDP already assumes (requires, but without enforcement) + # that users don't mutate their model parameters/structure after + # FSDP wrapping, because FSDP wouldn't notice or update its + # FlatParams. + # + # Therefore, torch.compile can skip guarding on params or submodule + # structure of fsdp_managed modules, by using FSDPNNModuleSource as + # the guard source. This behavior is gated on + # config.skip_fsdp_guards. + self.install_guards(GuardBuilder.TYPE_MATCH) + result = FSDPManagedNNModuleVariable(value, source=self.get_source()) + if not SideEffects.cls_supports_mutation_side_effects(type(value)): + # don't allow STORE_ATTR mutation with custom __setattr__ + return result + return self.tx.output.side_effects.track_object_existing(value, result) + elif mutation_guard.is_dynamic_nn_module(value, self.tx.export): + # created dynamically, don't specialize on it + + # Note [Tracing a torch.compiled function] + # when make_fx tracing a compiled function, we need + if isinstance(value, torch.fx.experimental.proxy_tensor._AttrProxy): + value = value.get_base() + self.source = AttrProxySource(self.source) + + if torch._dynamo.config.inline_inbuilt_nn_modules: + freezing = is_parameter_freezing() + + # Guard against the case where user may overwrite named parameters + # / named buffers + # NOTE: This is not likely to happen but worth guarding to avoid + # exception + if ( + callable(value.named_parameters) + and value.named_parameters.__func__ + is og_module_named_parameters_fn_ptr + ): + try: # catch TypeErrors in named_parameters() from unserializable nn modules + for _, p in value.named_parameters(): + self.mark_static_input(p, guard=freezing) + except TypeError as e: + raise_observed_exception(type(e), self.tx, args=list(e.args)) + + if ( + callable(value.named_buffers) + and value.named_buffers.__func__ is og_module_named_buffers_fn_ptr + ): + try: # catch TypeErrors in named_parameters() from unserializable nn modules + for _, b in value.named_buffers(): + self.mark_static_input(b, guard=freezing) + except TypeError as e: + raise_observed_exception(type(e), self.tx, args=list(e.args)) + + if freezing: + # we need to add the module to tracing context + # in order to allow its params to get invalidated + # this will get cleaned up once compile ends + self.tx.output.nn_modules[self.name] = value + + if ( + value.__module__.startswith(("torch.nn.modules", "torch.ao.")) + and not value.__module__.startswith("torch.nn.modules.container") + ) or getattr(value.__class__, "_dynamo_marked_static", False): + new_source = self.source + if config.inline_inbuilt_nn_modules and ( + not self.tx.output.export or config.install_free_tensors + ): + # Export corner case - look at test_repros.py test_inlining_cornercase + new_source = UnspecializedBuiltinNNModuleSource(self.source) + result = UnspecializedBuiltinNNModuleVariable(value, source=new_source) + install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH)) + else: + new_source = self.source + if config.inline_inbuilt_nn_modules and ( + not self.tx.output.export or config.install_free_tensors + ): + # Export corner case - look at test_repros.py test_inlining_cornercase + new_source = UnspecializedNNModuleSource(self.source) + result = UnspecializedNNModuleVariable(value, source=new_source) + install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH)) + + self.tx.output.add_fqn_info_for_inlined_modules(value, self.source) + + if not SideEffects.cls_supports_mutation_side_effects(type(value)): + # don't allow STORE_ATTR mutation with custom __setattr__ + return result + return self.tx.output.side_effects.track_object_existing(value, result) + elif issubclass( + value.__class__, torch.nn.parallel.distributed.DistributedDataParallel + ): + self.install_guards(GuardBuilder.TYPE_MATCH) + return UnspecializedNNModuleVariable(value, source=self.get_source()) + else: + return self.tx.output.register_attr_or_module( + value, + self.name, + source=self.get_source(), + # Guards are added inside register_attr_or_module + ) + + def wrap_literal(self, value): + if type(value) is int: + # allowlist has higher precedence over specialization control. + if is_dynamic_source(self.source.name): + log.debug("%s marked dynamic via source whitelist", self.source.name) + return self.wrap_symint(value, dynamism=DimDynamic.DYNAMIC) + + if is_unbacked_source(self.source.name): + log.debug("%s marked unbacked via source whitelist", self.source.name) + return self.wrap_symint(value, dynamism=DimDynamic.SIZE_LIKE_UNBACKED) + + if not config.specialize_int: + # unspecializing int by default, but still + # specialize for the following conditions + if is_int_specialization_case(value, self.source): + recompile_hint = None + if ( + self.source.guard_source.is_unspecialized_builtin_nn_module() + or self.source.guard_source.is_unspecialized_nn_module() + ): + # This means that it is an integer from a NN module. + # Dynamo considers nn module int attributes to be static + # (a good heuristic). But a user might want to mark the + # int attribute to be a symint, so track this integer + # for recompilation later. + recompile_hint = ( + "torch.compile considers integer attributes of the nn.Module to be static. " + "If you are observing recompilation, you might want to make this integer dynamic " + "using torch._dynamo.config.allow_unspec_int_on_nn_module = True, or convert this " + "integer into a tensor." + ) + + process_automatic_dynamic( + self.tx, + self.source.name, + FrameStateSizeEntry.make_scalar(value), + is_unspecialized_nn_module=self.source.guard_source.is_unspecialized_nn_module(), + ) + self.install_guards( + functools.partial( + GuardBuilder.EQUALS_MATCH, recompile_hint=recompile_hint + ) + ) + return ConstantVariable.create(value=value, source=self.source) + + return self.wrap_symint(value) + elif not config.specialize_float and type(value) is float: + return self.wrap_symfloat(value) + else: + self.install_guards(GuardBuilder.CONSTANT_MATCH) + result = ConstantVariable.create(value=value, source=self.source) + if isinstance(value, (list, set)): + return self.tx.output.side_effects.track_mutable(value, result) + return result + + def assert_not_wrapped_by_this_graph(self, value: torch.Tensor): + if is_fake(value) and maybe_get_fake_mode(value) is self.tx.fake_mode: + raise InternalTorchDynamoError( + "Cannot wrap a Tensor that has already been", + "wrapped by this instance of Dynamo", + ) + + def wrap_tensor(self, value: torch.Tensor): + source = self.get_source() + + # We cannot already be tracking the tensor, which implies + # it would have already been wrapped + assert value not in self.tx.output.side_effects + + is_static_input = get_static_address_type(value) is not None + + if ( + config.inline_inbuilt_nn_modules + and not is_static_input + and ( + isinstance(value, torch.nn.Parameter) + # mark tensor attributes of nn modules static. This is done to keep inline_inbuilt_nn_modules behavior + # compatible with previous behavior. + or (source and source.guard_source.is_unspecialized_nn_module()) + ) + ): + self.mark_static_input(value, guard=is_parameter_freezing()) + is_static_input = True + + # Install any tensors which are "free" variables; that is: + # 1. Globals + # 2. NonLocals + # 3. tensors that are attributes of nn module + should_install_free_tensor = config.install_free_tensors and ( + is_from_global_source(source) + or is_from_nonlocal_source(source) + or is_from_unspecialized_nn_module_source(source) + ) + + make_graph_attribute = is_static_input and ( + not config.inline_inbuilt_nn_modules + or is_parameter_freezing() + or torch._dynamo.config.prepare_freezing + ) + + if should_install_free_tensor or ( + (source.guard_source.is_specialized_nn_module() or make_graph_attribute) + and not source.guard_source.is_fsdp_module() + ): + self.assert_not_wrapped_by_this_graph(value) + return self.tx.output.register_attr_or_module( + value, self.name, source=source + ) + + if get_static_address_type(value) == "guarded": + # If it's a guarded tensor, we can install the parameter directly + # into the Fx graph instead of lifting it as an input. Lifting + # offers no benefit, such as regional compilation, since we still + # guard on the tensor's ID. Moreover, installing it in the Fx graph + # eliminates the pre-graph bytecode required to extract the tensor + # from locals/globals, reducing overhead. This can lead to + # significant cost savings, especially for optimizers handling many + # tensors. + self.install_guards(GuardBuilder.ID_MATCH) + self.assert_not_wrapped_by_this_graph(value) + return self.tx.output.register_attr_or_module( + value, self.name, source=source + ) + + if is_constant_source(source): + self.assert_not_wrapped_by_this_graph(value) + return self.tx.output.register_attr_or_module( + value, + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + source=source, + # Guards are added inside register_attr_or_module + ) + + # NB: this just says we accessed a tensor from the same source again + # (e.g., a tensor lives in a global foo, and we LOAD_GLOBAL it twice). + # This is distinct from two distinct sources mapping to the same + # Tensor (per id())! No guard is necessary here. See below for the + # other case. + is_duplicate_tensor = source in self.tx.output.input_source_to_var + if is_duplicate_tensor: + return self.tx.output.input_source_to_var[source] + + options = {} + subclass_type = infer_subclass_type(value) + if subclass_type is not None: + self.install_guards(GuardBuilder.TYPE_MATCH) + + if get_static_address_type(value) == "guarded": + self.install_guards(GuardBuilder.ID_MATCH) + + # By this point, we should have deduplicated all tensors + self.assert_not_wrapped_by_this_graph(value) + + if ( + isinstance(value, torch.Tensor) + and value.is_nested + and not isinstance(value, torch.nested._internal.nested_tensor.NestedTensor) + ): + unimplemented( + gb_type="Attempted to wrap strided NestedTensor", + context="", + explanation="torch.compile does not support strided NestedTensor", + hints=[], + ) + + # TODO(pearu,sparse-team) - Add the corresponding SPARSE_TENSOR_MATCH guards + if ( + isinstance(value, torch.Tensor) + and is_sparse_any(value) + and (not self.tx.export or not config.capture_sparse_compute) + ): + # A hot fix for sparse tensors + torch.compile. Support for + # export + sparsity is being added but we need to create + # SPARSE_TENSOR_GUARDS for guards to work properly. + unimplemented( + gb_type="Attempted to wrap sparse Tensor", + context="", + explanation="torch.compile does not support sparse Tensors", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + if ( + safe_has_grad(value) + and safe_grad(value) is not None + and value.dtype != safe_grad(value).dtype + ): + unimplemented( + gb_type="dtype mismatch between tensor and its gradient", + context=f"tensor dtype: {value.dtype}; grad dtype: {safe_grad(value).dtype}", + explanation="Inconsistent dtype between tensor and its gradient. " + "This can happen in FSDP and crashes meta tensor creation.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + # tx.output has multiple tracers if we're introspecting HigherOrderOperator. + # When we've discovered an untracked tensor, then we actually need + # to get Dynamo to track the tensor (which is what this function does) + # and put it as a graph input on the root tracer. Later on, + # if the input is actually used in the body of the HigherOrderOperator, + # then the relevant SubgraphTracer will lift it to being an input of + # the subgraph. + # See NOTE [HigherOrderOperator tracing design] for more details. + + example_value = wrap_to_fake_tensor_and_record( + value, tx=self.tx, is_tensor=True, source=source + ) + + tensor_proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(value), + example_value, + source=source, + ) + cache_real_value_when_export(self.tx, tensor_proxy, value) + + tensor_variable = wrap_fx_proxy( + tx=self.tx, + proxy=tensor_proxy, + example_value=example_value, + subclass_type=subclass_type, + source=source, + **options, + ) + + if value._is_view(): + # If value is a view, add its base tensor to the tracked fakes list. + # This is so we are able to access the correct source for its symbolic + # shape values, in case we need them. + wrap_to_fake_tensor_and_record( + value._base, + tx=self.tx, + source=AttrSource(source, "_base"), + is_tensor=True, + ) + + guard_type = GuardBuilder.TENSOR_MATCH + + if isinstance(source, GradSource) and is_from_optimizer_source(source): + guard_type = GuardBuilder.NOT_NONE_MATCH + + is_dtensor = torch.distributed.is_available() and isinstance( + value, torch.distributed.tensor.DTensor + ) + if not is_dtensor: + # We guard on the _local_tensor and the _spec, and therefore we dont + # have to guard on the outer DTensor. + self.install_guards( + functools.partial( + guard_type, + value=( + value + if isinstance(source, NumpyTensorSource) + else TensorWeakRef(value) + ), + ) + ) + + # We install TYPE_MATCH guards for traceable wrapper subclass object, + # and recursively install corresponding guard for each inner attribute. + if is_traceable_wrapper_subclass(value): + # Tensor subclass guards are very expensive because they are + # implemented in Python. Since DTensor is PyTorch-maintained class, + # we can skip a lot of these guards. + if is_dtensor: + self.install_guards(GuardBuilder.TYPE_MATCH) + + # The inner tensor name is always _local_tensor. If its not, we + # raise assertion to update the check accordingly. + inner_tensor_name = value.__tensor_flatten__()[0][0] + if inner_tensor_name != "_local_tensor": + raise RuntimeError( + "Expecting Dtensor inner tensor name to be _local_tensor" + ) + + # Now selectively guard on the flattening context + flattening_ctx = value.__tensor_flatten__()[1] + # This is supposed to be (self._spec, self.requires_grad) + if not ( + len(flattening_ctx) == 2 + and flattening_ctx[0] == value._spec + and flattening_ctx[1] == value.requires_grad + ): + # If not, raise an assertion to update to the new guards + raise RuntimeError( + "Expecting Dtensor flattening ctx to be _spec, requires_grad" + ) + # Guard on the dtensor spec + install_guard( + AttrSource(self.source, "_spec").make_guard( + GuardBuilder.DTENSOR_SPEC_MATCH + ) + ) + # Move this to C++ + install_guard( + AttrSource(self.source, "requires_grad").make_guard( + GuardBuilder.EQUALS_MATCH + ) + ) + else: + self.install_guards(GuardBuilder.TENSOR_SUBCLASS_METADATA_MATCH) + self.install_guards(GuardBuilder.TYPE_MATCH) + install_guard( + SubclassAttrListSource(source).make_guard(GuardBuilder.EQUALS_MATCH) + ) + + attrs, _ = value.__tensor_flatten__() + for attr in attrs: + inner_value = getattr(value, attr) + inner_source = AttrSource(self.source, attr) + LazyVariableTracker.realize_all( + VariableBuilder(self.tx, inner_source)(inner_value) + ) + + self.tx.output.input_source_to_var[source] = tensor_variable + assert "tensor_dict" not in tensor_proxy.node.meta + tensor_proxy.node.meta["tensor_dict"] = _extract_tensor_dict(value) + + # Note: this information is conveyed via subclass_type now + fake_tensor_value = tensor_variable.proxy.node.meta["example_value"] + if maybe_get_fake_mode(fake_tensor_value) is not self.tx.fake_mode: + raise InternalTorchDynamoError("Wrapped Tensor must be this graph's fake") + + grapharg = GraphArg(source, value, False, fake_tensor_value) + tensor_proxy.node.meta["grapharg"] = grapharg + return tensor_variable + + def wrap_numpy_ndarray(self, value): + assert np is not None + assert isinstance(value, np.ndarray) + + source = NumpyTensorSource(self.get_source()) + + from torch._numpy import _util + + readonly = not value.flags.writeable + if readonly: + try: + value.flags.writeable = True + except ValueError: + # One can not easily make nditer elements writable, + # but warning is not the end of the world + assert isinstance(value.base, np.nditer) + + with torch_function_mode_stack_state_mgr.temp_restore_stack(): + try: + tensor_value = _util._try_convert_to_tensor(value) + if readonly: + from torch._prims_common import clone_preserve_strides + + tensor_value = clone_preserve_strides(tensor_value) + except NotImplementedError as e: + # failed to convert to tensor, graph break + unimplemented( + gb_type="failed to convert numpy.ndarray to Tensor", + context=str(value), + explanation="Exception encountered when attempting to convert numpy.ndarray to Tensor", + hints=[], + from_exc=e, + ) + + # We do this because we want the full behavior of guarding the numpy ndarray as if it were + # a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here + # that there's not another great way to do this atm. + # This creates the right graphargs, as well as registration for guards in tensor names and shape env. + LazyVariableTracker.realize_all(VariableBuilder(self.tx, source)(tensor_value)) + example_value = wrap_to_fake_tensor_and_record( + tensor_value, + tx=self.tx, + is_tensor=False, + source=source, + ) + proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(tensor_value), + example_value, + source=source, + ) + cache_real_value_when_export(self.tx, proxy, tensor_value) + options = {"source": source} + numpy_ndarray_variable = wrap_fx_proxy_cls( + target_cls=NumpyNdarrayVariable, + tx=self.tx, + proxy=proxy, + example_value=example_value, + **options, + ) + + self.tx.output.input_source_to_var[source] = numpy_ndarray_variable + example_value = numpy_ndarray_variable.proxy.node.meta["example_value"] + + # pass_arg_as_tensor should be true because we are wrapping a np.ndarray as argument input, and it needs to be + # converted to a tensor. + grapharg = GraphArg( + source, + tensor_value, + pass_arg_as_tensor=True, + fake_tensor=example_value, + is_tensor=True, + example_strong_ref=tensor_value, + ) + proxy.node.meta["grapharg"] = grapharg + + # TODO - Why do we need to set the source of the np ndarray vt back to + # original source. Many tests fails. + numpy_ndarray_variable.source = self.source + + return numpy_ndarray_variable + + def wrap_symint( + self, + value, + dynamism: Optional[DimDynamic] = None, + context: Optional[SymIntSymbolicContext] = None, + ): + assert type(value) is int + + if self.name in self.tx.output.unspec_variable_map: + return self.tx.output.unspec_variable_map[self.name] + + shape_env = self.tx.output.shape_env + if TracingContext.get().force_unspec_int_unbacked_size_like: + wrapped_value = shape_env.create_unbacked_symint() + _constrain_range_for_size(wrapped_value) + self.tx.output.tracked_fakes.append( + TrackedFake(wrapped_value, self.source, None) + ) + + # NB: We do not do float. For motivation, see + # https://docs.google.com/document/d/1INSCdYu1PxXcr43HrD82OudeEuS-qxQe1yZmLg2wy6A/edit + # but the general idea is that we generate kernels that can + # take unspecialized floats and use them in sizevar computation + elif not is_constant_source(self.get_source()): + if dynamism is None and torch._dynamo.config.specialize_int: + # If specialize_int is False, also return + # a constant (but this should have been handled + # in the caller, TBH). But if `dynamism` is set, then actually + # turn it into a symint + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return ConstantVariable.create(value=value, source=self.source) + + name = self.source.name + + frame_state_entry = process_automatic_dynamic( + self.tx, + name, + FrameStateSizeEntry.make_scalar(value), + is_unspecialized_nn_module=self.source.guard_source.is_unspecialized_nn_module(), + ) + + # TODO: This should be dynamic, as we in general do not + # know if bare integers are actually going to be sizevars + # and it is inappropriate to eagerly duck size them with + # real sizevars + normalized_source_name = normalize_source_name(self.source.name) + base_source = self.source + if isinstance(base_source, ChainedSource): + base_source = base_source.get_base() + + if dynamism is not None: + dynamic_dim = dynamism + elif ( + config.automatic_dynamic_shapes + and frame_state_entry.scalar is auto_dynamic + ): + set_feature_use("dynamo.automatic_dynamic_shapes", True) + dynamic_dim = get_automatic_dynamic_shapes_mark_as() + elif ( + isinstance(base_source, LocalSource) + and base_source.dynamism is not None + and dict(base_source.dynamism).get(normalized_source_name, {0: False})[ + 0 + ] + ) or not config.assume_static_by_default: + dynamic_dim = DimDynamic.DYNAMIC + else: # assume_static_by_default + # TODO: dynamic_dim = DimDynamic.STATIC should work but + # for some reason it doesn't + if frame_state_entry.scalar is auto_dynamic: + set_feature_use("dynamo.automatic_dynamic_shapes", False) + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return ConstantVariable.create(value=value) + + wrapped_value = shape_env.create_unspecified_symint_and_symbol( + value, + source=self.source, + dynamic_dim=dynamic_dim, + ) + + self.tx.output.tracked_fakes.append( + TrackedFake(wrapped_value, self.source, context) + ) + else: + assert is_constant_source(self.get_source()) + # TODO: Do I actually need guard for constant source? + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return ConstantVariable.create(value=value, source=self.source) + + assert not isinstance(self.get_source(), RandomValueSource) + install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH)) + + options = {"source": self.get_source()} + + proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(wrapped_value), + wrapped_value, + source=self.get_source(), + ) + + sym_expr = wrapped_value.node.expr + assert isinstance(sym_expr, sympy.Symbol), f"{sym_expr} is not a basic Symbol." + self.tx.output.root_tracer.bound_symbols[sym_expr] = proxy + unspec_var = SymNodeVariable.create(self.tx, proxy, wrapped_value, **options) + self.tx.output.unspec_variable_map[self.name] = unspec_var + + if not is_constant_source(self.get_source()): + proxy.node.meta["grapharg"] = GraphArg( + self.get_source(), + wrapped_value, + pass_arg_as_tensor=False, + fake_tensor=None, + is_tensor=False, + example_strong_ref=wrapped_value, + ) + + return unspec_var + + def wrap_symfloat(self, value): + # SymFloat wrapping is special. We first wrap it in the same way we + # do an unspecialized primitive, and then we item() it into a + # SymFloat. Removal of the item() call is left to a later FX pass, + # mostly because that pass is more easily done after we have lowered + # to ATen ops. (Dynamo doesn't do decomposition right now). + + if self.name in self.tx.output.unspec_variable_map: + return self.tx.output.unspec_variable_map[self.name] + + frame_state_entry = process_automatic_dynamic( + self.tx, + self.source.name, + FrameStateSizeEntry.make_scalar(value), + is_unspecialized_nn_module=self.source.guard_source.is_unspecialized_nn_module(), + ) + + # NB: we specialize on nan input, because our guard modeling in + # ShapeEnv cannot deal with nan + if ( + torch._dynamo.config.specialize_float + or is_constant_source(self.get_source()) + or math.isnan(value) + or math.isinf(value) + # We don't support cudagraphs for now. Without this cudagraphs + # break because they expect all cuda inputs but our tensorified + # float will be a f64[] cpu tensor. Fixes the following test + # when specialize_float=False + # python test/inductor/test_compiled_optimizers.py CompiledOptimizerTests.test_rmsprop_weight_decay_maximize_capturable_cuda # noqa: B950 + or torch._inductor.config.triton.cudagraphs + or justknobs_check("pytorch/compiler:unspecialize_float_killswitch", False) + or ( + config.assume_static_by_default + and frame_state_entry.scalar is not auto_dynamic + ) + ): + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return ConstantVariable.create(value=value, source=self.source) + + # NB: At the point we've gotten here, we don't assume static by + # default. Since we have a guard mechanism, there isn't really any + # downside to trying to be dynamic for float all the time. Unlike + # ints, this won't make codegen perf worse. Modest cost to compile + # time. + + wrapped_value = torch.tensor(value, dtype=torch.float64) + + # We don't support specializing floats for grad checking tensors + # See https://github.com/pytorch/pytorch/pull/140828 for more + # context. + if torch._C._functorch.is_gradtrackingtensor(wrapped_value): + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return ConstantVariable.create(value=value, source=self.source) + + # TODO: Switch RandomValueSource over to use this, this is more + # accurate + assert not isinstance(self.get_source(), RandomValueSource) + install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH)) + + # The FloatTensorSource here is just for pedantic correctness: if you + # guard against an UnspecializedPythonVariable, you need to guard + # against the tensor-ified version of the local, otherwise it's not a + # Tensor. However, we never let the UnspecializedPythonVariable escape + # here, so there should never actually be any guards against this + # source. + source = FloatTensorSource(self.get_source()) + options = {"source": source, "raw_value": value} + + # TODO: Maybe the tensor-ification should be built into the source, + # rather than by special pattern match + example_value = wrap_to_fake_tensor_and_record( + wrapped_value, tx=self.tx, is_tensor=False, source=source + ) + proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(wrapped_value), + example_value, + source=source, + ) + cache_real_value_when_export(self.tx, proxy, wrapped_value) + + unspec_var = wrap_fx_proxy_cls( + UnspecializedPythonVariable, + tx=self.tx, + proxy=proxy, + example_value=example_value, + **options, + ) + assert isinstance(unspec_var, UnspecializedPythonVariable) + self.tx.output.unspec_variable_map[self.name] = unspec_var + + if self.tx.export and not isinstance(self.get_source(), LocalSource): + raise AssertionError( + f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}" + ) + fake_tensor_value = None + example_value = unspec_var.proxy.node.meta["example_value"] + assert is_fake(example_value) + + fake_tensor_value = example_value + assert fake_tensor_value.fake_mode is self.tx.fake_mode, ( + f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode" + "({self.tx.fake_mode}) from InstructionTranslator" + ) + + # There's something a bit incoherent about pass_arg_as_tensor, + # specifically regarding sources. + # + # Specifically, suppose we have "x: float" local argument. We + # eventually end up with an UnspecializedPythonVariable denoting + # torch.as_tensor(x)... but it's source is still L['x'] (which if you + # accessed it directly is a float!) So you gotta be careful when + # setting up your guards, because it's still going to be a float at + # this point, the conversion happens only precisely at the point we're + # actually calling the FX graph. This happens to be what we want for + # shape guard generation, but it's kind of unintuitive. + proxy.node.meta["grapharg"] = GraphArg( + self.get_source(), + wrapped_value, + pass_arg_as_tensor=True, + fake_tensor=fake_tensor_value, + is_tensor=False, + example_strong_ref=wrapped_value, + ) + + # Directly do item to bypass capture_scalar_outputs + r = wrap_fx_proxy( + self.tx, + self.tx.output.create_proxy( + "call_method", + "item", + *proxy_args_kwargs([unspec_var], {}), + ), + ) + self.tx.output.tracked_fakes.append(TrackedFake(r.sym_num, self.source, None)) + + get_metrics_context().set("tensorify_float_attempt", True, overwrite=True) + + return r + + def wrap_unspecialized_primitive(self, value): + if self.name in self.tx.output.unspec_variable_map: + return self.tx.output.unspec_variable_map[self.name] + + wrapped_value = torch.tensor(value) + if not isinstance(self.get_source(), RandomValueSource): + install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH)) + + options = {"source": self.get_source()} + options.update({"raw_value": value}) + + example_value = wrap_to_fake_tensor_and_record( + wrapped_value, tx=self.tx, is_tensor=False, source=self.get_source() + ) + proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(wrapped_value), + example_value, + source=self.get_source(), + ) + cache_real_value_when_export(self.tx, proxy, wrapped_value) + + unspec_var = wrap_fx_proxy_cls( + UnspecializedPythonVariable, + tx=self.tx, + proxy=proxy, + example_value=example_value, + **options, + ) + self.tx.output.unspec_variable_map[self.name] = unspec_var + if not is_constant_source(self.get_source()): + if self.tx.export and not isinstance(self.get_source(), LocalSource): + raise AssertionError( + f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}" + ) + fake_tensor_value = None + if unspec_var.is_python_constant(): + # TODO: when can this happen? + example_value = unspec_var.as_python_constant() + else: + example_value = unspec_var.proxy.node.meta["example_value"] + assert is_fake(example_value) + + fake_tensor_value = example_value + assert fake_tensor_value.fake_mode is self.tx.fake_mode, ( + f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode" + "({self.tx.fake_mode}) from InstructionTranslator" + ) + + proxy.node.meta["grapharg"] = GraphArg( + self.get_source(), + wrapped_value, + pass_arg_as_tensor=True, + fake_tensor=fake_tensor_value, + is_tensor=False, + example_strong_ref=wrapped_value, + ) + return unspec_var + + +def _dataclasses_fields_lambda(obj): + if isinstance(obj, UserDefinedObjectVariable): + value = obj.value + else: + unimplemented( + gb_type="dataclass fields failure", + context=f"obj: {obj}; variable type: {type(obj)}", + explanation=f"Dataclass fields handling fails for {obj}. Expected it to be a user-defined object.", + hints=[], + ) + items = [] + for field in dataclasses.fields(value): + source = None + if obj.source: + base_src = AttrSource(obj.source, "__dataclass_fields__") + source = DictGetItemSource(base_src, field.name) + items.append(UserDefinedObjectVariable(field, source=source)) + return TupleVariable(items) + + +def _clone_input(value, fake_mode): + if isinstance(value, torch.Tensor): + # tensor subclasses will not be converted to FakeTensors and need to be cloned + if not ( + isinstance(value, FakeTensor) + or ( + # Is functional tensor fakeified by this instance of Dynamo + torch._is_functional_tensor(value) + and maybe_get_fake_mode(value) is fake_mode + ) + or value.is_nested + ): + # NB: ensure strides are preserved + value = clone_input(value) + + return value + + +def wrap_fx_proxy( + tx, proxy, example_value=None, subclass_type=None, **options +) -> VariableTracker: + kwargs = { + "tx": tx, + "proxy": proxy, + "example_value": example_value, + "subclass_type": subclass_type, + **options, + } + if subclass_type is None: + return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs) + else: + result = wrap_fx_proxy_cls(target_cls=TensorWithTFOverrideVariable, **kwargs) + result.install_global(tx) + return result + + +def cache_real_value_when_export(tx, proxy, example_value): + if tx.export: + # The legacy behavior for real value cache with subclasses was + # to perform a clone WITHOUT preserving the subclass. It's + # not entirely clear this is what you actually want though. + with torch._C.DisableTorchFunctionSubclass(): + proxy.tracer.real_value_cache[proxy.node] = _clone_input( + example_value, tx.fake_mode + ) + + +# Note: Unfortunate split due to some gross classes existing that subclass TensorVariable +# Should be compositional instead +# +# This is a horribly complicated function that does too many things, to +# explain what it does, let's first talk about the classic usage wrap_fx_proxy +# for a TensorVariable. There are two primary modes of use: +# +# 1. Wrapping a pre-existing Tensor. In this case, example_value is set +# to the pre-existing Tensor. (Note that this example_value will NOT +# be the final example_value we put into node.meta['example_value'], +# instead it is converted into a fake tensor using +# wrap_to_fake_tensor_and_record and registered as a graph input.) +# +# 2. "Wrapping" the result of some Tensor operation Dynamo traced over. In +# this case, example_value is None (and we are going to figure it out +# ourselves using FakeTensors, via get_fake_value, which will run +# the operation represented by the (singular!) FX node referenced by +# the passed in proxy.) +# +# The expectation is you end up with a Tensor output, and everything is +# straightforwardly traced into the graph. +# +# In all cases, the returned `TensorVariable` subclass will have an `example_value` +# and that `example_value` must be a `FakeTensor` produced by the currently running +# instance of Dynamo. +# +# Upon closer inspection, you may notice that there are a slurry of non-Tensor +# output cases in handle_traced_output. What gives? Well, we sometimes trace operations into the +# graph that don't involve tensors. +# +# * Some operators return tuples; we need to recursively handle their +# contents +# +# * Some operators have side effects that will affect subsequent AOTAutograd +# tracing but don't otherwise return anything. +# +# * Some operators return symbolic ints/floats/bools which can go in the +# graph and be traced (but only if they're actually symbolic! If they're +# static you don't want to put them in the graph, which means you +# shouldn't call this function.) +# +# The common theme is that you only use this function WHEN YOU ARE TRACING +# SOMETHING INTO THE GRAPH. This is sort of obvious, because you can't call +# this function without a proxy. +def wrap_fx_proxy_cls( + target_cls, tx, proxy, example_value=None, subclass_type=None, **options +): + if example_value is None: + out = _wrap_fx_proxy( + target_cls, tx, proxy, example_value, subclass_type, **options + ) + elif isinstance(example_value, torch.Tensor): + out = _wrap_fx_preexisting_tensor( + target_cls, tx, proxy, example_value, subclass_type, **options + ) + else: + # This will skip tracing an op and recursively reinvoke wrap_fx_proxy_cls on supported + # data structures. In essence this just handles tracing some other value which may + # contain Fake Tensors or is otherwise proxyable. + out = handle_traced_output( + example_value, tx, proxy, options, subclass_type, target_cls + ) + + if ( + isinstance( + out, + ( + torch._dynamo.variables.TensorVariable, + torch._dynamo.variables.SymNodeVariable, + ), + ) + and proxy.node.op != "placeholder" + ): + tx.output.current_tracer.record_tensor_or_symint_vt(out) + return out + + +# This is 1 above (wrapping a preexisting tensor) +def _wrap_fx_preexisting_tensor( + target_cls, tx, proxy, tensor, subclass_type=None, **options +): + from ..symbolic_convert import InstructionTranslatorBase + + assert isinstance(tensor, torch.Tensor), ( + f"_wrap_fx_preexisting_tensor expected tensor, got {type(tensor)}" + ) + + assert isinstance(tx, InstructionTranslatorBase) + if "guards" in options and options["guards"] is not None: + tx.output.guards.update(options["guards"]) + + # Placeholders always carry example_value in node.meta. + # non-placeholders always have no example_value in node.meta + if proxy.node.op == "placeholder": + assert "example_value" in proxy.node.meta, ( + f"placeholder {proxy} doesn't have 'example_value' in node.meta" + ) + else: + assert "example_value" not in proxy.node.meta, ( + f"{proxy.node.meta['example_value']}" + ) + + # See NOTE: [Deferring tensor pack/unpack hooks until runtime] + with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(): + # Handle recursive calls here + if maybe_get_fake_mode(tensor) is tx.fake_mode: + pass + else: + cache_real_value_when_export(tx, proxy, tensor) + if tx.export: + # The legacy behavior for real value cache with subclasses was + # to perform a clone WITHOUT preserving the subclass. It's + # not entirely clear this is what you actually want though. + with torch._C.DisableTorchFunctionSubclass(): + proxy.tracer.real_value_cache[proxy.node] = _clone_input( + tensor, tx.fake_mode + ) + # NB: If we're ignoring subclass, then the expectation is you will + # take the returned TensorVariable and wrap it into a more + # accurate TensorVariable that is able to track subclass-ness; + # otherwise this is wrong! + kwargs = { + "is_tensor": target_cls + in (TensorVariable, TensorWithTFOverrideVariable), + } + assert "source" in options and options["source"] is not None + kwargs["source"] = options["source"] + tensor = wrap_to_fake_tensor_and_record(tensor, tx=tx, **kwargs) + + if tensor.device.type != "meta" and ( + maybe_get_fake_mode(tensor) is not tx.fake_mode + ): + raise InternalTorchDynamoError( + "`tensor` needs to be a `FakeTensor`" + f"wrapped by this instance of Dynamo. Found: {tensor}" + ) + + return construct_tensor_variable( + target_cls, tx, proxy, tensor, subclass_type, options + ) + + +# This is 2 in the above comment (wrapping the output of a traced op) +def _wrap_fx_proxy( + target_cls, tx, proxy, example_value=None, subclass_type=None, **options +): + from ..symbolic_convert import InstructionTranslatorBase + + assert isinstance(tx, InstructionTranslatorBase) + if "guards" in options and options["guards"] is not None: + tx.output.guards.update(options["guards"]) + + assert "example_value" not in proxy.node.meta, f"{proxy.node.meta['example_value']}" + + # See NOTE: [Deferring tensor pack/unpack hooks until runtime] + with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(): + # with preserve_rng_state(): + # only allow_non_graph_fake in this instance because we handle the non-fake + # cases properly below. + example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True) + + return handle_traced_output( + example_value, tx, proxy, options, subclass_type, target_cls + ) + + +# This handles wrapping of the output of an op traced into the graph +def handle_traced_output(example_value, tx, proxy, options, subclass_type, target_cls): + import torch._functorch.vmap + import torch._subclasses.fake_tensor + import torch._utils + + if isinstance(example_value, torch.Tensor): + # Check if the result is a sparse tensor - + # We generally don't support sparse tensor so better to graph break here + if is_sparse_any(example_value) and ( + not tx.export or not config.capture_sparse_compute + ): + unimplemented( + gb_type="Attempted to wrap sparse Tensor with VariableTracker", + context=str(example_value), + explanation="torch.compile does not support sparse Tensors with VariableTracker", + hints=[*graph_break_hints.SUPPORTABLE], + ) + var = construct_tensor_variable( + target_cls, tx, proxy, example_value, subclass_type, options + ) + # NOTE: [Side effect tracking for newly constructed tensor] + # For newly constructed objects that have mutable attributes, we usually + # construct their VariableTracker via `track_object_new`, but since + # tensor variable construction is a bit different, we handle them + # specially here. This ensures that codegen will actually generate the + # attribute mutations on this tensor. + # + # NOTE we pass a dummy object as the `item` argument to avoid + # constructing a dummy _tensor_ object. The object isn't used for + # newly constructed VTs anyways. + tx.output.side_effects._track_obj( + proxy, var, mutation_type_cls=AttributeMutationNew + ) + return var + elif ( + hasattr(proxy.node.target, "__name__") + and proxy.node.target.__name__ == "set_state" + and isinstance(proxy.node.target.__self__, torch._C.Generator) + or proxy.node.target is torch.random.set_rng_state + ): + return TorchInGraphFunctionVariable(proxy.node.target) + elif ( + proxy.node.target is torch._C._DisableFuncTorch + or proxy.node.target is torch.cuda._is_in_bad_fork + ): + return UserDefinedObjectVariable(example_value) + elif istype(example_value, torch.Size) and all( + isinstance(x, int) for x in example_value + ): + sizes = [ConstantVariable.create(x) for x in example_value] + return SizeVariable(sizes, **options) + elif isinstance(example_value, (tuple, list)): + set_example_value(proxy.node, example_value) + unpacked = [] + for i, val in enumerate(example_value): + if val is None: + # nn.MultiheadAttention() can return None, see issue #175 + unpacked.append( + ConstantVariable.create(None, **options), + ) + else: + proxy_i = proxy.tracer.create_proxy( + kind="call_function", + target=operator.getitem, + args=(proxy, i), + kwargs={}, + ) + + if "source" in options: + # This path should only trigger for list stealing, so it's + # safe to use `GetItemSource`. + assert isinstance(example_value, list) + source = options["source"] + options_i = options.copy() + options_i["source"] = GetItemSource( + base=source, index=i, index_is_slice=False + ) + else: + # use the same options object as parent + options_i = options + + # WARNING: this assumes the same target_cls as this tuple/list call + unpacked.append( + wrap_fx_proxy_cls( + target_cls=target_cls, + tx=tx, + proxy=proxy_i, + example_value=val, + **options_i, + ) + ) + if isinstance(example_value, torch.Size): + # NB: Keep the old proxy around. See SizeVariable for an + # explanation why + return SizeVariable(unpacked, proxy, **options) + elif istype(example_value, tuple): + return TupleVariable(unpacked, **options) + elif istype(example_value, (list, immutable_list)): + return ListVariable(unpacked, **options) + else: + assert ( + example_value.__class__.__module__ == "torch.return_types" + or hasattr(example_value, "_fields") + ), ( + f"expected {example_value.__class__.__module__} == torch.return_types or named tuple but got {type(example_value)}" + ) + return NamedTupleVariable(unpacked, example_value.__class__, **options) + elif example_value is None or proxy.node.target is torch.manual_seed: + return ConstantVariable.create(None, **options) + elif isinstance(example_value, (torch.SymInt, torch.SymFloat, torch.SymBool)): + tx.output.current_tracer.track_produced_symints(example_value, proxy) + set_example_value(proxy.node, example_value) + return SymNodeVariable.create(tx, proxy, example_value, **options) + elif ( + isinstance(example_value, torch.Stream) + and proxy.node.target is get_external_object_by_index + ) or proxy.node.target in [ + device_interface.current_stream + for _, device_interface in get_registered_device_interfaces() + ]: + set_example_value(proxy.node, example_value) + index = None + if proxy.node.target is get_external_object_by_index: + index = proxy.node.args[0] + return StreamVariable(proxy, example_value, index, **options) + elif ( + isinstance(example_value, torch.Event) + and proxy.node.target is get_external_object_by_index + ) or proxy.node.target in [ + device_interface.current_stream + for _, device_interface in get_registered_device_interfaces() + ]: + index = None + if proxy.node.target is get_external_object_by_index: + index = proxy.node.args[0] + set_example_value(proxy.node, example_value) + return EventVariable(proxy, example_value, index, **options) + elif ( + inspect.isclass(proxy.node.target) + and issubclass(proxy.node.target, torch.Event) + ) or proxy.node.target in [ + device_interface.Event + for _, device_interface in get_registered_device_interfaces() + ]: + set_example_value(proxy.node, example_value) + return EventVariable(proxy, example_value, None, **options) + elif proxy.node.target == "query" and proxy.node.op == "call_method": + set_example_value(proxy.node, example_value) + return ConstantVariable(example_value, **options) + elif ( + example_value is not None + and isinstance(example_value, torch.Event) + and proxy.node.target == "record_event" + and proxy.node.op == "call_method" + ): + set_example_value(proxy.node, example_value) + return EventVariable(proxy, example_value, None, **options) + elif isinstance(example_value, int) and ( + proxy.node.target + in [ + torch.sym_int, + getattr, + operator.getitem, + torch._utils._element_size, + torch.seed, + operator.mod, + torch._functorch.vmap._validate_and_get_batch_size, + torch._functorch.predispatch._vmap_increment_nesting, + torch._functorch.predispatch._vmap_decrement_nesting, + # some mac builds are missing torch.distributed.get_rank() + getattr(torch.distributed, "get_rank", _missing), + getattr(torch.distributed, "get_world_size", _missing), + # This always wants to be in the graph, even if the constraint + # results in a constant int + torch._constrain_as_size, + ] + or ( + # TODO: this is a little sus, because we didn't check what the self is + proxy.node.op == "call_method" and proxy.node.target == "bit_length" + ) + ): + set_example_value(proxy.node, example_value) + return ConstantVariable.create(example_value, **options) + elif isinstance(example_value, torch.backends.cuda.SDPAParams): + from .sdpa import SDPAParamsVariable + + set_example_value(proxy.node, example_value) + return SDPAParamsVariable(proxy, **options) + elif isinstance(example_value, bool) and ( + proxy.node.target + in [ + torch._C._are_functorch_transforms_active, + torch._C._functorch.is_batchedtensor, + torch.backends.cuda.is_flash_attention_available, + torch.backends.cuda.can_use_flash_attention, + torch.backends.cuda.can_use_efficient_attention, + torch._C._get_cudnn_sdp_enabled, + torch._C._get_flash_sdp_enabled, + torch._C._get_mem_efficient_sdp_enabled, + torch._C._get_math_sdp_enabled, + torch._C._get_overrideable_sdp_enabled, + "is_integer", + ] + + list(supported_const_comparison_op_values.keys()) + ): + set_example_value(proxy.node, example_value) + return ConstantVariable.create(example_value, **options) + elif isinstance(example_value, (int, float, bool)) and ( + proxy.node.target is call_torchbind + or proxy.node.target is flat_apply + or (proxy.node.op == "call_method" and proxy.node.target == "item") + ): + set_example_value(proxy.node, example_value) + return ConstantVariable.create(example_value, **options) + elif isinstance(example_value, float) or proxy.node.target in ["hex", "__round__"]: + set_example_value(proxy.node, example_value) + return ConstantVariable.create(example_value, **options) + else: + unimplemented( + gb_type="torch.* op returned non-Tensor", + context=f"example_value type: {typestr(example_value)}; op: {proxy.node.op}; target: {proxy.node.target}", + explanation="torch.* ops that return a non-Tensor cannot be traced into the Dynamo FX graph output", + hints=[], + ) + + +def infer_subclass_type(value): + if type(value) in ( + torch.Tensor, + torch.nn.Parameter, + torch._subclasses.fake_tensor.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ) or is_traceable_wrapper_subclass(value): + # Ordinarily, we would fakeify a tensor so that it can get dynamic + # shapes and be computed on without triggering actual operations. + # However, how can we fakeify a tensor subclass? Ordinary + # inheritance (nor multiple inheritance) won't work work. + # + # Instead, our plan is to *manually simulate* the tensor subclass + # inheriting from a fake tensor with dynamo. This means our + # data representation for a tensor subclass will be a fake tensor + # + tensor subclass type + any extra data the subclass may have + # been storing on the tensor. Because all Python accesses are + # mediated through TensorWithTFOverrideVariable, we can ensure + # that we dispatch differently, e.g., according to + # __torch_function__ + # + # To simplify things for now, the __dict__ tracking bits haven't + # been implemented yet, but they can be added into this design at + # a later point in time. + return None + else: + return type(value) + + +def get_specialized_props(target_cls, tx, example_value, subclass_type): + specialized_props = target_cls.specialize(example_value) + # TODO: not sure about this fake mode test + if ( + isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor) + and example_value.fake_mode is tx.fake_mode + ): + if subclass_type: + tensor_type = subclass_type + elif isinstance(example_value, torch.nn.Parameter): + tensor_type = torch.nn.Parameter + elif isinstance(example_value, torch.nn.Buffer): + tensor_type = torch.nn.Buffer + else: + tensor_type = torch.Tensor + specialized_props["class_type"] = tensor_type + + return specialized_props + + +def construct_tensor_variable( + target_cls, tx, proxy, example_value, subclass_type, options +): + """ + Actually construct a tensor variable after all the pre-processing from + wrapping a pre-existing or newly created tensor value. + """ + # NB: In most (all?) cases, this does not actually do a clone. + # (WARNING: this means that if we mutate metadata on the fake + # tensor, the stored example value will update too!) + example_value = _clone_input(example_value, tx.fake_mode) + set_example_value(proxy.node, example_value) + # We bind the unbacked symints in sizes/trdies of tensor lazily. + # So that subgraphs can access the unbacked symbol's proxy in parent graph + # when lifting unbacked symbols of input tensors to subgraph inputs. + # We do it lazily because the tensor may not be used in subgraphs. + if proxy.node.op != "placeholder": + tx.output.current_tracer.track_produced_symints(example_value, proxy) + options.update(get_specialized_props(target_cls, tx, example_value, subclass_type)) + return target_cls(proxy, **options) + + +def get_automatic_dynamic_shapes_mark_as(): + if config.automatic_dynamic_shapes_mark_as == "dynamic": + return DimDynamic.DYNAMIC + elif config.automatic_dynamic_shapes_mark_as == "unbacked": + return DimDynamic.SIZE_LIKE_UNBACKED + elif config.automatic_dynamic_shapes_mark_as == "oblivious": + return DimDynamic.OBLIVIOUS_SIZE + else: + raise ValueError( + f"invalid automatic_dynamic_shapes_mark_as = {config.automatic_dynamic_shapes_mark_as}" + ) + + +_DYNAMIC_SOURCES: Optional[set[str]] = None +_DYNAMIC_SOURCES_CONFIG_HASH: Optional[int] = None + + +def get_dynamic_sources() -> set[str]: + global _DYNAMIC_SOURCES, _DYNAMIC_SOURCES_CONFIG_HASH + + current_hash = hash(torch.compiler.config.dynamic_sources) + + # If we have already calculated the sources and the config hasn't changed, return cached result + if _DYNAMIC_SOURCES is not None and _DYNAMIC_SOURCES_CONFIG_HASH == current_hash: + return _DYNAMIC_SOURCES + + # Config has changed or first time, (re)calculate the sources + _DYNAMIC_SOURCES = { + s + for s in torch.compiler.config.dynamic_sources.replace(" ", "").split(",") + if s + } + _DYNAMIC_SOURCES_CONFIG_HASH = current_hash + + return _DYNAMIC_SOURCES + + +def is_dynamic_source(source_name: str) -> bool: + dynamic_sources = get_dynamic_sources() + for pattern in dynamic_sources: + if pattern == source_name or re.match(pattern, source_name): + log.debug( + "%s was marked dynamic due to dynamic source allowlist pattern: %s", + source_name, + pattern, + ) + return True + return False + + +def record_automatic_dynamic( + tx: "InstructionTranslator", name: str, e: torch.Tensor +) -> FrameStateSizeEntry: + # This mimics stride inference algorithm in _create_symbolic_sizes_strides_storage_offset + ex_size = e.size() + if not is_sparse_any(e): + ex_stride = e.stride() + dim = e.dim() + + stride = [None] * dim + pending = [(ex_stride[i], -i) for i in range(dim)] + pending.sort(key=_nested_int_aware_sort) + candidates = {} + for i_stride, neg_i in pending: + i = -neg_i + stride[i] = candidates.get(i_stride, i_stride) + candidates.setdefault(i_stride * ex_size[i], InferStride(i)) + else: + stride = [] + + return process_automatic_dynamic( + tx, name, FrameStateSizeEntry.make_tensor(tuple(ex_size), tuple(stride)) + ) + + +_UNBACKED_SOURCES: Optional[set[str]] = None +_UNBACKED_SOURCES_CONFIG_HASH: Optional[int] = None + + +def get_unbacked_sources() -> set[str]: + global _UNBACKED_SOURCES, _UNBACKED_SOURCES_CONFIG_HASH + + current_hash = hash(torch.compiler.config.unbacked_sources) + + # If we have already calculated the sources and the config hasn't changed, return cached result + if _UNBACKED_SOURCES is not None and _UNBACKED_SOURCES_CONFIG_HASH == current_hash: + return _UNBACKED_SOURCES + + # Config has changed or first time, (re)calculate the sources + _UNBACKED_SOURCES = { + s + for s in torch.compiler.config.unbacked_sources.replace(" ", "").split(",") + if s + } + _UNBACKED_SOURCES_CONFIG_HASH = current_hash + + return _UNBACKED_SOURCES + + +def is_unbacked_source(source_name: str) -> bool: + unbacked_sources = get_unbacked_sources() + for pattern in unbacked_sources: + if pattern == source_name or re.match(pattern, source_name): + log.debug( + "%s was marked unbacked due to unbacked source allowlist pattern: %s", + source_name, + pattern, + ) + return True + return False + + +# Performs automatic dynamic dim determination. +# Returns a SymbolicContext +def _automatic_dynamic( + e, tx, source, static_shapes, outer_only=False +) -> SymbolicContext: + # strided NT not supported + if e.is_nested and not isinstance( + e, torch.nested._internal.nested_tensor.NestedTensor + ): + unimplemented( + gb_type="Encountered strided NestedTensor in automatic dynamic dim determination", + context="", + explanation="torch.compile does not support strided NestedTensor", + hints=[], + ) + + name = source.name + prior_policy = tx.output.tracing_context.tensor_to_context.get(e, None) + shape_env_to_source_to_symbol_cache = ( + prior_policy.shape_env_to_source_to_symbol_cache if prior_policy else None + ) + + # Get base context if the tensor is a view + view_base_context: Optional[SymbolicContext] = None + if e._is_view(): + base_source = AttrSource(source, "_base") + view_base_context = _automatic_dynamic(e._base, tx, base_source, static_shapes) + + if is_traceable_wrapper_subclass(e) and not outer_only: + # Get symbolic context for outer tensor + outer_context = _automatic_dynamic( + e, tx, source, static_shapes, outer_only=True + ) + + # Get symbolic contexts for inner tensors + inner_contexts = {} # mapping from attr -> symbolic context + attrs, _ = type(e).__tensor_flatten__(e) + for attr in attrs: + inner_tensor = getattr(e, attr) + inner_source = AttrSource(source, attr) + inner_contexts[attr] = _automatic_dynamic( + inner_tensor, tx, inner_source, static_shapes + ) + + return SubclassSymbolicContext( + dynamic_sizes=outer_context.dynamic_sizes, + dynamic_strides=outer_context.dynamic_strides, + constraint_sizes=outer_context.constraint_sizes, + constraint_strides=outer_context.constraint_strides, + view_base_context=view_base_context, + tensor_source=outer_context.tensor_source, + shape_env_to_source_to_symbol_cache=outer_context.shape_env_to_source_to_symbol_cache, + inner_contexts=inner_contexts, + ) + + if static_shapes and not is_dynamic_source(name): + return StatefulSymbolicContext( + dynamic_sizes=[DimDynamic.STATIC] * e.dim(), + dynamic_strides=[DimDynamic.INFER_STRIDE] * e.dim(), + constraint_sizes=[None] * e.dim(), + constraint_strides=[None] * e.dim(), + view_base_context=view_base_context, + tensor_source=source, + shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache, + ) + + # We preserve the dynamism of inputs. For example, when users call + # make_fx(torch.cond, tracing_mode="symbolic")(*args), inputs have SymInt sizes. + from torch.fx.experimental.symbolic_shapes import is_nested_int + + if any(isinstance(s, SymInt) and not is_nested_int(s) for s in e.size()): + return StatefulSymbolicContext( + dynamic_sizes=[ + DimDynamic.DYNAMIC if isinstance(s, SymInt) else DimDynamic.STATIC + for s in e.size() + ], + dynamic_strides=[DimDynamic.INFER_STRIDE] * e.dim(), + constraint_sizes=[None] * e.dim(), + constraint_strides=[None] * e.dim(), + view_base_context=view_base_context, + tensor_source=source, + shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache, + ) + + # Prep for automatic dynamic + frame_state_entry = record_automatic_dynamic(tx, name, e) + + # TODO: index export_constraints ahead of time so we don't have to + # do a linear scan every time here + t_id = id(e) + dim2constraint = {} + + def update_dim2constraint(dim, constraint_range, name): + if dim in dim2constraint: + from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint + + old_constraint_range, old_name = dim2constraint[dim] + new_constraint_range = StrictMinMaxConstraint( + vr=constraint_range.vr & old_constraint_range.vr, + warn_only=False, + ) + # It is possible for (non-None) old_name and name to be different + # but this will only happen the corresponding Dims can be derived equal. + new_name = old_name or name + dim2constraint[dim] = new_constraint_range, new_name + else: + dim2constraint[dim] = constraint_range, name + + from torch.export.dynamic_shapes import _RelaxedConstraint + + if tx.output.export_constraints: + for constraint in tx.output.export_constraints: + if isinstance(constraint, _RelaxedConstraint): + continue + if constraint.t_id == t_id: + update_dim2constraint( + constraint.dim, constraint.constraint_range, constraint.name + ) + + dynamic_sizes = [] + dynamic_strides = [] + constraint_sizes = [] + constraint_strides = [] + specialize_on = [] + for i in range(e.dim()): + # NB: mark dynamic has precedence over static + marked_strict_unbacked = i in getattr( + e, "_dynamo_strict_unbacked_indices", set() + ) + marked_unbacked = i in getattr(e, "_dynamo_unbacked_indices", set()) + marked_dynamic = i in getattr(e, "_dynamo_dynamic_indices", set()) + marked_weak_dynamic = i in getattr(e, "_dynamo_weak_dynamic_indices", set()) + marked_static = i in getattr(e, "_dynamo_static_indices", set()) + + specialize_on.append(getattr(e, "_specialize_on", {}).get(i, [])) + + # Reflect the user directive in the frame_state + # For dynamic, apply None always + + normalized_source_name = normalize_source_name(source.name) + base_source = source + if isinstance(base_source, ChainedSource): + base_source = base_source.get_base() + + if marked_dynamic or ( + isinstance(base_source, LocalSource) + and base_source.dynamism is not None + and dict(base_source.dynamism).get(normalized_source_name, {i: False})[i] + ): + # TODO: This can be batched + # TODO: Doing this here is kind of sus, maybe better to set this + # up when we initially created the FrameStateSizeEntry to bong + # into the mutable state + log.debug("automatic dynamic %s marked dynamic", name) + mark_size = [auto_unset] * e.dim() + mark_size[i] = auto_dynamic + frame_state_entry |= FrameStateSizeEntry.make_size(size=mark_size) + + # NB: both static and dynamic have precedence over + automatic_dynamic_size = ( + config.automatic_dynamic_shapes and frame_state_entry.is_size_dynamic(i) + ) + # NB: previously, if size was dynamic, we wouldn't make its stride + # dynamic. But now, because of InferStride concept, we will properly + # not make stride dynamic even if it's wobbling + automatic_dynamic_stride = ( + config.automatic_dynamic_shapes and frame_state_entry.is_stride_dynamic(i) + ) + + if is_dynamic_source(name): + log.debug("%s marked dynamic via source whitelist", name) + automatic_dynamic_size = True + + if is_unbacked_source(name): + log.debug("%s marked unbacked via source whitelist", name) + automatic_dynamic_size = True + + automatic_dynamic = automatic_dynamic_size or automatic_dynamic_stride + + # We will process constraints first, as they will imply that we + # have a dynamic dimension + # Precedence: export constraints > eager constraints + constraint = dim2constraint.get(i) + if constraint is None: + constraint_size = None + constraint_stride = None + if marked_dynamic and not config.allow_ignore_mark_dynamic: + # constraint_stride is deliberaly kept None because no easy way to provide value ranges for mark dynamic + constraint_stride = None + if hasattr(e, "_dynamo_dynamic_range"): + dim_range = [ + dr for dr in e._dynamo_dynamic_range if dr.dim == i + ].pop() + if dim_range.min is None and dim_range.max is None: + constraint_size = RelaxedUnspecConstraint(warn_only=False) + else: + from torch.fx.experimental.symbolic_shapes import ( + StrictMinMaxConstraint, + ) + + constraint_size = StrictMinMaxConstraint( + vr=ValueRanges(lower=dim_range.min, upper=dim_range.max), + warn_only=False, + ) + else: + constraint_size = RelaxedUnspecConstraint(warn_only=False) + elif marked_strict_unbacked: + constraint_size = RelaxedUnspecConstraint(warn_only=False) + elif not marked_static and automatic_dynamic: + set_feature_use("dynamo.automatic_dynamic_shapes", True) + if automatic_dynamic_size: + constraint_size = RelaxedUnspecConstraint(warn_only=True) + if automatic_dynamic_stride: + constraint_stride = RelaxedUnspecConstraint(warn_only=True) + else: + if not marked_static and not config.automatic_dynamic_shapes: + set_feature_use("dynamo.automatic_dynamic_shapes", False) + constraint_size = None + constraint_stride = None + else: + constraint_size, name_ = constraint + constraint_stride = None + dim_name = f"{name}.size()[{i}]" + tx.output.shape_env.source_name_to_debug_name[dim_name] = name_ + constraint_sizes.append(constraint_size) + constraint_strides.append(constraint_stride) + + if marked_unbacked or is_unbacked_source(name): + dynamic_size = DimDynamic.SIZE_LIKE_UNBACKED + elif ( + constraint_size is not None + or marked_dynamic + or marked_weak_dynamic + or is_nested_int(e.size()[i]) + ): + # NB: We could assert static_shapes is False here, but it + # seems better to allow the user to override symbolic_context in this + # case + if automatic_dynamic: + dynamic_size = get_automatic_dynamic_shapes_mark_as() + else: + dynamic_size = DimDynamic.DYNAMIC + elif static_shapes or config.assume_static_by_default or marked_static: + dynamic_size = DimDynamic.STATIC + else: + # TODO: When does this show up? + dynamic_size = DimDynamic.DUCK + + if constraint_stride is not None: + dynamic_stride = DimDynamic.DYNAMIC + else: + dynamic_stride = DimDynamic.INFER_STRIDE + + dynamic_sizes.append(dynamic_size) + dynamic_strides.append(dynamic_stride) + + return StatefulSymbolicContext( + dynamic_sizes=dynamic_sizes, + dynamic_strides=dynamic_strides, + constraint_sizes=constraint_sizes, + constraint_strides=constraint_strides, + specialize_on=specialize_on, + view_base_context=view_base_context, + tensor_source=source, + shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache, + ) + + +# See note [Tensor Fakification and Symbol Caching] +def wrap_to_fake_tensor_and_record( + e, tx, *, source: Optional[Source], is_tensor: bool, parent_context=None +): + if ( + type(e) in (torch.Tensor, torch.nn.Parameter, FakeTensor) + or isinstance(e, torch.Tensor) + or is_traceable_wrapper_subclass(e) + ): + assert source is not None + static_shapes, _reason = tensor_always_has_static_shape( + e, + is_tensor, + tensor_source=source, + ) + + if not parent_context: + symbolic_context = _automatic_dynamic(e, tx, source, static_shapes) + else: + # Parent contexts are passed in when we are recursively creating + # fake tensors for subclasses. A better design would be not to create a + # parent/child relationship, but to recursively call _automatic_dynamic + # as we recursively call wrap_to_fake_tensor_and_record. This runs + # into bugs around how meta_utils knows and works to create fake tensors + # with tensor subclasses. Ideally, dynamo would drive both the recursive + # wrap_to_fake_tensor_and_record and _automatic_dynamic policy creation. + assert isinstance(source, AttrSource) + inner_context_name = source.member + symbolic_context = parent_context.inner_contexts[inner_context_name] + + log.debug( + "wrap_to_fake %s %s %s %s", + source.name, + tuple(e.shape), + symbolic_context, + type(e), + ) + + # Note [enable_python_dispatcher in dynamo] + # Dynamo disables itself when it runs fake tensor prop, which means that tensor subclasses + # have no way to know (purely based off of global state) if they are currently being run under compile or not. + # we use enable_python_dispatcher mainly to tweak the DispatchKeyState so that subclass authors + # can check it to know if they are running in an eager context or not + with enable_python_dispatcher(): + fake_e = wrap_fake_exception( + lambda: tx.fake_mode.from_tensor( + e, + source=source, + symbolic_context=symbolic_context, + ) + ) + if ( + source is not None + and isinstance(fake_e, FakeTensor) + and (sym_val := fake_e.item_memo) is not None + ): + tx.output.tracked_fakes.append( + TrackedFake(sym_val, CallMethodItemSource(source), symbolic_context) + ) + + if is_traceable_wrapper_subclass(fake_e): + attrs, _ = fake_e.__tensor_flatten__() + for attr in attrs: + fake_inner = getattr(fake_e, attr) + inner = getattr(e, attr) + inner_source = AttrSource(source, attr) + wrap_to_fake_tensor_and_record( + inner, + tx, + source=inner_source, + is_tensor=isinstance(fake_inner, torch.Tensor), + parent_context=symbolic_context, + ) + + tx.output.tracing_context.tensor_to_context[e] = symbolic_context + if is_sparse_any(fake_e): + # TODO: for TensorGuards, this eventually may need more + # fields for the size/stride of any other constituents + values = fake_e._values() if fake_e.is_sparse else fake_e.values() + tx.output.input_source_to_sizes_strides[source] = { + "size": fake_e.size(), + # TODO: revise this, but for now this stride instead of () + # avoids SegFault with PYTORCH_TEST_WITH_DYNAMO=1 + "stride": (1,) * fake_e.ndim, + "values_size": values.size(), + "values_stride": values.stride(), + } + else: + tx.output.input_source_to_sizes_strides[source] = { + "size": fake_e.size(), + "stride": fake_e.stride(), + } + + if ( + is_tensor + and not (static_shapes and source.is_specialized_nn_module()) + and not is_constant_source(source) + ): + tx.output.tracked_fakes.append( + TrackedFake(fake_e, source, symbolic_context) + ) + tx.output.tracked_fakes_id_to_source[id(e)].append(source) + + return fake_e + else: + return e + + +class SourcelessBuilder: + """ + Like builder, but stateless and does not require a source. Useful for simple type->VT objects, or objects + that are being created/evaporated during inlining (ex: consider a locally made list of tensors we then iterate over + .), such a list should not show up as an artifact from inputs, nor in reconstruction, nor in the graph. However, + there may be reasons to represent it as a ListVariable internally. + + NOTE - Objects produced here are born UNGUARDED due to the nature of sources! + + NOTE - This class is very new! It will have some rough edges, but it was created to stem the bleeding of giant + if/else type->VariableTracker trees that were cropping up all over dynamo. + """ + + def __init__(self) -> None: + raise AssertionError("Use SourcelessBuilder.create()") + + @staticmethod + def create(tx: "InstructionTranslator", value) -> VariableTracker: + value_type = type(value) + fast_handler = SourcelessBuilder._type_handlers.get(value_type) + if fast_handler: + return fast_handler(tx, value) + + if isinstance(value, VariableTracker): + # This is always valid to call, and useful for recursive calls. + return value + elif isinstance(value, dataclasses._HAS_DEFAULT_FACTORY_CLASS): + return UserDefinedObjectVariable(value) + elif ConstantVariable.is_literal(value): + return ConstantVariable.create(value) + elif callable(value) and trace_rules.lookup_callable(value) is not None: + if trace_rules.is_callable_allowed(value): + tx.output.has_user_defined_allowed_in_graph = True + return trace_rules.lookup_callable(value)(value) + elif callable(value) and UserDefinedClassVariable.is_supported_new_method( + value + ): + # NamedTuple._make uses an alias of tuple.__new__ + obj = trace_rules.lookup_callable(value.__self__)(value.__self__) + return GetAttrVariable(obj, "__new__") + elif is_function_or_wrapper(value): + return trace_rules.lookup(value)(value) + elif isinstance( + value, (enum.Enum, torch.DispatchKey, torch._C._functorch.TransformType) + ): + return EnumVariable(value) + elif isinstance(value, (type, abc.ABCMeta)): + return UserDefinedClassVariable(value) + elif isinstance(value, types.MethodWrapperType): + return MethodWrapperVariable(value) + elif ( + isinstance(value, types.MethodType) + # We only want to support sourceless class objects here + # An instance variable is not allowed and it should have source + and isinstance(value.__self__, (type, abc.ABCMeta)) + ): + # value is a classmethod + assert getattr(value.__self__, value.__func__.__name__) == value + cls_obj_vt = SourcelessBuilder.create(tx, value.__self__) + try: + return cls_obj_vt.var_getattr(tx, value.__func__.__name__) + except NotImplementedError: + pass # failthrough to unimplemented branch + elif isinstance(value, torch.fx.graph_module.GraphModule): + return SourcelessGraphModuleVariable(value) + elif isinstance(value, torch.utils._pytree.TreeSpec): + return UserDefinedObjectVariable(value) + elif PlacementVariable.is_placement(value): + return PlacementVariable(value) + elif DeviceMeshVariable.is_device_mesh(value): + return DeviceMeshVariable(value) + elif value is functools.wraps: + return FunctoolsWrapsVariable(value) + elif isinstance(value, re.Pattern): + return ConstantLikeVariable(value) + elif isinstance(value, torch._dynamo.variables.lazy.LazySymNodeFormatString): + return ConstantVariable.create(str(value)) + elif isinstance(value, type(torch._higher_order_ops.flex_attention_backward)): + return torch._dynamo.variables.higher_order_ops.FlexAttentionBackwardHighOrderVariable( + value + ) + elif isinstance(value, (types.GenericAlias, types.UnionType)): + return TypingVariable(value) + elif is_namedtuple(value): + output = [ + SourcelessBuilder.create(tx, getattr(value, name)) + for name in namedtuple_fields(type(value)) + ] + return NamedTupleVariable(output, tuple_cls=type(value)) + elif ( + isinstance(value, torch.SymInt) + and value.node.expr in tx.output.bound_symbols + ): + proxy = tx.output.bound_symbols[value.node.expr] + return SymNodeVariable.create(tx, proxy) + unimplemented( + gb_type="Unexpected type in sourceless builder", + context=f"{value_type.__module__}.{value_type.__qualname__}", + explanation=f"SourcelessBuilder.create does not know how to wrap {value_type}", + hints=[*graph_break_hints.DYNAMO_BUG], + ) + + @staticmethod + def wrap_constant_literal(value): + assert ConstantVariable.is_literal(value) + return ConstantVariable.create(value=value) + + @staticmethod + def make_type_handlers(): + create = SourcelessBuilder.create + handlers = {} + for t in common_constant_types: + handlers[t] = lambda tx, value: ConstantVariable(value) + handlers[set] = lambda tx, value: SetVariable( + [create(tx, x) for x in value], mutation_type=ValueMutationNew() + ) + handlers[dict] = lambda tx, value: ConstDictVariable( + {create(tx, k): create(tx, v) for k, v in value.items()}, + type(value), + mutation_type=ValueMutationNew(), + ) + handlers[list] = lambda tx, value: ListVariable( + [create(tx, x) for x in value], mutation_type=ValueMutationNew() + ) + handlers[tuple] = lambda tx, value: TupleVariable( + [create(tx, x) for x in value] + ) + handlers[torch.Size] = lambda tx, value: SizeVariable( + [create(tx, x) for x in value] + ) + handlers[collections.OrderedDict] = handlers[dict] + handlers[immutable_dict] = handlers[dict] + handlers[immutable_list] = handlers[list] + handlers[random.Random] = lambda tx, value: RandomClassVariable() + handlers[types.ModuleType] = lambda tx, value: PythonModuleVariable(value) + + handlers[torch.DispatchKeySet] = lambda tx, value: DispatchKeySetVariable( + value, mutation_type=ValueMutationNew() + ) + handlers[torch._functorch.pyfunctorch.FuncTorchInterpreter] = ( + lambda tx, value: FuncTorchInterpreterVariable( + value, mutation_type=ValueMutationNew() + ) + ) + + handlers[torch.distributions.constraints._Real] = ( + lambda tx, value: UserDefinedObjectVariable( + value, mutation_type=ValueMutationNew() + ) + ) + handlers[torch.distributions.constraints._Interval] = ( + lambda tx, value: UserDefinedObjectVariable( + value, mutation_type=ValueMutationNew() + ) + ) + handlers[torch.distributions.constraints.Constraint] = ( + lambda tx, value: UserDefinedObjectVariable( + value, mutation_type=ValueMutationNew() + ) + ) + + def passthrough(tx: "InstructionTranslator", value): + return value + + for cls in VariableTrackerMeta.all_subclasses: + handlers[cls] = passthrough + return handlers + + +SourcelessBuilder._type_handlers = SourcelessBuilder.make_type_handlers() + + +class SourcelessUserDefinedObjectBuilder: + """ + SourceLessBuilder does not return a UserDefinedObjectVariable, but in some + cases it might be ok to return UserDefinedObjects. In such case, use this + builder. + """ + + def __init__(self) -> None: + raise AssertionError("Use SourcelessUserDefinedObjectBuilder.create()") + + @staticmethod + def create(tx: "InstructionTranslator", value) -> VariableTracker: + value_type = type(value) + if issubclass(value_type, MutableMapping): + return MutableMappingVariable(value, mutation_type=ValueMutationNew()) + elif isinstance(value, torch.nn.Module): + return UnspecializedNNModuleVariable( + value, mutation_type=ValueMutationNew() + ) + else: + return UserDefinedObjectVariable(value, mutation_type=ValueMutationNew()) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/builtin.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/builtin.py new file mode 100644 index 0000000000000000000000000000000000000000..44fca37314a62b79df1374270065f6d5837bfaab --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/builtin.py @@ -0,0 +1,3286 @@ +""" +Built-in function and type variable tracking for TorchDynamo's symbolic execution. + +This module contains variable tracker classes for Python built-in functions, types, +and operations during graph compilation. It handles symbolic execution of: + +- Built-in functions (len, getattr, isinstance, etc.) +- Type constructors (int, float, str, list, dict, etc.) +- Built-in operators and methods +- Special Python constructs (super, hasattr, etc.) + +Key classes: +- BuiltinVariable: Tracks built-in functions and handles their execution +- TypeVariable: Manages type constructor calls and type checking +- SuperVariable: Handles super() calls in class hierarchies + +These variable trackers ensure that built-in Python operations are correctly +handled during symbolic execution, either by executing them directly when safe +or by creating appropriate graph nodes when needed. +""" + +import contextlib +import functools +import inspect +import itertools +import logging +import math +import operator +import sys +import types +import typing +import unittest +from collections import defaultdict, OrderedDict +from collections.abc import Callable, Iterable, KeysView, Sequence +from typing import Any, cast, TYPE_CHECKING, Union + +import torch +from torch import sym_float, sym_int +from torch._subclasses.meta_utils import is_sparse_any +from torch.overrides import BaseTorchFunctionMode +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from .. import config, graph_break_hints, polyfills, variables +from ..exc import ( + AttributeMutationError, + ObservedAttributeError, + ObservedUserStopIteration, + raise_observed_exception, + unimplemented, + Unsupported, + UserError, + UserErrorType, +) +from ..guards import GuardBuilder, install_guard +from ..replay_record import DummyModule +from ..source import ( + AttrSource, + GetItemSource, + GlobalSource, + is_constant_source, + Source, + TypeSource, +) +from ..utils import ( + check_constant_args, + check_numpy_ndarray_args, + check_unspec_or_constant_args, + check_unspec_python_args, + cmp_name_to_op_mapping, + dict_methods, + extract_fake_example_value, + frozenset_methods, + get_fake_value, + guard_if_dyn, + is_tensor_getset_descriptor, + is_wrapper_or_member_descriptor, + istype, + numpy_operator_wrapper, + proxy_args_kwargs, + raise_args_mismatch, + set_methods, + str_methods, + tensortype_to_dtype, +) +from .base import AsPythonConstantNotImplementedError, ValueMutationNew, VariableTracker +from .constant import ConstantVariable +from .dicts import ( + ConstDictVariable, + DefaultDictVariable, + DictKeysVariable, + DictViewVariable, + FrozensetVariable, + is_hashable, + SetVariable, +) +from .lists import ( + BaseListVariable, + ListIteratorVariable, + ListVariable, + SizeVariable, + TupleIteratorVariable, + TupleVariable, +) +from .streams import EventVariable, StreamVariable +from .tensor import ( + FakeItemVariable, + supported_comparison_ops, + SymNodeVariable, + TensorVariable, + UnspecializedPythonVariable, +) +from .user_defined import ( + MutableMappingVariable, + UserDefinedDictVariable, + UserDefinedObjectVariable, + UserDefinedVariable, +) + + +if TYPE_CHECKING: + # Cyclic dependency... + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + +log = logging.getLogger(__name__) + + +IN_PLACE_DESUGARING_MAP = { + operator.iadd: operator.add, + operator.isub: operator.sub, + operator.imul: operator.mul, + operator.ifloordiv: operator.floordiv, + operator.itruediv: operator.truediv, + operator.imod: operator.mod, + operator.imatmul: operator.imatmul, + operator.ilshift: operator.lshift, + operator.irshift: operator.rshift, + operator.ipow: operator.pow, + operator.iand: operator.and_, + operator.ior: operator.or_, + operator.ixor: operator.xor, +} + + +_HandlerCallback = Callable[ + ["InstructionTranslator", typing.Any, typing.Any], VariableTracker | None +] +_TrackersType = Union[type[VariableTracker], tuple[type[VariableTracker], ...]] +polyfill_fn_mapping = { + operator.eq: polyfills.cmp_eq, + operator.ne: polyfills.cmp_ne, + operator.lt: polyfills.cmp_lt, + operator.le: polyfills.cmp_le, + operator.gt: polyfills.cmp_gt, + operator.ge: polyfills.cmp_ge, +} + +bin_ops = ( + operator.pow, + operator.mul, + operator.matmul, + operator.floordiv, + operator.truediv, + operator.mod, + operator.add, + operator.lt, + operator.gt, + operator.ge, + operator.le, + operator.ne, + operator.eq, + operator.sub, + operator.ipow, + operator.imul, + operator.imatmul, + operator.ifloordiv, + operator.itruediv, + operator.imod, + operator.iadd, + operator.isub, +) + +bin_int_ops = ( + operator.and_, + operator.or_, + operator.xor, + operator.iand, + operator.ixor, + operator.ior, +) + +un_int_ops = (operator.invert,) + +tensor_and_int_ops = ( + operator.lshift, + operator.rshift, + operator.ilshift, + operator.irshift, + operator.getitem, +) + +un_ops = ( + operator.abs, + operator.pos, + operator.neg, + operator.not_, # Note: this has a local scalar dense call + operator.length_hint, +) + +BUILTIN_TO_TENSOR_FN_MAP: dict[Callable[..., Any], Callable[..., Any]] = {} + +# These functions represent the r* versions of the above ops +# Basically, if __add__(1, Tensor) is called, it is translated +# to __radd__(Tensor, 1). +# In the builtin var, we check if there is a tensor in the first args position, +# if not, we swap the args and use the r* version of the op. +BUILTIN_TO_TENSOR_RFN_MAP: dict[Callable[..., Any], Callable[..., Any]] = {} + + +def populate_builtin_to_tensor_fn_map() -> None: + global BUILTIN_TO_TENSOR_FN_MAP + if len(BUILTIN_TO_TENSOR_FN_MAP) > 0: + # Only populate once; after there are elements present no need to + # repopulate + return + most_recent_func: Callable[..., Any] | None = None + + class GetMethodMode(BaseTorchFunctionMode): + """ + Mode to extract the correct methods from torch function invocations + (Used to get the correct torch.Tensor methods from builtins) + """ + + def __torch_function__( + self, + func: Callable[..., Any], + types: Any, + args: Sequence[Any] = (), + kwargs: dict[str, Any] | None = None, + ) -> Any: + kwargs = kwargs or {} + nonlocal most_recent_func + most_recent_func = func + return func(*args, **kwargs) + + inp0 = torch.ones(1) + inp1 = torch.ones(1) + inp0_int = torch.ones(1, dtype=torch.int32) + inp1_int = torch.ones(1, dtype=torch.int32) + with GetMethodMode(): + setups_and_oplists: list[tuple[Callable[..., Any], Iterable[Any]]] = [ + (lambda o: o(inp0), un_ops), + (lambda o: o(inp0_int), un_int_ops), + (lambda o: o(inp0, inp1), bin_ops), + (lambda o: o(inp0_int, inp1_int), bin_int_ops), + (lambda o: o(inp0_int, 0), tensor_and_int_ops), + ] + for setup_fn, op_list in setups_and_oplists: + for op in op_list: + setup_fn(op) + assert most_recent_func is not None + BUILTIN_TO_TENSOR_FN_MAP[op] = most_recent_func + + # gather the reverse functions + rsetups_and_oplists: list[tuple[Callable[..., Any], Iterable[Any]]] = [ + ( + lambda o: o(1, inp1), + bin_ops, + ), # Get r* ops, (ex. __sub__(int, Tensor) -> __rsub__(Tensor, int)) + (lambda o: o(1, inp1_int), bin_int_ops), + (lambda o: o(0, inp0_int), tensor_and_int_ops), + ] + + rskips = {operator.matmul, operator.imatmul, operator.getitem} + for setup_fn, op_list in rsetups_and_oplists: + for op in op_list: + if op in rskips: + continue + setup_fn(op) + assert most_recent_func is not None + if most_recent_func != BUILTIN_TO_TENSOR_FN_MAP[op]: + BUILTIN_TO_TENSOR_RFN_MAP[op] = most_recent_func + + +class BuiltinVariable(VariableTracker): + """ + A VariableTracker that represents a built-in value (functions and operators). + A lot of the code here assumes it will be a function object. + + The BuiltinVariable class wraps Python built-in functions (like len, isinstance, etc.) + and operators (like +, -, *, etc.) to enable symbolic execution during tracing. This allows + Dynamo to properly handle these operations when converting Python code to FX graphs while + maintaining correct semantics and enabling optimizations. + """ + + _SENTINEL = object() + _nonvar_fields = { + "fn", + *VariableTracker._nonvar_fields, + } + + @classmethod + def create_with_source(cls, value: Any, source: Source) -> "BuiltinVariable": + install_guard(source.make_guard(GuardBuilder.BUILTIN_MATCH)) + return cls(value, source=source) + + @staticmethod + @functools.cache + def _constant_fold_functions() -> set[Callable[..., Any]]: + fns: set[Callable[..., Any]] = { + abs, + all, + any, + bool, + callable, + chr, + complex, + divmod, + float, + getattr, + int, + len, + max, + min, + ord, + pow, + repr, + round, + str, + str.format, + sum, + type, + operator.abs, + operator.pos, + operator.neg, + operator.not_, + operator.truth, + operator.invert, + operator.pow, + operator.mul, + operator.matmul, + operator.floordiv, + operator.truediv, + operator.mod, + operator.add, + operator.sub, + operator.getitem, + operator.length_hint, + operator.lshift, + operator.rshift, + operator.and_, + operator.or_, + operator.xor, + operator.ipow, + operator.imul, + operator.imatmul, + operator.ifloordiv, + operator.itruediv, + operator.imod, + operator.iadd, + operator.isub, + operator.ilshift, + operator.irshift, + operator.iand, + operator.ixor, + operator.ior, + operator.index, + } + from .tensor import supported_comparison_ops + + fns.update(supported_comparison_ops.values()) + fns.update(x for x in math.__dict__.values() if isinstance(x, type(math.sqrt))) + return fns + + def can_constant_fold_through(self) -> bool: + return self.fn in self._constant_fold_functions() + + @staticmethod + @functools.cache + def _fx_graph_functions() -> set[Callable[..., Any]]: + fns = { + operator.abs, + operator.pos, + operator.neg, + operator.not_, + operator.invert, + operator.pow, + operator.mul, + operator.matmul, + operator.floordiv, + operator.truediv, + operator.mod, + operator.add, + operator.lt, + operator.gt, + operator.ge, + operator.le, + operator.ne, + operator.eq, + operator.sub, + operator.length_hint, + operator.lshift, + operator.rshift, + operator.and_, + operator.or_, + operator.xor, + operator.ipow, + operator.imul, + operator.imatmul, + operator.ifloordiv, + operator.itruediv, + operator.getitem, + operator.imod, + operator.iadd, + operator.isub, + operator.ilshift, + operator.irshift, + operator.iand, + operator.ixor, + operator.ior, + } + return fns # type: ignore[return-value] + + @staticmethod + @functools.cache + def _binops() -> dict[ + Callable[..., object], tuple[list[str], Callable[..., object]] + ]: + # function -> ([forward name, reverse name, in-place name], in-place op) + fns: dict[Callable[..., object], tuple[list[str], Callable[..., object]]] = { + operator.add: (["__add__", "__radd__", "__iadd__"], operator.iadd), + operator.sub: (["__sub__", "__rsub__", "__isub__"], operator.isub), + operator.mul: (["__mul__", "__rmul__", "__imul__"], operator.imul), + operator.truediv: ( + ["__truediv__", "__rtruediv__", "__itruediv__"], + operator.itruediv, + ), + operator.floordiv: ( + ["__floordiv__", "__rfloordiv__", "__ifloordiv__"], + operator.ifloordiv, + ), + operator.mod: (["__mod__", "__rmod__", "__imod__"], operator.imod), + pow: (["__pow__", "__rpow__", "__ipow__"], operator.ipow), + operator.pow: (["__pow__", "__rpow__", "__ipow__"], operator.ipow), + operator.lshift: ( + ["__lshift__", "__rlshift__", "__ilshift__"], + operator.ilshift, + ), + operator.rshift: ( + ["__rshift__", "__rrshift__", "__irshift__"], + operator.irshift, + ), + operator.xor: (["__xor__", "__rxor__", "__ixor__"], operator.xor), + # NB: The follow binary operators are not supported for now, since the + # corresponding magic methods aren't defined on SymInt / SymFloat: + # operator.matmul + # divmod + # operator.and_ + # operator.or_ + } + return fns + + @staticmethod + @functools.cache + def _binop_handlers() -> dict[ + Callable[..., object], + list[ + tuple[ + tuple[ + type[VariableTracker], + _TrackersType, + ], + _HandlerCallback, + ] + ], + ]: + # Multiple dispatch mechanism defining custom binop behavior for certain type + # combinations. Handlers are attempted in order, and will be used if the type checks + # match. They are expected to have the signature: + # fn(tx, arg0: VariableTracker, arg1: VariableTracker) -> VariableTracker + from .functions import BaseUserFunctionVariable, UserFunctionVariable + from .nn_module import NNModuleVariable + from .tensor import supported_const_comparison_ops + from .torch import BaseTorchVariable + from .user_defined import ( + UserDefinedClassVariable, + UserDefinedObjectVariable, + UserDefinedVariable, + ) + + # Override table contains: op_fn -> [list of handlers] + op_handlers: dict[Any, list[Any]] = {} + for ( + op, + (magic_method_names, in_place_op), + ) in BuiltinVariable._binops().items(): + op_handlers[op] = [] + op_handlers[in_place_op] = [] + + forward_name, reverse_name, inplace_name = magic_method_names + + # User-defined args (highest precedence) + def user_defined_handler( + tx: "InstructionTranslator", + a: VariableTracker, + b: VariableTracker, + *, + forward_name: str = forward_name, + reverse_name: str = reverse_name, + ) -> VariableTracker: + # Manually handle reversing logic if needed (e.g. call __radd__) + + # TODO: If we expand this to handle tensor args, we need to manually + # handle cases like this: + # + # class A(int): + # def __radd__(self, other): + # print("woof") + # torch.randn(3) + A(3) + # + # In this example, A.__radd__() is not called -> nothing is printed, because + # Tensor.__add__ only does a subtype test against int, ignoring the subclass. + # To be fully correct, we should not call A.__radd__() here, and there may be + # other cases to reason about and add exceptions for. + if isinstance(a, UserDefinedVariable): + return a.call_method(tx, forward_name, [b], {}) + else: + return b.call_method(tx, reverse_name, [a], {}) + + op_handlers[op].append( + ((UserDefinedVariable, VariableTracker), user_defined_handler) + ) + op_handlers[op].append( + ((VariableTracker, UserDefinedVariable), user_defined_handler) + ) + + def user_defined_inplace_handler( + tx: "InstructionTranslator", + a: VariableTracker, + b: VariableTracker, + *, + forward_name: str = inplace_name, + ) -> VariableTracker: + return a.call_method(tx, forward_name, [b], {}) + + op_handlers[in_place_op].append( + ((UserDefinedVariable, VariableTracker), user_defined_inplace_handler) + ) + op_handlers[in_place_op].append( + ((VariableTracker, UserDefinedVariable), user_defined_inplace_handler) + ) + + # Dynamic shape args + def dynamic_handler( + tx: "InstructionTranslator", + a: VariableTracker, + b: VariableTracker, + *, + fn: Callable[..., Any] = op, + ) -> VariableTracker: + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", fn, *proxy_args_kwargs([a, b], {}) + ), + ) + + op_handlers[op].append( + ((SymNodeVariable, VariableTracker), dynamic_handler) + ) + op_handlers[op].append( + ((VariableTracker, SymNodeVariable), dynamic_handler) + ) + + # NB: Prefer out-of-place op when calling in-place op to generate valid graph + op_handlers[in_place_op].append( + ((SymNodeVariable, VariableTracker), dynamic_handler) + ) + op_handlers[in_place_op].append( + ((VariableTracker, SymNodeVariable), dynamic_handler) + ) + + # Special cases - lower precedence but still prefer these over constant folding + + # List-like addition (e.g. [1, 2] + [3, 4]) + def tuple_add_handler( + tx: "InstructionTranslator", a: BaseListVariable, b: VariableTracker + ) -> VariableTracker: + return TupleVariable([*a.items, *b.unpack_var_sequence(tx)]) + + def size_add_handler( + tx: "InstructionTranslator", a: BaseListVariable, b: VariableTracker + ) -> VariableTracker: + return SizeVariable([*a.items, *b.unpack_var_sequence(tx)]) + + list_like_addition_handlers: list[ + tuple[ + tuple[ + type[VariableTracker], + _TrackersType, + ], + _HandlerCallback, + ] + ] = [ + # NB: Prefer the tuple-specific logic over base logic because of + # some SizeVariable weirdness. Specifically, the tuple-specific logic + # drops the subclass type (e.g. SizeVariable) and returns TupleVariables. + ( + (SizeVariable, SizeVariable), + size_add_handler, + ), + ( + (SizeVariable, TupleVariable), + size_add_handler, + ), + ( + (TupleVariable, SizeVariable), + size_add_handler, + ), + ( + (TupleVariable, TupleVariable), + tuple_add_handler, + ), + ( + (TupleVariable, ConstantVariable), + tuple_add_handler, + ), + ( + (ConstantVariable, TupleVariable), + lambda tx, a, b: TupleVariable( + [ + *a.unpack_var_sequence(tx), + *b.items, + ], + ), + ), + ( + ( + ListVariable, + (BaseListVariable, ConstantVariable, ListIteratorVariable), + ), + lambda tx, a, b: ListVariable( + [*a.items, *b.unpack_var_sequence(tx)], + mutation_type=ValueMutationNew(), + ), + ), + ( + (BaseListVariable, BaseListVariable), + lambda tx, a, b: type(a)( + [ + *a.items, + *b.items, + ] + ), + ), + ] + op_handlers[operator.add].extend(list_like_addition_handlers) + + def list_iadd_handler( + tx: "InstructionTranslator", a: BaseListVariable, b: VariableTracker + ) -> Any: + if a.is_immutable() or not b.has_unpack_var_sequence(tx): + # Handler doesn't apply + return None + + seq = b.unpack_var_sequence(tx) + tx.output.side_effects.mutation(a) + a.items.extend(seq) + return a + + list_like_iadd_handlers: list[Any] = [ + ( + (ListVariable, VariableTracker), + list_iadd_handler, + ), + ( + (TupleVariable, TupleVariable), + tuple_add_handler, + ), + ( + (TupleVariable, ConstantVariable), + tuple_add_handler, + ), + ] + op_handlers[operator.iadd].extend(list_like_iadd_handlers) + + # List-like expansion (e.g. [1, 2, 3] * 3) + def expand_list_like( + tx: "InstructionTranslator", lst: VariableTracker, const: VariableTracker + ) -> VariableTracker: + if not isinstance(lst, BaseListVariable) and lst.is_python_constant(): + lst, const = const, lst + try: + assert isinstance(lst, BaseListVariable) + return lst.__class__( + items=lst.items * const.as_python_constant(), + mutation_type=ValueMutationNew(), + ) + except MemoryError as exc: + raise_observed_exception( + type(exc), + tx, + args=list(map(ConstantVariable.create, exc.args)), + ) + + list_like_expansion_handlers: list[ + tuple[ + tuple[type[VariableTracker], type[VariableTracker]], + _HandlerCallback, + ] + ] = [ + ((ListVariable, ConstantVariable), expand_list_like), + ((TupleVariable, ConstantVariable), expand_list_like), + ((ConstantVariable, ListVariable), expand_list_like), + ((ConstantVariable, TupleVariable), expand_list_like), + ] + op_handlers[operator.mul].extend(list_like_expansion_handlers) + + def create_cmp_op_handlers( + op: Callable[..., Any], + ) -> list[tuple[tuple[_TrackersType, _TrackersType], _HandlerCallback]]: + def compare_by_value( + tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker + ) -> VariableTracker: + try: + return ConstantVariable(op(a.value, b.value)) # type: ignore[attr-defined] + except TypeError as exc: + raise_observed_exception( + type(exc), + tx, + args=list(map(ConstantVariable.create, exc.args)), + ) + + result: list[ + tuple[ + tuple[ + _TrackersType, + _TrackersType, + ], + _HandlerCallback, + ] + ] = [((ConstantVariable, ConstantVariable), compare_by_value)] + + if op in polyfill_fn_mapping: + # For constants, speedup the comparison instead of using + # polyfill. Removing this line causes major regression for pr + # time benchmark - add_loop_eager. + result = [((ConstantVariable, ConstantVariable), compare_by_value)] + + op_var = BuiltinVariable(op) + # Special handling of SymNode variable + result.extend( + [ + ( + (SymNodeVariable, VariableTracker), + op_var._comparison_with_symnode, + ), + ( + (VariableTracker, SymNodeVariable), + op_var._comparison_with_symnode, + ), + ] + ) + + def handler( + tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker + ) -> VariableTracker: + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfill_fn_mapping[op]), [a, b], {} + ) + + result.append(((VariableTracker, VariableTracker), handler)) + return result + + result = [((ConstantVariable, ConstantVariable), compare_by_value)] + + if op in supported_const_comparison_ops.values() and op.__name__.startswith( + "is_" + ): + # Tensor is None, List is not None, etc + none_result = op(object(), None) + + def never( + tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker + ) -> VariableTracker: + return ConstantVariable(none_result) + + obj_op_none = never + none_op_obj = never + + types_that_are_never_none = ( + TensorVariable, + SymNodeVariable, + NNModuleVariable, + BaseListVariable, + UserDefinedVariable, + BaseUserFunctionVariable, + ConstDictVariable, + BaseTorchVariable, + ) + result.extend( + [ + ( + (types_that_are_never_none, ConstantVariable), + obj_op_none, + ), + ( + (ConstantVariable, types_that_are_never_none), + none_op_obj, + ), + ] + ) + + op_var = BuiltinVariable(op) + result.extend( + [ + ( + ( + (UserFunctionVariable, BuiltinVariable), + (UserFunctionVariable, BuiltinVariable), + ), + lambda tx, a, b: ConstantVariable(op(a.fn, b.fn)), + ), + ( + ( + NNModuleVariable, + NNModuleVariable, + ), + lambda tx, a, b: ConstantVariable( + op( + tx.output.get_submodule(a.module_key), + tx.output.get_submodule(b.module_key), + ) + ), + ), + ( + (UserDefinedObjectVariable, UserDefinedObjectVariable), + compare_by_value, + ), + ( + (UserDefinedClassVariable, UserDefinedClassVariable), + compare_by_value, + ), + ( + ( + (StreamVariable, EventVariable, ConstantVariable), + (StreamVariable, EventVariable, ConstantVariable), + ), + compare_by_value, + ), + ( + (TensorVariable, VariableTracker), + op_var._comparison_with_tensor, + ), + ( + (VariableTracker, TensorVariable), + op_var._comparison_with_tensor, + ), + ( + (SymNodeVariable, VariableTracker), + op_var._comparison_with_symnode, + ), + ( + (VariableTracker, SymNodeVariable), + op_var._comparison_with_symnode, + ), + ] + ) + + def handle_is( + tx: "InstructionTranslator", + left: VariableTracker, + right: VariableTracker, + ) -> VariableTracker | None: + # If the two objects are of different type, we can safely return False + # and True for `is` and `is not`, respectively + if type(left) is not type(right): + return ConstantVariable.create(op.__name__ != "is_") + if left is right: + return ConstantVariable.create(op(left, right)) + if ( + istype(left, variables.ExceptionVariable) + and istype(right, variables.ExceptionVariable) + and left.exc_type is not right.exc_type + ): + return ConstantVariable.create(op(left, right)) + return None + + result.append(((VariableTracker, VariableTracker), handle_is)) # type: ignore[arg-type] + + return result + + for op in supported_comparison_ops.values(): + assert callable(op) + assert op not in op_handlers + op_handlers[op] = create_cmp_op_handlers(op) + + return op_handlers + + @staticmethod + def _find_binop_handler( + op: Callable[..., Any], a_type: type[VariableTracker], b_type: type + ) -> list[_HandlerCallback] | None: + handlers = BuiltinVariable._binop_handlers().get(op) + if handlers is None: + return None + + matches = [] + for (type1, type2), handler in handlers: + if issubclass(a_type, type1) and issubclass(b_type, type2): + matches.append(handler) + return matches + + def can_insert_in_graph(self) -> bool: + return self.fn in self._fx_graph_functions() + + def __init__(self, fn: Any, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.fn = fn + + def __repr__(self) -> str: + if self.fn is None: + name = "None" + else: + name = self.fn.__name__ + + return f"{self.__class__.__name__}({name})" + + def as_python_constant(self) -> Any: + return self.fn + + def as_proxy(self) -> Any: + DTYPE = { + bool: torch.bool, + int: torch.int64, + float: torch.float64, + } + if self.fn in DTYPE: + return DTYPE[self.fn] + return super().as_proxy() + + def reconstruct(self, codegen: "PyCodegen") -> None: + name = self.fn.__name__ + assert self.fn.__module__ == "builtins" + assert name not in codegen.tx.f_globals, "shadowed global" + codegen.append_output(codegen.create_load_global(name, add=True)) + + def constant_args(self, *args: VariableTracker, **kwargs: VariableTracker) -> bool: + return check_constant_args(args, kwargs) + + def tensor_args(self, *args: VariableTracker) -> bool: + any_tensor = False + for arg in args: + if isinstance(arg, variables.GetAttrVariable): + return False + any_tensor = any_tensor or arg.is_tensor() + return any_tensor + + def tensor_args_type(self, arg_types: list[type]) -> bool: + any_tensor = False + for arg_type in arg_types: + if issubclass(arg_type, variables.GetAttrVariable): + return False + any_tensor = any_tensor or issubclass(arg_type, variables.TensorVariable) + return any_tensor + + def python_and_tensor_constant_only( + self, *args: VariableTracker, **kwargs: VariableTracker + ) -> bool: + tensor_args = [] + non_tensor_args = [] + for i in itertools.chain(args, kwargs.values()): + if i.is_tensor(): + tensor_args.append(i) + else: + non_tensor_args.append(i) + return all( + is_constant_source(t.source) if t.source is not None else False + for t in tensor_args + ) and self.constant_args(*non_tensor_args) + + @staticmethod + def unwrap_unspec_args_kwargs( + args: Sequence[VariableTracker], kwargs: dict[str, VariableTracker] + ) -> tuple[list[Any], dict[str, Any]]: + return [x.as_python_constant() for x in args], { + k: v.as_python_constant() for k, v in kwargs.items() + } + + def has_constant_handler( + self, args: Sequence[VariableTracker], kwargs: dict[str, VariableTracker] + ) -> bool: + return self.can_constant_fold_through() and check_unspec_or_constant_args( + args, kwargs + ) + + @staticmethod + def _make_handler( + fn: Callable[..., Any], arg_types: list[type], has_kwargs: bool + ) -> Callable[ + [ + "InstructionTranslator", + Sequence[VariableTracker], + dict[str, VariableTracker], + ], + VariableTracker | None, + ]: + from .lazy import LazyVariableTracker + + obj = BuiltinVariable(fn) + handlers: list[_HandlerCallback] = [] + + if any(issubclass(t, LazyVariableTracker) for t in arg_types): + return lambda tx, args, kwargs: obj.call_function( + tx, [v.realize() for v in args], kwargs + ) + + if inspect.isclass(fn) and ( + issubclass(fn, Exception) + # GeneratorExit doesn't inherit from Exception + # >>> issubclass(GeneratorExit, Exception) + # False + or fn is GeneratorExit + ): + + def create_exception_class_object( + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if fn is AssertionError and not all( + x.is_python_constant() and isinstance(x.as_python_constant(), str) + for x in args + ): + unimplemented( + gb_type="assert with non-string message", + context=str(args), + explanation="Dynamo only supports asserts with string messages", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + return variables.ExceptionVariable(fn, args, kwargs) + + return create_exception_class_object + + if obj.can_insert_in_graph() and not ( + fn is operator.getitem + and not issubclass(arg_types[0], variables.TensorVariable) + ): + if obj.tensor_args_type(arg_types): + return obj._handle_insert_op_in_graph + elif has_kwargs: + # need runtime check for kwargs + handlers.append(obj._handle_insert_op_in_graph) + + # Handle binary ops (e.g. __add__ / __radd__, __iadd__, etc.) + # NB: Tensor args are handled above and not here + if len(arg_types) == 2 and not has_kwargs: + # Try to find a handler for the arg types; otherwise, fall through to constant handler + binop_handlers = BuiltinVariable._find_binop_handler(fn, *arg_types) + if not binop_handlers: + pass + elif len(binop_handlers) == 1: + (binop_handler,) = binop_handlers + handlers.append(lambda tx, args, _: binop_handler(tx, *args)) + else: + + def call_binop_handlers( + tx: "InstructionTranslator", args: Any, _: Any + ) -> Any: + # pyrefly: ignore [not-iterable] + for fn in binop_handlers: + rv = fn(tx, *args) + if rv: + return rv + return None + + handlers.append(call_binop_handlers) + + self_handler = getattr(obj, f"call_{fn.__name__}", None) + if self_handler: + + def call_self_handler( + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker | None: + try: + # pyrefly: ignore [not-callable] + return self_handler(tx, *args, **kwargs) + except TypeError: + # Check if binding is bad. inspect signature bind is expensive. + # So check only when handler call fails. + try: + # pyrefly: ignore [bad-argument-type] + inspect.signature(self_handler).bind(tx, *args, **kwargs) + except TypeError as e: + has_constant_handler = obj.has_constant_handler(args, kwargs) + if not has_constant_handler: + log.warning( # noqa: G200 + "incorrect arg count %s %s and no constant handler", + self_handler, + e, + ) + unimplemented( + gb_type="invalid call to builtin op handler", + context=f"invalid args to {self_handler}: {args} {kwargs}", + explanation=f"Encountered TypeError when trying to handle op {fn.__name__}", + hints=[*graph_break_hints.DIFFICULT], + ) + else: + raise + except Unsupported as exc: + has_constant_handler = obj.has_constant_handler(args, kwargs) + if not has_constant_handler: + raise + # Actually, we will handle this just fine + exc.remove_from_stats() + return None + + handlers.append(call_self_handler) + + if obj.can_constant_fold_through(): + if ( + all(issubclass(x, ConstantVariable) for x in arg_types) + and not has_kwargs + ): + + def constant_fold_handler( + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker | None: + # fast path + try: + res = fn( + *[x.as_python_constant() for x in args], + ) + except Exception as exc: + raise_observed_exception( + type(exc), + tx, + args=list(map(ConstantVariable.create, exc.args)), + ) + except AsPythonConstantNotImplementedError as exc: + unimplemented( + gb_type="constant fold exception", + context=f"attempted to run function {fn} with arguments {args}", + explanation="Encountered exception when attempting to constant fold.", + hints=[*graph_break_hints.DYNAMO_BUG], + from_exc=exc, + ) + # pyrefly: ignore [unbound-name] + return VariableTracker.build(tx, res) + + else: + + def constant_fold_handler( + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker | None: + # path with a runtime check + if check_unspec_or_constant_args(args, kwargs): + try: + res = fn( + *[x.as_python_constant() for x in args], + **{ + k: v.as_python_constant() for k, v in kwargs.items() + }, + ) + except AsPythonConstantNotImplementedError as exc: + unimplemented( + gb_type="constant fold exception", + context=f"attempted to run function {fn} with arguments {args}", + explanation="Encountered exception when attempting to constant fold.", + hints=[*graph_break_hints.DYNAMO_BUG], + from_exc=exc, + ) + except Exception as exc: + raise_observed_exception( + type(exc), + tx, + args=list(map(ConstantVariable.create, exc.args)), + ) + # pyrefly: ignore [unbound-name] + return VariableTracker.build(tx, res) + return None + + handlers.append(constant_fold_handler) + + def call_unimplemented(args: Sequence[VariableTracker]) -> None: + real_arg_types = [arg.python_type_name() for arg in args] + unimplemented( + gb_type="Failed to trace builtin operator", + context=f"builtin {fn.__name__} {arg_types} {has_kwargs}", + explanation=f"Dynamo does not know how to trace builtin operator `{fn.__name__}` " + f"with argument types {real_arg_types} (has_kwargs {has_kwargs})", + hints=[ + f"Avoid calling builtin `{fn.__name__}` with argument types {real_arg_types}. " + f"Consider using an equivalent alternative function/method to `{fn.__name__}`.", + "If you are attempting to call a logging function (e.g. `print`), " + "you can try adding it to `torch._dynamo.config.reorderable_logging_functions`.", + "Please report an issue to PyTorch.", + ], + ) + + if len(handlers) == 0: + return lambda tx, args, kwargs: call_unimplemented(args) + elif len(handlers) == 1: + (handler,) = handlers + + def builtin_dispatch( + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker | None: + rv = handler(tx, args, kwargs) + if rv: + return rv + call_unimplemented(args) + return rv + + else: + + def builtin_dispatch( + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker | None: + rv = None + for fn in handlers: + rv = fn(tx, args, kwargs) + if rv: + return rv + call_unimplemented(args) + return rv + + return builtin_dispatch + + def call_vars(self, tx: "InstructionTranslator", *args: Any) -> VariableTracker: + if len(args) == 0: + unimplemented( + gb_type="unimplemented builtin op vars() with no arguments", + context=f"vars: {self} {args}", + explanation=f"Dynamo does not know how to trace builtin operator {self.fn} with no arguments", + hints=[*graph_break_hints.SUPPORTABLE], + ) + assert len(args) == 1 + # vars(obj) is obj.__dict__ if __dict__ is present else TypeError + try: + return args[0].var_getattr(tx, "__dict__") + except ObservedAttributeError: + raise_observed_exception(TypeError, tx) + + def _handle_insert_op_in_graph( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker | None: + from .builder import wrap_fx_proxy, wrap_fx_proxy_cls + + if kwargs and not self.tensor_args(*args, *kwargs.values()): + return None + + # insert handling for torch function here + from .builder import SourcelessBuilder + from .torch_function import can_dispatch_torch_function, dispatch_torch_function + + global BUILTIN_TO_TENSOR_RFN_MAP, BUILTIN_TO_TENSOR_FN_MAP + if can_dispatch_torch_function(tx, args, kwargs): + # Only remap the fn to tensor methods if we aren't exporting + # export serde does not handle method descriptors today + if not tx.export: + # Ensure the builtin maps are populated before accessing them + populate_builtin_to_tensor_fn_map() + # Use sourceless builder, we built the map ourselves + if not args[0].is_tensor(): + if self.fn in BUILTIN_TO_TENSOR_RFN_MAP: + func = BUILTIN_TO_TENSOR_RFN_MAP[self.fn] + else: + func = BUILTIN_TO_TENSOR_FN_MAP[self.fn] + + tmp = args[0] + # swap args and call reverse version of func + args[0] = args[1] # type: ignore[index] + args[1] = tmp # type: ignore[index] + else: + func = BUILTIN_TO_TENSOR_FN_MAP[self.fn] + else: + func = self.fn + + fn_var = SourcelessBuilder.create(tx, func) + + return dispatch_torch_function(tx, fn_var, args, kwargs) + + fn = self.fn + try: + # Constant fold for constant tensor and python constants + if self.python_and_tensor_constant_only(*args, **kwargs): + from ..bytecode_transformation import unique_id + from .functions import invoke_and_store_as_constant + + return invoke_and_store_as_constant( + tx, fn, unique_id(fn.__name__), args, kwargs + ) + + if fn in IN_PLACE_DESUGARING_MAP and isinstance( + args[0], variables.ConstantVariable + ): + # In-place operators like += usually mustate tensor + # values, but in the edge case of immutable values they + # re-bind the variable. + # + # The easiest way to keep the graph consistent in this + # scenario is to de-sugar eagerly. + fn = IN_PLACE_DESUGARING_MAP[fn] + args = [args[0], args[1]] # type: ignore[assignment] + + if fn is operator.getitem and isinstance(args[1], SymNodeVariable): + # Standard indexing will force specialization due to + # __index__. Rewrite as a regular torch op which will + # trace fine + fn = torch.select + args = [ + args[0], + variables.ConstantVariable.create(0), + args[1], + ] # type: ignore[assignment] + + # Interaction between ndarray and tensors: + # We prefer the tensor op whenever there are tensors involved + # NB: Use exact type check here - NumpyNdarrayVariable is a TensorVariable + # subclass but should NOT trigger the tensor path + if check_numpy_ndarray_args(args, kwargs) and not any( + type(arg) is TensorVariable for arg in args + ): + proxy = tx.output.create_proxy( + "call_function", + numpy_operator_wrapper(fn), + *proxy_args_kwargs(args, kwargs), + ) + + return wrap_fx_proxy_cls(variables.NumpyNdarrayVariable, tx, proxy) + + if fn is operator.eq and len(args) == 2 and args[0].is_tensor(): + # Dynamo expects `__eq__` str while operator.eq gives just `eq` + # TODO - supporting all comparison operators could also work but + # it fails lots of tests because graph str changes. + return args[0].call_method(tx, "__eq__", list(args[1:]), kwargs) + proxy = tx.output.create_proxy( + "call_function", + fn, + *proxy_args_kwargs(args, kwargs), + ) + if any(isinstance(arg, FakeItemVariable) for arg in args): + return wrap_fx_proxy_cls( + FakeItemVariable, + tx, + proxy, + ) + elif check_unspec_python_args(args, kwargs): + _args, _kwargs = self.unwrap_unspec_args_kwargs(args, kwargs) + raw_value = fn(*_args, **_kwargs) + + need_unwrap = any( + x.need_unwrap + for x in itertools.chain(args, kwargs.values()) + if isinstance(x, variables.UnspecializedPythonVariable) + ) + + return wrap_fx_proxy_cls( + UnspecializedPythonVariable, + tx, + proxy, + raw_value=raw_value, + need_unwrap=need_unwrap, + ) + elif all(isinstance(x, SymNodeVariable) for x in args): + return SymNodeVariable.create(tx, proxy, None) + else: + # Work around for vision_maskrcnn due to precision difference + # specialize the dividend when float divide by tensor + if fn is operator.truediv and isinstance( + args[0], variables.UnspecializedPythonVariable + ): + args = list(args) + args[0] = args[0].as_python_constant() + return wrap_fx_proxy(tx, proxy) + + except NotImplementedError: + unimplemented( + gb_type="unimplemented builtin op on tensor arguments", + context=f"partial tensor op: {self} {args} {kwargs}", + explanation=f"Dynamo does not know how to trace builtin operator {self.fn} with tensor arguments", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + call_function_handler_cache: dict[ + tuple[object, ...], + Callable[ + [ + "InstructionTranslator", + Sequence[VariableTracker], + dict[str, VariableTracker], + ], + VariableTracker, + ], + ] = {} + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + key: tuple[object, ...] + if kwargs: + kwargs = {k: v.realize() for k, v in kwargs.items()} + key = (self.fn, *(type(x) for x in args), True) + else: + key = (self.fn, *(type(x) for x in args)) + + handler = self.call_function_handler_cache.get(key) + if not handler: + self.call_function_handler_cache[key] = handler = self._make_handler( # type: ignore[assignment] + self.fn, [type(x) for x in args], bool(kwargs) + ) + assert handler is not None + return handler(tx, args, kwargs) # type: ignore[return-value] + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if self.fn is object and name == "__setattr__": + assert len(args) == 3 + assert len(kwargs) == 0 + obj, name_var, val = args + obj = obj.realize() + if ( + isinstance(obj, UserDefinedObjectVariable) + and tx.output.side_effects.is_attribute_mutation(obj) + and name_var.is_python_constant() + ): + return obj.method_setattr_standard(tx, name_var, val) + + if name == "__new__": + # Supported __new__ methods + if self.fn is object and len(args) == 1: + assert len(kwargs) == 0 + return tx.output.side_effects.track_new_user_defined_object( + self, args[0], args[1:] + ) + + if self.fn is dict and len(args) == 1 and not kwargs: + dict_vt = ConstDictVariable({}, dict, mutation_type=ValueMutationNew()) + if isinstance(args[0], BuiltinVariable) and args[0].fn is dict: + return dict_vt + # We don't have to set the underlying dict_vt in + # UserDefinedDictVariable because it will be set to empty + # ConstDictVariableTracker in the constructor. + return tx.output.side_effects.track_new_user_defined_object( + self, + args[0], + args[1:], + ) + + if ( + self.fn is tuple + and len(args) == 2 + and args[1].has_force_unpack_var_sequence(tx) + and not kwargs + ): + if isinstance(args[0], BuiltinVariable) and args[0].fn is tuple: + init_args = args[1].force_unpack_var_sequence(tx) + return variables.TupleVariable( + init_args, mutation_type=ValueMutationNew() + ) + + return tx.output.side_effects.track_new_user_defined_object( + self, + args[0], + args[1:], + ) + + if self.fn is list: + list_vt = ListVariable([], mutation_type=ValueMutationNew()) + if isinstance(args[0], BuiltinVariable) and args[0].fn is list: + return list_vt + return tx.output.side_effects.track_new_user_defined_object( + self, + args[0], + args[1:], + ) + + if ( + self.fn in (float, complex) + and len(args) == 1 + and ( + (self.fn is float and name in ("fromhex", "hex")) + or (name == "from_number" and sys.version_info >= (3, 14)) + ) + ): + if args[0].is_python_constant(): + try: + fn = getattr(self.fn, name) + res = fn(args[0].as_python_constant()) + return variables.ConstantVariable.create(res) + except (OverflowError, ValueError) as e: + raise_observed_exception( + type(e), + tx, + args=list(map(ConstantVariable.create, e.args)), + ) + + if self.fn is object and name == "__init__": + # object.__init__ is a no-op + return variables.ConstantVariable(None) + + if self.fn is dict and name == "fromkeys": + return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs) + + if self.fn is dict: + resolved_fn = getattr(self.fn, name) + if resolved_fn in dict_methods: + if isinstance(args[0], variables.UserDefinedDictVariable): + # pyrefly: ignore [missing-attribute] + return args[0]._dict_vt.call_method(tx, name, args[1:], kwargs) + elif isinstance(args[0], variables.ConstDictVariable): + return args[0].call_method(tx, name, args[1:], kwargs) + + if self.fn is set: + resolved_fn = getattr(self.fn, name) + if resolved_fn in set_methods: + if isinstance(args[0], variables.UserDefinedSetVariable): + # pyrefly: ignore [missing-attribute] + return args[0]._set_vt.call_method(tx, name, args[1:], kwargs) + elif isinstance(args[0], variables.SetVariable): + return args[0].call_method(tx, name, args[1:], kwargs) + + if self.fn is frozenset: + resolved_fn = getattr(self.fn, name) + if resolved_fn in frozenset_methods: + if isinstance(args[0], variables.FrozensetVariable): + return args[0].call_method(tx, name, args[1:], kwargs) + + if self.fn is str and len(args) >= 1: + resolved_fn = getattr(self.fn, name) + if resolved_fn in str_methods: + # Only delegate to ConstantVariable, not other types that happen to be constants + if isinstance(args[0], ConstantVariable): + return args[0].call_method(tx, name, args[1:], kwargs) + + if self.fn is float and len(args) >= 1: + # Only delegate to ConstantVariable, not other types that happen to be constants + if isinstance(args[0], ConstantVariable): + return ConstantVariable.create( + getattr(float, name)(args[0].as_python_constant()) + ) + + return super().call_method(tx, name, args, kwargs) + + def _call_int_float( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker | None: + # Handle cases like int(torch.seed()) + # Also handle sym_float to sym_int cases + if arg.is_tensor() or isinstance(arg, SymNodeVariable): + if arg.is_tensor(): + item = arg.call_method(tx, "item", [], {}) + else: + item = arg + fn_ = sym_int if self.fn is int else sym_float + from torch._dynamo.variables.builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn_, + (item.as_proxy(),), + {}, + ), + ) + return None + + call_int = _call_int_float + call_float = _call_int_float + + def call_bool( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker | None: + # Emulate `PyBool_Type.tp_vectorcall` which boils down to `PyObject_IsTrue`. + # https://github.com/python/cpython/blob/3.12/Objects/object.c#L1674-L1697 + if isinstance(arg, SymNodeVariable): + # Note that we delay specializing on symbolic values to avoid + # unnecessary guards. Specialization will happen later if, e.g., the + # resulting boolean is used for branching. + if isinstance(arg.sym_num, torch.SymBool): + return arg + + # Emulate `nb_bool` of int/float objects + # - https://github.com/python/cpython/blob/3.12/Objects/longobject.c#L4940-L4944 + # - https://github.com/python/cpython/blob/3.12/Objects/floatobject.c#L878-L882 + assert istype(arg.sym_num, (torch.SymInt, torch.SymFloat)) + return SymNodeVariable.create(tx, arg.as_proxy() != 0) + + # TODO handle more cases and merge this with this with `generic_jump`. + return None + + def call_repr(self, tx: "InstructionTranslator", arg): + """Handle repr() on user defined objects.""" + if isinstance(arg, variables.UserDefinedObjectVariable): + repr_method = arg.value.__repr__ + + if type(arg.value).__repr__ is object.__repr__: + # Default repr - build and trace it + fn_vt = VariableTracker.build(tx, repr_method) + return fn_vt.call_function(tx, [], {}) + else: + # Custom repr - inline the method for tracing + bound_method = repr_method.__func__ + fn_vt = VariableTracker.build(tx, bound_method) + return fn_vt.call_function(tx, [arg], {}) + + def call_str( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker | None: + # Handle `str` on a user defined function or object + if isinstance(arg, (variables.UserFunctionVariable)): + return variables.ConstantVariable.create(value=str(arg.fn)) + elif isinstance(arg, (variables.UserDefinedObjectVariable)): + # Check if object has __str__ method + if hasattr(arg.value, "__str__"): + str_method = arg.value.__str__ + elif hasattr(arg.value, "__repr__"): + # account for __repr__ functions when __str__ is absent + str_method = arg.value.__repr__ + else: + unimplemented( + gb_type="failed to call str() on user defined object", + context=str(arg), + explanation="User defined object has no __str__ or __repr__ method", + hints=[*graph_break_hints.USER_ERROR], + ) + + if type(arg.value).__str__ is object.__str__: + # Rely on the object str method + try: + # pyrefly: ignore [unbound-name] + return variables.ConstantVariable.create(value=str_method()) + except AttributeError: + # Graph break + return None + # pyrefly: ignore [unbound-name] + elif is_wrapper_or_member_descriptor(str_method): + unimplemented( + gb_type="Attempted to a str() method implemented in C/C++", + context="", + explanation=f"{type(arg.value)} has a C/C++ based str method. This is not supported.", + hints=["Write the str method in Python"], + ) + else: + # Overrides for custom str method + # Pass method as function to call tx.inline_user_function_return + bound_method = str_method.__func__ # type: ignore[attr-defined] + + try: + # Only supports certain function types + user_func_variable = VariableTracker.build(tx, bound_method) + except AssertionError: + # Won't be able to do inline the str method, return to avoid graph break + log.warning("Failed to create UserFunctionVariable", exc_info=True) + return None + + # Inline the user function + return user_func_variable.call_function(tx, [arg], {}) + elif isinstance(arg, (variables.ExceptionVariable,)): + if len(arg.args) == 0: + value = f"{arg.exc_type}" + else: + value = ", ".join(a.as_python_constant() for a in arg.args) + return variables.ConstantVariable.create(value=value) + return None + + def _call_min_max( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker | None: + if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx): + items = args[0].force_unpack_var_sequence(tx) + return self._call_min_max_seq(tx, items) + elif len(args) == 2: + return self._call_min_max_binary(tx, args[0], args[1]) + elif len(args) > 2: + return self._call_min_max_seq(tx, args) + return None + + def _call_min_max_seq( + self, tx: "InstructionTranslator", items: Sequence[VariableTracker] + ) -> VariableTracker: + assert len(items) > 0 + if len(items) == 1: + return items[0] + + return functools.reduce(functools.partial(self._call_min_max_binary, tx), items) # type: ignore[arg-type,return-value] + + def _call_min_max_binary( + self, + tx: "InstructionTranslator", + a: VariableTracker | None, + b: VariableTracker | None, + ) -> VariableTracker | None: + if a is None or b is None: + # a or b could be none if we reduce and _call_min_max_binary failed + # to return something + return None + if self.tensor_args(a, b): + if not a.is_tensor(): + a, b = b, a + assert a.is_tensor() + + # result of an item call is a scalar convert to a tensor + if isinstance(a, FakeItemVariable): + a = variables.TorchInGraphFunctionVariable(torch.tensor).call_function( + tx, [a], {} + ) + + # Dynamic input does not get resolved, rather, gets stored as call_function + if isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable): + from .builder import wrap_fx_proxy_cls + + return wrap_fx_proxy_cls( + type(a), + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.fn, + *proxy_args_kwargs([a, b], {}), + ), + ) + + # convert min/max to torch ops + if b.is_python_constant(): + fn: VariableTracker + if isinstance(a, variables.NumpyNdarrayVariable): + import numpy as np + + fn = variables.NumpyVariable(np.clip) + else: + fn = variables.TorchInGraphFunctionVariable(torch.clamp) + kwargs = {"min": b} if (self.fn is max) else {"max": b} + result = fn.call_function(tx, [a], kwargs) + else: + if isinstance(a, variables.NumpyNdarrayVariable): + import numpy as np + + np_fn = {max: np.maximum, min: np.minimum}[self.fn] + fn = variables.NumpyVariable(np_fn) + else: + torch_fn = {max: torch.maximum, min: torch.minimum}[self.fn] + fn = variables.TorchInGraphFunctionVariable(torch_fn) + result = fn.call_function(tx, [a, b], {}) + + # return unspec if both a, b are unspec or const + if all( + isinstance( + i, + ( + variables.UnspecializedPythonVariable, + variables.ConstantVariable, + ), + ) + for i in [a, b] + ): + if any(isinstance(val, FakeItemVariable) for val in [a, b]): + return variables.FakeItemVariable.from_tensor_variable(result) + + if b.is_python_constant(): + raw_b = b.as_python_constant() + else: + raw_b = b.raw_value # type: ignore[attr-defined] + if self.fn is max: + raw_res = max(a.raw_value, raw_b) # type: ignore[attr-defined] + else: + raw_res = min(a.raw_value, raw_b) # type: ignore[attr-defined] + + need_unwrap = any( + x.need_unwrap + for x in [a, b] + if isinstance(x, variables.UnspecializedPythonVariable) + ) + return variables.UnspecializedPythonVariable.from_tensor_variable( + result, raw_res, need_unwrap + ) + # otherwise return tensor + else: + return result + elif isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable): + py_fn = torch.sym_max if self.fn is max else torch.sym_min + proxy = tx.output.create_proxy( + "call_function", py_fn, *proxy_args_kwargs([a, b], {}) + ) + return SymNodeVariable.create(tx, proxy, None) + elif isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): + value = self.fn( + a.as_python_constant(), + b.as_python_constant(), + ) + return ConstantVariable.create(value) + return None + + call_min = _call_min_max + call_max = _call_min_max + + def call_abs( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: + # Call arg.__abs__() + abs_method = BuiltinVariable(getattr).call_function( + tx, [arg, ConstantVariable.create("__abs__")], {} + ) + return abs_method.call_function(tx, [], {}) + + def call_pos( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: + # Call arg.__pos__() + pos_method = BuiltinVariable(getattr).call_function( + tx, [arg, ConstantVariable.create("__pos__")], {} + ) + return pos_method.call_function(tx, [], {}) + + def call_index( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: + if arg.is_tensor(): + unimplemented( + gb_type="unsupported index(Tensor)", + context="", + explanation="Dynamo does not support tracing builtin index() on a Tensor", + hints=[], + ) + + arg = guard_if_dyn(arg) + constant_value = operator.index(arg) + return variables.ConstantVariable.create(constant_value) + + def call_round( + self, + tx: "InstructionTranslator", + arg: VariableTracker, + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + # Call arg.__round__() + round_method = BuiltinVariable(getattr).call_function( + tx, [arg, ConstantVariable.create("__round__")], {} + ) + return round_method.call_function(tx, args, kwargs) + + def call_range( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker | None: + if check_unspec_or_constant_args(args, {}): + return variables.RangeVariable(args) + elif self._dynamic_args(*args): + args = tuple( + variables.ConstantVariable.create(guard_if_dyn(arg)) for arg in args + ) + return variables.RangeVariable(args) + # None no-ops this handler and lets the driving function proceed + return None + + def _dynamic_args(self, *args: VariableTracker, **kwargs: VariableTracker) -> bool: + return any(isinstance(x, SymNodeVariable) for x in args) or any( + isinstance(x, SymNodeVariable) for x in kwargs.values() + ) + + def call_slice( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + return variables.SliceVariable(args, tx) + + def _dyn_proxy( + self, tx: "InstructionTranslator", *args: Any, **kwargs: Any + ) -> VariableTracker: + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", self.fn, *proxy_args_kwargs(args, kwargs) + ), + ) + + # NOTE must handle IteratorVariable separately! + def _call_iter_tuple_list( + self, + tx: "InstructionTranslator", + obj: VariableTracker | None = None, + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker | None: + assert not isinstance(obj, variables.IteratorVariable) + + if self._dynamic_args(*args, **kwargs): + return self._dyn_proxy(tx, *args, **kwargs) + + cls = variables.BaseListVariable.cls_for(self.fn) + if obj is None: + return cls( + [], + mutation_type=ValueMutationNew(), + ) + elif obj.has_unpack_var_sequence(tx): + if obj.source and not is_constant_source(obj.source): + if isinstance(obj, TupleIteratorVariable): + install_guard( + obj.source.make_guard(GuardBuilder.TUPLE_ITERATOR_LEN) + ) + else: + if ( + getattr(obj, "source", False) + and isinstance(obj, ConstDictVariable) + and not istype(obj, (SetVariable, FrozensetVariable)) + ): + tx.output.guard_on_key_order.add(obj.source) + + if isinstance(obj, variables.MappingProxyVariable): + # This could be an overguarding, but its rare to iterate + # through a mapping proxy and not use the keys. + install_guard( + obj.source.make_guard(GuardBuilder.MAPPING_KEYS_CHECK) + ) + elif not isinstance(obj, variables.UnspecializedNNModuleVariable): + # Prevent calling __len__ method for guards, the tracing + # of __iter__ will insert the right guards later. + install_guard( + obj.source.make_guard(GuardBuilder.SEQUENCE_LENGTH) + ) + + return cls( + list(obj.unpack_var_sequence(tx)), + mutation_type=ValueMutationNew(), + ) + return None + + def _call_iter_tuple_generator( + self, + tx: "InstructionTranslator", + obj: VariableTracker, + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + cls = variables.BaseListVariable.cls_for(self.fn) + return cls( + list(obj.force_unpack_var_sequence(tx)), # exhaust generator + mutation_type=ValueMutationNew(), + ) + + def _call_tuple_list( + self, + tx: "InstructionTranslator", + obj: VariableTracker | None = None, + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker | None: + if isinstance(obj, variables.IteratorVariable): + cls = variables.BaseListVariable.cls_for(self.fn) + return cls( + list(obj.force_unpack_var_sequence(tx)), + mutation_type=ValueMutationNew(), + ) + elif isinstance(obj, variables.LocalGeneratorObjectVariable) or ( + isinstance(obj, UserDefinedObjectVariable) + and obj.has_force_unpack_var_sequence(tx) + ): + return self._call_iter_tuple_generator(tx, obj, *args, **kwargs) + else: + return self._call_iter_tuple_list(tx, obj, *args, **kwargs) + + def call_iter( + self, + tx: "InstructionTranslator", + obj: VariableTracker, + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + # avoid the overhead of tracing the polyfill if we already know the class implemented __iter__ + if isinstance( + obj, + ( + variables.ListVariable, + variables.RangeVariable, + variables.IteratorVariable, + variables.ConstDictVariable, + variables.NNModuleVariable, + variables.TensorVariable, + ), + ): + return obj.call_method(tx, "__iter__", [], {}) + else: + # If the object doesn't implement a __iter__ method, it will be an error in eager mode when calling iter on it anyway. + # If the object implements a __iter__ method, inlining effectively forwards the call to another iter call + # (e.g. when __iter__ just returns iter(self.list)) or return a user-defined iterator. + # If the object implements a __getitem__ method, iter(...) will call obj.__getitem__() + # with an integer argument starting at 0, until __getitem__ raises IndexError + ret = variables.UserFunctionVariable( + polyfills.builtins.iter_ # type: ignore[arg-type] + ).call_function(tx, [obj, *args], {}) + + if args: + # iter(obj, sentinel) returns an object that implements + # __iter__ and __next__ methods (UserDefinedObjectVariable) + # Wrap the return value in a IteratorVariable subclass (LazyObjectIteratorVariable) + # that forwards the next_variable call to the object. + ret = variables.ObjectIteratorVariable(ret) + return ret + + call_tuple = _call_tuple_list + call_list = _call_tuple_list + + def call_callable( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker | None: + from .functions import BaseUserFunctionVariable, FunctoolsPartialVariable + from .nn_module import NNModuleVariable + + if isinstance( + arg, + ( + variables.UserDefinedClassVariable, + BaseUserFunctionVariable, + FunctoolsPartialVariable, + NNModuleVariable, + ), + ): + return variables.ConstantVariable.create(True) + elif isinstance(arg, UserDefinedVariable): + return variables.ConstantVariable.create(callable(arg.value)) + elif isinstance( + arg, + ( + ConstantVariable, + SymNodeVariable, + TensorVariable, + ListVariable, + TupleVariable, + ListIteratorVariable, + ), + ): + return variables.ConstantVariable.create(False) + else: + return None + + def call_cast( + self, _: Any, *args: VariableTracker, **kwargs: VariableTracker + ) -> VariableTracker | None: + if len(args) == 2: + return args[1] + + unimplemented( + gb_type="bad args to builtin cast()", + context=f"got args {args} {kwargs}", + explanation="Dynamo expects exactly 2 args to builtin cast().", + hints=["Ensure your call to cast() has exactly 2 arguments."], + ) + + def call_dir( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker | None: + if isinstance(arg, variables.UserDefinedClassVariable): + return VariableTracker.build(tx, dir(arg.value)) + if isinstance(arg, BuiltinVariable): + return VariableTracker.build(tx, dir(arg.fn)) + return None + + def call_dict( + self, + tx: "InstructionTranslator", + /, + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + return BuiltinVariable.call_custom_dict(tx, dict, *args, **kwargs) + + @staticmethod + def call_custom_dict( + tx: "InstructionTranslator", + user_cls: type, + /, + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + args_list = list(args) + if ( + len(args_list) == 1 + and isinstance(args_list[0], variables.GetAttrVariable) + and isinstance(args_list[0].obj, variables.UserDefinedClassVariable) + and not tx.output.side_effects.has_pending_mutation(args_list[0].obj) + ): + # Forward the GetAttrVariable(foo, "__dict__") to a realized vt of + # VT(foo.__dict__). This simplifies the construction of the new + # dict. + args_list[0] = args_list[0].get_forwarded_dict(tx) + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.construct_dict), + [VariableTracker.build(tx, user_cls), *args_list], + kwargs, + ) + + @staticmethod + def call_custom_dict_fromkeys( + tx: "InstructionTranslator", + user_cls: type, + /, + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + if user_cls not in {dict, OrderedDict, defaultdict}: + unimplemented( + gb_type="Unsupported dict type for fromkeys()", + context=f"{user_cls.__name__}.fromkeys(): {args} {kwargs}", + explanation=f"Failed to call {user_cls.__name__}.fromkeys() because " + f"{user_cls.__name__} is not any type of dict, OrderedDict, or defaultdict", + hints=[ + f"Ensure {user_cls.__name__} is a type of dict, OrderedDict, or defaultdict.", + ], + ) + if kwargs: + # Only `OrderedDict.fromkeys` accepts `value` passed by keyword + if ( + user_cls is not OrderedDict + or len(args) != 1 + or len(kwargs) != 1 + or "value" not in kwargs + ): + raise_args_mismatch( + tx, + f"{user_cls.__name__}.fromkeys", + "1 args and 1 kwargs (`value`)", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + args = (*args, kwargs.pop("value")) + if len(args) == 0: + raise_args_mismatch( + tx, + f"{user_cls.__name__}.fromkeys", + "at least 1 args", + f"{len(args)} args", + ) + if len(args) == 1: + args = (*args, ConstantVariable.create(None)) + if len(args) != 2: + raise_args_mismatch( + tx, + f"{user_cls.__name__}.fromkeys", + "2 args", + f"{len(args)} args", + ) + # pyrefly: ignore [bad-unpacking] + arg, value = args + DictVariableType = ( + ConstDictVariable if user_cls is not defaultdict else DefaultDictVariable + ) + + if isinstance(arg, dict): + arg_list = [ConstantVariable.create(k) for k in arg] + return DictVariableType( + # pyrefly: ignore [bad-argument-type] + dict.fromkeys(arg_list, value), + user_cls, + mutation_type=ValueMutationNew(), + ) + elif arg.has_force_unpack_var_sequence(tx): + keys = arg.force_unpack_var_sequence(tx) + if all(is_hashable(v) for v in keys): + return DictVariableType( + # pyrefly: ignore [bad-argument-type] + dict.fromkeys(keys, value), + user_cls, + mutation_type=ValueMutationNew(), + ) + + unimplemented( + gb_type="failed to call dict.fromkeys()", + context=f"{user_cls.__name__}.fromkeys(): {args} {kwargs}", + explanation=f"Failed to call {user_cls.__name__}.fromkeys() because " + "arguments could not be automatically converted to a list, " + "or some dict key is not hashable.", + hints=[ + "Manually convert the argument to a list.", + "Ensure all keys are hashable.", + ], + ) + + def call_set( + self, + tx: "InstructionTranslator", + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + # Can we merge this implementation and call_dict's one? + assert not kwargs + if not args: + return SetVariable([], mutation_type=ValueMutationNew()) + if len(args) != 1: + raise_observed_exception( + TypeError, + tx, + args=[ + ConstantVariable.create( + f"set() takes 1 positional argument but {len(args)} were given" + ) + ], + ) + arg = args[0] + if istype(arg, variables.SetVariable): + return arg.clone(mutation_type=ValueMutationNew()) + elif arg.has_force_unpack_var_sequence(tx): + items = arg.force_unpack_var_sequence(tx) + return SetVariable(items, mutation_type=ValueMutationNew()) + elif isinstance(arg, variables.UserDefinedObjectVariable) and isinstance( + arg.value, KeysView + ): + iter_fn = arg.var_getattr(tx, "__iter__") + if isinstance(iter_fn, variables.UserMethodVariable): + out = tx.inline_user_function_return(iter_fn, args, kwargs) + if isinstance(out, SetVariable): + return out + return BuiltinVariable(set).call_set(tx, out) + raise_observed_exception( + TypeError, + tx, + args=[ConstantVariable.create("failed to construct builtin set()")], + ) + + def call_frozenset( + self, + tx: "InstructionTranslator", + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + assert not kwargs + if not args: + return FrozensetVariable([]) + if len(args) != 1: + raise_observed_exception( + TypeError, + tx, + args=[ + ConstantVariable.create( + f"frozenset() takes 1 positional argument but {len(args)} were given" + ) + ], + ) + arg = args[0] + if istype(arg, variables.FrozensetVariable): + return FrozensetVariable([x.vt for x in arg.set_items]) + elif arg.has_force_unpack_var_sequence(tx): + items = arg.force_unpack_var_sequence(tx) + return FrozensetVariable(items) + raise_observed_exception( + TypeError, + tx, + args=[ConstantVariable.create("failed to construct builtin frozenset()")], + ) + + def call_zip( + self, + tx: "InstructionTranslator", + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + if kwargs: + if not (len(kwargs) == 1 and "strict" in kwargs): + raise_args_mismatch( + tx, + "zip", + "1 kwargs (`strict`)", + f"{len(kwargs)} kwargs", + ) + strict = kwargs.pop("strict", ConstantVariable.create(False)) + iter_args = [BuiltinVariable(iter).call_function(tx, [arg], {}) for arg in args] + return variables.ZipVariable( + iter_args, + strict=strict.as_python_constant(), + mutation_type=ValueMutationNew(), + ) + + def call_len( + self, + tx: "InstructionTranslator", + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + try: + return args[0].call_method(tx, "__len__", list(args[1:]), kwargs) + except AttributeError as e: + raise_observed_exception(type(e), tx, args=list(e.args)) + + def call_getitem( + self, + tx: "InstructionTranslator", + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + return args[0].call_method(tx, "__getitem__", list(args[1:]), kwargs) + + def call_isinstance( + self, + tx: "InstructionTranslator", + arg: VariableTracker, + isinstance_type_var: VariableTracker, + ) -> VariableTracker: + try: + arg_type = arg.python_type() + except NotImplementedError: + unimplemented( + gb_type="builtin isinstance() cannot determine type of argument", + context=f"isinstance({arg}, {isinstance_type_var})", + explanation=f"Dynamo doesn't have a rule to determine the type of argument {arg}", + hints=[*graph_break_hints.DYNAMO_BUG], + ) + isinstance_type = isinstance_type_var.as_python_constant() + if isinstance(arg, variables.TensorVariable) and arg.dtype is not None: + + def _tensor_isinstance( + tensor_var: VariableTracker, tensor_type: Any + ) -> bool: + def check_type(ty: Any) -> bool: + if ty not in tensortype_to_dtype: + example_val = arg.as_proxy().node.meta["example_value"] + if ( + is_traceable_wrapper_subclass(example_val) + and ty is torch.nn.parameter.Parameter + ): + # N.B: we are calling isinstance directly on the example value. + # torch.nn.Parameter has a meta-class that overrides __isinstance__, + # the isinstance check here allows us to invoke that logic. + return isinstance(example_val, ty) + else: + return issubclass(arg.python_type(), ty) + + dtypes = tensortype_to_dtype[ty] + # pyrefly: ignore [missing-attribute] + return arg.dtype in dtypes + + if type(tensor_type) is tuple: + return any(check_type(ty) for ty in tensor_type) + else: + return check_type(tensor_type) + + return variables.ConstantVariable.create( + _tensor_isinstance(arg, isinstance_type) + ) + # UserDefinedObject with C extensions can have torch.Tensor attributes, + # so break graph. + if isinstance(arg, variables.UserDefinedObjectVariable) and isinstance( + arg.value, types.MemberDescriptorType + ): + unimplemented( + gb_type="isinstance() called on user defined object with C extensions", + context=f"isinstance({arg}, {isinstance_type})", + explanation="User-defined object with C extensions can have torch.Tensor " + "attributes; intentionally graph breaking.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + # handle __instancecheck__ defined in user class + if ( + isinstance(arg, variables.UserDefinedObjectVariable) + and "__instancecheck__" in isinstance_type.__class__.__dict__ + ): + return variables.ConstantVariable.create( + isinstance_type.__class__.__instancecheck__(isinstance_type, arg.value) + ) + + if isinstance(arg, variables.UserDefinedExceptionClassVariable): + # pyrefly: ignore [unbound-name] + return ConstantVariable.create(isinstance(arg_type, isinstance_type)) + + isinstance_type_tuple: tuple[type, ...] + if isinstance(isinstance_type, type) or callable( + # E.g. isinstance(obj, typing.Sequence) + getattr(isinstance_type, "__instancecheck__", None) + ): + isinstance_type_tuple = (isinstance_type,) + elif isinstance(isinstance_type, types.UnionType): + isinstance_type_tuple = isinstance_type.__args__ + elif isinstance(isinstance_type, tuple) and all( + isinstance(tp, type) or callable(getattr(tp, "__instancecheck__", None)) + for tp in isinstance_type + ): + isinstance_type_tuple = isinstance_type + else: + raise_observed_exception( + TypeError, + tx, + args=[ + "isinstance() arg 2 must be a type, a tuple of types, or a union" + ], + ) + + try: + # NB: `isinstance()` does not call `__subclasscheck__` but use `__instancecheck__`. + # But usually `isinstance(obj, type_info)` and `issubclass(type(obj), type_info)` gives + # the same result. + # WARNING: This might run arbitrary user code `__subclasscheck__` and we did not trace + # through it. This is a limitation of the current implementation. + # Usually `__subclasscheck__` and `__instancecheck__` can be constant fold through, it + # might not be a big issue and we trade off it for performance. + # pyrefly: ignore [unbound-name] + val = issubclass(arg_type, isinstance_type_tuple) + except TypeError: + # pyrefly: ignore [unbound-name] + val = arg_type in isinstance_type_tuple + return variables.ConstantVariable.create(val) + + def call_issubclass( + self, + tx: "InstructionTranslator", + left_ty: VariableTracker, + right_ty: VariableTracker, + ) -> VariableTracker: + """Checks if first arg is subclass of right arg""" + try: + left_ty_py = left_ty.as_python_constant() + right_ty_py = right_ty.as_python_constant() + except NotImplementedError: + unimplemented( + gb_type="issubclass() with non-constant arguments", + context=f"issubclass({left_ty}, {right_ty})", + explanation="issubclass() with non-constant arguments not supported.", + hints=[ + "Make sure your arguments are types.", + *graph_break_hints.USER_ERROR, + ], + ) + + # WARNING: This might run arbitrary user code `__subclasscheck__`. + # See the comment in call_isinstance above. + # pyrefly: ignore [unbound-name] + return variables.ConstantVariable(issubclass(left_ty_py, right_ty_py)) + + def call_super( + self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker + ) -> VariableTracker: + return variables.SuperVariable(a, b) + + def call_next( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + arg = args[0] + try: + return arg.next_variable(tx) + except ObservedUserStopIteration: + if len(args) == 2: + return args[1] + raise + except Unsupported as ex: + if isinstance(arg, variables.BaseListVariable): + ex.remove_from_stats() + return arg.items[0] + raise + + def call_hasattr( + self, tx: "InstructionTranslator", obj: VariableTracker, attr: VariableTracker + ) -> VariableTracker | None: + if attr.is_python_constant(): + name = attr.as_python_constant() + if isinstance(obj, variables.BuiltinVariable): + return variables.ConstantVariable(hasattr(obj.fn, name)) + return obj.call_obj_hasattr(tx, name) + return None + + def call_map( + self, + tx: "InstructionTranslator", + fn: VariableTracker, + *seqs: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + strict = ConstantVariable.create(False) + if kwargs: + if sys.version_info >= (3, 14): + if not (len(kwargs) == 1 and "strict" in kwargs): + raise_args_mismatch( + tx, + "map", + "1 kwargs (`strict`)", + f"{len(kwargs)} kwargs", + ) + strict = kwargs.pop("strict", ConstantVariable.create(False)) + else: + raise_args_mismatch( + tx, + "map", + "0 kwargs", + f"{len(kwargs)} kwargs", + ) + + seq_list = [ + seq.unpack_var_sequence(tx) if seq.has_unpack_var_sequence(tx) else seq + for seq in seqs + ] + return variables.MapVariable( + fn, + seq_list, # type: ignore[arg-type] + strict=strict.as_python_constant(), + mutation_type=ValueMutationNew(), + ) + + def call_filter( + self, tx: "InstructionTranslator", fn: VariableTracker, seq: VariableTracker + ) -> VariableTracker: + seq_or_list = ( + seq.unpack_var_sequence(tx) if seq.has_unpack_var_sequence(tx) else seq + ) + return variables.FilterVariable( + fn, + seq_or_list, # type: ignore[arg-type] + mutation_type=ValueMutationNew(), + ) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + source = self.source and AttrSource(self.source, name) + if self.fn is object: + # for object, we can just directly read the attribute + try: + value = getattr(self.fn, name) + except AttributeError: + raise_observed_exception(AttributeError, tx) + # pyrefly: ignore [unbound-name] + if not callable(value): + # pyrefly: ignore [unbound-name] + return VariableTracker.build(tx, value, source) + return variables.GetAttrVariable(self, name, source=source) + + def call_getattr( + self, + tx: "InstructionTranslator", + obj: VariableTracker, + name_var: VariableTracker, + default: VariableTracker | None = None, + ) -> VariableTracker | None: + if not name_var.is_python_constant(): + unimplemented( + gb_type="getattr() with non-constant name argument", + context=f"getattr({obj}, {name_var}, {default})", + explanation="getattr() with non-constant name argument is not supported", + hints=["Ensure the name argument of getattr() is a string"], + ) + + name = name_var.as_python_constant() + + # See NOTE [Tensor "grad" and "_grad" attr] + if obj.is_tensor() and name == "_grad": + name = "grad" + + if tx.output.side_effects.is_attribute_mutation(obj): + if isinstance(obj, variables.UnspecializedNNModuleVariable): + if ( + name + in ( + "named_parameters", + "parameters", + "named_buffers", + "buffers", + "named_modules", + "modules", + ) + and obj.is_state_mutated + and tx.output.side_effects.has_pending_mutation(obj) + ): + unimplemented( + gb_type="getattr() on nn.Module with pending mutation", + context=f"getattr({obj}, {name}, {default})", + explanation="Intentionally graph breaking on getattr() on a nn.Module " + "with a pending mutation", + hints=[], + ) + + if tx.output.side_effects.has_pending_mutation_of_attr(obj, name): + return tx.output.side_effects.load_attr(obj, name) + + if default is not None: + hasattr_var = self.call_hasattr(tx, obj, name_var) + if hasattr_var is not None: + assert hasattr_var.is_constant_match(True, False) + if not hasattr_var.as_python_constant(): + return default + else: + return default + + source = obj.source and AttrSource(obj.source, name) + if name in {"__bases__", "__base__", "__flags__"}: + try: + value = obj.as_python_constant() + if isinstance(value, type): + if name == "__bases__": + tuple_args = [ + VariableTracker.build( + tx, b, source and GetItemSource(source, i) + ) + for i, b in enumerate(value.__bases__) + ] + return variables.TupleVariable(tuple_args, source=source) + if name == "__base__": + return VariableTracker.build(tx, value.__base__, source) + if name == "__flags__": + return ConstantVariable.create(value.__flags__) + except NotImplementedError: + pass + + if isinstance(obj, variables.NNModuleVariable): + return obj.var_getattr(tx, name) + elif isinstance( + obj, + ( + variables.TensorVariable, + variables.NamedTupleVariable, + variables.ConstantVariable, + variables.DistributedVariable, + variables.UserDefinedClassVariable, + variables.UserDefinedObjectVariable, + ), + ): + if ( + isinstance(obj, variables.UserDefinedObjectVariable) + and issubclass(obj.value.__class__, unittest.TestCase) + and config.enable_trace_unittest + and name + in ( + "assertRaisesRegex", + "assertNotWarns", + "assertWarnsRegex", + "assertWarns", + ) + ): + unimplemented( + gb_type="Failed to trace unittest method", + context=f"function: unittest.TestCase.{name}", + explanation=f"Dynamo does not know how to trace unittest method `{name}` ", + hints=[ + f"Avoid calling `TestCase.{name}`. " + "Please report an issue to PyTorch.", + ], + ) + if obj.is_tensor(): + fake_val = obj.as_proxy().node.meta["example_value"] + if ( + isinstance(fake_val, torch.Tensor) + and is_sparse_any(fake_val) + and (not tx.export or not config.capture_sparse_compute) + ): + unimplemented( + gb_type="Attempted to wrap sparse Tensor", + context="", + explanation="torch.compile does not support sparse Tensors", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + try: + return obj.var_getattr(tx, name) + except AsPythonConstantNotImplementedError: + # dont fallback on as_python_constant error because this leads + # to a failure later on, and leads to a wrong stacktrace + raise + except NotImplementedError: + return variables.GetAttrVariable(obj, name, source=source) + elif isinstance(obj, variables.TorchInGraphFunctionVariable): + # Get OpOverload from an OpOverloadPacket, e.g., torch.ops.aten.add.default. + member = getattr(obj.value, name) + if isinstance( + member, (torch._ops.OpOverloadPacket, torch._ops.OpOverload) + ) and torch._dynamo.trace_rules.is_aten_op_or_tensor_method(member): + return variables.TorchInGraphFunctionVariable(member, source=source) + elif name in cmp_name_to_op_mapping: + return variables.GetAttrVariable(obj, name, source=source) + else: + return None + elif isinstance(obj, DummyModule): + # TODO(mlazos) - Do we need this? + if obj.is_torch or name not in obj.value.__dict__: + member = getattr(obj.value, name) + else: + member = obj.value.__dict__[name] + + if config.replay_record_enabled: + tx.exec_recorder.record_module_access(obj.value, name, member) # type: ignore[arg-type, union-attr] + return VariableTracker.build(tx, member, source) + + elif istype(obj, variables.UserFunctionVariable) and name in ( + "__name__", + "__module__", + ): + return ConstantVariable.create(getattr(obj.fn, name)) + else: + try: + return obj.var_getattr(tx, name) + except NotImplementedError: + return variables.GetAttrVariable(obj, name, source=source) + + def call_setattr( + self, + tx: "InstructionTranslator", + obj: VariableTracker, + name_var: VariableTracker, + val: VariableTracker, + ) -> VariableTracker | None: + if isinstance( + obj, + ( + variables.PlacementVariable, + variables.NamedTupleVariable, + variables.UserDefinedObjectVariable, + variables.NestedUserFunctionVariable, + variables.ExceptionVariable, + ), + ): + return obj.call_method(tx, "__setattr__", [name_var, val], {}) + elif ( + tx.output.side_effects.is_attribute_mutation(obj) + and name_var.is_python_constant() + ): + name = name_var.as_python_constant() + if obj.is_tensor(): + from .builder import wrap_fx_proxy + + # Some special handling for tensor attributes. + if name == "requires_grad": + # TODO(voz): Make it work properly + unimplemented( + gb_type="setattr() on Tensor.requires_grad", + context=f"setattr({obj}, {name}, {val})", + explanation="setattr() on Tensor.requires_grad not supported. " + "Mutating requires_grad can introduce a new leaf from non-leaf or vice versa in " + "the middle of the graph, which AOTAutograd does not currently know how to handle.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + elif name == "data": + # See comments on `test_set_data_on_scoped_tensor` for plans + # to support this. + if obj.source is None: + unimplemented( + gb_type="Failed to mutate tensor data attribute", + context=f"setattr({obj}, {name}, {val})", + explanation="Dyanmo only supports mutating `.data`" + " of tensor created outside `torch.compile` region", + hints=[ + "Don't mutate `.data` on this tensor, or move " + "the mutation out of `torch.compile` region", + ], + ) + elif obj.dtype != val.dtype: # type: ignore[attr-defined] + unimplemented( + gb_type="Failed to mutate tensor data attribute to different dtype", + context=f"setattr({obj}, {name}, {val})", + explanation="Dyanmo only supports mutating `.data`" + " of tensor to a new one with the same dtype", + hints=[ + "Don't mutate `.data` on this tensor, or move " + "the mutation out of `torch.compile` region", + ], + ) + + # Remove the old reference in tracked fakes - if we don't do this + # new .data value size and shape differences will cause + # tracked fakes to produce incorrect guards. This is sound because the TensorVariable + # coming out of set_() below will be a new one, and get + # installed in tracked fakes. + to_remove = [ + tf for tf in tx.output.tracked_fakes if tf.source == obj.source + ] + for tf in to_remove: + tx.output.tracked_fakes.remove(tf) + + # Step 1 - disable grads + with dynamo_disable_grad(tx), torch.no_grad(): + # Step 2 - call `set_` + out = wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", + torch.Tensor.set_, + *proxy_args_kwargs([obj, val], {}), + ), + ) + + # Step 3 - drop the version counter - this is a step required to get + # .data setting to play correctly with the autograd engine. + # Essentially, dynamo is trying to faithfully preserve the (absurd) + # behavior of .data= from eager mode + def _lower_version_count_by_1(x: torch.Tensor) -> torch.Tensor: + version = x._version + if version > 0: + version = version - 1 + torch._C._autograd._unsafe_set_version_counter((x,), (version,)) + return x + + tx.output.create_proxy( + "call_function", + _lower_version_count_by_1, + (out.as_proxy(),), + {}, + ) + _lower_version_count_by_1(obj.as_proxy().node.meta["example_value"]) + # This handles options prop, guards and ends with a clone + # Step 4 - replace all reference to the current object with the new one + return out + elif name in ("_grad", "grad"): + # NOTE: [Tensor "grad" and "_grad" attr] + # _grad and grad share the same setter/getter, see + # THPVariable_properties, and here we make sure setting one + # enables reading `val` from the other, by routing all + # read/write to `grad`. + name = "grad" + elif is_tensor_getset_descriptor(name): + # Attribute like `torch.Tensor.real` has special setters we + # don't yet support; it's not as simple adding an entry to + # the side effect mapping. + unimplemented( + gb_type="Failed to set tensor attribute", + context=f"setattr({obj}, {name}, {val})", + explanation="Dyanmo doesn't support setting these tensor attributes", + hints=[ + f"Don't mutate attribute '{name}' on tensors, or " + "move the mutation out of `torch.compile` region", + ], + ) + + tx.output.side_effects.store_attr(obj, name, val) + return val + elif isinstance(obj, variables.NNModuleVariable): + if not tx.output.is_root_tracer(): + raise AttributeMutationError( + "Can't inplace modify module params/buffers inside HigherOrderOp" + ) + if name_var.is_python_constant() and isinstance( + val, variables.TensorVariable + ): + assigning_fake_val = get_fake_value(val.as_proxy().node, tx) + + try: + getattr_var = obj.var_getattr(tx, name_var.as_python_constant()) + except (AttributeError, ObservedAttributeError): + getattr_var = None + + if getattr_var is not None and getattr_var.is_tensor(): + # get_fake_val will get the same fake tensor + existing_fake_attr = get_fake_value(getattr_var.as_proxy().node, tx) + + # same tensor identity, setattr is a no-op + mod_setattr = inspect.getattr_static(obj.module_type, "__setattr__") + if ( + existing_fake_attr is assigning_fake_val + and mod_setattr is torch.nn.Module.__setattr__ + ): + return getattr_var + + obj.convert_to_unspecialized(tx) + return None + + def call_delattr( + self, + tx: "InstructionTranslator", + obj: VariableTracker, + name_var: VariableTracker, + ) -> VariableTracker: + return obj.call_method(tx, "__delattr__", [name_var], {}) + + def call_type( + self, tx: "InstructionTranslator", obj: VariableTracker + ) -> VariableTracker: + try: + py_type = obj.python_type() + except NotImplementedError as error: + raise UserError( + UserErrorType.INVALID_INPUT, + str(error), + case_name="unknown_python_type", + ) from None + + source = obj.source and TypeSource(obj.source) + if ( + source is None + and isinstance(obj, variables.UserDefinedObjectVariable) + and obj.cls_source + ): + source = obj.cls_source + if py_type is torch.Tensor: + # In some cases torch isn't available in globals + name = tx.output.install_global_by_id("", torch) + source = AttrSource(GlobalSource(name), "Tensor") + + return VariableTracker.build(tx, py_type, source) + + def call_reversed( + self, tx: "InstructionTranslator", obj: VariableTracker + ) -> VariableTracker | None: + if obj.has_unpack_var_sequence(tx): + items = list(reversed(obj.unpack_var_sequence(tx))) + return variables.TupleVariable(items) + return None + + def call_sorted( + self, + tx: "InstructionTranslator", + obj: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker | None: + if obj.has_force_unpack_var_sequence(tx) and not isinstance( + obj, variables.TensorVariable + ): + list_var = variables.ListVariable( + obj.force_unpack_var_sequence(tx), + mutation_type=ValueMutationNew(), + ) + list_var.call_method(tx, "sort", [], kwargs) + return list_var + return None + + # neg is a constant fold function, so we only get here if constant fold is not valid + def call_neg( + self, tx: "InstructionTranslator", a: VariableTracker + ) -> VariableTracker | None: + if isinstance(a, SymNodeVariable): + return SymNodeVariable.create( + tx, + (operator.neg)(a.as_proxy()), + sym_num=None, + ) + + if ( + isinstance(a, UserDefinedObjectVariable) + and a.call_obj_hasattr(tx, "__neg__").value # type: ignore[attr-defined] + ): + return a.call_method(tx, "__neg__", [], {}) + + # None no-ops this handler and lets the driving function proceed + return None + + def call_format( + self, + tx: "InstructionTranslator", + _format_string: VariableTracker, + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + format_string = _format_string.as_python_constant() + format_string = str(format_string) + return variables.StringFormatVariable.create(format_string, args, kwargs) + + def call_id( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + if len(args) > 0 and isinstance(args[0], variables.NNModuleVariable): + nn_mod_variable = args[0] + mod = tx.output.get_submodule(nn_mod_variable.module_key) + return variables.ConstantVariable.create(id(mod)) + elif len(args) == 1 and isinstance( + args[0], + (variables.UserDefinedClassVariable, variables.UserDefinedObjectVariable), + ): + if args[0].source: + if isinstance(args[0], variables.UserDefinedClassVariable): + install_guard(args[0].source.make_guard(GuardBuilder.CLASS_MATCH)) + else: + install_guard(args[0].source.make_guard(GuardBuilder.ID_MATCH)) + constant_result = id(args[0].value) + return variables.ConstantVariable.create(constant_result) + elif len(args) == 1 and args[0].is_tensor(): + tensor_variable = cast(TensorVariable, args[0]) + return tensor_variable.call_id(tx) + elif istype(args[0], variables.UserFunctionVariable): + return variables.ConstantVariable.create(id(args[0].fn)) + elif istype(args[0], variables.SkipFunctionVariable): + return variables.ConstantVariable.create(id(args[0].value)) + elif istype(args[0], variables.FunctoolsPartialVariable): + return variables.ConstantVariable.create(id(args[0].fake_value)) + else: + unimplemented( + gb_type="id() with unsupported args", + context=str(args), + explanation=f"Dynamo doesn't know how to trace id() call with args {args}", + hints=[ + "Supported args are Tensors, and functions/nn.Modules/user-defined objects " + "from outside the compiled region.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + def call_deepcopy( + self, tx: "InstructionTranslator", x: VariableTracker + ) -> VariableTracker: + unimplemented( + gb_type="copy.deepcopy()", + context=f"copy.deepcopy({x})", + explanation="Dynamo does not support copy.deepcopy()", + hints=[ + "Avoid calling copy.deepcopy()", + *graph_break_hints.SUPPORTABLE, + ], + ) + + def _comparison_with_tensor( + self, tx: "InstructionTranslator", left: VariableTracker, right: VariableTracker + ) -> VariableTracker: + from .builder import wrap_fx_proxy_cls + from .tensor import supported_tensor_comparison_op_values + + op = self.fn + + if op in [operator.is_, operator.is_not]: + is_result = ( + left.is_tensor() + and right.is_tensor() + and id(extract_fake_example_value(left.as_proxy().node)) + == id(extract_fake_example_value(right.as_proxy().node)) + ) + if op is operator.is_: + return ConstantVariable.create(is_result) + else: + return ConstantVariable.create(not is_result) + + if op not in supported_tensor_comparison_op_values: + unimplemented( + gb_type="unsupported Tensor comparison op", + context=f"{op.__name__}({left}, {right})", + explanation=f"Dynamo does not support the comparison op {op.__name__} " + f"with Tensor arguments {left}, {right}", + hints=[*graph_break_hints.SUPPORTABLE], + ) + if ( + isinstance(left, TensorVariable) + and isinstance(right, TensorVariable) + and (left.size and right.size) is not None + and left.size != right.size + ): + try: + torch.broadcast_shapes(left.size, right.size) + except RuntimeError: + # not broadcastable, can't be compared + unimplemented( + gb_type="failed to broadcast when attempting Tensor comparison op", + context=f"{op.__name__}({left}, {right})", + explanation=f"Dynamo was unable to broad cast the arguments {left}, {right} " + f"when attempting to trace the comparison op {op.__name__}.", + hints=[*graph_break_hints.USER_ERROR], + ) + tensor_cls = left if left.is_tensor() else right + proxy = tx.output.create_proxy( + "call_function", op, (left.as_proxy(), right.as_proxy()), {} + ) + return wrap_fx_proxy_cls( + type(tensor_cls), # handle Ndarrays and Tensors + tx, + proxy, + ) + + def _comparison_with_symnode( + self, tx: "InstructionTranslator", left: VariableTracker, right: VariableTracker + ) -> VariableTracker: + from .tensor import supported_tensor_comparison_op_values + + op = self.fn + + if op not in supported_tensor_comparison_op_values: + unimplemented( + gb_type="unsupported SymNode comparison op", + context=f"{op.__name__}({left}, {right})", + explanation=f"Dynamo does not support the comparison op {op.__name__} " + f"with SymNode arguments {left}, {right}", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + # This is seen in inspect signature where we check if the value is a default value + if isinstance(right, variables.UserDefinedClassVariable): + return variables.ConstantVariable(op(object(), None)) + + proxy = tx.output.create_proxy( + "call_function", op, (left.as_proxy(), right.as_proxy()), {} + ) + return SymNodeVariable.create( + tx, + proxy, + sym_num=None, + ) + + def call_xor( + self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker + ) -> VariableTracker | None: + # Rely on constant_handler + if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): + return None + if a.is_symnode_like() and b.is_symnode_like(): + return SymNodeVariable.create( + tx, + tx.output.create_proxy( + "call_function", operator.xor, *proxy_args_kwargs([a, b], {}) + ), + sym_num=None, + ) + + if isinstance( + a, + (DictKeysVariable, SetVariable, UserDefinedObjectVariable), + ): + return a.call_method(tx, "__xor__", [b], {}) + return None + + def call_ixor( + self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker + ) -> VariableTracker | None: + if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedObjectVariable)): + return a.call_method(tx, "__ixor__", [b], {}) + return None + + def call_sub( + self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker + ) -> VariableTracker | None: + if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedObjectVariable)): + return a.call_method(tx, "__sub__", [b], {}) + return None + + def call_isub( + self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker + ) -> VariableTracker | None: + if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedObjectVariable)): + return a.call_method(tx, "__isub__", [b], {}) + return None + + def call_and_( + self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker + ) -> VariableTracker | None: + # Rely on constant_handler + if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): + return None + if a.is_symnode_like() and b.is_symnode_like(): + return SymNodeVariable.create( + tx, + tx.output.create_proxy( + "call_function", operator.and_, *proxy_args_kwargs([a, b], {}) + ), + sym_num=None, + ) + if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedObjectVariable)): + return a.call_method(tx, "__and__", [b], {}) + # None no-ops this handler and lets the driving function proceed + return None + + def call_iand( + self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker + ) -> VariableTracker | None: + # Rely on constant_handler + if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): + return None + if a.is_symnode_like() and b.is_symnode_like(): + return SymNodeVariable.create( + tx, + tx.output.create_proxy( + "call_function", operator.iand, *proxy_args_kwargs([a, b], {}) + ), + sym_num=None, + ) + if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedObjectVariable)): + return a.call_method(tx, "__iand__", [b], {}) + return None + + def call_or_( + self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker + ) -> VariableTracker | None: + # Rely on constant_handler + if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): + return None + if a.is_symnode_like() and b.is_symnode_like(): + return SymNodeVariable.create( + tx, + tx.output.create_proxy( + "call_function", operator.or_, *proxy_args_kwargs([a, b], {}) + ), + sym_num=None, + ) + + # This call looks like `{"one": torch.ones(1)} | {"two": torch.ones(2)}`. + if isinstance( + a, + ( + ConstDictVariable, + DictKeysVariable, + MutableMappingVariable, + SetVariable, + UserDefinedDictVariable, + UserDefinedObjectVariable, + ), + ): + # TODO(guilhermeleobas): forward the call to b.__ror__(a) if + # a.__ror__(b) returns NotImplemented + return a.call_method(tx, "__or__", [b], {}) + + # None no-ops this handler and lets the driving function proceed + return None + + def call_ior( + self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker + ) -> VariableTracker | None: + # Rely on constant_handler + if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): + return None + if a.is_symnode_like() and b.is_symnode_like(): + return SymNodeVariable.create( + tx, + tx.output.create_proxy( + "call_function", operator.ior, *proxy_args_kwargs([a, b], {}) + ), + sym_num=None, + ) + + # This call looks like `{"one": torch.ones(1)} |= {"two": torch.ones(2)}`. + if isinstance( + a, + ( + ConstDictVariable, + DictKeysVariable, + MutableMappingVariable, + SetVariable, + UserDefinedObjectVariable, + ), + ): + return a.call_method(tx, "__ior__", [b], {}) + + # None no-ops this handler and lets the driving function proceed + return None + + def call_not_( + self, tx: "InstructionTranslator", a: VariableTracker + ) -> VariableTracker | None: + if isinstance(a, SymNodeVariable): + return SymNodeVariable.create( + tx, + tx.output.create_proxy( + "call_function", operator.not_, *proxy_args_kwargs([a], {}) + ), + sym_num=None, + ) + + # Unwrap the underlying ConstDictVariable + if isinstance(a, DictViewVariable): + a = a.dv_dict + if isinstance(a, (ListVariable, ConstDictVariable)): + return ConstantVariable.create(len(a.items) == 0) + + return None + + def call_contains( + self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker + ) -> VariableTracker: + return a.call_method(tx, "__contains__", [b], {}) + + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.fn) + + def is_python_equal(self, other): + return isinstance(other, variables.BuiltinVariable) and self.fn is other.fn + + +@contextlib.contextmanager +def dynamo_disable_grad(tx: "InstructionTranslator") -> typing.Iterator[None]: + from . import GradModeVariable + + gmv = GradModeVariable.create(tx, False) + try: + gmv.enter(tx) + yield + finally: + gmv.exit(tx) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/constant.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/constant.py new file mode 100644 index 0000000000000000000000000000000000000000..86b5301b63e7233fd4061858f081695511517537 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/constant.py @@ -0,0 +1,421 @@ +""" +Constant and enum variable tracking in Dynamo. + +This module is fundamental to Dynamo's ability to track and propagate constant +values during compilation, ensuring proper handling of Python literals and +maintaining type safety through the compilation process. +""" + +import enum +import operator +from collections.abc import Sequence +from typing import Any, Literal, Optional, overload, TYPE_CHECKING, Union +from typing_extensions import override + +import torch +from torch._dynamo.source import AttrSource, GetItemSource + +from .. import graph_break_hints, variables +from ..exc import raise_observed_exception, unimplemented +from ..utils import ( + cmp_name_to_op_mapping, + common_constant_types, + istype, + np, + raise_args_mismatch, + raise_on_overridden_hash, +) +from .base import ValueMutationNew, VariableTracker + + +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslator + + from .functions import UserFunctionVariable + + +class ConstantVariable(VariableTracker): + """ + Variable tracker for Python literals and basic immutable types, with automatic + routing support for collection types (lists, tuples, sets, etc.). + + The create() method intelligently constructs appropriate variable types for + nested collections. + """ + + @overload + @staticmethod + def create(value: bool) -> "ConstantVariable": ... + + @overload + @staticmethod + def create(value: Any, **kwargs: Any) -> VariableTracker: ... + + @staticmethod + def create(value: Any, **kwargs: Any) -> VariableTracker: + """ + Create a `ConstantVariable` based on the given value, and supports + automatic routing for collection types like `tuple` (in which case we'd + create `ConstantVariable` for the leaf items). + + NOTE: the caller must install the proper guards if needed; most often + the guard will be `CONSTANT_MATCH`. + """ + source = kwargs.get("source") + + # Routing for supported collection literals. + if isinstance(value, set): + items = [ConstantVariable.create(x) for x in value] + return variables.SetVariable(items, **kwargs) # type: ignore[arg-type] + elif isinstance(value, frozenset): + items = [ConstantVariable.create(x) for x in value] + return variables.FrozensetVariable(items, **kwargs) # type: ignore[arg-type] + elif isinstance(value, slice): + slice_args = (value.start, value.stop, value.step) + slice_args_vars = tuple(ConstantVariable.create(arg) for arg in slice_args) + return variables.SliceVariable(slice_args_vars, **kwargs) + elif isinstance(value, (list, tuple)): + items = [] + for i, x in enumerate(value): + item_source = GetItemSource(source, i) if source else None + items.append( + ConstantVariable.create( + x, + source=item_source, + ) + ) + return variables.BaseListVariable.cls_for(type(value))(items, **kwargs) + + return ConstantVariable(value, **kwargs) + + def __init__(self, value: Any, **kwargs: Any) -> None: + super().__init__(**kwargs) + assert ConstantVariable.is_base_literal(value), f""" +Cannot construct `ConstantVariable` for value of type {type(value)}. + +This failure likely due to PyTorch-internal use of `ConstantVariable` on +non-literal python values, please try using `VariableTracker.build` instead. If +you believe it's a necessary and legitimate use case (the value is immutable and +can't easily be represented with another `VariableTracker` class), please add +its type to `common_constant_types`. +""" + if np is not None and isinstance(value, np.number): + self.value = value.item() + else: + self.value = value + + def as_proxy(self) -> Any: + return self.value + + def __repr__(self) -> str: + return f"ConstantVariable({type(self.value).__name__}: {repr(self.value)})" + + def as_python_constant(self) -> Any: + return self.value + + def is_python_constant(self) -> Literal[True]: + return True + + def is_symnode_like(self) -> bool: + return isinstance(self.value, (int, bool)) + + def is_constant_match(self, *values: Any) -> bool: + return self.value in values + + def is_constant_none(self) -> bool: + return self.value is None + + @property + def items(self) -> list[VariableTracker]: + """ + Need this when adding a BaseListVariable and a ConstantVariable together. + Happens in detectron2. + """ + return self.unpack_var_sequence(tx=None) + + def getitem_const( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: + return ConstantVariable.create( + self.value[arg.as_python_constant()], + ) + + @staticmethod + def is_base_literal(obj: object) -> bool: + return type(obj) in common_constant_types + + @staticmethod + def is_literal(obj: object) -> bool: + if type(obj) in (list, tuple, set, frozenset, torch.Size): + return all(ConstantVariable.is_literal(x) for x in obj) # type: ignore[attr-defined] + return ConstantVariable.is_base_literal(obj) + + def unpack_var_sequence( + self, tx: Optional["InstructionTranslator"] + ) -> list[VariableTracker]: + try: + return [ConstantVariable.create(x) for x in self.as_python_constant()] + except TypeError as e: + raise NotImplementedError from e + + def const_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if not hasattr(self.value, name): + raise_observed_exception(AttributeError, tx, args=[name]) + member = getattr(self.value, name) + if callable(member): + raise NotImplementedError + return member + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from .tensor import SymNodeVariable + + if name == "format" and istype(self.value, str): + return variables.BuiltinVariable(str.format).call_function( + tx, [self, *args], kwargs + ) + elif name == "join" and istype(self.value, str): + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + arg_unpacked = args[0].force_unpack_var_sequence(tx) + try: + arg_const = [x.as_python_constant() for x in arg_unpacked] + return ConstantVariable.create(self.value.join(arg_const)) + except NotImplementedError: + return super().call_method(tx, name, args, kwargs) + elif name == "__iter__" and istype(self.value, str): + # this could be some generic iterator to avoid the circular import, + # but ListIterator does what we want + from .lists import ListIteratorVariable + + return ListIteratorVariable( + self.unpack_var_sequence(tx), mutation_type=ValueMutationNew() + ) + + if any(isinstance(x, SymNodeVariable) for x in args): + # Promote to SymNodeVariable for operations involving dynamic shapes. + return variables.SymNodeVariable.create( + tx, self.as_proxy(), self.value + ).call_method(tx, name, args, kwargs) + + try: + const_args = [a.as_python_constant() for a in args] + const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + except NotImplementedError: + return super().call_method(tx, name, args, kwargs) + + if isinstance(self.value, str) and name in str.__dict__: + method = getattr(self.value, name) + try: + return ConstantVariable.create(method(*const_args, **const_kwargs)) + except Exception as e: + raise_observed_exception(type(e), tx) + elif isinstance(self.value, (float, int)): + if not (args or kwargs): + try: + return ConstantVariable.create(getattr(self.value, name)()) + except (OverflowError, ValueError) as exc: + raise_observed_exception( + type(exc), + tx, + args=list(map(ConstantVariable.create, exc.args)), + ) + if ( + hasattr(operator, name) + and len(args) == 1 + and args[0].is_python_constant() + ): + add_target = const_args[0] + op = getattr(operator, name) + if isinstance( + add_target, (torch.SymBool, torch.SymFloat, torch.SymInt) + ): + # Addition between a non sym and sym makes a sym + proxy = tx.output.create_proxy( + "call_function", op, (self.value, add_target), {} + ) + return SymNodeVariable.create(tx, proxy, add_target) + else: + try: + return ConstantVariable.create(op(self.value, add_target)) + except Exception as e: + raise_observed_exception( + type(e), tx, args=list(map(ConstantVariable.create, e.args)) + ) + elif isinstance(self.value, bytes) and name == "decode": + method = getattr(self.value, name) + return ConstantVariable.create(method(*const_args, **const_kwargs)) + elif type(self.value) is complex and name in complex.__dict__: + method = getattr(self.value, name) + try: + return ConstantVariable.create(method(*const_args, **const_kwargs)) + except Exception as e: + raise_observed_exception(type(e), tx) + + if name == "__len__" and not (args or kwargs): + # pyrefly: ignore [bad-argument-type] + return ConstantVariable.create(len(self.value)) + elif name == "__round__" and len(args) == 1 and args[0].is_python_constant(): + try: + return ConstantVariable.create( + # pyrefly: ignore [no-matching-overload] + round(self.value, args[0].as_python_constant()) + ) + except Exception as e: + raise_observed_exception( + type(e), tx, args=list(map(ConstantVariable.create, e.args)) + ) + elif name == "__contains__" and len(args) == 1 and args[0].is_python_constant(): + assert not kwargs + search = args[0].as_python_constant() + try: + # pyrefly: ignore [unsupported-operation] + result = search in self.value + return ConstantVariable.create(result) + except TypeError as e: + raise_observed_exception( + type(e), tx, args=list(map(ConstantVariable.create, e.args)) + ) + return super().call_method(tx, name, args, kwargs) + + def call_tree_map( + self, + tx: "InstructionTranslator", + tree_map_fn: "UserFunctionVariable", + map_fn: VariableTracker, + rest: Sequence[VariableTracker], + tree_map_kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if self.value is None: + none_is_leaf_var = tree_map_kwargs.get("none_is_leaf") + if none_is_leaf_var is not None: + try: + none_is_leaf = bool(none_is_leaf_var.as_python_constant()) + except NotImplementedError: + return self._tree_map_fallback( + tx, + tree_map_fn, + map_fn, + rest, + tree_map_kwargs, + ) + else: + tree_map_module = getattr( + getattr(tree_map_fn, "fn", None), "__module__", "" + ) + # torch.utils._pytree and torch.utils._cxx_pytree treat None as a leaf + # by default, while optree keeps it as an internal node unless + # none_is_leaf=True is provided. + none_is_leaf = not tree_map_module.startswith("optree") + if none_is_leaf: + return map_fn.call_function(tx, [self, *rest], {}) + else: + for other in rest: + if not other.is_constant_none(): + return self._tree_map_fallback( + tx, + tree_map_fn, + map_fn, + rest, + tree_map_kwargs, + ) + return self.clone() + if isinstance(self.value, (int, float, bool, complex, str, bytes, torch.dtype)): + return map_fn.call_function(tx, [self, *rest], {}) + return super().call_tree_map( + tx, + tree_map_fn, + map_fn, + rest, + tree_map_kwargs, + ) + + @override + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> "ConstantVariable": + result = hasattr(self.value, name) + return variables.ConstantVariable.create(result) + + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.value) + + def is_python_equal(self, other): + # Could be an EnumVariable as well + from .tensor import SymNodeVariable + + if isinstance(other, SymNodeVariable): + return self.as_python_constant() == other.evaluate_expr() + return self.as_python_constant() == other.as_python_constant() + + +class EnumVariable(VariableTracker): + """VariableTracker for enum.Enum and enum.IntEnum instances + + Provides specialized handling for Python enum types, supporting + both standard Enum and IntEnum with proper value tracking and comparison. + """ + + def __init__(self, value: Union[enum.Enum, enum.IntEnum], **kwargs: Any) -> None: + super().__init__(**kwargs) + self.value = value + + @classmethod + def create( + cls, cls_type: Any, value_vt: VariableTracker, options: Any + ) -> "EnumVariable": + if value_vt.is_python_constant(): + for member in list(cls_type): + if member.value == value_vt.as_python_constant(): + return cls(member, **options) + unimplemented( + gb_type="Failed to construct Enum variable", + context=f"value: {value_vt}, allowed enum values: {list(cls_type)}", + explanation="Attempted to construct an Enum value that is non-constant (e.g. int, string) " + "or is not an acceptable value for the Enum. " + f"Acceptable values for Enum `{cls_type}`: {list(cls_type)}.", + hints=[*graph_break_hints.USER_ERROR, *graph_break_hints.SUPPORTABLE], + ) + + def as_proxy(self) -> Union[enum.Enum, int]: + if isinstance(self.value, int): + return int(self.value) # convert IntEnum to a normal int + return self.value + + def __repr__(self) -> str: + return f"EnumVariable({type(self.value)})" + + def as_python_constant(self) -> Union[enum.Enum, enum.IntEnum]: + return self.value + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if not hasattr(self.value, name): + raise NotImplementedError + if name in cmp_name_to_op_mapping: + return variables.GetAttrVariable(self, name) + member = getattr(self.value, name) + source = self.source and AttrSource(self.source, name) + return VariableTracker.build(tx, member, source=source) + + def is_python_hashable(self): + raise_on_overridden_hash(self.value, self) + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/ctx_manager.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/ctx_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..9c08a2d12eb96d3bf94880d17fe9064f9ea53975 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/ctx_manager.py @@ -0,0 +1,1529 @@ +""" +This file contains a collection of context manager classes used by Dynamo for tracking +and managing various PyTorch runtime states during graph compilation. These context +managers handle different aspects of PyTorch's execution environment, including: + +- Autograd states (grad mode, inference mode) +- CUDA streams and events +- Profiling contexts +- Deterministic algorithms +- Forward/backward AD modes +- SDPA (Scaled Dot Product Attention) kernels +- FSDP (Fully Sharded Data Parallel) states +- AMP (Automatic Mixed Precision) autocast states + +The context managers ensure proper state transitions during graph compilation by +tracking enter/exit points and managing cleanup operations. They help maintain +consistency between eager execution and compiled graph behavior by capturing and +restoring state changes. +""" + +import inspect +import sys +import warnings +from collections.abc import Callable, Sequence, Sized +from contextlib import AbstractContextManager, ExitStack +from typing import Any, Optional, TYPE_CHECKING, Union + +import torch._C +from torch._guards import Guard + +from .. import graph_break_hints, variables +from ..bytecode_transformation import ( + create_call_function, + create_instruction, + create_setup_with, +) +from ..exc import unimplemented +from ..guards import GuardBuilder, install_guard +from ..source import AttrSource, GlobalStateSource +from ..utils import _get_error_on_graph_break, _set_error_on_graph_break +from .base import VariableTracker +from .functions import ( + NestedUserFunctionVariable, + SkipFunctionVariable, + UserFunctionVariable, + UserMethodVariable, + WrappedNestedUserFunctionVariable, + WrappedSkipFunctionVariable, + WrappedUserFunctionVariable, + WrappedUserMethodVariable, +) +from .user_defined import UserDefinedObjectVariable + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + +class ContextWrappingVariable(VariableTracker): + _nonvar_fields = { + "cm_obj", + "target_values", + "initial_values", + "state", + *VariableTracker._nonvar_fields, + } + + def __init__( + self, target_values: Any, initial_values: Optional[Any] = None, **kwargs: Any + ) -> None: + super().__init__(**kwargs) + self.target_values = target_values + self.initial_values = initial_values + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + if hasattr(self, "_call_func"): + self._call_func(tx, self.target_values) + self.set_cleanup_hook(tx) + return variables.ConstantVariable.create(None) + + def set_cleanup_hook( + self, tx: "InstructionTranslator", fn: Optional[Callable[..., Any]] = None + ) -> None: + if fn is None: + + def fn() -> None: + if hasattr(self, "_call_func"): + self._call_func(tx, self.initial_values) + + self.cleanup_fn: Optional[Callable[..., Any]] = fn + tx.output.add_cleanup_hook(self.cleanup) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self.cleanup_assert() + return variables.ConstantVariable.create(None) + + def reconstruct_type(self, codegen: "PyCodegen") -> None: + codegen( + AttrSource(codegen.tx.import_source(self.module_name()), self.fn_name()) + ) + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null(lambda: self.reconstruct_type(codegen)) + target_values = self.target_values + if not target_values: + target_values = () + codegen.extend_output([codegen.create_load_const(val) for val in target_values]) + codegen.extend_output(create_call_function(len(target_values), False)) + + def module_name(self) -> str: + raise NotImplementedError("module_name called on base") + + def fn_name(self) -> str: + raise NotImplementedError("fn_name called on base") + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + assert len(args) == 1 + assert isinstance( + args[0], + ( + NestedUserFunctionVariable, + SkipFunctionVariable, + UserMethodVariable, + UserFunctionVariable, + ), + ) + + if isinstance(args[0], NestedUserFunctionVariable): + return WrappedNestedUserFunctionVariable(args[0], self) + elif isinstance(args[0], SkipFunctionVariable): + return WrappedSkipFunctionVariable(args[0], self) + elif isinstance(args[0], UserMethodVariable): + return WrappedUserMethodVariable(args[0], self) + elif isinstance(args[0], UserFunctionVariable): + return WrappedUserFunctionVariable(args[0], self) + else: + raise AssertionError("Unexpected arg type") + + def supports_graph_breaks(self) -> bool: + return True + + def exit_on_graph_break(self) -> bool: + return True + + def cleanup(self) -> None: + if self.cleanup_fn is not None: + self.cleanup_fn() + self.cleanup_fn = None + + def cleanup_assert(self) -> None: + assert self.cleanup_fn, "multiple exits?" + self.cleanup() + + +class GenericContextWrappingVariable(UserDefinedObjectVariable): + # Some methods in ContextWrappingVariable assumes the arguments are + # python constants. Which might not always be the case here. + def __init__(self, cm_obj: AbstractContextManager[Any], **kwargs: Any) -> None: + assert cm_obj is not None + super().__init__( + value=cm_obj, + value_type=cm_obj.__class__, + **kwargs, + ) + self.cm_obj = cm_obj + + def module_name(self) -> str: + return self.cm_obj.__module__ + + def fn_name(self) -> str: + return type(self.cm_obj).__name__ + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + source = None if self.source is None else AttrSource(self.source, "__enter__") + return variables.UserMethodVariable( + self.cm_obj.__enter__.__func__, # type: ignore[attr-defined] + self, + source=source, + ).call_function(tx, [], {}) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + source = None if self.source is None else AttrSource(self.source, "__exit__") + x = variables.UserMethodVariable( + self.cm_obj.__exit__.__func__, # type: ignore[attr-defined] + self, + source=source, + ).call_function(tx, list(args), {}) + tx.active_generic_context_managers.pop() + return x + + def supports_graph_breaks(self) -> bool: + return False + + def exit_on_graph_break(self) -> bool: + return True + + +class RepararametrizeModuleContextVariable(GenericContextWrappingVariable): + def __init__(self, ctx_manager_vt: ContextWrappingVariable, mod: Any) -> None: + self.cm_vt = ctx_manager_vt + self.mod = mod + # We don't call super().__init__() because we're delegating most methods to cm_vt + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + # Custom enter implementation with side effects + + self.old_parameters_var = self.mod.var_getattr(tx, "_parameters").realize() + self.old_buffer_var = self.mod.var_getattr(tx, "_buffers").realize() + tx.output.side_effects.ignore_mutations_on(self.old_parameters_var) + tx.output.side_effects.ignore_mutations_on(self.old_buffer_var) + return self.cm_vt.enter(tx) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + # Custom exit implementation with side effects + x = self.cm_vt.exit(tx, *args) + tx.output.side_effects.stop_ignoring_mutations_on(self.old_buffer_var) + tx.output.side_effects.stop_ignoring_mutations_on(self.old_parameters_var) + return x + + # Forward all other method calls to self.cm_vt + def __getattr__(self, name: str) -> Any: + # This will be called for any attribute not explicitly defined in this class + return getattr(self.cm_vt, name) + + +class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable): + """represents torch grad requires grad""" + + @staticmethod + def create( + tx: "InstructionTranslator", target_values: Any, **kwargs: Any + ) -> "GradInplaceRequiresGradCtxManagerVariable": + return GradInplaceRequiresGradCtxManagerVariable( + target_values=target_values, + initial_values=None, + **kwargs, + ) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + [enabled] = self.target_values + self.prev_state = torch._C._functorch.get_inplace_requires_grad_allowed() + torch._C._functorch.set_inplace_requires_grad_allowed(enabled) + self.set_cleanup_hook( + tx, + lambda: torch._C._functorch.set_inplace_requires_grad_allowed( + self.prev_state + ), + ) + self.proxy = tx.output.create_node( + "call_function", + torch._C._functorch.set_inplace_requires_grad_allowed, + (enabled,), + {}, + ) + return variables.ConstantVariable.create(None) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self.cleanup() + tx.output.create_node( + "call_function", + torch._C._functorch.set_inplace_requires_grad_allowed, + (self.prev_state,), + {}, + ) + return variables.ConstantVariable.create(None) + + +class TemporarilyPopInterpreterStackCtxManagerVariable(ContextWrappingVariable): + """represents torch._functorch.pyfunction.temporarily_pop_interpreter_stack()""" + + @staticmethod + def create( + tx: "InstructionTranslator", target_values: Any, **kwargs: Any + ) -> "TemporarilyPopInterpreterStackCtxManagerVariable": + return TemporarilyPopInterpreterStackCtxManagerVariable( + target_values=target_values, + initial_values=None, + **kwargs, + ) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + self.saved = torch._C._functorch.pop_dynamic_layer_stack() + self.set_cleanup_hook( + tx, + lambda: torch._C._functorch.push_dynamic_layer_stack(self.saved), + ) + self.proxy = tx.output.create_node( + "call_function", + torch._C._functorch.pop_dynamic_layer_stack, + (), + {}, + ) + return variables.ConstantVariable.create(None) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self.cleanup() + tx.output.create_node( + "call_function", + torch._C._functorch.push_dynamic_layer_stack, + (self.proxy,), + {}, + ) + return variables.ConstantVariable.create(None) + + +class JvpIncrementNestingCtxManagerVariable(ContextWrappingVariable): + """represents torch.func.jvp increment/decrement nesting""" + + # A guard is needed as the grad level is baked into the torch FX graph + # This is fine if jvp is only called from within the function + # being compiled. But the FX graph may be invalid in the case of a jvp + # call from eager that calls the compiled function, as the jvp levels + # may be different. + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) # type: ignore[arg-type] + + @staticmethod + def create( + tx: "InstructionTranslator", **kwargs: Any + ) -> "JvpIncrementNestingCtxManagerVariable": + var = JvpIncrementNestingCtxManagerVariable( + target_values=None, + initial_values=None, + **kwargs, + ) + return var + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + install_guard(self._guards_singleton) + jvp_level = torch._functorch.eager_transforms.enter_jvp_nesting() + self.set_cleanup_hook( + tx, lambda: torch._functorch.eager_transforms.exit_jvp_nesting() + ) + self.proxy = tx.output.create_node( + "call_function", + torch._C._functorch._jvp_increment_nesting, + (), + {}, + ) + return variables.ConstantVariable.create(jvp_level) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self.cleanup() + tx.output.create_node( + "call_function", torch._C._functorch._jvp_decrement_nesting, (), {} + ) + return variables.ConstantVariable.create(None) + + +class SetFwdGradEnabledContextManager(ContextWrappingVariable): + """represents torch.autograd.forward_ad._set_fwd_grad_enabled() to enable/disable fwd grad""" + + @staticmethod + def create( + tx: "InstructionTranslator", target_values: Any, **kwargs: Any + ) -> "SetFwdGradEnabledContextManager": + return SetFwdGradEnabledContextManager( + target_values=target_values, + initial_values=None, + **kwargs, + ) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + [mode] = self.target_values + self.prev_state = torch._C._is_fwd_grad_enabled() + torch._C._set_fwd_grad_enabled(mode) + self.set_cleanup_hook( + tx, + lambda: torch._C._set_fwd_grad_enabled(self.prev_state), + ) + self.proxy = tx.output.create_node( + "call_function", + torch._C._set_fwd_grad_enabled, + (mode,), + {}, + ) + return variables.ConstantVariable.create(None) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self.cleanup() + tx.output.create_node( + "call_function", + torch._C._set_fwd_grad_enabled, + (self.prev_state,), + {}, + ) + return variables.ConstantVariable.create(None) + + +class DualLevelContextManager(ContextWrappingVariable): + """Represents torch.autograd.forward_ad.dual_level ctx manager""" + + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.DUAL_LEVEL) # type: ignore[arg-type] + + @staticmethod + def create(tx: "InstructionTranslator", **kwargs: Any) -> "DualLevelContextManager": + return DualLevelContextManager( + target_values=None, + initial_values=None, + **kwargs, + ) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + install_guard(self._guards_singleton) + self.new_level = torch.autograd.forward_ad.enter_dual_level() + self.set_cleanup_hook( + tx, lambda: torch.autograd.forward_ad.exit_dual_level(level=self.new_level) + ) + self.proxy = tx.output.create_node( + "call_function", + torch._C._enter_dual_level, + (), + {}, + ) + return variables.ConstantVariable.create(self.new_level) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self.cleanup() + tx.output.create_node( + "call_function", + torch._C._exit_dual_level, + (self.new_level,), + {}, + ) + return variables.ConstantVariable.create(None) + + +class GradIncrementNestingCtxManagerVariable(ContextWrappingVariable): + """represents torch.func.grad increment/decrement nesting""" + + # A guard is needed as the grad level is baked into the torch FX graph + # This is fine if grad is only called from within the function + # being compiled. But the FX graph may be invalid in the case of a grad + # call from eager that calls the compiled function, as the grad levels + # may be different. + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) # type: ignore[arg-type] + + @staticmethod + def create( + tx: "InstructionTranslator", **kwargs: Any + ) -> "GradIncrementNestingCtxManagerVariable": + var = GradIncrementNestingCtxManagerVariable( + target_values=None, + initial_values=None, + **kwargs, + ) + return var + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + install_guard(self._guards_singleton) + grad_level = torch._C._functorch._grad_increment_nesting() + self.set_cleanup_hook(tx, lambda: torch._C._functorch._grad_decrement_nesting()) + self.proxy = tx.output.create_node( + "call_function", + torch._C._functorch._grad_increment_nesting, + (), + {}, + ) + return variables.ConstantVariable.create(grad_level) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self.cleanup() + tx.output.create_node( + "call_function", torch._C._functorch._grad_decrement_nesting, (), {} + ) + return variables.ConstantVariable.create(None) + + +class CatchWarningsCtxManagerVariable(ContextWrappingVariable): + """Delay a call to warnings.catch_warnings""" + + @staticmethod + def create( + tx: "InstructionTranslator", catch_warnings_args: dict[str, VariableTracker] + ) -> "CatchWarningsCtxManagerVariable": + return CatchWarningsCtxManagerVariable( + catch_warnings_args=catch_warnings_args, + target_values=None, + initial_values=None, + ) + + def __init__( + self, + catch_warnings_args: dict[str, VariableTracker], + target_values: Optional[Any] = None, + initial_values: Optional[Any] = None, + **kwargs: Any, + ) -> None: + assert isinstance(catch_warnings_args, dict), catch_warnings_args + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + self.catch_warnings_args = catch_warnings_args + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + kwargs = { + k: v.as_python_constant() for k, v in self.catch_warnings_args.items() + } + ctx_val = warnings.catch_warnings(**kwargs) + self.set_cleanup_hook(tx, lambda: ctx_val.__exit__(None, None, None)) + return variables.ConstantVariable.create(ctx_val.__enter__()) + + def reconstruct(self, cg: "PyCodegen") -> None: + cg.add_push_null(lambda: cg.load_import_from("warnings", "catch_warnings")) + cg.foreach(self.catch_warnings_args.values()) + keys = tuple(self.catch_warnings_args.keys()) + cg.extend_output(cg.create_call_function_kw(len(keys), keys, False)) + + +class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable): + """represents torch VMap increment/decrement nesting""" + + # A guard is needed as the vmap level is baked into the torch FX graph + # generated. This is fine if vmap is only called from within the function + # being compiled. But the FX graph may be invalid in the case of a vmap + # call from eager that calls the compiled function, as the vmap levels + # may be different. + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) # type: ignore[arg-type] + + @staticmethod + def create( + tx: "InstructionTranslator", + target_values: Sequence[VariableTracker], + **kwargs: Any, + ) -> "VmapIncrementNestingCtxManagerVariable": + var = VmapIncrementNestingCtxManagerVariable( + target_values=target_values, + initial_values=None, + **kwargs, + ) + return var + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + install_guard(self._guards_singleton) + batch_size, randomness = self.target_values + if isinstance(batch_size, variables.SymNodeVariable): + batch_size_value = batch_size.sym_num + else: + batch_size_value = batch_size.as_python_constant() + randomness = randomness.as_python_constant() + vmap_level = torch._C._functorch._vmap_increment_nesting( + batch_size_value, randomness + ) + self.set_cleanup_hook(tx, lambda: torch._C._functorch._vmap_decrement_nesting()) + self.proxy = tx.output.create_proxy( + "call_function", + torch._functorch.predispatch._vmap_increment_nesting, + (batch_size.as_proxy(), randomness), + {}, + ) + return variables.ConstantVariable.create(vmap_level) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self.cleanup() + tx.output.create_node( + "call_function", + torch._functorch.predispatch._vmap_decrement_nesting, + (), + {}, + ) + return variables.ConstantVariable.create(None) + + +class GradModeVariable(ContextWrappingVariable): + """represents torch.{no_grad,enable_grad,set_grad_mode}()""" + + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.GRAD_MODE) # type: ignore[arg-type] + + @staticmethod + def create( + tx: "InstructionTranslator", + target_value: Any, + initialized: bool = False, + **kwargs: Any, + ) -> "GradModeVariable": + var = GradModeVariable( + target_values=[target_value], + initial_values=[torch.is_grad_enabled()], + **kwargs, + ) + if initialized: + var._call_func(tx, var.target_values) + return var + + def __init__( + self, + target_values: Any, + initial_values: Optional[Sequence[bool]] = None, + initialized: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + install_guard(self._guards_singleton) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + self._call_func(tx, self.target_values) + return variables.ConstantVariable.create(None) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self._call_func(tx, self.initial_values) + return variables.ConstantVariable.create(None) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + self._call_func(tx, self.initial_values) # undo eager initialization + return super().call_function(tx, args, kwargs) + + def _call_func(self, tx: "InstructionTranslator", values: Any) -> None: + assert len(values) == 1 + value = values[0] + # Coalesce grad mode mutations + if torch.is_grad_enabled() != value: + tx.output.create_node( + "call_function", torch._C._set_grad_enabled, (value,), {} + ) + torch._C._set_grad_enabled(value) + + def module_name(self) -> str: + return "torch" + + def fn_name(self) -> str: + return "set_grad_enabled" + + +class InferenceModeVariable(ContextWrappingVariable): + @staticmethod + def create( + tx: "InstructionTranslator", target_value: Any, **kwargs: Any + ) -> "InferenceModeVariable": + var = InferenceModeVariable( + [target_value], initial_values=torch.is_inference_mode_enabled(), **kwargs + ) + return var + + def __init__( + self, + target_values: Any, + initial_values: Optional[bool] = None, + **kwargs: Any, + ) -> None: + if initial_values is None: + # This must be called here since function defaults are evaluated at import time + initial_values = torch.is_inference_mode_enabled() + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self.cleanup_assert() + tx.output.create_node( + "call_function", + torch.autograd.grad_mode._exit_inference_mode, + (self.proxy,), + {}, + ) + return variables.ConstantVariable.create(None) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + disabled_inference_mode_forcibly = False + if ( + torch._dynamo.config.fake_tensor_disable_inference_mode + and self.target_values[0] + ): + # Do not set the inference mode because we keep it off during + # compilation. Set the grad_enabled to False to reflect the relevant + # part of inference_mode to torch.compile. + disabled_inference_mode_forcibly = True + prior = torch.is_grad_enabled() + torch._C._set_grad_enabled(False) + else: + ctx = torch.autograd.grad_mode._enter_inference_mode(*self.target_values) + + def cleanup_hook() -> None: + if disabled_inference_mode_forcibly: + torch._C._set_grad_enabled(prior) + else: + torch.autograd.grad_mode._exit_inference_mode(ctx) + + self.set_cleanup_hook(tx, cleanup_hook) + self.proxy = tx.output.create_node( + "call_function", + torch.autograd.grad_mode._enter_inference_mode, + (*self.target_values,), + {}, + ) + return variables.ConstantVariable.create(None) + + def module_name(self) -> str: + return "torch" + + def fn_name(self) -> str: + return "inference_mode" + + +class CUDADeviceVariable(ContextWrappingVariable): + """represents torch.cuda.device""" + + @staticmethod + def create( + tx: "InstructionTranslator", device: Any, **kwargs: Any + ) -> "CUDADeviceVariable": + var = CUDADeviceVariable( + target_values=[torch.cuda._get_device_index(device, optional=True)], + initial_values=None, + **kwargs, + ) + return var + + def __init__( + self, + target_values: Any, + initial_values: Optional[Any] = None, + **kwargs: Any, + ) -> None: + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self.cleanup_assert() + tx.output.create_node( + "call_function", + torch.cuda._maybe_exchange_device, + (self.proxy,), + {}, + ) + return variables.ConstantVariable.create(False) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + prev_idx = torch.cuda._exchange_device(*self.target_values) + self.set_cleanup_hook(tx, lambda: torch.cuda._maybe_exchange_device(prev_idx)) + self.proxy = tx.output.create_node( + "call_function", + torch.cuda._exchange_device, + (*self.target_values,), + {}, + ) + return variables.ConstantVariable.create(None) + + def module_name(self) -> str: + return "torch.cuda" + + def fn_name(self) -> str: + return "device" + + +class TorchFunctionDisableVariable(ContextWrappingVariable): + """represents whether torch function overrides are enabled or not""" + + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.TORCH_FUNCTION_STATE) # type: ignore[arg-type] + + @staticmethod + def create( + tx: "InstructionTranslator", **kwargs: Any + ) -> "TorchFunctionDisableVariable": + var = TorchFunctionDisableVariable( + target_values=[], + initial_values=[], + **kwargs, + ) + return var + + def __init__( + self, + target_values: Sized, + initial_values: Optional[Sized] = None, + only_subclass: bool = True, + **kwargs: Any, + ) -> None: + assert len(target_values) == 0 + assert initial_values is not None and len(initial_values) == 0 + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + self.only_subclass = only_subclass + self.initial_torch_function_subclass_enabled = ( + tx.symbolic_torch_function_state.torch_function_subclass_enabled + ) + self.initial_torch_function_mode_enabled = ( + tx.symbolic_torch_function_state.torch_function_mode_enabled + ) + + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + install_guard(self._guards_singleton) + + def set_cleanup_hook( + self, + tx: "InstructionTranslator", + cleanup_fn: Optional[Callable[..., Any]] = None, + ) -> None: + if cleanup_fn is None: + + def cleanup_fn() -> None: + tx.symbolic_torch_function_state.torch_function_subclass_enabled = ( + self.initial_torch_function_subclass_enabled + ) + if not self.only_subclass: + tx.symbolic_torch_function_state.torch_function_mode_enabled = ( + self.initial_torch_function_subclass_enabled + ) + + self.cleanup_fn = cleanup_fn + tx.output.add_cleanup_hook(self.cleanup) + + def _call_func(self, tx: "InstructionTranslator", values: Sized) -> None: + assert len(values) == 0 + tx.symbolic_torch_function_state.torch_function_subclass_enabled = False + if not self.only_subclass: + tx.symbolic_torch_function_state.torch_function_mode_enabled = False + + def module_name(self) -> str: + return "torch._C" + + def fn_name(self) -> str: + if self.only_subclass: + return "DisableTorchFunctionSubclass" + return "DisableTorchFunction" + + +class DeterministicAlgorithmsVariable(ContextWrappingVariable): + """represents torch.{are_deterministic_algorithms_enabled,use_deterministic_algorithms}()""" + + _guards_singleton = Guard( + GlobalStateSource(), + GuardBuilder.DETERMINISTIC_ALGORITHMS, # type: ignore[arg-type] + ) + + @staticmethod + def create( + tx: "InstructionTranslator", target_value: bool, **kwargs: Any + ) -> "DeterministicAlgorithmsVariable": + var = DeterministicAlgorithmsVariable( + target_values=[target_value], + initial_values=[torch.are_deterministic_algorithms_enabled()], + **kwargs, + ) + var._call_func(tx, [target_value]) + var.set_cleanup_hook(tx) + return var + + def __init__( + self, + target_values: Sequence[bool], + initial_values: Optional[Sequence[bool]] = None, + **kwargs: Any, + ) -> None: + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + install_guard(self._guards_singleton) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + return variables.ConstantVariable.create(None) + + def _call_func(self, tx: "InstructionTranslator", values: Sequence[bool]) -> None: + assert len(values) == 1 + value = values[0] + tx.output.create_node( + "call_function", torch._C._set_deterministic_algorithms, (value,), {} + ) + torch._C._set_deterministic_algorithms(value) + + def module_name(self) -> str: + return "torch" + + def fn_name(self) -> str: + return "use_deterministic_algorithms" + + +class DisabledSavedTensorsHooksVariable(ContextWrappingVariable): + """represents torch.autograd.graph.disable_saved_tensors_hook.""" + + @staticmethod + def create( + tx: "InstructionTranslator", target_value: Optional[str], **kwargs: Any + ) -> "DisabledSavedTensorsHooksVariable": + var = DisabledSavedTensorsHooksVariable( + target_values=[target_value], + initial_values=[ + torch._C._autograd._saved_tensors_hooks_get_disabled_error_message() + ], + **kwargs, + ) + var._call_func(tx, [target_value]) + var.set_cleanup_hook(tx) + return var + + def __init__( + self, + target_values: Sequence[Optional[str]], + initial_values: Optional[Sequence[Optional[str]]] = None, + **kwargs: Any, + ) -> None: + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + return variables.ConstantVariable.create(None) + + def _call_func( + self, tx: "InstructionTranslator", values: Sequence[Optional[str]] + ) -> None: + assert len(values) == 1 + value = values[0] + if value is not None: + # Disable `saved_tensors_hooks` with message (`value`) + # OR + # we are exiting this context and restoring the previous message. + tx.output.create_node( + "call_function", + torch._C._autograd._saved_tensors_hooks_disable, + (value,), + {}, + ) + torch._C._autograd._saved_tensors_hooks_disable(value) + else: + # We are exiting this context and if prev_message was None, we re-enable `saved_tensors_hooks`. + tx.output.create_node( + "call_function", torch._C._autograd._saved_tensors_hooks_enable, (), {} + ) + torch._C._autograd._saved_tensors_hooks_enable() + + def module_name(self) -> str: + return "torch.autograd.graph" + + def fn_name(self) -> str: + return "disable_saved_tensors_hooks" + + +class AutocastModeVariable(ContextWrappingVariable): + @staticmethod + def create( + func: torch.amp.autocast_mode.autocast, + args: Sequence[Any], + kwargs: dict[str, Any], + ) -> "AutocastModeVariable": + assert func in [ + torch.amp.autocast_mode.autocast, + torch.cuda.amp.autocast, + torch.cpu.amp.autocast, + ] + # device_type : str, + # dtype : Optional[_dtype] = None, + # enabled : bool = True, + # cache_enabled : Optional[bool] = None):cache_enabled + bound_args = inspect.signature(func).bind(*args, **kwargs) + bound_args.apply_defaults() + target_values = [] + kwargs.clear() + + for key in ["device_type", "dtype", "enabled", "cache_enabled"]: + if key == "device_type" and func in [ + torch.cuda.amp.autocast, + torch.cpu.amp.autocast, + ]: + arg = "cuda" if func is torch.cuda.amp.autocast else "cpu" + else: + arg = bound_args.arguments[key] + if isinstance(arg, VariableTracker): + target_values.append(arg.as_python_constant()) + else: + target_values.append(arg) + + var = AutocastModeVariable(target_values, initial_values=None, **kwargs) + return var + + def __init__( + self, + target_values: Sequence[Any], + initial_values: Optional[Any] = None, + **kwargs: Any, + ) -> None: + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self.cleanup_assert() + tx.output.create_node( + "call_function", torch.amp._exit_autocast, (self.proxy,), {} + ) + return variables.ConstantVariable.create(None) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + ctx = torch.amp._enter_autocast(*self.target_values) + self.set_cleanup_hook(tx, lambda: torch.amp._exit_autocast(ctx)) + self.proxy = tx.output.create_node( + "call_function", torch.amp._enter_autocast, (*self.target_values,), {} + ) + return variables.ConstantVariable.create(None) + + def module_name(self) -> str: + return "torch.amp.autocast_mode" + + def fn_name(self) -> str: + return "autocast" + + +class NullContextVariable(ContextWrappingVariable): + """ + This class represents Python contextlib.nullcontext. + """ + + def __init__(self, target_values: Optional[Any] = None, **kwargs: Any) -> None: + super().__init__(target_values=target_values, **kwargs) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + none = variables.ConstantVariable.create(None) + return self.target_values if self.target_values else none + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + return variables.ConstantVariable.create(None) + + def module_name(self) -> str: + return "contextlib" + + def fn_name(self) -> str: + return "nullcontext" + + +class ProfilerContextVariable(ContextWrappingVariable): + """ + This class represents a set of torch profiler context objects, where Dynamo + ignores all the side-effects in the __init__, __enter__ and __exit__ methods + by treating the object mostly as a `contextlib.nullcontext`, except for edge + cases like the `__enter__` method which returns the object itself rather + than `None`, per implementation of the torch objects. + """ + + def __init__(self, **kwargs: Any) -> None: + super().__init__(target_values=None, **kwargs) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + return self + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + return variables.ConstantVariable.create(None) + + def module_name(self) -> str: + return "contextlib" + + def fn_name(self) -> str: + return "nullcontext" + + def reconstruct(self, cg: "PyCodegen") -> None: + unimplemented( + gb_type="torch.profiler object escaped from compiled region", + context=str(self), + explanation="Dynamo doesn't support compiling a region that returns a torch.profiler context manager.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + +class PreserveVersionContextVariable(ContextWrappingVariable): + """ + Wraps torch.autograd._unsafe_preserve_version_counter + """ + + @staticmethod + def _create_lambda_from_tensors( + tx: "InstructionTranslator", + tensors: VariableTracker, + ) -> "PreserveVersionContextVariable": + if tensors.is_tensor(): + versions = variables.TupleVariable( + [x.var_getattr(tx, "_version") for x in [tensors]] + ) + tensors_tuple = variables.TupleVariable([tensors]) + else: + assert isinstance(tensors, variables.TupleVariable) + versions = variables.TupleVariable( + [x.var_getattr(tx, "_version") for x in tensors.items] + ) + tensors_tuple = tensors + return PreserveVersionContextVariable(tensors_tuple, versions) + + @staticmethod + def constructor(tx: "InstructionTranslator") -> VariableTracker: + return variables.LambdaVariable( + lambda tensors: PreserveVersionContextVariable._create_lambda_from_tensors( + tx, tensors + ) + ) + + def __init__( + self, + tensors: VariableTracker, + prev_versions: VariableTracker, + **kwargs: Any, + ) -> None: + kwargs.setdefault("target_values", None) + super().__init__(**kwargs) + self.tensors = tensors + self.prev_versions = prev_versions + # The context manager accepts Union[Tensor, Tuple[Tensor]] + if self.tensors.is_tensor(): + self.tensors = variables.TupleVariable([self.tensors]) + if self.prev_versions.is_symnode_like(): + self.prev_versions = variables.TupleVariable([self.prev_versions]) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + return variables.ConstantVariable.create(None) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + from ..tensor_version_op import _unsafe_set_version_counter + + return variables.TorchInGraphFunctionVariable( + _unsafe_set_version_counter + ).call_function(tx, [self.tensors, self.prev_versions], {}) + + def reconstruct(self, codegen: "PyCodegen") -> None: + unimplemented( + gb_type="torch.autograd._unsafe_preserve_version_counter escaped from compiled region", + context=str(self), + explanation=( + "Dynamo doesn't support compiling a region that returns " + "a torch.autograd._unsafe_preserve_version_counter context manager." + ), + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + +class FSDPParamGroupUseTrainingStateVariable(ContextWrappingVariable): + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FSDP_TRAINING_STATE) # type: ignore[arg-type] + + @staticmethod + def create( + tx: "InstructionTranslator", + param_group_var: Any, + target_value: Any, + **kwargs: Any, + ) -> "FSDPParamGroupUseTrainingStateVariable": + var = FSDPParamGroupUseTrainingStateVariable( + param_group_var=param_group_var, + target_values=[target_value], + initial_values=[param_group_var.value._training_state], + **kwargs, + ) + return var + + def __init__( + self, + param_group_var: Any, + target_values: Sequence[Any], + initial_values: Optional[Sequence[Any]] = None, + **kwargs: Any, + ) -> None: + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + self.param_group_var = param_group_var + install_guard(self._guards_singleton) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + self._call_func(tx, self.target_values) + return variables.ConstantVariable.create(None) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self._call_func(tx, self.initial_values) # type: ignore[arg-type] + return variables.ConstantVariable.create(None) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # undo eager initialization + self._call_func(tx, self.initial_values) # type: ignore[arg-type] + return super().call_function(tx, args, kwargs) + + def _call_func(self, tx: "InstructionTranslator", values: Sequence[Any]) -> None: + assert len(values) == 1 + value = values[0] + if self.param_group_var.value._training_state != value: + self.param_group_var.call_method( + tx, + "__setattr__", + ( + variables.ConstantVariable.create("_training_state"), + variables.EnumVariable(value), + ), + {}, + ) + self.param_group_var.value._training_state = value + + def module_name(self) -> str: + return "torch.distributed.fsdp._fully_shard._fsdp_param_group.FSDPParamGroup" + + def fn_name(self) -> str: + return "use_training_state" + + +class SDPAKernelVariable(ContextWrappingVariable): + """represents torch.nn.attention.sdpa_kernel""" + + @staticmethod + def create( + tx: "InstructionTranslator", + backends: Any, + set_priority: bool = False, + **kwargs: Any, + ) -> "SDPAKernelVariable": + if isinstance(backends, torch.nn.attention.SDPBackend): + backends = [backends] + var = SDPAKernelVariable( + target_values=backends, + initial_values=None, + set_priority=set_priority, + **kwargs, + ) + return var + + def __init__( + self, + target_values: list[torch.nn.attention.SDPBackend], + initial_values: Any = None, + set_priority: bool = False, + **kwargs: Any, + ) -> None: + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + self.set_priority = set_priority + + @staticmethod + def _backends_to_nodes( + tx: "InstructionTranslator", + backends: list[Any], + ) -> list[Any]: + # convert to/from string in order to bake the backend into FX graph + nodes = [ + tx.output.create_node( + "call_function", + torch.nn.attention._backend_from_string, + (backend.name,), + {}, + ) + for backend in backends + ] + return nodes + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + self.prev_backends = torch.nn.attention._cur_sdpa_kernel_backends( + with_priority=self.set_priority + ) + self.set_cleanup_hook( + tx, + lambda: torch.nn.attention._sdpa_kernel( + self.prev_backends, set_priority=self.set_priority + ), + ) + torch.nn.attention._sdpa_kernel( + self.target_values, set_priority=self.set_priority + ) + arg = self._backends_to_nodes(tx, self.target_values) + tx.output.create_node( + "call_function", + torch.nn.attention._sdpa_kernel, + (arg, bool(self.set_priority)), + {}, + ) + return variables.ConstantVariable.create(None) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self.cleanup_assert() + arg = self._backends_to_nodes(tx, self.prev_backends) + tx.output.create_node( + "call_function", + torch.nn.attention._sdpa_kernel, + (arg, bool(self.set_priority)), + {}, + ) + return variables.ConstantVariable.create(None) + + def module_name(self) -> str: + return "torch.nn.attention" + + # use a private version of sdpa_kernel that accepts variadic arguments + # since dynamo reconstructs the contents of target_values one-by-one + def fn_name(self) -> str: + return "_sdpa_kernel_variadic" + + +class FxTracebackAnnotateVariable(ContextWrappingVariable): + """ + fx.traceback.annotate is a context manager that allows users to annotate the + fx graph nodes with custom metadata. In the context of Dynamo, we don't have + to trace the body of the context manager. Instead we want to directly run + the body of the context manager, so the Dynamo created Fx graphs have the + right custom metadata. This variable tracker just runs __enter__ and + __exit__ method (instead of tracing). + """ + + def __init__( + self, target_values: Any, initial_values: Any = None, **kwargs: Any + ) -> None: + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + + def enter( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + # Run the annotation ctx manager in eager. Also ensure that + # preserve_node_meta context manager is setup. This is important to pass + # on the metadata to the create_proxy nodes. + stack = ExitStack() + stack.enter_context(torch.fx.traceback.annotate(self.target_values)) + stack.enter_context(torch.fx.traceback.preserve_node_meta()) + self.set_cleanup_hook(tx, lambda: stack.close()) + return variables.ConstantVariable.create(None) + + def module_name(self) -> str: + return "torch.fx.traceback" + + def fn_name(self) -> str: + return "annotate" + + def reconstruct_type(self, codegen: "PyCodegen") -> None: + unimplemented( + gb_type="torch.fx.traceback.annotate escaped from compiled region", + context=str(self), + explanation="Dynamo doesn't support graph break on torch.fx.traceback.annotate.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + +class DynamoConfigPatchVariable(ContextWrappingVariable): + """represents torch._dynamo.patch_dynamo_config""" + + # NOTE: no need to guard on dynamo config because dynamo config should not affect soundness + # (though it may affect tracing behavior) + def __init__(self, target_values: dict[str, Any], **kwargs: Any) -> None: + target_values_tuple = tuple(target_values.items()) + super().__init__( + target_values=(target_values_tuple,), initial_values=None, **kwargs + ) + initial_values_dict = {} + for key, _ in target_values_tuple: + initial_values_dict[key] = torch._dynamo.config.__getattr__(key) # type: ignore[attr-defined] + self.initial_values = (tuple(initial_values_dict.items()),) + + def _call_func(self, tx: "InstructionTranslator", values: Any) -> None: + assert len(values) == 1 + value = values[0] + # manually patch dynamo config + for key, val in value: + torch._dynamo.config.__setattr__(key, val) # type: ignore[attr-defined] + # No need to keep track of global side effects because + # dynamo will properly restore this context manager for + # unsupported instructions and continuation functions. + # Dynamo config also should not affect the semantics of the compiled graph. + + def module_name(self) -> str: + return "torch._dynamo" + + def fn_name(self) -> str: + return "patch_dynamo_config" + + +class ErrorOnGraphBreakVariable(ContextWrappingVariable): + """represents torch._dynamo.error_on_graph_break""" + + def __init__(self, error_on_graph_break: bool, **kwargs: Any) -> None: + super().__init__( + target_values=(error_on_graph_break,), + initial_values=(_get_error_on_graph_break(),), + **kwargs, + ) + + def _call_func(self, tx: "InstructionTranslator", values: Sequence[bool]) -> None: + assert len(values) == 1 + _set_error_on_graph_break(values[0]) + + def module_name(self) -> str: + return "torch._dynamo" + + def fn_name(self) -> str: + return "error_on_graph_break" + + +class WithEnterFunctionVariable(VariableTracker): + def __init__( + self, + ctx: Union[ContextWrappingVariable, GenericContextWrappingVariable], + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.ctx = ctx + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + assert not args + assert not kwargs + # NOTE: we assume that the instruction immediately after the current CALL instruction + # is the first instruction of the block. + # pyrefly: ignore [bad-argument-type] + return tx.enter_ctx(self.ctx, tx.current_instruction) + + def reconstruct(self, codegen: "PyCodegen") -> None: + try: + type_str = f"{self.ctx.module_name()}.{self.ctx.fn_name()}" + except NotImplementedError: + type_str = str(type(self.ctx)) + unimplemented( + gb_type="Attempted to reconstruct context manager's __enter__ method", + context=str(self.ctx), + explanation=f"Attempted to reconstruct context manager {type_str} while tracing `with ...:`", + hints=[ + "It is likely there is a graph break while tracing `with ctx:` " + "but outside the actual `ctx.__enter__()` method. " + "`torch.compile` does not expect this to happen.", + *graph_break_hints.DIFFICULT, + *graph_break_hints.DYNAMO_BUG, + ], + ) + + +class WithExitFunctionVariable(VariableTracker): + _nonvar_fields = { + "target", + *VariableTracker._nonvar_fields, + } + + def __init__( + self, + ctx: Union[ContextWrappingVariable, GenericContextWrappingVariable], + target: Any, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + assert isinstance( + ctx, (ContextWrappingVariable, GenericContextWrappingVariable) + ) + self.ctx = ctx + self.target = target + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + assert not kwargs + return self.ctx.exit(tx, *args) + + def reconstruct(self, codegen: "PyCodegen") -> None: + # Note here we reconstruct the context manager rather than the + # exit function. The handler generated by BlockStackEntry + # will re-enter the context in the resume function. + self.ctx.reconstruct_type(codegen) # type: ignore[union-attr] + if codegen.tx.output.partial_convert: + if sys.version_info >= (3, 11): + codegen.append_output(create_instruction("PUSH_NULL")) + if sys.version_info < (3, 13): + codegen.append_output(create_instruction("SWAP", arg=2)) + # We rely on classes subtyping `GenericContextWrappingVariable` + # to implement these fns and have these attributes + codegen.extend_output( + [codegen.create_load_const(val) for val in self.ctx.target_values] # type: ignore[union-attr] + ) + codegen.extend_output( + create_call_function(len(self.ctx.target_values), False) # type: ignore[union-attr] + ) + codegen.append_output(create_setup_with(self.target)) + codegen.append_output(create_instruction("POP_TOP")) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/dicts.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/dicts.py new file mode 100644 index 0000000000000000000000000000000000000000..3a07bc1ac03cea5d41890904ce988f5608c96a82 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/dicts.py @@ -0,0 +1,1555 @@ +""" +Dictionary-related variable tracking classes for PyTorch Dynamo. + +This module implements variable tracking for different types of dictionary-like objects: +- Regular Python dictionaries (dict) +- Ordered dictionaries (collections.OrderedDict) +- Default dictionaries (collections.defaultdict) +- Dictionary views (keys and values) +- Sets and frozensets (implemented internally using dictionaries) + +These classes are responsible for tracking dictionary operations during graph compilation, +maintaining proper guards for dictionary mutations and key existence checks. They handle +dictionary creation, modification, key/value access, and view operations while ensuring +correct behavior in the compiled code through appropriate guard installation. + +The implementation uses a special _HashableTracker wrapper to handle dictionary keys +while preserving proper aliasing semantics. Sets are implemented as dictionaries with +None values for efficiency and code reuse. +""" + +import collections +import functools +import operator +import types +from collections.abc import Sequence +from typing import Any, Optional, TYPE_CHECKING, Union + +from .. import graph_break_hints, polyfills, variables +from ..bytecode_transformation import create_call_function, create_instruction +from ..exc import raise_observed_exception, unimplemented +from ..guards import GuardBuilder, install_guard +from ..source import is_constant_source, is_from_local_source +from ..utils import ( + cmp_name_to_op_mapping, + dict_items, + dict_keys, + dict_values, + istype, + raise_args_mismatch, + specialize_symnode, +) +from .base import ValueMutationNew, VariableTracker +from .constant import ConstantVariable +from .lists import ListIteratorVariable + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + from .functions import UserFunctionVariable + + +# [Adding a new supported class within the keys of ConstDictVariable] +# - Implement is_python_hashable() method in the VariableTracker subclass +# - Implement get_python_hash() and is_python_equal() methods for hashable types + + +def was_instancecheck_override(obj: Any) -> bool: + return type(obj).__dict__.get("__instancecheck__", False) + + +def raise_unhashable( + arg: VariableTracker, tx: Optional["InstructionTranslator"] = None +) -> None: + if tx is None: + from torch._dynamo.symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + try: + arg_type = arg.python_type() + except Exception: + arg_type = type(arg) + + raise_observed_exception( + TypeError, + tx, + args=[ + ConstantVariable( + f"unhashable type: {arg_type!r} and variable tracker = {type(arg.realize())}" + ) + ], + ) + + +def is_hashable(x: VariableTracker) -> bool: + # NB - performing isinstance check on a LazVT realizes the VT, accidentally + # inserting the guard. To avoid this, lazyVT `is_hashable` methods looks at + # the underlying value without realizing the VT. Consider updating the + # lazyVT `is_hashable` method if you see unnecessary guarding for a key VT. + if ( + isinstance(x, variables.LazyVariableTracker) + and not x.is_realized() + and x.is_hashable() + ): + return True + return x.is_python_hashable() + + +class ConstDictVariable(VariableTracker): + CONTAINS_GUARD = GuardBuilder.DICT_CONTAINS + + _nonvar_fields = { + "user_cls", + *VariableTracker._nonvar_fields, + } + + class _HashableTracker: + """ + Auxiliary opaque internal class that wraps a VariableTracker and makes it hashable + This should not be seen or touched by anything outside of ConstDictVariable and its children + Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing + """ + + def __init__(self, vt: VariableTracker) -> None: + # We specialize SymNodes + vt = specialize_symnode(vt) + + # If Dynamo does not know the hashability of the vt, it will raise unsupported here + if not is_hashable(vt): + raise_unhashable(vt) + self.vt = vt + + def __hash__(self) -> int: + """ + Computes the hash value for the wrapped VariableTracker. + + For unrealized LazyVariableTrackers, uses the hash of the original value + to avoid realizing the tracker and inserting unnecessary guards. + For all other cases, delegates to the VariableTracker's get_python_hash method. + + Returns: + The hash value of the underlying variable tracker + """ + if ( + isinstance(self.vt, variables.LazyVariableTracker) + and not self.vt.is_realized() + and self.vt.is_hashable() + ): + return hash(self.vt.original_value()) + return self.vt.get_python_hash() + + def __eq__(self, other) -> bool: + """ + Checks equality between two _HashableTracker instances. + + Delegates to the VariableTracker's is_python_equal method to compare + the underlying variable trackers for Python-level equality. + + Args: + other: Another _HashableTracker instance to compare with + + Returns: + True if the underlying variable trackers are Python-equal, False otherwise + """ + if self.vt is other.vt: + return True + return self.vt.is_python_equal(other.vt) + + def __init__( + self, + items: dict[VariableTracker, VariableTracker], + user_cls: type = dict, + **kwargs: Any, + ) -> None: + # .clone() pass these arguments in kwargs but they're recreated a few + # lines below + if "original_items" in kwargs: + kwargs.pop("original_items") + if "should_reconstruct_all" in kwargs: + kwargs.pop("should_reconstruct_all") + + super().__init__(**kwargs) + + Hashable = ConstDictVariable._HashableTracker + + # Keys will just be HashableTrackers when cloning, in any other case they'll be VariableTrackers + assert all( + isinstance(x, (VariableTracker, Hashable)) + and isinstance(v, VariableTracker) + for x, v in items.items() + ) + + def make_hashable( + key: Union[VariableTracker, "ConstDictVariable._HashableTracker"], + ) -> "ConstDictVariable._HashableTracker": + return key if isinstance(key, Hashable) else Hashable(key) + + dict_cls = self._get_dict_cls_from_user_cls(user_cls) + self.items = dict_cls({make_hashable(x): v for x, v in items.items()}) + # need to reconstruct everything if the dictionary is an intermediate value + # or if a pop/delitem was executed + self.should_reconstruct_all = ( + not is_from_local_source(self.source) if self.source else True + ) + self.original_items = items.copy() + self.user_cls = user_cls + + def _get_dict_cls_from_user_cls(self, user_cls: type) -> type: + accepted_dict_types = (dict, collections.OrderedDict, collections.defaultdict) + + # avoid executing user code if user_cls is a dict subclass + if user_cls in accepted_dict_types: + dict_cls = user_cls + else: + # + dict_cls = next( + base for base in user_cls.__mro__ if base in accepted_dict_types + ) + assert dict_cls in accepted_dict_types, dict_cls + + # Use a dict instead as the call "defaultdict({make_hashable(x): v ..})" + # would fail as defaultdict expects a callable as first argument + if dict_cls is collections.defaultdict: + dict_cls = dict + return dict_cls + + def as_proxy(self) -> dict[Any, Any]: + return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()} + + def debug_repr(self) -> str: + return ( + "{" + + ", ".join( + f"{k.vt.debug_repr()}: {v.debug_repr()}" for k, v in self.items.items() + ) + + "}" + ) + + def as_python_constant(self) -> dict[Any, Any]: + return { + k.vt.as_python_constant(): v.as_python_constant() + for k, v in self.items.items() + } + + def keys_as_python_constant(self) -> dict[Any, VariableTracker]: + self.install_dict_keys_match_guard() + return {k.vt.as_python_constant(): v for k, v in self.items.items()} + + def python_type(self) -> type: + return self.user_cls + + def __contains__(self, vt: VariableTracker) -> bool: + assert isinstance(vt, VariableTracker) + Hashable = ConstDictVariable._HashableTracker + return ( + vt.is_python_hashable() + and Hashable(vt) in self.items + and not isinstance(self.items[Hashable(vt)], variables.DeletedVariable) + ) + + def call_tree_map_branch( + self, + tx: "InstructionTranslator", + tree_map_fn: "UserFunctionVariable", + map_fn: VariableTracker, + rest: Sequence[VariableTracker], + tree_map_kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + other_dicts: list[ConstDictVariable] = [] + for candidate in rest: + candidate = candidate.realize() + if not isinstance(candidate, ConstDictVariable) or len( + candidate.items + ) != len(self.items): + return self._tree_map_fallback( + tx, tree_map_fn, map_fn, rest, tree_map_kwargs + ) + other_dicts.append(candidate) + + new_items_hashed = type(self.items)() + for key_tracker, value in self.items.items(): + sibling_leaves: list[VariableTracker] = [] + for candidate in other_dicts: + try: + sibling_leaves.append(candidate.items[key_tracker]) + except KeyError: + return self._tree_map_fallback( + tx, tree_map_fn, map_fn, rest, tree_map_kwargs + ) + new_items_hashed[key_tracker] = value.call_tree_map( + tx, + tree_map_fn, + map_fn, + sibling_leaves, + tree_map_kwargs, + ) + + updated_original_items = { + key_tracker.vt: new_items_hashed[key_tracker] + for key_tracker in new_items_hashed + } + + return self.clone( + items=new_items_hashed, + original_items=updated_original_items, + should_reconstruct_all=True, + source=None, + mutation_type=ValueMutationNew(), + ) + + def len(self) -> int: + return sum( + not isinstance(x, variables.DeletedVariable) for x in self.items.values() + ) + + def has_new_items(self) -> bool: + return self.should_reconstruct_all or any( + self.is_new_item(self.original_items.get(key.vt), value) + for key, value in self.items.items() + ) + + def is_new_item( + self, value: Optional[VariableTracker], other: VariableTracker + ) -> bool: + # compare the id of the realized values if both values are not lazy VTs + if value and value.is_realized() and other.is_realized(): + return id(value.realize()) != id(other.realize()) + return id(value) != id(other) + + def reconstruct_kvs_into_new_dict(self, codegen: "PyCodegen") -> None: + # Build a dictionary that contains the keys and values. + num_args = 0 + for key, value in self.items.items(): + # We can safely call realize() here as it won't introduce any new guards + item = self.original_items.get(key.vt) + if self.is_new_item(item, value) or self.should_reconstruct_all: + codegen(key.vt) + codegen(value) + num_args += 1 + codegen.append_output(create_instruction("BUILD_MAP", arg=num_args)) + + def reconstruct(self, codegen: "PyCodegen") -> None: + if self.user_cls is collections.OrderedDict: + # emit `OrderedDict(constructed_dict)` + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_python_module(collections), + codegen.create_load_attr("OrderedDict"), + ] + ) + ) + self.reconstruct_kvs_into_new_dict(codegen) + codegen.extend_output(create_call_function(1, False)) + else: + self.reconstruct_kvs_into_new_dict(codegen) + + def getitem_const_raise_exception_if_absent( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: + key = ConstDictVariable._HashableTracker(arg) + if key not in self.items: + try: + error_message = ( + f"Dict key lookup failed for {str(arg)}. " + f"Debug representation of the key is {arg.debug_repr()!r}" + ) + except Exception: + error_message = ConstantVariable.create( + f"Dict key lookup failed for {str(arg)}" + ) + raise_observed_exception(KeyError, tx, args=[error_message]) + return self.items[key] + + def getitem_const( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: + key = ConstDictVariable._HashableTracker(arg) + if key not in self.items: + msg = f"Dictionary key {arg.value} not found during tracing" # type: ignore[attr-defined] + unimplemented( + gb_type="key not found in dict", + context=f"Key {arg.value}", # type: ignore[attr-defined] + explanation=msg, + hints=[ + "Check if the key exists in the dictionary before accessing it.", + *graph_break_hints.USER_ERROR, + ], + ) + return self.items[key] + + def maybe_getitem_const(self, arg: VariableTracker) -> Optional[VariableTracker]: + key = ConstDictVariable._HashableTracker(arg) + if key not in self.items: + return None + return self.items[key] + + def realize_key_vt(self, arg: VariableTracker) -> None: + # Realize the LazyVT on a particular index + assert arg in self + key = ConstDictVariable._HashableTracker(arg) + index = tuple(self.items.keys()).index(key) + original_key_vt = tuple(self.original_items.keys())[index] + if isinstance(original_key_vt, variables.LazyVariableTracker): + original_key_vt.realize() + + def install_dict_keys_match_guard(self) -> None: + if self.source: + install_guard(self.make_guard(GuardBuilder.DICT_KEYS_MATCH)) + + def install_dict_contains_guard( + self, tx: "InstructionTranslator", args: list[VariableTracker] + ) -> None: + # Key guarding - These are the cases to consider + # 1) The dict has been mutated. In this case, we would have already + # inserted a DICT_KEYS_MATCH guard, so we can skip. + # + # 2) args[0].source is None. This happens for const keys. Here, we + # have to insert the DICT_CONTAINS guard. + # + # 3) args[0].source is not None. This can happen for non-const VTs. + # 3a) contains=True. In this case, we can access the lazyVT from + # original_items and selectively realize it. + # 3b) contains=False. There is no easy way to selectively apply this + # DICT_NOT_CONTAINS guard because our guard are represented via trees. + # Be conservative and add DICT_KEYS_MATCH guard. + + if not self.source: + return + + if tx.output.side_effects.is_modified(self): + return + + contains = args[0] in self + if args[0].source is None and args[0].is_python_constant(): + install_guard( + self.make_guard( + functools.partial( + type(self).CONTAINS_GUARD, + key=args[0].as_python_constant(), + invert=not contains, + ) + ) + ) + elif args[0].source: + if contains: + self.realize_key_vt(args[0]) + else: + self.install_dict_keys_match_guard() + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # NB - Both key and value are LazyVariableTrackers in the beginning. So, + # we have to insert guards when a dict method is accessed. For this to + # be simple, we are conservative and overguard. We skip guard only for + # get/__getitem__ because the key guard will be inserted by the + # corresponding value VT. For __contains__, we add a DICT_CONTAINS + # guard. But for all the other methods, we insert the DICT_KEYS_MATCH + # guard to be conservative. + from . import BuiltinVariable, ConstantVariable + + Hashable = ConstDictVariable._HashableTracker + + if name == "__init__": + temp_dict_vt = variables.BuiltinVariable(dict).call_dict( + tx, *args, **kwargs + ) + tx.output.side_effects.mutation(self) + self.items.update(temp_dict_vt.items) # type: ignore[attr-defined] + return ConstantVariable.create(None) + elif name == "__getitem__": + # Key guarding - Nothing to do. LazyVT for value will take care. + if len(args) != 1: + raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") + return self.getitem_const_raise_exception_if_absent(tx, args[0]) + elif name == "items": + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + self.install_dict_keys_match_guard() + if self.source: + tx.output.guard_on_key_order.add(self.source) + return DictItemsVariable(self) + elif name == "keys": + if len(args): + raise_args_mismatch(tx, name, "0 args", f"{len(args)} args") + self.install_dict_keys_match_guard() + if self.source: + tx.output.guard_on_key_order.add(self.source) + return DictKeysVariable(self) + elif name == "values": + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + self.install_dict_keys_match_guard() + if self.source: + tx.output.guard_on_key_order.add(self.source) + if args or kwargs: + raise_observed_exception(TypeError, tx) + return DictValuesVariable(self) + elif name == "copy": + self.install_dict_keys_match_guard() + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + return self.clone( + items=self.items.copy(), mutation_type=ValueMutationNew(), source=None + ) + elif name == "__len__": + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + self.install_dict_keys_match_guard() + return ConstantVariable.create(len(self.items)) + elif name == "__setitem__" and self.is_mutable(): + arg_hashable = args and is_hashable(args[0]) + if not arg_hashable: + raise_unhashable(args[0], tx) + + self.install_dict_keys_match_guard() + if kwargs or len(args) != 2: + raise_args_mismatch( + tx, + name, + "2 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + tx.output.side_effects.mutation(self) + self.items[Hashable(args[0])] = args[1] + return ConstantVariable.create(None) + elif name == "__delitem__" and self.is_mutable(): + arg_hashable = args and is_hashable(args[0]) + if arg_hashable: + self.install_dict_keys_match_guard() + self.should_reconstruct_all = True + tx.output.side_effects.mutation(self) + self.items.__delitem__(Hashable(args[0])) + return ConstantVariable.create(None) + else: + return super().call_method(tx, name, args, kwargs) + elif name == "get": + if len(args) not in (1, 2): + raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args") + + arg_hashable = args and is_hashable(args[0]) + if not arg_hashable: + raise_unhashable(args[0], tx) + + if args[0] not in self: + self.install_dict_contains_guard(tx, args) + if len(args) == 1: + # if default is not given, return None + return ConstantVariable.create(None) + return args[1] + # Key guarding - Nothing to do. + return self.getitem_const(tx, args[0]) + elif name == "pop" and self.is_mutable(): + if len(args) not in (1, 2): + raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args") + + arg_hashable = args and is_hashable(args[0]) + if not arg_hashable: + raise_unhashable(args[0], tx) + + if args[0] not in self: + # missing item, return the default value. Install no DICT_CONTAINS guard. + self.install_dict_contains_guard(tx, args) + if len(args) == 1: + # if default is not given, raise KeyError + raise_observed_exception(KeyError, tx) + return args[1] + + self.should_reconstruct_all = True + tx.output.side_effects.mutation(self) + return self.items.pop(Hashable(args[0])) + elif name == "popitem" and self.is_mutable(): + if ( + issubclass(self.user_cls, dict) + and not issubclass(self.user_cls, collections.OrderedDict) + and len(args) + ): + raise_args_mismatch(tx, name) + + if not self.items: + msg = ConstantVariable.create("popitem(): dictionary is empty") + raise_observed_exception(KeyError, tx, args=[msg]) + + if self.user_cls is collections.OrderedDict and ( + len(args) == 1 or "last" in kwargs + ): + if len(args) == 1 and args[0].is_python_constant(): + last = args[0].as_python_constant() + elif (v := kwargs.get("last")) and v.is_python_constant(): + last = v.as_python_constant() + else: + raise_args_mismatch(tx, name) + k, v = self.items.popitem(last=last) # type: ignore[possibly-undefined] + else: + k, v = self.items.popitem() + + self.should_reconstruct_all = True + tx.output.side_effects.mutation(self) + + return variables.TupleVariable([k.vt, v]) + elif name == "clear": + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + self.should_reconstruct_all = True + tx.output.side_effects.mutation(self) + self.items.clear() + return ConstantVariable.create(None) + elif name == "update" and self.is_mutable(): + # In general, this call looks like `a.update(b, x=1, y=2, ...)`. + # Either `b` or the kwargs is omittable, but not both. + self.install_dict_keys_match_guard() + has_arg = len(args) == 1 + has_kwargs = len(kwargs) > 0 + if has_arg or has_kwargs: + tx.output.side_effects.mutation(self) + if has_arg: + if isinstance(args[0], ConstDictVariable): + # NB - Guard on all the keys of the other dict to ensure + # correctness. + args[0].install_dict_keys_match_guard() + dict_vt: ConstDictVariable = args[0] + else: + dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) # type: ignore[assignment] + self.items.update(dict_vt.items) # type: ignore[attr-defined] + if has_kwargs: + # Handle kwargs + kwargs_hashable = { + Hashable(ConstantVariable.create(k)): v + for k, v in kwargs.items() + } + self.items.update(kwargs_hashable) + return ConstantVariable.create(None) + else: + return super().call_method(tx, name, args, kwargs) + elif name == "__contains__": + if not len(args): + raise_args_mismatch( + tx, + name, + "more than 1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + + arg_hashable = args and is_hashable(args[0]) + if not arg_hashable: + raise_unhashable(args[0], tx) + + self.install_dict_contains_guard(tx, args) + contains = args[0] in self + return ConstantVariable.create(contains) + elif name == "setdefault" and self.is_mutable(): + if len(args) not in (1, 2): + raise_args_mismatch( + tx, + name, + "1 or 2 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + + arg_hashable = args and is_hashable(args[0]) + if not arg_hashable: + raise_unhashable(args[0], tx) + + self.install_dict_keys_match_guard() + if kwargs or len(args) > 2: + raise_args_mismatch( + tx, + name, + "at most 2 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + value = self.maybe_getitem_const(args[0]) + if value is not None: + return value + else: + if len(args) == 1: + x = ConstantVariable.create(None) + else: + x = args[1] + tx.output.side_effects.mutation(self) + self.items[Hashable(args[0])] = x + return x + elif name == "move_to_end": + self.install_dict_keys_match_guard() + tx.output.side_effects.mutation(self) + if args[0] not in self: + raise_observed_exception(KeyError, tx) + + last = True + if len(args) == 2 and args[1].is_python_constant(): + last = args[1].as_python_constant() + + if kwargs and "last" in kwargs and kwargs["last"].is_python_constant(): + last = kwargs.get("last").as_python_constant() # type: ignore[union-attr] + + key = Hashable(args[0]) + self.items.move_to_end(key, last=last) + return ConstantVariable.create(None) + elif name == "__eq__" and istype( + self, ConstDictVariable + ): # don't let Set use this function + if len(args) != 1: + raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") + + return variables.UserFunctionVariable(polyfills.dict___eq__).call_function( + tx, [self, args[0]], {} + ) + elif name == "__ne__": + return ConstantVariable.create( + not self.call_method(tx, "__eq__", args, kwargs).value # type: ignore[attr-defined] + ) + elif name == "__or__": + if len(args) != 1: + raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") + other = args[0] + + # Method resolution for binops works as follow (using __or__ as example): + # (1) dict.__or__(dict) => dict + # (2) dict.__or__(subclass): return NotImplemented + # (3) Check if subclass implements __ror__ => forward the call + # to subclass.__ror__(dict) + + # Let's not forward the call to __ror__ yet because __ror__ can be + # implemented in C (i.e. OrderedDict subclass) which Dynamo cannot + # trace + # if istype(other, variables.UserDefinedDictVariable): + # if other.call_obj_hasattr(tx, "__ror__").value: + # return other.call_method(tx, "__ror__", [self], kwargs) + + # The three dict types Dynamo can handle are dict, OrderedDict and + # defaultdict. + + # TODO(guilhermeleobas): this check should be on builtin.py::call_or_ + if not istype( + other, (ConstDictVariable, variables.UserDefinedDictVariable) + ): + err_msg = ( + f"unsupported operand type(s) for |: '{self.python_type().__name__}'" + f"and '{other.python_type().__name__}'" + ) + raise_observed_exception(TypeError, tx, args=[err_msg]) + + # OrderedDict overloads __ror__ + ts = {self.user_cls, other.user_cls} # type: ignore[attr-defined] + user_cls = ( + collections.OrderedDict + if any(issubclass(t, collections.OrderedDict) for t in ts) + else dict + ) + + self.install_dict_keys_match_guard() + new_dict_vt = self.clone( + items=self.items.copy(), + mutation_type=ValueMutationNew(), + source=None, + user_cls=user_cls, + ) + + # NB - Guard on all the keys of the other dict to ensure + # correctness. + args[0].install_dict_keys_match_guard() # type: ignore[attr-defined] + new_dict_vt.items.update(args[0].items) # type: ignore[attr-defined] + return new_dict_vt + elif name == "__ior__": + self.call_method(tx, "update", args, kwargs) + return self + elif name == "__iter__": + if self.source and not is_constant_source(self.source): + tx.output.guard_on_key_order.add(self.source) + return ListIteratorVariable( + self.unpack_var_sequence(tx), mutation_type=ValueMutationNew() + ) + else: + return super().call_method(tx, name, args, kwargs) + + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: + self.install_dict_keys_match_guard() + return [x.vt for x in self.items] + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + # dict not allow setting arbitrary attributes. OrderedDict and + # defaultdict allow arbitrary setattr, but not deletion of default attrs + if any( + self.user_cls is t + for t in (dict, collections.OrderedDict, collections.defaultdict) + ): + if hasattr(self.user_cls, name): + return ConstantVariable.create(True) + if self.user_cls is dict: + return ConstantVariable.create(False) + + msg = f"hasattr on {self.user_cls} is not supported" + unimplemented( + gb_type="unsupported hasattr operation", + context=f"Class {self.user_cls}", + explanation=msg, + hints=[ + "Consider using a regular dictionary instead", + *graph_break_hints.SUPPORTABLE, + ], + ) + + def clone(self, **kwargs: Any) -> VariableTracker: + self.install_dict_keys_match_guard() + return super().clone(**kwargs) + + def is_python_hashable(self): + """ + Dictionaries are mutable and therefore not hashable in Python. + """ + return False + + +class MappingProxyVariable(VariableTracker): + # proxies to the original dict_vt + def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None: + super().__init__(**kwargs) + assert isinstance(dv_dict, ConstDictVariable) + self.dv_dict = dv_dict + + def python_type(self) -> type: + return types.MappingProxyType + + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: + return self.dv_dict.unpack_var_sequence(tx) + + def reconstruct(self, codegen: "PyCodegen") -> None: + # load types.MappingProxyType + if self.source: + msg = ( + f"Preexisting MappingProxyVariable (source: {self.source}) cannot be reconstructed " + "because the connection to the original dict will be lost." + ) + unimplemented( + gb_type="mapping proxy cannot be reconstructed", + context=f"Source: {self.source}", + explanation=msg, + hints=[ + "Use a mapping proxy constructed in the same `torch.compile` region.", + *graph_break_hints.SUPPORTABLE, + ], + ) + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_python_module(types), + codegen.create_load_attr("MappingProxyType"), + ] + ) + ) + codegen(self.dv_dict) + codegen.extend_output(create_call_function(1, False)) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if self.source and tx.output.side_effects.has_existing_dict_mutation(): + msg = ( + "A dict has been modified while we have an existing mappingproxy object. " + "A mapping proxy object, as the name suggest, proxies a mapping " + "object (usually a dict). If the original dict object mutates, it " + "is reflected in the proxy object as well. For an existing proxy " + "object, we do not know the original dict it points to. Therefore, " + "for correctness we graph break when there is dict mutation and we " + "are trying to access a proxy object." + ) + + unimplemented( + gb_type="mapping proxy affected by dictionary mutation", + context=f"Source: {self.source}, Dict mutation detected", + explanation=msg, + hints=[ + "Avoid modifying dictionaries that might be referenced by mapping proxy objects", + "Or avoid using the mapping proxy objects after modifying its underlying dictionary", + ], + ) + return self.dv_dict.call_method(tx, name, args, kwargs) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + if self.python_type() is types.MappingProxyType: + return ConstantVariable.create(name in types.MappingProxyType.__dict__) + return super().call_obj_hasattr(tx, name) + + +class NNModuleHooksDictVariable(ConstDictVariable): + # Special class to avoid adding any guards on the nn module hook ids. + def install_dict_keys_match_guard(self) -> None: + pass + + def install_dict_contains_guard( + self, tx: "InstructionTranslator", args: list[VariableTracker] + ) -> None: + pass + + +class DefaultDictVariable(ConstDictVariable): + def __init__( + self, + items: dict[VariableTracker, VariableTracker], + user_cls: type, + default_factory: Optional[VariableTracker] = None, + **kwargs: Any, + ) -> None: + super().__init__(items, user_cls, **kwargs) + assert user_cls is collections.defaultdict + if default_factory is None: + default_factory = ConstantVariable.create(None) + self.default_factory = default_factory + + def is_python_constant(self) -> bool: + # Return false for unsupported defaults. This ensures that a bad handler + # path is not taken in BuiltinVariable for getitem. + if self.default_factory not in [list, tuple, dict] and not self.items: + return False + return super().is_python_constant() + + def debug_repr(self) -> str: + assert self.default_factory is not None + return ( + f"defaultdict({self.default_factory.debug_repr()}, {super().debug_repr()})" + ) + + @staticmethod + def is_supported_arg(arg: VariableTracker) -> bool: + if isinstance(arg, variables.BuiltinVariable): + return arg.fn in (list, tuple, dict, set) + else: + return isinstance( + arg, + ( + variables.functions.BaseUserFunctionVariable, + variables.functions.PolyfilledFunctionVariable, + ), + ) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__getitem__": + if len(args) != 1: + raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") + + if args[0] in self: + return self.getitem_const(tx, args[0]) + else: + if ( + istype(self.default_factory, ConstantVariable) + and self.default_factory.value is None + ): + raise_observed_exception(KeyError, tx, args=[args[0]]) + else: + default_var = self.default_factory.call_function(tx, [], {}) + super().call_method( + tx, "__setitem__", [args[0], default_var], kwargs + ) + return default_var + else: + return super().call_method(tx, name, args, kwargs) + + def reconstruct(self, codegen: "PyCodegen") -> None: + # emit `defaultdict(default_factory, new_dict)` + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_python_module(collections), + codegen.create_load_attr("defaultdict"), + ] + ) + ) + codegen(self.default_factory) + self.reconstruct_kvs_into_new_dict(codegen) + codegen.extend_output(create_call_function(2, False)) + + +# TODO: Implementing this via inheritance rather than composition is a +# footgun, because self method calls in dict will route back to the set +# implementation, which is almost assuredly wrong +class SetVariable(ConstDictVariable): + """We model a sets as dictionary with None values""" + + CONTAINS_GUARD = GuardBuilder.SET_CONTAINS + + def __init__( + self, + items: list[VariableTracker], + **kwargs: Any, + ) -> None: + # pyrefly: ignore[bad-assignment] + items = dict.fromkeys(items, SetVariable._default_value()) + # pyrefly: ignore[bad-argument-type] + super().__init__(items, **kwargs) + + def debug_repr(self) -> str: + if not self.items: + return "set()" + else: + return "{" + ",".join(k.vt.debug_repr() for k in self.items) + "}" + + @property + def set_items(self) -> set["ConstDictVariable._HashableTracker"]: + return set(self.items.keys()) + + @staticmethod + def _default_value() -> VariableTracker: + # Variable to fill in he keys of the dictionary + return ConstantVariable.create(None) + + def as_proxy(self) -> Any: + return {k.vt.as_proxy() for k in self.set_items} + + def python_type(self) -> type: + return set + + def as_python_constant(self) -> Any: + return {k.vt.as_python_constant() for k in self.set_items} + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.foreach([x.vt for x in self.set_items]) + codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items))) + + def _fast_set_method( + self, + tx: "InstructionTranslator", + fn: Any, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + try: + res = fn( + *[x.as_python_constant() for x in [self, *args]], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ) + except Exception as exc: + raise_observed_exception( + type(exc), tx, args=list(map(ConstantVariable.create, exc.args)) + ) + # pyrefly: ignore[unbound-name] + return VariableTracker.build(tx, res) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # We forward the calls to the dictionary model + from ..utils import check_constant_args + + if ( + name + in ( + "isdisjoint", + "union", + "intersection", + "difference", + "symmetric_difference", + ) + and check_constant_args(args, kwargs) + and self.python_type() is set + ): + py_type = self.python_type() + return self._fast_set_method(tx, getattr(py_type, name), args, kwargs) + + if name == "__init__": + temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, **kwargs) + tx.output.side_effects.mutation(self) + self.items.clear() + self.items.update(temp_set_vt.items) # type: ignore[attr-defined] + return ConstantVariable.create(None) + elif name == "add": + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + name = "__setitem__" + args = [args[0], SetVariable._default_value()] + elif name == "pop": + if kwargs or args: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + # Choose an item at random and pop it via the Dict.pop method + try: + result: VariableTracker = self.set_items.pop().vt # type: ignore[assignment] + except KeyError as e: + raise_observed_exception( + KeyError, tx, args=list(map(ConstantVariable.create, e.args)) + ) + # pyrefly: ignore[unbound-name] + super().call_method(tx, name, [result], kwargs) + # pyrefly: ignore[unbound-name] + return result + elif name == "isdisjoint": + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + return variables.UserFunctionVariable( + polyfills.set_isdisjoint + ).call_function(tx, [self, args[0]], {}) + elif name == "intersection": + if kwargs: + raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") + return variables.UserFunctionVariable( + polyfills.set_intersection + ).call_function(tx, [self, *args], {}) + elif name == "intersection_update": + if kwargs: + raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") + return variables.UserFunctionVariable( + polyfills.set_intersection_update + ).call_function(tx, [self, *args], {}) + elif name == "union": + if kwargs: + raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") + return variables.UserFunctionVariable(polyfills.set_union).call_function( + tx, [self, *args], {} + ) + elif name == "difference": + if kwargs: + raise_args_mismatch( + tx, name, f"Expect: 0 kwargs, Actual: {len(kwargs)} kwargs" + ) + return variables.UserFunctionVariable( + polyfills.set_difference + ).call_function(tx, [self, *args], {}) + elif name == "difference_update": + if kwargs: + raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") + return variables.UserFunctionVariable( + polyfills.set_difference_update + ).call_function(tx, [self, *args], {}) + elif name == "symmetric_difference": + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + return variables.UserFunctionVariable( + polyfills.set_symmetric_difference + ).call_function(tx, [self, *args], {}) + elif name == "symmetric_difference_update": + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + return variables.UserFunctionVariable( + polyfills.set_symmetric_difference_update + ).call_function(tx, [self, *args], {}) + elif name == "update" and self.is_mutable(): + if kwargs: + raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") + return variables.UserFunctionVariable(polyfills.set_update).call_function( + tx, [self, *args], {} + ) + elif name == "remove": + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + if args[0] not in self: + raise_observed_exception(KeyError, tx, args=args) + return super().call_method(tx, "pop", args, kwargs) + elif name == "discard": + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + if args[0] in self: + return super().call_method(tx, "pop", args, kwargs) + else: + return ConstantVariable.create(value=None) + elif name in ("issubset", "issuperset"): + if len(args) != 1: + raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") + + op = { + "issubset": operator.le, + "issuperset": operator.ge, + } + other = args[0].realize() + if not istype(other, SetVariable): + other = variables.BuiltinVariable(set).call_function(tx, [other], {}) + return variables.BuiltinVariable(op.get(name)).call_function( + tx, [self, other], {} + ) + elif name in ("__and__", "__or__", "__xor__", "__sub__"): + m = { + "__and__": "intersection", + "__or__": "union", + "__xor__": "symmetric_difference", + "__sub__": "difference", + }.get(name) + if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): + msg = ConstantVariable.create( + f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'" + ) + raise_observed_exception(TypeError, tx, args=[msg]) + assert m is not None + return self.call_method(tx, m, args, kwargs) + elif name in ("__iand__", "__ior__", "__ixor__", "__isub__"): + if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): + msg = ConstantVariable.create( + f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'" + ) + raise_observed_exception(TypeError, tx, args=[msg]) + m = { + "__iand__": "intersection_update", + "__ior__": "update", + "__ixor__": "symmetric_difference_update", + "__isub__": "difference_update", + }.get(name) + assert m is not None + self.call_method(tx, m, args, kwargs) + return self + elif name == "__eq__": + if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): + return ConstantVariable.create(False) + r = self.call_method(tx, "symmetric_difference", args, kwargs) + return ConstantVariable.create(len(r.set_items) == 0) # type: ignore[attr-defined] + elif name in cmp_name_to_op_mapping: + if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): + return ConstantVariable.create(NotImplemented) + return ConstantVariable.create( + cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined] + ) + return super().call_method(tx, name, args, kwargs) + + def getitem_const( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: + raise RuntimeError("Illegal to getitem on a set") + + def install_dict_keys_match_guard(self) -> None: + # Already EQUALS_MATCH guarded + pass + + +class FrozensetVariable(SetVariable): + def debug_repr(self) -> str: + if not self.items: + return "frozenset()" + else: + return "{" + ",".join(k.vt.debug_repr() for k in self.items) + "}" + + @property + def set_items(self) -> set["ConstDictVariable._HashableTracker"]: + return self.items.keys() + + def python_type(self) -> type: + return frozenset + + def as_python_constant(self) -> Any: + return frozenset({k.vt.as_python_constant() for k in self.set_items}) + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.foreach([x.vt for x in self.set_items]) + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_global("frozenset"), + ] + ) + ) + codegen.extend_output(create_call_function(0, False)) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name in ["add", "pop", "update", "remove", "discard", "clear"]: + raise RuntimeError(f"Illegal call_method {name} on a frozenset") + elif name == "__init__": + # frozenset is immutable. Calling __init__ again shouldn't have any effect + # In[1]: s = frozenset([1, 2]) + # + # In[2]: s.__init__([3, 4]) + # + # In[3]: s + # frozenset({1, 2}) + return ConstantVariable.create(None) + elif name in ( + "copy", + "difference", + "intersection", + "symmetric_difference", + ): + r = super().call_method(tx, name, args, kwargs) + return FrozensetVariable(r.items) # type: ignore[attr-defined] + return super().call_method(tx, name, args, kwargs) + + def is_python_hashable(self): + """ + Frozensets are immutable and hashable in Python. + """ + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + + +class DictKeySetVariable(SetVariable): + def debug_repr(self) -> str: + if not self.items: + return "dict_keys([])" + else: + return ( + "dict_keys([" + ",".join(k.vt.debug_repr() for k in self.items) + "])" + ) + + def install_dict_keys_match_guard(self) -> None: + # Already EQUALS_MATCH guarded + pass + + def install_dict_contains_guard( + self, tx: "InstructionTranslator", args: list[VariableTracker] + ) -> None: + # Already EQUALS_MATCH guarded + pass + + @property + def set_items(self) -> Any: + return self.items + + def python_type(self) -> type: + return dict_keys + + def as_python_constant(self) -> Any: + return dict.fromkeys( + {k.vt.as_python_constant() for k in self.set_items}, None + ).keys() + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name in ["add", "pop", "update", "remove", "discard", "clear"]: + raise RuntimeError(f"Illegal call_method {name} on a dict_keys") + return super().call_method(tx, name, args, kwargs) + + +class DictViewVariable(VariableTracker): + """ + Models _PyDictViewObject + + This is an "abstract" class. Subclasses will override kv and the items method + """ + + kv: Optional[str] = None + + def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None: + super().__init__(**kwargs) + assert self.kv in ("keys", "values", "items") + assert isinstance(dv_dict, ConstDictVariable) + self.dv_dict = dv_dict + + @property + def view_items(self) -> Any: + assert self.kv is not None + return getattr(self.dv_dict.items, self.kv)() + + @property + def view_items_vt(self) -> list[VariableTracker]: + # Returns an iterable of the unpacked items + # Implement in the subclasses + raise NotImplementedError + + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: + return self.view_items_vt + + def reconstruct(self, codegen: "PyCodegen") -> None: + assert self.kv is not None + codegen(self.dv_dict) + codegen.load_method(self.kv) + codegen.call_method(0) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + assert self.kv is not None + if name in self.python_type().__dict__: + return ConstantVariable.create(True) + return ConstantVariable.create(False) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__len__": + return self.dv_dict.call_method(tx, name, args, kwargs) + elif name == "__iter__": + return ListIteratorVariable( + self.view_items_vt, mutation_type=ValueMutationNew() + ) + return super().call_method(tx, name, args, kwargs) + + +class DictKeysVariable(DictViewVariable): + kv = "keys" + + @property + def set_items(self) -> set[VariableTracker]: + return set(self.view_items) + + @property + def view_items_vt(self) -> list[VariableTracker]: + # Returns an iterable of the unpacked items + return [x.vt for x in self.view_items] + + def python_type(self) -> type: + return dict_keys + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__contains__": + return self.dv_dict.call_method(tx, name, args, kwargs) + elif name in ( + "__and__", + "__iand__", + "__or__", + "__ior__", + "__sub__", + "__isub__", + "__xor__", + "__ixor__", + ): + # These methods always returns a set + m = getattr(self.set_items, name) + r = m(args[0].set_items) # type: ignore[attr-defined] + return SetVariable(r) + if name in cmp_name_to_op_mapping: + if not isinstance(args[0], (SetVariable, DictKeysVariable)): + return ConstantVariable.create(NotImplemented) + return ConstantVariable.create( + cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined] + ) + return super().call_method(tx, name, args, kwargs) + + +class DictValuesVariable(DictViewVariable): + # DictValuesVariable is an iterable but cannot be compared. + kv = "values" + + @property + def view_items_vt(self) -> list[VariableTracker]: + return list(self.view_items) + + def python_type(self) -> type: + return dict_values + + +class DictItemsVariable(DictViewVariable): + kv = "items" + + @property + def view_items_vt(self) -> list[VariableTracker]: + # Returns an iterable of the unpacked items + return [variables.TupleVariable([k.vt, v]) for k, v in self.view_items] + + def python_type(self) -> type: + return dict_items + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # TODO(guilhermeleobas): This should actually check if args[0] + # implements the mapping protocol. + if name == "__eq__": + if len(args) != 1: + raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") + if isinstance(args[0], DictItemsVariable): + return self.dv_dict.call_method(tx, "__eq__", [args[0].dv_dict], {}) + return ConstantVariable.create(False) + return super().call_method(tx, name, args, kwargs) + + def is_python_hashable(self): + """ + Dictionary item views are not hashable in Python. + """ + return False diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/distributed.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..cbf80e45bd0ed597c2d9ae4e3c7e131da52f2d34 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/distributed.py @@ -0,0 +1,507 @@ +""" +Distributed computing variable tracking classes for PyTorch Dynamo. + +This module implements variable tracking for distributed computing components: +- Process Groups (for collective communication) +- Device Meshes (for distributed tensor sharding) +- Placement Types (for specifying distribution strategies) +- Distributed Tensors and their operations +- Backward hooks for distributed module operations + +These classes are responsible for tracking distributed operations during graph +compilation while maintaining proper guards and handling distributed-specific +behaviors. They ensure correct handling of distributed components like process +groups, device meshes, and placement strategies while preserving proper semantics +for distributed tensor operations in the compiled code. + +The implementation provides special handling for distributed package availability +checks and proper tracking of distributed state and operations across processes. +""" + +import functools +import inspect +from collections.abc import Sequence +from typing import Any, TYPE_CHECKING + +import torch +from torch.fx.experimental._backward_state import BackwardState + +from .. import compiled_autograd, variables +from .._trace_wrapped_higher_order_op import trace_wrapped +from ..bytecode_transformation import create_call_function +from ..exc import unimplemented +from ..external_utils import call_module_hooks_from_backward_state +from ..guards import GuardBuilder, install_guard +from ..source import AttrSource +from ..utils import istype +from .base import VariableTracker +from .constant import ConstantVariable, EnumVariable + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + +class DistributedVariable(VariableTracker): + """ + The base distributed variable that encapsulates common methods + for the distributed objects (i.e. ProcessGroup, DeviceMesh, etc.). + Concrete distributed objects could inherit this class and add object + specific logic. + + i.e. It provides the check on the distributed package existence + and hold the tracking value for the corresponding distributed object. + """ + + def __init__(self, value: Any, **kwargs: Any) -> None: + super().__init__(**kwargs) + if not DistributedVariable.is_available(): + unimplemented( + gb_type="torch.distributed package is not available!", + context="", + explanation="The PyTorch package doesn't include torch.distributed when building from source.", + hints=[ + "Set USE_DISTRIBUTED=1 to enable it when building PyTorch from source." + ], + ) + self.value = value + + def python_type(self) -> type: + return type(self.value) + + @staticmethod + def is_available() -> bool: + # check if the distributed package is available or not + return torch.distributed.is_available() + + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.value) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + + +def is_from_local(value: object) -> bool: + if not DistributedVariable.is_available(): + return False + from torch.distributed.tensor import DTensor + + return inspect.isfunction(value) and value is DTensor.from_local + + +def is_constant_pg_functions(value: object) -> bool: + if not DistributedVariable.is_available(): + return False + + from torch.distributed.distributed_c10d import ( + _get_group_size_by_name, + _get_group_tag, + _rank_not_in_group, + _resolve_group_name_by_ranks_and_tag, + get_process_group_ranks, + ) + + constant_processgroup_functions = [ + _get_group_size_by_name, + _get_group_tag, + _rank_not_in_group, + get_process_group_ranks, + _resolve_group_name_by_ranks_and_tag, + ] + + return inspect.isfunction(value) and value in constant_processgroup_functions + + +class WorldMetaClassVariable(DistributedVariable): + """ + Tracks torch.distributed.GroupMember and torch.distributed.group, which are + instances of the metaclass _WorldMeta. + """ + + @classmethod + def is_group_member_type(cls, value: object) -> bool: + if not cls.is_available(): + return False + + from torch.distributed.distributed_c10d import _WorldMeta + + return type(value) is _WorldMeta + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if name == "WORLD": + assert self.source + source = AttrSource(base=self.source, member="WORLD") + install_guard(source.make_guard(GuardBuilder.ID_MATCH)) + return ProcessGroupVariable(self.value.WORLD) + elif name == "NON_GROUP_MEMBER": + assert self.source + source = AttrSource(base=self.source, member="NON_GROUP_MEMBER") + install_guard(source.make_guard(GuardBuilder.ID_MATCH)) + return EnumVariable(self.value.NON_GROUP_MEMBER) + return super().var_getattr(tx, name) + + +class PlacementClassVariable(DistributedVariable): + @staticmethod + def is_placement_type(value: object) -> bool: + # we can't rely on importing/accessing torch distributed, it is not always built. + if not DistributedVariable.is_available(): + return False + + from torch.distributed.tensor.placement_types import Placement + + return isinstance(value, type) and issubclass(value, Placement) + + def as_python_constant(self) -> Any: + return self.value + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if self.source: + # NOTE: we don't need to track mutations to the placement class as they + # are supposed to be immutable. + new_obj = self.value.__new__(self.value) + var = PlacementVariable(new_obj) + if inspect.getattr_static(self.value, "__init__", None): + var.call_method(tx, "__init__", args, kwargs) + return var + + return super().call_function(tx, args, kwargs) + + +class PlacementVariable(DistributedVariable): + @staticmethod + def is_placement(value: object) -> bool: + # we can't rely on importing/accessing torch distributed, it is not always built. + if not DistributedVariable.is_available(): + return False + from torch.distributed.tensor.placement_types import Placement + + return isinstance(value, Placement) + + def as_python_constant(self) -> Any: + return self.value + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if name == "dim": + return ConstantVariable.create(self.value.dim) + return super().var_getattr(tx, name) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from . import ConstantVariable + + # Placement types dynamo tracking only allows following methods + # and __setattr__ is for case like `Shard(dim)` and methods. + # Methods in the list must satisfy: + # 1. Input arguments are constants and do not need to be guarded on; + # 2. Output is constant with respect to their inputs + constant_fold_functions = [ + "__init__", + "__setattr__", + "is_shard", + "is_partial", + "is_replicate", + ] + + if name in constant_fold_functions: + try: + value_type = type(self.value) + if inspect.getattr_static(value_type, "__getattr__", None) is not None: + unimplemented( + gb_type="Placement with custom __getattr__ not supported", + context=f"{value_type.__name__} with custom __getattr__", + explanation="Dynamo does not support Placement types with custom __getattr__ methods", + hints=[ + "Use Placement types without custom __getattr__ methods", + "Move the Placement usage outside the compiled region", + ], + ) + method = inspect.getattr_static(value_type, name) + except AttributeError: + method = None + if method is object.__init__: + return ConstantVariable.create(None) + + args = [x.as_python_constant() for x in args] + kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + assert method is not None + if name == "__setattr__": + method(self.value, *args, **kwargs) + return self + constant_val = method(self.value, *args, **kwargs) + return ConstantVariable.create(constant_val) + + return super().call_method(tx, name, args, kwargs) # type: ignore[arg-type] + + def reconstruct(self, codegen: "PyCodegen") -> None: + # Reconstruct the Placement object by calling its constructor + # e.g., Shard(0), Replicate(), Partial() + from torch.distributed.tensor.placement_types import Partial, Replicate, Shard + + placement_type = type(self.value) + + # Load the placement class + codegen.add_push_null( + lambda: codegen.load_import_from( + "torch.distributed.tensor.placement_types", placement_type.__name__ + ) + ) + + # For Shard, we need to pass the dim argument + if isinstance(self.value, Shard): + codegen(ConstantVariable.create(self.value.dim)) + codegen.extend_output(create_call_function(1, False)) + # Replicate and Partial have no required args + elif istype(self.value, (Replicate, Partial)): + codegen.extend_output(create_call_function(0, False)) + else: + super().reconstruct(codegen) + + +class DeviceMeshVariable(DistributedVariable): + @staticmethod + def is_device_mesh(value: object) -> bool: + # we can't rely on importing/accessing torch distributed, it is not always built. + if not DistributedVariable.is_available(): + return False + + from torch.distributed.device_mesh import DeviceMesh + + return istype(value, DeviceMesh) + + def as_python_constant(self) -> Any: + return self.value + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if name == "ndim": + return ConstantVariable.create(self.value.ndim) + if name == "device_type": + return ConstantVariable.create(self.value.device_type) + if name == "mesh_dim_names": + source = self.source + if source: + source = AttrSource(base=source, member="mesh_dim_names") + return VariableTracker.build(tx, self.value.mesh_dim_names, source) + return super().var_getattr(tx, name) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "size": + const_args = [x.as_python_constant() for x in args] + const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + return ConstantVariable.create(self.value.size(*const_args, **const_kwargs)) + if name == "get_coordinate": + return ConstantVariable.create(self.value.get_coordinate()) + if name == "get_rank": + return ConstantVariable.create(self.value.get_rank()) + if name == "get_local_rank": + const_args = [x.as_python_constant() for x in args] + const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + return ConstantVariable.create( + self.value.get_local_rank(*const_args, **const_kwargs) + ) + if name == "get_group": + const_args = [x.as_python_constant() for x in args] + const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + return ProcessGroupVariable( + self.value.get_group(*const_args, **const_kwargs) + ) + if name == "_get_or_create_default_group": + return ProcessGroupVariable(self.value._get_or_create_default_group()) + if name == "_flatten": + from .builder import SourcelessBuilder + + const_args = [x.as_python_constant() for x in args] + const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + return SourcelessBuilder.create( + tx, self.value._flatten(*const_args, **const_kwargs) + ) + return super().call_method(tx, name, args, kwargs) + + +class ProcessGroupVariable(DistributedVariable): + """ + We don't want a ProcessGroup object to end up in our output graph. + + But it's common for dynamo to intercept a PG that is then used to get info like + rank() or world_size(), as well as passed to utility functions in distributed_c10d + which desugar it into plain types like a ranklist and tag. + + For convenience and proper guarding, we construct a variable type. + + TODO: make it possible to use ProcessGroupVariable as input to simple functions + like _expand_group without dynamo complaining about making a proxy for it. + It is not a tensor-like type, and we don't want a proxy- but dynamo assumes + torch library functions are dealing with tensor-like types and would have proxies + for their args. + TODO: should we make this inherit VT instead of UDOV? Do we want any of the default behaviors + or just graph-break whenever one of our special cases is not hit? + """ + + def as_python_constant(self) -> Any: + return self.value + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "rank": + return variables.ConstantVariable.create(self.value.rank()) + if name == "size": + return variables.ConstantVariable.create(self.value.size()) + if name == "_get_backend_name": + return variables.ConstantVariable.create(self.value._get_backend_name()) + + return super().call_method(tx, name, args, kwargs) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if name == "group_name": + return variables.ConstantVariable.create(self.value.group_name) + if name in ["rank", "size"]: + return variables.LambdaVariable( + lambda *args, **kwargs: self.call_method(tx, name, args, kwargs) + ) + # TODO should this just raise unimplemented? + return super().var_getattr(tx, name) + + @staticmethod + def is_process_group(value: object) -> bool: + # we can't rely on importing/accessing torch distributed, it is not always built. + if not DistributedVariable.is_available(): + return False + from torch._C._distributed_c10d import ProcessGroup + from torch.testing._internal.distributed.fake_pg import FakeProcessGroup + + return istype(value, (ProcessGroup, FakeProcessGroup)) + + +class BackwardHookVariable(VariableTracker): + """ + Handles torch.utils.hooks.BackwardHook for module-level backward + hooks. + """ + + @staticmethod + def create( + tx: "InstructionTranslator", + module: VariableTracker, + user_hooks: VariableTracker, + user_pre_hooks: VariableTracker, + ) -> "BackwardHookVariable": + if not compiled_autograd.compiled_autograd_enabled: + unimplemented( + gb_type="Module-level backwards hooks require compiled autograd.", + context="", + explanation="", + hints=[ + "Enable compiled autograd by setting torch._dynamo.config.compiled_autograd = True." + ], + ) + + def _in_graph_bw_hooks( + bw_state: BackwardState, + ) -> torch.utils.hooks.BackwardHook: + """ + Rather than installing the user hooks in the graph (which + don't survive AotAutograd), we install hooks that will call + trace_wrapped in the backward pass that CompiledAutograd + can turn into actual hook calls. + """ + return torch.utils.hooks.BackwardHook( + None, + ( + functools.partial( + trace_wrapped, + fn=call_module_hooks_from_backward_state, + bw_state=bw_state, + hooks_name=user_hooks_name, + module_name=module_name, + ), + ), + ( + functools.partial( + trace_wrapped, + fn=call_module_hooks_from_backward_state, + bw_state=bw_state, + hooks_name=user_pre_hooks_name, + module_name=module_name, + ), + ), + ) + + module_name, bw_state_proxy = tx.output.add_backward_state_hook(module, "mod") + user_pre_hooks_name, _ = tx.output.add_backward_state_hook(user_pre_hooks) + user_hooks_name, _ = tx.output.add_backward_state_hook(user_hooks) + proxy = tx.output.create_proxy( + "call_function", + _in_graph_bw_hooks, + (bw_state_proxy,), + {}, + ) + proxy.node.meta["example_value"] = torch.utils.hooks.BackwardHook(None, (), ()) + return BackwardHookVariable(proxy, module, user_hooks, user_pre_hooks) + + def __init__( + self, + proxy: torch.fx.Proxy, + module: VariableTracker, + user_hooks: VariableTracker, + user_pre_hooks: VariableTracker, + **options: Any, + ) -> None: + super().__init__(**options) + self.proxy = proxy + self.module = module + self.user_hooks = user_hooks + self.user_pre_hooks = user_pre_hooks + + def as_proxy(self) -> torch.fx.Proxy: + return self.proxy + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name in ("setup_input_hook", "setup_output_hook"): + return self._setup_hook(tx, name, *args, **kwargs) + return super().call_method(tx, name, args, kwargs) + + def _setup_hook( + self, tx: "InstructionTranslator", hook_method_name: str, args: VariableTracker + ) -> VariableTracker: + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_method", + hook_method_name, + (self.as_proxy(), args.as_proxy()), + {}, + ), + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..9638278300bcf7df327cdd338d927c35f6b6cdad --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py @@ -0,0 +1,3059 @@ +""" +Function-related variable tracking classes for Dynamo's symbolic execution. + +This module contains classes that track different types of functions during graph +compilation, including: +- User-defined functions and methods +- Built-in functions and methods +- Wrapped functions (e.g. from decorators) +- Special function types (e.g. functools.partial) +- Triton kernels and related function types + +These classes are responsible for: +- Tracking function calls and their arguments +- Managing function closures and cell variables +- Handling function attributes and special methods +- Maintaining guards for function identity and closure contents +- Supporting function inlining and specialization +- Enabling proper symbolic execution of different function types + +The variable trackers here work together with the rest of Dynamo to enable +accurate graph capture while handling Python's various function-related behaviors. +""" + +import builtins +import functools +import inspect +import itertools +import logging +import sys +import traceback +import types +from collections import namedtuple +from collections.abc import Callable, Sequence +from types import CellType, FunctionType +from typing import Any, cast, Optional, TYPE_CHECKING, TypeVar +from typing_extensions import Never +from weakref import WeakKeyDictionary + +import torch +from torch._dynamo.exc import get_stack_above_dynamo +from torch._guards import Source +from torch.utils._pytree import is_namedtuple_class + +from .. import config, graph_break_hints, polyfills, variables +from ..bytecode_transformation import create_call_function, create_rot_n, is_generator +from ..exc import ( + format_skip_frame_message, + get_dynamo_observed_exception, + handle_observed_exception, + InfiniteGeneratorError, + ObservedException, + ObservedGeneratorExit, + ObservedUserStopIteration, + raise_observed_exception, + SkipFrame, + StepUnsupported, + unimplemented, + Unsupported, +) +from ..guards import GuardBuilder, install_guard +from ..source import ( + AttrSource, + ClosureSource, + CollectionsSource, + ConstantSource, + DefaultsSource, + GetItemSource, + SkipGuardSource, + TorchSource, + TypeSource, +) +from ..utils import ( + check_constant_args, + check_unspec_or_constant_args, + cmp_name_to_op_mapping, + identity, + is_function, + is_wrapper_or_member_descriptor, + istype, + make_cell, +) +from .base import ( + AsPythonConstantNotImplementedError, + AttributeMutationNew, + raise_type_error_exc, + ValueMutationNew, + VariableTracker, +) +from .constant import ConstantVariable + + +try: + from torch.distributed.fsdp._fully_shard import _fsdp_param_group +except ModuleNotFoundError: + _fsdp_param_group = None # type: ignore[assignment] + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import ( + InliningGeneratorInstructionTranslator, + InliningInstructionTranslator, + InstructionTranslator, + InstructionTranslatorBase, + ) + from torch._dynamo.variables.ctx_manager import ContextWrappingVariable + from torch._higher_order_ops.triton_kernel_wrap import ( + TritonGridType, + TritonKernelType, + ) + + from .lists import BaseListVariable, ListVariable + from .tensor import TensorVariable + + +_F = TypeVar("_F", bound=Callable[..., Any]) +CO_VARARGS = 0x04 +CO_VARKEYWORDS = 0x08 +_SUPPORTED_TREE_MAP_KWARGS = frozenset({"namespace", "none_is_leaf", "is_leaf"}) +_TREE_MAP_ONLY_SUPPORTED_KWARGS = frozenset({"is_leaf"}) + + +# Module-level cache keyed by the function object +_spec_cache: WeakKeyDictionary[Any, Any] = WeakKeyDictionary() + + +@functools.lru_cache +def get_pytree_SUPPORTED_NODES_source(): + return AttrSource( + AttrSource(AttrSource(TorchSource(), "utils"), "_pytree"), "SUPPORTED_NODES" + ) + + +class FunctionSpec: + def __init__(self, func: FunctionType): + code = func.__code__ + vn = code.co_varnames + + self.posonly_count = code.co_posonlyargcount + self.arg_count = code.co_argcount + self.kwonly_count = code.co_kwonlyargcount + + self.posonly_names = vn[: self.posonly_count] + self.pos_or_kw_names = vn[self.posonly_count : self.arg_count] + self.all_pos_names = self.posonly_names + self.pos_or_kw_names + self.kwonly_names = vn[self.arg_count : self.arg_count + self.kwonly_count] + + off = self.arg_count + self.kwonly_count + self.varargs_name = vn[off] if code.co_flags & CO_VARARGS else None + off += 1 if self.varargs_name else 0 + self.varkw_name = vn[off] if code.co_flags & CO_VARKEYWORDS else None + + def update_defaults(self, func: FunctionType) -> None: + # Defaults can change from function call to function call. So re-update + # them on every call. + self.defaults = func.__defaults__ or () + self.kwdefaults = func.__kwdefaults__ or {} + + # Map positional-default names → their index in self.defaults + self.pos_default_map = dict( + zip(self.all_pos_names[-len(self.defaults) :], range(len(self.defaults))) + ) + + +def _get_spec(func: FunctionType) -> FunctionSpec: + spec = _spec_cache.get(func) + if spec is None: + spec = FunctionSpec(func) + _spec_cache[func] = spec + return spec + + +def bind_args_cached( + func: FunctionType, + tx: "InstructionTranslator", + fn_source: Optional[Source], + args: Sequence[Any], + kwargs: dict[str, Any], +) -> dict[str, VariableTracker]: + spec = _get_spec(func) + spec.update_defaults(func) + ba = {} + rem_kw = dict(kwargs) + + # 1) Bind all positional (pos-only + pos-or-kw) + # 1.1) Apply pos-defaults first (maybe overridden later) + for name, idx in spec.pos_default_map.items(): + default_source = None + if fn_source and not ( + ConstantVariable.is_literal(spec.defaults[idx]) + and config.skip_guards_on_constant_func_defaults + ): + default_source = DefaultsSource(fn_source, idx) + ba[name] = wrap_bound_arg(tx, spec.defaults[idx], default_source) + # 1.2) Fill in provided positional args + for i, name in enumerate(spec.all_pos_names): + if i < len(args): + # Maybe override pos-defaults applied above + ba[name] = wrap_bound_arg(tx, args[i]) + elif name in rem_kw and ( + # `kwargs` can have the same key as a pos-only arg `name`. + # If this case happens, we should not consume the `name` here and + # keep it in `kwargs`: + # >>> def fn(a, /, **kwargs): return (a, kwargs) + # >>> fn(1, a=2) + # (1, {'a': 2}) + name not in spec.posonly_names + ): + # Maybe override pos-defaults applied above + ba[name] = wrap_bound_arg(tx, rem_kw.pop(name)) + elif name not in ba: + raise_observed_exception( + TypeError, + tx, + args=[ + ConstantVariable.create( + f"Missing required positional argument: {name}" + ) + ], + ) + + # 2) *args + extra = args[len(spec.all_pos_names) :] + if spec.varargs_name: + ba[spec.varargs_name] = wrap_bound_arg(tx, tuple(extra)) + elif extra: + raise_observed_exception( + TypeError, + tx, + args=[ + ConstantVariable.create( + f"Too many positional arguments: got {len(args)}, expected {len(spec.all_pos_names)}" + ) + ], + ) + + # 3) Keyword-only + for name in spec.kwonly_names: + if name in rem_kw: + ba[name] = wrap_bound_arg(tx, rem_kw.pop(name)) + elif name in spec.kwdefaults: + kwdefault_source = None + if fn_source: + kwdefault_source = DefaultsSource(fn_source, name, is_kw=True) + ba[name] = wrap_bound_arg(tx, spec.kwdefaults[name], kwdefault_source) + else: + raise_observed_exception( + TypeError, + tx, + args=[ + ConstantVariable.create( + f"Missing required keyword-only argument: {name}" + ) + ], + ) + + # 4) **kwargs + if spec.varkw_name: + ba[spec.varkw_name] = wrap_bound_arg(tx, rem_kw) + elif rem_kw: + raise_observed_exception( + TypeError, + tx, + args=[ + ConstantVariable.create(f"Unexpected keyword arguments: {list(rem_kw)}") + ], + ) + + return ba + + +def wrap_bound_arg( + tx: "InstructionTranslator", val: Any, source: Optional[Source] = None +) -> VariableTracker: + # Source propagation is best effort since not every object we encounter has a source to begin with. + if isinstance(val, VariableTracker): + return val + elif not source: + return VariableTracker.build(tx, val) + else: + # Create a lazy variable to avoid guarding on __defaults__ unless really + # needed. + return variables.LazyVariableTracker.create(val, source) + + +def wrap_args_kwargs(tx: "InstructionTranslator", result: dict[str, Any]) -> None: + for k, v in list(result.items()): + if isinstance(v, (tuple, dict)): + # args/kwargs + result[k] = wrap_bound_arg(tx, v) + + +def init_cellvars( + parent: "InstructionTranslator", + result: dict[str, VariableTracker], + code: types.CodeType, +) -> None: + """ + Update `result` to add mapping from local name to new cells created + directly by `code`, or update SideEffects in `parent` if the a local cell is + already in `result` (cell argument). + """ + side_effects = parent.output.side_effects + + for name in code.co_cellvars: + new_cell = side_effects.track_cell_new() + if name in result: + # This handles when a function argument is a cell (e.g., captured by + # a nested func). See `MAKE_CELL` bytecode for more info. + side_effects.store_cell(new_cell, result.pop(name)) + result[name] = new_cell + + +def _create_nested_fn( + code: types.CodeType, + f_globals: dict[str, Any], + name: str, + defaults: Optional[tuple[object, ...]], + closure: Optional[tuple[CellType]], + kwdefaults: Optional[dict[str, Any]], + annotations: Optional[dict[str, Any]], +) -> types.FunctionType: + from types import FunctionType + + func = FunctionType(code, f_globals, name, defaults, closure) + func.__kwdefaults__ = kwdefaults + + if isinstance(annotations, tuple): + from itertools import pairwise + + annotations = dict(pairwise(annotations)) + + # TypeError: __annotations__ must be set to a dict object + assert annotations is None or isinstance(annotations, dict) + func.__annotations__ = annotations # type: ignore[assignment] + + return func + + +fn_known_dunder_attrs = { + "__annotations__", + "__defaults__", + "__kwdefaults__", + "__code__", + "__globals__", + "__closure__", + "__doc__", +} + + +def fn_var_getattr( + tx: "InstructionTranslator", fn: object, source: Optional[Source], name: str +) -> VariableTracker: + source = source and AttrSource(source, name) + + if source and name == "__annotations__": + # We get a large number of silly guards from annotations from inspect + # module. Changing annotations is rare, and it impacting the extracted + # graph is even rarer. So skip guards. + source = SkipGuardSource(source) + + subobj = None + try: + subobj = inspect.getattr_static(fn, name) + except AttributeError: + # function does not have a __getattr__ or __getattribute__ method, + # so we can safely assume that this attribute is absent + raise_observed_exception(AttributeError, tx) + + # Special handling for known dunder attributes + if name in fn_known_dunder_attrs: + subobj = getattr(fn, name) + if source: + return variables.LazyVariableTracker.create(subobj, source) + return VariableTracker.build(tx, subobj) + + +class BaseUserFunctionVariable(VariableTracker): + def get_filename(self) -> str: + return self.get_code().co_filename # type: ignore[attr-defined] + + def get_name(self) -> str: + return self.get_code().co_name # type: ignore[attr-defined] + + def get_globals(self): + raise NotImplementedError + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # Ignore patch_track_step_called from torch/optim/lr_scheduler.py - it just patches + # the optimizer.step method and we don't need to trace it + if ( + self.get_name() == "patch_track_step_called" + and self.get_filename().endswith("torch/optim/lr_scheduler.py") + ): + return ConstantVariable.create(None) + return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) # type: ignore[attr-defined] + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + result = False + + try: + result = hasattr(self.get_function(), name) # type: ignore[attr-defined] + except NotImplementedError: + if name == "__name__" and isinstance(self, NestedUserFunctionVariable): + result = True + return variables.ConstantVariable.create(result) + + def closure_vars(self, tx: "InstructionTranslator") -> dict[str, VariableTracker]: + return {} + + # Override to set whether or not nested graph breaks should be allowed + # if we create an inlining tx for this BaseUserFunctionVariable. + # See symbolic_convert.py for where this function is called. + def should_allow_nested_graph_breaks(self): + return True + + +class UserFunctionVariable(BaseUserFunctionVariable): + """Some unsupported user-defined global function""" + + _nonvar_fields = { + "fn", + "is_constant", + *BaseUserFunctionVariable._nonvar_fields, + } + + _TREE_MAP_MODULES = frozenset( + { + "optree", + "optree.ops", + "torch.utils._pytree", + "torch.utils._cxx_pytree", + } + ) + + @classmethod + def create_with_source(cls, value: Any, source: Any) -> "UserFunctionVariable": + install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH)) + return cls(value, source=source) + + def __init__( + self, + fn: types.FunctionType | torch.jit.ScriptFunction, # type: ignore[type-arg] + is_constant: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if getattr(fn, "_dynamo_marked_constant", False): + # This method should be treated as a constant for the purposes of compilation + self.is_constant = True + else: + self.is_constant = False + + # TODO putting this here to avoid duplication, because we could hit this + # from several paths (e.g., SuperVariable or `var_getattr`s). + if not isinstance(fn, (types.FunctionType, torch.jit.ScriptFunction)): + unimplemented( + gb_type="can't handle functions not implemented in python ", + context=f"{fn}", + explanation="Dynamo can only handle functions defined in python", + hints=[ + "Move usage of this function out of `torch.compile` region", + *graph_break_hints.INFERENCE_MODE, + ], + ) + # TODO(anijain2305) - Replace directly calling UserFunctionVariable with + # VariableBuilder, which handles the wrapping of _torchdynamo_inline. + # unpack @torch._dynamo.optimize()(fn) wrapped function + fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn) + self.fn = fn + + def as_python_constant(self) -> Any: + if istype(self, UserFunctionVariable): + return self.fn + # subclasses (such as methods) usually aren't a constant + return super().as_python_constant() + + def self_args(self) -> list[VariableTracker]: + return [] + + def get_function(self) -> types.FunctionType: + return self.fn + + def get_code(self) -> types.CodeType: + return self.fn.__code__ + + def python_type(self) -> type: + return types.FunctionType + + def has_self(self) -> bool: + return getattr(self.fn, "__self__", None) is not None + + def get_globals(self) -> dict[str, Any]: + return self.fn.__globals__ + + def get_source(self) -> Source: + source = self.source + + if source and isinstance(self, variables.UserMethodVariable): + source = self.source_fn # type: ignore[assignment] + return source # type: ignore[return-value] + + def bind_args( + self, + parent: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> dict[str, VariableTracker]: + """ + Assume `args` and `kwargs` are VariableTracker arguments for a call to + this function, create new bindings for initial locals. + """ + assert not self.is_constant + + fn: types.FunctionType = self.fn + + if not isinstance(fn, FunctionType): + raise TypeError("Only supports regular Python functions.") + root_tx = parent.output.root_tx + + source = self.get_source() + result = bind_args_cached(fn, root_tx, source, args, kwargs) # type: ignore[arg-type] + + init_cellvars(parent, result, fn.__code__) + closure = self.fn.__closure__ or () + assert len(closure) == len(self.fn.__code__.co_freevars) + for idx, name, cell in zip( + itertools.count(), self.fn.__code__.co_freevars, closure + ): + # TODO refactor these 3 branches. + side_effects = parent.output.side_effects + if cell in side_effects: + cell_var = side_effects[cell] + + elif source: + closure_cell = GetItemSource(ClosureSource(source), idx) + closure_cell_contents = AttrSource(closure_cell, "cell_contents") + try: + contents_var = VariableTracker.build( + parent, cell.cell_contents, closure_cell_contents + ) + except ValueError: + # Cell has not yet been assigned + contents_var = variables.DeletedVariable() + cell_var = side_effects.track_cell_existing( + closure_cell, cell, contents_var + ) + + else: + # TODO figure out why source isn't available here, and whether + # we can fix that and remove this branch. + try: + contents_var = VariableTracker.build(parent, cell.cell_contents) + except ValueError: + # Cell has not yet been assigned + contents_var = variables.DeletedVariable() + cell_var = side_effects.track_cell_existing(None, cell, contents_var) + + result[name] = cell_var + + return result + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if name in cmp_name_to_op_mapping: + return variables.GetAttrVariable(self, name) + source = self.get_source() + return fn_var_getattr(tx, self.fn, source, name) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + result = hasattr(self.fn, name) + return variables.ConstantVariable.create(result) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # Handle patch_dynamo_config call + if self.fn is torch._dynamo.patch_dynamo_config: + try: + args_const = [arg.as_python_constant() for arg in args] + kwargs_const = { + key: val.as_python_constant() for key, val in kwargs.items() + } + changes = torch._dynamo.patch_dynamo_config( + *args_const, **kwargs_const + ).changes + return variables.DynamoConfigPatchVariable(changes) + except AsPythonConstantNotImplementedError as e: + raise RuntimeError( + "Cannot convert patch_dynamo_config args/kwargs to constants. " + "Please fix your call to patch_dynamo_config by using simpler inputs. " + f"args: {args}, kwargs: {kwargs}" + ) from e + elif self.fn is torch._dynamo.error_on_graph_break: + try: + bound = inspect.signature(self.fn).bind(*args, **kwargs) + error_on_graph_break = bound.arguments[ + "error_on_graph_break" + ].as_python_constant() + assert isinstance(error_on_graph_break, bool) + return variables.ErrorOnGraphBreakVariable(error_on_graph_break) + except Exception as e: + raise RuntimeError( + "Improper error_on_graph_break() call. Please fix your call to error_on_graph_break(). " + f"args: {args}, kwargs: {kwargs}" + ) from e + # Handle a `nonstrict_trace(fn)` call + elif self.fn is torch._dynamo.nonstrict_trace: + bound = inspect.signature(self.fn).bind(*args, **kwargs) + fn_var = bound.args[0] + if not isinstance(fn_var, BaseUserFunctionVariable): + typ = fn_var.python_type() + msg = f"`nonstrict_trace` expects a callable, but got value of type <{typ.__name__}>" + unimplemented( + gb_type="TypeError from user code", + context=f"call_function({self.value}, {args}, {kwargs})", # type: ignore[attr-defined] + explanation=msg, + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + if not isinstance(fn_var, UserFunctionVariable): + fn_name = fn_var.get_name() + msg = f"Applying `nonstrict_trace` to function <{fn_name}>; however, `nonstrict_trace` currently requires the function to be defined outside `torch.compile` region." # noqa: B950 + unimplemented( + gb_type="Limitation of `nonstrict_trace", + context=f"{self}", + explanation=msg, + hints=[ + f"make sure definition of {fn_name} is outside ", + "`torch.compile` region", + ], + ) + # pyrefly: ignore[missing-attribute] + fn = fn_var.fn + return variables.TorchInGraphFunctionVariable(fn, nonstrict_traceable=True) + + if self.is_constant: + return invoke_and_store_as_constant( + tx, self.fn, self.get_name(), args, kwargs + ) + + if ( + not tx.output.current_tracer.unsafe_allow_externally_visible_side_effects + and self.fn + is torch._dynamo.utils._disable_side_effect_safety_checks_for_current_subtracer + ): + with torch._dynamo.side_effects.allow_externally_visible_side_effects_in_subtracer( + tx + ): + return super().call_function(tx, args, kwargs) + + if ( + getattr(tx.output.current_tracer, "description", None) + == "torch.utils.checkpoint.checkpoint" + and not tx.output.current_tracer.allow_side_effects_in_hop + ): + try: + from torch.distributed.fsdp._fully_shard._fsdp_state import FSDPState + except Exception: + FSDPState = None # type: ignore[assignment, misc] + if FSDPState is not None and self.fn in [ + FSDPState._pre_forward, + FSDPState._post_forward, + ]: + with torch._dynamo.side_effects.allow_side_effects_in_hop(tx): + return super().call_function(tx, args, kwargs) + + tree_map_result = self._maybe_call_tree_map_fastpath(tx, args, kwargs) + if tree_map_result is not None: + return tree_map_result + + return super().call_function(tx, args, kwargs) + + def _maybe_call_tree_map_fastpath( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> Optional[VariableTracker]: + rewrite = self._rewrite_tree_map_only_call(tx, args, kwargs) + if rewrite is not None: + tree_map_fn, tree_map_args, tree_map_kwargs = rewrite + else: + tree_map_fn = self + tree_map_args = args + tree_map_kwargs = kwargs + + if not ( + isinstance(tree_map_fn, UserFunctionVariable) + and tree_map_fn._is_tree_map_function() + and not ({*tree_map_kwargs} - _SUPPORTED_TREE_MAP_KWARGS) + and len(tree_map_args) >= 2 + ): + return None + + map_fn = tree_map_args[0] + first_tree = tree_map_args[1] + rest = tree_map_args[2:] + return first_tree.call_tree_map( + tx, + tree_map_fn, + map_fn, + rest, + tree_map_kwargs, + ) + + def _is_tree_map_function(self) -> bool: + return ( + getattr(self.fn, "__name__", None) == "tree_map" + and getattr(self.fn, "__module__", None) in self._TREE_MAP_MODULES + ) + + def _is_tree_map_only_function(self) -> bool: + return ( + getattr(self.fn, "__name__", None) == "tree_map_only" + and getattr(self.fn, "__module__", None) in self._TREE_MAP_MODULES + ) + + def _rewrite_tree_map_only_call( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> Optional[ + tuple[ + "UserFunctionVariable", + Sequence[VariableTracker], + dict[str, VariableTracker], + ] + ]: + if not self._is_tree_map_only_function(): + return None + + if len(args) != 3: + return None + if {*kwargs} - _TREE_MAP_ONLY_SUPPORTED_KWARGS: + return None + + type_selector, map_fn, tree_arg = args + allowed_types = self._extract_tree_map_only_types(type_selector) + if allowed_types is None: + return None + + tree_map_callable = self._lookup_tree_map_function() + if tree_map_callable is None: + return None + + wrapped_map_fn = TreeMapOnlyFunctionVariable( + allowed_types, + map_fn, + source=getattr(map_fn, "source", None), + ) + tree_map_variable = variables.UserFunctionVariable(tree_map_callable) + return tree_map_variable, [wrapped_map_fn, tree_arg], dict(kwargs) + + def _lookup_tree_map_function(self) -> Optional[types.FunctionType]: + module_name = getattr(self.fn, "__module__", None) + if not module_name: + return None + module = sys.modules.get(module_name) + if module is None: + return None + tree_map = getattr(module, "tree_map", None) + if isinstance(tree_map, types.FunctionType): + return tree_map + return None + + def _extract_tree_map_only_types( + self, selector: VariableTracker + ) -> Optional[tuple[type, ...]]: + if not selector.is_python_constant(): + return None + try: + raw_value = selector.as_python_constant() + except NotImplementedError: + return None + + flattened = self._flatten_type_spec(raw_value) + if not flattened: + return None + if not all(isinstance(typ, type) for typ in flattened): + return None + return tuple(dict.fromkeys(flattened)) + + def _flatten_type_spec(self, value: Any) -> Optional[list[type]]: + if isinstance(value, type): + return [value] + if isinstance(value, tuple): + collected: list[type] = [] + for entry in value: + flat = self._flatten_type_spec(entry) + if flat is None: + return None + collected.extend(flat) + return collected + union_type = getattr(types, "UnionType", None) + if union_type is not None and isinstance(value, union_type): + collected = [] + for entry in value.__args__: + flat = self._flatten_type_spec(entry) + if flat is None: + return None + collected.extend(flat) + return collected + return None + + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.fn) + + def is_python_equal(self, other): + return isinstance(other, variables.UserFunctionVariable) and self.fn is other.fn + + +class TreeMapOnlyFunctionVariable(BaseUserFunctionVariable): + _nonvar_fields = { + "allowed_types", + *BaseUserFunctionVariable._nonvar_fields, + } + + def __init__( + self, + allowed_types: tuple[type, ...], + map_fn: VariableTracker, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.allowed_types = allowed_types + self.map_fn = map_fn + + def python_type(self) -> type: + return FunctionType + + def _matches_allowed_type(self, node: VariableTracker) -> bool: + try: + node_type = node.python_type() + except NotImplementedError: + return False + return any(issubclass(node_type, allowed) for allowed in self.allowed_types) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if not args: + return self.map_fn.call_function(tx, args, kwargs) + leaf = args[0] + if self._matches_allowed_type(leaf): + return self.map_fn.call_function(tx, args, kwargs) + if len(args) != 1 or kwargs: + # Defer to the original map function so we fall back to normal + # tracing instead of triggering a graph break. + return self.map_fn.call_function(tx, args, kwargs) + return leaf + + +class BuiltinMethodVariable(BaseUserFunctionVariable): + def __init__( + self, fn: types.BuiltinMethodType, is_constant: bool = False, **kwargs: Any + ) -> None: + super().__init__(**kwargs) + assert isinstance(fn, types.BuiltinMethodType) + self.fn = fn + + @staticmethod + def is_supported_builtin_method(obj: Any) -> bool: + method_self = obj.__self__ + method_name = obj.__name__ + + # TODO(anijain2305) - Add support for more builtin methods + # Supports tuple.__new__ and frozenset({....}).__contains__ + return (method_self is tuple and method_name == "__new__") or ( + type(method_self) is frozenset and method_name == "__contains__" + ) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + method_self = self.fn.__self__ + name = self.fn.__name__ + obj_source = self.source and AttrSource(self.source, "__self__") + obj_vt = VariableTracker.build(tx, method_self, obj_source) + return obj_vt.call_method(tx, name, args, kwargs) + + +class LocalGeneratorObjectVariable(VariableTracker): + def __init__( + self, + code: types.CodeType, + f_globals: dict[str, Any], + inline_tracer: "InliningGeneratorInstructionTranslator", + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.code = code + self.f_globals = f_globals + self.inline_tracer = inline_tracer + + def get_code(self) -> types.CodeType: + return self.code + + def get_filename(self) -> str: + return self.get_code().co_filename + + def get_name(self) -> str: + return self.get_code().co_name + + def get_function(self) -> Never: + raise NotImplementedError + + def has_self(self) -> bool: + return False + + def __name__(self) -> str: + return self.get_name() + + def __str__(self) -> str: + return f"{self.__class__.__name__}({self.get_name()})" + + __repr__ = __str__ + + def reconstruct(self, codegen: "PyCodegen") -> None: + from torch._dynamo.side_effects import disallow_side_effects_in_generator + from torch._dynamo.symbolic_convert import ( + InstructionTranslator, + save_and_restart_speculation_log, + temporarely_allow_writes_to_output_graph, + ) + + tx = InstructionTranslator.current_tx() + save = save_and_restart_speculation_log(tx) + disallow = disallow_side_effects_in_generator(tx) + temp = temporarely_allow_writes_to_output_graph(tx) + + with save, disallow, temp: + tracer = self.inline_tracer + if not tracer.generator_exhausted: + self.remaining_items = self.force_unpack_var_sequence(tx) + variables.ListIteratorVariable(self.remaining_items).reconstruct(codegen) + + def bind_args( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> dict[str, VariableTracker]: + return self.vt.bind_args(tx, args, kwargs) # type: ignore[attr-defined] + + def get_globals(self) -> dict[str, Any]: + return self.f_globals + + def python_type(self) -> type: + return types.GeneratorType + + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: + tracer = self.inline_tracer + + if self._is_generator_exhausted(): + raise_observed_exception(StopIteration, tx) + + try: + # Hierarchically, tx can be seen as the parent of the inline tracer + # created on call_function. Any exception needs to be propagated to tx + # for Dynamo to behave correctly + return tracer.inline_call_() + except ObservedException as e: + tracer.generator_exhausted = True + raise e + except InfiniteGeneratorError: + # test/dynamo/test_misc.py::test_iterator_limit + raise + except Unsupported as e: + torch._dynamo.eval_frame.skip_code(self.get_code()) + raise SkipFrame from e + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + if name in self.python_type().__dict__: + return ConstantVariable.create(True) + return ConstantVariable.create(False) + + def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: + return False + + def has_force_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: + return True + + def force_unpack_var_sequence( + self, tx: "InstructionTranslator" + ) -> list[VariableTracker]: + result: list[VariableTracker] = [] + self.force_apply_to_var_sequence(tx, result.append) + return result + + def force_apply_to_var_sequence( + self, tx: "InstructionTranslator", fn: Callable[[VariableTracker], Any] + ) -> None: + while True: + try: + fn(self.next_variable(tx)) + except ObservedUserStopIteration: + handle_observed_exception(tx) + break + + # no nested graph breaks in generators + def should_allow_nested_graph_breaks(self): + return False + + def _setup_exception( + self, tx: "InstructionTranslator", exc: VariableTracker + ) -> None: + tracer = self.inline_tracer + try: + tracer._raise_exception_variable(exc) + except ObservedException as e: + # if no handler is available (i.e. user code doesn't catch it), the + # exception is raised again. + tracer.exception_handler(e) + + def _is_generator_just_started(self) -> bool: + return self.inline_tracer is None or self.inline_tracer.instruction_pointer == 0 + + def _is_generator_exhausted(self) -> bool: + return getattr(self.inline_tracer, "generator_exhausted", False) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__next__": + return self.next_variable(tx) + elif name == "__iter__": + # iter(gen) returns itself + return self + elif name == "send": + # Sends a value into the generator function. Returns the next value + # yielded by the generator, or raises StopIteration if the generator + # exits without yielding another value + if self._is_generator_just_started() and len(args): + # can't send non-None value to a just-started generator + # Test: GeneratorCPythonTests.test_send_non_none_to_new_gen + if not all(arg.is_constant_none() for arg in args): + raise_observed_exception(TypeError, tx) + tracer = self.inline_tracer + tracer.push_many(args) + return self.next_variable(tx) + elif name == "close": + # * Raises a GeneratorExit at the point where the generator function was paused. + # * If the generator function catches the exception and returns a + # value, this value is returned from close() - Python 3.13+ + # * If the generator function is already closed, or raises GeneratorExit + # (by not catching the exception), close() returns None. + # * If the generator yields a value, a RuntimeError is raised. + # * If the generator raises any other exception, it is propagated to the caller. + # * If the generator has already exited due to an exception or normal + # exit, close() returns None and has no other effect. + + # Return None if close is called on a just-started generator + # See test GeneratorCloseCpythonTests::test_close_not_started + + tracer = self.inline_tracer + if self._is_generator_just_started() or self._is_generator_exhausted(): + tracer.generator_exhausted = True + return variables.ConstantVariable(None) + + # Raise GeneratorExit to see if user code catches it. Any other exception + # is propagated to the parent frame. + try: + self._setup_exception( + tx, variables.ExceptionVariable(GeneratorExit, ()) + ) + # There's an extra block on Python 3.12+ to handle StopIteration + # see: https://github.com/python/cpython/blob/8f93dd8a8f237b277abad20d566df90c5cbd7f1e/Objects/genobject.c#L394-L397 + # + # 1 0 RETURN_GENERATOR + # 2 POP_TOP + # 4 RESUME 0 + + # 2 6 LOAD_CONST 1 (1) + # 8 YIELD_VALUE 1 + # 10 RESUME 1 + # 12 POP_TOP + # 14 RETURN_CONST 0 (None) + # >> 16 CALL_INTRINSIC_1 3 (INTRINSIC_STOPITERATION_ERROR) + # 18 RERAISE 1 + # ExceptionTable: + # 4 to 14 -> 16 [0] lasti + if ( + sys.version_info >= (3, 12) + and tracer.next_instruction.opname == "CALL_INTRINSIC_1" + ): + tracer.generator_exhausted = True + return variables.ConstantVariable(None) + except ObservedGeneratorExit: + # If it doesn't catch, we just return None, as per the text above + tracer.generator_exhausted = True + return variables.ConstantVariable(None) + + try: + # Raise RuntimeError if the generator yields any other value + if self.next_variable(tx): + raise_observed_exception(RuntimeError, tx) + except ObservedGeneratorExit: + tracer.generator_exhausted = True + return variables.ConstantVariable(None) + except ObservedUserStopIteration: + # In Python 3.13+, one can capture GeneratorExit and return a value + # See test_generator.py::test_close_capture_GeneratorExit_return + # https://discuss.python.org/t/let-generator-close-return-stopiteration-value/24786/26 + # https://github.com/python/cpython/pull/104771 + assert tracer.symbolic_result is not None + return tracer.symbolic_result + elif name == "throw": + # * Raises an exception at the point where the generator was paused, and + # returns the next value yielded by the generator. + # * If the generator exits without yielding, raise StopIteration + # * If the generator function does not catch the passed-in exception, + # or raises a different exception, then that exception propagates to the caller. + + # Setup the exception table and jump target in case of try...finally + tracer = self.inline_tracer + try: + # In Python 3.9, the exception is represented as a triple (typ, val, tb) + # In such cases, we re-raise the exception object given to avoid + # creating a new object, so that IS_OP works. + # See: https://github.com/pytorch/pytorch/pull/146496 + self._setup_exception(tx, args[1] if len(args) == 3 else args[0]) + except ObservedException: # noqa: TRY203 + # propagate the exception back to the parent caller + raise + + retval = self.next_variable(tx) + + # The exception raised before is still active. We need to check the exception + # table one more time to find the next target. But why? Let's walk + # through an example and its generated bytecode: https://godbolt.org/z/ebdTbMv8M + # + # z = 0 + # def whoo(): + # global z + # z = 0 + # try: + # yield 1 + # except ValueError: + # yield 2 + # finally: + # z += 1 + # z += 10 + # + # gen = whoo() + # next(gen) + # gen.throw(ValueError) + # print('z', z) -> z = 1 + # + # ... + # >> 58 PUSH_EXC_INFO + # + # 8 60 LOAD_GLOBAL 2 (ValueError) + # 70 CHECK_EXC_MATCH + # 72 POP_JUMP_IF_FALSE 7 (to 88) + # 74 POP_TOP + # + # 9 76 LOAD_CONST 3 (2) + # 78 YIELD_VALUE 3 <------ ValueError is still active here + # 80 RESUME 1 + # 82 POP_TOP + # 84 POP_EXCEPT + # 86 jump_backward 34 (to 20) + # ... + # + # ExceptionTable: + # 4 to 8 -> 124 [0] lasti + # 12 to 18 -> 58 [0] + # 20 to 56 -> 124 [0] lasti + # 58 to 82 -> 90 [1] lasti <------ move to 90 + # 84 to 86 -> 96 [0] + # 88 to 88 -> 90 [1] lasti + # 90 to 94 -> 96 [0] + # 96 to 116 -> 118 [1] lasti + # 118 to 122 -> 124 [0] lasti + # + # In this scenario, a generator can yield after `throw()` is called. Even + # after the exception is raised a few lines above, it remains active + # within the `78 YIELD_VALUE` instruction. When the generator resumes + # after the second yield on instruction `80 RESUME`, we cannot simply + # return the control flow to the next instruction. Instead, one must + # check the exception table (or equivalent) to find the next target + # In this case, it says the instruction pointer must be moved to 90. + # + # Without this step, if we let the trace proceed to the next + # instruction, it would follow the control flow where the exception + # raised by `throw()` was handled and swallowed, potentially leading + # to incorrect behavior. + exc_type = type("__InternalThrowException", (Exception,), {}) + + try: + self._setup_exception(tx, variables.ExceptionVariable(exc_type, ())) + self.next_variable(tx) + except get_dynamo_observed_exception(exc_type): + # We should get back the exception raised before. + pass + else: + raise_observed_exception(RuntimeError, tracer) + return retval + + return super().call_method(tx, name, args, kwargs) + + +class ContextlibContextManagerLocalGeneratorObjectVariable( + LocalGeneratorObjectVariable +): + """ + .. note:: + + This is only used when the function is annotated with @contextlib.contextmanager + + It is a special case of a generator function as we do not allow return a context manager + from a torch.compile function. + """ + + +class LocalGeneratorFunctionVariable(BaseUserFunctionVariable): + """functions that behaves like iterators + + .. note:: + + This is a wrapper around (Nested)UserFunctionVariable + """ + + def __init__( + self, + vt: VariableTracker, + *, + generator_cls: type = LocalGeneratorObjectVariable, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.vt = vt + self.generator_cls = generator_cls + + def __getattr__(self, name): + if name in self.__class__.__dict__: + return getattr(self, name) + return getattr(self.vt, name) + + def get_globals(self) -> dict[str, Any]: + return self.vt.get_globals() # type: ignore[attr-defined] + + def _build_inline_tracer( + self, + tx: "InstructionTranslatorBase", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> "InliningInstructionTranslator": + from torch._dynamo.symbolic_convert import InliningInstructionTranslator + + return InliningInstructionTranslator.build_inline_tracer( + tx, + self, + args, + kwargs, + ) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if not is_generator(self.vt.get_code()): # type: ignore[attr-defined] + unimplemented( + gb_type="non-generator contextlib.contextmanager", + context=str(self.vt.get_code()), # type: ignore[attr-defined] + explanation="Cannot compile function decorated with `@contextlib.contextmanager` that is not a generator" + ", i.e. does not use `yield`", + hints=[ + "Use `yield` in the function body instead of `return`.", + "Remove the `@contextlib.contextmanager` decorator.", + ], + ) + + inline_tracer = self._build_inline_tracer(tx, list(args), kwargs) + code = self.vt.get_code() # type: ignore[attr-defined] + f_globals = self.vt.get_globals() # type: ignore[attr-defined] + + # calling a generator returns a generator object + return self.generator_cls( + code, + f_globals, + inline_tracer, # type: ignore[arg-type] + source=self.source, + ) + + +class FunctionDecoratedByContextlibContextManagerVariable( + LocalGeneratorFunctionVariable +): + """ + .. note:: + + This is only used when the function is annotated with @contextlib.contextmanager + """ + + def __init__(self, vt: VariableTracker, **kwargs: Any): + super().__init__( + vt, + generator_cls=ContextlibContextManagerLocalGeneratorObjectVariable, + **kwargs, + ) + + def _build_inline_tracer( + self, + tx: "InstructionTranslatorBase", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> "InliningGeneratorInstructionTranslator": + # NOTE: This only exists to not break support for context manager when + # config.enable_faithful_generator_behavior = False and + # config.enable_trace_contextlib = True. In case the former is false, + # Dynamo should still be able to trace through @contextmanager functions + tracer = super()._build_inline_tracer(tx, args, kwargs) + assert isinstance( + tracer, + torch._dynamo.symbolic_convert.InliningGeneratorInstructionTranslator, + ) + tracer.is_generator_from_ctx_manager = True + return tracer + + +class UserMethodVariable(UserFunctionVariable): + """Some unsupported user-defined method""" + + def __init__( + self, + fn: Callable[..., Any], + obj: VariableTracker, + source_fn: Optional[Callable[..., Any]] = None, + **kwargs: Any, + ) -> None: + super().__init__(fn=fn, **kwargs) # type: ignore[arg-type] + self.obj = obj + self.source_fn = source_fn + # Note on source and source_fn + # Be careful with `source` when delegating to UserFunctionVariable + # (base-class) methods. In this __init__, `source` is a *bound method* + # object, but the base class expects the underlying *function* object. + # One way is to simplly use `__func__` to unwrap it. + # + # For recursive dict-tag optimizations, it can be faster to fetch the + # function directly from `cls.__dict__`; that's why we pass on + # `source_fn`. Whenever it is possible to access the function from + # cls.__dict__, we pass that on to `source_fn`. Because bind_args + # operates on the unbound function, most guards should target + # `source_fn` rather than the original `source`. + if source_fn is None and kwargs.get("source") is not None: + self.source_fn = AttrSource(kwargs.get("source"), "__func__") # type: ignore[assignment, arg-type] + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.fn}, {self.obj})" + + def self_args(self) -> list[VariableTracker]: + return [self.obj] + + def python_type(self) -> type[types.MethodType]: + return types.MethodType + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # NOTE this is to handle methods annotated by `nonstrict_trace`. + # a `nonstrict_trace`-ed function will be wrapped by + # `VariableTracker.build` and route to `TorchInGraphFunctionVariable`, + # but in the case of method, we manually wrap it with `UserMethodVariable` + # inside `UserDefinedObjectVariable.var_getattr`. + # + # We might be able to simplify this away by canonicalizing the + # function/method wrapping code paths. + from ..trace_rules import is_nonstrict_trace_callable + + if is_nonstrict_trace_callable(self.fn): + call_args = [*self.self_args(), *args] + var = variables.TorchInGraphFunctionVariable( + self.fn, nonstrict_traceable=True + ) + return var.call_function(tx, call_args, kwargs) + + # For nn.Module methods, redirecting to NNModuleVariable.call_method for optimized solution + # rather than simple inlining. E.g, putting `call_method` op in FX graph for `forward` method + # since we ensure `forward` of allowed modules can be traced by AOT safely. + # Note this is not only for allowed modules, as user customized modules can extend from + # allowed modules but using parent's `forward` method, which is also covered by this branch. + + # If we are tracing the higher order op, we want Dynamo to step inside + # the module call so that Dynamo can see the underlying parameters and + # buffers and raise them as inputs to the graph. The is_root_tracer + # check bypasses the if condition for non-root tracers and directly + # calls the super().call_function at the end, which is basically + # equivalent of inlining the method. + if tx.output.is_root_tracer() and isinstance( + self.obj, variables.NNModuleVariable + ): + module_attr = getattr(self.fn, "__module__", "") + # inline torch.nn.utils.parametrize + if ( + module_attr is not None + and module_attr.startswith("torch.nn.") + and module_attr != "torch.nn.utils.parametrize" + or self.is_constant + ): + return self.obj.call_method( + tx, self.fn.__name__, list(args), kwargs, constant=self.is_constant + ) + elif ( + _fsdp_param_group is not None + and self.fn is _fsdp_param_group.FSDPParamGroup.use_training_state # type: ignore[attr-defined] + ): + return variables.TorchCtxManagerClassVariable(self.fn).call_function( + tx, (self.obj, *args), kwargs + ) + if self.is_constant: + fn = getattr(self.obj.value, self.fn.__name__) # type: ignore[attr-defined] + return invoke_and_store_as_constant(tx, fn, self.get_name(), args, kwargs) + return super().call_function(tx, args, kwargs) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if name == "__self__": + return self.obj + if name == "__func__": + # We might have a better way to access the function object, this + # information is stored in self.source_fn, use that to construct the + # variable tracker. + return VariableTracker.build(tx, self.fn, self.source_fn) # type: ignore[arg-type] + return super().var_getattr(tx, name) + + +class WrappedUserMethodVariable(UserMethodVariable): + def __init__( + self, + wrapped: UserMethodVariable, + context: "ContextWrappingVariable", + **kwargs: Any, + ) -> None: + kwargs.pop("fn", None) + kwargs.pop("obj", None) + super().__init__(wrapped.fn, wrapped.obj, **kwargs) + self.wrapped = wrapped + self.context = context + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + self.context.enter(tx) + result = super().call_function(tx, args, kwargs) + self.context.exit(tx) + return result + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null(lambda: codegen(self.context)) # type: ignore[arg-type] + codegen(self.wrapped) + codegen.extend_output(create_call_function(1, False)) + + +class WrappedUserFunctionVariable(UserFunctionVariable): + def __init__( + self, + wrapped: UserFunctionVariable, + context: "ContextWrappingVariable", + **kwargs: Any, + ) -> None: + kwargs.pop("fn", None) + super().__init__(wrapped.fn, **kwargs) + self.wrapped = wrapped + self.context = context + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + self.context.enter(tx) + result = super().call_function(tx, args, kwargs) + self.context.exit(tx) + return result + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null(lambda: codegen(self.context)) # type: ignore[arg-type] + codegen(self.wrapped) + codegen.extend_output(create_call_function(1, False)) + + +def invoke_and_store_as_constant( + tx: "InstructionTranslator", + fn: Callable[..., Any], + name: str, + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], +) -> VariableTracker: + def convert(x: VariableTracker) -> Any: + if x.is_tensor(): + return cast("TensorVariable", x).get_real_value() + return x.as_python_constant() + + args = [convert(x) for x in args] + kwargs = {k: convert(v) for k, v in kwargs.items()} + res = fn(*args, **kwargs) + return tx.output.register_attr_or_module( + res, + name, + source=ConstantSource(name), + ) + + +class NestedUserFunctionVariable(BaseUserFunctionVariable): + _nonvar_fields = { + "f_globals", + *BaseUserFunctionVariable._nonvar_fields, + } + + def __init__( + self, + fn_name: VariableTracker, + code: VariableTracker, + f_globals: dict[str, Any], + defaults: Optional[VariableTracker], + kwdefaults: Optional[VariableTracker], + annotations: Optional[VariableTracker], + closure: Optional[VariableTracker], + # This is present when this function is created by + # `functools.wrap(wrapped_fn)(this_fn)`. + wrapped_fn: Optional[VariableTracker] = None, + **kwargs: Any, + ) -> None: + if kwargs.get("mutation_type") is None: + kwargs.update(mutation_type=AttributeMutationNew()) + super().__init__(**kwargs) + assert isinstance(fn_name.as_python_constant(), str) + assert isinstance(code.as_python_constant(), types.CodeType) + assert isinstance(f_globals, dict) + self.fn_name = fn_name + self.code = code + self.f_globals = f_globals + self.defaults = defaults + self.kwdefaults = kwdefaults + self.annotations = annotations + self.closure = closure + self.wrapped_fn: Optional[VariableTracker] = wrapped_fn + + def self_args(self) -> list[VariableTracker]: + return [] + + def get_code(self) -> types.CodeType: + return self.code.as_python_constant() + + def python_type(self) -> type: + return types.FunctionType + + def get_function(self) -> types.FunctionType: + if self.closure: + raise NotImplementedError + func = types.FunctionType( + self.code.as_python_constant(), + self.f_globals, + self.fn_name.as_python_constant(), + ) + if self.defaults: + func.__defaults__ = self.defaults.as_python_constant() + if self.kwdefaults: + func.__kwdefaults__ = self.kwdefaults.as_python_constant() + if self.annotations: + annotations = self.annotations.as_python_constant() + if isinstance(annotations, tuple): + from itertools import pairwise + + annotations = dict(pairwise(annotations)) + + # TypeError: __annotations__ must be set to a dict object + assert isinstance(annotations, dict) + func.__annotations__ = annotations + return func + + def call_setattr( + self, + tx: "InstructionTranslator", + name_var: VariableTracker, + val: VariableTracker, + ) -> VariableTracker: + tx.output.side_effects.store_attr(self, name_var.value, val) # type: ignore[attr-defined] + return ConstantVariable(None) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__setattr__": + return self.call_setattr(tx, *args) + return super().call_method(tx, name, list(args), kwargs) + + def has_closure(self) -> bool: + return self.closure is not None + + def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any: + if name == "__name__": + return self.get_name() + if name == "__code__": + return self.get_code() + if name == "__defaults__": + d = getattr(self, "defaults", None) + return d.as_python_constant() if d else None + return super().const_getattr(tx, name) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + if name == "__code__": + return variables.ConstantVariable.create(hasattr(self, "code")) + if name == "__defaults__": + return variables.ConstantVariable.create(hasattr(self, "defaults")) + return super().call_obj_hasattr(tx, name) + + def has_self(self) -> bool: + return False + + def get_globals(self) -> dict[str, Any]: + return self.f_globals + + def bind_args( + self, + parent: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> dict[str, VariableTracker]: + code = self.get_code() + func = types.FunctionType( + code, + self.f_globals, + self.fn_name.as_python_constant(), + tuple(self.defaults.items) if self.defaults else None, # type: ignore[attr-defined] + tuple(make_cell(None) for _ in range(len(self.get_code().co_freevars))), + ) + if self.kwdefaults: + func.__kwdefaults__ = self.kwdefaults.keys_as_python_constant() # type: ignore[attr-defined] + bound = inspect.signature(func).bind(*args, **kwargs) + bound.apply_defaults() + result = dict(bound.arguments.items()) + wrap_args_kwargs(parent.output.root_tx, result) # type: ignore[arg-type] + init_cellvars(parent, result, code) + + for idx, name in enumerate(code.co_freevars): + assert name not in result + cell = self.closure.items[idx] # type: ignore[attr-defined, union-attr] + result[name] = cell + + return result + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null( + lambda: codegen.load_import_from(__name__, "_create_nested_fn") + ) + codegen(self.code) + codegen.extend_output([codegen.create_load_const_unchecked(self.f_globals)]) + codegen(ConstantVariable.create(self.code.value.co_name)) # type: ignore[attr-defined] + + if self.defaults: + codegen(self.defaults) + else: + codegen.extend_output([codegen.create_load_const(None)]) + + if self.closure: + codegen(self.closure) + else: + codegen.extend_output([codegen.create_load_const(None)]) + + if self.kwdefaults: + codegen(self.kwdefaults) + else: + codegen.extend_output([codegen.create_load_const(None)]) + + if self.annotations: + try: + annotations = self.annotations.as_python_constant() + codegen.extend_output( + [codegen.create_load_const_unchecked(annotations)] + ) + except NotImplementedError: + codegen(self.annotations) + else: + codegen.extend_output([codegen.create_load_const(None)]) + + codegen.extend_output(create_call_function(7, False)) + + if self.wrapped_fn: + codegen.add_push_null( + lambda: codegen.load_import_from("functools", "wraps") + ) + codegen(self.wrapped_fn) + codegen.extend_output(create_call_function(1, False)) + codegen.extend_output(create_rot_n(2)) + codegen.extend_output(create_call_function(1, True)) + + # codegen attributes + from torch._dynamo.symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + if tx.output.side_effects.has_pending_mutation(self): + for name, value in tx.output.side_effects.store_attr_mutations[ + self + ].items(): + codegen.dup_top() + codegen(value) + codegen.extend_output(create_rot_n(2)) + codegen.store_attr(name) + + +class WrappedNestedUserFunctionVariable(NestedUserFunctionVariable): + def __init__( + self, + wrapped: Any, + context: "ContextWrappingVariable", + **kwargs: Any, + ) -> None: + kwargs.pop("fn_name", None) + kwargs.pop("code", None) + kwargs.pop("f_globals", None) + kwargs.pop("defaults", None) + kwargs.pop("kwdefaults", None) + kwargs.pop("annotations", None) + kwargs.pop("closure", None) + kwargs.pop("wrapped_fn", None) + super().__init__( + wrapped.fn_name, + wrapped.code, + wrapped.f_globals, + wrapped.defaults, + wrapped.kwdefaults, + wrapped.annotations, + wrapped.closure, + wrapped.wrapped_fn, + ) + self.wrapped = wrapped + self.context = context + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + self.context.enter(tx) + result = super().call_function(tx, args, kwargs) + self.context.exit(tx) + return result + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null(lambda: codegen(self.context)) # type: ignore[arg-type] + codegen(self.wrapped) + codegen.extend_output(create_call_function(1, False)) + + +class SkipFunctionVariable(VariableTracker): + _nonvar_fields = { + "value", + "reason", + *VariableTracker._nonvar_fields, + } + + def __init__(self, value: Any, reason: Optional[str] = None, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.value = value + self.reason = reason + + def as_python_constant(self) -> Any: + return self.value + + @classmethod + def create_with_source(cls, value: Any, source: Source) -> "SkipFunctionVariable": + # Use closure match guard (i.e. guard on __code__ object instead of + # function id) to avoid guarding on nested functions. + if inspect.getattr_static(value, "_torchdynamo_disable", False): + # For torch._dynamo.disable function, ensure that the original + # function is guarded. Otherwise, the else branch will guard on the + # _dynamo.disable.__code__ + guard_on_source = source + guard_on_value = value + + while getattr(guard_on_value, "_torchdynamo_orig_callable", False): + guard_on_value = guard_on_value._torchdynamo_orig_callable + guard_on_source = AttrSource( + guard_on_source, "_torchdynamo_orig_callable" + ) + + guard_on_source.make_guard(GuardBuilder.CLOSURE_MATCH) + elif inspect.isbuiltin(value): + install_guard(source.make_guard(GuardBuilder.BUILTIN_MATCH)) + elif not is_wrapper_or_member_descriptor(value): + # These descriptors are not guaranteed to return the same object on + # attribute lookup. They are unlikely to be changed, so we can skip + # guarding them. + install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH)) + return cls(value, source=source) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if inspect.getattr_static(self.value, "_torchdynamo_disable", False): + msg = inspect.getattr_static(self.value, "_torchdynamo_disable_msg", None) + unimplemented( + gb_type="Skip calling `torch.compiler.disable()`d function", + context=str(self.value), + explanation=f"Skip calling function `{self.value}` since it was wrapped " + f"with `torch.compiler.disable` (reason: {msg})", + hints=[ + "Remove the `torch.compiler.disable` call", + ], + ) + elif self.value is torch._dynamo.graph_break: + graph_break_msg = kwargs.get("msg") + if graph_break_msg: + graph_break_msg = graph_break_msg.as_python_constant() + unimplemented( + gb_type="Call to `torch._dynamo.graph_break()`", + context=f"Called `torch._dynamo.graph_break()` with args `{args}`, kwargs `{kwargs}`", + explanation=f"User-inserted graph break. Message: {graph_break_msg}", + hints=[ + "Remove the `torch._dynamo.graph_break()` call.", + ], + ) + elif self.value is torch._dynamo.skip_frame: + skip_frame_msg = kwargs.get("msg") + if skip_frame_msg: + skip_frame_msg = skip_frame_msg.as_python_constant() + else: + skip_frame_msg = "" + raise SkipFrame( + format_skip_frame_message( + tx.f_code, + f"Skip frame due to `torch._dynamo.skip_frame()`. Message: {skip_frame_msg}", + ) + ) + elif self.value is torch._dynamo.step_unsupported: + raise StepUnsupported + else: + if config.dont_skip_tracing: + from .builder import SourcelessBuilder + + # re-build the function, attempting to not skip + rebuilt_fn = SourcelessBuilder.create(tx, self.value) + # if we still get SkipFunctionVariable, then we *really* should skip this function + if not isinstance(rebuilt_fn, SkipFunctionVariable): + return rebuilt_fn.call_function(tx, args, kwargs) + qualname = getattr(self.value, "__qualname__", "") + module_or = getattr(self.value, "__module__", None) + module_name = "" if module_or is None else str(module_or) + try: + path = inspect.getfile(self.value) + explanation = ( + f"Dynamo developers have intentionally marked that the function `{qualname}` " + f"in file `{path}` should not be traced." + ) + hints = [ + f"Avoid calling the function `{qualname}`.", + ] + # TODO improve trace_rules reasoning to provide better hints. + # How do we tell that a function/file should NOT be removed from skip files? + # Do a very basic check for now. + if "_dynamo" not in path: + hints += [ + f"Apply `@torch._dynamo.dont_skip_tracing` to the function `{qualname}` " + "to force tracing into the function. " + "More graph breaks may occur as a result of attempting to trace into the function.", + "Please file an issue to PyTorch.", + ] + except TypeError: + known_python_builtin_modules = {"_abc", "_warnings"} + if module_or in known_python_builtin_modules: + explanation = ( + f"Dynamo does not know how to trace the Python builtin " + f"`{module_name}.{qualname}`." + ) + hints = [ + "If you are attempting to call a logging function (e.g. `_warnings.warn`), " + "you can try adding it to `torch._dynamo.config.reorderable_logging_functions`.", + "Please file an issue on GitHub " + "so the PyTorch team can add support for it. ", + ] + elif module_or is not None and module_or.startswith("optree"): + explanation = f"Dynamo cannot trace optree C/C++ function {module_name}.{qualname}." + hints = [ + " Consider using torch.utils._pytree - " + "https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py" + ] + # also warn on it because most users won't see the graph break message + torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints)) + else: + explanation = ( + f"Dynamo does not know how to trace the builtin `{module_name}.{qualname}.` " + f"This function is either a Python builtin (e.g. _warnings.warn) " + f"or a third-party C/C++ Python extension (perhaps created with pybind)." + ) + hints = [ + "If it is a Python builtin, please file an issue on GitHub " + "so the PyTorch team can add support for it and see the next case for a workaround.", + "If it is a third-party C/C++ Python extension, please " + "either wrap it into a PyTorch-understood custom operator " + "(see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html " + "for more details) or, if it is traceable, use " + "`torch.compiler.allow_in_graph`.", + ] + # also warn on it because most users won't see the graph break message + torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints)) + if qualname == "allow_in_graph": + explanation = ( + "Found an allow_in_graph decorator to a function which " + "is created inside the parent function that is getting " + "compiled. This is not supported for now." + ) + hints = [] + reason = self.reason if self.reason else "" + unimplemented( + gb_type="Attempted to call function marked as skipped", + context=f"module: {module_name}, qualname: {qualname}, skip reason: {reason}", + explanation=explanation, + hints=hints, + ) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + return variables.ConstantVariable.create(hasattr(self.value, name)) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if name in cmp_name_to_op_mapping: + return variables.GetAttrVariable(self, name) + + return fn_var_getattr(tx, self.value, self.source, name) + + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.value) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + + +class WrappedSkipFunctionVariable(SkipFunctionVariable): + def __init__( + self, + wrapped: VariableTracker, + context: "ContextWrappingVariable", + **kwargs: Any, + ) -> None: + kwargs.pop("value", None) + kwargs.pop("reason", None) + super().__init__(wrapped.value, reason=wrapped.reason, **kwargs) # type: ignore[attr-defined] + self.wrapped = wrapped + self.context = context + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + self.context.enter(tx) + result = super().call_function(tx, args, kwargs) + self.context.exit(tx) + return result + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null(lambda: codegen(self.context)) # type: ignore[arg-type] + codegen(self.wrapped) + codegen.extend_output(create_call_function(1, False)) + + +class WrapperUserFunctionVariable(VariableTracker): + """ + Used to represent a wrapper object that contains the actual callable as an + attribute. For example, torch.jit.script/trace have the original function at + their _torchdynamo_inline attribute. Similarly, functions with + __script_if_tracing_wrapper have the original attr at "__original_fn". + """ + + def __init__(self, wrapper_obj: Any, attr_to_trace: str, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.wrapper_obj = wrapper_obj + self.attr_to_trace = attr_to_trace + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if name == self.attr_to_trace: + val = getattr(self.wrapper_obj, self.attr_to_trace) + source = self.source and AttrSource(self.source, name) + return VariableTracker.build(tx, val, source) + + return super().var_getattr(tx, name) + + def self_args(self) -> list[VariableTracker]: + return [] + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if hasattr(self.wrapper_obj, "cache_info"): + target_fn = getattr(self.wrapper_obj, self.attr_to_trace, None) + module_name = getattr(target_fn, "__module__", "") or "" + + if module_name.split(".", maxsplit=1)[0] != "torch": + msg = ( + "Dynamo detected a call to a `functools.lru_cache`-wrapped " + "function. Dynamo ignores the cache wrapper and directly " + "traces the wrapped function. Silent incorrectness is only " + "a *potential* risk, not something we have observed. " + 'Enable TORCH_LOGS="+dynamo" for a DEBUG stack trace.' + ) + + torch._dynamo.utils.warn_once(msg) + + dynamo_logger = torch._dynamo.utils.logging.getLogger("torch._dynamo") + if dynamo_logger.isEnabledFor(logging.DEBUG): + user_stack = torch._guards.TracingContext.extract_stack() + user_stack = get_stack_above_dynamo() + user_stack + frame_loc = (user_stack[-1].filename, user_stack[-1].lineno) + user_stack_formatted = "".join(traceback.format_list(user_stack)) + user_stack_trace = f"call to a lru_cache wrapped function at: {frame_loc[0]}:{frame_loc[1]}\n" + user_stack_trace += str(user_stack_formatted) + dynamo_logger.debug(user_stack_trace) + + all_args = self.self_args() + list(args) + return variables.UserFunctionVariable( + polyfills.getattr_and_trace # type: ignore[arg-type] + ).call_function( + tx, + [self, variables.ConstantVariable(self.attr_to_trace), *all_args], + kwargs, + ) + + +class WrapperUserMethodVariable(WrapperUserFunctionVariable): + """ + Similar to WrapperUserFunctionVariable, but for methods. The only delta is + saving the vt for `self` object of the method which is then used by + WrapperUserFunctionVariable in `call_function` method. + """ + + def __init__( + self, + wrapper_obj: Any, + attr_to_trace: str, + self_obj: VariableTracker, + **kwargs: Any, + ) -> None: + super().__init__(wrapper_obj, attr_to_trace, **kwargs) + self.obj = self_obj + + def self_args(self) -> list[VariableTracker]: + return [self.obj] + + +def _traceable_collective_remaps() -> dict[Any, Any]: + # We can't rely on importing from distributed, since it's not always built + if torch.distributed.is_available(): + from torch.distributed._functional_collectives import ( + traceable_collective_remaps, + ) + + return traceable_collective_remaps + return {} + + +def _traceable_collectives_source( + tx: "InstructionTranslator", fn: Callable[..., Any] +) -> AttrSource: + assert torch.distributed.is_available(), "Illegal invocation." + assert fn in _traceable_collective_remaps().values() + + inner_name = fn.__name__ + path_source = tx.import_source("torch.distributed._functional_collectives") + return AttrSource(path_source, inner_name) + + +class CollectiveFunctionRewriteVariable(UserFunctionVariable): + """ + Some of the torch.distributed.* collective APIs are possible to rewrite to 'traceable' collectives. + + This class provides both a way to check if a function is remappable, and perform the remapping. + + In the case that a function is 'remappable' but only for some combinations of call-time arguments, + we check the args at `call_function` time and fall back to graph-breaking if needed. This is no worse + than status-quo as we currently graph-break on all distributed.* collectives. + """ + + def __init__( + self, + fn: Callable[..., Any], + *, + replacement_var: UserFunctionVariable, + **kwargs: Any, + ) -> None: + super().__init__(fn, **kwargs) # type: ignore[arg-type] + assert isinstance(replacement_var, UserFunctionVariable) + self.replacement_var = replacement_var + + @staticmethod + def create( + tx: "InstructionTranslator", + old_fn: Callable[..., Any], + source: Source, + **options: Any, + ) -> "CollectiveFunctionRewriteVariable": + new_fn, new_source = CollectiveFunctionRewriteVariable.rewrite(tx, old_fn) + return CollectiveFunctionRewriteVariable( + old_fn, + replacement_var=UserFunctionVariable(new_fn, source=new_source, **options), + source=source, + **options, + ) + + @staticmethod + def can_rewrite(variable: Any) -> bool: + return ( + inspect.isfunction(variable) and variable in _traceable_collective_remaps() + ) + + @staticmethod + def rewrite( + tx: "InstructionTranslator", fn: Callable[..., Any] + ) -> tuple[Any, AttrSource]: + new_fn = _traceable_collective_remaps()[fn] + return new_fn, _traceable_collectives_source(tx, new_fn) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # call_function must check any unsupported arguments and graph-break. + # It's safe to assume args/kwargs from orig_fn map 1:1 to args/kwargs of remapped_fn, + # since that's the contract for putting a mapping in `traceable_collective_remaps` + import torch.distributed as dist + from torch.distributed._functional_collectives import REDUCE_OP_TO_STR + + # Merge args into kwargs so positional and keyword args + # can be processed the same way. + signature = inspect.signature(self.fn) + kwargs = dict(signature.bind(*args, **kwargs).arguments) + args = () + + if "async_op" in kwargs and kwargs["async_op"].as_python_constant(): + unimplemented( + gb_type="async_op=True for distributed collectives", + context=f"{self.fn}, {args=}, {kwargs=}", + explanation=f"`torch.compile` doesn't support `async_op=True for {self.fn}", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + if self.fn in ( + dist.all_reduce, + dist.reduce_scatter_tensor, + dist._reduce_scatter_base, + ): + reduce_op_var = kwargs.get("op") + reduce_op = ( + reduce_op_var.value # type: ignore[attr-defined] + if reduce_op_var is not None + else signature.parameters["op"].default + ) + if reduce_op not in REDUCE_OP_TO_STR: + raise ValueError(f"Unsupported all_reduce op: {reduce_op}") + kwargs["op"] = variables.ConstantVariable.create( + REDUCE_OP_TO_STR[reduce_op] + ) + return self.replacement_var.call_function(tx, args, kwargs) + + +class FunctoolsWrapsVariable(UserFunctionVariable): + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if not kwargs and len(args) == 1: + + def wraps(fn: Any) -> VariableTracker: + if isinstance(fn, variables.NestedUserFunctionVariable): + return fn.clone(wrapped_fn=args[0]) + unimplemented( + gb_type="functools.wraps", + context=f"{fn}", + explanation="`torch.compile` can't trace `functools.wraps` on functions defined outside the compile region", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + return variables.LambdaVariable(wraps) + + return super().call_function(tx, args, kwargs) + + +class CollectionsNamedTupleFunction(UserFunctionVariable): + def as_python_constant(self) -> Any: + return self.fn + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + constant_args = check_constant_args(args, kwargs) + if constant_args: + try: + value = self.fn( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ) + except TypeError as exc: + raise_observed_exception( + type(exc), + tx, + args=list(map(ConstantVariable.create, exc.args)), + ) + return variables.UserDefinedClassVariable( + # pyrefly: ignore[unbound-name] + value, + mutation_type=ValueMutationNew(), + ) + unimplemented( + gb_type="namedtuple construction", + context=f"{args=}, {kwargs=}", + explanation="`torch.compile` only support certain input types for namedtuple", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + +class FunctoolsPartialVariable(VariableTracker): + def __init__( + self, + func: VariableTracker, + args: Sequence[VariableTracker], + keywords: dict[str, VariableTracker], + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.func = func + assert isinstance(args, list) + self.args = args + assert isinstance(keywords, dict) + self.keywords = keywords + # fake_value is used for id calculation. Creating this value and id'ng + # on it is sufficient for the tracing purposes. + self.fake_value = functools.partial(identity) + + def python_type(self) -> type: + return functools.partial + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null(lambda: codegen.load_import_from("functools", "partial")) + codegen(self.func) + if self.args: + codegen.foreach(self.args) + if not self.keywords: + codegen.extend_output(create_call_function(len(self.args) + 1, False)) + return + + codegen.foreach(self.keywords.values()) + keys = tuple(self.keywords.keys()) + codegen.extend_output( + codegen.create_call_function_kw(len(keys) + len(self.args) + 1, keys, False) + ) + + def get_function(self) -> Any: + return self.as_python_constant() + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + merged_args = self.args + list(args) + merged_kwargs = {**self.keywords, **kwargs} + return self.func.call_function(tx, merged_args, merged_kwargs) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + # functools.partial uses slots, so attributes are constant + return variables.ConstantVariable.create( + hasattr(functools.partial(identity), name) + ) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + source = self.source and AttrSource(self.source, name) + # Handle __slots__ + if name == "func": + return self.func + if name == "args": + return variables.ListVariable(self.args, source=source) + if name == "keywords": + items = {ConstantVariable.create(k): v for k, v in self.keywords.items()} + return variables.ConstDictVariable(items, source=source) + if name in cmp_name_to_op_mapping: + return variables.GetAttrVariable(self, name) + raise_observed_exception(AttributeError, tx) + + def as_python_constant(self) -> Any: + return functools.partial( + self.func.as_python_constant(), + *[arg.as_python_constant() for arg in self.args], + **{k: v.as_python_constant() for k, v in self.keywords.items()}, + ) + + def guard_as_python_constant(self) -> Any: + """Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants""" + return functools.partial( + self.func.guard_as_python_constant(), + *[v.guard_as_python_constant() for v in self.args], + **{k: v.guard_as_python_constant() for k, v in self.keywords.items()}, + ) + + def is_python_hashable(self) -> bool: + return ( + self.func.is_python_hashable() + and all(arg.is_python_hashable() for arg in self.args) + and all(value.is_python_hashable() for value in self.keywords.values()) + ) + + def get_python_hash(self): + func_hash = self.func.get_python_hash() + args_hash = (arg.get_python_hash() for arg in self.args) + values_hash = (value.get_python_hash() for value in self.keywords.values()) + return hash((func_hash, *args_hash, *values_hash)) + + def is_python_equal(self, other): + return ( + self.func.is_python_equal(other.func) + and all( + arg_a.is_python_equal(arg_b) + for (arg_a, arg_b) in zip(self.args, other.args) + ) + and all( + value_a.is_python_equal(value_b) + for (value_a, value_b) in zip( + self.keywords.values(), other.keywords.values() + ) + ) + ) + + +class PolyfilledFunctionVariable(VariableTracker): + _nonvar_fields = { + "fn", + "wrapped_fn", + "traceable_fn", + *VariableTracker._nonvar_fields, + } + + @classmethod + @functools.cache + def _get_polyfill_handlers(cls) -> dict[Callable[..., Any], types.FunctionType]: + return {} + + @classmethod + def create_with_source( + cls, value: Any, source: Source + ) -> "PolyfilledFunctionVariable": + install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH)) + + return cls(value, source=source) + + def __init__(self, fn: _F, **kwargs: Any) -> None: + super().__init__(**kwargs) + # pyrefly: ignore[invalid-type-var] + self.fn: _F = fn + + handler = self._get_polyfill_handlers().get(fn, fn) + traceable_fn = None + assert callable(handler), f"Polyfill handler {handler} is not callable for {fn}" + for candidate_attr in ( + "__torch_dynamo_polyfill__", # registered polyfill + "__python_implementation__", # self handler from third-party libraries + ): + candidate = getattr(handler, candidate_attr, None) + if candidate: + assert callable(candidate) + traceable_fn = candidate + break + else: + raise RuntimeError( + f"Polyfill handler {handler} does not have a traceable function" + ) + # pyrefly: ignore[invalid-type-var] + self.wrapped_fn = handler + # pyrefly: ignore[invalid-type-var] + self.traceable_fn: _F = traceable_fn + + @property + def polyfill_fn(self) -> Callable[..., Any]: + return self.traceable_fn + + def can_constant_fold_through(self) -> bool: + return getattr( + self.wrapped_fn, "__torch_dynamo_can_constant_fold_through__", False + ) + + def get_function(self) -> Any: + return self.as_python_constant() + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if self.can_constant_fold_through() and check_unspec_or_constant_args( + args, kwargs + ): + result = ( + self.fn( # use the original function which is faster than the polyfill + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ) + ) + return VariableTracker.build(tx, result) + + # Special case for sum on tuple/list of ints + if ( + self.fn is builtins.sum + and len(args) == 1 + and not kwargs + and isinstance(args[0], (variables.ListVariable, variables.TupleVariable)) + and all( + (x.is_python_constant() and isinstance(x.as_python_constant(), int)) + or (isinstance(x, variables.SymNodeVariable) and x.python_type() is int) + for x in args[0].items + ) + ): + return variables.SymNodeVariable.create( + tx, + tx.output.create_proxy( + "call_function", + torch.sym_sum, + (tuple(a.as_proxy() for a in args[0].items),), + {}, + ), + sym_num=torch.sym_sum( + [ + ( + x.as_python_constant() + if x.is_python_constant() + else x.sym_num # type: ignore[attr-defined] + ) + for x in args[0].items + ] + ), + ) + + traceable_function_variable = VariableTracker.build(tx, self.traceable_fn) + return traceable_function_variable.call_function(tx, args, kwargs) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__call__": + return self.call_function(tx, args, kwargs) + + method = getattr(self.fn, name, None) + if not (method or is_function(method)): + raise_type_error_exc(tx, f"Cannot find callable {name} in {self.fn}") + options = {} + if self.source: + options["source"] = AttrSource(self.source, name) + # pyrefly: ignore[bad-specialization] + polyfilled_method_variable = PolyfilledFunctionVariable(method, **options) + return polyfilled_method_variable.call_function(tx, args, kwargs) + + def as_python_constant(self) -> Any: + return self.fn + + +class TracebackVariable(VariableTracker): + # We don't track traceback. A call to any function in this module is a no-op + def call_function( # type: ignore[empty-body] + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: ... + + +class SysFunctionVariable(VariableTracker): + def __init__(self, value: Any, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.value = value + + def exc_info(self, tx: "InstructionTranslator") -> "variables.TupleVariable": + if len(tx.exn_vt_stack): + exn = tx.exn_vt_stack[-1] + typ = exn.exc_type # type: ignore[union-attr] + tb = None + items = [ + VariableTracker.build(tx, typ), + exn, + VariableTracker.build(tx, tb), + ] + else: + items = [ + variables.ConstantVariable(None), + variables.ConstantVariable(None), + variables.ConstantVariable(None), + ] + return variables.TupleVariable(items) # type: ignore[arg-type] + + def exception(self, tx: "InstructionTranslator") -> VariableTracker: + return self.exc_info(tx).items[1] + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if self.value is sys.exc_info: + return self.exc_info(tx) + assert self.value is sys.exception + return self.exception(tx) + + +from torch._higher_order_ops.triton_kernel_wrap import ( + create_tma_experimental_metadata, + create_tma_stable_metadata, + TMADescriptorMetadata, + TritonHOPifier, +) + + +class DynamoTritonHOPifier(TritonHOPifier): + def raise_unsupported(self, msg: str) -> Never: + unimplemented( + gb_type="triton kernel unsupported feature", + context="", + explanation=f"Encountered triton kernel unsupported feature: {msg}", + hints=[], + ) + + def is_callable(self, maybe_callable: VariableTracker) -> bool: + return isinstance( + maybe_callable, (NestedUserFunctionVariable, UserFunctionVariable) + ) + + def get_value(self, val: VariableTracker) -> Any: + return val.value # type: ignore[attr-defined] + + def check_grid(self, grid: "BaseListVariable") -> tuple[torch.fx.proxy.Proxy, ...]: + from .lists import BaseListVariable + + if isinstance(grid, BaseListVariable): + return grid.as_proxy() + else: + unimplemented( + gb_type="unsupported grid type for triton hop check_grid", + context=f"grid type = {type(grid)}", + explanation="`torch.compile` only supports list-like grid for check_grid", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + def call_grid( + self, grid: Any, meta: dict[str, Any], tx: "InstructionTranslator" + ) -> Any: + meta_var = {variables.ConstantVariable.create(k): v for k, v in meta.items()} + grid = grid.call_function(tx, [meta_var], {}) + return grid + + # We use this function to wrap call_prune_configs + def call_user_defined_fn( + self, + user_fn: Callable[..., Any], + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + tx: Optional["InstructionTranslator"], + variable: Any, + ) -> VariableTracker: + from .builder import SourcelessBuilder + + wrapped_user_function = SourcelessBuilder.create(tx, user_fn) # type: ignore[arg-type] + result = wrapped_user_function.call_function(tx, args, kwargs) + return result + + def wrap_user_defined_obj( + self, + user_obj: Any, + tx: Optional["InstructionTranslator"], + variable: Any, + name: str, + ) -> VariableTracker: + from .builder import VariableBuilder + + wrapped_user_obj = VariableBuilder( + tx, AttrSource(variable.kernel_source, f"{name}") + )._wrap(user_obj) + return wrapped_user_obj + + def maybe_unpack_configs( + self, configs: Any, tx: Optional["InstructionTranslator"] + ) -> list[Any]: + # unpack the list of configs + configs = configs.unpack_var_sequence(tx) + + # guard_as_python_constant inserts guards for Dynamo to check if the configs object changed. + configs = [config.guard_as_python_constant() for config in configs] + + return configs + + def maybe_unpack_heuristic_result(self, result: VariableTracker) -> Any: + if not result.is_python_constant(): + self.raise_unsupported( + "@triton.heuristics must return constant values because configs can only contain constant values." + ) + + return result.guard_as_python_constant() + + # We need to override call_getitem here so that we can add the source in the case + # where we call the triton kernel with a grid + def call_getitem( # type: ignore[override] + self, + variable: "TritonKernelVariable", + args: Sequence[Any], + ) -> "TritonKernelVariable": + # __getitem__ should only be called if we don't already have a grid + # Only grid needs to be passed + if variable.grid is not None or len(args) != 1: + self.raise_unsupported( + "Triton kernels should be called with only a single grid" + ) + return type(variable)( + kernel=variable.kernel, + kernel_idx=variable.kernel_idx, + grid=args[0], + kernel_source=variable.source, + ) + + def call_HOP( + self, + variable: "TritonKernelVariable", + grids: Any, + combined_args_raw: dict[str, Any], + tx: "InstructionTranslator", + ) -> "variables.ConstantVariable": + from .dicts import ConstDictVariable + + # as we can only pass tensors as non-const args in fx graph, + # here we replace TMA descriptors + # (TMADescriptorExperimentalVariable and TMADescriptorStableVariable + # instances) with the underlying tensors, while moving the + # TMA descriptor-related metadata to a separate argument, + # so that we can reconstruct the TMA descriptors downstream + tma_descriptor_metadata: TMADescriptorMetadata = {} + for k in list(combined_args_raw.keys()): + v = combined_args_raw[k] + if isinstance( + v, (TMADescriptorExperimentalVariable, TMADescriptorStableVariable) + ): + tma_descriptor_metadata[k] = v.to_metadata() + combined_args_raw[k] = v.get_tensor() + + combined_args = { + variables.ConstantVariable.create(k): v + for k, v in combined_args_raw.items() + } + + from torch._higher_order_ops.triton_kernel_wrap import ( + kernel_side_table, + triton_kernel_wrapper_mutation, + ) + + # Combine args and kwargs and pass as a dict so that if user defined triton + # kernel uses variables as 'grid' or 'kernel', it does not conflict with + # parameters of the wrapper function + constant_args = { + k: v.as_python_constant() + for k, v in combined_args_raw.items() + if isinstance(v, VariableTracker) and v.is_python_constant() + } + non_constant_args = { + k: v + for k, v in combined_args.items() + if not (isinstance(v, VariableTracker) and v.is_python_constant()) + } + + for v in non_constant_args.values(): + v = v.realize() + if not (v.is_tensor() or v.is_symnode_like()): + self.raise_unsupported( + f"Unexpected argument type for a Triton kernel: {repr(v)}." + ) + + constant_args_idx = kernel_side_table.add_constant_args(constant_args) + meta = ConstDictVariable(non_constant_args, dict) + tx.output.create_proxy( + "call_function", + triton_kernel_wrapper_mutation, + (), + { + "kernel_idx": variable.kernel_idx, + "constant_args_idx": constant_args_idx, + "grid": grids, + "tma_descriptor_metadata": tma_descriptor_metadata, + "kwargs": meta.as_proxy(), + }, + ) + + return variables.ConstantVariable( + None, + ) + + +dynamo_triton_hopifier_singleton = DynamoTritonHOPifier() + + +class TritonKernelVariable(VariableTracker): + grid: "TritonGridType" + kernel: "TritonKernelType" + kernel_idx: Optional[int] + kernel_source: "AttrSource" + + def __init__( + self, kernel: Any, kernel_idx: Optional[int], grid: Any, **kwargs: Any + ) -> None: + self.kernel_source = kwargs.pop("kernel_source", None) + super().__init__(**kwargs) + dynamo_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + return dynamo_triton_hopifier_singleton.call_triton_kernel( # type: ignore[return-value] + self, args, kwargs, tx + ) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__getitem__": + return dynamo_triton_hopifier_singleton.call_getitem(self, args) + elif name == "run": + return dynamo_triton_hopifier_singleton.call_run(self, args, kwargs, tx) # type: ignore[return-value] + + # Bail out to parent's implementation + return super().call_method(tx, name, args, kwargs) + + def specialize_symbolic(self, arg: Any) -> Any: + from .constant import ConstantVariable + from .tensor import SymNodeVariable + + # See [Note: Specialize tl.constexpr args in user-defined triton kernels] + if isinstance(arg, SymNodeVariable): + return ConstantVariable.create(arg.evaluate_expr()) + return arg + + +class TMADescriptorExperimentalVariable(VariableTracker): + def __init__( + self, + data_ptr: "variables.DataPtrVariable", + dims: list[VariableTracker], + block_dims: list[VariableTracker], + element_size: VariableTracker, + **kwargs: Any, + ) -> None: + assert isinstance(data_ptr, variables.DataPtrVariable) + super().__init__(**kwargs) + self.data_ptr = data_ptr + self.dims = dims + self.block_dims = block_dims + self.element_size = element_size + + def to_metadata(self) -> Any: + return create_tma_experimental_metadata( + [dim.as_proxy() for dim in self.dims], + [dim.as_proxy() for dim in self.block_dims], + self.element_size.as_proxy(), + ) + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null( + lambda: codegen.load_import_from( + "triton.tools.experimental_descriptor", + f"create_{len(self.dims)}d_tma_descriptor", + ) + ) + self.data_ptr.reconstruct(codegen) + args = [*self.dims, *self.block_dims, self.element_size] + codegen.foreach(args) + codegen.call_function(len(args) + 1, False) + + def get_tensor(self) -> VariableTracker: + return self.data_ptr.from_tensor + + +class TMADescriptorStableVariable(VariableTracker): + def __init__( + self, + tensor: "TensorVariable", + block_shape: "ListVariable", + **kwargs: Any, + ) -> None: + assert tensor.is_tensor() + super().__init__(**kwargs) + self.tensor = tensor + self.block_shape = block_shape + + def to_metadata(self) -> Any: + return create_tma_stable_metadata( + self.block_shape.as_proxy(), + ) + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null( + lambda: codegen.load_import_from( + "triton.tools.tensor_descriptor", + "TensorDescriptor", + ) + ) + codegen.load_method("from_tensor") + self.tensor.reconstruct(codegen) + codegen(self.block_shape) + codegen.call_method(2) + + def get_tensor(self) -> Any: + return self.tensor + + +class CreateTMADescriptorExperimentalVariable(VariableTracker): + def __init__( + self, + rank: int, + **kwargs: Any, + ) -> None: + assert rank in (1, 2) + super().__init__(**kwargs) + self.rank = rank + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + ptr = kwargs["ptr"] if "ptr" in kwargs else args[0] + + if not isinstance(ptr, variables.DataPtrVariable): + raise Unsupported( + "Please ensure there were no graph breaks between " + f"create_{self.rank}d_tma_descriptor and the upstream " + ".data_ptr() call." + ) + + if self.rank == 1: + if len(args) + len(kwargs) != 4: + raise_type_error_exc( + tx, + f"TMA metadata rank=1 requires exactly 4 arguments, got {len(args) + len(kwargs)}", + ) + dims = [ + kwargs["dim"] if "dim" in kwargs else args[1], + ] + block_dims = [ + kwargs["block_dim"] if "block_dim" in kwargs else args[2], + ] + else: + if len(args) + len(kwargs) != 6: + raise_type_error_exc( + tx, + f"TMA metadata rank=2 requires exactly 6 arguments, got {len(args) + len(kwargs)}", + ) + dims = [ + kwargs["dim1"] if "dim1" in kwargs else args[1], + kwargs["dim0"] if "dim0" in kwargs else args[2], + ] + block_dims = [ + kwargs["block_dim1"] if "block_dim1" in kwargs else args[3], + kwargs["block_dim0"] if "block_dim0" in kwargs else args[4], + ] + element_size = kwargs["element_size"] if "element_size" in kwargs else args[-1] + + return TMADescriptorExperimentalVariable( + data_ptr=ptr, + dims=dims, + block_dims=block_dims, + element_size=element_size, + ) + + +class CreateTMADescriptorStableVariable(VariableTracker): + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + tensor = kwargs["tensor"] if "tensor" in kwargs else args[0] + block_shape = kwargs["block_shape"] if "block_shape" in kwargs else args[1] + + return TMADescriptorStableVariable( + tensor=tensor, # type: ignore[arg-type] + block_shape=block_shape, # type: ignore[arg-type] + ) + + +class PyTreeGetNodeTypeFunctionVariable(UserFunctionVariable): + """ + `torch.utils._pytree._get_node_type` function is very hot function. We want to special case it to reduce Dynamo tracing time. + + def _get_node_type(tree: Any) -> Any: + node_type = type(tree) + # All namedtuple types are implicitly registered as pytree nodes. + # XXX: Other parts of the codebase expect namedtuple types always return + # `namedtuple` instead of the actual namedtuple type. Even if the type + # is explicitly registered. + if is_namedtuple_class(node_type): + return namedtuple + return node_type + """ + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if len(args) != 1: + raise_type_error_exc( + tx, + f"pytree_get_node_type requires exactly 1 argument, got {len(args)}", + ) + type_source = None + if args[0].source: + install_guard(args[0].source.make_guard(GuardBuilder.TYPE_MATCH)) + type_source = TypeSource(args[0].source) + python_type = args[0].python_type() + if is_namedtuple_class(python_type): + type_source = AttrSource(CollectionsSource(), "namedtuple") + return VariableTracker.build(tx, namedtuple, type_source) + return VariableTracker.build(tx, python_type, source=type_source) + + +class PyTreeTreeIsLeafFunctionVariable(UserFunctionVariable): + """ + `torch.utils._pytree.tree_is_leaf` function is a hot function. We want to special case it to reduce Dynamo tracing time. + + def tree_is_leaf( + tree: PyTree, + is_leaf: Callable[[PyTree], bool] | None = None, + ) -> bool: + if is_leaf is not None and is_leaf(tree): + return True + return _get_node_type(tree) not in SUPPORTED_NODES + + When is_leaf is None (the common case), we can optimize by not tracing into the function. + When is_leaf is not None, we fall back to regular tracing since it requires executing user code. + """ + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # tree_is_leaf(tree, is_leaf=None) + if len(args) < 1 or len(args) > 2: + raise_type_error_exc( + tx, + f"tree_is_leaf requires 1 or 2 arguments, got {len(args)}", + ) + + # Check if is_leaf parameter is provided + is_leaf = kwargs.get("is_leaf", ConstantVariable.create(None)) + if len(args) == 2: + is_leaf = args[1] + + if not is_leaf.is_constant_none(): + return super().call_function(tx, args, kwargs) + + # Optimize the case where is_leaf is None + # return _get_node_type(tree) not in SUPPORTED_NODES + tree = args[0] + node_type_var = PyTreeGetNodeTypeFunctionVariable( + torch.utils._pytree._get_node_type + ).call_function(tx, [tree], {}) + + # If the SUPPORTED_NODES was seen earlier and mutated, there would be a + # source and that will give us the mutated SUPPORTED_NODES. + supported_nodes_var = VariableTracker.build( + tx, + torch.utils._pytree.SUPPORTED_NODES, + source=get_pytree_SUPPORTED_NODES_source(), + ) + out = supported_nodes_var.call_method(tx, "__contains__", [node_type_var], {}) + return ConstantVariable.create(not out.value) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/higher_order_ops.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/higher_order_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..253386a94eeee02876fc0ce2fc7ca7036f170aaf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/higher_order_ops.py @@ -0,0 +1,4827 @@ +# mypy: ignore-errors + +""" +This module contains classes and utilities for handling higher-order operators in Dynamo. +It provides functionality for tracing and transforming control flow constructs like +conditions (torch.cond), loops (torch.while_loop), maps (torch.ops.higher_order.map), +and other higher-order operations. + +The module includes specialized VariableTracker classes for different types of +higher-order operations, along with utilities for: +- Speculating and capturing subgraphs +- Managing control flow +- Handling autograd function applications +- Supporting function transformations +- Processing activation checkpoints + +These classes work together to enable Dynamo to correctly trace and compile code +containing complex control flow patterns and higher-order functions while preserving +their semantic behavior. +""" + +import contextlib +import functools +import inspect +import itertools +import logging +import types +import warnings +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, Literal, Optional, TYPE_CHECKING + +import torch._C +import torch.fx +import torch.nn +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.utils import get_fake_value +from torch._dynamo.variables.builtin import BuiltinVariable +from torch._dynamo.variables.constant import ConstantVariable +from torch._dynamo.variables.ctx_manager import RepararametrizeModuleContextVariable +from torch._dynamo.variables.functions import UserFunctionVariable +from torch._dynamo.variables.nn_module import UnspecializedNNModuleVariable +from torch._dynamo.variables.tensor import SymNodeVariable, TensorVariable +from torch._guards import Source +from torch._ops import HigherOrderOperator +from torch.fx.passes.shape_prop import _extract_tensor_metadata +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils import _pytree as pytree + +from .. import graph_break_hints, variables +from ..exc import ( + ObservedException, + UncapturedHigherOrderOpError, + unimplemented, + Unsupported, +) +from ..source import AttrSource, DictGetItemSource +from ..utils import proxy_args_kwargs, set_example_value +from .base import VariableTracker +from .dicts import ConstDictVariable +from .lazy import LazyVariableTracker +from .lists import ListVariable, TupleVariable + + +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslator + + +log = logging.getLogger(__name__) +hc_log = torch._logging.getArtifactLogger(__name__, "hierarchical_compile") + + +@dataclass +class OutputSpec: + """ + Contains the treespec of the output of the speculated subgraph, and the + information to mask out the constant values from the output during + flattening and inserting them back during unflattening. Cleaning up + constants from the graph makes the graph simpler for AOTDispatcher and + Inductor. + """ + + treespec: pytree.TreeSpec + # list of True/False to identify the locations of const values in the + # subgraph output. True means that value at that index is a constant. + masks_to_filter_const_values: Optional[list[bool]] = None + # The actual constant values that were present in the subgraph output. Note + # that this is the same length as the mask, we just look at the indices + # where mask is True. + const_values: Optional[list[Any]] = None + # Number of intermediate nodes that are also made subgraph outputs. + num_intermediate_nodes_as_outputs: int = 0 + + def __post_init__(self): + if ( + self.masks_to_filter_const_values is not None + or self.const_values is not None + ): + assert len(self.masks_to_filter_const_values) == len(self.const_values) + + +def raise_hard_error_if_graph_break(reason): + def deco(fn): + @functools.wraps(fn) + def graph_break_as_hard_error(*args, **kwargs): + try: + return fn(*args, **kwargs) + except (Unsupported, ObservedException) as e: + import sys + + if isinstance(e, Unsupported): + exc = UncapturedHigherOrderOpError( + f"{reason} Got {e.msg}", e.real_stack + ) + else: + msg = e.msg if hasattr(e, "msg") else type(e) + real_stack = e.real_stack if hasattr(e, "real_stack") else None + exc = UncapturedHigherOrderOpError( + f"{reason} Got {msg}", real_stack + ) + raise exc.with_traceback(sys.exc_info()[2]) from None + + return graph_break_as_hard_error + + return deco + + +# This function is a syntax sugar for creating a dummy new subtracer so that +# newly added nodes are added to a separate subgraph in this subtracer instead of affecting +# the main graph. This is useful for creating sample inputs for tracing the subgraph. +# For example, in FlexAttentionHigherOrderVariable, we want to create several scalars +# to trace the score_mod function but we don't want the operators that creates the scalar to +# show up in the graph, we could this function to discard the graph changes. +# Example usage: +# with discard_graph_changes(): +# sample_input= create_sample_inputs() +# speculate_subgraph(tx, f, sample_inputs, {}) +@contextlib.contextmanager +def discard_graph_changes(tx): + ctx = tx.output.subtracer("subgraph_wrapper", None) + try: + ctx.__enter__() + yield + finally: + ctx.__exit__(None, None, None) + + +def check_meta_consistency_vt( + vars1: list[VariableTracker], + vars2: list[VariableTracker], + lhs_name: str, + rhs_name: str, + include_contiguity: bool = True, +) -> None: + from torch._higher_order_ops.utils import check_meta_consistency + + def _unwrap_var(var): + if var.is_tensor(): + return var.proxy.node.meta["example_value"] + elif isinstance(var, SymNodeVariable): + return var.sym_num + elif var.is_python_constant(): + return var.as_python_constant() + else: + unimplemented( + gb_type="cannot unwrap variable for check_meta_consistency", + context=str(var), + explanation=f"Expected {var} to be TensorVariable, SymNodeVariable, or ConstantVariable", + hints=[], + ) + + unwrapped1 = [_unwrap_var(var) for var in vars1] + unwrapped2 = [_unwrap_var(var) for var in vars2] + + return check_meta_consistency( + unwrapped1, + unwrapped2, + lhs_name, + rhs_name, + include_contiguity=include_contiguity, + ) + + +@contextlib.contextmanager +def dynamo_enable_grad(tx: "InstructionTranslator", enable=True): + from . import GradModeVariable + + org_value = torch.is_grad_enabled() + try: + GradModeVariable.create(tx, enable, initialized=True) + yield + finally: + GradModeVariable.create(tx, org_value, initialized=True) + + +@contextlib.contextmanager +def dynamo_allow_side_effects_in_hop(tx: "InstructionTranslator"): + orig_val = tx.output.current_tracer.allow_side_effects_in_hop + try: + tx.output.current_tracer.allow_side_effects_in_hop = True + yield + finally: + tx.output.current_tracer.allow_side_effects_in_hop = orig_val + + +def find_mismatched_vars(var, types, allow_none=False): + """ + Recursively finds variables whose type is not an instance of the specified types. + Args: + var: The variable to check. + types: A tuple of allowed types. + allow_none (bool): Whether to allow None values. Defaults to False. + Returns: + A set of variables whose type is not an instance of the specified types. + """ + mismatched_vars = set() + if isinstance(var, (list, tuple)): + for item in var: + mismatched_vars.update(find_mismatched_vars(item, types, allow_none)) + elif isinstance(var, (TupleVariable, ListVariable)): + for item in var.items: + mismatched_vars.update(find_mismatched_vars(item, types, allow_none)) + elif isinstance(var, ConstDictVariable): + for value in var.items.values(): + mismatched_vars.update(find_mismatched_vars(value, types, allow_none)) + else: + if not isinstance(var, types) and not (allow_none and var.is_constant_none()): + mismatched_vars.add(var) + return mismatched_vars + + +def only_consist_of(var, types, allow_none=False): + mismatch_vars = find_mismatched_vars(var, types, allow_none=allow_none) + return len(mismatch_vars) == 0 + + +# A more read-able syntax sugar for creating a UserFunctionVariable for f +# and run call_function on it. Make it return a function to preserve the calling +# convention of the original f. +def _make_inlined(tx: "InstructionTranslator", f): + assert callable(f), "Expect f to be a python callable." + + def inline_call(*args, **kwargs): + return UserFunctionVariable(f).call_function(tx, args, kwargs) + + return inline_call + + +def _call_function_with_auto_output_flattening( + tx: "InstructionTranslator", + fn: Any, + args: tuple[Any, ...], + kwargs: dict[str, Any], + flat_example_value: Any, + body_r: Optional[VariableTracker], + graph_output_vts: VariableTracker | tuple[VariableTracker, ...], +) -> Optional[VariableTracker]: + """ + Create HOP call node and reproxify output VTs for HOPs with auto output semantics. + + This function is used by HOPs with auto output semantics (see speculate_subgraph_with_auto_output_flattening) + to create the actual HOP call in the FX graph and properly handle the output variable trackers. + + The key operation is "reproxifying" - updating the proxies of the original tensor VTs + (from body_r) to point to the HOP call outputs, ensuring the outer graph correctly + references the HOP outputs while allowing body_r to contain arbitrary Python objects. + + Args: + tx: The instruction translator + fn: The HOP function to call + args: Arguments for the HOP call (typically includes the subgraph node) + kwargs: Keyword arguments for the HOP call + flat_example_value: Example value for the HOP output + body_r: The output VT structure that Dynamo continues tracing with (may be None) + graph_output_vts: Tensor/symint VTs that were actual graph outputs + + Returns: + The body_r VT (unchanged), which Dynamo will continue tracing with + """ + from .builder import wrap_fx_proxy + + # Store the invocation as a call + flat_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn, + args=args, + kwargs=kwargs, + ), + example_value=flat_example_value, + ) + + # wrap_fx_proxy creates fresh variable trackers. However, the main program + # after the speculate subgraph can still use the original tensor vts that + # are still pointing to the nodes present in the subgraph. So, we reproxify + # the original tensor vts with the subgraph outputs. This way, whenever the + # outer graph uses an original vt, it uses the subgraph output. + # + # This is critical for maintaining the separation between: + # - `body_r`: The output VT structure that Dynamo continues tracing (may + # contain non-proxyable objects, nested structures, etc.) + # - `graph_output_vts`: Only the tensor/symint VTs that were actual graph + # outputs from speculate_subgraph + # + # By overwriting the proxies of VTs in `body_r` with the proxies from the + # HOP call, we ensure the outer graph correctly references the HOP outputs + # while still allowing `body_r` to contain arbitrary Python objects. + if body_r is not None: + for orig_vt, subgraph_vt in zip(graph_output_vts, flat_variable.items): + if orig_vt.is_tensor() or isinstance(orig_vt, SymNodeVariable): + assert subgraph_vt.is_tensor() or isinstance( + subgraph_vt, SymNodeVariable + ) + orig_vt.proxy = subgraph_vt.proxy + return body_r + + +def _call_function_and_unflatten_output( + tx, fn, args, kwargs, flat_example_value, ret_spec, body_r +): + from .builder import wrap_fx_proxy + + # Store the invocation as a call + flat_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn, + args=args, + kwargs=kwargs, + ), + example_value=flat_example_value, + ) + + # wrap_fx_proxy creates fresh variable trackers. However, the main program + # after the speculate subgraph can still use the original tensor vts that + # are still pointing to the nodes present in the subgraph. So, we reproxify + # the original tensor vts with the subgraph outputs. This way, whenever the + # outer graph uses an original vt, it uses the subgraph output. + if body_r is not None: + for orig_vt, subgraph_vt in zip(body_r.items, flat_variable.items): + if orig_vt.is_tensor() or isinstance(orig_vt, SymNodeVariable): + assert subgraph_vt.is_tensor() or isinstance( + subgraph_vt, SymNodeVariable + ) + orig_vt.proxy = subgraph_vt.proxy + + if ret_spec.num_intermediate_nodes_as_outputs: + # The treespec was computed w/o any extra intermediate outputs. At this + # point, it is safe to just get rid of the extra outputs + flat_variable = TupleVariable( + flat_variable.items[: -ret_spec.num_intermediate_nodes_as_outputs] + ) + + if ret_spec.masks_to_filter_const_values: + from torch._dynamo.external_utils import insert_const_values_with_mask + + # During flattening, we removed the constant values. To ensure Dynamo + # can trace correctly, insert back the constant values in the output. + flat_variable = _make_inlined(tx, insert_const_values_with_mask)( + flat_variable, ret_spec.masks_to_filter_const_values, ret_spec.const_values + ) + + # Transform variable back into a list (previously made into a tuple by + # speculate_subgraph function) so as to respect the pytree API typing. + flat_list_variable = BuiltinVariable(list).call_function(tx, [flat_variable], {}) + return ( + _make_inlined(tx, pytree.tree_unflatten)(flat_list_variable, ret_spec.treespec) + if ret_spec.treespec + else flat_variable + ) + + +def _assert_tensors_nonaliasing(inputs, outputs): + input_tensor_ids = { + id(t) for t in pytree.tree_leaves(inputs) if isinstance(t, torch.Tensor) + } + output_tensor_ids = { + id(t) for t in pytree.tree_leaves(outputs) if isinstance(t, torch.Tensor) + } + assert input_tensor_ids.isdisjoint(output_tensor_ids), ( + "inputs to function body cannot alias outputs" + ) + + +def get_tensor_storages(tensor: torch.Tensor) -> set[StorageWeakRef]: + """ + Get storage references from a tensor. + + Handles regular tensors. Raises NotImplementedError for sparse tensors + and traceable wrapper subclasses. + + Args: + tensor: The tensor to extract storages from + + Returns: + Set of StorageWeakRef objects for the tensor's storage(s) + """ + from torch.multiprocessing.reductions import StorageWeakRef + from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + storages: set[StorageWeakRef] = set() + + if not isinstance(tensor, torch.Tensor): + return storages + + if tensor.is_sparse or tensor.is_sparse_csr: + raise NotImplementedError("get_tensor_storages does not support sparse tensors") + + if is_traceable_wrapper_subclass(tensor): + raise NotImplementedError( + "get_tensor_storages does not support traceable wrapper subclasses" + ) + else: + storages.add(StorageWeakRef(tensor._typed_storage())) + + return storages + + +class StorageAliasingTracker: + """ + Tracks storage references to detect aliasing between tensors. + + This class encapsulates the logic for collecting storages from tensors + and checking for aliasing conflicts. Used to filter intermediate outputs + that would create input-output or output-output aliasing. + """ + + def __init__(self): + self.excluded_storages: set = set() + + def _collect_storages_from_tensor(self, example_value): + self.excluded_storages.update(get_tensor_storages(example_value)) + + def collect_from_inputs(self, tx): + """Collect storages from graph input placeholders.""" + from torch._higher_order_ops.utils import _collect_fake_inputs + + for node in tx.output.graph.nodes: + if node.op == "placeholder": + example_value = _collect_fake_inputs([node])[0] + if isinstance(example_value, torch.Tensor): + self._collect_storages_from_tensor(example_value) + else: + break + + def collect_from_outputs(self, graph_output_vts): + """Collect storages from existing graph outputs.""" + from torch._higher_order_ops.utils import _collect_fake_inputs + + for vt in graph_output_vts: + proxy = vt.as_proxy() + example_value = _collect_fake_inputs([proxy.node])[0] + if isinstance(example_value, torch.Tensor): + self._collect_storages_from_tensor(example_value) + + def check_and_track(self, proxy_node) -> bool: + """ + Check if a tensor can be added as a subgraph output without causing aliasing issues. + + Given a proxy node, extracts its example tensor value and checks if its storage + aliases with any previously tracked storages (from inputs or other outputs). + If there's no aliasing conflict, the tensor's storage is added to the tracked set. + + Args: + proxy_node: An FX proxy node whose example_value is the tensor to check. + + Returns: + True if the tensor doesn't alias with tracked storages (safe to add as output), + False if it aliases (should be filtered out). + """ + from torch._higher_order_ops.utils import _collect_fake_inputs + from torch.multiprocessing.reductions import StorageWeakRef + from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + example_value = _collect_fake_inputs([proxy_node])[0] + + # Non-tensor outputs (e.g., symints) don't have aliasing concerns + if not isinstance(example_value, torch.Tensor): + return True + + # Check if any storage aliases with existing inputs/outputs + tensor_storages = get_tensor_storages(example_value) + if tensor_storages & self.excluded_storages: + return False + + # Track this tensor's storage (for wrapper subclasses, inner storages were already checked) + if not is_traceable_wrapper_subclass(example_value): + if not (example_value.is_sparse or example_value.is_sparse_csr): + self.excluded_storages.add( + StorageWeakRef(example_value._typed_storage()) + ) + + return True + + +def collect_intermediate_outputs( + tx, subtracer, graph_output_vts, filter_aliased_intermediates=False +): + extra_outputs = [] + existing_out_proxies = {vt.as_proxy() for vt in graph_output_vts} + + # Build the aliasing tracker if we're filtering + tracker = None + if filter_aliased_intermediates: + tracker = StorageAliasingTracker() + tracker.collect_from_inputs(tx) + tracker.collect_from_outputs(graph_output_vts) + + for out in subtracer.tracked_tensor_or_symint_vt: + proxy = out.as_proxy() + + # Skip if already in output + if proxy in existing_out_proxies: + continue + + # TODO floats are not supported in HOP input/output + if isinstance(out, SymNodeVariable) and out.python_type() is float: + continue + + if not filter_aliased_intermediates: + extra_outputs.append(out) + else: + # Filter out intermediates that alias with inputs or outputs. + # This is needed for HOPs like invoke_subgraph that don't support aliasing. + # TODO: If a filtered intermediate is captured by side effects (e.g., appended + # to a list), it will fail later with "does not belong to this Graph" error + # when the outer graph tries to use it. See test_side_effect_with_aliased_intermediate. + if tracker.check_and_track(proxy.node): + extra_outputs.append(out) + + return extra_outputs + + +def _check_all_tensorvariable(args): + if not all(type(a.realize()) is TensorVariable for a in args): + unimplemented( + gb_type="HOP: non torch.Tensor leaf", + context=f"args types: {[type(a.realize()) for a in args]}", + explanation="Expected all leaves to be of torch.Tensor type.", + hints=[], + ) + + +def _check_supported_callable_arg( + tx: "InstructionTranslator", func_var: VariableTracker, arg_name +): + is_callable = ( + BuiltinVariable(callable).call_function(tx, [func_var], {}).as_python_constant() + ) + if not is_callable: + unimplemented( + gb_type="HOP: non-callable variable", + context=f"arg name: {arg_name}, func_var type: {str(func_var)}", + explanation=f"{arg_name} should be a callable but is of type {str(func_var)}.", + hints=[], + ) + + +def _call_while_loop( + self: VariableTracker, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + stack_output: bool, +) -> VariableTracker: + from torch._higher_order_ops.while_loop import _create_unbacked_symint + + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + cond_fn, body_fn, operands, additional_inputs = args + + # Input checks + for i, k in enumerate(["cond_fn", "body_fn", "operands"]): + if v := kwargs.pop(k, None): + assert i == len(args), ( + "did not provide the right number of non-keyword args" + ) + args.append(v) + + if kwargs or len(args) != 4: + unimplemented( + gb_type="torch.while_loop: improper args/kwargs", + context=f"args: {args}, kwargs: {kwargs}", + explanation=f"torch.while_loop expects 4 positional arguments (got {len(args)}) " + f"and no keyword arguments (got {len(kwargs)}) " + "Usage: while_loop(cond_fn, body_fn, operands)", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + # cond_fn and body_fn input check + _check_supported_callable_arg(tx, cond_fn, "cond_fn") + _check_supported_callable_arg(tx, body_fn, "body_fn") + + # operands input check + operands_seq = operands.unpack_var_sequence(tx) + + # additional_inputs input check + if not isinstance(additional_inputs, (ListVariable, TupleVariable)): + unimplemented( + gb_type="torch.while_loop: improper additional_inputs", + context=str(additional_inputs), + explanation=f"Expected additional_inputs to be a list/tuple but got {additional_inputs.python_type()}", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + additional_inputs_seq = additional_inputs.unpack_var_sequence(tx) + + with discard_graph_changes(tx): + # Note: this must be run under discard graph changes. + def unspecialize_carried_inputs(tx, carry) -> VariableTracker: + # See NOTE [unspecialize int carry with unbacked symints] + if ( + carry.is_python_constant() + and isinstance(carry.as_python_constant(), int) + ) or isinstance(carry, SymNodeVariable): + example_value = _create_unbacked_symint( + tx.output.fake_mode, ignore_fresh_unbacked_symbols=True + ) + proxy = tx.output.current_tracer.create_graph_input( + "unbacked_symint", type(example_value), example_value + ) + return SymNodeVariable.create(tx, proxy, example_value) + else: + # See NOTE [unspecialize constant tensor carry] + assert carry.is_tensor() + cloned_carry = carry.clone() + cloned_carry.proxy.node.meta["example_value"].constant = None + return cloned_carry + + # clone inputs across subgraphs, to avoid unbacked memoization in fake prop + cond_operands_seq = [ + unspecialize_carried_inputs( + tx, + ( + carry.call_method(tx, "clone", args=(), kwargs={}) + if carry.is_tensor() + else carry + ), + ) + for carry in operands_seq + ] + body_operands_seq = [ + unspecialize_carried_inputs( + tx, + ( + carry.call_method(tx, "clone", args=(), kwargs={}) + if carry.is_tensor() + else carry + ), + ) + for carry in operands_seq + ] + + # create cond subgrpahs + ( + (cond_r, _cond_treespec), + cond_graph, + cond_lifted_freevars, + ) = speculate_subgraph( + tx, + cond_fn, + cond_operands_seq + additional_inputs_seq, + {}, + "while_loop", + source_target=self.value, + # NOTE [why we cannot use "automatic" for while_loop]: + # The reason is that we want to enforce + # the ordering of inputs and outputs to be consistent and the ordering + # of cond_fn and body_fn to the consistent. + # e.g. suppose we use "automatic" and we have: + # + # def body_fn(ph1, ph2): + # new_a, new_b = ph2.cos(), ph1.sin() + # return new_a, new_b + # + # a, b = torch.randn(3), torch.randn(3) + # new_a, new_b = body_fn(a, b) + # + # Using automatic, the ordering of arguments will be the order that they're + # used. In this example, the capture graph looks like: + # + # def captured_body(ph1, ph2): + # new_a, new_b = ph1.cos(), ph2.add_(1) + # return new_a, new_b + # + # This is fine when we change the calling convention of captured_body to be + # new_a, new_b = captured_body(b, a). + # But for while_loop, the next iteration's input is previous iteration output + # we'll end up feeding captured_body(new_a, new_b) instead. + # So it's best we always enforce the ordering of carried_inputs the same as outputs + # with "flatten_manual". + set_subgraph_inputs="flatten_manual", + supports_input_mutation=self.supports_input_mutation, + supports_aliasing=self.supports_aliasing, + remove_consts_from_outputs=False, + ) + cond_nn_modules = dict(tx.output.nn_modules) + validate_subgraph_output_types(cond_r) + if cond_r.is_tensor(): + cond_r_meta = _extract_tensor_metadata( + cond_r.proxy.node.meta["example_value"], include_contiguity=False + ) + if cond_r_meta.dtype != torch.bool or cond_r_meta.shape != torch.Size([]): + unimplemented( + gb_type="torch.while_loop: unsupported cond_fn return type", + context=str(cond_r), + explanation=f"Expected cond_fn to return a scalar tensor or a bool but got {cond_r_meta.shape}.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + elif cond_r.is_python_constant(): + # short-circuiting while_loop when cond_fn returns a constant such as 0, 1 True or False + pred = cond_r.as_python_constant() + if pred: + unimplemented( + gb_type="torch.while_loop: infinite loop detected", + context=str(cond_r), + explanation=f"Infinite loop detected because while_loop's cond_fn always returns the same value {pred}.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + else: + return operands + + # create body subgraph + ( + (body_r, body_treespec), + body_graph, + body_lifted_freevars, + ) = speculate_subgraph( + tx, + body_fn, + body_operands_seq + additional_inputs_seq, + {}, + "while_loop", + source_target=self.value, + set_subgraph_inputs="flatten_manual", + should_flatten_outputs=True, + supports_input_mutation=False, + supports_aliasing=False, + remove_consts_from_outputs=False, + ) + validate_subgraph_output_types(body_r) + + # We set include contiguity=False because we have vmap x HOP tests, where if + # include_contiguity=True will call t.is_contiguous inside of vmap and get an error + # "querying is_contiguous inside of vmap for memory_format other than + # torch.contiguous_format is not yet implemented". This is okay because stride + # is still checked. + check_meta_consistency_vt( + body_r.unpack_var_sequence(tx), + operands_seq, + "body_fn_output", + "carried_inputs", + include_contiguity=False, + ) + + ( + cond_graph, + body_graph, + cond_shared, + _body_shared, + cond_unique, + body_unique, + ) = _merge_graph_inputs( + cond_graph, + cond_lifted_freevars, + "cond_fn", + body_graph, + body_lifted_freevars, + "body_fn", + ) + + # Note: cond_shared and body_shared refer to the same proxy in parent graph + # so using either of them is OK. Use cond_shared as it doesn't matter. + additional_lifted_inputs = cond_shared + cond_unique + body_unique + + body_nn_modules = dict(tx.output.nn_modules) + + cond_gm = torch.fx.GraphModule(cond_nn_modules, cond_graph) + body_gm = torch.fx.GraphModule(body_nn_modules, body_graph) + cond_name = tx.output.install_subgraph("cond_fn", cond_gm) + body_name = tx.output.install_subgraph("body_fn", body_gm) + + cond_node = make_attr(tx, cond_name) + body_node = make_attr(tx, body_name) + + operands_proxy = tuple(operand.as_proxy() for operand in operands_seq) + additional_inputs_proxy = tuple( + [inp.as_proxy() for inp in additional_inputs_seq] + additional_lifted_inputs + ) + p_args = ( + cond_node, + body_node, + operands_proxy, + additional_inputs_proxy, + ) + return _call_function_and_unflatten_output( + tx, + self.value, + p_args, + {}, + None, + body_treespec, + body_r, + ) + + +def are_same_graph_modules(fn_name, a_mod, b_mod, fake_mode): + from torch._subclasses._fake_tensor_utils import _CacheKeyState + from torch._subclasses.fake_tensor import extract_tensor_metadata + + # Maps the equivalent nodes from a to b + node_map = {} + + def check_all_args(a_nodes, b_nodes): + for arg_a, arg_b in zip(a_nodes, b_nodes): + if isinstance(arg_a, torch.fx.Node): + if node_map[arg_a] != arg_b: + return False + elif isinstance(arg_a, slice): + if not isinstance(arg_b, slice): + return False + if not check_all_args( + (arg_a.start, arg_a.stop, arg_a.step), + (arg_b.start, arg_b.stop, arg_b.step), + ): + return False + elif arg_a != arg_b: + # This is a catch-all for everything else. `slice` was a + # surprise but can there be other data structures that can + # contain fx.Nodes in them? + return False + return True + + for a_node, b_node in zip(a_mod.graph.nodes, b_mod.graph.nodes): + if a_node.op != b_node.op: + return False + + if a_node.op == "placeholder": + a_value = a_node.meta["example_value"] + b_value = b_node.meta["example_value"] + + if isinstance(a_value, torch.Tensor): + if not isinstance(b_value, torch.Tensor): + return False + # Extract fake tensor metadata for a and b and then compare + a_result = [] + state = _CacheKeyState(fake_mode.shape_env) + a_metadata = extract_tensor_metadata(a_value) + a_metadata._flatten_into(a_result, fake_mode, state) + + b_result = [] + state = _CacheKeyState(fake_mode.shape_env) + b_metadata = extract_tensor_metadata(b_value) + b_metadata._flatten_into(b_result, fake_mode, state) + if a_result != b_result: + return False + elif isinstance(a_value, torch.SymInt): + if not isinstance(b_value, torch.SymInt): + return False + if a_value is not b_value: + return False + elif a_node.op == "call_function": + if a_node.target is not b_node.target: + return False + a_flat, _ = pytree.tree_flatten((a_node.args, a_node.kwargs)) + b_flat, _ = pytree.tree_flatten((b_node.args, b_node.kwargs)) + if not check_all_args(a_flat, b_flat): + hc_log.debug( + "%s: Graph comparison failed at node (call_function): %s", + fn_name, + a_node, + ) + return False + elif a_node.op == "call_method": + if a_node.target != b_node.target: + return False + a_flat, _ = pytree.tree_flatten((a_node.args, a_node.kwargs)) + b_flat, _ = pytree.tree_flatten((b_node.args, b_node.kwargs)) + if not check_all_args(a_flat, b_flat): + hc_log.debug( + "%s: Graph comparison failed at node (call_method) : %s", + fn_name, + a_node, + ) + return False + elif a_node.op == "output": + a_flat, _ = pytree.tree_flatten((a_node.args, a_node.kwargs)) + b_flat, _ = pytree.tree_flatten((b_node.args, b_node.kwargs)) + if not check_all_args(a_flat, b_flat): + hc_log.debug("%s: Graph comparison failed at the output node", fn_name) + return False + elif a_node.op == "get_attr": + a_attr = getattr(a_mod, a_node.target) + b_attr = getattr(b_mod, b_node.target) + if isinstance(a_attr, torch.fx.GraphModule): + if not isinstance(b_attr, torch.fx.GraphModule): + return False + # This is an example of a HOP inside a HOP + if not are_same_graph_modules(fn_name, a_attr, b_attr, fake_mode): + return False + else: + # TODO - write an example with tensor as a graph attribute in + # the Fx graph + raise NotImplementedError(f"get_attr with {type(a_attr)}") + else: + # TODO - call_module is not supported because Dynamo Fx graph does + # not install a call_module + raise NotImplementedError(f"Graph equivalence check saw a {a_node.op}") + + # Two nodes are equal - add them to them map + node_map[a_node] = b_node + + return True + + +def validate_args_and_maybe_create_graph_inputs( + sub_args, + tracer, + tx, + set_subgraph_inputs, + description, + sub_args_names=None, +): + from . import AutogradFunctionContextVariable + from .builder import wrap_fx_proxy_cls + + assert tracer.parent is not None + + if set_subgraph_inputs == "flatten_manual": + flat_args, tree_spec = _make_inlined(tx, pytree.tree_flatten)( + ListVariable(sub_args) + ).unpack_var_sequence(tx) + + flat_inputs = validate_args_and_maybe_create_graph_inputs( + flat_args.unpack_var_sequence(tx), + tracer, + tx, + set_subgraph_inputs="manual", + description=description, + ) + + return _make_inlined(tx, pytree.tree_unflatten)( + ListVariable(flat_inputs), tree_spec + ).unpack_var_sequence(tx) + else: + if sub_args_names is not None: + # Can be greater if user passes some args as kwargs + assert len(sub_args_names) >= len(sub_args) + args = [] + for idx, a in enumerate(sub_args): + assert isinstance(a, VariableTracker) + if set_subgraph_inputs == "automatic": + args.append(a) + continue + elif set_subgraph_inputs == "automatic_with_forced_inputs": + if isinstance(a, variables.TensorVariable): + node = a.maybe_fx_node() + example_value = node.meta["example_value"] + arg_name = ( + a.as_proxy().node.name + if sub_args_names is None + else sub_args_names[idx] + ) + new_proxy = tracer.create_graph_input( + arg_name, a.python_type(), example_value + ) + example_value = node.meta.get("example_value", None) + a = wrap_fx_proxy_cls( + target_cls=type(a), + tx=tx, + proxy=new_proxy, + example_value=example_value, + ) + elif set_subgraph_inputs == "semi_automatic": + if isinstance(a, AutogradFunctionContextVariable): + example_value = a.as_proxy().node.meta["example_value"] + arg_name = ( + a.as_proxy().node.name + if sub_args_names is None + else sub_args_names[idx] + ) + tracer.create_graph_input(arg_name, a.python_type(), example_value) + elif a.maybe_fx_node() is not None: + node = a.maybe_fx_node() + example_value = node.meta["example_value"] + arg_name = ( + a.as_proxy().node.name + if sub_args_names is None + else sub_args_names[idx] + ) + new_proxy = tracer.create_graph_input( + arg_name, a.python_type(), example_value + ) + example_value = node.meta.get("example_value", None) + a = wrap_fx_proxy_cls( + target_cls=type(a), + tx=tx, + proxy=new_proxy, + example_value=example_value, + ) + args.append(a) + continue + + if a.is_python_constant(): + # This arg is not used in the body of the higher order op. + # Currently, this new input is added to make the calls + # happy, which expect a fixed number of arguments. In + # future, we can clean this up. + arg_name = ( + "const_unused" + if sub_args_names is None + else f"const_unused_{sub_args_names[idx]}" + ) + tracer.create_graph_input( + arg_name, a.python_type(), a.as_python_constant() + ) + new_arg = a + # Weird special case, we probably want to delete it or fold it + # into the next case (of `a` being placeable into a graph) + elif isinstance(a, AutogradFunctionContextVariable): + example_value = a.as_proxy().node.meta["example_value"] + arg_name = ( + a.as_proxy().node.name + if sub_args_names is None + else sub_args_names[idx] + ) + tracer.create_graph_input(arg_name, a.python_type(), example_value) + new_arg = a + # If `a` can be put into a graph + elif a.maybe_fx_node() is not None: + node = a.maybe_fx_node() + example_value = node.meta.get("example_value", None) + arg_name = node.name if sub_args_names is None else sub_args_names[idx] + new_proxy = tracer.create_graph_input( + arg_name, a.python_type(), example_value + ) + new_arg = wrap_fx_proxy_cls( + target_cls=type(a), + tx=tx, + proxy=new_proxy, + example_value=example_value, + ) + # If `a` cannot be put into a graph + else: + # HOPs work much better if they use speculate_subgraph(set_subgraph_inputs="automatic"). + unimplemented( + gb_type="HOP body taking non-Tensor as input", + context=str(sub_args), + explanation=f"{description} with body that accepts non-Tensors as input. " + f"Got type {a.python_type()} at index {idx}.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + args.append(new_arg) + return args + + +# This helper function is used to make sure two graphs share the same input signature. For example, +# in torch.cond, two branches might lift different set of tensors as inputs. This function helps to +# dedup the inputs and modify the graphs to take the same set of inputs. +def _merge_graph_inputs( + l_graph, l_lifted_freevars, l_name, r_graph, r_lifted_freevars, r_name +): + def dedup_and_sort_lifted_freevars(l_lifted_freevars, r_lifted_freevars): + # The nn module attributes are guaranteed to be registered into the top-level graph module during + # higher order op speculation. Therefore, get_attr nodes in two branches with the same + # target refer to the same attribute and we can safely deduplicate them with their target. + # + # Note: ideally, dynamo should just create a single proxy for the same attribute of a nn module. But + # true_branch and false_branch belong to two separate tracing contexts, they may register the same + # attribute to top level separately. This creates two get_attr proxies for the same attribute + # that have different meta data such as stack_trace (one stack trace for the true_branch, + # and the other for false_branch). It seems better to discard the proxy explicitly in cond + # than make dynamo create a single proxy for the same get_attr target. + def shared_getattrs(l_lifted_proxies, r_lifted_proxies): + true_targets = { + proxy.node.target: proxy + for proxy in l_lifted_proxies + if proxy.node.op == "get_attr" + } + l_shared_getattrs = {} + r_shared_getattrs = {} + + for false_proxy in r_lifted_proxies: + if ( + false_proxy.node.op == "get_attr" + and false_proxy.node.target in true_targets + ): + true_proxy = true_targets[false_proxy.node.target] + l_shared_getattrs[true_proxy] = true_proxy + r_shared_getattrs[false_proxy] = true_proxy + return l_shared_getattrs, r_shared_getattrs + + l_shared_getattrs, r_shared_getattrs = shared_getattrs( + l_lifted_freevars.keys(), r_lifted_freevars.keys() + ) + + l_shared_freevars = (l_lifted_freevars.keys() & r_lifted_freevars.keys()).union( + l_shared_getattrs.keys() + ) + r_shared_freevars = (l_lifted_freevars.keys() & r_lifted_freevars.keys()).union( + r_shared_getattrs.keys() + ) + unique_l_freevars = l_lifted_freevars.keys() - l_shared_freevars + unique_r_freevars = r_lifted_freevars.keys() - r_shared_freevars + + def _sort_by_name(vars): + return sorted(vars, key=lambda var: var.node.name) + + return ( + list(_sort_by_name(list(l_shared_freevars))), + list(_sort_by_name(list(r_shared_freevars))), + list(_sort_by_name(list(unique_l_freevars))), + list(_sort_by_name(list(unique_r_freevars))), + ) + + (l_shared, r_shared, unique_l, unique_r) = dedup_and_sort_lifted_freevars( + l_lifted_freevars, r_lifted_freevars + ) + + # Let's say we capture cond(pred, true_fn, false_fn, (x,)) + # With set_graph_input set to automatic, + # true_fn has lifted variables x, a, b, c + # false_fn has lifted variables x, a, b, d + # Then fixup_branch_inps make sure both branches have the same signature, i.e.: + # - true_fn(x, a, b, c_true_branch, d_false_branch) + # - false_fn(x, a, b, c_true_branch, d_false_branch) + # + # More formally, the signature has three parts in the following order: + # 1. used in both branches: x, a, b + # 2. only used in true branches: c, suffixed with _true_branch + # 3. only used in false branches: d, suffixed with _false_branch + # Within each part, we re-order the nodes by name to have a derterministic ordering for testing. + def fixup_branch_inps(graph, lifted_freevars, shared, unique_l, unique_r): + def _insert_or_replace_phs(new_args, name_suffix): + for arg in new_args: + new_ph = graph.placeholder(arg.node.name + name_suffix) + new_ph.meta = arg.node.meta + # Override with new_ph if there exists a old placeholder. + if arg in lifted_freevars: + old_ph = lifted_freevars[arg].node + old_ph.replace_all_uses_with(new_ph) + # replace_all_uses_with doesn't clean users. Clean it manually so that we could erase it. + old_ph.users = {} + graph.erase_node(old_ph) + + first_not_ph_node = next( + node for node in graph.nodes if node.op != "placeholder" + ) + with graph.inserting_before(first_not_ph_node): + _insert_or_replace_phs(shared, "") + _insert_or_replace_phs(unique_l, "_" + l_name) + _insert_or_replace_phs(unique_r, "_" + r_name) + + fixup_branch_inps(l_graph, l_lifted_freevars, l_shared, unique_l, unique_r) + fixup_branch_inps(r_graph, r_lifted_freevars, r_shared, unique_l, unique_r) + return l_graph, r_graph, l_shared, r_shared, unique_l, unique_r + + +# NOTE: [HigherOrderOperator subgraph input ordering] +# The input ordering of the higher order ops is determined by the order of +# the creation of the placeholder. +# Manually created inputs are created in validate_args_and_maybe_create_graph_inputs before +# speculating subgraph. +# During subgraph speculation, we may lift closured tensors and free symbols as inputs, +# their ordering is determined by the time they are lifted: earlier lifted ones precede later +# lifted ones. +# +# Suppose the placeholders are +# O1, O2, X1, O3, O4, X2, X3, O5 where Xs are lifted phs +# The following code re-order the placeholders to +# O1, O2, O3, O4, O5, X1, X2, X3 +def move_lifted_freevars_phs_to_end( + graph: torch.fx.Graph, lifted_freevars: dict[Any, torch.fx.Node] +): + lifted_ph_set = {child_p.node for child_p in lifted_freevars.values()} + + prev_phs = [n for n in graph.nodes if n.op == "placeholder"] + + # No need to reorder when graph doesn't have args or doesn't + # have lifted freevars or all inputs are lifted freevars. + if ( + len(prev_phs) == 0 + or len(lifted_ph_set) == 0 + or len(prev_phs) == len(lifted_ph_set) + ): + return + + # Step 1: find first X1 + for x1 in prev_phs: + if x1 in lifted_ph_set: + break + + assert x1 is not None and x1.op == "placeholder" + # Step 2: starting from the X1, skip Xs and prepend Os before X1. + cand_x = x1.next + while cand_x is not None and cand_x.op == "placeholder": + if cand_x in lifted_ph_set: + cand_x = cand_x.next + else: + nxt = cand_x.next + cand_x._remove_from_list() + x1.prepend(cand_x) + cand_x = nxt + + # Step 3: assert that all placeholders are in the correct order as . + # in lifted_freevars + after_phs = [node for node in graph.nodes if node.op == "placeholder"][ + -len(lifted_freevars) : + ] + assert len(after_phs) == len(lifted_freevars) + for child_proxy, ph in zip(lifted_freevars.values(), after_phs): + assert child_proxy.node is ph, ( + "The order of placeholders is different from the order of lifted_freevars" + ) + + graph.lint() + + +def check_aliasing_and_input_mutation( + subtracer, graph, supports_input_mutation, supports_aliasing, source_target +): + if not supports_input_mutation: + mutation_info = subtracer.has_input_mutation() + if mutation_info.has_mutation: + context = f"{mutation_info.msg} in\n {graph}" + unimplemented( + gb_type="Encountered input mutation during higher order op tracing", + context=context, + explanation=f"Higher order ops do not support input mutation. Found in {source_target.name}", + hints=[ + "Consider using the debug context to change user code to avoid mutation.", + "Please open an issue.", + ], + ) + + if not supports_aliasing: + aliasing_info = subtracer.has_aliasing() + if aliasing_info.has_aliasing: + context = f"{aliasing_info.msg} in\n {graph}" + unimplemented( + gb_type="Encountered aliasing during higher order op tracing", + context=context, + explanation=f"Higher order ops do not support aliasing. Found in {source_target.name}", + hints=[ + "Replace `return input` with `return input.clone()` to avoid aliasing.", + "Consider using the debug context to change user code to avoid aliasing.", + "Please open an issue.", + ], + ) + + +def trace_hop_function( + f, + tx, + subtracer, + enable_grad, + restore_side_effects, + args, + sub_kwargs, +): + # For autograd.Function and other legacy HOPs, we do NOT couple + # restore_side_effects with allow_side_effects_in_hop. + # This preserves the old behavior where: + # - restore_side_effects=False means ctx mutations persist + # - But non-ctx side effects still cause graph breaks (under_activation_checkpoint was False) + enable_side_effects_with_extra_outputs = False + + autograd_ctx = ( + dynamo_enable_grad(tx, enable_grad) + if enable_grad is not None + else contextlib.nullcontext() + ) + side_effects_ctx = ( + dynamo_allow_side_effects_in_hop(tx) + if enable_side_effects_with_extra_outputs + else contextlib.nullcontext() + ) + + # For handling side effects, we can make an argument that we don't + # have to do anything here. The side effects infra does a good job + # of graph breaking if we mutate any nonlocal or global variable + # while subtracing. As a result if tracing succeeds, side effects + # data structure will only contain read-only data structures that + # are put there for tracking purposes. + # But on the other hand, there is an argument that if we ever write + # a new side effect in Dynamo which does not go through the side + # effect infra, we can end up in bad state. + # Therefore we restore the side effects after tracing. The catch is + # that we have to special handle tensor variables. If we have seen a + # nonlocal variable tensor during subtracing, we want to keep a + # track of that tensor, so that later subtracing or the root tracer + # itself does not create a new proxy for the already observed tensor + # variable. + if restore_side_effects: + prev_side_effects = tx.output.side_effects.clone() + + with autograd_ctx, side_effects_ctx: + output = f.call_function(tx, args, sub_kwargs) + + if restore_side_effects: + new_side_effects = tx.output.side_effects.clone() + prev_side_effects.track_runahead_tensor_and_symvar_side_effects( + new_side_effects + ) + tx.output.side_effects = prev_side_effects + return output + + +def trace_hop_function_with_auto_output_flattening( + f, + tx, + subtracer, + enable_grad, + allow_side_effects, + args, + sub_kwargs, +): + autograd_ctx = ( + dynamo_enable_grad(tx, enable_grad) + if enable_grad is not None + else contextlib.nullcontext() + ) + side_effects_ctx = ( + dynamo_allow_side_effects_in_hop(tx) + if allow_side_effects + else contextlib.nullcontext() + ) + + with autograd_ctx, side_effects_ctx: + output = f.call_function(tx, args, sub_kwargs) + + return output + + +def get_hop_args( + tx, f, subtracer, sub_args, sub_kwargs, set_subgraph_inputs, description +): + sub_args_names = maybe_positional_arg_names(f) + # User mismatch in the number of args. Will eventually lead to an error. + if sub_args_names is not None and len(sub_args_names) < len(sub_args): + sub_args_names = None + args = validate_args_and_maybe_create_graph_inputs( + sub_args, + subtracer, + tx, + set_subgraph_inputs, + description, + sub_args_names, + ) + + validate_args_and_maybe_create_graph_inputs( + sub_kwargs.values(), + subtracer, + tx, + set_subgraph_inputs="automatic", + description=description, + ) + return args + + +# TODO - The eventual goal is to replace +# speculate_subgraph_with_auto_output_flattening with speculate_subgraph or +# merge them two into one. We are following a staged approach because of +# existing implementation complexity for control flow ops. +def speculate_subgraph_with_auto_output_flattening( + tx: "InstructionTranslator", + f: VariableTracker, + sub_args: Sequence[VariableTracker], + sub_kwargs: Optional[dict[str, VariableTracker]], + description: str, + *, + # source_target is the .value of HigherOrderOpVariable and is the + # target of the proxy that we created for the higherOrderOperator. + source_target: Optional[HigherOrderOperator] = None, + enable_grad: Optional[bool] = None, + # TODO - We can probably just make everyone use automatic for wrap_semantics + set_subgraph_inputs: Literal[ + "automatic", "semi_automatic", "flatten_manual", "manual" + ] = "automatic", + # If True, exposes intermediates to subgraph outputs to allow later tensor ops to + # access intermediates from the subgraph, this is useful for mutation + allow_side_effects: bool = False, + # Controls whether to filter aliased intermediates when collecting extra outputs. + # This is only relevant when allow_side_effects=True. + # - True: Filter out intermediates that alias with inputs or outputs (strict, for invoke_subgraph) + # - False: Allow aliased intermediates (for checkpoint/autograd.Function which get desugared/inlined) + # + # Example where filtering is needed: + # + # @invoke_subgraph + # def gn(x): + # view = x.view(2, 4) # intermediate that aliases input x + # y = torch.sin(view) + # return torch.cos(view) + # + # def fn(x): + # res = gn(x) + # return res + 4 + # + # In this case, if we don't filter `view`, we would later error because some HOPs + # have strict aliasing checks on inputs/outputs. + # + # This does however introduce a subtle issue when we do something like: + # + # captured = [] + # + # @invoke_subgraph + # def gn(x): + # view = x.view(2, 4) # intermediate that aliases input x + # y = torch.sin(view) + # captured.append(view) + # return torch.cos(view) + # + # def fn(x): + # res = gn(x) + # return res + captured[0] + # + # In this case, we will not replay the side effect on `captured` in the graph, + # which fails with a not-so-nice error. We will address this in a follow-up PR + # because this case is rare. This is not a regression because side effects were + # never supported for invoke_subgraph anyway. + filter_aliased_intermediates: bool = False, + # TODO - supports input_mutation and aliasing should be False by default for strictness + supports_input_mutation: bool = True, + supports_aliasing: bool = True, + # Pass in an originating tracer - this is needed for preserving context + # across fwd-bwd for autograd.Function + tracer: Optional["torch._dynamo.output_graph.SubgraphTracer"] = None, +) -> tuple[ + VariableTracker, # output: The VT that Dynamo continues tracing with + torch.fx.Graph, # graph: The FX graph representing the subgraph computation + dict[ + torch.fx.Proxy, torch.fx.Proxy + ], # lifted_freevars: Free variables lifted as inputs + VariableTracker + | tuple[ + VariableTracker, ... + ], # graph_output_vts: Tensor/symint VTs that are actual FX graph outputs +]: + """ + Speculate subgraph for Higher-Order Operators (HOPs) with automatic output flattening. + + ## Automatic output flattening + + For many HOPs, the representation exists only as a container for the + subgraph. In later compiler stages or at runtime, the HOP is desugared and + simply executes the subgraph directly, as if it were inlined. For such hops, + we follow automatic output flattening. + For example: + - invoke_subgraph + - activation checkpointing (torch.utils.checkpoint.checkpoint) + - autograd.Function + - nested_compile_region + + This is in contrast to control flow HOPs which do not follow this desugaring: + - torch.cond (conditional execution based on predicate) + - torch.while_loop (iterative execution) + - torch.map (parallel execution over batch dimension) + + For control flow HOPs, the HOP behavior is fundamentally different from just + running the body function once. + + ## Key Advantage: Disentangling VTs from Graph Outputs + + Desugaring simplify HOP processing by allowing us to disentangle the output + variable trackers (VTs) from the HOP subgraph outputs. This mirrors typical + Dynamo processing where: + - VTs "run ahead" representing the program state for continued tracing + - The graph is a side data structure tracking computation seen so far + + This separation is crucial for HOPs with non-proxyable outputs (e.g., custom + user-defined objects containing tensors). The function may return complex Python + objects for Dynamo to continue tracing, but only the tensor/symint VTs need to + be registered as actual FX graph outputs. + + Example: + class Foo: + def __init__(self, a, b): + self.a = a # tensor + self.b = b # tensor + + def gn(x): + return Foo(torch.sin(x), torch.cos(x)) + + result = some_hop(gn, x) # Returns Foo instance + out = result.a + result.b # Dynamo can continue tracing + + Here, `output` VT is a UserDefinedObjectVariable wrapping Foo, but + `graph_output_vts` contains only the tensor VTs (a and b) that should be + actual FX graph outputs. This allows Dynamo to continue tracing with the + Foo object while the graph only needs to output the constituent tensors. + + ## Return Values + + Unlike `speculate_subgraph`, this function returns: + - output: The VT that Dynamo continues tracing with (may be complex Python objects) + - graph: The FX graph representing the subgraph computation + - lifted_freevars: Free variables lifted as inputs to the subgraph + - graph_output_vts: Only the tensor/symint VTs that are actual FX graph outputs + + The key difference is `graph_output_vts` instead of `treespec`, which gives more + flexibility for handling non-proxyable outputs. + """ + if sub_kwargs is None: + sub_kwargs = {} + + assert set_subgraph_inputs in { + "automatic", + "semi_automatic", + "flatten_manual", + "manual", + }, "Please use one of the supported set_subgraph_inputs options." + + # See NOTE [Temporary argument `set_subgraph_inputs`] + if sub_kwargs and set_subgraph_inputs != "automatic": + unimplemented( + gb_type="invalid set_subgraph_inputs and sub_kwargs settings", + context=f"set_subgraph_inputs: {set_subgraph_inputs}, sub_kwargs: {sub_kwargs}", + explanation="`sub_kwargs` cannot be used when `set_subgraph_inputs` is not set to 'automatic'.", + hints=[ + "Use `set_subgraph_inputs='automatic'` when passing `sub_kwargs`.", + *graph_break_hints.USER_ERROR, + ], + ) + + try: + # ensure guards on args get installed in parent subgraph + f, sub_args, sub_kwargs = LazyVariableTracker.realize_all( + (f, sub_args, sub_kwargs), + ) + + with tx.output.subtracer(source_target, tracer, description) as subtracer: + args = get_hop_args( + tx, f, subtracer, sub_args, sub_kwargs, set_subgraph_inputs, description + ) + + # Special case - if users uses + # `traced_with_externally_visible_side_effects`, we still need to + # return the intermediates as outputs. However, this API gets + # triggered during the hop tracing, and we don't know at this point + # of time, if the API will take into effect. To handle this, we have + # a flag traced_with_externally_visible_side_effects (default=False) + # that is set to True anytime + # `traced_with_externally_visible_side_effects` is set. We reset it + # with the old value after the hop is traced out. + old_value = ( + tx.output.current_tracer.traced_with_externally_visible_side_effects + ) + + output = trace_hop_function_with_auto_output_flattening( + f, + tx, + subtracer, + enable_grad, + allow_side_effects, + args, + sub_kwargs, + ) + + # NOTE: [Separation of graph outputs and output VTs] + # In Dynamo (outside of speculate_subgraph), VTs and the graph are + # separate concepts: + # - VTs (VariableTrackers) can "run ahead" and continue Dynamo tracing + # - The graph is just a side data structure tracking computation seen so far + # + # This separation is crucial for HOPs with non-proxyable outputs (e.g., + # custom user-defined objects containing tensors). The function may return + # complex Python objects for Dynamo to continue tracing, but only the + # tensor/symint VTs need to be registered as actual graph outputs. + # + # Example: + # class Foo: + # def __init__(self, a, b): + # self.a = a # tensor + # self.b = b # tensor + # + # def gn(x): + # return Foo(torch.sin(x), torch.cos(x)) + # + # Here, `output` VT is a UserDefinedObjectVariable wrapping Foo, but + # `graph_output_vts` contains only the tensor VTs (a and b) that should + # be actual FX graph outputs. + # Collect only tensor and symint VTs that should be graph outputs. + # We walk the output structure and extract proxyable VTs. + graph_output_vts = [] + + def visit(vt): + if vt.is_tensor() or isinstance(vt, SymNodeVariable): + graph_output_vts.append(vt) + + VariableTracker.visit(visit, output) + graph_output_vts = tuple(graph_output_vts) + + # NOTE - [Return subgraph intermediates as subgraph outputs] + # This helps HOPs which allow side effects. Consider the + # following example + # + # def gn(x, z): + # o = torch.matmul(x, x) @ x + # out = x.sin() + # z.append(out) + # return torch.cos(torch.sin(o)) + + # def fn(x): + # z = [] + # out1 = torch.utils.checkpoint.checkpoint( + # gn, + # x, + # z, + # use_reentrant=False, + # ) + # return out1, z[0] + # + # In this example, list `z` is in outer scope and gets appended + # in the subgraph with `out`. But `out` is not an output of the + # subgraph. This can cause issue because later on when the outer + # graph returns `z[0]` it needs to have access to the graph node + # `out`. To solve this problem, we just return all intermediates + # from the subgraph. + + # TODO - Today this is supported only for AC. AC HOP gets + # desugared in AOTDispatcher so even though subgraph has extra + # unused outputs in Dynamo, its ok even if we don't DCE them in + # Dynamo. As AOTDispatcher desugars/inlines the subgraph, the + # subgraph boundary disappears. And even for AC, today this only + # works when the skip_fwd_side_effects_in_bwd_under_checkpoint + # flag is True, i.e., only when we allow side-effects. But, we + # want this to be supported for other Hops as well, specifically + # nested_compile_region and autograd.Function. Today, its safe + # because we error out on seeing a side-effect. + + allow_side_effects = ( + allow_side_effects + or tx.output.current_tracer.traced_with_externally_visible_side_effects + ) + if allow_side_effects: + extra_outputs = collect_intermediate_outputs( + tx, subtracer, graph_output_vts, filter_aliased_intermediates + ) + graph_output_vts = graph_output_vts + tuple(extra_outputs) + + tx.output.current_tracer.traced_with_externally_visible_side_effects = ( + old_value + ) + + validate_subgraph_output_types(graph_output_vts) + + # The output proxies might not belong to this SubgraphTracer + # (if they are free variables that were never lifted) + # so lift them here. + # output_proxies = output.as_proxy() + if isinstance(graph_output_vts, tuple): + output_proxies = [a.as_proxy() for a in graph_output_vts] + output_proxies = pytree.tree_map( + subtracer.maybe_lift_tracked_freevar_to_input, output_proxies + ) + output_proxies = tuple(output_proxies) + else: + output_proxies = output.as_proxy() + output_proxies = pytree.tree_map( + subtracer.maybe_lift_tracked_freevar_to_input, output_proxies + ) + + tx.output.create_node( + "output", + "output", + (subtracer.create_arg((output_proxies,))), + {}, + ) + graph = tx.output.graph + graph.lint() + lifted_freevars = subtracer.lifted_freevars + + if len(lifted_freevars) > 0: + move_lifted_freevars_phs_to_end(graph, lifted_freevars) + + check_aliasing_and_input_mutation( + subtracer, + graph, + supports_input_mutation, + supports_aliasing, + source_target, + ) + # Return both the output VT and the graph output VTs separately: + # - `output`: The VT that Dynamo continues tracing with (may be + # complex Python objects, tuples, dicts, etc.) + # - `graph`: The FX graph representing the subgraph computation + # - `lifted_freevars`: Free variables lifted as inputs to the subgraph + # - `graph_output_vts`: Only the tensor/symint VTs that are actual + # FX graph outputs (basically the vts associated with graph outputs) + return ( + output, + graph, + lifted_freevars, + graph_output_vts, + ) + except Unsupported as ex: + f_name = f"{type(f).__name__}" + if isinstance(f, UserFunctionVariable): + f_name = f.get_name() + msg = ( + f"speculate_subgraph: while introspecting {description}, we were unable " + f"to trace function `{f_name}` into a single graph. This means " + f"that Dynamo was unable to prove safety for this API and will " + f"fall back to eager-mode PyTorch, which could lead to a slowdown." + ) + log.info(msg) + log.info(ex) # noqa: G200 + raise ex + + +# See NOTE [HigherOrderOperator tracing design] for details of the design +def speculate_subgraph( + tx, + f, + sub_args, + sub_kwargs, + description, + *, + # source_target is the .value of HigherOrderOpVariable and is the + # target of the proxy that we created for the higherOrderOperator. + source_target=None, + always_restore=False, + enable_grad=None, + # NOTE [argument `set_subgraph_inputs`] + # set_subgraph_inputs controls what how to construct subgraphs' placeholders from sub_args. + # 1. if your HOP supports arbitrary inputs, use set_subgraph_inputs="automatic" (most recommended). + # 2. if your HOP supports only Tensor and symnode inputs, use set_subgraph_inputs="flatten_manual" (recommended). + # If sub_args contain Pytree structure (e.g. dict/list/tuple/set), the sub_args will be flattened first. + # Then the flattened args are manually set as subgraph's placeholders. + # 3. if your HOP must preserve inputs that are not tensor or symnode as placeholders e.g. AutogradFunctionContextVariable + # use set_subgraph_inputs="manual" (not recommended). We do not recommend it in general because it has the + # restriction that user need to manually control how to create placeholders and VariableTrackers for the args. + set_subgraph_inputs="automatic", + restore_side_effects=True, + should_flatten_outputs=False, + # if should_flatten_outputs is True, `remove_consts_from_outputs` remove the + # const outputs from the subgraph output. + remove_consts_from_outputs=True, + # TODO - supports input_mutation and aliasing should be False by default for strictness + supports_input_mutation=True, + supports_aliasing=True, + # Pass in an originating tracer - this is needed for preserving context + # across fwd-bwd for autograd.Function + tracer=None, +): + if sub_kwargs is None: + sub_kwargs = {} + + assert set_subgraph_inputs in { + "automatic", + "semi_automatic", + "flatten_manual", + "manual", + }, "Please use one of the supported set_subgraph_inputs options." + + # See NOTE [Temporary argument `set_subgraph_inputs`] + if sub_kwargs and set_subgraph_inputs != "automatic": + unimplemented( + gb_type="invalid set_subgraph_inputs and sub_kwargs settings", + context=f"set_subgraph_inputs: {set_subgraph_inputs}, sub_kwargs: {sub_kwargs}", + explanation="`sub_kwargs` cannot be used when `set_subgraph_inputs` is not set to 'automatic'.", + hints=[ + "Use `set_subgraph_inputs='automatic'` when passing `sub_kwargs`.", + *graph_break_hints.USER_ERROR, + ], + ) + + try: + # ensure guards on args get installed in parent subgraph + f, sub_args, sub_kwargs = LazyVariableTracker.realize_all( + (f, sub_args, sub_kwargs), + ) + + with tx.output.subtracer(source_target, tracer, description) as subtracer: + args = get_hop_args( + tx, f, subtracer, sub_args, sub_kwargs, set_subgraph_inputs, description + ) + + output = trace_hop_function( + f, + tx, + subtracer, + enable_grad, + restore_side_effects, + args, + sub_kwargs, + ) + + treespec = None + masks_to_filter_const_values = None + const_values = None + if should_flatten_outputs: + from torch._dynamo.external_utils import filter_out_const_values + + # Flatten the speculated subgraph output. + output, treespec = _make_inlined(tx, pytree.tree_flatten)( + output + ).unpack_var_sequence(tx) + + # Actually, transform the list (returned by flatten) into a tuple + # for dynamo consistency. + output = BuiltinVariable(tuple).call_function(tx, [output], {}) + + if remove_consts_from_outputs: + # Filter out the constants and save them into a spec. Filtering + # out constants makes the graph simpler for the backends. We + # need to ensure that after unflattening the constants are + # inserted back at the right positions for the Dynamo tracing to + # continue. This is done by filter_const_spec + output_proxies = output.as_proxy() + masks_to_filter_const_values = pytree.tree_map( + lambda x: not isinstance(x, torch.fx.Proxy), output_proxies + ) + const_values = pytree.tree_map( + lambda x: None if isinstance(x, torch.fx.Proxy) else x, + output_proxies, + ) + output = _make_inlined(tx, filter_out_const_values)( + output, masks_to_filter_const_values + ) + + # TODO - clean up num_intermediate_nodes_as_outputs - we do not need + # after AC moved to auto_output_flattening + num_intermediate_nodes_as_outputs = 0 + # Register output to graph + # Modeled off of compile_and_call_fx_graph + # TODO: support pytree output + # We check always_restore because we dont use the output or side effects of always_restore code, + # like bwd. + if always_restore: + # Nothing left to do here + return ( + ( + output, + OutputSpec( + treespec, + masks_to_filter_const_values, + const_values, + num_intermediate_nodes_as_outputs, + ), + ), + tx.output.graph, + subtracer.lifted_freevars, + ) + else: + validate_subgraph_output_types(output) + + # The output proxies might not belong to this SubgraphTracer + # (if they are free variables that were never lifted) + # so lift them here. + output_proxies = output.as_proxy() + output_proxies = pytree.tree_map( + subtracer.maybe_lift_tracked_freevar_to_input, output_proxies + ) + + tx.output.create_node( + "output", + "output", + (subtracer.create_arg((output_proxies,))), + {}, + ) + graph = tx.output.graph + graph.lint() + lifted_freevars = subtracer.lifted_freevars + + if len(lifted_freevars) > 0: + move_lifted_freevars_phs_to_end(graph, lifted_freevars) + + check_aliasing_and_input_mutation( + subtracer, + graph, + supports_input_mutation, + supports_aliasing, + source_target, + ) + + return ( + ( + output, + OutputSpec( + treespec, + masks_to_filter_const_values, + const_values, + num_intermediate_nodes_as_outputs, + ), + ), + graph, + lifted_freevars, + ) + + except Unsupported as ex: + f_name = f"{type(f).__name__}" + if isinstance(f, UserFunctionVariable): + f_name = f.get_name() + msg = ( + f"speculate_subgraph: while introspecting {description}, we were unable " + f"to trace function `{f_name}` into a single graph. This means " + f"that Dynamo was unable to prove safety for this API and will " + f"fall back to eager-mode PyTorch, which could lead to a slowdown." + ) + log.info(msg) + log.info(ex) # noqa: G200 + raise ex + + +def make_attr(tx: "InstructionTranslator", name): + node = tx.output.create_proxy( + "get_attr", + name, + (), + {}, + ) + return node + + +class TorchHigherOrderOperatorVariable(VariableTracker): + def __init__( + self, value: HigherOrderOperator, source: Optional[Source] = None, **kwargs + ) -> None: + super().__init__(**kwargs) + self.value = value + self.source = source + + @staticmethod + def make(value, source=None, **kwargs): + variable_class = _hop_name_to_variable_class.get(value.__name__) + if variable_class is not None: + return variable_class(value, source, **kwargs) + + from torch._higher_order_ops import BaseHOP + + if isinstance(value, BaseHOP): + return BaseHOPVariable(value, source, **kwargs) + unimplemented( + gb_type="unsupported HigherOrderOperator", + context=str(value), + explanation=f"Unable to create higher order operator variable for {value.__name__}.", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from .torch_function import can_dispatch_torch_function, dispatch_torch_function + + if can_dispatch_torch_function(tx, args, kwargs): + return dispatch_torch_function(tx, self, args, kwargs) + + return self._call_function(tx, args, kwargs) + + def _call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + unimplemented( + gb_type="unsupported HigherOrderOperator function call", + context=str(self.value), + explanation=f"Unable to trace calling higher order operator variable for {self.value.__name__}.", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + + def as_python_constant(self): + return self.value + + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + + +class CustomFunctionHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable): + """ + Wraps torch._functorch.autograd_function.custom_function_call + """ + + def _call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + return torch._dynamo.variables.UserMethodVariable( + self.value.__call__.__func__, + torch._dynamo.variables.UserDefinedObjectVariable( + self.value, source=self.source + ), + source=AttrSource(self.source, "__call__"), + ).call_function(tx, args, kwargs) + + +class CondHigherOrderVariable(TorchHigherOrderOperatorVariable): + supports_input_mutation = False + supports_aliasing = False + + @raise_hard_error_if_graph_break( + reason="Cond doesn't work unless it is captured completely with torch.compile." + ) + def _call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from . import ListVariable + + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + + for i, k in enumerate(["pred", "true_fn", "false_fn", "operands"]): + if v := kwargs.pop(k, None): + assert i == len(args), ( + "did not provide the right number of non-keyword args" + ) + args.append(v) + + # TODO(voz): Support fake tensor dispatch for recursive + # ops - see torch/dispatch/_dispatcher.py + if len(args) != 4 or kwargs: + unimplemented( + gb_type="torch.cond: improper args/kwargs", + context=f"args: {args}, kwargs: {kwargs}", + explanation=f"torch.cond expects 4 positional arguments (got {len(args)}) " + f"and no keyword arguments (got {len(kwargs)}) " + "Usage: cond(pred, cond_fn, body_fn, operands)", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + # Specialize into one of the branches since pred is constant + pred, true_fn, false_fn, operands = args + if type(args[0]) is ConstantVariable: + warnings.warn( + "Pred is a Python constant. When used with torch.cond, it specializes on one of the branches." + " If you want torch.cond to preserve two branches, please make the predicate a boolean tensor or a SymBool.", + UserWarning, + ) + if pred.as_python_constant(): + return true_fn.call_function(tx, operands.unpack_var_sequence(tx), {}) + else: + return false_fn.call_function(tx, operands.unpack_var_sequence(tx), {}) + + # predicate + if type(pred.realize()) not in ( + ConstantVariable, + TensorVariable, + SymNodeVariable, + ): + unimplemented( + gb_type="torch.cond: improper predicate", + context=str(pred), + explanation="Expected `pred` to be a bool or a boolean tensor with a single item " + f"but got {str(type(pred))} with original python type {str(pred.python_type())}.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + # operands + if not isinstance(operands, (ListVariable, TupleVariable)): + unimplemented( + gb_type="torch.cond: improper operands", + context=str(operands), + explanation="Expected `operands` to be a list/tuple " + f"but got {operands.python_type()}.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + operands_seq = operands.unpack_var_sequence(tx) + if not only_consist_of( + operands, (TensorVariable, ConstantVariable, SymNodeVariable) + ): + unimplemented( + gb_type="torch.cond: improper operands contents", + context=str(operands), + explanation="Expected `operands` to be a list/tuple of pytrees that only consists of tensor leaves.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + # branches + _check_supported_callable_arg(tx, true_fn, "true_fn") + _check_supported_callable_arg(tx, false_fn, "false_fn") + + # Our strategy for tracing the true/false branches of cond + # are to checkpoint our graphstate, run the true branch, + # roll it back to the checkpoint, and run the false + # branch, and then merge the graphstates. Well, perhaps + # "merge" is too strong a word: we mostly assert that + # the resulting graphstates have to be the same. + # + # We only permit guards to diverge (we union the guards from + # both branches). In particular, this means that side + # effects are NOT permitted inside true/false branches; this + # would be difficult to implement, because of the path + # explosion problem. + + def speculate_branch(branch): + # NB: 0 is predicate + ix = 1 if branch else 2 + # TODO: Support kwargs + ( + (ret_val, ret_spec), + ret_graph, + ret_lifted_freevars, + ) = speculate_subgraph( + tx, + args[ix], + operands_seq, + {}, + "cond", + source_target=self.value, + should_flatten_outputs=True, + # TODO - removing consts from control flow ops need more work + remove_consts_from_outputs=False, + supports_input_mutation=self.supports_input_mutation, + supports_aliasing=self.supports_aliasing, + ) + + # need to ensure we increase epoch so we don't memoize unbacked bindings + # across different subgraphs which can interfere with runtime assertion + # generation. + tx.fake_mode.epoch += 1 + + if not only_consist_of(ret_val, (TensorVariable, ConstantVariable)): + unimplemented( + gb_type="torch.cond: unsupported branch return type", + context=str(ret_val), + explanation="Expected branches to return a possibly nested pytree of tensors or constant ints.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + for ret in ret_val.unpack_var_sequence(tx): + if ret.is_python_constant() and not isinstance( + ret.as_python_constant(), int + ): + unimplemented( + gb_type="torch.cond: unsupported branch return type (constant non-int)", + context=str(ret_val), + explanation="Constants returned from branches must be ints.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + return ret_val, ret_spec, ret_graph, ret_lifted_freevars + + (true_r, true_spec, true_graph, true_lifted_freevars) = speculate_branch(True) + true_nn_modules = dict(tx.output.nn_modules) + + ( + false_r, + false_spec, + false_graph, + false_lifted_freevars, + ) = speculate_branch(False) + false_nn_modules = dict(tx.output.nn_modules) + + same_spec = _make_inlined(tx, pytree.TreeSpec.__eq__)( + true_spec.treespec, false_spec.treespec + ).as_python_constant() + # 3.14: NotImplemented cannot be converted to bool + if same_spec is not NotImplemented and not same_spec: + unimplemented( + gb_type="torch.cond: differing branch outputs", + context=f"true_spec: {true_spec.treespec}, false_spec: {false_spec.treespec}, same_spec: {same_spec}", + explanation="Expected branches to return the same pytree structure.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + ( + true_graph, + false_graph, + true_shared, + _false_shared, + unique_true, + unique_false, + ) = _merge_graph_inputs( + true_graph, + true_lifted_freevars, + "true_branch", + false_graph, + false_lifted_freevars, + "false_branch", + ) + + true_name = tx.output.install_subgraph( + "cond_true", + torch.fx.GraphModule(true_nn_modules, true_graph), + ) + false_name = tx.output.install_subgraph( + "cond_false", + torch.fx.GraphModule(false_nn_modules, false_graph), + ) + + true_node = make_attr(tx, true_name) + false_node = make_attr(tx, false_name) + + p_args = ( + pred.as_proxy(), + true_node, + false_node, + # We pick true_shared but it shouldn't matter + tuple(true_shared + unique_true + unique_false), + ) + + return _call_function_and_unflatten_output( + tx, + torch.ops.higher_order.cond, + p_args, + {}, + None, + true_spec, + true_r, + ) + + +class CallTorchbindHigherOrderVariable(TorchHigherOrderOperatorVariable): + def __init__(self, hop, source, script_obj_var, method_name) -> None: + super().__init__(hop, source) + self.script_obj_var = script_obj_var + self.method_name = method_name + + def _call_function( + self, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from .builder import wrap_fx_proxy + + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + + args_proxy = [arg.as_proxy() for arg in args] + kwargs_proxy = {k: v.as_proxy() for k, v in kwargs.items()} + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=tuple( + [self.script_obj_var.as_proxy(), self.method_name] + args_proxy + ), + kwargs=kwargs_proxy, + ), + ) + + +def validate_subgraph_output_types(output: VariableTracker): + """Verify that that the output of the subgraph is a tensor, + int, bool, SymBool, or SymInt. + """ + from . import TensorVariable + + if non_tensor_output := find_mismatched_vars( + output, TensorVariable, allow_none=True + ): + for out in non_tensor_output: + if ( + isinstance(out, SymNodeVariable) and out.python_type() in (int, bool) + ) or ( + out.is_python_constant() + and isinstance(out.as_python_constant(), (int, bool)) + ): + continue + unimplemented( + gb_type="HOP body output unsupported", + context=f"non-tensor outputs: {non_tensor_output}", + explanation="HigherOrderOperator body's output must consist of tensors or ints/bools only " + f"but got {out.python_type()}.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + +class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable): + supports_input_mutation = False + supports_aliasing = False + + @raise_hard_error_if_graph_break( + reason="while_loop doesn't work unless it is captured completely with torch.compile." + ) + def _call_function( + self, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + return _call_while_loop(self, tx, args, kwargs, stack_output=False) + + +class WhileLoopStackOutputHigherOrderVariable(TorchHigherOrderOperatorVariable): + supports_input_mutation = False + supports_aliasing = False + + @raise_hard_error_if_graph_break( + reason="while_loop_stack_output doesn't work unless it is captured completely with torch.compile." + ) + def _call_function( + self, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + return _call_while_loop(self, tx, args, kwargs, stack_output=True) + + +class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable): + supports_input_mutation = False + supports_aliasing = False + + @raise_hard_error_if_graph_break( + reason="associative_scan must be captured completely with torch.compile." + ) + def _call_function( + self, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from torch._higher_order_ops.utils import first_slice_copy + + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + + def arg_extractor(combine_fn, xs, additional_inputs): + return combine_fn, xs, additional_inputs + + combine_fn, xs, additional_inputs = arg_extractor(*args, **kwargs) + + if args[0].python_type() is functools.partial: + # This is the standard case when the user calls the frontend + # and the frontend invokes dynamo + if len(args) != 2: + unimplemented( + gb_type="torch.associative_scan: improper args", + context=f"args: {args}", + explanation=f"torch.associative_scan expects 2 positional arguments (got {len(args)}) " + "Usage: associative_scan(combine_fn, xs)", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + xs_treespec = args[0].keywords["spec"] + + # combine_fn input check + # We need to get the pure combine_fn from the functools.partial + _check_supported_callable_arg( + tx, combine_fn.keywords["combine_fn"], "combine_fn" + ) + else: + # This case is hit during re-tracing, for example in export tests + # In this case, the combine_fn is a callable and not a functools.partial + xs_treespec = _make_inlined(tx, pytree.tree_structure)(xs) + + _check_supported_callable_arg(tx, combine_fn, "combine_fn") + + # xs input check + if not isinstance(xs, (ListVariable, TupleVariable)): + unimplemented( + gb_type="torch.associative_scan: improper xs", + context=str(xs), + explanation=f"Expected xs to be a list/tuple but got {xs.python_type()}", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + xs_vars = xs.unpack_var_sequence(tx) + _check_all_tensorvariable(xs_vars) + + # additional_inputs input check + if not isinstance(additional_inputs, (ListVariable, TupleVariable)): + unimplemented( + gb_type="torch.associative_scan: improper additional_inputs", + context=str(additional_inputs), + explanation=f"Expected additional_inputs to be a list/tuple but got {additional_inputs.python_type()}", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + additional_inputs_vars = additional_inputs.unpack_var_sequence(tx) + _check_all_tensorvariable(additional_inputs_vars) + + scan_length = get_fake_value(xs_vars[0].as_proxy().node, tx).size()[0] + if scan_length == 0: + unimplemented( + gb_type="torch.associative_scan: zero-sized tensor", + context=str(xs_vars[0]), + explanation="associative_scan() operator doesn't support zero-sized tensors during tracing.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + # Trace the subgraph + # The sub_args is a slice of original input, e.g. if input.size is (3, 4), and scan dim=0 + # the sub_args shape will be (4, ). + with discard_graph_changes(tx): + sub_args = [ + _make_inlined(tx, first_slice_copy)(leaf) + for leaf in itertools.chain(xs_vars, xs_vars) + ] + sub_args_additional_inputs = [ + t.call_method(tx, "clone", args=(), kwargs={}) + for t in additional_inputs_vars + ] + + sub_args = sub_args + sub_args_additional_inputs + ( + (combine_result, _combine_spec), + combine_graph, + combine_lifted_freevars, + ) = speculate_subgraph( + tx, + combine_fn, + sub_args, + sub_kwargs={}, + description="associative_scan_combine_fn", + source_target=self.value, + set_subgraph_inputs="flatten_manual", + supports_input_mutation=self.supports_input_mutation, + supports_aliasing=self.supports_aliasing, + ) + + # Ensure that the output of scan is a flattened list of elements, + # because downstream operations assume that the output of HOPs + # is flattened + output_node = combine_graph.find_nodes(op="output")[0] + output_node.args = (pytree.tree_leaves(output_node.args),) + combine_graph.lint() + + # Collect the results from the combine_fn + results, _combine_treespec = _make_inlined(tx, pytree.tree_flatten)( + combine_result + ).unpack_var_sequence(tx) + + # Check whether the combine_fn returns one child tree for the output. + if _combine_treespec.as_python_constant().num_leaves < 1: + unimplemented( + gb_type="torch.associative_scan: combine_fn improper number of leaves", + context=str(_combine_treespec.as_python_constant()), + explanation="combine_fn needs to produce one pytree for the output " + f"but combine_fn produces the pytree {_combine_treespec.as_python_constant()}.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + # Check whether the outs produced by combine_fn has the same treespec as xs + # We need to have this check this way, because in case init is a TreeSpec and carry + # but carry is only a LeafSpec, these two cannot be compared correctly. + if ( + xs_treespec.as_python_constant().is_leaf() + != _combine_treespec.as_python_constant().is_leaf() + ) or not _make_inlined(tx, pytree.TreeSpec.__eq__)( + xs_treespec, _combine_treespec + ).as_python_constant(): + unimplemented( + gb_type="torch.associative_scan: mismatched input/output tree structure", + context=f"xs: {xs_treespec.as_python_constant()}, output: {_combine_treespec.as_python_constant()}", + explanation="The tree structure of the xs and the outs of the combine_fn are are expected to be identical, but got " + f"xs: {xs_treespec.as_python_constant()} vs output: {_combine_treespec.as_python_constant()}.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + # We set include contiguity=False because we have vmap x HOP tests, where if + # include_contiguity=True will call t.is_contiguous inside of vmap and get an error + # "querying is_contiguous inside of vmap for memory_format other than + # torch.contiguous_format is not yet implemented". This is okay because stride + # is still checked. + check_meta_consistency_vt( + [_make_inlined(tx, first_slice_copy)(t) for t in xs_vars], + results.items, + "initial_xs", + "combine_fn_output", + include_contiguity=False, + ) + + combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph) + combine_freevars_proxy = tuple(combine_lifted_freevars.keys()) + + # Compute the proxies for the input check + proxy_vars_inputcheck = ( + tuple(sarg.as_proxy() for sarg in sub_args) + combine_freevars_proxy + ) + + from torch._higher_order_ops.utils import _maybe_fake_tracing + from torch._inductor.utils import is_pointwise_use + + with tx.fake_mode: + sub_args_fake = [ + ( + leaf.node.meta["example_value"].clone() + if hasattr(leaf.node.meta["example_value"], "clone") + else leaf.node.meta["example_value"] + ) + for leaf in pytree.tree_leaves(proxy_vars_inputcheck) + ] + pre_dispatch = False + + fx = _maybe_fake_tracing( + combine_gm, sub_args_fake, pre_dispatch=pre_dispatch + ) + + for node in fx.graph.nodes: + # Check that the combine_fn is pointwise, if combine_mode='pointwise' + if not all( + is_pointwise_use(use) or use.op == "output" for use in node.users + ): + raise RuntimeError( + "For combine_mode='pointwise', the combine_fn needs to be pointwise" + ) + + combine_fn_name = tx.output.install_subgraph( + "associative_scan_combine_fn", combine_gm + ) + + # Compute the proxies + xs_proxy = xs.as_proxy() + combine_freevars_proxy = tuple(combine_lifted_freevars.keys()) + additional_inputs_proxy = additional_inputs.as_proxy() + combine_freevars_proxy + + p_args = ( + make_attr(tx, combine_fn_name), + xs_proxy, + additional_inputs_proxy, + ) + + return _call_function_and_unflatten_output( + tx, + torch.ops.higher_order.associative_scan, + p_args, + {}, + None, + OutputSpec(xs_treespec), + None, + ) + + +class ScanHigherOrderVariable(TorchHigherOrderOperatorVariable): + supports_input_mutation = False + supports_aliasing = False + + @raise_hard_error_if_graph_break( + reason="scan must be captured completely with torch.compile." + ) + def _call_function( + self, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from torch._higher_order_ops.scan import _extract_carry_and_out + from torch._higher_order_ops.utils import first_slice_copy + + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + + # combine_fn input check + def _check_combine_fn_is_normalized(combine_fn_var): + if not isinstance( + combine_fn_var, + ( + variables.nn_module.NNModuleVariable, + variables.nn_module.UnspecializedNNModuleVariable, + variables.FunctoolsPartialVariable, + ), + ): + unimplemented( + gb_type="torch.scan: improper combine_fn", + context=str(combine_fn_var), + explanation="Expected combine_fn to be wrapped as functools.partial in scan user-facing api " + f"or a graph module if we're re-exporting but got {combine_fn_var.python_type()}.", + hints=[ + *graph_break_hints.DIFFICULT, + ], + ) + return isinstance( + combine_fn_var, + ( + variables.nn_module.NNModuleVariable, + variables.nn_module.UnspecializedNNModuleVariable, + ), + ) + + def arg_extractor(combine_fn, init, xs, additional_inputs): + return combine_fn, init, xs, additional_inputs + + combine_fn, init, xs, additional_inputs = arg_extractor(*args, **kwargs) + init_vars = init.unpack_var_sequence(tx) + xs_vars = xs.unpack_var_sequence(tx) + additional_inputs_vars = additional_inputs.unpack_var_sequence(tx) + + # combine_fn input check + combine_fn_is_normalized = _check_combine_fn_is_normalized(combine_fn) + if combine_fn_is_normalized: + combine_gm = combine_fn.value + assert isinstance(combine_gm, torch.fx.GraphModule), ( + combine_fn, + combine_gm, + ) + else: + # combine_fn input check + # We need to get the pure combine_fn from the functools.partial + _check_supported_callable_arg( + tx, combine_fn.keywords["combine_fn"], "combine_fn" + ) + # xs input check + if not isinstance(xs, (ListVariable, TupleVariable)): + unimplemented( + gb_type="torch.scan: improper xs", + context=str(xs), + explanation=f"Expected xs to be a list/tuple but got {xs.python_type()}", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + # init input check + if not isinstance(init, (ListVariable, TupleVariable)): + unimplemented( + gb_type="torch.scan: improper init", + context=str(init), + explanation=f"Expected init to be a list/tuple with at least one element but got {init.python_type()}", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + + if len(init_vars) == 0: + unimplemented( + gb_type="torch.scan: no init leaves", + context="", + explanation="Expected init leaves.", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + + # additional_inputs input check + if not isinstance(additional_inputs, (ListVariable, TupleVariable)): + unimplemented( + gb_type="torch.scan: improper additional_inputs", + context=str(additional_inputs), + explanation=f"Expected additional_inputs to be a list/tuple but got {additional_inputs.python_type()}", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + # scan_length check + scan_length = get_fake_value(xs_vars[0].as_proxy().node, tx).size()[0] + if scan_length == 0: + unimplemented( + gb_type="torch.scan: zero-sized tensor", + context=str(xs_vars[0]), + explanation="associative_scan() operator doesn't support zero-sized tensors during tracing.", + hints=[ + *graph_break_hints.USER_ERROR, + *graph_break_hints.SUPPORTABLE, + ], + ) + _check_all_tensorvariable(init_vars) + _check_all_tensorvariable(xs_vars) + _check_all_tensorvariable(additional_inputs_vars) + + with discard_graph_changes(tx): + sub_args_init = [ + ini.call_method(tx, "clone", args=(), kwargs={}) for ini in init_vars + ] + # The sub_args_inp is a slice of original input, e.g. if input.size is (3, 4), and scan dim=0 + # the sub_args_inp shape will be (4, ). + sub_args_inp = [_make_inlined(tx, first_slice_copy)(inp) for inp in xs_vars] + sub_args_additional_inputs = [ + t.call_method(tx, "clone", args=(), kwargs={}) + for t in additional_inputs_vars + ] + + sub_args = sub_args_init + sub_args_inp + sub_args_additional_inputs + ( + (combine_result, _combine_spec), + combine_graph, + combine_lifted_freevars, + ) = speculate_subgraph( + tx, + combine_fn, + sub_args, + sub_kwargs={}, + description="scan_combine_fn", + source_target=self.value, + set_subgraph_inputs="flatten_manual", + supports_input_mutation=self.supports_input_mutation, + supports_aliasing=self.supports_aliasing, + ) + + # Ensure that the output of scan is a flattened list of elements, + # because downstream operations assume that the output of HOPs + # is flattened + output_node = combine_graph.find_nodes(op="output")[0] + output_node.args = (pytree.tree_leaves(output_node.args),) + combine_graph.lint() + combine_freevars_proxy = list(combine_lifted_freevars.keys()) + combine_result_vars = combine_result.unpack_var_sequence(tx) + + if combine_fn_is_normalized: + carry_vars, out_vars = _extract_carry_and_out( + combine_result_vars, len(init_vars) + ) + else: + if len(combine_result_vars) != 2: + unimplemented( + gb_type="torch.scan: improper combine_fn number of returns", + context=str(combine_result_vars), + explanation=f"Expect combine_fn to return a tuple (next_carry, y) but got {combine_result_vars}.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + carry_tree, out_vars = combine_result_vars + carry_vars, _ = _make_inlined(tx, pytree.tree_flatten)( + carry_tree + ).unpack_var_sequence(tx) + carry_vars = carry_vars.unpack_var_sequence(tx) + out_vars = _make_inlined(tx, pytree.tree_leaves)( + out_vars + ).unpack_var_sequence(tx) + + # additional output checking + _combine_spec = OutputSpec( + _make_inlined(tx, pytree.tree_structure)(combine_result) + ) + + check_meta_consistency_vt( + init_vars, + carry_vars, + "init", + "carry", + ) + + # Check meta data of carries and inits. If we pass this stage, we are sure that the init and carries + # have the same tree structure. + # We set include contiguity=False because we have vmap x HOP tests, where if + # include_contiguity=True will call t.is_contiguous inside of vmap and get an error + # "querying is_contiguous inside of vmap for memory_format other than + # torch.contiguous_format is not yet implemented". This is okay because stride + # is still checked. + check_meta_consistency_vt( + init_vars, + carry_vars, + "init", + "carry", + include_contiguity=False, + ) + + xs_proxy = xs.as_proxy() + init_proxy = init.as_proxy() + additional_inputs_proxy = list(additional_inputs.as_proxy()) + list( + combine_freevars_proxy + ) + + combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph) + combine_fn_name = tx.output.install_subgraph("scan_combine_fn", combine_gm) + + p_args = ( + make_attr(tx, combine_fn_name), + init_proxy, + xs_proxy, + additional_inputs_proxy, + ) + + return _call_function_and_unflatten_output( + tx, + torch.ops.higher_order.scan, + p_args, + {}, + None, + _combine_spec, + None, + ) + + +def non_single_tensor_return_unsupported(api, ret): + if not ret.is_tensor(): + unimplemented( + gb_type="non-single Tensor return unsupported", + context=f"api: {api}, ret: {ret}", + explanation=f"{api} over function that returns something other than one Tensor.", + hints=[], + ) + + +class MapHigherOrderVariable(TorchHigherOrderOperatorVariable): + supports_input_mutation = False + supports_aliasing = False + + @raise_hard_error_if_graph_break( + reason="map doesn't work unless it is captured completely with torch.compile." + ) + def _call_function( + self, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + + if len(kwargs) > 0: + unimplemented( + gb_type="torch.map: kwargs not supported", + context=f"args: {args}, kwargs: {kwargs}", + explanation=f"torch.map expects no keyword arguments (got {len(kwargs)})", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + _check_supported_callable_arg(tx, args[0], "map_fn") + + # args = f, flat_xs, flat_args + assert isinstance(args[1], (ListVariable, TupleVariable)), args[1] + assert isinstance(args[2], (ListVariable, TupleVariable)), args[2] + unpacked_xs = args[1].unpack_var_sequence(tx) + unpacked_args = args[2].unpack_var_sequence(tx) + + sample_shape = get_fake_value(unpacked_xs[0].as_proxy().node, tx).size() + + if len(sample_shape) < 1 or sample_shape[0] == 0: + unimplemented( + gb_type="torch.map: improper inputs", + context=str(sample_shape), + explanation="torch.map doesn't support scalar or non-zero sized tensors during tracing.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + # To get the example output from map() we will need to provide at least one sample to + # the loop body. In our case we will always use xs[0], and our map() won't support zero + # sized tensor during tracing. + with discard_graph_changes(tx): + sliced_xs = [ + xs.call_method( + tx, + "select", + args=(VariableTracker.build(tx, 0), VariableTracker.build(tx, 0)), + kwargs={}, + ) + for xs in unpacked_xs + ] + + # TODO: Support kwargs + ( + (body_r, body_spec), + body_graph, + body_lifted_freevars, + ) = speculate_subgraph( + tx, + args[0], + [ + *sliced_xs, + *unpacked_args, + ], + {}, + "torch.ops.higher_order.map", + source_target=self.value, + set_subgraph_inputs="flatten_manual", + should_flatten_outputs=True, + # TODO - removing consts from control flow ops need more work + remove_consts_from_outputs=False, + supports_input_mutation=self.supports_input_mutation, + supports_aliasing=self.supports_aliasing, + ) + + # Check all outputs of map are tensors. + # For map, outputting None is OK, thus ignore None values in the check + body_r_vars = body_r.unpack_var_sequence(tx) + none_mask = [x.is_constant_none() for x in body_r_vars] + _check_all_tensorvariable( + [br for bm, br in zip(none_mask, body_r_vars) if not bm] + ) + + body_nn_modules = dict(tx.output.nn_modules) + + body_name = tx.output.install_subgraph( + "map_body", + torch.fx.GraphModule(body_nn_modules, body_graph), + ) + + body_node = make_attr(tx, body_name) + + p_args = ( + body_node, + [xs.as_proxy() for xs in unpacked_xs], + [arg.as_proxy() for arg in unpacked_args] + + list(body_lifted_freevars.keys()), + ) + + return _call_function_and_unflatten_output( + tx, torch.ops.higher_order.map_impl, p_args, {}, None, body_spec, body_r + ) + + +class PrintHigherOrderVariable(TorchHigherOrderOperatorVariable): + def _call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + + args_proxy = [arg.as_proxy() for arg in args] + kwargs_proxy = {k: v.as_proxy() for k, v in kwargs.items()} + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=tuple(args_proxy), + kwargs=kwargs_proxy, + ), + ) + + +class ExecutorchCallDelegateHigherOrderVariable(TorchHigherOrderOperatorVariable): + def _call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + # This is operator for delegation within Executorch which calls a + # specific function in the given lowered module with the given + # operators. The actual operator is defined in the Executorch codebase. + # This is a bad hierarchical violation since + # executorch_call_delegate sits at a higher level than dynamo, but + # there's no real solution to this issue yet. + if len(kwargs) > 0: + unimplemented( + gb_type="executorch_call_delegate: kwargs not supported", + context=f"args: {args}, kwargs: {kwargs}", + explanation=f"executorch_call_delegate expects no keyword arguments (got {len(kwargs)})", + hints=[], + ) + if isinstance(args[0], variables.NNModuleVariable): + lowered_module = tx.output.get_submodule(args[0].module_key) + lowered_node = make_attr(tx, args[0].module_key) + elif isinstance(args[0], variables.UnspecializedNNModuleVariable): + # This nn module is special sa delegated by executorch. Just + # install it as a attr in the graph. + lowered_module = args[0].value + lowered_node = tx.output.register_static_attr_and_return_proxy( + "delegate", lowered_module + ) + + p_args = tuple(arg.as_proxy() for arg in args[1:]) + real_sub_args = pytree.tree_map_only( + torch.fx.Proxy, lambda a: get_fake_value(a.node, tx), p_args + ) + + with tx.fake_mode: + example_value = lowered_module.original_module.module()(*real_sub_args) + + # NOTE [Guaranteeing the 1-1 correspondence of FakeTensors and real tensors]: + # executorch modules promise not to alias inputs and outputs. + # Thus, output FakeTensors will correctly not alias input FakeTensors. + _assert_tensors_nonaliasing(real_sub_args, example_value) + + p_args = (lowered_node,) + p_args + + # Store the invocation as a call + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=tuple(p_args), + kwargs={}, + ), + example_value=example_value, + ) + + +class FunctorchHigherOrderVariable(UserFunctionVariable): + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + return super().call_function(tx, args, kwargs) + + def should_allow_nested_graph_breaks(self): + return False + + +class FunctionalCallVariable(FunctorchHigherOrderVariable): + def call_function( + self, tx, args: list[VariableTracker], kwargs: dict[str, VariableTracker] + ) -> VariableTracker: + if not torch._dynamo.config.inline_inbuilt_nn_modules: + unimplemented( + gb_type="torch.func.functional_call capture is disabled", + context="", + explanation="torch.func.functional_call capture is disabled", + hints=[ + "Set `torch._dynamo.config.inline_inbuilt_nn_modules=True` to enable.", + ], + ) + return super().call_function(tx, args, kwargs) + + +class ReparametrizeModuleCallVariable(FunctorchHigherOrderVariable): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def call_function( + self, tx, args: list[VariableTracker], kwargs: dict[str, VariableTracker] + ) -> VariableTracker: + ctx_manager_vt = super().call_function(tx, args, kwargs) + return RepararametrizeModuleContextVariable(ctx_manager_vt, args[0]) + + +class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable): + supports_input_mutation = True + supports_aliasing = True + allow_side_effects = False + + def install_subgraph_in_output_graph( + self, tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name="wrap_body" + ): + return tx.output.install_subgraph( + f"{attr_name}", + body_gmod, + ) + + def create_wrapped_node( + self, + tx: "InstructionTranslator", + fn_vt, + fn_args_vt, + kwargs, + description, + *, + subgraph_name="wrap_body", + ): + # See NOTE [HigherOrderOperator tracing design] for more details + ( + body_r, + body_graph, + body_lifted_freevars, + body_graph_output_vts, + ) = speculate_subgraph_with_auto_output_flattening( + tx, + fn_vt, + fn_args_vt, + kwargs, + description, + source_target=self.value, + allow_side_effects=self.allow_side_effects, + filter_aliased_intermediates=getattr( + self, "filter_aliased_intermediates", False + ), + supports_input_mutation=self.supports_input_mutation, + supports_aliasing=self.supports_aliasing, + ) + + body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph) + body_name = self.install_subgraph_in_output_graph( + tx, + fn_vt, + fn_args_vt, + kwargs, + body_gmod, + attr_name=subgraph_name, + ) + body_node = make_attr(tx, body_name) + + # Since, we call `speculate_subgraph` with `set_subgraph_inputs="automatic`, + # all the arguments are lifted. + lifted_args = tuple(arg for arg in body_lifted_freevars) + + proxy_args = (body_node,) + lifted_args + + example_value = pytree.tree_map_only( + torch.fx.Node, + lambda a: a.meta["example_value"], + body_graph.find_nodes(op="output")[0].args[0], + ) + + return ( + proxy_args, + {}, + example_value, + body_r, + body_gmod, + body_name, + body_graph_output_vts, + ) + + def _call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + # This flattens the kwargs into lifted args + ( + p_args, + p_kwargs, + _example_value, + body_r, + _, + _, + body_graph_output_vts, + ) = self.create_wrapped_node(tx, args[0], args[1:], kwargs, "wrap") + + if len(p_kwargs) > 0: + unimplemented( + gb_type="WrapHigherOrderVariable: kwargs unexpected", + context=f"args: {args}, kwargs: {kwargs}", + explanation="kwargs should have been flattened into lifted args.", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + + return _call_function_with_auto_output_flattening( + tx, + self.value, + tuple(p_args), + p_kwargs, + _example_value, + body_r, + body_graph_output_vts, + ) + + +class WrapWithSetGradEnabledHigherOrderVariable(TorchHigherOrderOperatorVariable): + """ + This hop is not exposed to users but is inserted into the graph + after export as a post-processing step. + """ + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + + if kwargs: + unimplemented( + gb_type="wrap_with_set_grad_enabled: unexpected kwargs", + context=f"args: {args}, kwargs: {kwargs}", + explanation=f"wrap_with_set_grad_enabled expects no keyword arguments (got {len(kwargs)}).", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + + grad_enabled, fn_var, *rest_args = args + + if not grad_enabled.is_python_constant(): + unimplemented( + gb_type="wrap_with_set_grad_enabled: non-constant grad_enabled", + context=str(grad_enabled), + explanation="wrap_with_set_grad_enabled expects grad_enabled argument to be a constant.", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + + _check_supported_callable_arg(tx, fn_var, "enable_grad_fn") + + with torch.set_grad_enabled(grad_enabled.as_python_constant()): + ( + (body_r, treespec), + body_graph, + body_lifted_freevars, + ) = speculate_subgraph( + tx, + fn_var, + [*rest_args], + {}, + "torch.ops.higher_order.wrap_with_set_grad_enabled", + source_target=self.value, + set_subgraph_inputs="manual", + should_flatten_outputs=True, + ) + + if len(body_lifted_freevars) > 0: + unimplemented( + gb_type="wrap_with_set_grad_enabled: unexpected freevars", + context=str(body_lifted_freevars), + explanation="wrap_with_set_grad_enabled expects no freevars.", + hints=[], + ) + + body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph) + body_name = tx.output.install_subgraph( + "wrap_body", + body_gmod, + ) + + body_node = make_attr(tx, body_name) + + proxy_args = tuple( + [ + grad_enabled.as_python_constant(), + body_node, + ] + + [operand.as_proxy() for operand in rest_args] + ) + example_value = pytree.tree_map_only( + torch.fx.Proxy, + lambda a: a.node.meta["example_value"], + body_r.as_proxy(), + ) + return _call_function_and_unflatten_output( + tx, self.value, proxy_args, {}, example_value, treespec, body_r + ) + + +class WrapWithAutocastHigherOrderVariable(TorchHigherOrderOperatorVariable): + """ + This hop is not exposed to users but is inserted into the graph + after export as a post-processing step. + """ + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + + if kwargs: + unimplemented( + gb_type="wrap_with_autocast: unexpected kwargs", + context=f"args: {args}, kwargs: {kwargs}", + explanation=f"wrap_with_autocast expects no keyword arguments (got {len(kwargs)}).", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + + device_type, dtype, enabled, cache_enabled, fn_var, *rest_args = args + + for arg in [device_type, dtype, enabled, cache_enabled]: + if not arg.is_python_constant(): + unimplemented( + gb_type="wrap_with_autocast: expected constant arg", + context=str(args), + explanation="wrap_with_autocast expects device_type, dtype, enabled, " + "and cache_enabled arguments to be constants.", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + + _check_supported_callable_arg(tx, fn_var, "autocast") + + python_constants = [ + arg.as_python_constant() + for arg in [device_type, dtype, enabled, cache_enabled] + ] + + with torch.autocast(*python_constants): + ( + (body_r, treespec), + body_graph, + body_lifted_freevars, + ) = speculate_subgraph( + tx, + fn_var, + [*rest_args], + {}, + "torch.ops.higher_order.wrap_with_autocast", + source_target=self.value, + set_subgraph_inputs="manual", + should_flatten_outputs=True, + ) + + if len(body_lifted_freevars) > 0: + unimplemented( + gb_type="wrap_with_autocast: unexpected freevars", + context=str(body_lifted_freevars), + explanation="wrap_with_autocast expects no freevars.", + hints=[], + ) + + body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph) + body_name = tx.output.install_subgraph( + "wrap_body", + body_gmod, + ) + + body_node = make_attr(tx, body_name) + + proxy_args = tuple( + [ + *python_constants, + body_node, + ] + + [operand.as_proxy() for operand in rest_args] + ) + example_value = pytree.tree_map_only( + torch.fx.Proxy, + lambda a: a.node.meta["example_value"], + body_r.as_proxy(), + ) + + return _call_function_and_unflatten_output( + tx, self.value, proxy_args, {}, example_value, treespec, body_r + ) + + +class HintsWrapperHigherOrderVariable(WrapHigherOrderVariable): + def install_subgraph_in_output_graph( + self, tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name="wrap_body" + ): + return tx.output.install_subgraph( + "hints_wrapper_body", + body_gmod, + ) + + @raise_hard_error_if_graph_break( + reason="hints_wrapper doesn't work unless it is captured completely with torch.compile." + ) + def _call_function( + self, tx, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]" + ) -> "VariableTracker": + _check_supported_callable_arg(tx, args[0], "body_fn") + + # inputs + if ( + len(args) != 3 + or not isinstance(args[1], (ListVariable, TupleVariable)) + or not isinstance(args[2], ConstDictVariable) + or len(kwargs) != 1 + or "hints" not in kwargs + ): + unimplemented( + gb_type="hints_wrapper: improper args/kwargs", + context=f"args: {args}, kwargs: {kwargs}", + explanation=f"hints_wrapper expects 3 positional arguments (got {len(args)}) " + f"and 1 keyword argument (got {len(kwargs)}). " + "Usage: hints_wrapper(body_fn, args, kwargs, hints=...). " + "args is expected to be list/tuple and kwargs is expected to be a dict.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + operands = args[1].unpack_var_sequence(tx) + fn_kwargs = args[2].as_python_constant() + + # Use create_wrapped_node from WrapHigherOrderVariable + ( + p_args, + _, + example_value, + body_r, + body_gmod, + _, + body_graph_output_vts, + ) = self.create_wrapped_node( + tx, + args[0], # function + operands, + fn_kwargs, + "hints_wrapper", + ) + + # hints_wrapper expects (body_node, args, kwargs) as positional args + # So we need to restructure p_args from (body_node, *lifted_args) + # to (body_node, lifted_args_tuple, {}) + body_node = p_args[0] + lifted_args = p_args[1:] + p_args = (body_node, tuple(lifted_args), {}) + + # add hints into p_kwargs + p_kwargs = {} + p_kwargs["hints"] = kwargs["hints"].as_python_constant() + + return _call_function_with_auto_output_flattening( + tx, + self.value, + p_args, + p_kwargs, + example_value, + body_r, + body_graph_output_vts, + ) + + +class OutDtypeHigherOrderVariable(TorchHigherOrderOperatorVariable): + def _call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + if len(kwargs) > 0: + unimplemented( + gb_type="out_dtype: unexpected kwargs", + context=f"args: {args}, kwargs: {kwargs}", + explanation=f"out_dtype expects no keyword arguments (got {len(kwargs)}).", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + p_args = tuple(arg.as_proxy() for arg in args) + op = p_args[0] + output_dtype = p_args[1] + fake_sub_args = pytree.tree_map_only( + torch.fx.Proxy, lambda a: a.node.meta["example_value"], p_args[2:] + ) + # This is a simplified implementation of this operator just for tracing. + # Actual implementation may also first promote the arguments + example_value = op(*fake_sub_args).to(dtype=output_dtype) + + # Store the invocation as a call + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=tuple(p_args), + kwargs={}, + ), + example_value=example_value, + ) + + +class StrictModeHigherOrderVariable(TorchHigherOrderOperatorVariable): + @raise_hard_error_if_graph_break( + reason="strict_mode HOO doesn't work unless it is captured completely with torch.compile." + ) + def _call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + unpacked_sequence = args[1].unpack_var_sequence(tx) + # TODO (tmanlaibaatar) support pytree here + for arg in unpacked_sequence: + if isinstance(arg, (ListVariable, TupleVariable, ConstDictVariable)): + unimplemented( + gb_type="strict_mode: improper args", + context=f"args: {args}, kwargs: {kwargs}", + explanation="strict_mode higher order op expects flat inputs (list/tuple/dict)", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + if kwargs: + unimplemented( + gb_type="strict_mode: unexpected kwargs", + context=f"args: {args}, kwargs: {kwargs}", + explanation=f"strict_mode higher order op expects no keyword arguments (got {len(kwargs)}).", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + ( + (ret_val, ret_spec), + ret_graph, + ret_lifted_freevars, + ) = speculate_subgraph( + tx, + args[0], + unpacked_sequence, + {}, + "strict_mode", + source_target=self.value, + should_flatten_outputs=True, + ) + + strict_mode_nn_modules = dict(tx.output.nn_modules) + + strict_mode_name = tx.output.install_subgraph( + "strict_mode_body", + torch.fx.GraphModule(strict_mode_nn_modules, ret_graph), + ) + + strict_mode_node = make_attr(tx, strict_mode_name) + p_args = ( + strict_mode_node, + tuple(ret_lifted_freevars.keys()), + ) + + flat_example_value = pytree.tree_map_only( + torch.fx.Proxy, + lambda a: a.node.meta["example_value"], + ret_val.as_proxy(), + ) + + return _call_function_and_unflatten_output( + tx, + torch.ops.higher_order.strict_mode, + p_args, + {}, + flat_example_value, + ret_spec, + ret_val, + ) + + +class CheckpointHigherOrderVariable(WrapHigherOrderVariable): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.allow_side_effects = ( + torch._dynamo.config.skip_fwd_side_effects_in_bwd_under_checkpoint + ) + + def _call_function( + self, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from torch._higher_order_ops.wrap import TagActivationCheckpoint + from torch.utils.checkpoint import noop_context_fn + + context_fn = None + if "context_fn" in kwargs and kwargs["context_fn"] is not noop_context_fn: + ctx = kwargs.pop("context_fn") + if isinstance(ctx, torch._dynamo.variables.UserFunctionVariable): + context_fn = ctx.fn + elif isinstance( + ctx, torch._dynamo.variables.functions.FunctoolsPartialVariable + ): + context_fn = ctx.guard_as_python_constant() + else: + raise NotImplementedError( + f"checkpoint not implemented for {type(ctx)} context_fn" + ) + + checkpoint_kwargs, gmod_kwargs = TagActivationCheckpoint.divide_kwargs(kwargs) + + # Here we use checkpoint_kwargs (and not gmod kwargs). gmod_kwargs are + # already flattened above and managed inside the fx graph. + ( + p_args, + _, + example_value, + _body_r, + checkpointed_gmod, + _, + body_graph_output_vts, + ) = self.create_wrapped_node( + tx, + args[0], + args[1:], + gmod_kwargs, + "torch.utils.checkpoint.checkpoint", + ) + if context_fn is not None: + checkpointed_gmod.meta["_checkpoint_context_fn"] = context_fn + + _, checkpoint_kwargs = proxy_args_kwargs([], checkpoint_kwargs) + + return _call_function_with_auto_output_flattening( + tx, + self.value, + p_args, + checkpoint_kwargs, + example_value, + _body_r, + body_graph_output_vts, + ) + + +class DynamoBypassingWrapperHigherOrderVariable(WrapHigherOrderVariable): + def __init__(self, hop, source) -> None: + super().__init__(hop, source) + + def _call_function( + self, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + func_var = args[0] + + if isinstance(func_var, torch._dynamo.variables.UserFunctionVariable): + func = func_var.fn + elif isinstance( + func_var, torch._dynamo.variables.functions.FunctoolsPartialVariable + ): + func = func_var.as_python_constant() + else: + raise RuntimeError( + f"DynamoBypassingWrapperHigherOrderVariable: Unsupported function {type(func_var)}" + ) + ( + p_args, + _, + example_value, + _body_r, + gmod, + _, + body_graph_output_vts, + ) = self.create_wrapped_node( + tx, + args[1], + args[2:], + kwargs, + str(func), + ) + + # Alternatively, we could've stored only the function's fqn and + # reconstructed, but that requires the function to be a global. + gmod_meta_key = "_dynamo_bypassing_wrapper_fn" + gmod.meta[gmod_meta_key] = func + + return _call_function_with_auto_output_flattening( + tx, + self.value, + (gmod_meta_key,) + tuple(p_args), + {}, + example_value, + _body_r, + body_graph_output_vts, + ) + + +class ExportTracepointHigherOrderVariable(TorchHigherOrderOperatorVariable): + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + p_args = tuple(arg.as_proxy() for arg in args) + p_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()} + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=p_args, + kwargs=p_kwargs, + ), + example_value=None, + ) + + +class RunWithRNGStateHigherOrderVariable(TorchHigherOrderOperatorVariable): + def _call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + p_args = tuple(arg.as_proxy() for arg in args) + p_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()} + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=p_args, + kwargs=p_kwargs, + ), + example_value=None, + ) + + +class AutoFunctionalizeHigherOrderVariable(TorchHigherOrderOperatorVariable): + def _call_function( + self, tx, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]" + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + p_args = tuple(arg.as_proxy() for arg in args) + p_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()} + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=p_args, + kwargs=p_kwargs, + ), + example_value=None, + ) + + +class FlexAttentionBackwardHighOrderVariable(TorchHigherOrderOperatorVariable): + def proxy_submod(self, tx, arg): + assert isinstance(arg.source.base, DictGetItemSource) + submod_name = tx.output.install_subgraph(arg.source.base.index, arg.value) + p_submod = make_attr(tx, submod_name) + set_example_value(p_submod.node, arg.value) + return p_submod + + def to_proxy(self, tx, arg): + if isinstance(arg, UnspecializedNNModuleVariable): + return self.proxy_submod(tx, arg) + elif isinstance(arg, (ListVariable, TupleVariable)): + return arg.python_type()( + self.to_proxy(tx, nested_arg) for nested_arg in arg.items + ) + else: + return arg.as_proxy() + + def _call_function( + self, tx, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]" + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + try: + p_args = tuple(self.to_proxy(tx, arg) for arg in args) + p_kwargs = {key: self.to_proxy(tx, arg) for key, arg in kwargs.items()} + except (NotImplementedError, Unsupported) as err: + unimplemented( + gb_type="failed to handle argument for FlexAttentionBackward HOP", + context=f"args: {args}, kwargs: {kwargs}", + explanation="Missing Dynamo support for FlexAttentionBackward HOP argument.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + from_exc=err, + ) + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=p_args, + kwargs=p_kwargs, + ), + example_value=None, + ) + + +class TraceWrappedHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable): + """ + Handles torch._dynamo._trace_wrapped_higher_order_op.inner_trace + by unwrapping the higher order op and inlining through it. This op + is created by dynamo to survive through AotAutograd, then unwrapped + here in the call to dynamo from compiled autograd. + """ + + def _call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + kwargs = dict(kwargs) + fn = kwargs.pop("fn") + return fn.call_function(tx, args, kwargs) + + +class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable): + @staticmethod + def normalize_to_args(args, kwargs): + # input signature is (query, key, value, score_mod, block_mask, *other_buffers), + # block_mask is a tuple, and we don't want to flatten it. + # only flatten kwargs into lists + flat_kwargs = pytree.tree_flatten(kwargs)[0] + + # Combine the flattened lists + all_args = args + flat_kwargs + return all_args + + def create_wrapped_node( + self, + tx: "InstructionTranslator", + query: "VariableTracker", + fn: "VariableTracker", + fn_name: str, + ): + from .._trace_wrapped_higher_order_op import TransformGetItemToIndex + + def create_scalar(): + return query.call_method( + tx, + "new_empty", + (VariableTracker.build(tx, []),), + { + "dtype": VariableTracker.build(tx, torch.int32), + }, + ) + + with discard_graph_changes(tx): + bhmn = [create_scalar() for _ in range(4)] + if fn_name == "score_mod": + scores_require_grad: bool = query.requires_grad + score = query.call_method( + tx, + "new_empty", + (VariableTracker.build(tx, []),), + {"requires_grad": VariableTracker.build(tx, scores_require_grad)}, + ) + new_args = [score, *bhmn] + else: + assert fn_name == "mask_fn", "Illegal function name: " + fn_name + new_args = [*bhmn] + + with TransformGetItemToIndex(): + ( + (_body_output, _body_spec), + body_graph, + body_lifted_freevars, + ) = speculate_subgraph( + tx, + fn, + new_args, + {}, # expect only args no kwargs for now + description=fn_name, + source_target=self.value, + set_subgraph_inputs="flatten_manual", + ) + + body_name = tx.output.install_subgraph( + fn_name, + torch.fx.GraphModule(tx.output.nn_modules, body_graph), + ) + + body_node = make_attr(tx, body_name) + + # It is possible that the score-mod function captures some free variables that are not + # passed in as arguments. In this case, we need to lift them, which is handled by speculate_subgraph. + # We then need to create proxies for this + the inputs. + + lifted_args = tuple(arg for arg in body_lifted_freevars) + + proxy_args = (body_node, lifted_args) + + return proxy_args + + def _call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + ( + query, + key, + value, + score_mod, + block_mask, + scale, + kernel_options, + ) = self.normalize_to_args(args, kwargs) + + score_mod_node, score_mod_lifted_args = self.create_wrapped_node( + tx, query, score_mod, "score_mod" + ) + mask_fn = block_mask.items[-1] # type: ignore[attr-defined] + if mask_fn.is_python_constant() and mask_fn.as_python_constant() is None: + mask_fn = UserFunctionVariable( + torch.nn.attention.flex_attention.noop_mask, + source=mask_fn.source, + ) + mask_fn_node, mask_fn_lifted_args = self.create_wrapped_node( + tx, query, mask_fn, "mask_fn" + ) + + proxied_args = [ + query, + key, + value, + TupleVariable(block_mask.items[:-1], source=block_mask.source), + scale, + kernel_options, + ] + + # Store the invocation as a call + # Norm_kwargs contains the score_function and we dont want to proxy this because + # Proxying user defined functions is not supported. + inp_args, _ = proxy_args_kwargs(proxied_args, {}) + + # Compose the ordered HOO args: + # - inp_args: [query, key, value, block_mask, scale, kernel_options] + # - subgraph node: [score_mod, mask_fn_node] + # - lifted args from tracing subgraph: [score_mod_other_buffers, mask_fn_other_buffers] + _, _, _, inp_arg_block_mask, inp_arg_scale, inp_arg_kernel_options = inp_args + block_mask = tuple(inp_arg_block_mask + (mask_fn_node,)) + with torch.fx.experimental.proxy_tensor.set_original_aten_op(self.value): + proxy = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=inp_args[:3] + + ( + score_mod_node, + block_mask, + inp_arg_scale, + inp_arg_kernel_options, + score_mod_lifted_args, + mask_fn_lifted_args, + ), + kwargs={}, + ), + example_value=None, + ) + return proxy + + +class AutogradFunctionApplyVariable(VariableTracker): + def __init__(self, fwd_graph, bwd_graph, parent_source, **kwargs) -> None: + super().__init__(**kwargs) + self.fwd_graph = fwd_graph + self.bwd_graph = bwd_graph + self.parent_source = parent_source + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from . import ( + AutogradFunctionContextVariable, + UserDefinedClassVariable, + UserFunctionVariable, + UserMethodVariable, + ) + from .builder import wrap_fx_proxy + + """ + Consider the following: + class MySin(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return x.sin() + @staticmethod + def backward(ctx, grad): + x, = ctx.saved_tensors + return grad * x.cos() + We want the resulting graphs to look like: + def fwd(ctx, x): + # (output, saved tensors / attrs) + return (x.sin(), [x]) + # bwd(ctx, grad0, grad1, ..., gradn, *saved_tensors_or_attrs) + def bwd(ctx, grad, x): + return grad * x.cos() + To accomplish this, we're going to: + 1. Construct a ctx object + 2. (fwd_out, _), fwd_graph, fwd_freevars = speculate_subgraph on MySin.forward (manually_set_inputs=True) + 3. (bwd_out, _), bwd_graph, bwd_freevars = speculate_subgraph on MySin.backward, while manually setting + the ctx and grad inputs. + 4. Manually rewriting the fwd graph's output to be (output, stuff_that_gets_used in bwd_graph) + Getting from 3 to 4 is pretty elegant: stuff_that_gets_used in bwd graph is + just the bwd_freevars returned from speculate_subgraph, assuming MySin.backward + doesn't capture any arguments. + All these steps work if MySin.backward doesn't capture any values. This is a + limitation in general that we should check for. + """ + + prev_side_effects = tx.output.side_effects.clone() + fwd_tracer = torch._dynamo.output_graph.SubgraphTracer( + tx.output, + parent=tx.output.current_tracer, + source_target="autograd.Function", + ) + + ctx = AutogradFunctionContextVariable.create(tx, args, kwargs) + with discard_graph_changes(tx): + # A little hacky, but we need a dummy ctx proxy for speculate_subgraph. + # We should clean this up at some point. + proxy = tx.output.create_proxy( + "call_function", torch.autograd.function.FunctionCtx, (), {} + ) + set_example_value(proxy.node, ctx.value) + ctx.proxy = proxy + + if isinstance(self.fwd_graph, types.FunctionType): + fwd_fn = UserFunctionVariable(self.fwd_graph) + fwd_args = [ctx, *args] + elif isinstance(self.fwd_graph, types.MethodType): + fwd_fn = UserMethodVariable( + self.fwd_graph.__func__, + UserDefinedClassVariable(self.fwd_graph.__class__), + ) + fwd_args = [fwd_fn.obj, ctx, *args] + else: + unimplemented( + gb_type="autograd.Function.apply: non-function or method forward", + context=str(self.fwd_graph), + explanation="Expected forward function to be a function or method.", + hints=[], + ) + + # Speculate subgraph on the fwd + (fwd_out, _), fwd_graph, fwd_freevars = speculate_subgraph( + tx, + fwd_fn, + fwd_args, + kwargs, + "autograd.Function", + enable_grad=False, + set_subgraph_inputs="semi_automatic", + restore_side_effects=False, + tracer=fwd_tracer, + ) + + if ctx in tx.output.side_effects.store_attr_mutations: + if ( + "_materialize_non_diff_grads" + in tx.output.side_effects.store_attr_mutations[ctx] + ): + unimplemented( + gb_type="autograd.Function.apply: _materialize_non_diff_grads mutation", + context="", + explanation="Mutations to autograd.Function.ctx._materialize_non_diff_grads are not supported.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + bwd_tracer = torch._dynamo.output_graph.SubgraphTracer( + tx.output, + parent=fwd_tracer, + source_target="autograd.Function", + ) + + # Speculate subgraph on the backward. We make the + # bwd tracer a child of the fwd tracer, because backward may rely on + # tensors/attrs created in the fwd tracer. + + if isinstance(fwd_out, variables.BaseListVariable): + bwd_args = [ctx, *fwd_out.items] + else: + bwd_args = [ctx, fwd_out] + + bwd_src = AttrSource(self.parent_source, member="backward") + if isinstance(self.bwd_graph, types.FunctionType): + bwd_fn = UserFunctionVariable(self.bwd_graph, source=bwd_src) + elif isinstance(self.bwd_graph, types.MethodType): + bwd_fn = UserMethodVariable( + self.bwd_graph.__func__, + UserDefinedClassVariable(self.bwd_graph.__class__), + source=bwd_src, + ) + bwd_args = [bwd_fn.obj, *bwd_args] + else: + unimplemented( + gb_type="autograd.Function.apply: non-function or method backward", + context=str(self.bwd_graph), + explanation="Expected backward function to be a function or method.", + hints=[], + ) + + def is_strict_for(v: VariableTracker): + if v.is_tensor(): + # we can be more lax for stuff from forward + return v.proxy.tracer is not fwd_tracer + return True + + with ( + tx.output.subtracer(fwd_fn, fwd_tracer), + tx.strict_translation_mode(is_strict_for), + ): + try: + (bwd_out, _), bwd_graph, bwd_freevars = speculate_subgraph( + tx, + bwd_fn, + bwd_args, + kwargs, + "autograd.Function", + enable_grad=False, + set_subgraph_inputs="manual", + restore_side_effects=False, + tracer=bwd_tracer, + ) + except torch._dynamo.exc.Unsupported as e: + if isinstance( + e, torch._dynamo.exc.UnknownPropertiesDuringBackwardTrace + ): + from unittest import mock + + bwd_tracer = torch._dynamo.output_graph.SubgraphTracer( + tx.output, + parent=fwd_tracer, + source_target="autograd.Function", + ) + from .._trace_wrapped_higher_order_op import ( + autograd_function_backward_rewritten, + ) + + if isinstance(self.bwd_graph, types.FunctionType): + bwd_fn = UserFunctionVariable( + autograd_function_backward_rewritten(self.bwd_graph) + ) + elif isinstance(self.bwd_graph, types.MethodType): + bwd_fn = UserMethodVariable( + autograd_function_backward_rewritten( + self.bwd_graph.__func__ + ), + UserDefinedClassVariable(self.bwd_graph.__class__), + ) + else: + unimplemented( + gb_type="autograd.Function.apply: non-function or method backward (2)", + context=str(self.bwd_graph), + explanation="Expected backward function to be a function or method.", + hints=[], + ) + + with mock.patch( + "torch._dynamo.config._autograd_backward_strict_mode_conditional_banned_ops", + [], + ): + (bwd_out, _), bwd_graph, bwd_freevars = speculate_subgraph( + tx, + bwd_fn, + bwd_args, + kwargs, + "autograd.Function", + enable_grad=False, + set_subgraph_inputs="manual", + restore_side_effects=False, + tracer=bwd_tracer, + ) + else: + raise e + + # TODO: assert that bwd_graph didn't capture values that were + # not created inside fwd_graph. + + # TODO(oulgen): Ideally, we would not do a linear search for output + # node but as things currently are there could be nodes after the + # output node + # This is bug prone as if there's code after the output node, then + # graph.output will append the output at the very end + # This might be a behavior difference + + # If users call ctx.mark_non_differentiable, we should capture these output tensors who + # are marked as non-differentiable and pass them to ApplyTemplate + # at torch._functorch.autograd_function.AutogradFunctionApply for reconstruction. + non_differentiable_idx = [] + if ctx.non_differentiable is not None: + non_differentiable_set = set(ctx.non_differentiable) + assert isinstance(fwd_out, variables.BaseListVariable) + for i, x in enumerate(fwd_out.items): + if x.is_tensor() and x.as_proxy() in non_differentiable_set: + non_differentiable_idx.append(i) + + # Rewrite the output of fwd_graph to (output, stuff_necessary_for_bwd) + for node in fwd_graph.find_nodes(op="output"): + fwd_graph.erase_node(node) + break + + # Because we lift the bwd_freevars as inputs of the bwd_graph, + # we have to manually add the bwd_freevars as output of fwd_graph. + # However, the bwd_freevars got from speculate_subgraph use the Proxies in the bwd_graph, + # we need to convert them to Proxies in the fwd_graph and then generate new fwd_graph output. + fwd_proxy_of_bwd_freevars = [] + for k in bwd_freevars: + if k in fwd_freevars: + fwd_proxy_of_bwd_freevars.append(fwd_freevars[k]) + else: + fwd_proxy_of_bwd_freevars.append(k) + + def unwrap_proxy(x): + if isinstance(x, torch.fx.Proxy): + return x.node + else: + assert variables.ConstantVariable.is_literal(x), ( + f"Only constant is allowed. Got {x}" + ) + return x + + new_fwd_graph_outputs = (fwd_out.as_proxy(), fwd_proxy_of_bwd_freevars) + new_fwd_graph_outputs = pytree.tree_map(unwrap_proxy, new_fwd_graph_outputs) + fwd_graph.output(new_fwd_graph_outputs) + fwd_graph.lint() + + # Store fwd_body + fwd_nn_modules = tx.output.tracing_context.module_context.copy_graphstate() + fwd_name = tx.output.install_subgraph( + "fwd_body", + torch.fx.GraphModule(fwd_nn_modules.nn_modules, fwd_graph), + ) + + fwd_node = make_attr(tx, fwd_name) + + # The type of original args can be arbitrary, but we only support basic type in FX graph. + # So the speculated subgraph input includes original tensor args and the lifted freevars. + # We need to filter out the original tensor args and concat them with the lifted freevars + # to generate the proxy args for the FX call_function node. + filtered_args = [] + # A boolean list to mark if the type of corresponding argument is tensor. + # This is used to determine if a FX node's argument should be an argument of + # ApplyTemplate.forward and if we should skip the output from ApplyTemplate.backward + # at torch._functorch.autograd_function.AutogradFunctionApply. + args_tensor_mask = [False] * len(args) + for i, arg in enumerate(args): + if arg.is_tensor() or isinstance(arg, SymNodeVariable): + filtered_args.append(arg) + args_tensor_mask[i] = True + + # Rewrite the output of bwd_graph to remove the grad output for the non-Tensor args. + new_bwd_graph_outputs = None + for node in bwd_graph.find_nodes(op="output"): + bwd_graph.erase_node(node) + break + + # The same as the above fwd proxies, we need to use the bwd proxies in the bwd_graph + # if some of the output is from fwd_freevars. + bwd_out_proxy = bwd_out.as_proxy() + bwd_proxy_of_fwd_freevars = [] + if isinstance(bwd_out_proxy, (tuple, list)): + for k in bwd_out_proxy: + if k in bwd_freevars: + bwd_proxy_of_fwd_freevars.append(bwd_freevars[k]) + else: + bwd_proxy_of_fwd_freevars.append(k) + else: + if bwd_out_proxy in bwd_freevars: + bwd_proxy_of_fwd_freevars = bwd_freevars[bwd_out_proxy] + else: + bwd_proxy_of_fwd_freevars = bwd_out_proxy + + # Remove bwd output for non-Tensor args. + output_proxy = bwd_proxy_of_fwd_freevars + if isinstance(output_proxy, (tuple, list)): + new_bwd_graph_outputs = () + for x, mask in zip(output_proxy, args_tensor_mask): + if mask: + new_bwd_graph_outputs = new_bwd_graph_outputs + (x,) + else: + assert x is None, f"Grad of non-Tensor arg {x} is not None." + else: + new_bwd_graph_outputs = output_proxy + + # Update the bwd graph output. + new_bwd_graph_outputs = pytree.tree_map( + lambda x: None if x is None else x.node, new_bwd_graph_outputs + ) + bwd_graph.output(new_bwd_graph_outputs) + bwd_graph.lint() + + # Store bwd_body + bwd_nn_modules = tx.output.tracing_context.module_context.copy_graphstate() + bwd_name = tx.output.install_subgraph( + "bwd_body", + torch.fx.GraphModule(bwd_nn_modules.nn_modules, bwd_graph), + ) + + bwd_node = make_attr(tx, bwd_name) + + tx.output.side_effects = prev_side_effects + + p_args = ( + fwd_node, + bwd_node, + *([arg.as_proxy() for arg in filtered_args] + list(fwd_freevars.keys())), + ) + kwargs = { + "args_tensor_mask": args_tensor_mask, + "non_differentiable_idx": non_differentiable_idx, + } + + # Store the invocation as a call + from torch._functorch.autograd_function import autograd_function_apply + + # We use speculate_subgraph to get the fwd graph, but it's always under no grad mode like what eager mode does. + # The fwd outputs (tensor's example_value) need to be inferred from fake tensor prop to get the correct attributes + # (e.g, tensor.requires_grad), which would be used by downstream Dynamo tracing. + # Since there can be other ops like Triton kernels, which depends on python dispatcher, we have to enable it. + with enable_python_dispatcher(), tx.output.fake_mode: + fake_args = ( + tx.output.nn_modules[fwd_node.node.name], + tx.output.nn_modules[bwd_node.node.name], + *( + [ + _get_fake_value(arg) + for arg in filtered_args + list(fwd_freevars.keys()) + ] + ), + ) + example_value = autograd_function_apply(*fake_args, **kwargs) + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + autograd_function_apply, + args=p_args, + kwargs=kwargs, + ), + example_value=example_value, + ) + + +def _get_fake_value(x): + if isinstance(x, variables.VariableTracker): + return x.as_proxy().node.meta["example_value"] + elif isinstance(x, torch.fx.Proxy): + return x.node.meta["example_value"] + else: + return x + + +def maybe_positional_arg_names(func): + result = [] + if not hasattr(func, "get_function"): + return None + try: + fn = func.get_function() + except (Unsupported, NotImplementedError): + return None + try: + sig = inspect.signature(fn) + except ValueError: + return None + for name, param in sig.parameters.items(): + if param.kind is inspect.Parameter.VAR_POSITIONAL: + return None + if ( + param.kind is inspect.Parameter.POSITIONAL_ONLY + or param.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD + ): + if name == "self": + # FX graphs can't have a placeholder named self + result.append("self_") + else: + result.append(name) + return result + + +class BaseHOPVariable(WrapHigherOrderVariable): + supports_input_mutation = False + supports_aliasing = False + + def python_type(self): + return type(self.value) + + def _call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + ( + p_args, + p_kwargs, + example_value, + body_r, + _, + _, + body_graph_output_vts, + ) = self.create_wrapped_node( + tx, args[0], args[1:], {}, self.value._name, subgraph_name="subgraph" + ) + assert len(p_kwargs) == 0 + + p_kwargs = {key: value.as_proxy() for key, value in kwargs.items()} + return _call_function_with_auto_output_flattening( + tx, + self.value, + p_args, + p_kwargs, + example_value, + body_r, + body_graph_output_vts, + ) + + +class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable): + supports_input_mutation = True + supports_aliasing = False + allow_side_effects = True + # invoke_subgraph is NOT desugared in AOTAutograd, so the HOP input/output + # shouldn't alias. For checkpoint HOP, we inline it so we don't need + # alias analysis as functionalization would just work on the flat graph. + filter_aliased_intermediates = True + + def install_subgraph_in_output_graph( + self, tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name + ): + # Check if the subgraph from speculate_subgraph (body_gmod) and the fake + # inputs have already been seen before. If yes, the subgraph is already + # installed in the output graph and we can just access the subgraph + # using the saved attr name. + + if not isinstance(fn_vt, (UnspecializedNNModuleVariable, UserFunctionVariable)): + unimplemented( + gb_type="Encountered non user function variable during invoke_subgraph HOP tracing", + context=str(fn_vt), + explanation="invoke_subgraph does not support non user function variable", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + invoke_subgraph_cache = ( + tx.output.tracing_context.hop_dispatch_set_cache.get_cache( + torch._higher_order_ops.invoke_subgraph + ) + ) + + if isinstance(fn_vt, UserFunctionVariable): + fn_id = id(fn_vt.get_function()) + fn_name = fn_vt.get_function().__name__ + else: + assert isinstance(fn_vt, UnspecializedNNModuleVariable) + fn_id = id(fn_vt.value.forward.__func__) + fn_name = fn_vt.value.forward.__name__ + previously_installed_submodules = [] + if invoke_subgraph_cache: + previously_installed_submodules = ( + invoke_subgraph_cache.get_dynamo_installed_submodules(fn_id) + ) + current_mod = body_gmod + # NB - reverse is more likely to cause a hit sooner because first + # graph can have requires_grad=False for a few inputs + for submodule_name in reversed(previously_installed_submodules): + assert submodule_name in tx.output.nn_modules + previous_mod = tx.output.nn_modules[submodule_name] + if are_same_graph_modules( + fn_name, previous_mod, current_mod, tx.fake_mode + ): + return submodule_name + + body_name = super().install_subgraph_in_output_graph( + tx, fn_vt, fn_args_vt, kwargs, body_gmod, "subgraph" + ) + hc_log.debug( + "%s: Installing subgraph with identifier '%s', bringing total count for '%s' function to %s", + fn_name, + body_name, + fn_name, + len(previously_installed_submodules) + 1, + ) + if invoke_subgraph_cache: + invoke_subgraph_cache.add_dynamo_installed_submodule(fn_id, body_name) + + return body_name + + @raise_hard_error_if_graph_break( + reason="torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph", + ) + def _call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + # This flattens the kwargs into lifted args + ( + p_args, + p_kwargs, + example_value, + body_r, + _, + body_name, + body_graph_output_vts, + ) = self.create_wrapped_node(tx, args[0], args[1:], kwargs, "invoke_subgraph") + + if len(p_kwargs) > 0: + unimplemented( + gb_type="invoke_subgraph: kwargs unexpected", + context=f"args: {args}, kwargs: {kwargs}", + explanation="kwargs should have been flattened into lifted args.", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + + p_args = ( + p_args[0], + body_name, + *p_args[1:], + ) + return _call_function_with_auto_output_flattening( + tx, + torch._higher_order_ops.invoke_subgraph, + tuple(p_args), + p_kwargs, + example_value, + body_r, + body_graph_output_vts, + ) + + +class LocalMapWrappedHigherOrderVariable(WrapHigherOrderVariable): + supports_input_mutation = False + supports_aliasing = False + + # Subclasses aren't supported by speculate_subgraph yet + # So this HOP is only usable with plain tensors + _enabled = False + + @classmethod + @contextlib.contextmanager + def enable(cls): + """Context manager to temporarily enable local map wrapping. + Will be removed when speculate_subgraph supports subclass inputs: + https://github.com/pytorch/pytorch/issues/161456. + + Usage: + with LocalMapWrappedHigherOrderVariable.enable_wrapping(): + # Code where should_wrap_in_hop will return True + pass + """ + old_value = cls._enabled + cls._enabled = True + try: + yield + finally: + cls._enabled = old_value + + @classmethod + def should_wrap_in_hop(cls, value): + if not torch.distributed.is_available(): + return False + + from torch.distributed.tensor.experimental._func_map import _local_map_wrapped + + # check is important to avoid subclass dispatch + if type(value) is not type(_local_map_wrapped): + return False + + return value is _local_map_wrapped and cls._enabled + + @staticmethod + def build(**options): + return TorchHigherOrderOperatorVariable.make( + torch._higher_order_ops.local_map_hop, + **options, + ) + + def python_type(self): + return type(self.value) + + def _call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + """ + Goal of this function is to rewrite local_map usage as a HOP: + local_map(func, ...) -> local_map_hop(gm, ...) + """ + + ( + user_func, + out_placements, + in_placements, + in_grad_placements, + device_mesh, + redistribute_inputs, + *user_args, + ) = args + + # None placements are used to pass non-Tensors into the local_map function. + # Containers passed this way can not hold tensors. Thus, Dynamo would have inlined + # into them, and we handle None placements by assuming they will be desugared away. + # This will need to be adjusted for dynamic shapes support. + def check_none_last(placements): + seen_none = 0 + for p in placements: + if p is None: + seen_none += 1 + else: + assert seen_none == 0, ( + "Tracing local_map is only currently supported with None placements last." + ) + return seen_none + + inputs_none_placements = check_none_last(in_placements.value) + output_none_placements = check_none_last(out_placements.value) + + local_map_kwargs = { + "out_placements": out_placements.value, + "in_placements": in_placements.value, + "redistribute_inputs": redistribute_inputs.value, + "in_grad_placements": in_grad_placements.value, + "device_mesh": device_mesh.value, + } + assert local_map_kwargs["device_mesh"] is not None, ( + "Not yet implemented, please manually provide a device_mesh to local_map." + ) + mesh = local_map_kwargs["device_mesh"] + + # For Autoparallel, the initial trace is done with global shapes, then we decide model weights sharding, + # and reuse the graph. Since the sharding decision is after the initial trace, we can't trace with local shapes. + # For local_map however, since we specify all placements, we can trace with local shapes. + + # Step 1: Validate the annotated function matches the input_placements (i.e. that it can run in eager) + template = ( + "Expecting {expected} {inputs_or_outputs} to local_map function based on placements" + ", but found {actual}. Please ensure the count matches for eager. " + ) + assert len(in_placements.value) == len(user_args), template.format( + expected=len(in_placements.value), + inputs_or_outputs="inputs", + actual=len(user_args), + ) + + from torch._higher_order_ops.local_map import ( + redistribute_fw_inputs, + redistribute_fw_outputs, + ) + + # Step 2: Convert inputs to local shapes + priors = {} + for placements, vt in zip(in_placements.value, user_args): + if isinstance(vt, variables.lazy.LazyVariableTracker): + vt = variables.lazy.LazyVariableTracker.realize_all(vt) + + if not vt.is_tensor(): + assert placements is None + continue + + global_tensor = vt.as_proxy().node.meta["example_value"] + # NOTE: We don't support local_map region relying on exact grad_fn information + # This is okay since accessing grad_fn is a graph break. + local_tensor = redistribute_fw_inputs( + (global_tensor,), + (placements,), + mesh, + ) + local_tensor = local_tensor[0] + + priors[vt] = global_tensor + vt.as_proxy().node.meta["example_value"] = local_tensor + vt.synchronize_attributes(tx) + + # Step 3: Trace local_map subgraph with local tensors + ( + p_args, + p_kwargs, + example_value, + body_r, + body_gmod, + body_name, + body_graph_output_vts, + ) = self.create_wrapped_node( + tx, user_func, user_args, kwargs, self.value._name, subgraph_name="subgraph" + ) + + # Step 4: Validate traced graph signature still matches placement information + expected_num_inputs = len(in_placements.value) - inputs_none_placements + actual_num_inputs = len(body_gmod.graph.find_nodes(op="placeholder")) + expected_num_outputs = len(out_placements.value) - output_none_placements + assert len(body_gmod.graph.find_nodes(op="output")) == 1 + actual_num_outputs = len(body_gmod.graph.find_nodes(op="output")[0].args[0]) + + template = ( + "Expecting {expected} {inputs_or_outputs} to local_map function based on placements" + ", but found {actual}. If the count matches for eager, " + "Dynamo may have flattened {inputs_or_outputs} to the function or found additional " + "tensors used via closures. " + "Please adjust the input placements to match what the traced graph sees: \n{gm_str}." + ) + + def make_error_msg(*args): + expected_num, actual_num, inputs_or_outputs = args + gm_str = body_gmod.print_readable(print_output=False) + return template.format( + expected=expected_num, + inputs_or_outputs=inputs_or_outputs, + actual=actual_num, + gm_str=gm_str, + ) + + if expected_num_inputs != actual_num_inputs: + raise AssertionError( + make_error_msg(expected_num_inputs, actual_num_inputs, "inputs") + ) + if expected_num_outputs != actual_num_outputs: + raise AssertionError( + make_error_msg(expected_num_outputs, actual_num_outputs, "outputs") + ) + + if inputs_none_placements > 0: + expected_input_nodes = [ + arg.as_proxy().node for arg in user_args[:-inputs_none_placements] + ] + else: + expected_input_nodes = [arg.as_proxy().node for arg in user_args] + actual_input_nodes = [proxy.node for proxy in p_args] + assert actual_input_nodes[0].op == "get_attr" + assert "subgraph" in actual_input_nodes[0].target + assert len(expected_input_nodes) == len(actual_input_nodes) - 1 + for expected_order, actual_order in zip( + expected_input_nodes, actual_input_nodes[1:] + ): + assert expected_order == actual_order, ( + "Dynamo changed the order of inputs to the local_map function, please adjust " + f"the order of inputs and input_placements from {expected_input_nodes}, to: {actual_input_nodes[1:]}" + ) + assert len(p_kwargs) == 0 + + # Step 5: Install local_map subgraph + p_kwargs = {key: value.as_proxy() for key, value in kwargs.items()} + out = _call_function_with_auto_output_flattening( + tx, + self.value, + p_args, + p_kwargs, + example_value, + body_r, + body_graph_output_vts, + ) + + # Step 6: Restore inputs and outputs to global shapes + for vt, global_tensor in priors.items(): + vt.as_proxy().node.meta["example_value"] = global_tensor + vt.synchronize_attributes(tx) + + outs = out.items if isinstance(out, TupleVariable) else [out] + assert len(outs) == len(out_placements.value) + for placements, vt in zip(out_placements.value, outs): + if not vt.is_tensor(): + assert placements is None + continue + + local_tensor = vt.as_proxy().node.meta["example_value"] + + # NOTE: We don't support code after the local_map region relying on exact grad_fn information + # This is okay since accessing grad_fn is a graph break. + global_tensor = redistribute_fw_outputs( + (local_tensor,), + (placements,), + mesh, + num_activations=0, # this is not the joint + ) + global_tensor = global_tensor[0] + + vt.as_proxy().node.meta["example_value"] = global_tensor + vt.synchronize_attributes(tx) + + # TODO: Figure out how to handle output order diverging from eager + + # Treat as const, so we don't have to deal with Placement types in fx IR + # Guarded with EQUALS_MATCH on local_map call's arguments + body_gmod.meta["local_map_kwargs"] = { + "out_placements": out_placements.value[:expected_num_outputs], + "in_placements": in_placements.value[:expected_num_inputs], + "redistribute_inputs": redistribute_inputs.value, + "in_grad_placements": in_grad_placements.value, + "device_mesh": device_mesh.value, + } + + return out + + +# Map operator names to their corresponding variable for fast TorchHigherOrderOperatorVariable.make() +_hop_name_to_variable_class = { + "cond": CondHigherOrderVariable, + "while_loop": WhileLoopHigherOrderVariable, + "while_loop_stack_output": WhileLoopStackOutputHigherOrderVariable, + "map_impl": MapHigherOrderVariable, + "executorch_call_delegate": ExecutorchCallDelegateHigherOrderVariable, + "out_dtype": OutDtypeHigherOrderVariable, + "wrap": WrapHigherOrderVariable, + "hints_wrapper": HintsWrapperHigherOrderVariable, + "flex_attention": FlexAttentionHigherOrderVariable, + "flex_attention_backward": FlexAttentionBackwardHighOrderVariable, + "wrap_activation_checkpoint": CheckpointHigherOrderVariable, + "tag_activation_checkpoint": CheckpointHigherOrderVariable, + "_export_tracepoint": ExportTracepointHigherOrderVariable, + "trace_wrapped": TraceWrappedHigherOrderOperatorVariable, + "strict_mode": StrictModeHigherOrderVariable, + "run_with_rng_state": RunWithRNGStateHigherOrderVariable, + "associative_scan": AssociativeScanHigherOrderVariable, + "scan": ScanHigherOrderVariable, + "call_torchbind": CallTorchbindHigherOrderVariable, + "print": PrintHigherOrderVariable, + "wrap_with_set_grad_enabled": WrapWithSetGradEnabledHigherOrderVariable, + "wrap_with_autocast": WrapWithAutocastHigherOrderVariable, + "dynamo_bypassing_wrapper": DynamoBypassingWrapperHigherOrderVariable, + "auto_functionalized": AutoFunctionalizeHigherOrderVariable, + "auto_functionalized_v2": AutoFunctionalizeHigherOrderVariable, + "invoke_subgraph": InvokeSubgraphHigherOrderVariable, + "custom_function_call": CustomFunctionHigherOrderOperatorVariable, + "local_map_hop": LocalMapWrappedHigherOrderVariable, +} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/iter.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/iter.py new file mode 100644 index 0000000000000000000000000000000000000000..4a3c0247add1b44329a2555ce49341fe75602ba2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/iter.py @@ -0,0 +1,620 @@ +""" +This module provides iterator-related variable tracking functionality for Dynamo. +It implements variable classes for handling Python iterators and itertools functions +during symbolic execution and tracing. + +The module includes: +- Base iterator variable classes for tracking iterator state +- Implementations of built-in iterators (zip, map, filter) +- Support for itertools functions (product, accumulate, combinations, etc.) +- Mutation tracking and reconstruction capabilities for iterator operations + +These classes integrate with Dynamo's variable tracking system to enable proper +handling of iterator operations during code transformation and optimization. +""" + +import itertools +import sys +from collections.abc import Callable, Sequence +from typing import Any, TYPE_CHECKING, Union + +from .. import graph_break_hints, polyfills, variables +from ..bytecode_transformation import ( + create_build_tuple, + create_call_function, + create_call_function_ex, + create_instruction, +) +from ..exc import ( + handle_observed_exception, + ObservedUserStopIteration, + raise_observed_exception, + unimplemented, + UserError, +) +from .base import ValueMutationNew, VariableTracker +from .constant import ConstantVariable + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + +MAX_ITERATOR_LIMIT = 100 * 1024 # 100k + + +class ItertoolsVariable(VariableTracker): + def __init__(self, value: Any, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.value = value + + def __repr__(self) -> str: + return f"ItertoolsVariable({self.value})" + + def as_python_constant(self) -> Any: + return self.value + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence["VariableTracker"], + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + # See also: module `torch._dynamo.polyfills.itertools` + + if self.value is itertools.product: + if any(kw != "repeat" for kw in kwargs): + unimplemented( + gb_type="Unsupported kwargs for itertools.product", + context=f"call_function {self} {args} {kwargs}", + explanation=f"Expected kwargs: 'repeat', but got " + f"{','.join(set(kwargs.keys()) - {'repeat'})}", + hints=[*graph_break_hints.USER_ERROR], + ) + + if "repeat" in kwargs: + r = kwargs["repeat"].as_python_constant() + else: + r = 1 + seqs = [arg.force_unpack_var_sequence(tx) for arg in args] + items = [ + variables.TupleVariable(list(item)) + for item in itertools.product(*seqs, repeat=r) + ] + return variables.ListIteratorVariable( + items, # type: ignore[arg-type] + mutation_type=ValueMutationNew(), + ) + elif ( + self.value is itertools.combinations + and not kwargs + and len(args) == 2 + and args[0].has_unpack_var_sequence(tx) + and args[1].is_python_constant() + ): + iterable = args[0].unpack_var_sequence(tx) + r = args[1].as_python_constant() + + items = [] + for item in itertools.combinations(iterable, r): + items.append(variables.TupleVariable(list(item))) + return variables.ListIteratorVariable( + items, # type: ignore[arg-type] + mutation_type=ValueMutationNew(), + ) + elif self.value is itertools.groupby: + if any(kw != "key" for kw in kwargs): + unimplemented( + gb_type="Unsupported kwargs for itertools.groupby", + context=f"call_function {self} {args} {kwargs}", + explanation=f"Expected kwargs: 'key', but got " + f"{','.join(set(kwargs.keys()) - {'key'})}", + hints=[*graph_break_hints.USER_ERROR], + ) + + def retrieve_const_key(key: VariableTracker) -> Any: + if isinstance(key, variables.SymNodeVariable): + return key.evaluate_expr() + elif key.is_python_constant(): + return key.as_python_constant() + else: + unimplemented( + gb_type="Unsupported key type for itertools.groupby", + context=f"call_function {self} {args} {kwargs}", + explanation="Dynamo does not know how to trace " + f"itertools.groupby with key type: {str(type(key))}. " + "We only support grouping keys that are constants (int, float, str, etc.)", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + if len(args) == 1 and args[0].has_unpack_var_sequence(tx): + seq = args[0].unpack_var_sequence(tx) + else: + unimplemented( + gb_type="Unsupported arguments for itertools.groupby", + context=f"call_function {self} {args} {kwargs}", + explanation="Dynamo does not know how to trace " + f"itertools.groupby with args: {args} and kwargs: {kwargs}. " + "itertools.groupby expects an iterable to group and an " + "optional key function to determine groupings.", + hints=[ + "Make sure the arguments to itertools.groupby are correct.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + if "key" in kwargs: + + def keyfunc(x: VariableTracker) -> Any: + return retrieve_const_key( + kwargs.get("key").call_function(tx, [x], {}) # type: ignore[union-attr] + ) + + else: + + def keyfunc(x: VariableTracker) -> Any: + return retrieve_const_key(x) + + result = [] + try: + # pyrefly: ignore [unbound-name] + for k, v in itertools.groupby(seq, key=keyfunc): + result.append( + variables.TupleVariable( + [ + ( + variables.ConstantVariable.create(k) + if variables.ConstantVariable.is_literal(k) + else k + ), + variables.ListIteratorVariable( + list(v), mutation_type=ValueMutationNew() + ), + ], + mutation_type=ValueMutationNew(), + ) + ) + except Exception as e: + unimplemented( + gb_type="Unexpected failure during itertools.groupby() iteration", + context=f"call_function {self} {args} {kwargs}", + explanation="Unexpected failure in invoking function during groupby", + hints=[*graph_break_hints.SUPPORTABLE], + from_exc=e, + ) + return variables.ListIteratorVariable( + result, # type: ignore[arg-type] + mutation_type=ValueMutationNew(), + ) + elif self.value is itertools.repeat: + if len(args) < 2: + return variables.RepeatIteratorVariable( + *args, mutation_type=ValueMutationNew() + ) + + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.repeat), args, kwargs + ) + elif self.value is itertools.count: + return variables.CountIteratorVariable( + *args, mutation_type=ValueMutationNew() + ) + elif ( + self.value is itertools.permutations + and (len(args) == 1 or (len(args) == 2 and args[1].is_python_constant())) + and not kwargs + ): + if len(args) == 2: + r = args[1].as_python_constant() + else: + r = None + items = [ + variables.TupleVariable(list(item)) + for item in itertools.permutations( + args[0].force_unpack_var_sequence(tx), r + ) + ] + return variables.ListIteratorVariable( + items, # type: ignore[arg-type] + mutation_type=ValueMutationNew(), + ) + else: + return super().call_function(tx, args, kwargs) + + +class IteratorVariable(VariableTracker): + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: + unimplemented( + gb_type="Unimplemented next() call", + context=f"next({self})", + explanation="This abstract method must be implemented", + hints=[*graph_break_hints.DYNAMO_BUG], + ) + + # NOTE: only call when unpacking this iterator safely done eagerly! + # Normally, iterators are accessed lazily. + # Example of safe eager unpacking: list(map(f, seq)) + # Example of unsafe eager unpacking: list(islice(map(f, seq), 5)) + def force_unpack_var_sequence( + self, tx: "InstructionTranslator" + ) -> list[VariableTracker]: + result: list[VariableTracker] = [] + self.force_apply_to_var_sequence(tx, result.append) + return result + + def force_apply_to_var_sequence( + self, tx: "InstructionTranslator", fn: Callable[[Any], Any] + ) -> None: + while True: + try: + fn(self.next_variable(tx)) + except ObservedUserStopIteration: + handle_observed_exception(tx) + break + + # don't call force_unpack_var_sequence since it can mutate + # IteratorVariable state! + def has_force_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: + return True + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> "ConstantVariable": + if name == "__iter__" or name == "__next__": + return variables.ConstantVariable.create(True) + return super().call_obj_hasattr(tx, name) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__iter__": + return self + elif name == "__next__": + return self.next_variable(tx) + return super().call_method(tx, name, args, kwargs) + + +class ObjectIteratorVariable(IteratorVariable): + """ + VariableTracker for iter(obj) that implements the iterator protocol (i.e., + has a `__next__` method). + + We use this class to track the state of the iterator and handle the case + when the iterator is exhausted: + + Example usage: + > b = iter(obj) + > list(b) # exhaust the iterator + > list(b) # empty list + """ + + def __init__(self, obj: VariableTracker, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.obj = obj + self.generator_exhausted = False + + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: + if self.generator_exhausted: + raise_observed_exception(StopIteration, tx) + + try: + return self.obj.next_variable(tx) + except ObservedUserStopIteration: + # Do not rely on the object to always return StopIteration once it + # is exhausted. + self.generator_exhausted = True + raise + + +class RepeatIteratorVariable(IteratorVariable): + def __init__(self, item: VariableTracker, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.item = item + + # Repeat needs no mutation, clone self + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: + return self.item + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_python_module(itertools), + codegen.create_load_attr("repeat"), + ] + ) + ) + codegen(self.item) + codegen.extend_output(create_call_function(1, False)) + + +class CountIteratorVariable(IteratorVariable): + def __init__( + self, + item: Union[int, VariableTracker] = 0, + step: Union[int, VariableTracker] = 1, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if not isinstance(item, VariableTracker): + item = ConstantVariable.create(item) + if not isinstance(step, VariableTracker): + step = ConstantVariable.create(step) + self.item = item + self.step = step + + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: + assert self.is_mutable() + old_item = self.item + tx.output.side_effects.mutation(self) + self.item = self.item.call_method(tx, "__add__", [self.step], {}) + return old_item + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_python_module(itertools), + codegen.create_load_attr("count"), + ] + ) + ) + codegen(self.item) + codegen(self.step) + codegen.extend_output(create_call_function(2, False)) + + +class ZipVariable(IteratorVariable): + """ + Represents zip(*iterables) + """ + + _nonvar_fields = { + "index", + "strict", + *IteratorVariable._nonvar_fields, + } + + def __init__( + self, + iterables: list[VariableTracker], + strict: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + assert isinstance(iterables, list) + # can be list[Variable] or VariableTracker (with next_variable implemented) + self.iterables = iterables + self.index = 0 + self.strict = strict + + def python_type(self) -> type[zip]: # type: ignore[type-arg] + return zip + + def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: + return all( + isinstance(it, list) or it.has_unpack_var_sequence(tx) + for it in self.iterables + ) + + def unpack_var_sequence( + self, tx: "InstructionTranslator" + ) -> list["VariableTracker"]: + assert self.has_unpack_var_sequence(tx) + iterables = [] + for it in self.iterables: + if isinstance(it, list): + iterables.append(it[self.index :]) + else: + iterables.append(it.unpack_var_sequence(tx)) + kwargs = {"strict": self.strict} if self.strict else {} + zipped = zip(*iterables, **kwargs) + return [variables.TupleVariable(list(var)) for var in zipped] + + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: + assert self.is_mutable() + + if len(self.iterables) == 0: + raise_observed_exception(StopIteration, tx) + + old_index = self.index + args = [] + + def get_item( + it: Union[list[VariableTracker], VariableTracker], + ) -> VariableTracker: + if isinstance(it, list): + if old_index >= len(it): + raise_observed_exception(StopIteration, tx) + return it[old_index] + else: + return it.next_variable(tx) + + idx: int | None = None + try: + for idx, it in enumerate(self.iterables): # noqa:B007 + args.append(get_item(it)) + except ObservedUserStopIteration: + if self.strict: + if idx == 0: + # all other iterables should be exhausted + for it in self.iterables: + try: + get_item(it) + except ObservedUserStopIteration: + handle_observed_exception(tx) + continue + # no ObservedUserStopIteration - fall through to UserError + break + else: + # all iterables exhausted, raise original error + raise + handle_observed_exception(tx) + raise UserError( + ValueError, # type: ignore[arg-type] + "zip() has one argument of len differing from others", + ) from None + raise + + tx.output.side_effects.mutation(self) + self.index += 1 + return variables.TupleVariable(args) + + def reconstruct_items(self, codegen: "PyCodegen") -> None: + for it in self.iterables: + if isinstance(it, list): + remaining_items = it[self.index :] + codegen.foreach(remaining_items) + codegen.append_output(create_build_tuple(len(remaining_items))) + else: + codegen(it) + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null( + lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True + ) + self.reconstruct_items(codegen) + codegen.append_output(create_build_tuple(len(self.iterables))) + codegen.extend_output( + [ + codegen.create_load_const("strict"), + codegen.create_load_const(self.strict), + create_instruction("BUILD_MAP", arg=1), + *create_call_function_ex(True, False), + ] + ) + + +class MapVariable(ZipVariable): + """ + Represents map(fn, *iterables) + """ + + def __init__( + self, + fn: VariableTracker, + iterables: list[VariableTracker], + **kwargs: Any, + ) -> None: + super().__init__(iterables, **kwargs) + self.fn = fn + + def python_type(self) -> type: + return map + + def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: + return False + + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: + args = super().next_variable(tx) + return self.fn.call_function(tx, args.items, {}) # type: ignore[attr-defined] + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null( + lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True + ) + codegen(self.fn) + self.reconstruct_items(codegen) + codegen.append_output(create_build_tuple(len(self.iterables) + 1)) + if self.strict: + assert sys.version_info >= (3, 14), ( + "Unexpected bug: map(strict=True) requires Python 3.14+" + ) + codegen.extend_output( + [ + codegen.create_load_const("strict"), + codegen.create_load_const(self.strict), + create_instruction("BUILD_MAP", arg=1), + *create_call_function_ex(True, False), + ] + ) + else: + codegen.extend_output(create_call_function_ex(False, False)) + + +class FilterVariable(IteratorVariable): + """ + Represents filter(fn, iterable) + """ + + _nonvar_fields = { + "index", + *IteratorVariable._nonvar_fields, + } + + def __init__( + self, + fn: VariableTracker, + iterable: list[VariableTracker], + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.fn = fn + self.iterable = iterable + self.index = 0 + + def python_type(self) -> type: + return filter + + def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: + return isinstance(self.iterable, list) or self.iterable.has_unpack_var_sequence( + tx + ) + + def unpack_var_sequence( + self, tx: "InstructionTranslator" + ) -> list["VariableTracker"]: + assert self.has_unpack_var_sequence(tx) + it = None + if isinstance(self.iterable, list): + it = self.iterable[self.index :] + else: + it = self.iterable.unpack_var_sequence(tx) + filtered = self.fn.call_function(tx, it, {}) + return [variables.TupleVariable([filtered])] + + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: + def _next() -> VariableTracker: + old_index = self.index + if isinstance(self.iterable, list): + if old_index >= len(self.iterable): + raise_observed_exception(StopIteration, tx) + return self.iterable[old_index] + else: + return self.iterable.next_variable(tx) + + # A do-while loop to find elements that make fn return true + while True: + item = _next() + self.index += 1 + if self.fn.is_constant_none(): + res = item + else: + res = self.fn.call_function(tx, [item], {}) + pred_res = variables.UserFunctionVariable( + polyfills.predicate # type: ignore[arg-type] + ).call_function(tx, [res], {}) + if pred_res.as_python_constant(): + return item + + def reconstruct_items(self, codegen: "PyCodegen") -> None: + if isinstance(self.iterable, list): + remaining_items = self.iterable[self.index :] + codegen.foreach(remaining_items) + codegen.append_output(create_build_tuple(len(remaining_items))) + else: + codegen(self.iterable) + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null(lambda: codegen.load_import_from("builtins", "filter")) + codegen(self.fn) + self.reconstruct_items(codegen) + codegen.extend_output(create_call_function(2, False)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py new file mode 100644 index 0000000000000000000000000000000000000000..74609e0884cb284f4d9e286696e2cdde4e7d8e1f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +import collections +import functools +import inspect +from typing import Any, TYPE_CHECKING + +from ..utils import is_function_or_wrapper +from .base import VariableTracker, VariableTrackerMeta + + +if TYPE_CHECKING: + from collections.abc import Callable + from typing_extensions import Self + + from .tensor import SymNodeVariable + + +class LazyCache: + """Container to cache the real VariableTracker""" + + def __init__(self, value: Any, source: Any) -> None: + if not isinstance(value, LazySymNodeFormatString): + assert source + self.value = value + self.source = source + self.name_hint: str | None = None + self.vt: VariableTracker | None = None + + def realize(self) -> None: + assert self.vt is None + from ..symbolic_convert import InstructionTranslator + from . import builder + + tx = InstructionTranslator.current_tx() + + if isinstance(self.value, LazySymNodeFormatString): + self.vt = builder.SourcelessBuilder.create(tx, self.value) + else: + self.vt = builder.VariableBuilder(tx, self.source)(self.value) + + if self.name_hint is not None: + # pyrefly: ignore [missing-attribute] + self.vt.set_name_hint(self.name_hint) + + del self.value + del self.source + del self.name_hint + + +class LazyVariableTracker(VariableTracker, metaclass=VariableTrackerMeta): + """ + A structure that defers the creation of the actual VariableTracker + for a given underlying value until it is accessed. + + The `realize` function invokes VariableTracker.build() to produce the real object. + Once a LazyVariableTracker has been realized, internal bookkeeping will + prevent double realization. + + This object should be utilized for processing containers, or objects that + reference other objects where we may not want to take on creating all the + VariableTrackers right away. + """ + + # Flag to prevent implicit realization in isinstance checks (inherited by subclasses) + _no_implicit_realize = True + _nonvar_fields = {"_cache", *VariableTracker._nonvar_fields} + + @staticmethod + def create(value: Any, source: Any, **options: Any) -> LazyVariableTracker: + return LazyVariableTracker(LazyCache(value, source), source=source, **options) + + def __init__(self, _cache: LazyCache, **kwargs: Any) -> None: + assert isinstance(_cache, LazyCache) + super().__init__(**kwargs) + self._cache = _cache + + def realize(self) -> VariableTracker: + """Force construction of the real VariableTracker""" + if self._cache.vt is None: + self._cache.realize() + assert self._cache.vt is not None + return self._cache.vt + + def lazy_isinstance(self, cls: type) -> bool: + """Check isinstance after realizing, used by ImplicitRealizingVariableTrackerMeta""" + return type.__instancecheck__(cls, self.realize()) + + def unwrap(self) -> VariableTracker | Self: + """Return the real VariableTracker if it already exists""" + if self.is_realized(): + assert self._cache.vt is not None + return self._cache.vt + return self + + def is_realized(self) -> bool: + return self._cache.vt is not None + + def clone(self, **kwargs: Any) -> VariableTracker: + assert kwargs.get("_cache", self._cache) is self._cache + if kwargs.get("source", self.source) is not self.source: + self.realize() + return VariableTracker.clone(self.unwrap(), **kwargs) + + def peek_type(self) -> type[Any]: + assert not self.is_realized() + return type(self._cache.value) + + def peek_value(self) -> Any: + assert not self.is_realized() + return self._cache.value + + def set_name_hint(self, name: str) -> None: + if self.is_realized(): + self._cache.vt.set_name_hint(name) # type: ignore[union-attr] + else: + self._cache.name_hint = name + + def __str__(self) -> str: + variable_info = "LazyVariableTracker(" + if self.is_realized(): + variable_info += f"realized: {repr(self.unwrap())})" + else: + variable_info += f"unrealized: {self.peek_type()})" + + return variable_info + + def __getattr__(self, item: str) -> Any: + return getattr(self.realize(), item) + + # most methods are auto-generated below, these are the ones we want to exclude + visit = VariableTracker.visit # type: ignore[assignment] + __repr__ = __str__ + + @classmethod + def realize_all( + cls, + value: Any, + cache: dict[int, tuple[Any, Any]] | None = None, + ) -> Any: + """ + Walk an object and realize all LazyVariableTrackers inside it. + """ + if cache is None: + cache = {} + + idx = id(value) + if idx in cache: + return cache[idx][0] + + value_cls = type(value) + if issubclass(value_cls, LazyVariableTracker): + result = cls.realize_all(value.realize(), cache) + elif issubclass(value_cls, VariableTracker): + # update value in-place + result = value + value_dict = value.__dict__ + nonvars = value._nonvar_fields + for key in value_dict: + if key not in nonvars: + value_dict[key] = cls.realize_all(value_dict[key], cache) + elif value_cls is list: + result = [cls.realize_all(v, cache) for v in value] + elif value_cls is tuple: + result = tuple(cls.realize_all(v, cache) for v in value) + elif value_cls in (dict, collections.OrderedDict): + result = {k: cls.realize_all(v, cache) for k, v in list(value.items())} + else: + result = value + + # save `value` to keep it alive and ensure id() isn't reused + cache[idx] = (result, value) + return result + + def is_hashable(self) -> bool: + # Checks that the underlying value is hashable without realizing the VT. + # This is used by ConstDictVariable tracker to find if the key LazyVT + # can be hashed. + def _helper(value: Any) -> bool: + # TODO: Add support for more types + return ( + inspect.isbuiltin(value) + or issubclass(type(value), type) + or is_function_or_wrapper(value) + ) + + assert not self.is_realized() + value = self._cache.value + if isinstance(value, tuple): + return all(_helper(v) for v in value) + return _helper(value) + + def original_value(self) -> Any: + # Returns the value without realizing the VT. + assert not self.is_realized() + return self._cache.value + + def original_source(self) -> Any: + # Returns the source without realizing the VT. + assert not self.is_realized() + return self._cache.source + + +class LazySymNodeFormatString: + def __init__( + self, sym_node_variable: SymNodeVariable, fmt_spec_var: VariableTracker + ) -> None: + from .constant import ConstantVariable + + self.sym_node_var = sym_node_variable + self.fmt_var = ConstantVariable.create( + "{:" + fmt_spec_var.as_python_constant() + "}" + ) + + def __repr__(self) -> str: + return str.format( + self.fmt_var.as_python_constant(), + str(self.sym_node_var.evaluate_expr()), + ) + + +def _create_realize_and_forward( + name: str, +) -> Callable[[LazyVariableTracker, Any, Any], Any]: + @functools.wraps(getattr(VariableTracker, name)) + def realize_and_forward( + self: LazyVariableTracker, *args: Any, **kwargs: Any + ) -> Any: + return getattr(self.realize(), name)(*args, **kwargs) + + return realize_and_forward + + +def _populate() -> None: + for name, value in VariableTracker.__dict__.items(): + if name not in LazyVariableTracker.__dict__: + if callable(value): + setattr(LazyVariableTracker, name, _create_realize_and_forward(name)) + + +_populate() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/lists.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/lists.py new file mode 100644 index 0000000000000000000000000000000000000000..734d30a76380d350e615da34929ec56d6d4bae7d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/lists.py @@ -0,0 +1,1821 @@ +""" +Variable tracking implementations for list-like data structures in Dynamo. + +This module provides specialized variable tracking for various collection types: +- Lists and list subclasses (including torch.nn.ModuleList, ParameterList) +- Tuples and named tuples +- Ranges and slices +- Collections.deque +- torch.Size with special proxy handling + +The implementations support both mutable and immutable collections, iteration, +and common sequence operations. Each collection type has a dedicated Variable +class that handles its unique behaviors while integrating with Dynamo's +variable tracking system. +""" + +import collections +import inspect +import operator +import sys +from collections.abc import Sequence +from typing import Any, Optional, TYPE_CHECKING + +import torch +import torch.fx + +from .. import graph_break_hints, polyfills, variables +from ..bytecode_transformation import ( + create_build_tuple, + create_call_function, + create_instruction, + create_rot_n, +) +from ..exc import raise_observed_exception, unimplemented +from ..source import AttrSource, NamedTupleFieldsSource +from ..utils import ( + cmp_name_to_op_mapping, + cmp_name_to_op_str_mapping, + get_fake_value, + guard_if_dyn, + iter_contains, + Lit, + namedtuple_fields, + odict_values, + raise_args_mismatch, + range_iterator, + set_example_value, +) +from .base import ValueMutationNew, VariableTracker +from .constant import ConstantVariable +from .functions import UserFunctionVariable, UserMethodVariable +from .iter import IteratorVariable + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + +class BaseListVariable(VariableTracker): + @staticmethod + def cls_for_instance(obj: Any) -> type["BaseListVariable"]: + return BaseListVariable.cls_for(type(obj)) + + @staticmethod + def cls_for(obj: Any) -> type: + return { + iter: ListIteratorVariable, + list: ListVariable, + slice: SliceVariable, + torch.Size: SizeVariable, + tuple: TupleVariable, + odict_values: ListVariable, + torch.nn.ParameterList: ListVariable, + torch.nn.ModuleList: ListVariable, + collections.deque: DequeVariable, + }[obj] + + def __init__( + self, + items: list[VariableTracker], + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + assert isinstance(items, list) + assert all(isinstance(x, VariableTracker) for x in items) + self.items: list[VariableTracker] = items + + def _as_proxy(self) -> list[Any]: + return [x.as_proxy() for x in self.items] + + def modified( + self, items: list[VariableTracker], **kwargs: Any + ) -> "BaseListVariable": + return type(self)(items, **kwargs) + + @property + def value(self) -> Any: + return self.as_python_constant() + + def debug_repr_helper(self, prefix: str, suffix: str) -> str: + return prefix + ", ".join(i.debug_repr() for i in self.items) + suffix + + def as_python_constant(self) -> Any: + return self.python_type()([x.as_python_constant() for x in self.items]) + + def as_proxy(self) -> Any: + assert self.python_type() is not SizeVariable + return self.python_type()(self._as_proxy()) + + def getitem_const( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: + from .tensor import SymNodeVariable + + if isinstance(arg, SymNodeVariable): + index = arg.sym_num + else: + index = arg.as_python_constant() + + if isinstance(index, slice): + if index.step == 0: + msg = ConstantVariable.create("slice step cannot be zero") + raise_observed_exception(ValueError, tx, args=[msg]) + # Set source to None because slicing a list gives a new local + return self.clone( + items=self.items[index], + source=None, + mutation_type=ValueMutationNew() if self.mutation_type else None, + ) + else: + assert isinstance(index, (int, torch.SymInt)) + try: + return self.items[index] + except IndexError: + raise_observed_exception( + IndexError, tx, args=["list index out of range"] + ) + + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: + return list(self.items) + + def call_tree_map_branch( + self, + tx: "InstructionTranslator", + tree_map_fn: UserFunctionVariable, + map_fn: VariableTracker, + rest: Sequence[VariableTracker], + tree_map_kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if not isinstance(self, (ListVariable, TupleVariable)): + return self._tree_map_fallback( + tx, tree_map_fn, map_fn, rest, tree_map_kwargs + ) + + other_lists: list[BaseListVariable] = [] + for candidate in rest: + if ( + not isinstance(candidate, BaseListVariable) + or len(candidate.items) != len(self.items) + or self.python_type() != candidate.python_type() + ): + return self._tree_map_fallback( + tx, tree_map_fn, map_fn, rest, tree_map_kwargs + ) + other_lists.append(candidate) + + new_items: list[VariableTracker] = [] + for idx, item in enumerate(self.items): + sibling_leaves = [candidate.items[idx] for candidate in other_lists] + new_items.append( + item.call_tree_map( + tx, + tree_map_fn, + map_fn, + sibling_leaves, + tree_map_kwargs, + ) + ) + + return self.clone( + items=new_items, + source=None, + mutation_type=ValueMutationNew(), + ) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__getitem__": + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + + if args[0].is_tensor(): + value = get_fake_value(args[0].as_proxy().node, tx) + if value.constant is not None and value.constant.numel() == 1: + value = variables.ConstantVariable.create(value.constant.item()) + else: + unimplemented( + gb_type="Indexing list with non-scalar tensor", + context=f"call_method {self} {name} {args} {kwargs}", + explanation=( + "Attempted to index list-like object with tensor with > 1 element." + ), + hints=[*graph_break_hints.USER_ERROR], + ) + else: + value = args[0] + + if value.python_type() not in (int, slice): + msg = f"indices must be integers or slices, not {value.python_type()}" + raise_observed_exception(TypeError, tx, args=[ConstantVariable(msg)]) + + return self.getitem_const(tx, value) + elif name == "__contains__": + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + return iter_contains(self.unpack_var_sequence(tx), args[0], tx) + elif name == "index": + if not len(args): + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.index), + [self] + list(args), + kwargs, + ) + elif name == "count": + if len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + return VariableTracker.build(tx, operator.countOf).call_function( + tx, + [self, args[0]], + kwargs, + ) + elif name in ("__add__", "__iadd__"): + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + + if type(self) is not type(args[0]): + tp_name = self.python_type_name() + other = args[0].python_type_name() + msg_vt = ConstantVariable.create( + f'can only concatenate {tp_name} (not "{other}") to {tp_name}' + ) + raise_observed_exception(TypeError, tx, args=[msg_vt]) + + if name == "__add__": + return type(self)(self.items + args[0].items, source=self.source) # type: ignore[attr-defined] + else: + self.items += args[0].items # type: ignore[attr-defined] + return self + elif name in ("__mul__", "__imul__"): + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + + if not (args[0].is_python_constant() and args[0].python_type() is int): + msg_vt = ConstantVariable.create( + f"can't multiply sequence by non-int type of '{args[0].python_type_name()}'" + ) + raise_observed_exception(TypeError, tx, args=[msg_vt]) + + val = args[0].as_python_constant() + + if name == "__mul__": + return type(self)(self.items * val, source=self.source) + else: + self.items *= val + return self + elif name in cmp_name_to_op_mapping: + if len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + + left = self + right = args[0] + # TODO this type check logic mirrors the following + # https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/object.c#L991-L1007 + # But we should probably move it up the stack to so that we don't + # need to duplicate it for different VTs. + if not isinstance(left, BaseListVariable) or not isinstance( + right, BaseListVariable + ): + if name == "__eq__": + return variables.BuiltinVariable(operator.is_).call_function( + tx, (left, right), {} + ) + elif name == "__ne__": + return variables.BuiltinVariable(operator.is_not).call_function( + tx, (left, right), {} + ) + else: + op_str = cmp_name_to_op_str_mapping[name] + left_ty = left.python_type_name() + right_ty = right.python_type_name() + msg = f"{op_str} not supported between instances of '{left_ty}' and '{right_ty}'" + raise_observed_exception(TypeError, tx, args=[msg]) + + return variables.UserFunctionVariable(polyfills.list_cmp).call_function( + tx, + [variables.BuiltinVariable(cmp_name_to_op_mapping[name]), left, right], + {}, + ) + elif name == "__iter__": + return ListIteratorVariable(self.items, mutation_type=ValueMutationNew()) + + return super().call_method(tx, name, args, kwargs) + + +class RangeVariable(BaseListVariable): + def __init__(self, items: Sequence[VariableTracker], **kwargs: Any) -> None: + items_to_map = items + start = variables.ConstantVariable.create(0) + stop = None + step = variables.ConstantVariable.create(1) + + if len(items_to_map) == 1: + (stop,) = items_to_map + elif len(items_to_map) == 2: + start, stop = items_to_map + elif len(items_to_map) == 3: + start, stop, step = items_to_map + else: + raise AssertionError + + def maybe_as_int(x: VariableTracker) -> VariableTracker: + return ( + ConstantVariable.create(int(x.as_python_constant())) + if x.is_python_constant() + else x + ) + + # cast each argument to an integer + start = maybe_as_int(start) + step = maybe_as_int(step) + stop = maybe_as_int(stop) + + assert stop is not None + super().__init__([start, stop, step], **kwargs) + + def debug_repr(self) -> str: + return self.debug_repr_helper("range(", ")") + + def python_type(self) -> type: + return range + + def start(self) -> Any: + return self.items[0].as_python_constant() + + def stop(self) -> Any: + return self.items[1].as_python_constant() + + def step(self) -> Any: + return self.items[2].as_python_constant() + + def range_length(self) -> int: + lo = self.start() + hi = self.stop() + step = self.step() + + assert step != 0 + if step > 0 and lo < hi: + return 1 + (hi - 1 - lo) // step + elif step < 0 and lo > hi: + return 1 + (lo - 1 - hi) // (0 - step) + else: + return 0 + + def _get_slice_indices(self, length: int, slice: slice) -> list[int]: + step_is_negative = 0 + + if slice.step is None: + step = 1 + step_is_negative = False + else: + step = slice.step + step_is_negative = slice.step < 0 + + # Find lower and upper bounds for start and stop. + if step_is_negative: + lower = -1 + upper = length + lower + else: + lower = 0 + upper = length + + # Compute start + if slice.start is None: + start = upper if step_is_negative else lower + else: + start = slice.start + + if start < 0: + start += length + if start < lower: + start = lower + else: + if start > upper: + start = upper + + # Compute stop. + if slice.stop is None: + stop = lower if step_is_negative else upper + + else: + stop = slice.stop + + if stop < 0: + stop += length + if stop < lower: + stop = lower + else: + if stop > upper: + stop = upper + + return [start, stop, step] + + def apply_index(self, index: int) -> VariableTracker: + length = self.range_length() + if index < 0: + index = length + index + + if index < 0 or index >= length: + tx = torch._dynamo.symbolic_convert.InstructionTranslator.current_tx() + raise_observed_exception( + IndexError, + tx, + args=[ConstantVariable("range object index out of range")], + ) + + return variables.ConstantVariable.create(self.start() + (index * self.step())) + + def apply_slice(self, slice: slice) -> "RangeVariable": + (slice_start, slice_stop, slice_step) = self._get_slice_indices( + self.range_length(), slice + ) + + def compute_item(index: int) -> int: + return self.start() + (index * self.step()) + + sub_step = self.step() * slice_step + sub_start = compute_item(slice_start) + sub_stop = compute_item(slice_stop) + + result = RangeVariable( + [ + variables.ConstantVariable.create(x) + for x in [sub_start, sub_stop, sub_step] + ], + mutation_type=ValueMutationNew() if self.mutation_type else None, + ) + return result + + def as_python_constant(self) -> range: + return range(*[x.as_python_constant() for x in self.items]) + + def getitem_const( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: + # implementations mimics https://github.com/python/cpython/blob/main/Objects/rangeobject.c + index = arg.as_python_constant() + + if isinstance(index, slice): + return self.apply_slice(index) + elif isinstance(index, int): + return self.apply_index(index) + else: + msg = ConstantVariable("range indices must be integers or slices") + raise_observed_exception(TypeError, tx, args=[msg]) + + def as_proxy(self) -> range: + return self.python_type()(*self._as_proxy()) + + def unpack_var_sequence( + self, tx: Optional["InstructionTranslator"] = None + ) -> list[VariableTracker]: + return [variables.ConstantVariable.create(x) for x in self.as_python_constant()] + + def reconstruct(self, codegen: "PyCodegen") -> None: + assert "range" not in codegen.tx.f_globals + codegen.add_push_null( + lambda: codegen.append_output(codegen.create_load_python_module(range)) # type: ignore[arg-type] + ) + codegen.foreach(self.items) + codegen.extend_output(create_call_function(3, False)) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + if self.python_type() is range: + return variables.ConstantVariable.create(name in range.__dict__) + return super().call_obj_hasattr(tx, name) + + def range_equals(self, other: "RangeVariable") -> bool: + r0, r1 = self, other + if ( + self.range_length() != r1.range_length() + or self.range_length() == 0 + or r0.start() != r1.start() + ): + return False + + if self.range_length() == 1: + return True + + return r0.step() == r1.step() + + def range_count(self, x: VariableTracker) -> int: + # Based on CPython + # https://github.com/guilhermeleobas/cpython/blob/baefaa6cba1d69efd2f930cdc56bca682c54b139/Objects/rangeobject.c#L442-L486 + x = x.as_python_constant() + if type(x) not in (bool, int, float): + return 0 + + start, stop, step = self.start(), self.stop(), self.step() + + if step == 0: + return 0 + + in_range = (start <= x < stop) if step > 0 else (stop < x <= start) + + if in_range: + re = ((x - start) % step) == 0 + return int(re) + return 0 + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__iter__": + if not all(var.is_python_constant() for var in self.items): + # Can't represent a `range_iterator` without well defined bounds + return variables.misc.DelayGraphBreakVariable( + msg="Cannot create range_iterator: bounds (start, stop, step) must be fully defined as concrete constants.", + ) + return RangeIteratorVariable( + self.start(), self.stop(), self.step(), self.range_length() + ) + elif name == "__len__": + length = self.range_length() + if length > sys.maxsize: + raise_observed_exception(OverflowError, tx) + return ConstantVariable.create(self.range_length()) + elif name in ("count", "__contains__"): + return ConstantVariable(self.range_count(*args)) + elif name == "__getitem__": + return self.getitem_const(tx, *args) + elif name in cmp_name_to_op_mapping: + other = args[0] + pt = other.python_type() + if name not in ("__eq__", "__ne__"): + # ranges are only comparable to other ranges + msg = f"{name} not supported between instances of 'range' and '{pt}'" + raise_observed_exception( + TypeError, + tx, + args=[ConstantVariable.create(msg)], + ) + + if pt is not range: + return ConstantVariable.create(NotImplemented) + + if isinstance(other, RangeVariable): + cmp = self.range_equals(other) + else: + cmp = False + + # Two ranges are equal if they produce the same sequence of values + if name == "__eq__": + return ConstantVariable(cmp) + else: + return ConstantVariable(not cmp) + return super().call_method(tx, name, args, kwargs) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + fields = ["start", "stop", "step"] + if name in fields: + return self.items[fields.index(name)] + return super().var_getattr(tx, name) + + def is_python_hashable(self): + return True + + def get_python_hash(self): + l = self.range_length() + start = self.start() + step = self.step() + return hash((l, start, step)) + + def is_python_equal(self, other): + if not isinstance(other, variables.RangeVariable): + return False + + return ( + self.start() == other.start() + and self.step() == other.step() + and self.stop() == other.stop() + ) + + +class CommonListMethodsVariable(BaseListVariable): + """ + Implement methods common to List and other List-like things + """ + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from .tensor import SymNodeVariable + + if name == "append" and self.is_mutable(): + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + (arg,) = args + tx.output.side_effects.mutation(self) + self.items.append(arg) + return ConstantVariable.create(None) + elif name == "extend" and self.is_mutable(): + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + + if not args[0].has_force_unpack_var_sequence(tx): + msg = ConstantVariable.create(f"{type(args[0])} object is not iterable") + raise_observed_exception(TypeError, tx, args=[msg]) + + (arg,) = args + arg.force_apply_to_var_sequence( + tx, lambda item: self.call_method(tx, "append", [item], {}) + ) + return ConstantVariable.create(None) + elif name == "insert" and self.is_mutable(): + if kwargs or len(args) != 2: + raise_args_mismatch( + tx, + name, + "2 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + idx, value = args + if isinstance(idx, SymNodeVariable): + const_idx = idx.evaluate_expr() + else: + const_idx = idx.as_python_constant() + tx.output.side_effects.mutation(self) + self.items.insert(const_idx, value) + return ConstantVariable.create(None) + elif name == "pop" and self.is_mutable(): + if kwargs or len(args) > 1: + raise_args_mismatch( + tx, + name, + "at most 1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + + if len(self.items) == 0: + msg = ConstantVariable.create("pop from empty list") + raise_observed_exception(IndexError, tx, args=[msg]) + + if len(args): + idx = args[0].as_python_constant() + if idx > len(self.items): + msg = ConstantVariable.create("pop index out of range") + raise_observed_exception(IndexError, tx, args=[msg]) + tx.output.side_effects.mutation(self) + return self.items.pop(*[a.as_python_constant() for a in args]) + elif name == "clear" and self.is_mutable(): + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + tx.output.side_effects.mutation(self) + self.items.clear() + return ConstantVariable.create(None) + elif ( + name == "__setitem__" + and self.is_mutable() + and args + and ( + args[0].is_python_constant() + or isinstance(args[0], SymNodeVariable) + or ( + isinstance(args[0], SliceVariable) + and all( + s.is_python_constant() or isinstance(s, SymNodeVariable) + for s in args[0].items + ) + ) + ) + ): + if kwargs: + raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") + key, value = args + tx.output.side_effects.mutation(self) + if isinstance(key, SymNodeVariable): + self.items[key.evaluate_expr()] = value + elif isinstance(key, SliceVariable): + if key.is_python_constant(): + self.items[key.as_python_constant()] = list(value.items) # type: ignore[attr-defined] + else: + items_slice = slice( + *[ + ( + s.evaluate_expr() + if isinstance(s, SymNodeVariable) + else s.as_python_constant() + ) + for s in key.items + ] + ) + self.items[items_slice] = list(value.items) # type: ignore[attr-defined] + else: + self.items[key.as_python_constant()] = value + return ConstantVariable.create(None) + elif name == "__delitem__" and self.is_mutable(): + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + + tx.output.side_effects.mutation(self) + if args[0].is_python_constant() and isinstance( + args[0].as_python_constant(), (int, slice) + ): + if isinstance(args[0], SymNodeVariable): + idx = args[0].evaluate_expr() + else: + idx = args[0].as_python_constant() + + try: + self.items.__delitem__(idx) + except (IndexError, ValueError) as exc: + raise_observed_exception( + type(exc), + tx, + args=list(map(ConstantVariable.create, exc.args)), + ) + else: + msg = ConstantVariable.create( + f"list indices must be integers or slices, not {args[0].python_type_name()}" + ) + raise_observed_exception(TypeError, tx, args=[msg]) + return ConstantVariable.create(None) + elif name == "copy": + # List copy() doesn't have args and kwargs + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + items_lst: list[VariableTracker] = list(self.items) + return self.modified(items_lst, mutation_type=ValueMutationNew()) + elif name == "reverse" and self.is_mutable(): + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + self.items.reverse() + tx.output.side_effects.mutation(self) + return ConstantVariable.create(None) + elif name == "remove" and self.is_mutable(): + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + + idx = self.call_method(tx, "index", args, kwargs) + self.call_method(tx, "pop", [idx], {}) + return ConstantVariable.create(None) + else: + return super().call_method(tx, name, args, kwargs) + + +class ListVariable(CommonListMethodsVariable): + def python_type(self) -> type: + return list + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(length={len(self.items)})" + + def debug_repr(self) -> str: + return self.debug_repr_helper("[", "]") + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.foreach(self.items) + codegen.append_output(create_instruction("BUILD_LIST", arg=len(self.items))) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from .tensor import SymNodeVariable + + if name == "__setitem__" and self.is_mutable(): + if kwargs or len(args) != 2: + raise_args_mismatch( + tx, + name, + "2 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + key, value = args + + if not key.is_python_constant(): + # probably will graph-break + super().call_method(tx, name, args, kwargs) + + tx.output.side_effects.mutation(self) + if isinstance(key, SliceVariable): + if not value.has_force_unpack_var_sequence(tx): + msg = ConstantVariable.create("can only assign an iterable") + raise_observed_exception(TypeError, tx, args=[msg]) + + key_as_const = key.as_python_constant() + if key_as_const.step == 0: + msg = ConstantVariable.create("slice step cannot be zero") + raise_observed_exception(ValueError, tx, args=[msg]) + + value_unpack = value.force_unpack_var_sequence(tx) + try: + self.items[key_as_const] = value_unpack + except Exception as exc: + raise_observed_exception( + type(exc), + tx, + args=list(map(ConstantVariable.create, exc.args)), + ) + else: + if isinstance(key, SymNodeVariable): + key = key.evaluate_expr() + else: + key = key.as_python_constant() + + try: + self.items[key] = value + except (IndexError, TypeError) as e: + raise_observed_exception( + type(e), tx, args=list(map(ConstantVariable.create, e.args)) + ) + return ConstantVariable.create(None) + + if name == "sort" and self.is_mutable(): + if len(args) != 0: + raise_args_mismatch(tx, name, "0 args", f"{len(args)} args") + key_fn_var = kwargs.pop("key", ConstantVariable.create(None)) + reverse = kwargs.pop( + "reverse", ConstantVariable.create(False) + ).as_python_constant() + if len(kwargs) != 0: + raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") + + if key_fn_var.is_constant_none(): + keys = self.items.copy() + else: + keys = [key_fn_var.call_function(tx, [x], {}) for x in self.items] + + if not all(k.is_python_constant() for k in keys): + first_non_constant_key = None + for k in keys: + if not k.is_python_constant(): + first_non_constant_key = k + assert first_non_constant_key is not None + + try: + python_type = str(first_non_constant_key.python_type()) + except NotImplementedError: + python_type = "unknown" + + unimplemented( + gb_type="sort with non-constant keys", + context=str(first_non_constant_key), + explanation=( + f"Cannot perform sort with non-constant key. " + f"First non-constant key type: {python_type}. " + f"Most notably, we cannot sort with Tensor or SymInt keys, but we can " + f"sort ints." + ), + hints=["Use something else as the key."], + ) + + tx.output.side_effects.mutation(self) + sorted_items_with_keys = sorted( + ( + ( + x, + k.as_python_constant(), + -i if reverse else i, # extra key to ensure stable sort + ) + for i, (k, x) in enumerate(zip(keys, self.items)) + ), + key=operator.itemgetter(1, 2), + reverse=reverse, + ) + self.items[:] = [x for x, *_ in sorted_items_with_keys] + return ConstantVariable.create(None) + + if name == "__init__" and self.is_mutable(): + if kwargs: + raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") + if len(args) == 0: + return ConstantVariable.create(None) + elif len(args) == 1 and args[0].has_force_unpack_var_sequence(tx): + (arg,) = args + tx.output.side_effects.mutation(self) + self.items[:] = arg.force_unpack_var_sequence(tx) + return ConstantVariable.create(None) + + return super().call_method(tx, name, args, kwargs) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if name == "__class__": + source = AttrSource(self.source, name) if self.source else None + class_type = self.python_type() + if class_type is list: + return variables.BuiltinVariable(class_type, source=source) + else: + return variables.UserDefinedClassVariable(class_type, source=source) + return super().var_getattr(tx, name) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + if self.python_type() is not list: + return super().call_obj_hasattr(tx, name) + return variables.ConstantVariable.create(hasattr([], name)) + + def is_python_hashable(self): + return False + + +class DequeVariable(CommonListMethodsVariable): + def __init__( + self, + items: list[VariableTracker], + maxlen: Optional[VariableTracker] = None, + **kwargs: Any, + ) -> None: + if maxlen is None: + maxlen = ConstantVariable.create(None) + assert maxlen.is_python_constant(), ( + f"maxlen must be a constant, got: {maxlen.debug_repr()}" + ) + self.maxlen = maxlen + items = list(items) + if self.maxlen.as_python_constant() is not None: + items = items[-maxlen.as_python_constant() :] + super().__init__(items, **kwargs) + + def python_type(self) -> type: + return collections.deque + + def debug_repr(self) -> str: + if self.maxlen.as_python_constant() is None: + return self.debug_repr_helper( + "deque([", "], maxlen=" + self.maxlen.debug_repr() + ")" + ) + return self.debug_repr_helper("deque([", "])") + + def as_python_constant(self) -> collections.deque[Any]: + return self.python_type()( + [x.as_python_constant() for x in self.items], + maxlen=self.maxlen.as_python_constant(), + ) + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null( + lambda: codegen.append_output( + codegen.create_load_python_module(collections.deque) # type: ignore[arg-type] + ) + ) + codegen.foreach(self.items) + codegen.extend_output([create_instruction("BUILD_LIST", arg=len(self.items))]) + codegen(self.maxlen) + codegen.extend_output(codegen.create_call_function_kw(2, ("maxlen",), False)) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if name == "maxlen": + return self.maxlen + return super().var_getattr(tx, name) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if ( + name == "__setitem__" + and self.is_mutable() + and args + and args[0].is_python_constant() + ): + if kwargs or len(args) != 2: + raise_args_mismatch( + tx, + name, + "2 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + key, value = args + assert key.is_python_constant() + assert isinstance(key.as_python_constant(), int) + tx.output.side_effects.mutation(self) + self.items[key.as_python_constant()] = value + return ConstantVariable.create(None) + + maxlen = self.maxlen.as_python_constant() + if maxlen is not None: + slice_within_maxlen = slice(-maxlen, None) + else: + slice_within_maxlen = None + + if ( + name == "extendleft" + and self.is_mutable() + and len(args) > 0 + and args[0].has_force_unpack_var_sequence(tx) + ): + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + # NOTE this is inefficient, but the alternative is to represent self.items + # as a deque, which is a more intrusive change. + args[0].force_apply_to_var_sequence( + tx, lambda item: self.call_method(tx, "appendleft", [item], {}) + ) + slice_within_maxlen = slice(None, maxlen) + result = ConstantVariable.create(None) + elif name == "popleft" and self.is_mutable(): + if kwargs or len(args) > 0: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + tx.output.side_effects.mutation(self) + result, *self.items[:] = self.items + elif name == "appendleft" and len(args) > 0 and self.is_mutable(): + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + tx.output.side_effects.mutation(self) + self.items[:] = [args[0], *self.items] + slice_within_maxlen = slice(None, maxlen) + result = ConstantVariable.create(None) + elif name == "insert" and len(args) > 0 and self.is_mutable(): + if kwargs or len(args) != 2: + raise_args_mismatch( + tx, + name, + "2 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + if maxlen is not None and len(self.items) == maxlen: + raise_observed_exception( + IndexError, tx, args=["deque already at its maximum size"] + ) + result = super().call_method(tx, name, args, kwargs) + else: + result = super().call_method(tx, name, args, kwargs) + + if ( + slice_within_maxlen is not None + and maxlen is not None + and len(self.items) > maxlen + ): + self.items[:] = self.items[slice_within_maxlen] + return result + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + if self.python_type() is collections.deque: + return variables.ConstantVariable.create(name in collections.deque.__dict__) + return super().call_obj_hasattr(tx, name) + + +class TupleVariable(BaseListVariable): + def python_type(self) -> type[tuple]: # type: ignore[type-arg] + return tuple + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(length={len(self.items)})" + + def debug_repr(self) -> str: + return self.debug_repr_helper("(", ")") + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.foreach(self.items) + codegen.append_output(create_build_tuple(len(self.items))) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if name == "__class__": + source = AttrSource(self.source, name) if self.source else None + class_type = self.python_type() + if class_type is tuple: + return variables.BuiltinVariable(class_type, source=source) + else: + return variables.UserDefinedClassVariable(class_type, source=source) + return super().var_getattr(tx, name) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + if self.python_type() is not tuple: + return super().call_obj_hasattr(tx, name) + return variables.ConstantVariable.create(hasattr((), name)) + + def is_python_hashable(self): + return all(item.is_python_hashable() for item in self.items) + + def get_python_hash(self): + items = tuple(x.get_python_hash() for x in self.items) + return hash(items) + + def is_python_equal(self, other): + return isinstance(other, variables.TupleVariable) and all( + a.is_python_equal(b) for (a, b) in zip(self.items, other.items) + ) + + +class SizeVariable(TupleVariable): + """torch.Size(...)""" + + _nonvar_fields = { + "proxy", + *TupleVariable._nonvar_fields, + } + + def __init__( + self, + items: list[VariableTracker], + proxy: Optional[torch.fx.Proxy] = None, + **kwargs: Any, + ) -> None: + self.proxy = proxy + super().__init__(items, **kwargs) + + def debug_repr(self) -> str: + return self.debug_repr_helper("torch.Size([", "])") + + def python_type(self) -> type: + return torch.Size + + def as_proxy(self) -> Any: + if self.proxy is not None: + return self.proxy + + # torch.Size needs special handling. Normally, we pun a list-like + # container to directly contain Proxy/Node objects from FX, and FX + # knows to look inside containers (via map_aggregate). But torch.Size + # is weird; although it subclasses from tuple, it doesn't allow + # members which aren't int-like (rejecting Proxy and Node). This + # means we can't use the normal representation trick + # torch.Size([proxy0, proxy1]). I looked into seeing if I could + # relax torch.Size in PyTorch proper, but if torch.Size constructor + # sees a type that it doesn't recognize, it will try to call + # __index__() on it, so there is no BC way to actually change this + # behavior (though it occurs to me that I could have just added a + # YOLO no checking alternate constructor.) + # + # To work around this problem, I represent a torch.Size proxy as + # a straight up proxy, that would have been constructed by taking + # the constituent proxies as arguments. This trick can be generally + # used for any construct that we need a proxy for but we can't + # directly represent as an aggregate; I don't see very many examples + # of this in torchdynamo though! + + # Look for a proxy. If there are none, do the legacy behavior + tracer = None + proxies = self._as_proxy() + for proxy in proxies: + if isinstance(proxy, torch.fx.Proxy): + tracer = proxy.tracer + break + + if tracer is None: + return torch.Size(proxies) + + proxy = tracer.create_proxy("call_function", torch.Size, (proxies,), {}) + set_example_value( + proxy.node, + torch.Size( + [ + p.node.meta["example_value"] if not isinstance(p, int) else p + for p in proxies + ] + ), + ) + return proxy + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null(lambda: codegen.load_import_from("torch", "Size")) + codegen.foreach(self.items) + build_torch_size = [ + create_build_tuple(len(self.items)), + ] + create_call_function(1, False) + codegen.extend_output(build_torch_size) + + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: + return list(self.items) + + def numel(self, tx: "InstructionTranslator") -> VariableTracker: + from .builtin import BuiltinVariable + from .tensor import SymNodeVariable + + const_result = 1 + sym_sizes = [] + + for v in self.items: + if v.is_python_constant(): + const_result *= v.as_python_constant() + else: + assert isinstance(v, SymNodeVariable), type(v) + # Delay proxy calls until we know it will be necessary + sym_sizes.append(v) + + result = ConstantVariable.create(const_result) + if sym_sizes and const_result == 1: + # Skip multiplying by 1 + result, *sym_sizes = sym_sizes + + if not sym_sizes or const_result == 0: + return result + + mul = BuiltinVariable(operator.mul) + for v in sym_sizes: + result = mul.call_function(tx, [result, v], {}) + return result + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__getitem__": + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + out = self.get_item_dyn(tx, args[0]) + return out + elif name == "numel": + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + return self.numel(tx) + + return super().call_method(tx, name, args, kwargs) + + def get_item_dyn( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: + from .tensor import SymNodeVariable + + if isinstance(arg, SymNodeVariable): + index = arg.sym_num + else: + index = arg.as_python_constant() + + if isinstance(index, slice): + return SizeVariable(self.items[index]) + else: + assert isinstance(index, (int, torch.SymInt)) + return self.items[index] + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + return variables.ConstantVariable.create(hasattr(torch.Size, name)) + + +class NamedTupleVariable(TupleVariable): + _nonvar_fields = { + "tuple_cls", + "dynamic_attributes", + *TupleVariable._nonvar_fields, + } + + def __init__( + self, + items: list[VariableTracker], + tuple_cls: type, + dynamic_attributes: Optional[dict[str, VariableTracker]] = None, + **kwargs: Any, + ) -> None: + super().__init__(items, **kwargs) + self.tuple_cls = tuple_cls + self.dynamic_attributes = dynamic_attributes if dynamic_attributes else {} + + def is_namedtuple(self) -> bool: + return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable( + getattr(self.tuple_cls, "_make", None) + ) + + def is_structseq(self) -> bool: + return not self.is_namedtuple() + + def fields(self) -> tuple[str, ...]: + return namedtuple_fields(self.tuple_cls) + + def debug_repr(self) -> str: + if self.is_structseq(): + # StructSequenceType(iterable) + return repr(self.tuple_cls([Lit(x.debug_repr()) for x in self.items])) + # NamedTupleType(*iterable) + return repr(self.tuple_cls(*(Lit(x.debug_repr()) for x in self.items))) + + def python_type(self) -> type: + return self.tuple_cls + + def as_python_constant(self) -> Any: + if self.is_structseq(): + # StructSequenceType(iterable) + result = self.python_type()([x.as_python_constant() for x in self.items]) + else: + # NamedTupleType(*iterable) + result = self.python_type()(*[x.as_python_constant() for x in self.items]) + + # Apply dynamic attributes if any were set + if self.dynamic_attributes: + for attr_name, attr_value in self.dynamic_attributes.items(): + # Convert VariableTracker to Python constant if needed + if hasattr(attr_value, "as_python_constant"): + python_value = attr_value.as_python_constant() + else: + raise NotImplementedError( + "Can not convert dynamic attribute without python constant value to python constant." + ) + setattr(result, attr_name, python_value) + + return result + + def as_proxy(self) -> Any: + assert self.python_type() is not SizeVariable + if self.is_structseq(): + # StructSequenceType(iterable) + return self.python_type()(self._as_proxy()) + # NamedTupleType(*iterable) + return self.python_type()(*self._as_proxy()) + + def reconstruct(self, codegen: "PyCodegen") -> None: + # Always reconstruct the NamedTuple normally first + # Constructors: + # StructSequenceType(iterable) + # NamedTupleType(*iterable) + # NamedTupleType._make(iterable) + if self.is_structseq(): + create_fn = self.tuple_cls + else: + create_fn = self.tuple_cls._make # type: ignore[attr-defined] + codegen.add_push_null( + lambda: codegen.append_output( + codegen.create_load_const_unchecked(create_fn) + ) + ) + codegen.foreach(self.items) + codegen.extend_output( + [ + create_build_tuple(len(self.items)), + ] + + create_call_function(1, False) + ) + + for name, value in self.dynamic_attributes.items(): + codegen.dup_top() + codegen(value) + codegen.extend_output(create_rot_n(2)) + codegen.store_attr(name) + + def _is_method_overridden(self, method_name: str) -> bool: + """Checks if a method is overridden in the NamedTuple subclass. + + Args: + method_name (str): The name of the method to check. + + Returns: + bool: True if the method is overridden in the subclass, False otherwise. + + Raises: + ValueError: If the NamedTuple class does not inherit from both Tuple and Object. + """ + if len(self.tuple_cls.__mro__) < 3: + raise ValueError("NamedTuple should inherit from Tuple and Object.") + if getattr(self.tuple_cls, method_name, None) == getattr( + self.tuple_cls.__mro__[-3], method_name, None + ): + return False + return True + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__setattr__": + if kwargs or len(args) != 2: + raise_args_mismatch( + tx, + name, + "2 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + attr, value = args + attr = attr.as_python_constant() + if ( + # structseq is immutable + self.is_structseq() + # namedtuple directly created by `collections.namedtuple` is immutable + or self.tuple_cls.__bases__ == (tuple,) + # fields are immutable + or attr in self.fields() + ): + raise_observed_exception(AttributeError, tx) + # Subclass of namedtuple type can have dynamic attributes + tx.output.side_effects.mutation(self) + if self.source: + tx.output.side_effects.store_attr(self, attr, value) + self.dynamic_attributes[attr] = value + return ConstantVariable.create(None) + elif name == "_replace": + # NamedTuple._replace should create a new instance with replaced fields + if args: + raise_args_mismatch(tx, name, "0 args", f"{len(args)} args") + + # Get the field names for validation + fields = self.fields() + + # Start with current items (copy them) + new_items = list(self.items) + + # Replace fields specified in kwargs + for field_name, new_value in kwargs.items(): + if field_name not in fields: + raise_observed_exception( + ValueError, + tx, + args=[ + ConstantVariable.create( + f"Got unexpected field name: '{field_name}'" + ) + ], + ) + + # Replace the item at the field's index + field_index = fields.index(field_name) + new_items[field_index] = new_value + + return NamedTupleVariable(new_items, self.tuple_cls) + + return super().call_method(tx, name, args, kwargs) + + def getitem_const( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: + if isinstance(arg, SliceVariable): + # slicing a namedtuple produces a tuple + return TupleVariable( + self.items[arg.as_python_constant()], + source=None, + ) + return super().getitem_const(tx, arg) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + def check_and_create_method() -> Optional[VariableTracker]: + method = inspect.getattr_static(self.tuple_cls, name, None) + if isinstance(method, classmethod): + # We need the unbounded cls method to avoid the inline __self__ + return UserMethodVariable( + method.__func__, + variables.UserDefinedClassVariable(self.tuple_cls), + ) + elif isinstance(method, staticmethod): + # pyrefly: ignore[bad-argument-type] + return UserFunctionVariable(method.__func__) + elif inspect.isfunction(method): + return UserMethodVariable(method, self) + else: + return None + + # Avoid UserMethodVariable fallback precisely when methods NamedTuple methods have not been overwritten. + if ( + name == "_replace" + and not self._is_method_overridden("_replace") + and not self._is_method_overridden("__getattr__") + ): + # Return a BuiltinVariable for the _replace method + # Get the actual _replace method from the tuple class + actual_replace_method = getattr(self.tuple_cls, "_replace", None) + if actual_replace_method: + from ..source import AttrSource + + source = AttrSource(self.source, name) if self.source else None + return variables.GetAttrVariable(self, name, source=source) + # Fallback if _replace doesn't exist (shouldn't happen for proper NamedTuples) + return super().var_getattr(tx, name) + + if name == "_fields": + result_source = NamedTupleFieldsSource(self.source) if self.source else None + return VariableTracker.build(tx, self.fields(), source=result_source) + + if name in self.dynamic_attributes: + return self.dynamic_attributes[name] + + fields = self.fields() + if name not in fields: + method = check_and_create_method() + if not method: + return super().var_getattr(tx, name) + return method + return self.items[fields.index(name)] + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + return variables.ConstantVariable.create( + name in self.dynamic_attributes or hasattr(self.tuple_cls, name) + ) + + +class SliceVariable(VariableTracker): + def __init__( + self, + items: Sequence[VariableTracker], + tx: Optional["InstructionTranslator"] = None, + **kwargs: Any, + ) -> None: + items_to_map = items + start, stop, step = [variables.ConstantVariable.create(None)] * 3 + + if len(items_to_map) == 1: + (stop,) = items_to_map + elif len(items_to_map) == 2: + start, stop = items_to_map + elif len(items_to_map) == 3: + start, stop, step = items_to_map + else: + raise AssertionError + + # Convert TensorVariable to SymIntVariable by calling .item() + # This decomposes a[:t] to u=t.item(); a[:u] at the dynamo level + if start.is_tensor(): + assert tx is not None, ( + "tx is required when slice indices are TensorVariables" + ) + start = start.call_method(tx, "item", [], {}) + if stop.is_tensor(): + assert tx is not None, ( + "tx is required when slice indices are TensorVariables" + ) + stop = stop.call_method(tx, "item", [], {}) + if step.is_tensor(): + assert tx is not None, ( + "tx is required when slice indices are TensorVariables" + ) + step = step.call_method(tx, "item", [], {}) + + self.items = (start, stop, step) + + super().__init__(**kwargs) + + def debug_repr(self) -> str: + return "slice(" + ", ".join(i.debug_repr() for i in self.items) + ")" + + def as_proxy(self) -> slice: + return slice(*[x.as_proxy() for x in self.items]) + + def python_type(self) -> type: + return slice + + def as_python_constant(self) -> slice: + return slice(*[guard_if_dyn(x) for x in self.items]) + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.foreach(self.items) + codegen.append_output(create_instruction("BUILD_SLICE", arg=len(self.items))) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if name in cmp_name_to_op_mapping: + return variables.GetAttrVariable(self, name) + fields = ["start", "stop", "step"] + if name not in fields: + unimplemented( + gb_type="Unsupported attribute for slice() object", + context=f"var_getattr {self} {name}", + explanation=f"Expected attribute to be one of {','.join(fields)} " + f"but got {name}", + hints=[*graph_break_hints.USER_ERROR], + ) + return self.items[fields.index(name)] + + +class ListIteratorVariable(IteratorVariable): + _nonvar_fields = { + "index", + *IteratorVariable._nonvar_fields, + } + + def __init__( + self, items: list[VariableTracker], index: int = 0, **kwargs: Any + ) -> None: + super().__init__(**kwargs) + assert isinstance(items, list) + # Removing this check as it slows things down too much + # https://github.com/pytorch/pytorch/pull/87533#issuecomment-1287574492 + + # assert all(isinstance(x, VariableTracker) for x in items) + self.items = items + self.index = index + self.is_exhausted = False + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(length={len(self.items)}, index={repr(self.index)})" + + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: + assert self.is_mutable() + old_index = self.index + if old_index >= len(self.items) or self.is_exhausted: + self.is_exhausted = True + raise_observed_exception(StopIteration, tx) + + tx.output.side_effects.mutation(self) + self.index += 1 + return self.items[old_index] + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + return variables.ConstantVariable.create(hasattr(iter([]), name)) + + def python_type(self) -> type: + return type(iter([])) + + def as_python_constant(self) -> Any: + if self.index > 0: + raise NotImplementedError + return iter([x.as_python_constant() for x in self.items]) + + def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: + return True + + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: + if self.is_exhausted: + return [] + self.is_exhausted = True + return list(self.items[self.index :]) + + def force_unpack_var_sequence( + self, tx: "InstructionTranslator" + ) -> list[VariableTracker]: + return self.unpack_var_sequence(tx) + + def reconstruct(self, codegen: "PyCodegen") -> None: + if not self.is_exhausted: + remaining_items = self.items[self.index :] + else: + remaining_items = [] + codegen.foreach(remaining_items) + codegen.extend_output( + [ + create_build_tuple(len(remaining_items)), + create_instruction("GET_ITER"), + ] + ) + + +class TupleIteratorVariable(ListIteratorVariable): + pass + + +class RangeIteratorVariable(IteratorVariable): + # only needed for isinstance(..., range_iterator) to work + _nonvar_fields = { + "iter_obj", + } + + def __init__( + self, start: int, stop: int, step: int, len_: int, **kwargs: Any + ) -> None: + super().__init__(**kwargs) + self.start = start + self.stop = stop + self.step = step + self.len = len_ + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__next__": + return self.next_variable(tx) + elif name == "__iter__": + return self + return super().call_method(tx, name, args, kwargs) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + if self.python_type() is range_iterator: + ri = iter(range(0)) + return ConstantVariable(hasattr(ri, name)) + return super().call_obj_hasattr(tx, name) + + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: + if self.len <= 0: + raise_observed_exception(StopIteration, tx) + + self.len -= 1 + current = self.start + self.start += self.step + return ConstantVariable.create(current) + + def python_type(self) -> type: + return range_iterator + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null( + lambda: codegen.append_output(codegen.create_load_python_module(range)) # type: ignore[arg-type] + ) + codegen.append_output(codegen.create_load_const(self.start)) + codegen.append_output(codegen.create_load_const(self.stop)) + codegen.append_output(codegen.create_load_const(self.step)) + codegen.extend_output(create_call_function(3, False)) + codegen.append_output(create_instruction("GET_ITER")) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/misc.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..95816b81fa199d8427c24a9b50cbd74e24b81f24 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/misc.py @@ -0,0 +1,2129 @@ +# mypy: ignore-errors + +""" +This module contains miscellaneous variable tracker implementations for various Python types +and features used in Dynamo's symbolic execution. These classes help track and propagate +information about different kinds of variables during graph capture. + +Key classes include: +- SuperVariable: Handles super() calls and method resolution +- ExceptionVariable: Tracks exception objects +- RandomVariable: Manages random number generators +- GetAttrVariable: Tracks attribute access +- MethodWrapperVariable: Handles method wrappers +- PythonModuleVariable: Tracks Python modules +- NumpyVariable: Handles numpy functions and types +- StringFormatVariable: Manages string formatting +- DebuggingVariable: Handles print and logging +""" + +import dataclasses +import enum +import functools +import inspect +import itertools +import random +import re +import sys +import types +import warnings +from typing import Optional, TYPE_CHECKING + +import torch._C +import torch._numpy as tnp +import torch.utils._pytree as pytree + +from .. import config, graph_break_hints, trace_rules, variables +from ..bytecode_transformation import ( + create_call_function, + create_call_function_ex, + create_instruction, +) +from ..create_parameter_op import do_not_convert_to_tracable_parameter +from ..exc import raise_observed_exception, unimplemented +from ..guards import GuardBuilder, install_guard +from ..mutation_guard import unpatched_nn_module_init +from ..source import ( + AttrSource, + GenericAttrSource, + GetItemSource, + TypeMROSource, + TypeSource, + WeakRefCallSource, +) +from ..utils import ( + check_unspec_or_constant_args, + cmp_name_to_op_mapping, + identity, + is_tensor_base_attr_getter, + istype, + list_methods, + proxy_args_kwargs, + raise_args_mismatch, + tuple_methods, +) +from .base import ( + AsPythonConstantNotImplementedError, + raise_type_error_exc, + VariableTracker, +) +from .constant import ConstantVariable +from .functions import NestedUserFunctionVariable, UserFunctionVariable +from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObjectVariable + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + +class NO_SUCH_SUBOBJ: + pass + + +class SuperVariable(VariableTracker): + _nonvar_fields = { + *VariableTracker._nonvar_fields, + } + + def __init__(self, typevar, objvar=None, **kwargs) -> None: + super().__init__(**kwargs) + # typevar is the first argument to super(). In the case where no argument + # is provided to super(), it is the __class__ object where + # the super() function is being called + self.typevar = typevar + # objvar here must be an instance or subtype of typevar. + # In the case where super() is called without arguments, it is the first argument + # to the current function where super() is called from (self for regular method, + # cls for a classmethod) + self.objvar = objvar + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null(lambda: codegen(variables.BuiltinVariable(super))) + codegen(self.typevar) + if self.objvar is not None: + codegen(self.objvar) + codegen.extend_output(create_call_function(2, False)) + else: + codegen.extend_output(create_call_function(1, False)) + + def _resolved_getattr_and_source(self, tx: "InstructionTranslator", name): + if not self.objvar: + unimplemented( + gb_type="1-arg super not implemented", + context="", + explanation=f"Dynamo failed to trace attribute `{name}` accessed " + f"via `super()` (for type `{self.typevar}` and object `{self.objvar}`) " + "because one-argument of super() is not supported.", + hints=[ + "Use two-argument super(type, object_or_type).", + ], + ) + search_type = self.typevar.as_python_constant() + + # The rest of this function does two things: + # - Walk the mro to find where the attribute comes from to be + # able to provide accurate source + # - Call the getattr to get the object + + # Find the class object, where the function lives. + # When objvar is "self", use type(self), when objvar is "cls", use it as-is + type_to_use = self.objvar.python_type() + type_to_use_source = ( + TypeSource(self.objvar.source) if self.objvar.source else None + ) + if issubclass(type_to_use, type): + type_to_use = self.objvar.value + type_to_use_source = self.objvar.source + + source = None + search_mro = type_to_use.__mro__ + + try: + start_index = search_mro.index(search_type) + 1 + except ValueError: + # Corner case where the typevar is not in the mro of the objvar + # https://github.com/python/cpython/blob/3.11/Objects/typeobject.c#L8843-L8844 + return getattr(super(search_type, type_to_use), name), None + # Implemented based on https://github.com/python/cpython/blob/3.11/Objects/typeobject.c#L8812 + # super has its getattro implementation. The key point is that instead of calling getattr, it checks the + # attribute in the class __dict__ + for index in range(start_index, len(search_mro)): + # Dont call getattr, just check the __dict__ of the class + if resolved_getattr := search_mro[index].__dict__.get(name, NO_SUCH_SUBOBJ): + if resolved_getattr is not NO_SUCH_SUBOBJ: + # Equivalent of something like type(L['self']).__mro__[1].attr_name + if type_to_use_source: + source = AttrSource( + GetItemSource(TypeMROSource(type_to_use_source), index), + name, + ) + return resolved_getattr, source + + unimplemented( + gb_type="Unable to resolve super getattr", + context="", + explanation=f"Dynamo failed to trace attribute `{name}` accessed " + f"via `super()` (for type `{self.typevar}` and object `{self.objvar}`) " + "because the resolved attribute type is not supported.", + hints=[ + "Ensure the attribute exists in the parent class.", + "Check the arguments passed to `super()`.", + ], + ) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": + # Check if getattr is a constant. If not, delay the actual work by + # wrapping the result in GetAttrVariable. Mostly super is called with a + # method, so most of the work is delayed to call_function. + # + # We could have just implemented a const_getattr. However, super is + # special when it comes to finding sources. Compared to other VTs, super + # requires the attr name to walk the mro and find the actual source (and + # not just AttrSource). + value, source = self._resolved_getattr_and_source(self, name) + if not variables.ConstantVariable.is_literal(value): + return GetAttrVariable(self, name) + if source: + install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH)) + return variables.ConstantVariable.create(value, source=source) + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + inner_fn, source = self._resolved_getattr_and_source(self, name) + # This essentially simulates CPython's `super_getattro`: + # https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/typeobject.c#L11138-L11168 + # where `inner_fn` is the VT for `res = _super_lookup_descr(...)`. + # + # However, `res`'s type needs to be checked for `tp_descr_get`, and + # applied if it has one. We currently don't have polyfills for all the + # relevant `tp_descr_get`, so we explicitly handle the cases we care + # about here (e.g., note the staticmethod, classmethod cases). + if inner_fn is object.__init__: + return LambdaVariable(identity) + elif inner_fn is torch.nn.Module.__init__: + objvar = self.objvar + from ..side_effects import AttributeMutationNew + + if ( + isinstance(objvar, variables.UserDefinedObjectVariable) + and isinstance(objvar.mutation_type, AttributeMutationNew) + and not (args or kwargs) + ): + with do_not_convert_to_tracable_parameter(): + fn_vt = VariableTracker.build( + tx, unpatched_nn_module_init, source=source + ) + return fn_vt.call_function(tx, [self.objvar] + args, kwargs) + else: + unimplemented( + gb_type="Unsupported super().__init__() call", + context=f"call_method {self} {name} {args} {kwargs}", + explanation="Dynamo encountered a super().__init__() call " + f"on {objvar} that resolved to a `torch.nn.Module.__init__()` " + "call that we cannot trace.", + hints=[*graph_break_hints.DIFFICULT], + ) + elif ( + self.objvar.source + and hasattr(inner_fn, "__name__") + and inner_fn.__name__ == "__new__" + and variables.UserDefinedClassVariable.is_supported_new_method(inner_fn) + ): + user_cls = inner_fn.__self__ + if hasattr(user_cls, "__module__") and user_cls.__module__ == "builtins": + user_cls_vt = variables.BuiltinVariable(user_cls) + else: + user_cls_source = source.member + user_cls_vt = variables.UserDefinedClassVariable( + user_cls, source=user_cls_source + ) + return user_cls_vt.call_method(tx, "__new__", args, kwargs) + elif isinstance(inner_fn, staticmethod) and isinstance( + inner_fn.__func__, types.FunctionType + ): + fn_vt = VariableTracker.build(tx, inner_fn.__func__, source=source) + return fn_vt.call_function(tx, args, kwargs) + elif isinstance(inner_fn, classmethod) and isinstance( + inner_fn.__func__, types.FunctionType + ): + if isinstance(self.objvar, variables.UserDefinedClassVariable): + # super().classmethod is called from a classmethod itself. So, + # super was converted to super(__class__, cls) in bytecode and + # therefore we have to propagate the cls. + cls_variable = self.objvar + else: + # current function is an instance method, therefore super was + # converted to super(__class__, self). We have to find + # type(self) to bind the cls to the parent classmethod. + # Note that it can't be the self.typevar because __class__ is + # the class where the method is defined, which could be + # different from type(self) with polymorphism. + cls_source = None + if self.objvar.source: + cls_source = TypeSource(self.objvar.source) + cls_variable = VariableTracker.build( + tx, self.objvar.value_type, cls_source + ) + + fn_vt = VariableTracker.build( + tx, inner_fn.__func__, source=AttrSource(source, "__func__") + ) + return fn_vt.call_function(tx, [cls_variable, *args], kwargs) + elif isinstance(inner_fn, types.FunctionType): + fn_vt = VariableTracker.build(tx, inner_fn, source=source) + return fn_vt.call_function(tx, [self.objvar] + args, kwargs) + elif isinstance(inner_fn, types.MethodType): + return variables.UserMethodVariable( + inner_fn.__func__, self.objvar, source=source + ).call_function(tx, args, kwargs) + elif is_standard_setattr(inner_fn) and isinstance( + self.objvar, UserDefinedObjectVariable + ): + return self.objvar.method_setattr_standard(tx, *args, **kwargs) + elif inner_fn is object.__delattr__: + attr = args[0] + try: + attr = attr.as_python_constant() + except NotImplementedError as exc: + unimplemented( + gb_type="Non-constant attribute given to `super().__delattr__()`", + context=f"call_method {self} {name}", + explanation="Dynamo requires the attribute name passed to " + "`super().__delattr__(...)` to be a constant (string).", + hints=[ + "Ensure the attribute name is a string literal or a constant variable." + ], + from_exc=exc, + ) + if not tx.output.side_effects.is_attribute_mutation(self.objvar): + unimplemented( + gb_type="Attempted super().__delattr__() on an object without mutation tracking", + context=f"call_method {self} {name}", + explanation="Dynamo needs to track mutations on an object " + "before `super().__delattr__` can be used on it. But the " + f"object ({self.objvar}) doesn't have attribute mutation " + "tracking enabled.", + hints=[ + "Ensure the object is tracked by Dynamo's side effect system.", + *graph_break_hints.DYNAMO_BUG, + ], + ) + + tx.output.side_effects.store_attr( + self.objvar, attr, variables.DeletedVariable() + ) + return variables.ConstantVariable(None) + elif ( + isinstance(self.objvar, variables.UserDefinedDictVariable) + and inner_fn in self.objvar._dict_methods + ): + return self.objvar._dict_vt.call_method(tx, name, args, kwargs) + elif ( + isinstance(self.objvar, variables.UserDefinedSetVariable) + and inner_fn in self.objvar._set_methods + ): + return self.objvar._set_vt.call_method(tx, name, args, kwargs) + elif ( + isinstance(self.objvar, variables.UserDefinedTupleVariable) + and inner_fn in tuple_methods + ): + return self.objvar._tuple_vt.call_method(tx, name, args, kwargs) + elif ( + isinstance(self.objvar, variables.UserDefinedListVariable) + and inner_fn in list_methods + ): + return self.objvar._list_vt.call_method(tx, name, args, kwargs) + elif inner_fn is object.__getattribute__: + # object.__getattribute__ has no side-effects. We can directly call + # __getattribute__ to access the attribute. + attr_name = args[0].value + if tx.output.side_effects.has_pending_mutation_of_attr( + self.objvar, attr_name + ): + result = tx.output.side_effects.load_attr( + self.objvar, attr_name, deleted_ok=True + ) + if isinstance(result, variables.DeletedVariable): + raise_observed_exception(AttributeError, tx) + return result + + try: + # NB - use object.__getattribute__ to prevent running any user code + attr_value = object.__getattribute__(self.objvar.value, attr_name) + except AttributeError: + raise_observed_exception(AttributeError, tx) + + attr_source = None + if self.objvar.source is not None: + # setup a object.__getattribute__(self.objvar, name) source + attr_source = GenericAttrSource(self.objvar.source, attr_name) + return VariableTracker.build(tx, attr_value, attr_source) + elif inner_fn is torch._C._disabled_torch_function_impl: + # See `THPModule_disable_torch_function` for the C impl. + # The signature of _disabled_torch_function_impl is similar to + # `__torch_function__`, just without the first `cls` argument: + # * (func, types, args, kwargs) + func = args[0] + tf_kwargs = {} + tf_args = args[2].items + for hash_key_vt, value_vt in args[3].items.items(): + key_str = hash_key_vt.vt.as_python_constant() + tf_kwargs[key_str] = value_vt + + tx_old = tx.symbolic_torch_function_state.torch_function_subclass_enabled + tx.symbolic_torch_function_state.torch_function_subclass_enabled = False + try: + return func.call_function(tx, tf_args, tf_kwargs) + finally: + tx.symbolic_torch_function_state.torch_function_subclass_enabled = ( + tx_old + ) + elif ( + isinstance(inner_fn, types.MethodDescriptorType) + and inner_fn in trace_rules.get_tensor_method() + ): + # FunctionType but implementation is in C, we support some of these, + # e.g., tensor ops like `torch.Tensor.to`. + fn_var = VariableTracker.build(tx, inner_fn, source) + return fn_var.call_function(tx, [self.objvar] + args, kwargs) + + unimplemented( + gb_type="Attempted to call a super() attribute that is " + "not a function or method", + context=f"call_method {self} {name}", + explanation="Dynamo does not know how to trace the call " + f"`super().{name}()` because `super().{name}` is not a " + "function or method attribute.", + hints=[ + "Ensure the attribute accessed via `super()` is a standard method or function.", + ], + ) + + +class ExceptionVariable(VariableTracker): + # The ExceptionVariable corresponds to the BaseException class in Python + def __init__( + self, exc_type, args, init_kwargs=None, source=None, mutation_type=None + ) -> None: + super().__init__(source=source, mutation_type=mutation_type) + self.exc_type = exc_type + self.args = args + if init_kwargs: + unimplemented( + gb_type="Keyword args passed to exception constructor", + context=f"{self} with kwargs {init_kwargs}", + explanation="Dynamo does not know how to handle keyword args passed to an exception constructor", + hints=[*graph_break_hints.SUPPORTABLE], + ) + # When raising a new exception while another exception is already being + # handled, the new exception's __context__ attribute is automatically + # set to the handled exception. + self.__context__ = ConstantVariable(None) + # Set when user raised an exception from another: + # raise ... from ... + self.__cause__ = ConstantVariable(None) + # Boolean flag that controls whether the __context__ attribute is set + self.__suppress_context__ = ConstantVariable(False) + # Contains the call stack where the exception was raised. Dynamo does + # not track traceback. So, this variable is always set to None + self.__traceback__ = ConstantVariable(None) + + def set_context(self, context: "ExceptionVariable"): + self.__context__ = context + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null( + lambda: codegen.load_import_from("builtins", self.exc_type.__name__) + ) + codegen.foreach(self.args) + codegen.call_function(len(self.args), False) + + def codegen_attr(name: str) -> None: + attr = getattr(self, name) + if istype(attr, ConstantVariable): + assert attr.value in (True, False, None), attr + else: + codegen.dup_top() + codegen(attr) + codegen.extend_output(codegen.rot_n(2)) + codegen.store_attr(name) + + codegen_attr("__context__") + codegen_attr("__cause__") + codegen_attr("__suppress_context__") + + def python_type(self): + return self.exc_type + + def call_setattr( + self, + tx: "InstructionTranslator", + name_var: VariableTracker, + val: VariableTracker, + ): + def raise_error(msg): + raise_observed_exception(TypeError, tx, args=[ConstantVariable(msg)]) + + name = name_var.as_python_constant() + if name == "__context__": + self.set_context(val) + elif name == "__cause__": + if val.is_constant_none() or isinstance( + val, + ( + variables.BuiltinVariable, + variables.ExceptionVariable, + variables.UserDefinedExceptionClassVariable, + variables.UserDefinedExceptionObjectVariable, + ), + ): + self.__cause__ = val + self.__suppress_context__ = variables.ConstantVariable(True) + else: + raise_error("exception cause must be None or derive from BaseException") + elif name == "__suppress_context__": + if val.is_constant_match(True, False): + self.__suppress_context__ = val + else: + raise_error("exception cause must be None or derive from BaseException") + elif name == "__traceback__": + if val.is_constant_none(): + self.__traceback__ = val + else: + unimplemented( + gb_type="Set Exception object `__traceback__` attribute to not-`None`", + context=f"call_setattr {self} {name}", + explanation="Dynamo does not support setting the attribute " + "'__traceback__' on tracked exception objects to anything " + "other than None.", + hints=[ + "Avoid setting '__traceback__' on exception objects " + "within traced code, or set it to None." + ], + ) + else: + unimplemented( + gb_type="Unsupported attribute assignment on Exception object", + context=f"call_setattr {self} {name}", + explanation="Dynamo does not support setting the attribute " + f"'{name}' on tracked exception objects. Only `__context__`, " + "`__cause__`, `__suppress_context__`, and `__traceback__` are supported.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + return variables.ConstantVariable(None) + + def call_method(self, tx, name, args, kwargs): + if name == "__setattr__": + return self.call_setattr(tx, *args) + elif name == "with_traceback": + [tb] = args + self.call_setattr(tx, ConstantVariable("__traceback__"), tb) + return self + else: + return super().call_method(tx, name, args, kwargs) + + def var_getattr(self, tx, name): + if name == "__context__": + return self.__context__ + elif name == "__cause__": + return self.__cause__ + elif name == "__suppress_context__": + return self.__suppress_context__ + elif name == "__traceback__": + return variables.ConstantVariable(None) + elif name == "args": + return variables.ListVariable(self.args, source=self.source) + return super().var_getattr(tx, name) + + def __str__(self): + return f"{self.__class__.__name__}({self.exc_type})" + + __repr__ = __str__ + + +class UnknownVariable(VariableTracker): + """ + It could be anything! + """ + + +class DelayGraphBreakVariable(UnknownVariable): + """ + Used to insert a dummy variable in the stack to do the graph break at CALL_FUNCTION. + """ + + def __init__(self, msg=None, **kwargs): + super().__init__(**kwargs) + self.msg = msg + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + unimplemented( + gb_type="Unsupported function call (delayed)", + context=f"source: {self.source}", + explanation="Dynamo determined that a graph break should occur " + f"when calling `{self.source.name}`. Reason: {self.msg}", + hints=[], + ) + + +class ComptimeVariable(VariableTracker): + """ + This variable is special, it lets you execute arbitrary code at + Dynamo compile time + """ + + def reconstruct(self, codegen: "PyCodegen"): + raise NotImplementedError("comptime is special form") + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": + from ..comptime import comptime + + # To support the comptime.print_graph convenience accessors + return VariableTracker.build( + tx, getattr(comptime, name), source=AttrSource(self.source, name) + ) + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from ..comptime import ComptimeContext + + # TODO: support an expression form as well + # Second argument is runtime lambda, ignored + if kwargs or len(args) > 2: + raise_args_mismatch( + tx, + "comptime()", + "at most 2 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + fn = args[0] + if isinstance(fn, UserFunctionVariable): + fn.get_function()(ComptimeContext(tx)) + elif isinstance(fn, NestedUserFunctionVariable): + # We have to manually bind the freevars ourselves + code = fn.get_code() + if fn.closure: + raise_type_error_exc( + tx, + f"comptime function must not have free variables, but these variables were free: {code.co_freevars}", + ) + func = types.FunctionType( + code, + fn.f_globals, + fn.fn_name.as_python_constant(), + tuple(fn.defaults.items) if fn.defaults else None, + # We could automatically promote free variables into + # ComptimeVar but this is confusing if you access + # a free variable that we actually DO have the runtime + # value for + # tuple(make_cell(ComptimeVar(i)) for i in fn.closure.items) + (), + ) + func(ComptimeContext(tx)) + else: + raise RuntimeError(f"unsupported argument to comptime: {type(fn)}") + + return variables.ConstantVariable.create(None) + + +class CellVariable(VariableTracker): + # If the cell existed before Dynamo tracing started, this will be the + # VariableTracker that represents the cell content. + # + # Note that all mutation to the cell (i.e., its content) will be buffered in + # SideEffects, rather than being reflected here. One can think of + # `CellVariable` as a special case for `UserDefinedObjectVariable`. + pre_existing_contents: Optional[VariableTracker] + + # This is set when this cell can be referenced via `LOAD/STORE_DEREF` in the + # root frame via this name (e.g., the name is in `co_cellvars/co_freevars`). + local_name: Optional[str] = None + + def __init__( + self, pre_existing_contents: Optional[VariableTracker] = None, **kwargs + ) -> None: + super().__init__(**kwargs) + self.pre_existing_contents = pre_existing_contents + + +class NewGlobalVariable(VariableTracker): + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + +def produce_trampoline_autograd_apply(fn_cls): + def trampoline_autograd_apply(*args, **kwargs): + return fn_cls.apply(*args, **kwargs) + + trampoline_autograd_apply._origin = produce_trampoline_autograd_apply + return trampoline_autograd_apply + + +class AutogradFunctionVariable(VariableTracker): + """represents a torch.autograd.Function subclass""" + + _nonvar_fields = { + "fn_cls", + *VariableTracker._nonvar_fields, + } + + def __init__(self, fn_cls, **kwargs) -> None: + super().__init__(**kwargs) + self.fn_cls = fn_cls + + def call_apply(self, tx: "InstructionTranslator", args, kwargs): + requires_grad = False + + def visit(vt): + nonlocal requires_grad + if vt.is_tensor(): + if vt.requires_grad is not False: + requires_grad = True + if isinstance(vt, variables.NNModuleVariable): + if vt.is_training(tx): + requires_grad = True + + VariableTracker.visit(visit, (args, kwargs)) + + if requires_grad and torch.is_grad_enabled(): + if config.capture_autograd_function is False: + warnings.warn( + "The config.capture_autograd_function flag is deprecated, it's now always true." + ) + + from torch._functorch.autograd_function import ( + autograd_function_forward_rewritten, + ) + from torch.autograd.function import _is_setup_context_defined + + forward_fn = self.fn_cls.forward + + is_setup_ctx_defined = _is_setup_context_defined(self.fn_cls.setup_context) + if is_setup_ctx_defined: + # If setup_context is defined, we generate a new forward function which includes + # the original forward and setup_context function, and trace the new forward function. + forward_fn = autograd_function_forward_rewritten( + self.fn_cls.forward, self.fn_cls.setup_context + ) + + vjp_fn = self.fn_cls.vjp # type: ignore[attr-defined] + if vjp_fn is not torch.autograd.Function.vjp: + unimplemented( + gb_type="Unsupported custom vjp", + context=f"call_apply {self} {args} {kwargs}", + explanation="Dynamo does not support tracing " + "`torch.autograd.Function` subclasses that define " + "a custom `vjp` method.", + hints=[ + "Remove the custom `vjp` method if possible.", + "Use standard `backward` instead if applicable.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + jvp_fn = self.fn_cls.jvp # type: ignore[attr-defined] + if jvp_fn is not torch.autograd.Function.jvp: + unimplemented( + gb_type="Unsupported custom jvp", + context=f"call_apply {self} {args} {kwargs}", + explanation="Dynamo does not support tracing " + "`torch.autograd.Function` subclasses that define " + "a custom `jvp` method.", + hints=[ + "Remove the custom `jvp` method if possible.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + from .higher_order_ops import AutogradFunctionApplyVariable + + source = self.source + if source is None: + source = AttrSource( + tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__ + ) + + val = AutogradFunctionApplyVariable( + forward_fn, + self.fn_cls.backward, + source, + source=AttrSource(source, member="apply"), + ).call_function(tx, args, kwargs) + # Inside of AutogradFunctionApplyVariable.call_function, we use sourceless variable wrapping + # the forward function, as we don't want to generate guards for new_forward.__closure__ + # if forward is rewritten by autograd_function_forward_rewritten. + # But we still need to generate correct guards for the original forward and setup_context + # functions, so we have to add guards manually. + if self.source: + fwd_src = AttrSource(self.source, "forward") + install_guard(fwd_src.make_guard(GuardBuilder.CLOSURE_MATCH)) + if is_setup_ctx_defined: + setup_ctx_src = AttrSource(self.source, "setup_context") + install_guard(setup_ctx_src.make_guard(GuardBuilder.CLOSURE_MATCH)) + + return val + + if self.source: + source = AttrSource(self.source, "forward") + else: + source = None + + fn = self.fn_cls.forward + ctx = AutogradFunctionContextVariable.create(tx, args, kwargs) + args = [ctx, *args] + if isinstance(fn, types.FunctionType): + sig = inspect.signature(fn) + if len(args) - 1 == len(sig._parameters): + args = args[1:] # Don't use context + fn_vt = VariableTracker.build(tx, fn, source=source) + return fn_vt.call_function(tx, args, kwargs) + elif isinstance(fn, types.MethodType): + return variables.UserMethodVariable( + fn.__func__, + variables.UserDefinedClassVariable(self.fn_cls), + source=source, + ).call_function(tx, args, kwargs) + else: + unimplemented( + gb_type="Non-function or method in subclass of torch.autograd.Function", + context=f"call_apply {self} {args} {kwargs}", + explanation="Dynamo requires the `forward` attribute of a " + "`torch.autograd.Function` subclass to be a standard Python " + f"function or method. Found type `{type(fn).__name__}` instead.", + hints=[ + "Ensure the `forward` method is defined as a regular " + "function or instance method." + ], + ) + + def call_backward(self, tx: "InstructionTranslator", args, kwargs): + fn = self.fn_cls.backward + assert type(args[0].value) is torch._dynamo.external_utils.FakeBackwardCFunction + assert isinstance(fn, types.FunctionType) + + fn_source = AttrSource(self.source, "backward") + fn_vt = VariableTracker.build(tx, fn, source=fn_source) + return fn_vt.call_function(tx, args, kwargs) + + def call_function(self, tx: "InstructionTranslator", args, kwargs): + return AutogradFunctionVariable(self.fn_cls) + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ): + from .builder import wrap_fx_proxy + + if name == "apply": + if trace_rules.is_callable_allowed(self.fn_cls): + trampoline_autograd_apply = produce_trampoline_autograd_apply( + self.fn_cls + ) + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + trampoline_autograd_apply, + *proxy_args_kwargs(args, kwargs), + ), + ) + else: + return self.call_apply(tx, args, kwargs) + + elif name == "backward": + return self.call_backward(tx, args, kwargs) + else: + source = AttrSource(self.source, name) if self.source is not None else None + try: + obj = inspect.getattr_static(self.fn_cls, name) + except AttributeError: + obj = None + + if isinstance(obj, staticmethod): + func = obj.__get__(self.fn_cls) + if source is not None: + return ( + trace_rules.lookup(func) + .create_with_source(func, source=source) + .call_function(tx, args, kwargs) + ) + else: + return trace_rules.lookup(func)(func).call_function( + tx, args, kwargs + ) + elif isinstance(obj, classmethod): + return variables.UserMethodVariable( + obj.__func__, self, source=source + ).call_function(tx, args, kwargs) + else: + unimplemented( + gb_type="Unsupported autograd.Function method", + context=f"call_method {self} {name}", + explanation="Dynamo does not support calling the method " + f"`{name}` directly on the `torch.autograd.Function` " + "instance. Supported methods include `apply`, `backward`, " + "static methods, and class methods.", + hints=[ + "Ensure the method is decorated with `@staticmethod` " + "or `@classmethod` if it's meant to be called on the class.", + ], + ) + + +@dataclasses.dataclass +class SavedTensorBox: + tensors: list[VariableTracker] = dataclasses.field(default_factory=list) + + +class AutogradFunctionContextVariable(UserDefinedObjectVariable): + """ + Tracks an autograd.Function() context using mutation tracking in side_effects.py + """ + + _nonvar_fields = { + "proxy", + "inference", + "saved_tensors", + *UserDefinedObjectVariable._nonvar_fields, + } + + def __init__( + self, + value, + value_type=None, + inference=False, + saved_tensors=None, + needs_input_grad=None, + non_differentiable=None, + **kwargs, + ) -> None: + super().__init__(value=value, value_type=value_type, **kwargs) + self.inference = inference + self.saved_tensors = saved_tensors + self.needs_input_grad = needs_input_grad + self.non_differentiable = non_differentiable + + @staticmethod + def create(tx: "InstructionTranslator", args=None, kwargs=None): + needs_input_grad = None + if args and not kwargs: + needs_input_grad = tuple(x.is_tensor() and x.requires_grad for x in args) + out = tx.output.side_effects.track_object_new( + None, + torch.autograd.function.FunctionCtx, + functools.partial( + AutogradFunctionContextVariable, + inference=True, + saved_tensors=SavedTensorBox(), + needs_input_grad=needs_input_grad, + ), + {}, + ) + return out + + def as_proxy(self): + if self.proxy is None: + unimplemented( + gb_type="proxy not set", + context=f"as_proxy {self}", + explanation="Dynamo requires the autograd.Function context " + "to be initialized with a proxy.", + hints=[*graph_break_hints.DYNAMO_BUG], + ) + return self.proxy + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "__setattr__": + return super().call_method(tx, name, args, kwargs) + elif name == "mark_non_differentiable": + if kwargs: + raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") + self.non_differentiable = proxy_args_kwargs(args, {})[0] + return variables.ConstantVariable.create(None) + + if name != "save_for_backward": + unimplemented( + gb_type="Unsupported autograd.Function context method", + context=f"call_method {self} {name}", + explanation="Dynamo does not support calling the method " + f"`{name}` on `autograd.Function` context objects. Supported " + "methods are `__setattr__`, `save_for_backward` and " + "`mark_non_differentiable`.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + if self.saved_tensors is None: + unimplemented( + gb_type="Unsupported autograd.Function context `save_for_backward`", + context=f"call_method {self} {name}", + explanation="Dynamo requires the `saved_tensors` attribute " + "to be initialized on the `autograd.Function` context object.", + hints=[ + "Ensure that the `saved_tensors` attribute is properly " + "initialized before calling `save_for_backward`. " + "`save_for_backward` only supported on a newly constructed `torch.autograd.function.FunctionCtx`.", + ], + ) + + if not self.inference: + if kwargs or not self.source: + raise_type_error_exc( + tx, "save_for_backward() requires a source and no keyword arguments" + ) + tx.output.side_effects.track_save_for_backward(self, args) + + # In eager mode, multiple calls to .save_for_backward() will overwrite previous calls. + if len(self.saved_tensors.tensors) > 0: + self.saved_tensors.tensors = [] + for arg in args: + self.saved_tensors.tensors.append(arg) + return variables.ConstantVariable.create(None) + + def var_getattr(self, tx: "InstructionTranslator", name): + if name in ["save_for_backward", "mark_non_differentiable"]: + return LambdaVariable( + lambda *args, **kwargs: self.call_method(tx, name, args, kwargs) + ) + if name == "saved_tensors" and self.saved_tensors is not None: + return variables.TupleVariable(list(self.saved_tensors.tensors)) + if name == "needs_input_grad": + if self.needs_input_grad is not None: + return variables.ConstantVariable.create(self.needs_input_grad) + if self.source: + source = AttrSource(self.source, "needs_input_grad") + return VariableTracker.build(tx, self.value.needs_input_grad, source) + + return super().var_getattr(tx, name) + + +class AutogradEngineVariable(UserDefinedObjectVariable): + """ + Represents a torch._C._ImperativeEngine instance. + """ + + def __init__( + self, + value, + value_type=None, + **kwargs, + ) -> None: + super().__init__(value=value, value_type=value_type, **kwargs) + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "queue_callback": + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + assert tx.one_graph or tx.error_on_graph_break, ( + "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True" + ) + # queue_callback is a method-wrapper, no need to insert a guard. + fn_vt = VariableTracker.build( + tx, + torch._dynamo.external_utils.FakeCompiledAutogradEngine.queue_callback, + ) + return fn_vt.call_function( + tx, + (tx.output.side_effects.get_ca_final_callbacks_var(), *args), + kwargs, + ) + else: + unimplemented( + gb_type="Unsupported torch._C._ImperativeEngine.queue_callback()", + context=f"call_method {self} {name}", + explanation="queue_callback() is only supported when " + "Compiled Autograd is enabled with fullgraph=True.", + hints=[], + ) + else: + unimplemented( + gb_type="Unsupported torch._C._ImperativeEngine method", + context=f"call_method {self} {name}", + explanation="Dynamo only supports the `queue_callback` method " + f"on a torch._C._ImperativeEngine instance, but found: `{name}`.", + hints=[], + ) + + +class LambdaVariable(VariableTracker): + def __init__(self, fn, **kwargs) -> None: + super().__init__(**kwargs) + self.fn = fn + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + return self.fn(*args, **kwargs) + + +class GetAttrVariable(VariableTracker): + _nonvar_fields = { + "name", + "py_type", + *VariableTracker._nonvar_fields, + } + + def __init__(self, obj, name, py_type=None, **kwargs) -> None: + super().__init__(**kwargs) + assert isinstance(obj, VariableTracker) + assert isinstance(name, str) + self.obj = obj + self.name = name + self.py_type = py_type # In some cases we know the type (ex. tensor methods) + + def python_type(self): + if self.py_type is not None: + return self.py_type + else: + return super().python_type() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.obj}, {self.name})" + + @staticmethod + def create_getattr_proxy(base_proxy: torch.fx.Proxy, attr): + return getattr(base_proxy, attr) + + def as_proxy(self): + return GetAttrVariable.create_getattr_proxy(self.obj.as_proxy(), self.name) + + def as_python_constant(self): + constant = self.obj.as_python_constant() + try: + return getattr(constant, self.name) + except AttributeError: + raise NotImplementedError(f"{self} is not a constant") from None + + def const_getattr(self, tx: "InstructionTranslator", name): + if not isinstance(self.obj, variables.NNModuleVariable): + raise NotImplementedError + step1 = tx.output.get_submodule(self.obj.module_key) + if self.name not in step1.__dict__: + raise NotImplementedError + step2 = inspect.getattr_static(step1, self.name) + if name not in step2.__dict__: + raise NotImplementedError + return inspect.getattr_static(step2, name) + + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.obj) + codegen.extend_output(codegen.create_load_attrs(self.name)) + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + return self.obj.call_method(tx, self.name, args, kwargs) + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if ( + name in ("__getitem__", "get") + and self.name == "__dict__" + and not kwargs + and args[0].is_python_constant() + and isinstance( + self.obj, + ( + variables.UserDefinedObjectVariable, + variables.NNModuleVariable, + variables.UserDefinedClassVariable, + ), + ) + ): + obj = self.obj + key = args[0].as_python_constant() + if obj.has_key_in_generic_dict(tx, key): + # redirect to var_getattr on the original obj + return obj.var_getattr(tx, key) + + # Return the default value for get + if name == "get": + if len(args) == 2: + return args[1] + else: + return variables.ConstantVariable(None) + + elif ( + name == "__contains__" + and self.name == "__dict__" + and len(args) == 1 + and args[0].is_python_constant() + and not kwargs + and isinstance( + self.obj, + ( + variables.UserDefinedObjectVariable, + variables.NNModuleVariable, + variables.UserDefinedClassVariable, + ), + ) + ): + obj = self.obj + key = args[0].as_python_constant() + if obj.has_key_in_generic_dict(tx, key): + return variables.ConstantVariable(True) + else: + return variables.ConstantVariable(False) + + elif name == "__setitem__" and self.name == "__dict__" and not kwargs: + if isinstance(self.obj, variables.UserDefinedObjectVariable): + # Bypass any custom setattr as we are updating the `__dict__` itself + return self.obj.method_setattr_standard( + tx, args[0], args[1], directly_update_dict=True + ) + if isinstance(self.obj, variables.NNModuleVariable): + # This matches how `setattr` is handled for NNModuleVariable + self.obj.convert_to_unspecialized(tx) + + return super().call_method(tx, name, args, kwargs) + + def get_forwarded_dict(self, tx): + assert ( + self.name == "__dict__" + and isinstance(self.obj, variables.UserDefinedClassVariable) + and not tx.output.side_effects.has_pending_mutation(self.obj) + ) + self.obj.ban_mutation = True + return VariableTracker.build(tx, self.obj.value.__dict__, self.source) + + +class MethodWrapperVariable(VariableTracker): + def __init__(self, method_wrapper, **kwargs) -> None: + super().__init__(**kwargs) + self.method_wrapper = method_wrapper + self._builtin_fns = {} + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if is_tensor_base_attr_getter(self.method_wrapper) and isinstance( + args[0], variables.TensorVariable + ): + if not (len(args) == 1 and len(kwargs) == 0): + raise_type_error_exc( + tx, "tensor attribute getter takes exactly one argument" + ) + + return args[0].var_getattr(tx, self.method_wrapper.__self__.__name__) + + # method-wrapper variables are common in __init__ calls. For example, + # str("foo").__init__ is a method-wrapper. These method wrappers point + # to C functions. Here we intercept if these method-wrappers are from + # builtins and then call the function counterpart directly by obtaining + # the self object. + self_obj = self.method_wrapper.__self__ + wrapper_name = self.method_wrapper.__name__ + # TODO(dynamo-team) - We can perhaps expand the scope to more names and + # more builtins. + if wrapper_name == "__init__": + fn_obj = type(self_obj).__init__ + if fn_obj is object.__init__: + return variables.BuiltinVariable(object).call_method( + tx, wrapper_name, [self_obj, *args], kwargs + ) + elif ( + sys.version_info >= (3, 14) + # for some reason, even if the below check passes, + # self.method_wrapper may not be the same as type.__dict__["__annotations__"].__get__ + and self_obj is type.__dict__["__annotations__"] + and wrapper_name == "__get__" + ): + from .builder import SourcelessBuilder + + if len(args) == 1 and not kwargs: + try: + return SourcelessBuilder.create( + tx, self.method_wrapper(args[0].as_python_constant()) + ) + except AttributeError: + raise_observed_exception(AttributeError, tx) + except AsPythonConstantNotImplementedError: + pass + + unimplemented( + gb_type="unsupported type.__dict__['__annotations__'].__get__ call", + context=f"call_function {self}, args: {args}, kwargs: {kwargs}", + explanation="`torch.compile` only supports calling type.__dict__['__annotations__'].__get__ " + "on a single constant argument (i.e. a type).", + hints=[ + "Make sure your call to type.__dict__['__annotations__'] only has " + "one positional argument (no keyword arguments).", + "Make sure the argument to type.__dict__['__annotations__'] is a constant " + "(i.e. type). For example, `object`, `int`, `MyCustomClass`.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + return super().call_function(tx, args, kwargs) + + def is_python_constant(self): + return True + + def as_python_constant(self): + return self.method_wrapper + + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + + +class GetSetDescriptorVariable(VariableTracker): + def __init__(self, desc, **kwargs) -> None: + super().__init__(**kwargs) + self.desc = desc + + def var_getattr(self, tx: "InstructionTranslator", name): + if name == "__get__" and self.source: + source = AttrSource(self.source, "__get__") + return VariableTracker.build(tx, self.desc.__get__, source) + else: + return super().var_getattr(tx, name) + + def is_python_constant(self): + return True + + def as_python_constant(self): + return self.desc + + +class PythonModuleVariable(VariableTracker): + _nonvar_fields = { + "value", + "is_torch", + *VariableTracker._nonvar_fields, + } + + def __init__(self, value: types.ModuleType, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + self.is_torch = self.value is torch or self.value.__name__.startswith("torch.") + + def python_type(self): + return types.ModuleType + + def as_python_constant(self): + return self.value + + def __repr__(self) -> str: + return f"PythonModuleVariable({self.value})" + + def call_obj_hasattr(self, tx: "InstructionTranslator", name): + result = hasattr(self.value, name) + return variables.ConstantVariable.create(result) + + def var_getattr(self, tx: "InstructionTranslator", name): + if tx.output.side_effects.has_pending_mutation_of_attr(self, name): + return tx.output.side_effects.load_attr(self, name) + + if self.is_torch or name not in self.value.__dict__: + try: + attr_value = getattr(self.value, name) + except AttributeError: + raise_observed_exception(AttributeError, tx) + else: + attr_value = self.value.__dict__[name] + + source = self.source and AttrSource(self.source, name) + return VariableTracker.build(tx, attr_value, source) + + +class TypingVariable(VariableTracker): + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + # Create a new typing variable, e.g., `List[int]` + if name == "__getitem__" and len(args) == 1: + new_typing = self.value[args[0].as_python_constant()] + return TypingVariable(new_typing) + unimplemented( + gb_type="unsupported method call on `typing` variable", + context=f"typing variable: {self.value}, method name: {name}, args: {args}, kwargs: {kwargs}", + explanation=f"`torch.compile` does not support method call `{name}` on `typing` variable f{self.value}.", + hints=[ + f"Avoid calling the {name} method on {self.value}.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + def var_getattr(self, tx: "InstructionTranslator", name: str): + from .builder import SourcelessBuilder, VariableBuilder + + if name in cmp_name_to_op_mapping: + return variables.GetAttrVariable(self, name) + + if tx.output.side_effects.has_pending_mutation_of_attr(self, name): + return tx.side_effects.load_attr(self, name) + + value = getattr(self.value, name) + if self.source: + attr_source = AttrSource(self.source, name) + return VariableBuilder(tx, attr_source)(value) + else: + return SourcelessBuilder.create(tx, value) + + def as_python_constant(self): + return self.value + + def reconstruct(self, codegen: "PyCodegen") -> None: + if not isinstance(self.value, types.GenericAlias): + return super().reconstruct(codegen) + # We're just trying to load the type here. Reconstructing the type from + # scratch is tricky - for a type like `typing.List[int]` we'd need to + # deconstruct the origin and args. The origin for `List[int]` is `list` + # and the args is `(int,)`. When we recombine those we get the parts + # back and need to emit code for: + # + # `typing.List[int]` + # + # But it's # worse than that - what if `typing` isn't in the globals (or + # was loaded like `import typing as _typing ; _typing.List[int]`?) so we + # really need to do something like: + # + # `sys.modules["typing"].List[int]` + # + # Argh - but what if they rewrote the global `int`? So we have to do: + # + # `sys.modules["typing"].List[sys.modules["builtins"].int]` + # + # But where do we get `sys`? What if they never imported it or have + # something ELSE called `sys`? + # + # Let's skip all that noise and just emit it as a simple const. + # + codegen.append_output(codegen.create_load_const(self.value)) + + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + + +@functools.lru_cache(maxsize=1) +def get_np_to_tnp_map(): + """ + This generates a mapping from numpy modules to their torch._numpy + modules equivalents. + """ + from ..utils import NP_TO_TNP_MODULE + + np_fn_to_tnp_fn = {} + + for np_mod, tnp_mod in NP_TO_TNP_MODULE.items(): + for fn_name, tnp_fn in tnp_mod.__dict__.items(): + if callable(tnp_fn): + # some internal details do leak from tnp + # which are not part of numpy API. + if np_fn := getattr(np_mod, fn_name, None): + np_fn_to_tnp_fn[np_fn] = tnp_fn + + return np_fn_to_tnp_fn + + +@functools.lru_cache(maxsize=1) +def get_tnp_to_np_map(): + """ + This is just the reverse mapping of get_np_to_tnp_map() - mapping from + torch._numpy modules to numpy equivalents. + """ + m = get_np_to_tnp_map() + return {v: k for k, v in m.items()} + + +class NumpyVariable(VariableTracker): + """ + Wrapper around `numpy.*`. Currently, is able to trace a small subset of numpy functions as well as numpy dtypes. + """ + + constant_fold_functions = (tnp.issubdtype,) + + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + + @classmethod + def can_constant_fold_through(cls, fn): + mod = fn.__module__.split(".") + assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"] + return fn in cls.constant_fold_functions + + @classmethod + def get_constant_collection_for_func(cls, fn): + mod = fn.__module__.split(".") + assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"] + return np_constant_collections_map.get(fn) + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if not config.trace_numpy: + unimplemented( + gb_type="attempted to trace numpy function with config.trace_numpy=False", + context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs}", + explanation=f"Attempted to trace numpy function {self.value} " + "while `torch._dynamo.config.trace_numpy` was set to False.", + hints=[ + "Set `torch._dynamo.config.trace_numpy` to True to trace numpy functions.", + ], + ) + + from ..utils import numpy_to_tensor_wrapper + from .tensor import NumpyNdarrayVariable + + func = get_np_to_tnp_map().get(self.value) + if func is None: + unimplemented( + gb_type="attempted to trace numpy function unsupported by PyTorch", + context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})", + explanation=f"Can't find numpy numpy function {self.value} in torch._numpy.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + # We are dealing with a function that produces a const collection type (np.dtype, np.iinfo/np.finfo) + if ( + collection_variable_typ := self.get_constant_collection_for_func(func) + ) is not None: + try: + return collection_variable_typ( + self.value( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ) + ) + except AsPythonConstantNotImplementedError: + unimplemented( + gb_type="numpy function that produces a const collection type encountered non-const arguments", + context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})", + explanation=f"numpy function {self.value} that produces a const collection type " + "(e.g. np.dtype, np.iinfo/np.finfo) " + "received arguments that are not constant.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + else: + if ( + func.__module__ == "torch._numpy.random" + and config.use_numpy_random_stream + ): + unimplemented( + gb_type="attempted to trace torch._numpy.random function with config.use_numpy_random_stream=True", + context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})", + explanation=f"Attempted to trace {self.value} when `torch._dynamo.config.use_numpy_random_stream` " + "is set to True.", + hints=[ + "Set `torch._dynamo.config.use_numpy_random_stream` to False.", + f"Avoid calling {self.value}.", + ], + ) + + args, kwargs = NumpyNdarrayVariable.patch_args(func.__name__, args, kwargs) + + if self.can_constant_fold_through(func) and ( + check_unspec_or_constant_args(args, kwargs) + ): + # constant fold + return variables.ConstantVariable.create( + self.as_python_constant()( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ), + ) + + # TODO Add all the functions that go from constants to constants to can_constant_fold_through + proxy = tx.output.create_proxy( + "call_function", + numpy_to_tensor_wrapper(func), + *proxy_args_kwargs(args, kwargs), + ) + return NumpyNdarrayVariable.create(tx, proxy) + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + unimplemented( + gb_type="attempted to trace numpy.* function as a method", + context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs}", + explanation="Tracing numpy.* functions as methods is not supported.", + hints=[ + *graph_break_hints.DIFFICULT, + ], + ) + + def as_python_constant(self): + return self.value + + def as_proxy(self): + if config.trace_numpy: + # Can replace with EnumType once we drop 3.10 support + if isinstance(self.value, enum.EnumMeta): + # This is mostly for np._CopyMode + return self.value + if isinstance(self.value, type): + # This handles numpy dtype attributes such as np.float32 + # We return a string as we don't want to serialize non-PyTorch objects in the output FX graph + # In torch/_numpy we normalize strings to their dtypes when the input is a dtype, as NumPy does + return self.value.__name__ + + return super().as_proxy() + + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.as_python_constant()) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + + +# Used to keep track of NULLs pushed on the stack for Python 3.11 function calls +class NullVariable(VariableTracker): + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + def __repr__(self) -> str: + return "NullVariable" + + def reconstruct(self, codegen: "PyCodegen"): + if sys.version_info < (3, 11): + unimplemented( + gb_type="cannot reconstruct NullVariable in Python < 3.11", + context="", + explanation="Attempted to generate PUSH_NULL instruction in Python < 3.11; " + "where this instruction does not exist.", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + codegen.append_output(create_instruction("PUSH_NULL")) + + +class DeletedVariable(VariableTracker): + """Marker used to implement delattr()""" + + +class StringFormatVariable(VariableTracker): + """ + Represents a call to str.format(), we delay calling format until after the graph. + """ + + _nonvar_fields = {"format_string", *VariableTracker._nonvar_fields} + + @classmethod + def create(cls, format_string, sym_args, sym_kwargs): + if all( + x.is_python_constant() + for x in itertools.chain(sym_args, sym_kwargs.values()) + ): + return variables.ConstantVariable.create( + format_string.format( + *[v.as_python_constant() for v in sym_args], + **{k: v.as_python_constant() for k, v in sym_kwargs.items()}, + ) + ) + return cls(format_string, list(sym_args), dict(sym_kwargs)) + + def __init__(self, format_string, sym_args, sym_kwargs, **kwargs) -> None: + super().__init__(**kwargs) + assert isinstance(format_string, str) + self.format_string = format_string + self.sym_args = sym_args + self.sym_kwargs = sym_kwargs + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.format_string!r}, {self.sym_args!r}, {self.sym_kwargs!r})" + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_const(self.format_string), + codegen.create_load_attr("format"), + ] + ), + call_function_ex=True, + ) + codegen(variables.TupleVariable(self.sym_args)) + kwargs = { + variables.ConstantVariable.create(k): v for k, v in self.sym_kwargs.items() + } + codegen(variables.ConstDictVariable(kwargs)) + codegen.extend_output(create_call_function_ex(True, False)) + + +class DebuggingVariable(VariableTracker): + """ + Represents a call to a debugging function like print(), or something + registered to config.reorderable_logging_functions. + """ + + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + + @staticmethod + def is_reorderable_logging_function(obj): + return ( + callable(obj) + and isinstance(obj, (types.FunctionType, types.BuiltinFunctionType)) + and obj in torch._dynamo.config.reorderable_logging_functions + ) + + def call_function(self, tx: "InstructionTranslator", args, kwargs): + if tx.export: + # For export cases, we can just make debugging functions no-ops + return + + if not self.can_reorder_logs(self.value, args, kwargs): + unimplemented( + gb_type="attempted to reorder a debugging function that can't actually be reordered", + context=f"fn: {self.value}, args: {args}, kwargs: {kwargs}", + explanation="`torch.compile` can only reorder functions where the arguments " + "are Tensors, constants, or string formatters.", + hints=[ + f"Avoid calling the logging function {self.value} with args that are not supported.", + ], + ) + + tx.debug_locals.append((self, list(args))) + + def reconstruct(self, codegen: "PyCodegen"): + return self.source.reconstruct(codegen) + + @staticmethod + def can_reorder_logs(fn, args, kwargs) -> True: + """ + Run some additional checks for what sort of function calls can we + actually reorder. + """ + + allowed_input_types = ( + variables.TensorVariable, + variables.ConstantVariable, + StringFormatVariable, + ) + + flat_args = pytree.tree_leaves([args, kwargs]) + for arg in flat_args: + if not isinstance(arg, allowed_input_types): + return False + + return True + + +class LoggingLoggerVariable(VariableTracker): + """ + Represents a call to any of logging.Logger methods + """ + + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if tx.export: + # For export cases, we can just make debugging functions no-ops + return + method = getattr(self.value, name, None) + function = getattr(method, "__func__", None) + if {method, function}.intersection(torch._dynamo.config.ignore_logger_methods): + return variables.ConstantVariable.create(None) + unimplemented( + gb_type="logging.Logger method not supported for non-export cases", + context=f"method: {self.value}.{name}, args: {args}, kwargs: {kwargs}", + explanation="logging.Logger methods are not supported for non-export cases.", + hints=[ + "Add the logging method to `torch._dynamo.config.ignore_logger_methods.", + ], + ) + + +class ConstantLikeVariable(VariableTracker): + """self.value is a compile-time constant, but not a literal""" + + try: + from numpy import ( + dtype as np_dtype, + floating as np_floating, + generic as np_generic, + ) + except ImportError: + np_floating = type("invalid_type", (), {}) + np_dtype = type("invalid_type", (), {}) + + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + + @property + def _error_prefix(self): + """Dynamically compute the prefix from the value's type""" + t = type(self.value) + + # For builtins (int, str, etc.), just return the name + if t.__module__ == "builtins": + return t.__qualname__ + + return f"{t.__module__}.{t.__qualname__}" + + def as_python_constant(self): + return self.value + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + try: + # we only support constant propagation for methods + cargs = [x.as_python_constant() for x in args] + ckwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + except NotImplementedError: + unimplemented( + gb_type="constant-like method call with non-constant args", + context=f"{self._error_prefix}.{name}(*{args}, **{kwargs})", + explanation=f"Attempted to call {self._error_prefix}.{name} with non-constant args.", + hints=[ + "Ensure that the args to the method call are constant (int, str, etc.).", + ], + ) + + result = getattr(self.value, name)(*cargs, **ckwargs) + + if variables.ConstantVariable.is_literal(result): + return variables.ConstantVariable.create(result) + if isinstance(result, re.Match): + return ConstantLikeVariable(result) + + unimplemented( + gb_type="constant-like method call with unsupported return type", + context=f"{self._error_prefix}.{name}(*{args}, **{kwargs}) returned {result}", + explanation=f"Attempted to call {self._error_prefix}.{name}, got unsupported return value {result}.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + result = getattr(self.value, name) + if isinstance(result, self.np_floating): + result = float(result) + if isinstance(result, self.np_dtype): + return NumpyDTypeVariable(result) + if isinstance(result, type) and issubclass(result, self.np_generic): + # things like x.dtype.type + return NumpyVariable(result) + if variables.ConstantVariable.is_literal(result): + return variables.ConstantVariable.create(result) + return GetAttrVariable(self, name) + + +class TorchVersionVariable(ConstantLikeVariable): + _error_prefix = "torch.__version__" + + def __init__(self, **kwargs) -> None: + kwargs.setdefault("value", torch.__version__) + assert kwargs["value"] is torch.__version__ + super().__init__(**kwargs) + + +class NumpyDTypeVariable(ConstantLikeVariable): + def as_proxy(self): + """Similar to how numpy dtype descriptors (e.g. np.float32 ) are handled by NumpyVariable: + + np.dtype() objects are serialized as strings, torch._numpy wrappers will normalize to the torch dtype. + This also handles unsupported things nicely (i.e. structured arrays and object arrays). + """ + return self.value.type.__name__ + + +np_constant_collections_map = { + tnp.finfo: ConstantLikeVariable, + tnp.iinfo: ConstantLikeVariable, + tnp.dtype: NumpyDTypeVariable, +} + + +class RandomClassVariable(VariableTracker): + """random.Random""" + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + def call_function(self, tx: "InstructionTranslator", args, kwargs): + if len(args) > 1 or kwargs: + unimplemented( + gb_type="random.Random() with improper arguments", + context=f"args: {args}, kwargs: {kwargs}", + explanation="random.Random() with > 1 arg or with kwargs is not supported.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + seed = variables.ConstantVariable.create(None) if len(args) == 0 else args[0] + return RandomVariable( + seed=seed, mutation_type=variables.base.ValueMutationNew() + ) + + +class RandomVariable(VariableTracker): + """random.Random() + + Implemented by wrapping a VariableTracker around a random.Random object. + The supported methods for the random.Random object cannot be overridden. + Assumes that random objects behave the same given a set seed or state. + """ + + _nonvar_fields = { + "random", + *VariableTracker._nonvar_fields, + } + + _supported_fn_names = { + "random", + "randint", + "randrange", + "uniform", + } + + def __init__( + self, + rand: Optional[random.Random] = None, + seed: Optional[VariableTracker] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + if rand is not None: + assert self.is_supported_random_obj(rand) + self.random = random.Random() + self.random.setstate(rand.getstate()) + else: + seed = seed.as_python_constant() if seed is not None else None + self.random = random.Random(seed) + + def python_type(self): + return random.Random + + def as_python_constant(self): + return self.random + + @staticmethod + def is_supported_random_obj(val): + if type(val) is not random.Random: + return False + for name in itertools.chain( + RandomVariable._supported_fn_names, ("seed", "getstate", "setstate") + ): + if not hasattr(val, name): + return False + meth = getattr(val, name) + if inspect.isbuiltin(meth): + # e.g. random.Random.random + if meth != getattr(random.Random, name).__get__(val): + return False + else: + if getattr(meth, "__func__", None) is not getattr(random.Random, name): + return False + return True + + @staticmethod + def check_state(state): + assert type(state) is tuple + assert type(state[0]) is int + assert type(state[1]) is tuple + assert all(type(x) is int for x in state[1]) + assert state[2] is None or type(state[2]) is float + + @staticmethod + def wrap_state(state): + RandomVariable.check_state(state) + return variables.TupleVariable( + [ + variables.ConstantVariable.create(state[0]), + variables.TupleVariable( + [variables.ConstantVariable.create(x) for x in state[1]] + ), + variables.ConstantVariable.create(state[2]), + ] + ) + + @staticmethod + def unwrap_state(state): + state_obj = state.as_python_constant() + RandomVariable.check_state(state_obj) + return state_obj + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "seed": + tx.output.side_effects.mutation(self) + self.random.seed( + *[x.as_python_constant() for x in args], + **{key: val.as_python_constant() for key, val in kwargs.items()}, + ) + return variables.ConstantVariable.create(None) + elif name == "getstate": + return self.wrap_state(self.random.getstate()) + elif name == "setstate": + tx.output.side_effects.mutation(self) + self.random.setstate(self.unwrap_state(args[0])) + return variables.ConstantVariable.create(None) + elif name in self._supported_fn_names: + tx.output.side_effects.mutation(self) + state = self.random.getstate() + + def call_random_meth(*args, **kwargs): + r = random.Random() + r.setstate(state) + return getattr(r, name)(*args, **kwargs) + + # self.random state not actually updated by call_random_meth, so update here + # by calling the method + getattr(self.random, name)( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ) + + return call_random_fn(tx, call_random_meth, args, kwargs) + return super().call_method(tx, name, args, kwargs) + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_python_module(random), + codegen.create_load_attr("Random"), + ] + ) + ) + codegen.call_function(0, False) + # NOTE using add_push_null may result in NULL being duplicated + # so defer the push_null to call_function + codegen.dup_top() + codegen.load_attr("setstate") + codegen(self.wrap_state(self.random.getstate())) + codegen.call_function(1, True) + codegen.pop_top() + + +class WeakRefVariable(VariableTracker): + @staticmethod + def build(tx, weakref_value, **options): + source = options.get("source") + callback = weakref_value.__callback__ + callback_source = source and AttrSource(source, "__callback__") + callback_vt = VariableTracker.build(tx, callback, callback_source) + referent = weakref_value() + source = source and WeakRefCallSource(source) + referent_vt = VariableTracker.build(tx, referent, source) + options["source"] = source + return WeakRefVariable(referent_vt, callback_vt, **options) + + def __init__(self, referent_vt, callback_vt, **options): + super().__init__(**options) + self.referent_vt = referent_vt + self.callback_vt = callback_vt + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + return self.referent_vt + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null(lambda: codegen.load_import_from("weakref", "ref")) + codegen(self.referent_vt) + codegen(self.callback_vt) + codegen.extend_output(create_call_function(2, False)) + + def is_python_hashable(self): + return self.referent_vt.is_python_hashable() + + def get_python_hash(self): + # weakref relies on the referent's hash + return self.referent_vt.get_python_hash() + + def is_python_equal(self, other): + return self.referent_vt.is_python_equal(other.referent_vt) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/nn_module.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/nn_module.py new file mode 100644 index 0000000000000000000000000000000000000000..fb3b2b792215ccdec807f20a31d06e9fdd937e49 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/nn_module.py @@ -0,0 +1,1378 @@ +""" +This module implements variable tracking for PyTorch nn.Module instances during Dynamo tracing. + +It provides specialized handling for different types of nn.Module instances through several key classes: + +- NNModuleVariable: Handles instance-specific module tracing, specializing on module id() and placing + parameters directly on the torch.fx.GraphModule. This creates one graph per module instance. + +- UnspecializedNNModuleVariable: Provides class-level module tracing, treating nn.Modules like other + user-defined objects and passing parameters as inputs to the FX graph. This creates one graph per + module class. + +- UnspecializedBuiltinNNModuleVariable: Specifically handles built-in PyTorch modules (e.g. nn.Linear) + with appropriate optimizations. + +- FSDPManagedNNModuleVariable: Special handling for FSDP-wrapped modules with modified guarding behavior + and parameter handling. + +The module integrates with Dynamo's broader tracing functionality to handle module method calls, +parameter access, hooks, and other nn.Module behaviors while maintaining proper scoping and guarding +of module state. +""" + +import functools +import inspect +import itertools +import re +import types +from collections.abc import Iterable, Sequence +from contextlib import contextmanager, nullcontext +from typing import Any, Optional, TYPE_CHECKING + +import torch.nn +from torch._guards import Source + +from .. import graph_break_hints, trace_rules, variables +from ..exc import raise_observed_exception, unimplemented, UnspecializeRestartAnalysis +from ..guards import GuardBuilder, install_guard +from ..mutation_guard import GenerationTracker +from ..source import ( + AttrSource, + ConstDictKeySource, + DictGetItemSource, + FSDPNNModuleSource, + GetItemSource, + NNModuleSource, + UnspecializedNNModuleSource, +) +from ..utils import ( + get_custom_getattr, + get_fake_value, + is_lazy_module, + is_namedtuple, + is_safe_constant, + istensor, + istype, + nnmodule_has_hooks, + object_has_getattribute, + proxy_args_kwargs, + raise_args_mismatch, + set_example_value, + unpatched_nn_module_call, + unpatched_nn_module_call_impl, +) +from .base import raise_type_error_exc, typestr, ValueMutationNew, VariableTracker +from .functions import invoke_and_store_as_constant +from .lazy import LazyVariableTracker +from .lists import SliceVariable +from .user_defined import UserDefinedObjectVariable + + +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslator + + from .constant import ConstantVariable + + +def initialize_lazy_module( + tx: "InstructionTranslator", + mod: torch.nn.Module, + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], +) -> None: + """ + Fairly coupled helper used by NNModuleVariable and UnspecializedNNModuleVariable. + + Used to cause lazy module to be initialized (and delete its init hook) before tracing. Especially + useful now that 'allowed' modules graph-break on hooks, calling this first ensures there is no hook + by the time we trace __call__ and thus no graph-break for lazy allowed modules. + """ + if hasattr(mod, "_initialize_hook"): + + def convert_to_fake(x: Any) -> Any: + if is_namedtuple(x): + return type(x)(*(convert_to_fake(elem) for elem in x)) + elif isinstance(x, dict): + return {k: convert_to_fake(v) for k, v in x.items()} # type: ignore[misc] + elif isinstance(x, (list, tuple, set)): + return type(x)(convert_to_fake(elem) for elem in x) + elif isinstance(x, torch.fx.Proxy): + return get_fake_value(x.node, tx) + else: + return x + + proxy_args, proxy_kwargs = proxy_args_kwargs(args, kwargs) + fake_args = [convert_to_fake(arg) for arg in proxy_args] + fake_kwargs = {k: convert_to_fake(v) for k, v in proxy_kwargs.items()} + try: + mod._infer_parameters(mod, fake_args, fake_kwargs) # type: ignore[operator] + except AttributeError as e: + # Re-raise with the original error message from the AttributeError + raise_observed_exception( + AttributeError, + tx, + args=[ + str(e) + if str(e) + else "AttributeError during lazy module initialization" + ], + ) + + +@contextmanager +def record_nn_module_stack( + module_key: str, source: Source, tx: "InstructionTranslator", mod: torch.nn.Module +) -> Any: + fully_qualified_name = source.name + # Remove redundant namings + fully_qualified_name = re.sub( + r"\._(?:modules|parameters|buffers)\[(['\"])([^'\"\]]+)\1\]", + r".\2", + fully_qualified_name, + ) + num_calls = tx.num_calls.get(fully_qualified_name, 0) + module_key = f"{module_key}@{num_calls}" if num_calls > 0 else module_key + try: + tx.nn_module_stack[module_key] = (fully_qualified_name, mod.__class__) + tx.num_calls[fully_qualified_name] = num_calls + 1 + yield + finally: + del tx.nn_module_stack[module_key] + + +def guard_to_detect_forward_monkeypatching( + source: Optional[Source], mod: torch.nn.Module +) -> None: + # Users sometimes patch the forward method of a nn module instance to + # perform optimizations like quantization. Though this is not a good + # software practice, but python allows this and Dynamo needs to detect + # this patching. + # + # One way to do this is to add an ID_MATCH guard on every function + # getting inlined (https://github.com/pytorch/pytorch/pull/124975). But + # this increased guard overhead by around 20%. + # + # To keep the guard overhead down, we just guard on the `forward` being + # not present in the mod __dict__. The common case of patching forward + # method adds `forward` in the instance __dict__, whereas the unpatched + # `forward` sits in the type(mod).__dict__ + if source: + if "forward" in mod.__dict__ and callable(mod.__dict__["forward"]): + # Monkeypatched forward method, add an ID_MATCH guard on forward function + fwd = mod.__dict__["forward"] + forward_source = AttrSource(source, "forward") + if type(fwd) is types.MethodType: + forward_source = AttrSource(forward_source, "__func__") + install_guard(forward_source.make_guard(GuardBuilder.CLOSURE_MATCH)) + else: + # Common case - check that the forward key is absent in mod __dict__ + install_guard( + source.make_guard( + functools.partial( + GuardBuilder.NOT_PRESENT_IN_GENERIC_DICT, attr="forward" + ) + ) + ) + + +class NNModuleVariable(VariableTracker): + _nonvar_fields = { + "module_type", + "module_key", + "value", + "nn_module_stack_source", + *VariableTracker._nonvar_fields, + } + + def __init__( + self, module_type: type, module_key: str, value: torch.nn.Module, **kwargs: Any + ) -> None: + super().__init__(**kwargs) + self.module_type = module_type + self.module_key = module_key + self.value = value + # pyrefly: ignore[bad-override] + # NOTE: Don't remove this; better than adding suppressions + # everywhere else with asserts + self.source: Source = self.source + self.nn_module_stack_source = self.source + + def get_nn_module_stack_source(self) -> Source: + res = self.nn_module_stack_source or self.source + assert res + return res + + def set_nn_module_stack_source(self, source: Source) -> None: + self.nn_module_stack_source = source + + def python_type(self) -> type: + return self.module_type + + def _wrap_submodule( + self, + tx: "InstructionTranslator", + source: Source, + submod: torch.nn.Module, + *key_extra: Any, + **options: Any, + ) -> None: + return + + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: + # implement list/iter/tuple/etc calls + base = tx.output.get_submodule(self.module_key) + result: list[VariableTracker] = [] + if isinstance(base, torch.nn.ModuleDict): + for name, submod in base.items(): + name_var = variables.ConstantVariable.create(name) + tx.output.register_attr_or_module( + submod, + self.module_key, + name, + source=NNModuleSource(GetItemSource(self.source, name)), # type: ignore[arg-type] + ) + result.append(name_var) + return result + + assert isinstance( + base, (torch.nn.ModuleList, torch.nn.ParameterList, torch.nn.Sequential) + ), typestr(base) + for idx, submod in enumerate(base): + result.append( + tx.output.register_attr_or_module( + submod, + self.module_key, + idx, + source=NNModuleSource(GetItemSource(self.source, idx)), + ) + ) + return result + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> "ConstantVariable": + mod = tx.output.get_submodule(self.module_key) + result = hasattr(mod, name) + install_guard( + NNModuleSource(AttrSource(self.source, name)).make_guard( + GuardBuilder.HASATTR + ) + ) + return variables.ConstantVariable.create(result) + + def is_training(self, tx: "InstructionTranslator") -> bool: + mod = tx.output.get_submodule(self.module_key) + return getattr(mod, "training", False) + + def convert_to_unspecialized(self, tx: "InstructionTranslator") -> None: + """Restart analysis treating this module as an UnspecializedNNModuleVariable""" + mod = tx.output.get_submodule(self.module_key) + GenerationTracker.tag(mod) + + # Mark the class dynamic unless its module initialization + if tx.f_code.co_name != "__init__": + GenerationTracker.mark_class_dynamic(type(mod)) + raise UnspecializeRestartAnalysis + + def has_key_in_generic_dict(self, tx: "InstructionTranslator", key: str) -> bool: + base = tx.output.get_submodule(self.module_key) + + if object_has_getattribute(base): + unimplemented( + gb_type="Custom __getattribute__ in nn.Module dict key check", + context=f"has_key_in_generic_dict {self} {key}", + explanation="Dynamo does not support checking key existence " + "on `nn.Module` instances that have a custom " + "`__getattribute__` method defined.", + hints=[ + "Avoid defining `__getattribute__` in your module.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + if tx.output.side_effects.has_pending_mutation_of_attr(self, key): + mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True) + return not isinstance(mutated_attr, variables.DeletedVariable) + + base_dict = object.__getattribute__(base, "__dict__") + return key in base_dict + + def _custom_getattr_fallback( + self, + base: torch.nn.Module, + tx: "InstructionTranslator", + name: str, + obj_source: Source, + ) -> Optional[VariableTracker]: + """Check for a __getattr__ and handle it specially if it is implemented""" + if object_has_getattribute(base): + unimplemented( + gb_type="Custom __getattribute__ in nn.Module attribute access", + context=f"var_getattr {self} {name}", + explanation="Dynamo does not support checking key existence " + "on `nn.Module` instances that have a custom " + "`__getattribute__` method defined.", + hints=[ + "Avoid defining `__getattribute__` in your module.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + getattr_fn = get_custom_getattr(base, ignore_nn_module_getattr=True) + if getattr_fn is None: + return None + + if not isinstance(getattr_fn, types.FunctionType): + unimplemented( + gb_type="torch.nn.Module with a non-function custom __getattr__", + context=f"var_getattr {self} {name}", + explanation=( + "Dynamo detected a nn.Module object with a custom " + "`__getattr__` method, but this method is not a standard " + "Python function (e.g., it might be implemented in C/C++). " + "Dynamo cannot currently trace into such non-standard " + "`__getattr__` methods." + ), + hints=[ + "Avoid using objects with non-standard __getattr__ methods " + "within the compiled region. If possible, implement " + "__getattr__ as a standard Python function.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + options = {"source": AttrSource(obj_source, "__getattr__")} + # pyrefly: ignore[bad-argument-type] + return variables.UserMethodVariable(getattr_fn, self, **options).call_function( + tx, [variables.ConstantVariable.create(name)], {} + ) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + source = self.source and AttrSource(self.source, name) + + base = tx.output.get_submodule(self.module_key) + base_dict = object.__getattribute__(base, "__dict__") + object_member = True + all_class_attribute_names = set() + for x in inspect.getmro(base.__class__): + all_class_attribute_names.update(x.__dict__.keys()) + + if not self.source: + unimplemented( + gb_type="getattr with no source", + context=f"var_getattr {self} {name}", + explanation="Dynamo does not know how to access an attribute " + "on an `nn.Module` instance that lacks a source. This is " + "usually an internal error in Dynamo.", + hints=[*graph_break_hints.DYNAMO_BUG], + ) + + if name == "__dict__": + return variables.GetAttrVariable(self, name, source=source) + + subobj = None + if name in base_dict: + subobj = base_dict[name] + elif ( + "_modules" in base_dict + and name in base_dict["_modules"] + and name not in all_class_attribute_names + ): + subobj = base_dict["_modules"][name] + elif "_parameters" in base_dict and name in base_dict["_parameters"]: + subobj = base_dict["_parameters"][name] + elif "_buffers" in base_dict and name in base_dict["_buffers"]: + subobj = base_dict["_buffers"][name] + else: + try: + subobj = inspect.getattr_static(base, name) + object_member = False + except AttributeError: + # see if we can fallback to __getattr__, which is not checked by getattr_static + result = self._custom_getattr_fallback( + base=base, tx=tx, name=name, obj_source=self.source + ) + if result is not None: + return result + # if we can't find a __getattr__, we can't parse this, raise attribute error + raise_observed_exception( + AttributeError, + tx, + args=[f"'{type(base).__name__}' object has no attribute '{name}'"], + ) + + if name == "forward": + guard_to_detect_forward_monkeypatching(self.source, base) + + if name == "__class__" and not object_member: + return variables.UserDefinedClassVariable(base.__class__, source=source) + + if object_member: + out = VariableTracker.build(tx, subobj, NNModuleSource(source)) # type: ignore[arg-type] + + if isinstance(out, (NNModuleVariable, UnspecializedNNModuleVariable)): + # nn_module_stack source is BC surface area. Ensure that + # mod._modules["linear"] is reflected as mod.linear for + # nn_module_stack. + out.set_nn_module_stack_source( + AttrSource(self.get_nn_module_stack_source(), name) + ) + return out + + else: + if istype(subobj, property): + if self.source: + # Read the class attribute to reach the property + source = AttrSource(AttrSource(self.source, "__class__"), name) + # Get the getter function + source = AttrSource(source, "fget") + return variables.UserFunctionVariable( + subobj.fget, # pyrefly: ignore[bad-argument-type] + source=source, + ).call_function(tx, [(self)], {}) + elif istype(subobj, classmethod): + return variables.UserMethodVariable( + subobj.__func__, + variables.UserDefinedObjectVariable(type(base)), + source=source, + ) + elif istype(subobj, staticmethod): + return variables.UserFunctionVariable( + # pyrefly: ignore[bad-argument-type] + subobj.__get__(base), + source=source, + ) + elif istype(subobj, types.FunctionType): + return variables.UserMethodVariable(subobj, self, source=source) + elif is_safe_constant(subobj) or istensor(subobj): + # Support possibly common cases of class members + return VariableTracker.build(tx, subobj, NNModuleSource(source)) # type: ignore[arg-type] + else: + unimplemented( + gb_type="Unsupported nn.Module attribute type", + context=f"nn.Module subclass: {typestr(base)}, name: {name}, attribute type: {typestr(subobj)}", + explanation=f"Dynamo does not support tracing nn.Module attributes of type `{typestr(subobj)}`", + hints=[ + f"Refactor your code so that `{name}` (type `{typestr(subobj)}`) is not an attribute of `{typestr(base)}`", + "Currently supported attribute types are methods, classmethods, staticmethods, " + "properties, constants, and tensors.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + return variables.GetAttrVariable(self, name, source=source) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + mod = tx.output.get_submodule(self.module_key) + + with record_nn_module_stack( + self.module_key, self.get_nn_module_stack_source(), tx, mod + ): + is_lazy = is_lazy_module(mod) + if ( + isinstance(mod, torch.nn.Sequential) + and mod.__class__.forward is torch.nn.Sequential.forward + ): + if nnmodule_has_hooks(mod): + # We do not want to unroll sequential if it has hooks, since evaporating it + # will cause hooks to not fire! + # This terminates and restart the tracing process + self.convert_to_unspecialized(tx) + + # Unroll sequential + assert not is_lazy, ( + "Expected lazy sequential isn't a valid combination?" + ) + if kwargs: + raise_args_mismatch( + tx, + "torch.nn.Module.Sequential", + "0 kwargs", + f"{len(kwargs)} kwargs", + ) + (arg,) = args + # TODO: Use named_children when it supports remove_duplicate=False. + for child_name, submod in mod._modules.items(): + tx.call_function( + tx.output.register_attr_or_module( + submod, + self.module_key, + child_name, + source=NNModuleSource(AttrSource(self.source, child_name)), # type: ignore[arg-type] + ), + [arg], + {}, + ) + arg = tx.pop() + return arg + + if is_lazy: + # The module type will change after it is called + if mod.cls_to_become is not None: + self.module_type = mod.cls_to_become # type: ignore[assignment] + + # The pre-hook runs to initialize the module shapes, then deletes itself. After this, + # the module is more or less not lazy and can be treated as a normal module regardless of + # is_allowed or other variations. + initialize_lazy_module(tx, mod, args, kwargs) + + # If we are tracing the higher order op, we want Dynamo to step + # inside the module call so that Dynamo can see the underlying + # parameters and buffers and raise them as inputs to the graph. + # + # NB: torch.nn.utils.parametrize changes the class type of a + # parametrized module such that its __module__ points to + # "torch.nn.utils.parametrize". + if ( + tx.output.is_root_tracer() + and mod.__module__.startswith(("torch.nn.", "torch.ao.")) + and mod.__module__ != "torch.nn.utils.parametrize" + # this basically means we are using the new strict export tracer which wraps the + # user callable, so we shouldn't directly proxy in the fx graph + and not isinstance( + mod, torch.ao.quantization.pt2e.export_utils._WrapperModule + ) + ): + if nnmodule_has_hooks( + mod, check_forward_hooks=True, check_backward_hooks=True + ): + # End of fn, this bubbles up and restarts tracing. + self.convert_to_unspecialized(tx) + + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_module", + self.module_key, + *proxy_args_kwargs(args, kwargs), + ), + ) + else: + if isinstance(mod, torch.fx.GraphModule): + # TODO: do we want to support __call__ for GM's? + # If so at least some changes are needed, we don't allow inlining + # the call_wrapped currently, and maybe other issues too + fn = mod.forward + fn_source = AttrSource(self.source, "forward") + else: + fn = mod._call_impl + fn_source = AttrSource(self.source, "_call_impl") + if istype(fn, types.MethodType): + fn = fn.__func__ + fn_source = AttrSource(fn_source, "__func__") + args = [self] + list(args) + else: + assert istype(fn, types.FunctionType) + return tx.inline_user_function_return( + # pyrefly: ignore[bad-argument-type] + variables.UserFunctionVariable(fn, source=fn_source), + args, + kwargs, + ) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + constant: bool = False, + ) -> VariableTracker: + from . import ConstantVariable, ListIteratorVariable, TupleVariable + + key = self.module_key + module = tx.output.get_submodule(key) + + def generic_call_method_helper(name: str) -> VariableTracker: + # Helper function to put a `call_method` node in FX graph, + # with nn.Module as the first arg. + mod_proxy = tx.output.create_proxy( + "get_attr", + self.module_key, + (), + {}, + ) + set_example_value(mod_proxy.node, module) + + proxy_args, proxy_kwargs = proxy_args_kwargs(args, kwargs) + + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_method", + name, + args=(mod_proxy, *proxy_args), + kwargs=proxy_kwargs, + ), + ) + + if name in ["_call_impl", "_wrapped_call_impl"]: + # Example: `self.layer.__call__(x)` + # This is used for explicit calling `__call__` in a forward function. + # Dynamo inlines `__call__`, includes hooks. + return self.call_function(tx, args, kwargs) + elif name == "forward": + # Example: `self.layer.forward(x)` + # This is used for explicit calling `forward` in a forward function. + # Dynamo puts `call_method` node in FX, doesn't trigger hooks. + with record_nn_module_stack( + self.module_key, self.get_nn_module_stack_source(), tx, module + ): + return generic_call_method_helper(name) + + if name == "_check_input_dim" and trace_rules.is_torch_inline_allowed( + inspect.getfile(module.__class__._check_input_dim) # type: ignore[union-attr] + ): + return ConstantVariable.create(True) + + if name == "_get_item_by_idx": + if not args[1].is_python_constant(): + raise_type_error_exc( + tx, + f"``nn.Module`` {module}'s call method {name} requires a constant index argument", + ) + if not isinstance(args[0], TupleVariable): + raise_type_error_exc( + tx, + f"``nn.Module`` {module}'s call method {name} requires a tuple as first argument", + ) + mod_var = args[0].items[args[1].value] # type: ignore[attr-defined] + if isinstance(mod_var, UnspecializedNNModuleVariable): + return mod_var + key = mod_var.module_key # type: ignore[attr-defined] + submod = tx.output.get_submodule(key) + return tx.output.register_attr_or_module( + submod, + key, + key, + source=NNModuleSource(GetItemSource(self.source, key)), + ) + + if constant: + fn = getattr(module, name) + name = f"{module.__class__.__name__}_{name}_result" + return invoke_and_store_as_constant(tx, fn, name, args, kwargs) + + def assert_all_args_kwargs_const() -> None: + if not all( + x.is_python_constant() for x in itertools.chain(args, kwargs.values()) + ): + unimplemented( + gb_type="non-const argument in nn.Module method", + context=f"call_method: {self} {name} {args} {kwargs}", + explanation="Dynamo does not support calling " + f"method `{name}` of ``nn.Module`` {module} with non-constant arguments.", + hints=[], + ) + + def get_kwargs(*names: str) -> dict[str, Any]: + assert_all_args_kwargs_const() + fn = getattr(module, name) + bound_args = inspect.signature(fn).bind( + *([x.as_python_constant() for x in args]), + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ) + bound_args.apply_defaults() + bound_args = bound_args.arguments + return {k: bound_args[k] for k in names} + + def wrap_values( + items: Iterable[tuple[Any, Any]], + ) -> "variables.ListIteratorVariable": + result = [] + for name, submod in items: + result.append( + tx.output.register_attr_or_module( + submod, + key, + name, + source=NNModuleSource(gen_source(self.source, name)), + ) + ) + return ListIteratorVariable( + named_children, mutation_type=ValueMutationNew() + ) + + def named_embed(name: str, obj: Any) -> "variables.TupleVariable": + return TupleVariable( + [ + ConstantVariable.create(name), + tx.output.register_attr_or_module( + obj, + key, + name, + source=NNModuleSource(gen_source(self.source, name)), + ), + ] + ) + + def gen_source(source: Source, name: str) -> Source: + name_split = name.split(".") + if name_split[0] == "": + return source + while len(name_split) > 0: + x = name_split.pop(0) + source = AttrSource(source, x) + return source + + if name == "named_children": + tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules")) + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + named_children: list[VariableTracker] = [] + for name, submod in module.named_children(): + named_children.append(named_embed(name, submod)) + return ListIteratorVariable( + named_children, mutation_type=ValueMutationNew() + ) + elif name == "named_parameters": + tx.output.guard_on_key_order.add(AttrSource(self.source, "_parameters")) + named_parameters: list[VariableTracker] = [] + for name, param in module.named_parameters( + **get_kwargs("prefix", "recurse") + ): + named_parameters.append(named_embed(name, param)) + return ListIteratorVariable( + named_parameters, mutation_type=ValueMutationNew() + ) + elif name == "named_buffers": + tx.output.guard_on_key_order.add(AttrSource(self.source, "_buffers")) + named_buffers: list[VariableTracker] = [] + for name, buffer in module.named_buffers( + **get_kwargs("prefix", "recurse", "remove_duplicate") + ): + named_buffers.append(named_embed(name, buffer)) + return ListIteratorVariable(named_buffers, mutation_type=ValueMutationNew()) + elif name == "named_modules": + tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules")) + named_modules_list: list[VariableTracker] = [] + for name, submod in module.named_modules( + **get_kwargs("memo", "prefix", "remove_duplicate") + ): + named_modules_list.append(named_embed(name, submod)) + return ListIteratorVariable( + named_modules_list, mutation_type=ValueMutationNew() + ) + elif name == "children": + tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules")) + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + return wrap_values(module.named_children()) + elif name == "modules": + tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules")) + return wrap_values(module.named_modules()) + elif name == "parameters": + tx.output.guard_on_key_order.add(AttrSource(self.source, "_parameters")) + return wrap_values(module.named_parameters(**get_kwargs("recurse"))) + elif name == "buffers": + tx.output.guard_on_key_order.add(AttrSource(self.source, "_buffers")) + return wrap_values(module.named_buffers(**get_kwargs("recurse"))) + elif name == "keys": + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + result = [] + # pyrefly: ignore[not-iterable] + for tmp in module: + result.append(ConstantVariable.create(tmp)) + return ListIteratorVariable(result, mutation_type=ValueMutationNew()) + elif name == "values": + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + return wrap_values(module.items()) # type: ignore[operator] + elif name == "items": + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + items_result: list[VariableTracker] = [] + for name, submod in module.items(): # type: ignore[operator] + items_result.append(named_embed(name, submod)) + return ListIteratorVariable(items_result, mutation_type=ValueMutationNew()) + elif name == "__len__": + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + return ConstantVariable.create(len(module)) # type: ignore[arg-type] + elif name == "__iter__": + return ListIteratorVariable( + self.unpack_var_sequence(tx), mutation_type=ValueMutationNew() + ) + elif ( + name == "__contains__" + and isinstance(module, (torch.nn.ModuleDict, torch.nn.ParameterDict)) + and args + and args[0].is_python_constant() + ): + return ConstantVariable.create( + args[0].as_python_constant() in module._modules + ) + elif name == "__getitem__": + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + builtin_supported = ( + torch.nn.ModuleDict.__getitem__, + torch.nn.ModuleList.__getitem__, + torch.nn.ParameterDict.__getitem__, + torch.nn.ParameterList.__getitem__, + torch.nn.Sequential.__getitem__, + ) + # pyrefly: ignore[missing-attribute] + if type(module).__getitem__ not in builtin_supported: + if not ( + args[0].is_python_constant() + and isinstance(args[0].as_python_constant(), (str, int)) + ): + unimplemented( + gb_type="Invalid or non-const argument in nn.Module __getitem__", + context=f"call_method: {self} {name} {args} {kwargs}", + explanation="Dynamo does not support calling " + f"method `{name}` of ``nn.Module`` {module} with a non-constant or non-(str, int) key.", + hints=[ + "Use constant arguments of type str or int for __getitem__" + ], + ) + fn = getattr(module, name).__func__ + + assert isinstance(fn, types.FunctionType) + + src = AttrSource(AttrSource(self.source, name), "__func__") # type: ignore[arg-type] + return tx.inline_user_function_return( + variables.UserFunctionVariable(fn, source=src), + [self] + list(args), + kwargs, + ) + + if isinstance(args[0], SliceVariable): + # TODO(anijain2305,export-team) - Remove this if condition when inlining of inbuilt nn modules is + # enabled for export. + if tx.output.export: + # Build a TupleVariable of NNModules + result = [] + + # Turn the slice into the list of integers + keys = list(range(len(module)))[args[0].as_python_constant()] # type: ignore[arg-type] + for idx, submod in enumerate(module[args[0].as_python_constant()]): # type: ignore[arg-type] + key = keys[idx] + src = NNModuleSource(GetItemSource(self.source, key)) + result.append( + tx.output.register_attr_or_module( + submod, + key, + source=src, + ) + ) + + new_module = module[args[0].as_python_constant()] # type: ignore[index] + new_module_variable = tx.output.register_attr_or_module( + new_module, + f"{self}.__getitem__(slice)", + source=NNModuleSource( + GetItemSource(self.source, args[0].as_python_constant()) + ), + ) + return new_module_variable + else: + # slice on nn module results in a creation of new module instance, so we need to make it sourceless. + # Convert to unspecialized so that UnspecializedNNModule variable can take care of it. + self.convert_to_unspecialized(tx) + + from .tensor import SymNodeVariable + + key_value = 0 + if isinstance(args[0], SymNodeVariable): + key_value = args[0].evaluate_expr(tx.output) + elif args[0].is_python_constant(): + key_value = args[0].as_python_constant() + else: + unimplemented( + gb_type="Unsupported key type for nn.Module.__getitem__", + context=f"call_method: {self} {name} {args} {kwargs}", + explanation="Dynamo does not support getitem on " + "`nn.Module` with non-constant key.", + hints=[], + ) + + submod = module[key_value] # type: ignore[index] + return tx.output.register_attr_or_module( + submod, + self.module_key, + key_value, + source=NNModuleSource(GetItemSource(self.source, key_value)), + ) + elif ( + name == "_get_abs_string_index" + or ( + isinstance(module, torch.nn.modules.conv._ConvNd) + and name == "_conv_forward" + ) + or ( + isinstance(module, torch.nn.modules.conv._ConvTransposeNd) + and name == "_output_padding" + ) + ): + # Inline the function + fn = getattr(module, name).__func__ + fn_source = AttrSource(AttrSource(self.source, name), "__func__") # type: ignore[arg-type] + return tx.inline_user_function_return( + variables.UserFunctionVariable(fn, source=fn_source), + [self] + list(args), + kwargs, + ) + # A loose heuristic, but seems to be generally good before we drop into the + # manual handling of inputs + elif ( + name in module.__class__.__dict__ + and callable(module.__class__.__dict__[name]) + and all(x.is_tensor() for x in itertools.chain(args, kwargs.values())) + ): + return generic_call_method_helper(name) + else: + return super().call_method(tx, name, list(args), kwargs) + + +class UnspecializedNNModuleVariable(UserDefinedObjectVariable): + _nonvar_fields = { + "value_type", + "is_state_mutated", + "nn_module_stack_source", + *UserDefinedObjectVariable._nonvar_fields, + } + + """ + The above class will specialize on the id() of a module and place + parameters on the torch.fx.GraphModule. Giving one graph per + module instance. This version treats nn.Modules() like other user + defined objects and will pass parameters into the FX graph as inputs. + Giving one graph per module class. + """ + + def __init__(self, value: torch.nn.Module, **kwargs: Any) -> None: + if type(value) is torch.jit._script.RecursiveScriptModule: + unimplemented( + gb_type="UnspecializedNNModuleVariable wrapped around ScriptModules unsupported", + context=str(value), + explanation="ScriptModules aren't supported in UnspecializedNNModuleVariable" + " because their .forward function isn't a static member of their type.", + hints=[ + *graph_break_hints.DIFFICULT, + ], + ) + if "value_type" in kwargs: + lazy_value_to_become = getattr(kwargs["value_type"], "cls_to_become", None) + if type(value) is lazy_value_to_become: + # We may have cloned a variabletracker for a LazyModule earlier (e.g. tracking side-effects) + # and then later we called and mutated the LazyModule into a MaterializedModule. + # We do not do the mutation upon first seeing a LazyModule since we preserve eager semantics to only + # mutate upon first call, but this requires we update multiple copies of the VariableTracker post-mutation. + kwargs["value_type"] = type(value) + + super().__init__(value=value, **kwargs) + self.is_state_mutated = False + # nn_module_stack_source is used to ensure BC for nn_module_stack. + # Downstream users prefer mod.linear instead of mod._modules['linear'] + # as the module stack. When Dynamo inlines the __getattr__ method, we + # cannot use self.source for nn_module_stack because it will be similar + # to mod._modules['linear']. In these cases, we set the + # nn_module_stack_source appropriately to resemble mod.linear. + self.nn_module_stack_source = self.source + + def _wrap_source(self, attr_source: Source) -> Source: + # the vt is already wrapped with UnspecializedNNModuleSource + return attr_source + + def get_nn_module_stack_source(self) -> Source: + res = self.nn_module_stack_source or self.source + assert res + return res + + def set_nn_module_stack_source(self, source: Source) -> None: + self.nn_module_stack_source = source + + @staticmethod + @functools.cache + def _nn_module_method_ids() -> set[int]: + # Allow __setattr__ to fall through to base class handler + supported = { + torch.nn.Module.__setattr__, + torch.nn.Module.__init__, + torch.nn.Module.__delattr__, + } + return { + id(x.__code__) + for x in torch.nn.Module.__dict__.values() + if hasattr(x, "__code__") and x not in supported + } + + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: + try: + fn = inspect.getattr_static(self.value_type, "__iter__") + except AttributeError as e: + raise NotImplementedError from e + + if fn in ( + torch.nn.ModuleList.__iter__, + torch.nn.ParameterList.__iter__, + torch.nn.Sequential.__iter__, + ): + # The program can mutate the nn module object but the saved `value` + # will not reflect the mutations. So, trace through the `__iter__` + # function to reflect any tracked mutations. + return tx.inline_user_function_return( + VariableTracker.build(tx, fn), + [ + self, + ], + {}, + ).unpack_var_sequence(tx) + + return super().unpack_var_sequence(tx) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + mod = self.value + # see comment on lazy module handling in NNModuleVariable.call_function for context + if is_lazy_module(mod): # type: ignore[arg-type] + if mod.cls_to_become is not None: # type: ignore[attr-defined] + self.value_type = mod.cls_to_become # type: ignore[attr-defined,assignment] + initialize_lazy_module(tx, mod, args, kwargs) # type: ignore[arg-type] + + if not isinstance(mod, torch.fx.GraphModule): + name = "__call__" + fn = getattr(self.value_type, name) + else: + name = "_call_impl" + fn = getattr(self.value_type, name) + + # Check if we can short circuit nn.Module._call_impl to the forward + # method. NB - This is done to reduce the compile time of Dynamo. + if ( + istype(mod.__call__, types.MethodType) # type: ignore[operator] + and istype(mod._call_impl, types.MethodType) # type: ignore[attr-defined] + and mod.__call__.__func__ is unpatched_nn_module_call # type: ignore[operator] + and mod._call_impl.__func__ is unpatched_nn_module_call_impl # type: ignore[attr-defined] + and "forward" not in mod.__dict__ + ): + forward_method = inspect.getattr_static(mod, "forward") + if isinstance(forward_method, types.FunctionType): + globals_vt = tx.nn_modules_globals_vt + if not ( + self.var_getattr(tx, "_backward_hooks").realize().len() # type: ignore[attr-defined] + or self.var_getattr(tx, "_backward_pre_hooks").realize().len() # type: ignore[attr-defined] + or self.var_getattr(tx, "_forward_hooks").realize().len() # type: ignore[attr-defined] + or self.var_getattr(tx, "_forward_pre_hooks").realize().len() # type: ignore[attr-defined] + or globals_vt.var_getattr(tx, "_global_backward_pre_hooks").len() # type: ignore[attr-defined] + or globals_vt.var_getattr(tx, "_global_backward_hooks").len() # type: ignore[attr-defined] + or globals_vt.var_getattr(tx, "_global_forward_hooks").len() # type: ignore[attr-defined] + or globals_vt.var_getattr(tx, "_global_forward_pre_hooks").len() # type: ignore[attr-defined] + or globals_vt.var_getattr(tx, "_global_backward_pre_hooks").len() # type: ignore[attr-defined] + or globals_vt.var_getattr(tx, "_global_backward_hooks").len() # type: ignore[attr-defined] + or globals_vt.var_getattr(tx, "_global_forward_hooks").len() # type: ignore[attr-defined] + or globals_vt.var_getattr(tx, "_global_forward_pre_hooks").len() # type: ignore[attr-defined] + ): + name = "forward" + fn = self.value_type.forward + + if self.source: + source = self.get_source_by_walking_mro(name) + else: + source = None + + guard_to_detect_forward_monkeypatching(self.source, mod) # type: ignore[arg-type] + + ctx = ( + record_nn_module_stack( + str(id(mod)), + self.get_nn_module_stack_source(), + tx, + mod, # type: ignore[arg-type] + ) + if self.source + else nullcontext() + ) + with ctx: + if not isinstance(fn, (types.FunctionType, torch.jit.ScriptFunction)): + fn_vt = VariableTracker.build(tx, fn, source=source) + return fn_vt.call_function(tx, [self] + list(args), kwargs) + else: + # Ideally we would have just used VariableTracker.build(tx, fn, + # source=source) but that introduces guard on the + # `forward.__code__` object. Given that we already guard on the + # forward not present in generic dict, we dont need this guard. + return variables.UserFunctionVariable(fn, source=source).call_function( + tx, [self] + list(args), kwargs + ) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name in ["_call_impl", "_wrapped_call_impl"]: + fn = getattr(self.value_type, name) + if self.source: + source = self.get_source_by_walking_mro(name) + else: + source = None + + fn_vt = VariableTracker.build(tx, fn, source=source) + return fn_vt.call_function(tx, [self] + list(args), kwargs) + + if name not in getattr(self.value, "__dict__", {}): + try: + method = inspect.getattr_static(type(self.value), name) + except AttributeError: + method = None + + if isinstance(method, staticmethod): + source = AttrSource(self.get_source_by_walking_mro(name), "__func__") + fn_vt = VariableTracker.build(tx, method.__func__, source=source) + return fn_vt.call_function(tx, args, kwargs) + + if ( + hasattr(method, "__code__") + and id(method.__code__) in self._nn_module_method_ids() + ): + unimplemented( + gb_type="UnspecializedNNModuleVariable missing method", + context=f"call_method: {self} {name} {args} {kwargs}", + explanation=f"Dynamo does not support tracing method {name} of nn.Module {self.value}", + hints=[ + "Dynamo does not really define unspecialized nn.Module very well.", + *graph_break_hints.DIFFICULT, + ], + ) + + # "_parameters" in self.value.__dict__ checks that module is initialized + if name == "__setattr__" and "_parameters" in self.value.__dict__: + # Record if mutations happens on parameters/buffers/modules. The + # mutations on these are not tracked by base class + # UserDefinedObject vt. This will be used later to graph break + # on seeing a parameters() and family calls. + # TODO(anijain2305) - This might not be needed if we let Dynamo + # inline both getattr and setattr. In that case, it should see + # the lowest level dicts - _parameters and family and + # automatically track mutations on those. Investigate if that + # can be done. + attr_name = args[0].as_python_constant() + value = args[1] + + # This is reverse engineered by looking at nn module __setattr__ + # logic. + if ( + value.is_tensor() and value.python_type() is torch.nn.Parameter + ) or attr_name in self.value.__dict__["_parameters"]: + # Handle parameters + self.is_state_mutated = True + elif attr_name in self.value.__dict__["_buffers"]: + # Handle buffers + self.is_state_mutated = True + elif ( + isinstance( + value, + ( + variables.NNModuleVariable, + variables.UnspecializedNNModuleVariable, + ), + ) + or attr_name in self.value.__dict__["_modules"] + ): + # Handle submodules + self.is_state_mutated = True + + if ( + method is torch.nn.Module.__setattr__ + and isinstance(args[1], variables.DeletedVariable) + ) or method is torch.nn.Module.__delattr__: + # Trace through __delattr__ to track mutations on the module + # members like `_modules``. + fn_vt = VariableTracker.build(tx, torch.nn.Module.__delattr__) + return fn_vt.call_function(tx, [self, args[0]], kwargs) + + return super().call_method(tx, name, list(args), kwargs) + + def getattr_helper( + self, tx: "InstructionTranslator", field: str, name_vt: VariableTracker + ) -> Optional[VariableTracker]: + dict_vt = self.var_getattr(tx, field) + if isinstance(dict_vt, variables.ConstDictVariable): + return dict_vt.maybe_getitem_const(name_vt) + return None + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + # Allow skipping of empty hook dict guards on inbuilt nn modules + if name in ( + "_backward_hooks", + "_backward_pre_hooks", + "_forward_hooks", + "_forward_pre_hooks", + ): + # For empty hooks, make an EMPTY_NN_MODULE_HOOKS_DICT. This allows us to control the installation of empty + # hooks guard via skip_nnmodule_hook_guards + if not tx.output.side_effects.has_pending_mutation_of_attr(self, name): + hooks_dict = getattr(self.value, name) + if isinstance(hooks_dict, dict) and len(hooks_dict) == 0: + if self.source: + hooks_source = AttrSource(self.source, name) + install_guard( + hooks_source.make_guard( + GuardBuilder.EMPTY_NN_MODULE_HOOKS_DICT + ) + ) + return variables.ConstDictVariable({}) + + # For non-empty hook dicts, one way is to just fallback to VariableTracker.build() and create a ConstDictVariable. + # However, ConstDictVariable guards on keys. This can cause recompiles when the same hook is installed for + # different nn module instances, because the key keeps changing (look more into RemovableHandle to understand why + # key changes - also related https://github.com/pytorch/pytorch/issues/125836). Here, we carefully craft a + # NNModuleHooksDictVariable (a subclass of ConstDictVariable) to avoid any guard on the keys. + if ( + self.source + and name + in ( + "_forward_pre_hooks", + "_forward_hooks", + ) + and not tx.output.side_effects.has_pending_mutation_of_attr(self, name) + ): + hooks_dict = getattr(self.value, name) + hooks_dict_source = AttrSource(self.source, name) + install_guard(hooks_dict_source.make_guard(GuardBuilder.SEQUENCE_LENGTH)) + tx.output.guard_on_key_order.add(hooks_dict_source) + + def build_key_value( + i: int, k: Any, v: Any + ) -> tuple[VariableTracker, VariableTracker]: + # Make key sourceless to avoid any guard on it + key = variables.ConstantVariable.create(k) + + # Instead of using dict[key] to access the value, use a dict[dict.keys()[index]] to access the + # value. This removes the reliance on the actual key value. + source_key = ConstDictKeySource(hooks_dict_source, i) + source_value = DictGetItemSource(hooks_dict_source, source_key) + value = LazyVariableTracker.create(v, source_value) + return key, value + + result = dict( + build_key_value(i, k, v) for i, (k, v) in enumerate(hooks_dict.items()) + ) + + return variables.NNModuleHooksDictVariable( + result, type(hooks_dict), source=hooks_dict_source + ) + return super().var_getattr(tx, name) + + def manually_trace_nn_module_getattr( + self, tx: "InstructionTranslator", name: str + ) -> VariableTracker: + """ + Dynamo tracing of nn.Module __getattr__ can be expensive if the model + has deep submodule hierarchy. Since the __getattr__ is stable, we can + directly look into the underlying datastructures. This saves a lot of + compilation time. + """ + name_vt = variables.ConstantVariable(name) + out = self.getattr_helper(tx, "_parameters", name_vt) + if out is None: + out = self.getattr_helper(tx, "_modules", name_vt) + if out is None: + out = self.getattr_helper(tx, "_buffers", name_vt) + if out is None: + raise_observed_exception( + AttributeError, + tx, + args=[ + f"'{type(self.value).__name__}' object has no attribute '{name}'" + ], + ) + assert out is not None + return out + + +class UnspecializedBuiltinNNModuleVariable(UnspecializedNNModuleVariable): + """ + Differentiates between builtin nn modules (e.g. torch.nn.Linear) and user defined nn modules. + """ + + def _wrap_source(self, attr_source: Source) -> Source: + # vt is already wrapped with the UnspecializedBuiltinNNModuleSource + return attr_source + + +class FSDPManagedNNModuleVariable(UnspecializedNNModuleVariable): + """ + Tracing behavior: trace into submodules and treat them as Unspecialized, do not + register parameters to the top-level, treat them as function inputs. + + Guards behavior: if 'skip_fsdp_guards', many guards that would be installed + by a vanilla UnspecializedNNModuleVariable are simply dropped, on the basis + that a user wrapping their model in FSDP(model) is already opting into a + requirement to not modify internal model state, which would already break FSDP without + compilation. + """ + + def __init__(self, value: torch.nn.Module, **kwargs: Any) -> None: + source = kwargs.get("source") + assert source is not None, ( + "FSDPManagedNNModule depends on having an accurate source to control guarding." + ) + + super().__init__(value=value, **kwargs) + self.source = source + + def _wrap_source(self, attr_source: Any) -> Any: + if not isinstance( + attr_source, (FSDPNNModuleSource, UnspecializedNNModuleSource) + ): + if torch._dynamo.config.skip_fsdp_guards: + return FSDPNNModuleSource(attr_source) + else: + return UnspecializedNNModuleSource(attr_source) + return attr_source diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/optimizer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..53d3acc0d40118acecfa4d71bbcf10486e3f3dcf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/optimizer.py @@ -0,0 +1,420 @@ +""" +This module implements variable tracking for PyTorch optimizers during Dynamo tracing. + +The OptimizerVariable class provides specialized handling for optimizer instances by: +- Optimizing the tracing of expensive optimizer initialization +- Managing optimizer state and parameter group tracking +- Handling tensor sources and guards for optimizer state tensors +- Supporting CUDA graph execution through static tensor address management +- Providing special handling for parameter gradients and optimizer state tensors + +Key features include: +- Efficient initialization tracing via _init_group optimization +- Automatic marking of optimizer state tensors as static for CUDA graphs +- Proper source tracking for parameter groups, gradients, and state tensors +- Guard installation for optimizer state structure +- Support for both CPU and GPU tensor handling +- Cleanup of static tensor references via finalizers + +The module integrates with Dynamo's broader tracing system while providing +optimizer-specific optimizations and safety guarantees. +""" + +import logging +import weakref +from collections.abc import Iterable +from typing import Any, Optional, TYPE_CHECKING + +import torch +from torch._dynamo.variables.tensor import TensorVariable +from torch._guards import Source +from torch._logging import getArtifactLogger +from torch.utils._pytree import tree_map_only + +from ..guards import GuardBuilder, install_guard +from ..source import ( + AttrSource, + ConstDictKeySource, + DictGetItemSource, + GetItemSource, + GlobalWeakRefSource, + GradSource, +) +from ..utils import GLOBAL_KEY_PREFIX +from .base import VariableTracker +from .constant import ConstantVariable +from .dicts import ConstDictVariable +from .lists import ListVariable +from .misc import GetAttrVariable +from .user_defined import UserDefinedObjectVariable + + +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslator + + +class ArgMappingException(Exception): + pass + + +class GuardInstallException(Exception): + pass + + +perf_hint_log = getArtifactLogger(__name__, "perf_hints") + + +def _is_static_for_cudagraphs(x: torch.Tensor) -> bool: + from torch._inductor.cudagraph_trees import get_manager + + if x.is_cuda: + manager = get_manager(x.device.index, False) + is_static_address = torch._dynamo.utils.get_static_address_type(x) is not None + if manager: + assert manager.current_node is not None + return ( + is_static_address + or manager.current_node._is_cuda_graph_recorded_tensor(x) + ) + else: + return is_static_address + else: + # Don't print a warning for non-cuda tensors + return True + + +class OptimizerVariable(UserDefinedObjectVariable): + _nonvar_fields = { + "grad_to_source", + "tensor_to_source", + "static_tensor_names", + *UserDefinedObjectVariable._nonvar_fields, + } + + def __init__( + self, + value: torch.optim.Optimizer, + grad_to_source: Optional[dict[Any, GradSource]] = None, + static_tensor_names: Optional[set[str]] = None, + tensor_to_source: Optional[dict[torch.Tensor, Source]] = None, + **kwargs: Any, + ) -> None: + super().__init__(value, **kwargs) + # pyrefly: ignore [bad-override] + self.value: torch.optim.Optimizer = value + self.grad_to_source = grad_to_source or {} + self.tensor_to_source = tensor_to_source or {} + self.static_tensor_names = static_tensor_names or set() + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> "VariableTracker": + """This is an optimization to avoid tracing the very slow initialization of the optimizer""" + if name == "_init_group": + if not hasattr(self.value, "_init_group"): + # Fallback: if the optimizer does not have _init_group, trace normally + return super().call_method(tx, name, args, kwargs) + try: + self.graph_break_if_pending_mutation(tx) + self.move_step_if_cpu() + py_args, py_kwargs = self.get_python_args(*args, **kwargs) + ret_val = self.value._init_group(*py_args, **py_kwargs) + self.map_sources_and_install_guards(tx) + self.update_list_args(tx, args, kwargs, py_args, py_kwargs) + # stash a weak_ptr to optimizer to invalidate code + # if the optimizer object dies + mangled_name = f"__optimizer_{id(self.value)}" + tx.store_global_weakref_by_id(mangled_name, self.value) + self.create_finalizer(tx) + + # This is currently safe only because the only actual `ret_val`s returned + # by the `_init_group` of existing optimizers are properties that are invariant + # to the input tensors (e.g. dtype, layout). Changing these would trigger a + # recompilation and hence never result in the wrong specialization of `ret_val`. + return ConstantVariable.create(ret_val) + except (ArgMappingException, GuardInstallException) as _: + # trace normally if we can't map args or install guards correctly + pass + + return super().call_method(tx, name, args, kwargs) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + # Note: this allows us to intercept the call in call_method + # in the typical case, we return a UserMethodVariable + # which will directly inline + if name in ("_init_group"): + assert self.source + return GetAttrVariable(self, name, source=AttrSource(self.source, name)) + + if name == "param_groups": + from ..decorators import mark_static_address + + for group in self.value.param_groups: + for p in group["params"]: + mark_static_address(p, guard=True) + + self._set_capturable(tx) + + return super().var_getattr(tx, name) + + def graph_break_if_pending_mutation(self, tx: "InstructionTranslator") -> None: + # If there are pending mutations on a parameter (due to using closure) + # then we need to graph break to allow the python version of the parameter + # to update, so that running _init_group will initialize the states with + # the correct values + for g in self.value.param_groups: + for p in g["params"]: + side_effects = tx.output.side_effects + variable = side_effects.id_to_variable.get(id(p), None) + if variable and side_effects.has_pending_mutation(variable): + from ..exc import unimplemented + + unimplemented( + gb_type="optimizer: pending mutation on parameter", + context=f"variable: {variable}, parameter: {p}", + explanation="Pending mutations on a parameter (e.g. due to using closure) require a graph break.", + hints=[], + ) + + def _set_capturable(self, tx: "InstructionTranslator") -> None: + from . import LazyVariableTracker + + # We only set capturable if params are on cuda + # and the state is not initialized + def safe_to_set_capturable(group: dict[str, Any]) -> bool: + all_uninitialized = True + all_gpu = True + + for p in group.get("params", []): + all_gpu &= p.is_cuda or p.is_xpu + all_uninitialized &= p not in self.value.state + + return "capturable" in group and all_uninitialized and all_gpu + + # track indices to not set so we don't need to + # in the variable tracker realize the whole state + # we handle guarding the state specially + for group in self.value.param_groups: + if safe_to_set_capturable(group): + group["capturable"] = True + + source = self.source and AttrSource(self.source, "param_groups") + param_groups_vt = LazyVariableTracker.realize_all( + VariableTracker.build(tx, self.value.param_groups, source) + ) + for param_group_vt in param_groups_vt.items: + key = ConstDictVariable._HashableTracker( + ConstantVariable.create("capturable") + ) + param_group_vt.items[key] = ConstantVariable.create(True) + + def get_python_args( + self, *args: Any, **kwargs: Any + ) -> tuple[list[Any], dict[str, Any]]: + """Get python values equivalent to the variable tracker args""" + + def map_arg(arg: Any) -> Any: + if isinstance(arg, VariableTracker) and arg.is_python_constant(): + return arg.as_python_constant() + elif isinstance(arg, ListVariable) and not arg.items: + return [] + elif ( + isinstance(arg, ConstDictVariable) + and isinstance(arg.source, GetItemSource) + and isinstance(arg.source.base, AttrSource) + and arg.source.base.member == "param_groups" + ): + return self.value.param_groups[arg.source.index] + + raise ArgMappingException + + new_args = [map_arg(arg) for arg in args] + new_kwargs = {k: map_arg(v) for k, v in kwargs.items()} + + return new_args, new_kwargs + + # If users load an old state dictionary, + # it's possible that step could be on the cpu + # if this is the case, move it to the GPU + # corresponding to the parameter + # in most cases this is a no-op because the state is empty + def move_step_if_cpu(self) -> None: + for p, state in self.value.state.items(): + if "step" in state and state["step"].is_cpu: + state["step"] = state["step"].to(p.device) + + def map_sources_and_install_guards(self, tx: "InstructionTranslator") -> None: + from ..decorators import mark_static_address + from .lazy import LazyVariableTracker + + self.grad_to_source = {} + self.tensor_to_source = {} + + def mark_static(x: Any) -> None: + mark_static_address(x, guard=True) + + tree_map_only(torch.Tensor, mark_static, self.value.state) + + # Recursively realize the variable trackers for optim.state and + # optim.param_groups, which recursively install the necessary guards. + params_groups_source = self.source and AttrSource(self.source, "param_groups") + param_groups_vt = LazyVariableTracker.realize_all( + VariableTracker.build(tx, self.value.param_groups, params_groups_source) + ) + + state_source = self.source and AttrSource(self.source, "state") + state_vt = VariableTracker.build(tx, self.value.state, state_source) + + # We need to realize the top level state dict to populate + # the guard locals + state_vt.realize() + assert state_source is not None + tx.output.guard_on_key_order.add(state_source) + + # Populate self.grad_to_source and self.tensor_to_source so that we can + # manually update_list_args + for group, group_vt in zip(self.value.param_groups, param_groups_vt.items): + # we assume here that all params within a param group + # are initialized similarly + if len(group["params"]) > 0: + for param in group["params"]: + if param.grad is not None: + key_index = None + for i, k in enumerate(self.value.state.keys()): + if k is param: + key_index = i + break + if key_index: + LazyVariableTracker.realize_all( + VariableTracker.build( + tx, + self.value.state[param], + DictGetItemSource( + state_source, + ConstDictKeySource(state_source, key_index), + ), + ) + ) + break + + params_vt = group_vt.getitem_const(tx, ConstantVariable.create("params")) + all_static = True + non_static_grads = [] + for p, p_vt in zip(group["params"], params_vt.unpack_var_sequence(tx)): + param_source = p_vt.source + self.tensor_to_source[p] = param_source + grad_source = GradSource( + param_source, + "grad", + ) + + if p.grad is not None: + self.grad_to_source[p.grad] = grad_source + if not _is_static_for_cudagraphs(p.grad): + all_static = False + non_static_grads.append(grad_source) + else: + install_guard(grad_source.make_guard(GuardBuilder.CONSTANT_MATCH)) + + # Note: to avoid spam logs only warn if perf hint artifact is enabled + # (NB: artifacts are only enabled at the debug or warning level) + if not all_static and perf_hint_log.isEnabledFor(logging.DEBUG): + non_static_grad_names = [src.name for src in non_static_grads] + perf_hint_log.warning( + ( + "Grad tensors %s will be copied during cudagraphs execution." + "If using cudagraphs and the grad tensor addresses will be the same across runs," + " use torch._dynamo.decorators.mark_static_address to elide this copy.", + ), + non_static_grad_names, + ) + + # We have to again iterate over the state dict to collect the + # tensor_to_source dict. This is used for the finalizer. + for idx, value in enumerate(self.value.state.values()): + p_state_source = DictGetItemSource( + state_source, ConstDictKeySource(state_source, idx) + ) + tx.output.guard_on_key_order.add(p_state_source) + for inner_idx, v in enumerate(value.values()): + if ( + isinstance(v, torch.Tensor) + and v not in self.grad_to_source + and v not in self.tensor_to_source + ): + self.tensor_to_source[v] = DictGetItemSource( + p_state_source, ConstDictKeySource(p_state_source, inner_idx) + ) + + def wrap_tensor( + self, tx: "InstructionTranslator", tensor_value: torch.Tensor + ) -> TensorVariable: + """Wrap state tensor in a TensorVariable""" + from ..decorators import mark_static_address + + # If we have a source for a tensor already use it, + # if we have not seen a tensor before, stash and use a + # global weak ref source, since it must be an optimizer tensor + # that we have missed + + if tensor_value in self.tensor_to_source: + # mark these tensors as static for cudagraphs + mark_static_address(tensor_value, guard=True) + source = self.tensor_to_source[tensor_value] + self.static_tensor_names.add(tx.output.module_key_name(source.name)) + elif tensor_value in self.grad_to_source: + source = self.grad_to_source[tensor_value] + else: + # mark these tensors as static for cudagraphs + mark_static_address(tensor_value, guard=True) + + global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value) + source = GlobalWeakRefSource(global_name) + self.static_tensor_names.add(tx.output.module_key_name(source.name)) + + return VariableTracker.build(tx, tensor_value, source) + + def update_list_args( + self, + tx: "InstructionTranslator", + args: Iterable[VariableTracker], + kwargs: Any, + py_args: Iterable[Any], + py_kwargs: Any, + ) -> None: + """Update the args and kwargs to the traced optimizer call""" + for arg, py_arg in zip(args, py_args): + if isinstance(arg, ListVariable): + assert isinstance(py_arg, list), ( + "py_arg should be a list in optimizer variable" + ) + for i, val in enumerate(py_arg): + tx.output.side_effects.mutation(arg) + if isinstance(val, torch.Tensor): + arg.items.append(self.wrap_tensor(tx, val)) + else: + source = arg.source and GetItemSource(arg.source, i) + arg.items.append(VariableTracker.build(tx, val, source)) + + def create_finalizer(self, tx: "InstructionTranslator") -> None: + names_to_delete = self.static_tensor_names + value = self.value + tc = tx.output.tracing_context + + def init_finalizer(gm: torch.fx.GraphModule) -> None: + def clear_static_tensor_refs() -> None: + for name in names_to_delete: + gm._buffers.pop(name, None) + gm._parameters.pop(name, None) + if tc.params_flat: + tc.params_flat.clear() + if tc.params_flat_unwrap_subclasses: + tc.params_flat_unwrap_subclasses.clear() + + weakref.finalize(value, clear_static_tensor_refs) + + tx.output.add_graph_finalizer(init_finalizer) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/script_object.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/script_object.py new file mode 100644 index 0000000000000000000000000000000000000000..ed7f0873e8eb0164a8671c2b6e575e8495da9d0e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/script_object.py @@ -0,0 +1,236 @@ +""" +This module implements variable tracking for TorchScript objects during Dynamo tracing. + +The TorchScriptObjectVariable class provides specialized handling for TorchScript +objects with strong safety guarantees by: +- Enforcing method-call-only access to prevent unsafe attribute manipulation +- Converting graph breaks into hard errors via _raise_hard_error_if_graph_break +- Proper proxy and source tracking for TorchScript method calls +- Integration with higher-order operators for method call handling + +Key safety features: +- Strict validation that only method calls are allowed (no direct attribute access) +- Immediate error reporting for potentially unsafe operations +- Proper source tracking for debugging and guard installation +- Safe handling of TorchScript object method calls through torchbind + +The module ensures that TorchScript objects are handled safely during tracing +by limiting operations to known-safe patterns and failing fast for unsafe usage. +""" + +import functools +from collections.abc import Callable, Iterable +from typing import Any, Optional, TYPE_CHECKING, TypeVar +from typing_extensions import ParamSpec + +import torch +from torch._guards import Source +from torch._library.opaque_object import ( + is_opaque_reference_type, + is_opaque_type, + is_opaque_value_type, +) +from torch.fx.proxy import Proxy + +from .. import graph_break_hints +from ..eval_frame import skip_code +from ..exc import unimplemented, UnsafeScriptObjectError, Unsupported +from .base import VariableTracker +from .constant import ConstantVariable +from .dicts import ConstDictVariable +from .lists import TupleVariable +from .user_defined import UserDefinedObjectVariable, UserDefinedVariable + + +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslator + +_P = ParamSpec("_P") +_T = TypeVar("_T") + + +def _raise_hard_error_if_graph_break( + reason: str, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + def deco(fn: Callable[_P, _T]) -> Callable[_P, _T]: + @functools.wraps(fn) + def graph_break_as_hard_error(*args: _P.args, **kwargs: _P.kwargs) -> _T: + try: + return fn(*args, **kwargs) + except Unsupported as e: + raise UnsafeScriptObjectError(e.msg) from e + + return graph_break_as_hard_error + + return deco + + +class OpaqueObjectClassVariable(UserDefinedVariable): + """ + A variable that represents an opaque object class (not instance). + Since UserDefinedClassVariable has some special handling for side effects, + we have a separate class here which will directly return the object when + __init__ is called. + """ + + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + + def as_python_constant(self): + return self.value + + def is_python_hashable(self): + return is_opaque_value_type(type(self.value)) + + def as_proxy(self): + return self.value + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.value})" + + def call_function( # pyrefly: ignore[bad-override] + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + # disallow creating reference-type opaque objects in the middle of the + # program + if is_opaque_reference_type(self.value): + # Skip __init__ to prevent dynamo from tracing it during resume + skip_code(self.value.__init__.__code__) + + unimplemented( + gb_type="An opaque object was created in the middle of the program.", + context=f"Opaque object type: {self.value}.", + explanation=( + "Opaque objects cannot be created inside the torch.compile region. " + "They must be created before entering the compiled function." + ), + hints=[ + "Please create the opaque object before calling torch.compile " + "and pass it in as an argument or as a global variable." + ], + ) + + var_args = TupleVariable(list(args)) + var_kwargs = ConstDictVariable( + {ConstantVariable(k): v for k, v in kwargs.items()} + ) + opaque_obj = self.value( # pyrefly: ignore[not-callable] + *(var_args.as_python_constant()), + **(var_kwargs.as_python_constant()), + ) + + return TorchScriptObjectVariable.create(opaque_obj, opaque_obj) + + +class TorchScriptObjectVariable(UserDefinedObjectVariable): + _fake_script_object_cache: dict[int, "TorchScriptObjectVariable"] = {} + + @classmethod + def is_matching_cls(cls, user_cls: type) -> bool: + return issubclass(user_cls, torch.ScriptObject) or is_opaque_type(user_cls) + + @staticmethod + def create(proxy: Proxy, value: Any, **options: Any) -> "TorchScriptObjectVariable": + return TorchScriptObjectVariable(proxy, value, **options) + + def __init__( + self, proxy: Proxy, value: Any, source: Optional[Source] = None, **kwargs: Any + ) -> None: + super().__init__(value, **kwargs) + self.proxy = proxy + if isinstance(self.proxy, torch.fx.Proxy): + self.proxy.node.meta["example_value"] = value + self.source = source + + def as_proxy(self) -> Proxy: + return self.proxy + + @_raise_hard_error_if_graph_break( + "Dynamo cannot safely trace script object due to graph break." + ) + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + from torch._higher_order_ops.torchbind import call_torchbind + + from ..source import AttrSource + from .higher_order_ops import TorchHigherOrderOperatorVariable + + if is_opaque_value_type(type(self.value)): + res = super().var_getattr(tx, name) + return res + + if hasattr(self.value, "script_class_name") and is_opaque_type( + self.value.script_class_name + ): + # For non-value opaque types, block attribute access + unimplemented( + gb_type="Attempted to access attributes/methods on an OpaqueObject", + context=f"value={self.value}, attr={name}", + explanation="Attribute/method access of OpaqueObjects is not supported.", + hints=[ + "Use custom operators instead of direct attribute/method access.", + ], + ) + + method = getattr(self.value, name, None) + if method is None: + unimplemented( + gb_type="FakeScriptObject missing method implementation", + context=f"value={self.value}, method={name}", + explanation=f"TorchScript object {self.value} doesn't define the method {name}.", + hints=[ + f"Ensure the method {name} is implemented in {self.value}.", + *graph_break_hints.USER_ERROR, + ], + ) + + if not callable(method): + unimplemented( + gb_type="Attempted to access non-callable attribute of TorchScript object", + context=f"value={self.value}, method={name}", + explanation="Attribute accesses of TorchScript objects to non-callable attributes are not supported.", + hints=[ + "Use method calls instead of attribute access.", + ], + ) + assert self.source is not None + return TorchHigherOrderOperatorVariable.make( + call_torchbind, + source=AttrSource(self.source, name), + script_obj_var=self, + method_name=name, + ) + + # We only support method calls on script objects. Interpreting the bytecodes + # should go through var_getattr then call_function instead of call_method. + # + # However, it's possible for call_method to be used directly e.g. for __setattr__. + @_raise_hard_error_if_graph_break( + "Dynamo cannot safely trace script object due to graph break." + ) + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: Iterable[Any], + kwargs: dict[str, Any], + ) -> VariableTracker: + unimplemented( + gb_type="Weird method call on TorchScript object", + context=f"value={self.value}, method={name}", + explanation=( + f"This particular method call ({name}) is not supported (e.g. calling `__setattr__`). " + "Most method calls to TorchScript objects should be supported." + ), + hints=[ + "Avoid calling this method.", + ], + ) + + def as_python_constant(self): + if is_opaque_value_type(type(self.value)): + return self.value + return super().as_python_constant() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/sdpa.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/sdpa.py new file mode 100644 index 0000000000000000000000000000000000000000..1a7006f5d56ab364d91a974a3cd9e14aab6af317 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/sdpa.py @@ -0,0 +1,95 @@ +from collections.abc import Sequence +from inspect import getattr_static +from typing import Any, TYPE_CHECKING, TypeGuard + +from torch._guards import Source +from torch.backends.cuda import SDPAParams +from torch.fx.proxy import Proxy + +from ..bytecode_transformation import create_call_function +from ..exc import unimplemented +from ..source import AttrSource +from .base import VariableTracker + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + +PARAM_NAMES = [ + "query", + "key", + "value", + "attn_mask", + "dropout", + "is_causal", + "enable_gqa", +] + + +class SDPAParamsVariable(VariableTracker): + """Represents the c++ params struct for scaled dot product attention. + This is a read-only container.""" + + @staticmethod + def create( + tx: "InstructionTranslator", value: Any, source: Source + ) -> VariableTracker: + from .torch import TorchInGraphFunctionVariable + + params = [ + VariableTracker.build(tx, getattr(value, p), AttrSource(source, p)) + for p in PARAM_NAMES + ] + return TorchInGraphFunctionVariable(SDPAParams).call_function(tx, params, {}) + + def __init__( + self, proxy: Proxy, param_vars: Sequence[VariableTracker], **kwargs: Any + ) -> None: + self.proxy = proxy + self.param_vars = param_vars + super().__init__(**kwargs) + + def reconstruct(self, codegen: "PyCodegen") -> None: + assert self.source is None + assert self.param_vars is not None + codegen.add_push_null( + lambda: codegen.load_import_from("torch._C", "_SDPAParams") + ) + codegen.foreach(self.param_vars) + codegen.extend_output(create_call_function(len(self.param_vars), False)) + + def as_proxy(self) -> Proxy: + return self.proxy + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + import torch._C + + from .builder import wrap_fx_proxy + from .misc import GetAttrVariable + + try: + getattr_static(torch._C._SDPAParams, name) + except AttributeError: + import torch._dynamo.graph_break_hints as graph_break_hints + + unimplemented( + gb_type="unsupported torch._C._SDPAParams attribute", + context=f"name: {name}", + explanation=f"Unable to fetch attribute {name} from torch._C._SDPAParams.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name) + if self.source is not None: + return wrap_fx_proxy( + tx=tx, proxy=proxy, source=AttrSource(self.source, name) + ) + else: + return wrap_fx_proxy(tx=tx, proxy=proxy) + + @staticmethod + def is_sdpa_params(value: Any) -> TypeGuard["SDPAParams"]: + return value is SDPAParams diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/streams.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/streams.py new file mode 100644 index 0000000000000000000000000000000000000000..426f50e76d6ab918bfc1862ff6c4ff06556d9f68 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/streams.py @@ -0,0 +1,549 @@ +import collections +from collections.abc import Callable +from typing import Any, Optional + +import torch +from torch._dynamo.variables.dicts import ConstDictVariable +from torch._dynamo.variables.lists import TupleVariable +from torch.fx import has_side_effect, Proxy + +from .. import graph_break_hints +from ..bytecode_transformation import create_call_function +from ..exc import TYPE_CHECKING, unimplemented +from ..graph_bytecode_inputs import ( + get_external_object_by_index, + register_graph_created_object, +) +from ..source import CurrentStreamSource +from .base import VariableTracker +from .constant import ConstantVariable +from .ctx_manager import FxTracebackAnnotateVariable +from .lazy import LazyVariableTracker + + +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslator + + from ..codegen import PyCodegen + +from torch._library.custom_ops import custom_op + + +Tensor = torch.Tensor + + +def new_event(*args: Any, **kwargs: Any) -> int: + event = torch.Event(*args, **kwargs) + return register_graph_created_object( + event, + EventVariable.make_construct_in_graph_event_fn( + TupleVariable([]), ConstDictVariable({}) + ), + ) + + +def new_stream(*args: tuple[Any], **kwargs: Any) -> int: + stream = torch.Stream(*args, **kwargs) # type: ignore[no-matching-overload,call-overload] + return register_graph_created_object( + stream, + StreamVariable.make_construct_in_graph_stream_fn( + TupleVariable([]), ConstDictVariable({}) + ), + ) + + +def _codegen_current_stream(device: torch.device, cg: "PyCodegen") -> None: + cg.add_push_null( + lambda: cg.load_import_from( + torch._dynamo.graph_bytecode_inputs.__name__, # type: ignore[implicit-imports] + "stash_graph_created_object", + ) + ) + cg(CurrentStreamSource(device)) + cg.extend_output(create_call_function(1, False)) + + +def get_current_stream(device: torch.device) -> int: + stream = torch.accelerator.current_stream(device) + return register_graph_created_object( + stream, lambda _, cg: _codegen_current_stream(device, cg) + ) + + +def _get_stream_by_index(index: int) -> torch.Stream: + stream = get_external_object_by_index(index) + assert isinstance(stream, torch.Stream), ( + f"Fork/join stream expected a stream object at index {index}" + ) + return stream + + +def _get_event_by_index(index: int) -> torch.Event: + event = get_external_object_by_index(index) + assert isinstance(event, torch.Event), ( + f"Record/wait event expected an event object at index {index}" + ) + return event + + +@custom_op("streams::fork", mutates_args=()) +def fork_stream( + from_index: int, # kept to make stream transitions clearer + to_index: int, +) -> None: + torch.accelerator.set_stream(_get_stream_by_index(to_index)) + + +@fork_stream.register_fake +def _( + from_index: int, # kept to make stream transitions clearer + to_index: int, +) -> None: + pass + + +has_side_effect(torch.ops.streams.fork.default) + + +@custom_op("streams::join", mutates_args=()) +def join_stream(from_index: int, to_index: int) -> None: + torch.accelerator.set_stream(_get_stream_by_index(to_index)) + + +@join_stream.register_fake +def _( + from_index: int, + to_index: int, +) -> None: + pass + + +has_side_effect(torch.ops.streams.join.default) + + +@custom_op("streams::record_event", mutates_args=()) +def record_event(event_index: int, stream_index: int) -> None: + event = _get_event_by_index(event_index) + stream = _get_stream_by_index(stream_index) + stream.record_event(event) + + +@record_event.register_fake +def _( + event_index: int, + stream_index: int, +) -> None: + pass + + +has_side_effect(torch.ops.streams.record_event.default) + + +@custom_op("streams::wait_event", mutates_args=()) +def wait_event(event_index: int, stream_index: int) -> None: + event = _get_event_by_index(event_index) + stream = _get_stream_by_index(stream_index) + stream.wait_event(event) + + +@wait_event.register_fake +def _( + event_index: int, + stream_index: int, +) -> None: + pass + + +has_side_effect(torch.ops.streams.wait_event.default) + + +@custom_op("streams::wait_stream", mutates_args=()) +def wait_stream(waiting_stream_index: int, waited_on_stream_index: int) -> None: + waiting = _get_stream_by_index(waiting_stream_index) + waited_on = _get_stream_by_index(waited_on_stream_index) + waiting.wait_stream(waited_on) + + +@wait_stream.register_fake +def _( + event_index: int, + stream_index: int, +) -> None: + pass + + +has_side_effect(torch.ops.streams.wait_stream.default) + + +@custom_op("streams::sync_dealloc", mutates_args=()) +def sync_dealloc( + wait_event_index: int, src_stream_index: int, to_dealloc: torch.Tensor +) -> None: + """An op which waits on an event and moves the last usage of to_dealloc + after the wait, so that after the sync occurs, the deallocation or + subsequent reuse of the tensor's memory will be guaranteed to happen + after a side stream is finished using it. + See https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html#torch.Tensor.record_stream + for more details""" + torch.ops.streams.wait_event.default(wait_event_index, src_stream_index) + + +has_side_effect(torch.ops.streams.sync_dealloc.default) + + +@custom_op("streams::record_stream", mutates_args=()) +def record_stream(tensor: torch.Tensor, stream_index: int) -> None: + tensor.record_stream(_get_stream_by_index(stream_index)) + + +@record_stream.register_fake +def _( + src_stream_index: int, + wait_event_index: int, + to_dealloc: torch.Tensor, +) -> None: + pass + + +class SymbolicStreamState: + """Track the currently entered stream if any""" + + def __init__(self) -> None: + from ..source import CurrentStreamSource + + cur_stack: list[StreamVariable] = [] + if torch.accelerator.is_available(): + stream_var = LazyVariableTracker.create( + torch.accelerator.current_stream(), + source=CurrentStreamSource(torch.accelerator.current_stream().device), + ) + cur_stack = [stream_var] # type: ignore[list-item] + + self.cur_stream_stack: collections.deque[StreamVariable] = collections.deque( + cur_stack + ) + + def enter_stream(self, stream: "StreamVariable") -> None: + self.cur_stream_stack.append(stream) + + def exit_stream(self) -> None: + self.cur_stream_stack.pop() + + def cur_stream(self, device: Optional[torch.device] = None) -> "StreamVariable": + if device is not None: + for stream in reversed(self.cur_stream_stack): + if stream.device == device: + return stream + + return self.cur_stream_stack[-1] + + def in_stream_context(self) -> bool: + return len(self.cur_stream_stack) > 0 + + +class StreamContextVariable(FxTracebackAnnotateVariable): + """This represents torch.cuda.StreamContext""" + + @staticmethod + def create( + tx: "InstructionTranslator", + stream_to_enter: "StreamVariable", + **kwargs: dict[str, Any], + ) -> "StreamContextVariable": + return StreamContextVariable( + stream_to_enter, + **kwargs, + ) + + def __init__(self, stream: Optional["StreamVariable"], **kwargs: Any) -> None: + self.stream = stream + super().__init__( + target_values={"stream": self.get_stream().user_object_index}, + initial_values=None, + **kwargs, + ) + + def enter( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + # to stream, from stream is the order of the arguments + # we are entering the target, and leaving the initial stream + tx.symbolic_stream_state.enter_stream(self.get_stream()) + return super().enter(tx) + + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + # to stream, from stream is the order of the arguments + # we are leaving the target, and entering the initial stream + tx.symbolic_stream_state.exit_stream() + return super().exit(tx, *args) + + def supports_graph_breaks(self) -> bool: + return True + + def get_stream(self) -> "StreamVariable": + assert self.stream, "Stream context should have a separate stream" + return self.stream + + +class StreamVariable(StreamContextVariable): + """Represents the device-agnostic torch.Stream class""" + + def __init__( + self, + proxy: Proxy, + value: torch.Stream, + user_object_index: Optional[int] = None, + **kwargs: Any, + ) -> None: + # Index into the user object table + # used to pass arbitrary objects to the graph + if proxy is not None and "example_value" in proxy.node.meta: + assert proxy.node.meta["example_value"] == value + + self.proxy = proxy + self.value = value + # pyrefly: ignore [read-only] + self.device = value.device + # pyrefly: ignore [read-only] + self.user_object_index = user_object_index + super().__init__(None, **kwargs) + + def python_type(self) -> type: + return torch.Stream + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + assert hasattr(self.value, name), f"no stream method found named {name}" + + from ..utils import cmp_name_to_op_mapping, proxy_args_kwargs + from .builder import wrap_fx_proxy_cls + + if name in ("wait_stream", "synchronize", "wait_event"): + tx.output.create_proxy( + "call_method", name, *proxy_args_kwargs([self] + args, kwargs) + ) + return ConstantVariable(None) + elif name == "query": + return wrap_fx_proxy_cls( + target_cls=ConstantVariable, + tx=tx, + proxy=tx.output.create_proxy( + "call_method", name, *proxy_args_kwargs([self] + args, kwargs) + ), + ) + elif name == "record_event": + return wrap_fx_proxy_cls( + target_cls=EventVariable, + tx=tx, + proxy=tx.output.create_proxy( + "call_method", name, *proxy_args_kwargs([self] + args, kwargs) + ), + ) + elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs: + from ..guards import GuardBuilder, install_guard + + if self.source: + install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH)) + + # NB : Checking for mutation is necessary because we compare + # constant values + other = args[0] + if not isinstance(other, StreamVariable): + return ConstantVariable.create(NotImplemented) + + if other.source: + assert self.source is not None + install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH)) + return ConstantVariable.create( + cmp_name_to_op_mapping[name](self.value, other.value) # type: ignore[arg-type] + ) + + return super().call_method(tx, name, args, kwargs) + + def as_proxy(self) -> Proxy: + return self.proxy + + def module_name(self) -> str: + return "torch._C" + + def fn_name(self) -> str: + return "Stream" + + def reconstruct(self, codegen: "PyCodegen") -> None: + # If we got here, this stream is fully subsumed by the graph - this means it is + # not an input or global + assert not self.source + if self.user_object_index is not None: + codegen.add_push_null( + lambda: codegen.load_import_from( + torch._dynamo.graph_bytecode_inputs.__name__, + "get_external_object_by_index", + ) + ) + codegen.append_output(codegen.create_load_const(self.user_object_index)) + codegen.extend_output(create_call_function(1, False)) + else: + # This will support the legacy behavior + prefix = f"_stream_{self.device}" + name = codegen.tx.output.install_global_by_id(prefix, self.value) + codegen.append_output(codegen.create_load_global(name, add=True)) + + def get_stream(self) -> "StreamVariable": + return self + + @staticmethod + def make_construct_in_graph_stream_fn( + args: TupleVariable, kwargs: ConstDictVariable + ) -> Callable[[int, "PyCodegen"], None]: + def fn(index: int, codegen: "PyCodegen") -> None: + codegen.add_push_null( + lambda: codegen.load_import_from( + torch._dynamo.graph_bytecode_inputs.__name__, # type: ignore[implicit-imports] + "stash_graph_created_object", + ) + ) + codegen.add_push_null( + lambda: codegen.load_import_from( + torch._dynamo.utils.__name__, "build_stream" + ) + ) + codegen(args) + codegen(kwargs) + codegen.extend_output(create_call_function(2, False)) + codegen.extend_output(create_call_function(1, False)) + + return fn + + +class EventVariable(VariableTracker): + def __init__( + self, + proxy: Proxy, + value: torch.Event, + user_object_index: Optional[int], + **kwargs: Any, + ) -> None: + if proxy is not None and "example_value" in proxy.node.meta: + assert proxy.node.meta["example_value"] == value + super().__init__(**kwargs) + self.proxy = proxy + self.value = value + self.user_object_index = user_object_index + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from ..utils import proxy_args_kwargs + from .builder import wrap_fx_proxy_cls + + if name == "wait": + tx.output.create_proxy( + "call_function", + torch.ops.streams.wait_event, + ( + self.user_object_index, + EventVariable._get_stream_arg(tx, args, kwargs).user_object_index, + ), + {}, + ) + return ConstantVariable(None) + elif name == "record": + tx.output.create_proxy( + "call_function", + torch.ops.streams.record_event, + ( + self.user_object_index, + EventVariable._get_stream_arg(tx, args, kwargs).user_object_index, + ), + {}, + ) + return ConstantVariable(None) + elif name == "synchronize": + tx.output.create_proxy( + "call_method", name, *proxy_args_kwargs([self] + args, kwargs) + ) + return ConstantVariable(None) + elif name == "query": + return wrap_fx_proxy_cls( + target_cls=ConstantVariable, + tx=tx, + proxy=tx.output.create_proxy( + "call_method", name, *proxy_args_kwargs([self] + args, kwargs) + ), + ) + else: + method_name = ( + f"{type(self.value).__module__}.{type(self.value).__qualname__}.{name}" + ) + unimplemented( + gb_type="Unsupported event method", + context=str(name), + explanation=f"Dynamo doesn't support tracing the {method_name} method. " + f"We currently support wait, record, synchronize, and query.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + def as_proxy(self) -> Proxy: + return self.proxy + + @staticmethod + def _get_stream_arg( + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> "StreamVariable": + stream_arg = None + if args: + stream_arg = args[0] + elif kwargs: + stream_arg = kwargs.get("stream") + + if not stream_arg: + stream_arg = tx.symbolic_stream_state.cur_stream() + + return stream_arg # type: ignore[return-value] + + @staticmethod + def make_construct_in_graph_event_fn( + args: TupleVariable, kwargs: ConstDictVariable + ) -> Callable[[int, "PyCodegen"], None]: + def fn(index: int, codegen: "PyCodegen") -> None: + codegen.add_push_null( + lambda: codegen.load_import_from( + torch._dynamo.graph_bytecode_inputs.__name__, # type: ignore[implicit-imports] + "stash_graph_created_object", + ) + ) + codegen.add_push_null( + lambda: codegen.load_import_from( + torch._dynamo.utils.__name__, "build_event" + ) + ) + codegen(args) + codegen(kwargs) + codegen.extend_output(create_call_function(2, False)) + codegen.extend_output(create_call_function(1, False)) + + return fn + + def reconstruct(self, codegen: "PyCodegen") -> None: + # If we got here, this event is fully subsumed by the graph - this means it is + # not an input or global + assert not self.source + # Similar to stream handling, we lift the event into a global and then codegen bytecode to load it from there. + prefix = "_event" + name = codegen.tx.output.install_global_by_id(prefix, self.value) + codegen.append_output(codegen.create_load_global(name, add=True)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/tensor.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..94b72200c72fa2e73a59a1bd0333d30e7ddc85f0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/tensor.py @@ -0,0 +1,1889 @@ +# mypy: ignore-errors + +""" +This module contains variable tracker classes for handling tensors and tensor-related operations in Dynamo. + +The main class is TensorVariable which represents torch.Tensor inputs and intermediate values in the FX graph. +It handles tensor operations, method calls, and maintains metadata about tensor properties like dtype, device, etc. + +Other key classes include: +- SymNodeVariable: Represents symbolic scalars (int/float/bool) used for size computation and unspecialized values +- NumpyNdarrayVariable: Handles numpy array interop through torch._numpy +- UnspecializedPythonVariable: Represents unspecialized Python numeric values as 1-element tensors +- TensorSubclassVariable: Handles tensor subclasses with __torch_function__ overrides +- UntypedStorageVariable: Represents tensor storage objects +- DataPtrVariable: Handles tensor data pointer operations + +These classes work together to track tensor operations and properties during Dynamo's tracing process. +""" + +import functools +import logging +import operator +import textwrap +import traceback +import types +from collections.abc import Sequence +from contextlib import nullcontext +from typing import TYPE_CHECKING + +import sympy + +import torch._numpy as tnp +import torch.fx +import torch.random +from torch._dynamo import compiled_autograd +from torch._subclasses.meta_utils import is_sparse_any +from torch.fx.experimental.symbolic_shapes import ( + guard_scalar, + GuardOnDataDependentSymNode, + has_free_symbols, + is_symbolic, + SymTypes, +) +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from .. import config, graph_break_hints, variables +from .._trace_wrapped_higher_order_op import trace_wrapped +from ..exc import ( + unimplemented, + UnknownPropertiesDuringBackwardTrace, + UserError, + UserErrorType, +) +from ..external_utils import call_hook_from_backward_state +from ..guards import GuardBuilder, install_guard +from ..source import AttrSource +from ..utils import ( + fqn, + get_custom_getattr, + get_fake_value, + get_real_value, + guard_if_dyn, + object_has_getattribute, + product, + proxy_args_kwargs, + raise_args_mismatch, + set_example_value, + tensortype_to_dtype, +) +from .base import AttributeMutationNew, ValueMutationNew, VariableTracker +from .constant import ConstantVariable +from .lists import ListIteratorVariable, SizeVariable +from .user_defined import UserDefinedClassVariable + + +try: + import numpy as np +except ModuleNotFoundError: + np = None + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + from .functions import UserFunctionVariable + + +log = logging.getLogger(__name__) + +# Ops that allow tensor tensor +supported_tensor_comparison_ops = { + ">": operator.gt, + "<": operator.lt, + ">=": operator.ge, + "<=": operator.le, + "==": operator.eq, + "!=": operator.ne, + "is": operator.is_, + "is not": operator.is_not, +} +# Ops that allow tensor None +supported_const_comparison_ops = { + "is": operator.is_, + "is not": operator.is_not, + "==": operator.eq, + "!=": operator.ne, +} +supported_comparison_ops = { + **supported_tensor_comparison_ops, + **supported_const_comparison_ops, +} +supported_tensor_comparison_op_values = dict.fromkeys( + supported_tensor_comparison_ops.values() +) +supported_const_comparison_op_values = dict.fromkeys( + supported_const_comparison_ops.values() +) + + +def is_bound_tensor_method(value): + return ( + callable(value) + and not torch._dynamo.utils.object_has_getattribute(value) + and hasattr(value, "__self__") + and isinstance(value.__self__, torch.Tensor) + and getattr(value.__self__, value.__name__, None) + ) + + +# instead of using inspect.getattr_static, we directly lookup the appropriate +# dicts. It is necessary to keep the torch._C.TensorBase first in the or +# operation, because the second arg takes priority in or operation when there +# are common keys. +all_tensor_attrs = torch._C.TensorBase.__dict__ | torch.Tensor.__dict__ + + +class TensorVariable(VariableTracker): + """A torch.Tensor input or an intermediate value in the FX graph""" + + _nonvar_fields = { + "proxy", + "dtype", + "device", + "layout", + "ndim", + "size", + "stride", + "requires_grad", + "is_quantized", + "is_contiguous", + "is_nested", + "is_sparse", + "class_type", + "specialized_value", + "_is_name_set", + *VariableTracker._nonvar_fields, + } + + def get_real_value(self): + """ + Get the actual value represented by this variable if computation is run + using the user-provided inputs. + NOTE: this runs actual tensor computation and may be + slow and memory-intensive. + """ + return get_real_value(self.proxy.node, self.proxy.tracer) + + def __init__( + self, + proxy: torch.fx.Proxy, + *, + dtype, + device, + layout, + ndim, + requires_grad, + is_nested, + is_quantized, + is_sparse, + class_type, + has_grad_fn, + _size=None, + stride=None, + is_contiguous=None, + _is_name_set=None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.proxy = proxy + self.dtype = dtype + self.device = device + self.layout = layout + self.ndim = ndim + self._size = _size # this is accessed as a property for validation + self.stride = stride + self.requires_grad = requires_grad + self.is_quantized = is_quantized + self.is_contiguous = is_contiguous + self.is_nested = is_nested + self.is_sparse = is_sparse + self.class_type = class_type + self.has_grad_fn = has_grad_fn + if _is_name_set is None: + # no need to rename inputs + _is_name_set = self.proxy.node.op == "placeholder" + self._is_name_set: bool = _is_name_set + + def synchronize_attributes(self, tx, target_cls=None): + from .builder import get_specialized_props, infer_subclass_type + + if target_cls is None: + target_cls = type(self) + + example_value = self.proxy.node.meta.get("example_value") + specialized_props = get_specialized_props( + target_cls, tx, example_value, infer_subclass_type(example_value) + ) + for k, v in specialized_props.items(): + setattr(self, k, v) + + def debug_repr(self): + # TODO: strip off fake tensor from repr here + return repr(self.proxy.node.meta["example_value"]) + + def as_proxy(self): + return self.proxy + + def python_type(self): + return self.class_type + + def is_tensor(self) -> bool: + return True + + @staticmethod + def specialize(value: torch.Tensor): + props = { + "dtype": value.dtype, + "device": value.device, + "layout": value.layout, + "ndim": int(value.ndim), + "requires_grad": value.requires_grad, + "is_nested": value.is_nested, + "is_quantized": value.is_quantized, + "is_sparse": value.is_sparse, + "class_type": type(value), + } + try: + props["has_grad_fn"] = value.grad_fn is not None + except Exception: + # Workaround for issues with create_parameter_op in Dynamo. Reading + # grad_fn should never cause an issue. + props["has_grad_fn"] = False + + if is_sparse_any(value) and not has_free_symbols(value): + props["_size"] = tuple( + int(s) if is_symbolic(s) else s for s in value.size() + ) + elif not has_free_symbols(value): + # this is a fully static shape, and the keys on props here inform specialization. + # We have to cast to int here, because these might get accessed as ConstantVariable, which has + # a strict no-symint policy. If we got here due to not having free symbols, this is a known constant + # already. We could remove the discrepancy here, by having ConstantVariable be more permissive for + # constant backed SymInts, but that assert being strict has led to some good signal in hunting bugs, and + # I'd like to keep it around for now. + props["_size"] = tuple( + # the non is_symbolic case applies to the jagged layout + # NestedTensor case as singleton ints are not symbolic + int(s) if is_symbolic(s) else s + for s in value.size() + ) + props["stride"] = tuple(value.stride()) + if torch._C._functorch.is_batchedtensor(value): + # Batched tensors does not support contiguity patterns, so + # we refrain from computing the `is_contiguous` property + props["is_contiguous"] = None + else: + props["is_contiguous"] = tuple( + x + for x in torch._prims_common._memory_formats + if value.is_contiguous(memory_format=x) + ) + return props + + def dynamic_getattr(self, tx: "InstructionTranslator", name): + fake_val = self.proxy.node.meta["example_value"] + # For getattrs on tensors without sources, + # we can do better than the default (creating a GetAttrVariable) + # if: + # (1) the tensor is a traceable tensor subclass + # (2) We are getattr'ing an inner tensor from that subclass + if not self.source and is_traceable_wrapper_subclass(fake_val): + attrs, _ctx = fake_val.__tensor_flatten__() + proxy = getattr(self.as_proxy(), name) + example_value = getattr(fake_val, name) + if name in attrs: + # attrs returned from tensor_flatten are always tensors + assert isinstance(example_value, torch.Tensor) + from .builder import wrap_fx_proxy + + return wrap_fx_proxy(tx=tx, proxy=proxy, example_value=example_value) + # any other attributes on the subclass (that are not methods) + # are assumed to be constant metadata. + elif not callable(example_value): + return VariableTracker.build(tx, example_value) + + if not (self.source and self.source.subguards_allowed()): + raise NotImplementedError + + # For local source, we associate the real value. We use this real value + # for implementing getattr fallthrough on the variable tracker base class. + + # Note - this scope construction is mirrored in guards + # A subsequent PR will introduce a util. + scope = {"L": tx.output.local_scope, "G": tx.output.global_scope} + try: + # We raise in case we get a typerror bug w/ SuperSource. + # SuperSource has bugs in it atm, and can produce code like + # eval("super(L['mod'].model.model.encoder.embed_positions.forward__class__, + # L['mod'].model.model.encoder.embed_positions)", scope) + # Which is incorrect, and violates the invariant that all sources should be eval()-able against the scope. + _input_associated_real_value = eval(self.source.name, scope) + except Exception as exc: + raise NotImplementedError from exc + + if _input_associated_real_value is None: + raise NotImplementedError + + if object_has_getattribute(_input_associated_real_value): + raise NotImplementedError + + if get_custom_getattr(_input_associated_real_value): + raise NotImplementedError + + real_value = getattr(_input_associated_real_value, name) + + attr_source = AttrSource(self.source, name) + + # Typically we'd want to use variable builder here + # but unfortunately id(real_value.__self__) is not id() + if is_bound_tensor_method(real_value): + # No need to install the guard because its a bound tensor method + from .misc import GetAttrVariable + + return GetAttrVariable( + self, name, source=attr_source, py_type=type(real_value) + ) + + install_guard(attr_source.make_guard(GuardBuilder.HASATTR)) + return VariableTracker.build(tx, real_value, attr_source) + + def method_attr_ndim(self, tx): + if self.ndim is not None: + return ConstantVariable.create(self.ndim) + else: + return self.call_method(tx, "dim", [], {}) + + def method_attr_dtype(self, tx): + if self.dtype is not None: + return ConstantVariable.create(self.dtype) + + def method_attr_device(self, tx): + if self.device is not None: + return ConstantVariable.create(self.device) + + def method_attr_layout(self, tx): + if self.layout is not None: + return ConstantVariable.create(self.layout) + + def method_attr_is_cuda(self, tx): + if self.device is not None: + return ConstantVariable.create(self.device.type == "cuda") + + def method_attr_shape(self, tx): + if self.valid_size(): + sizes = [variables.ConstantVariable.create(x) for x in self.size] + return SizeVariable(sizes) + else: + return self.call_method(tx, "size", [], {}) + + def method_attr_requires_grad(self, tx): + if self.requires_grad is not None: + return ConstantVariable.create(self.requires_grad) + + def method_attr_is_quantized(self, tx): + if self.is_quantized is not None: + return ConstantVariable.create(self.is_quantized) + + def method_attr_is_sparse(self, tx): + if self.is_sparse is not None: + return ConstantVariable.create(self.is_sparse) + + def method_attr_is_nested(self, tx): + if self.is_nested is not None: + return ConstantVariable.create(self.is_nested) + + def method_attr_retain_grad(self, tx): + unimplemented( + gb_type="Tensor.retain_grad() with AOTDispatcher", + context=f"var_getattr {self} retain_grad", + explanation="`Tensor.retain_grad()` does not work with AOTDispatcher.", + hints=[], + ) + + def method_attr_data(self, tx): + return variables.TorchInGraphFunctionVariable( + torch._C._autograd._get_data_attr + ).call_function(tx, [self], {}) + + def method_attr_grad_fn(self, tx): + if self.has_grad_fn: + unimplemented( + gb_type="Tensor with grad_fn()", + context=f"var_getattr {self} grad_fn", + explanation="Dynamo does not support tracing tensors with a grad_fn directly.", + hints=[], + ) + else: + return variables.ConstantVariable(None) + + def method_attr__version(self, tx): + from ..tensor_version_op import _tensor_version + + return variables.TorchInGraphFunctionVariable(_tensor_version).call_function( + tx, [self], {} + ) + + def call_obj_hasattr(self, tx: "InstructionTranslator", name): + from . import GetAttrVariable + from .builtin import BuiltinVariable + + # TODO - This is not a good solution but solves an accuracy issue. + # Today, var_getattr returns GetAttrVariable for both non-existent + # attributes and existing attributes. This is a bug and requires more + # deep dive. + if name in all_tensor_attrs: + return ConstantVariable(True) + + try: + var = BuiltinVariable(getattr).call_function( + tx, [self, ConstantVariable(name)], {} + ) + # in the event that TensorVariable returns NotImplemented + # BuiltinVariable.call_getattr returns GetAttrVariable + ret_val = not isinstance(var, GetAttrVariable) + except AttributeError: + ret_val = False + + if self.source: + install_guard( + AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR) + ) + + return ConstantVariable(ret_val) + + def var_getattr(self, tx: "InstructionTranslator", name): + if self.is_strict_mode(tx): + if name in self._strict_mode_banned_ops(): + unimplemented( + gb_type="Strict mode banned op", + context=f"var_getattr {self} {name}", + explanation=f"Getattr invocation '{name}' in strict mode is not supported.", + hints=[ + f"Remove `{name}` from the list of banned ops by " + "setting `torch._dynamo.config._autograd_backward_strict_mode_banned_ops`.", + ], + ) + elif name in self._strict_mode_conditional_banned_ops(): + raise UnknownPropertiesDuringBackwardTrace( + f"Unknown property {name} during speculating backward, dynamo will insert contiguous call ahead and speculate it again" # noqa: B950 + ) + + if name == "__class__": + return UserDefinedClassVariable(self.python_type()) + + handler = getattr(self, f"method_attr_{name}", None) + result = handler(tx) if handler is not None else None + + # Add a guard for type matching, these guards are checked before tensor guards + # In some cases, a . guard can be evaluated first, and break if + # is later changed to another type + if ( + result is not None + and self.source + and self.source.subguards_allowed() + and not ( + name not in ("grad", "requires_grad") and result.is_python_constant() + ) + ): + install_guard(self.make_guard(GuardBuilder.TYPE_MATCH)) + result.source = AttrSource(self.source, name) + + # It's hard to get inplace view (metadata mutation) on graph input work properly across + # dynamo/aot/inductor, just fall back. + if self.source is not None and hasattr(torch.ops.aten, name): + fn = getattr(torch.ops.aten, name) + if ( + hasattr(fn, "overloads") + and hasattr(fn, fn.overloads()[0]) + and torch.Tag.inplace_view in getattr(fn, fn.overloads()[0]).tags + ): + # Delay the graph break to the actual call of unsqueeze_/resize_/resize_as_ etc. + return variables.misc.DelayGraphBreakVariable( + source=AttrSource(self.source, name), + msg="Getting an inplace view on a graph input is not supported", + ) + + # For attributes (not methods) that were not caught in the special handling above, + # (e.g. tensor.real), we handle these generically, assuming that the output type is + # a tensor. + if result is None and name != "grad": + + def try_generic_attr_handling(): + from .builder import wrap_fx_proxy + from .misc import GetAttrVariable + + static_attr = all_tensor_attrs.get(name, None) + if static_attr is None: + return None + + # Make sure this is an attribute, not a method. + # type(torch.Tensor.H) should be "getset_descriptor" + # This is a because of CPython implementation, see THPVariableType: + # these attributes are implemented under tp_getset, which appear + # as `getset_descriptor`s, (compared to, say, methods which appear + # as `method_descriptor`s) + if type(static_attr) is not types.GetSetDescriptorType: + return None + + proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name) + if self.source is not None: + return wrap_fx_proxy( + tx=tx, proxy=proxy, source=AttrSource(self.source, name) + ) + else: + return wrap_fx_proxy(tx=tx, proxy=proxy) + + result = try_generic_attr_handling() + + if result is None: + result = self.dynamic_getattr(tx, name) + + if result is None: + raise NotImplementedError + return result + + def call_id(self, tx): + if not self.source: + unimplemented( + gb_type="Unsupported call_id() without source", + context=f"call_id {self}", + explanation="call_id() not supported for sourceless TensorVariable.", + hints=[], + ) + + # For local source, we associate the real value. We use this real value + scope = {"L": tx.output.local_scope, "G": tx.output.global_scope} + try: + _input_associated_real_value = eval(self.source.name, scope) + except Exception as exc: + unimplemented( + gb_type="Error getting associated real value", + context=f"call_id {self}", + explanation="Dynamo encountered an error while trying to " + "get the associated real value.", + hints=[], + from_exc=exc, + ) + + if _input_associated_real_value is None: + unimplemented( + gb_type="call_id() without associated real value", + context=f"call_id {self}", + explanation="Dynamo could not find an associated real value for the tensor.", + hints=[], + ) + + install_guard(self.source.make_guard(GuardBuilder.ID_MATCH)) + id_value = id(_input_associated_real_value) + return ConstantVariable.create(id_value) + + def has_unpack_var_sequence(self, tx): + return self.ndim > 0 + + def unpack_var_sequence(self, tx: "InstructionTranslator", idxes=None): + from .builder import wrap_fx_proxy_cls + + if self.valid_size(): + size_len = len(self.size) + else: + size_var = self.call_method(tx, "size", [], {}) + assert isinstance(size_var, SizeVariable) + size_len = len(size_var.items) + # Ensure we don't unpack a scalar tensor. + assert size_len != 0, "Can't unpack scalar tensors." + + if self.valid_size(): + length = self.size[0] + else: + dyn_length = self.call_method(tx, "size", [ConstantVariable.create(0)], {}) + # SymNodeVariable for symbolic sizes, ConstantVariable for constants OR values produced through + # symbolic_shapes, but that end up as int/sympy.Integer + assert ( + isinstance(dyn_length, SymNodeVariable) + or dyn_length.is_python_constant() + ) + if isinstance(dyn_length, SymNodeVariable): + length = dyn_length.evaluate_expr(tx.output) + else: + length = dyn_length.as_python_constant() + + if idxes is None: + idxes = range(length) + else: + assert len(idxes) == length, ( + f"Can't unpack a tensor of {length} rows into a tuple of {len(idxes)} elements." + ) + return [ + wrap_fx_proxy_cls(target_cls=type(self), tx=tx, proxy=self.as_proxy()[i]) + for i in idxes + ] + + def call_tree_map( + self, + tx, + tree_map_fn: "UserFunctionVariable", + map_fn, + rest, + tree_map_kwargs, + ) -> "VariableTracker": + return map_fn.call_function(tx, [self, *rest], {}) + + def valid_size(self): + return self._size is not None + + @property + def size(self): + assert self._size is not None, "accessing None size in TensorVariable" + return self._size + + def _strict_mode_banned_ops(self): + return torch._dynamo.config._autograd_backward_strict_mode_banned_ops + + def _strict_mode_conditional_banned_ops(self): + return ( + torch._dynamo.config._autograd_backward_strict_mode_conditional_banned_ops + ) + + def call_method( + self, + tx, + name, + args: Sequence[VariableTracker], + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from .builder import SourcelessBuilder, VariableBuilder + from .torch_function import can_dispatch_torch_function, dispatch_torch_function + + if self.is_strict_mode(tx) and name in self._strict_mode_banned_ops(): + unimplemented( + gb_type="Illegal method invocation in strict mode", + context=f"call_method {self} {name} {args} {kwargs}", + explanation="Dynamo currently does not support this method " + f"({name}) invocation in strict mode.", + hints=[], + ) + + # Only override builtin tensor methods + # The user can manually add override handling + # with a decorator for other methods (e.g. a dispatch subclass with other methods) + static_attr = all_tensor_attrs.get(name, None) + is_base_tensor_method = static_attr is not None + + if ( + can_dispatch_torch_function(tx, tuple([self] + list(args)), kwargs) + and is_base_tensor_method + ): + if self.source: + func_var = VariableBuilder( + tx, AttrSource(AttrSource(self.source, "__class__"), name) + )(static_attr) + else: + func_var = SourcelessBuilder.create(tx, getattr(torch.Tensor, name)) + + return dispatch_torch_function( + tx, func_var, tuple([self] + list(args)), kwargs + ) + + """ + Dispatch to a method-specific handler defined below. If the + handler returns None (or doesn't exist) we put the method call + in the graph. + """ + + # This is seen in inspect signature where we check if the value is a default value + if name == "__eq__" and isinstance(args[0], UserDefinedClassVariable): + return variables.ConstantVariable(False) + + # For historical reasons, these ops decompose down to syntactically + # invalid aten ops because they contain the python keyword `from`, see + # discussions in #151432 for more details. + # We graph break for now since this use case is uncommon. + if name == "random_": + unimplemented( + gb_type="Tensor.random_ op", + context=f"Tensor.{name}({args=}, {kwargs=})", + explanation="This is currently not supported.", + hints=[ + "Use the out-of-place version of this op", + *graph_break_hints.SUPPORTABLE, + ], + ) + elif name == "uniform_" and "from" in kwargs: + unimplemented( + gb_type="Tensor.uniform_ op called with `from` keyword", + context=f"Tensor.{name}({args=}, {kwargs=})", + explanation="This is currently not supported.", + hints=[ + "Avoid using the `from` keyword.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + try: + handler_method = getattr(self, f"method_{name}") + except AttributeError: + pass + else: + try: + result = handler_method(*args, **kwargs) + if result: + return result + except TypeError as e: + unimplemented( + gb_type="Unhandled args for method", + context=f"call_method {self} {name} {args} {kwargs}", + explanation="Dynamo encountered an error while calling " + f"the method `{name}`.", + hints=[], + from_exc=e, + ) + + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_method", + name, + *proxy_args_kwargs([self, *args], kwargs), + ), + ) + + def method_size(self, *args, **kwargs): + return self._method_size_stride("size", *args, **kwargs) + + def method_stride(self, *args, **kwargs): + return self._method_size_stride("stride", *args, **kwargs) + + def _method_size_stride(self, name, dim=None): + dim = guard_if_dyn(dim) + + def make_const_size_variable(x, **options): + return SizeVariable( + [ConstantVariable.create(y, **options) for y in x], **options + ) + + RetVariable = ( + make_const_size_variable if name == "size" else ConstantVariable.create + ) + + # Technically, this should not be necessary, but I'm including it + # for enhanced BC, in case example_value is sometimes not set + # (it really should always be set though!) + if name != "size": + r = getattr(self, name) + elif name == "size" and self.valid_size(): + r = self.size + else: + r = None + + if r is not None: + if dim is None: + return RetVariable(r) + else: + return ConstantVariable.create(r[dim]) + + # It might still be constant! Consult the fake tensor and see + if (fake := self.proxy.node.meta.get("example_value")) is not None: + if dim is None: + fake_r = getattr(fake, name)() + if not has_free_symbols(fake_r): + # int conversion for safety, in case a SymInt refined + # to constant + return RetVariable(tuple(int(r) for r in fake_r)) + else: + fake_r = getattr(fake, name)(dim) + if not has_free_symbols(fake_r): + return ConstantVariable.create(int(fake_r)) + + def method_numel(self): + if self.valid_size(): + return ConstantVariable.create(product(self.size)) + + # It might still be constant! Consult the fake tensor and see + if (fake := self.proxy.node.meta.get("example_value")) is not None: + fake_r = fake.numel() + if not has_free_symbols(fake_r): + return ConstantVariable.create(int(fake_r)) + + method_nelement = method_numel + + def method_dim(self): + if self.ndim is not None: + return ConstantVariable.create(self.ndim) + + method_ndimension = method_dim + + def method_is_floating_point(self): + if self.dtype is not None: + return ConstantVariable.create(self.dtype.is_floating_point) + + def method_is_inference(self): + if config.fake_tensor_disable_inference_mode: + unimplemented( + gb_type="Encountered tensor.is_inference() during tracing", + context="", + explanation="tensor.is_inference() is not supported", + hints=[ + *graph_break_hints.FUNDAMENTAL, + *graph_break_hints.INFERENCE_MODE, + ], + ) + if (fake := self.proxy.node.meta.get("example_value")) is not None: + return ConstantVariable.create(fake.is_inference()) + + def method_is_complex(self): + if self.dtype is not None: + return ConstantVariable.create(self.dtype.is_complex) + + def method_is_contiguous(self, memory_format=None): + memory_format = ( + memory_format.as_python_constant() + if memory_format is not None + else torch.contiguous_format + ) + if self.is_contiguous is not None: + return ConstantVariable.create(memory_format in self.is_contiguous) + elif (fake := self.proxy.node.meta.get("example_value")) is not None: + return ConstantVariable.create( + fake.is_contiguous(memory_format=memory_format) + ) + + def method_type(self, dtype=None, non_blocking=False, **kwargs): + if ( + dtype is None + and self.dtype is not None + and isinstance(self.device, torch.device) + ): + tensortype = next( + k for k, v in tensortype_to_dtype.items() if self.dtype in v + ) + if self.device.type == "cpu": + return ConstantVariable.create(f"torch.{tensortype.__name__}") + else: + return ConstantVariable.create( + f"torch.{self.device.type}.{tensortype.__name__}" + ) + elif ( + dtype is not None + and fqn(type(dtype.as_python_constant())) == "torch.tensortype" + ): + # torch.FloatTensor, etc. are all of type "torch.tensortype". + # torch.fx's tracer fails on these types, because it doesn't support arguments of torch.tensortype type. + # So, we pass it in as a string (which is also supported, see above implementation for .type() with 0 args) + tensor_type = dtype.as_python_constant() + tensor_type_const = ConstantVariable.create(fqn(tensor_type)) + + from ..symbolic_convert import InstructionTranslator + from .builder import wrap_fx_proxy + + tx = InstructionTranslator.current_tx() + + if non_blocking: + kwargs = {"non_blocking": non_blocking, **kwargs} + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_method", + "type", + *proxy_args_kwargs([self, tensor_type_const], kwargs), + ), + ) + + def method_as_subclass(self, cls): + if isinstance(cls, TensorSubclassVariable) and cls.source: + from ..symbolic_convert import InstructionTranslator + from .torch_function import TensorWithTFOverrideVariable + + tx = InstructionTranslator.current_tx() + py_cls = cls.as_python_constant() + var = TensorWithTFOverrideVariable.from_tensor_var( + tx, self, py_cls, cls.source + ) + # See NOTE [Side effect tracking for newly constructed tensor] + tx.output.side_effects._track_obj( + object(), var, mutation_type_cls=AttributeMutationNew + ) + return var + unimplemented( + gb_type="Argument of `as_subclass` must be a non-dispatcher-style tensor subclass", + context=f"{self}.as_subclass({cls})", + explanation="Currently not supported", + hints=[ + "Avoid this call or move it outside `torch.compile` regione", + *graph_break_hints.SUPPORTABLE, + ], + ) + + def method_get_device(self): + if isinstance(self.device, torch.device): + index = self.device.index if self.device.type != "cpu" else -1 + return ConstantVariable.create(index) + + def method_element_size(self): + return ConstantVariable.create(self.dtype.itemsize) + + def method_numpy(self, *, force=False): + if not config.trace_numpy: + unimplemented( + gb_type="Tensor.numpy() with trace_numpy=False", + context=f"call_method {self} numpy", + explanation="`Tensor.numpy()` was called, but the `trace_numpy` " + "configuration was manually disabled.", + hints=[ + "Set `torch._dynamo.config.trace_numpy = True` to allow " + "Dynamo to trace through NumPy.", + ], + ) + if not np: + unimplemented( + gb_type="Tensor.numpy() without NumPy installed", + context=f"call_method {self} numpy", + explanation="`Tensor.numpy()` was called, but the NumPy library " + "is not available in the current environment.", + hints=[ + "Ensure NumPy is installed in your Python environment.", + ], + ) + if self.layout != torch.strided: + raise TypeError( + f"can't convert {self.layout} layout tensor to numpy. Use Tensor.to_dense() first" + ) + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + + # We don't check that the tensor is on CPU when force is False, as this + # allows us to execute NumPy code on CUDA. Same for requires_grad=True + if force and force.as_python_constant(): + # If the user set force=True we try to preserve the semantics (no gradients, move to CPU...) + t = self.call_method(tx, "detach", [], {}) + proxy = tx.output.create_proxy("call_method", "cpu", (t.as_proxy(),), {}) + else: + # Hacky way to create a view of self that will be marked as NumpyNdarrayVariable + proxy = tx.output.create_proxy( + "call_method", "view_as", *proxy_args_kwargs([self, self], {}) + ) + return NumpyNdarrayVariable.create(tx, proxy) + + def method_tolist(self): + from ..symbolic_convert import InstructionTranslator + from .builder import wrap_fx_proxy + + tx = InstructionTranslator.current_tx() + + def tolist(tensor, sub_proxy): + def wrap(i, sub_proxy): + return wrap_fx_proxy( + tx, + sub_proxy.item(), + ) + + if tensor.dtype not in [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + ]: + unimplemented( + gb_type="Tensor.tolist() with non-integer tensor", + context=f"call_method {self} to_list", + explanation="Dynamo currently does not support tracing " + "`tolist()` on non-integer tensors.", + hints=[ + "Ensure the input tensor to `tolist()` is an integer " + "type (e.g., int8, int16, int32, int64)." + ], + ) + + if tensor.dim() == 0: + return wrap(tensor, sub_proxy) + + if tensor.dim() == 1: + return [wrap(val, sub_proxy[i]) for i, val in enumerate(tensor)] + + return [ + tolist(sub_tensor, sub_proxy=sub_proxy[i]) + for i, sub_tensor in enumerate(tensor) + ] + + tensor = self.as_proxy().node.meta["example_value"] + out = tolist(tensor, self.as_proxy()) + return VariableTracker.build(tx, out) + + def method_backward(self, *args, **kwargs): + unimplemented( + gb_type="Unsupported Tensor.backward() call", + context=f"call_method {self} backward {args} {kwargs}", + explanation="Dynamo currently does not support tracing `Tensor.backward()`.", + hints=[*graph_break_hints.FUNDAMENTAL], + ) + + def method_data_ptr(self, *args, **kwargs): + return DataPtrVariable(self) + + def method_item(self, *args, **kwargs): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + # We enable capture_scalar_outputs when full_graph=True by default. + if not tx.one_graph and not config.capture_scalar_outputs: + self._warn_capture_scalar_outputs() + unimplemented( + gb_type="Unsupported Tensor.item() call with capture_scalar_outputs=False", + context=f"call_method {self} item {args} {kwargs}", + explanation="Dynamo does not support tracing `Tensor.item()` " + "with config.capture_scalar_outputs=False.", + hints=[ + "Set `torch._dynamo.config.capture_scalar_outputs = True` " + "or `export TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` " + "to include these operations in the captured graph.", + ], + ) + + def method___getitem__(self, *args, **kwargs): + from ..symbolic_convert import InstructionTranslator + from .builder import wrap_fx_proxy + + tx = InstructionTranslator.current_tx() + if isinstance(args[0], SymNodeVariable): + # Standard indexing will force specialization due to + # __index__. Rewrite as a regular torch op which will + # trace fine + fn, args = ( + torch.select, + [ + variables.ConstantVariable.create(0), + args[0], + ], + ) + else: + fn = operator.getitem + + proxy = tx.output.create_proxy( + "call_function", + fn, + *proxy_args_kwargs([self] + list(args), kwargs), + ) + + return wrap_fx_proxy(tx, proxy) + + @staticmethod + @functools.cache + def _warn_capture_scalar_outputs(): + user_stack = torch._guards.TracingContext.extract_stack() + user_stack_formatted = "".join(traceback.format_list(user_stack)) + log.warning( + textwrap.dedent( + """\ + Graph break from `Tensor.item()`, consider setting: + torch._dynamo.config.capture_scalar_outputs = True + or: + env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 + to include these operations in the captured graph. + + Graph break: from user code at: + %s + """ + ), + user_stack_formatted, + ) + + def method___len__(self): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + return self.call_method(tx, "size", [ConstantVariable.create(0)], {}) + + def method___iter__(self): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + return ListIteratorVariable( + self.unpack_var_sequence(tx), mutation_type=ValueMutationNew() + ) + + def method_addcmul_(self, tensor1, tensor2, *, value=None): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + if value is not None: + from .. import polyfills + + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.addcmul_inplace), + [self, tensor1, tensor2, value], + {}, + ) + + def method___setitem__(self, key, value): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + proxy = tx.output.create_proxy( + "call_function", + operator.setitem, + *proxy_args_kwargs([self, key, value], {}), + ) + + if value.is_tensor(): + # [Note: Tensor.__setitem__ and VariableTracker metadata] + # At this point, we proxied a node representing `self[key] = value` into the graph. + # When executed, this node will mutate `self`'s tensor metadata, so it's important + # even during tracing to propagate. For example: + # value.requires_grad is True => self.requires_grad becomes True + # value.requires_grad is True => self.has_grad_fn becomes True + + # Not sure if __setitem__ can ever save activations, disabling just in case + + # Ignore fresh unbacked symbols that could arise from the internal indexing (selection), + # that happen in code like t[idx] += 1 when idx is unbacked. Namely the selection + # during 'setitem'. + # When the selection happens if idx is unbacked we allocate a new unbacked symbol for the + # storage offset in select_meta, but the output of the operation 'setitem' does not depend + # on the selection. + with ( + torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(), + tx.fake_mode.shape_env.ignore_fresh_unbacked_symbols() + if tx.fake_mode and tx.fake_mode.shape_env + else nullcontext(), + ): + get_fake_value(proxy.node, tx, allow_non_graph_fake=False) + + vt = value + if isinstance(vt, variables.lazy.LazyVariableTracker): + vt = variables.lazy.LazyVariableTracker.realize_all(vt) + + self.synchronize_attributes(tx, type(vt)) + + if config.use_graph_deduplication or config.track_nodes_for_deduplication: + tx.output.region_tracker.add_node_mutation(proxy.node, 0) + + return ConstantVariable.create(None) + + def method_resize_(self, *args, **kwargs): + unimplemented( + gb_type="Unsupported Tensor.resize_() call", + context=f"call_method {self} resize_ {args} {kwargs}", + explanation="Dynamo currently does not support tracing `Tensor.resize_()`.", + hints=[], + ) + + def method_resize_as_(self, *args, **kwargs): + unimplemented( + gb_type="Unsupported Tensor.resize_as_() call", + context=f"call_method {self} resize_as_ {args} {kwargs}", + explanation="Dynamo currently does not support tracing `Tensor.resize_as_()`.", + hints=[], + ) + + def method_sparse_resize_(self, *args, **kwargs): + unimplemented( + gb_type="Unsupported Tensor.sparse_resize_() call", + context=f"call_method {self} sparse_resize_ {args} {kwargs}", + explanation="Dynamo currently does not support tracing `Tensor.sparse_resize_()`.", + hints=[], + ) + + def method_sparse_resize_and_clear_(self, *args, **kwargs): + unimplemented( + gb_type="Unsupported Tensor.sparse_resize_and_clear_() call", + context=f"call_method {self} sparse_resize_and_clear_ {args} {kwargs}", + explanation="Dynamo currently does not support tracing `Tensor.sparse_resize_and_clear_()`.", + hints=[], + ) + + def method_set_(self, *args, **kwargs): + if len(args) > 1: + # torch.Tensor.set_() has several overloads. + # aten::set_.source_Tensor(Tensor) gets special handling + # in AOTAutograd and functionalization, because it is the most common + # overload and is used by FSDP. + # graph-breaking on aten::set_source_Tensor_storage_offset for now, + # unless we find that we need to make it work. + unimplemented( + gb_type="Unsupported Tensor.set_() call", + context=f"call_method {self} set_ {args} {kwargs}", + explanation="Dynamo currently does not support tracing `Tensor.set_()` " + "overloads that include more than one argument.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + def method_add_(self, other, *, alpha=None): + if alpha is not None: + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + result = variables.TorchInGraphFunctionVariable(torch.mul).call_function( + tx, [other, alpha], {} + ) + return self.call_method(tx, "add_", [result], {}) + + def method_addcdiv_(self, tensor1, tensor2, *, value=None): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + if value is not None: + result = variables.TorchInGraphFunctionVariable(torch.div).call_function( + tx, [tensor1, tensor2], {} + ) + result = variables.TorchInGraphFunctionVariable(torch.mul).call_function( + tx, [result, value], {} + ) + return self.call_method(tx, "add_", [result], {}) + + def method___contains__(self, arg): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + + # Rewrite __contains__ here so that downstream passes can trace through + # without dealing with unbacked symbool. Roughly the code we translate is: + # def __contains__(self, x): + # return (x == self).any().item() + result = variables.TorchInGraphFunctionVariable(torch.eq).call_function( + tx, [self, arg], {} + ) + result = variables.TorchInGraphFunctionVariable(torch.any).call_function( + tx, [result], {} + ) + return result.call_method(tx, "item", [], {}) + + def method_redistribute(self, *args, **kwargs): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function + # and rewrite args to have only proxyable args, then insert call_function + args_as_value = [x.as_python_constant() for x in args] + kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()} + + def redistribute_fn_with_prim_types(x): + return x.redistribute(*args_as_value, **kwargs_as_value) + + # attach the same function name for better debugging + redistribute_fn_with_prim_types.__name__ = "prim_redistribute" + + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + redistribute_fn_with_prim_types, + *proxy_args_kwargs([self], {}), + ), + ) + + def method_to_local(self, *args, **kwargs): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function + # and rewrite args to have only proxyable args, then insert call_function + + grad_placements_vt = kwargs.get( + "grad_placements", ConstantVariable.create(None) + ) + if isinstance(grad_placements_vt, variables.UserDefinedObjectVariable): + # grad_placement is a sequence-like structure, iterate over the value + grad_placements_vt = variables.BuiltinVariable(tuple).call_function( + tx, [grad_placements_vt], {} + ) + + if kwargs.get("grad_placements") is not None: + kwargs["grad_placements"] = grad_placements_vt + + args_as_value = [x.as_python_constant() for x in args] + kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()} + + def to_local_fn_with_prim_types(x): + return x.to_local(*args_as_value, **kwargs_as_value) + + # attach the same function name for better debugging + to_local_fn_with_prim_types.__name__ = "prim_to_local" + + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + to_local_fn_with_prim_types, + *proxy_args_kwargs([self], {}), + ), + ) + + def method_register_hook(self, *args, **kwargs): + return self._method_register_hook("register_hook", *args, **kwargs) + + def method_register_post_accumulate_grad_hook(self, *args, **kwargs): + return self._method_register_hook( + "register_post_accumulate_grad_hook", *args, **kwargs + ) + + def _method_register_hook(self, name: str, hook: VariableTracker): + # Note - do not arbitrarily add hooks here - make sure they match the same contract + # see [On tensor.register_hook] + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + + if not self.source: + if not compiled_autograd.compiled_autograd_enabled: + # TODO(voz): + # We can relax this by speculating the callable and ensuring that it doesn't modify arbitrary + # python state. + # We *Must* be in compiled_autograd here because backward hooks can contain anything, and it is unsafe to run + # them in a compiled bwd without re-entering dynamo as compiled_autograd does. + # + # Discussion point 1 - Should we bypass this if nopython/fullgraph = True? + # No. Because this was going to be a graph break anyway - this check does not + # introduce new graph breaks where there were none. + # + # Discussion point 2 - Should we defer this check to backwards? + # No. Because compiled autograd is not yet ready for prime time. As such, if we defer, a user + # would have no recourse - their forward traces just fine, but will fail at backwards unless + # compiled_autograd is enabled. If compiled_autograd fails (there are a lot of failures today) + # then they have nothing they can do except disable compile. + unimplemented( + gb_type="Compilation of intermediate hooks requires compiled autograd", + context=f"var_getattr {self} {name}", + explanation="Dynamo must be in compiled_autograd to register hooks.", + hints=[], + ) + + hook_name, bw_state_proxy = tx.output.add_backward_state_hook(hook) + + def _register_hook_trampoline(tensor, bw_state): + register_hook = getattr(tensor, name) + register_hook( + functools.partial( + trace_wrapped, + fn=call_hook_from_backward_state, + bw_state=bw_state, + hook_name=hook_name, + ) + ) + # TODO(jansel): returning None here is wrong, it should be + # RemovableHandle, but we need some extra work to support + # this properly. + return None + + from .builder import wrap_fx_proxy + + self_proxy = self.as_proxy() + self_proxy.node.meta["has_backward_hook"] = True + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", + _register_hook_trampoline, + (self_proxy, bw_state_proxy), + {}, + ), + ) + + handle_variable = variables.RemovableHandleVariable( + mutation_type=variables.base.ValueMutationNew(), + ) + tx.output.side_effects.register_hook(self, hook, handle_variable, name) + return handle_variable + + def method_requires_grad_(self, requires_grad=True): + if requires_grad is not True: + requires_grad = requires_grad.as_python_constant() + + if self.as_proxy().node.meta["example_value"].requires_grad != requires_grad: + unimplemented( + gb_type="Unsupported Tensor.requires_grad_() call", + context=f"call_method {self} requires_grad_", + explanation="Dynamo does not support changes to a Tensor's " + "`requires_grad` through calling `requires_grad_()`.", + hints=[], + ) + else: + return self + + def method_new(self, *args, **kwargs): + # Convert x.new(torch.Size) into x.new_empty(torch.Size), + # as Tensor.new acts differently with a Size input versus a tuple input. + if (len(args) == 1 and isinstance(args[0], SizeVariable)) or ( + len(args) >= 1 + and all( + a.is_python_constant() and isinstance(a.as_python_constant(), int) + for a in args + ) + ): + from ..symbolic_convert import InstructionTranslator + + return self.call_method( + InstructionTranslator.current_tx(), "new_empty", args, kwargs + ) + + def method_untyped_storage(self): + return UntypedStorageVariable( + self, self.as_proxy().node.meta["example_value"].untyped_storage() + ) + + def set_name_hint(self, name: str): + if not self._is_name_set: + self.proxy.node._rename(name) + self._is_name_set = True + + def is_python_hashable(self): + # Tensors are hashable if they have an example_value (a fake tensor) + # Most VT's should have one. + # It'd be nice if at some point we could assert that they all have one + return self.as_proxy().node.meta["example_value"] is not None + + def get_python_hash(self): + return hash(self.as_proxy().node.meta["example_value"]) + + def is_python_equal(self, other): + a = self.as_proxy().node.meta["example_value"] + b = other.as_proxy().node.meta["example_value"] + return a is b + + +class SymNodeVariable(VariableTracker): + """ + Represents a symbolic scalar, either int, float or bool. This is most commonly used to + handle symbolic size computation, e.g., tensor.size(0), but it is also used to + handle logic like float_tensor.item() or unspecialized float inputs. + """ + + _nonvar_fields = { + "proxy", + "sym_num", + *VariableTracker._nonvar_fields, + } + + def debug_repr(self): + return repr(self.sym_num) + + @classmethod + def create(cls, tx, proxy, sym_num=None, **options): + if sym_num is None: + sym_num = get_fake_value(proxy.node, tx) + if "example_value" in proxy.node.meta: + assert proxy.node.meta["example_value"] == sym_num + set_example_value(proxy.node, sym_num) + + if isinstance(sym_num, (sympy.Integer, int, bool)): + sym_num = int(sym_num) if isinstance(sym_num, sympy.Integer) else sym_num + return ConstantVariable.create(sym_num) + + out = SymNodeVariable(proxy, sym_num, **options) + if proxy.node.op != "placeholder": + tx.output.current_tracer.record_tensor_or_symint_vt(out) + return out + + def __init__(self, proxy, sym_num, **kwargs) -> None: + super().__init__(**kwargs) + self.proxy = proxy + # TODO: Should we allow non SymTypes here? Today it is allowed + self.sym_num = sym_num + self._tensor_var = None + + def python_type(self): + if isinstance(self.sym_num, SymTypes): + return self.sym_num.node.pytype + else: + return type(self.sym_num) + + def is_symnode_like(self) -> bool: + return True + + def as_proxy(self): + return self.proxy + + def as_tensor(self, tx, dtype): + if self._tensor_var is None: + self._tensor_var = VariableTracker.build( + tx, torch.scalar_tensor + ).call_function(tx, [self], {"dtype": VariableTracker.build(tx, dtype)}) + return self._tensor_var + + def evaluate_expr(self, output_graph=None): + try: + return guard_scalar(self.sym_num) + except GuardOnDataDependentSymNode as e: + if torch.fx.experimental._config.no_data_dependent_graph_break: + raise + + raise UserError( # noqa: B904 + UserErrorType.ANTI_PATTERN, + f"Consider annotating your code using torch._check*(). {str(e)}", + case_name="constrain_as_size_example", + ) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_method", + name, + *proxy_args_kwargs([self, *args], kwargs), + ), + ) + + def is_python_hashable(self): + return True + + def get_python_hash(self): + # Essentially convert the SymNode to a constant variable whenever its + # searched for a dict key. + return hash(self.evaluate_expr()) + + def is_python_equal(self, other): + if isinstance(other, SymNodeVariable): + return self.evaluate_expr() == other.evaluate_expr() + # could be constant variable as well + return self.evaluate_expr() == other.as_python_constant() + + +class NumpyNdarrayVariable(TensorVariable): + """ + Represents a np.ndarray, but backed by torch Tensor via torch._numpy.ndarray. + Use this for Tensor.numpy() call. + """ + + @staticmethod + def create(tx: "InstructionTranslator", proxy, **options): + from .builder import wrap_fx_proxy_cls + + return wrap_fx_proxy_cls( + target_cls=NumpyNdarrayVariable, + tx=tx, + proxy=proxy, + **options, + ) + + def var_getattr(self, tx: "InstructionTranslator", name): + # NB: This INTENTIONALLY does not call super(), because there is + # no intrinsic reason ndarray properties are related to Tensor + # properties. The inheritance here is for implementation sharing. + + from ..utils import numpy_attr_wrapper + from .builder import wrap_fx_proxy + + result = None + + example_value = self.as_proxy().node.meta["example_value"] + example_ndarray = tnp.ndarray(example_value) + + def insert_into_graph(): + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", numpy_attr_wrapper, (self.as_proxy(), name), {} + ), + ) + + if name in ["T", "real", "imag"]: + proxy = tx.output.create_proxy( + "call_function", + numpy_attr_wrapper, + (self.as_proxy(), name), + {}, + ) + result = NumpyNdarrayVariable.create(tx, proxy) + + # These are awkward to implement. The standard playbook for torch._numpy + # interop is to trace a call into the torch._numpy wrapper which works for + # Tensor operations. However, we don't want to do this for calls + # that don't return Tensors, because in those cases we may not want + # to trace the attribute access into the graph at all (it is sort + # of harmless to do so, because AOTAutograd will eliminate them, + # but it's best not to trace them in to begin with.) But in any + # case, tracing these into the graph is like trying to fit a square + # peg into a round hole; best not to do it. So instead we + # painstakingly implement these by hand + # + # NB: only ALWAYS specialized attributes can go here; notably, + # size/shape not allowed! + elif name in ("ndim", "itemsize"): + return ConstantVariable.create(getattr(example_ndarray, name)) + elif name in ("shape", "stride"): + if not has_free_symbols(r := getattr(example_ndarray, name)): + return ConstantVariable.create(tuple(int(r) for r in r)) + return insert_into_graph() + elif name == "size": + if not has_free_symbols(r := example_ndarray.size): + return ConstantVariable.create(int(r)) + return insert_into_graph() + elif name in ["base", "flags", "dtype"]: + unimplemented( + gb_type="Unsupported ndarray attribute access", + context=f"var_getattr {self} {name}", + explanation=f"Dynamo currently does not support tracing `ndarray.{name}`.", + hints=[], + ) + elif name == "__version__": + unimplemented( + gb_type="Unsupported ndarray.__version__ access", + context=f"var_getattr {self} {name}", + explanation=f"Dynamo currently does not support tracing `ndarray.{name}`.", + hints=[], + ) + if result is None: + raise NotImplementedError + return result + + @staticmethod + def patch_args(name, args, kwargs): + if name == "clip": + kwargs_rename = {"a_min": "min", "a_max": "max"} + kwargs = {kwargs_rename.get(k, k): v for k, v in kwargs.items()} + return args, kwargs + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from ..exc import unimplemented + from ..utils import numpy_method_wrapper + + args, kwargs = self.patch_args(name, args, kwargs) + + if name == "astype": + from .builtin import BuiltinVariable + + dtype_arg = None + if "dtype" in kwargs: + dtype_arg = kwargs["dtype"] + elif len(args) > 0: + dtype_arg = args[0] + is_object_str = dtype_arg is not None and dtype_arg.is_constant_match("O") + is_object_type = ( + isinstance(dtype_arg, BuiltinVariable) and dtype_arg.fn is object + ) + if is_object_str or is_object_type: + unimplemented( + gb_type="ndarray.astype(object)", + context=f"call_method {self} {name} {args} {kwargs}", + explanation=( + "`ndarray.astype('O')` or `ndarray.astype(object)` is not supported " + "by torch.compile, as there is no equivalent to object type in torch.Tensor. " + "This will be executed eagerly." + ), + hints=[*graph_break_hints.FUNDAMENTAL], + ) + if name in ["__len__", "size", "tolist", "__iter__"]: + # delegate back to TensorVariable + return super().call_method(tx, name, args, kwargs) + if name in ("tostring", "tobytes", "__delattr__"): + unimplemented( + gb_type="Unsupported ndarray method call", + context=f"call_method {self} {name} {args} {kwargs}", + explanation=f"`ndarray.{name}()` is not modelled in `torch._numpy`.", + hints=[], + ) + proxy = tx.output.create_proxy( + "call_function", + numpy_method_wrapper(name), + *proxy_args_kwargs([self] + list(args), kwargs), + ) + return NumpyNdarrayVariable.create(tx, proxy) + + def python_type(self): + return np.ndarray + + +class UnspecializedPythonVariable(TensorVariable): + """ + This is a 1-element tensor represents unspecialized python float/int. + """ + + _nonvar_fields = { + "raw_value", + "need_unwrap", + *TensorVariable._nonvar_fields, + } + + def __init__( + self, proxy: torch.fx.Proxy, *, raw_value=None, need_unwrap=True, **kwargs + ) -> None: + super().__init__(proxy, **kwargs) + self.raw_value = raw_value + self.need_unwrap = need_unwrap + + @classmethod + def from_tensor_variable(cls, tensor_variable, raw_value, need_unwrap=True): + # Convert a `TensorVariable` instance into an `UnspecializedPythonVariable` instance. + return UnspecializedPythonVariable( + **dict(tensor_variable.__dict__), + raw_value=raw_value, + need_unwrap=need_unwrap, + ) + + +class FakeItemVariable(TensorVariable): + """An unspecialized python variable which prevents access to the underlying raw value. + This is needed if item is called on a FakeTensor.""" + + _nonvar_fields = { + "need_unwrap", + *TensorVariable._nonvar_fields, + } + + def __init__(self, proxy: torch.fx.Proxy, **kwargs) -> None: + need_unwrap = kwargs.pop("need_unwrap", False) + super().__init__(proxy, **kwargs) + self.need_unwrap = need_unwrap + + @classmethod + def from_tensor_variable(cls, tensor_variable): + return FakeItemVariable(**dict(tensor_variable.__dict__)) + + +class TensorSubclassVariable(UserDefinedClassVariable): + def call_function( + self, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # Handle `Subclass(existing_tensor, ...)` calls. + from .torch_function import TensorWithTFOverrideVariable + + new_func = self.value.__new__ + if new_func is torch.Tensor.__new__: + if len(args) == 1 and args[0].is_tensor() and len(kwargs) == 0: + data = args[0] + # Simulate `torch.Tensor.__new__` as shallow-copying the input + # tensor data with a new type. TODO polyfill? + var = TensorWithTFOverrideVariable.from_tensor_var( + tx, data, self.value, self.source + ) + else: + unimplemented( + gb_type="Calling subclass default constructor with more than tensor argument", + context=f"{self.value}(args={args}, kwargs={kwargs})", + explanation="Currently not supported", + hints=[ + "Avoid this constructor call or move it outside " + "`torch.compile` regione", + *graph_break_hints.SUPPORTABLE, + ], + ) + else: + # Let Dynamo trace through custom `__new__` + var = VariableTracker.build(tx, new_func).call_function( + tx, [self] + args, kwargs + ) + + # Let Dynamo trace through custom `__init__` + init_func = self.value.__init__ + # TODO builder should be able to handle `torch.Tensor.__init__`, + # which is `object.__init__`, so that we can remove this check. + if init_func is not torch.Tensor.__init__: + VariableTracker.build(tx, init_func).call_function(tx, [var], kwargs) + + # See NOTE [Side effect tracking for newly constructed tensor] + tx.output.side_effects._track_obj( + object(), var, mutation_type_cls=AttributeMutationNew + ) + return var + + def as_python_constant(self): + return self.value + + +class UntypedStorageVariable(VariableTracker): + _nonvar_fields = { + "example_value", + *VariableTracker._nonvar_fields, + } + + def __init__( + self, + from_tensor: TensorVariable, + example_value: torch.UntypedStorage, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.from_tensor = from_tensor + # Example_value will always have device="meta" + self.example_value = example_value + + def call_method( + self, + tx, + name, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "size": + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + result = self.example_value.size() + if not has_free_symbols(result): + # avoid creating a node in the graph + return ConstantVariable.create(int(result)) + else: + from ..external_utils import untyped_storage_size + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", + untyped_storage_size, + (self.from_tensor.as_proxy(),), + {}, + ), + ) + if name == "resize_" and len(args) == 1: + if kwargs: + raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") + tx.output.create_proxy( + "call_function", + torch.ops.inductor.resize_storage_bytes_, + (self.from_tensor.as_proxy(), args[0].as_proxy()), + {}, + ) + return self + + return super().call_method(tx, name, args, kwargs) + + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.from_tensor) + codegen.load_method("untyped_storage") + codegen.call_method(0) + + +class DataPtrVariable(VariableTracker): + def __init__( + self, + from_tensor: TensorVariable, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.from_tensor = from_tensor + + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.from_tensor) + codegen.load_method("data_ptr") + codegen.call_method(0) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/torch.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/torch.py new file mode 100644 index 0000000000000000000000000000000000000000..9a3c3afc551b8f0fde5527bf4adcae3689bb3b9e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/torch.py @@ -0,0 +1,2183 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs + +""" +This module implements variable tracking for torch functions and operations during Dynamo tracing. + +It provides classes to handle different types of torch operations: + +TorchInGraphFunctionVariable: Handles torch.* functions that should be captured in the FX graph. +Provides special handling for constant folding, tensor methods, and torch function overrides. +Manages complex cases like out= variants and parameter construction. + +TorchCtxManagerClassVariable: Handles torch context managers like torch.no_grad(), autocast, etc. +Provides implementations for entering/exiting these contexts during tracing. + +DispatchKeySetVariable: Represents torch.DispatchKeySet for managing dispatch keys and +device-specific operations during tracing. + +The module includes special handling for: +- Constant folding of pure functions +- Tensor method calls +- torch.nn.Parameter construction +- __torch_function__ overrides +- Context manager state tracking +- Device and dtype management + +This is a core part of Dynamo's tracing system, translating torch operations into +traceable graph nodes while preserving correct semantics and handling edge cases. +""" + +import functools +import inspect +import logging +import math +import re +from collections.abc import Callable, Sequence +from typing import Any, Optional, TYPE_CHECKING + +import torch._C +import torch._refs +import torch.fx +import torch.nn +from torch._guards import TracingContext +from torch._logging import warning_once +from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type + +from .. import config, graph_break_hints, polyfills, variables +from ..codegen import PyCodegen +from ..create_parameter_op import ( + can_convert_to_tracable_parameter, + new_parameter_placeholder, + tracable_create_parameter, +) +from ..device_interface import get_registered_device_interfaces +from ..exc import raise_observed_exception, unimplemented +from ..guards import GuardBuilder, install_guard +from ..source import ( + AttrSource, + CallFunctionNoArgsSource, + SyntheticLocalSource, + TorchSource, +) +from ..utils import ( + check_unspec_or_constant_args, + guard_if_dyn, + has_torch_function, + hashable, + is_wrapper_or_member_descriptor, + product, + proxy_args_kwargs, + unwrap_if_wrapper, +) +from .base import raise_type_error_exc, typestr, VariableTracker +from .ctx_manager import ( + AutocastModeVariable, + ProfilerContextVariable, + TorchFunctionDisableVariable, +) +from .dicts import ConstDictVariable +from .distributed import DistributedVariable, ProcessGroupVariable +from .functions import bind_args_cached, NestedUserFunctionVariable +from .lists import ListVariable, TupleVariable +from .torch_function import ( + can_dispatch_torch_function, + dispatch_torch_function, + TensorWithTFOverrideVariable, + TorchFunctionModeStackVariable, +) + + +try: + import numpy as np +except ModuleNotFoundError: + np = None # type: ignore[assignment] + +try: + from torch.distributed.fsdp._fully_shard import _fsdp_param_group +except ModuleNotFoundError: + _fsdp_param_group = None # type: ignore[assignment] + + +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslator + + +log = logging.getLogger(__name__) + +supported_ctx_manager_classes = dict.fromkeys( + [ + torch.profiler.profiler.profile, + torch.autograd.forward_ad._set_fwd_grad_enabled, + torch.autograd.forward_ad.dual_level, + torch.autograd.profiler.profile, + torch.autograd.profiler.record_function, + torch._C.DisableTorchFunctionSubclass, + torch._C.DisableTorchFunction, + torch._functorch.vmap.vmap_increment_nesting, + torch._functorch.eager_transforms.grad_increment_nesting, + torch._functorch.eager_transforms.jvp_increment_nesting, + torch._functorch.eager_transforms.enable_inplace_requires_grad, + torch.amp.autocast_mode.autocast, + torch.autograd.grad_mode.enable_grad, + torch.autograd.grad_mode.inference_mode, + torch.autograd.grad_mode.no_grad, + torch.autograd.grad_mode.set_grad_enabled, + torch.autograd.graph.disable_saved_tensors_hooks, + torch.cpu.amp.autocast_mode.autocast, + torch.cuda.amp.autocast_mode.autocast, + torch.fx.traceback.annotate, + torch.fx.traceback.annotate.__wrapped__, # type: ignore[attr-defined] + # We'll let Dynamo inline into the contextlib part of these context + # manager instances, all the way till it invokes the wrapped function + # itself (at which point we wrap it back to special context manager + # VTs). + # + # This allows us to support calling functions decorated with these + # context managers, without much extra effort or code dup. + torch.nn.attention.sdpa_kernel.__wrapped__, # type: ignore[attr-defined] + ] +) + + +REWRITE_OPS_TO_TENSOR_SIZE_METHOD = dict.fromkeys( + [ + torch._shape_as_tensor, + ] +) + +constant_fold_functions_need_guards = [ + torch.accelerator.current_device_index, + torch.accelerator.current_accelerator, + torch.cuda.current_device, + torch.cuda.is_initialized, + torch.xpu.current_device, + torch.xpu.is_initialized, +] + +constant_fold_functions = [ + torch._assert, + torch._utils._get_device_index, + torch._C._get_cublas_allow_tf32, + torch._C._is_any_autocast_enabled, + torch.accelerator.is_available, + torch.cuda.get_device_properties, + torch.cuda.is_available, + torch.distributed.is_available, + torch.get_autocast_dtype, + torch.get_autocast_gpu_dtype, + torch.get_default_dtype, + torch.is_autocast_cache_enabled, + torch.is_autocast_cpu_enabled, + torch.is_autocast_enabled, + torch.is_complex, + torch.is_floating_point, + torch.nn.functional._Reduction.get_enum, # type: ignore[attr-defined] + torch.promote_types, + torch._C._get_privateuse1_backend_name, + torch.autograd._is_checkpoint_valid, + torch.xpu.get_device_properties, + torch.xpu.is_available, +] + constant_fold_functions_need_guards +if torch.distributed.is_available(): + constant_fold_functions.extend( + [ + torch.distributed.is_initialized, + torch.distributed.get_rank, + torch.distributed.get_world_size, + ] + ) +# Convert to dict for O(1) access times +constant_fold_functions_need_guards = dict.fromkeys(constant_fold_functions_need_guards) +constant_fold_functions = dict.fromkeys(constant_fold_functions) + + +@functools.cache +def tracing_state_functions() -> dict[Callable[[], Any], Optional[bool]]: + # Defined as a function to avoid circular import like torch.onnx + return { + torch.jit.is_scripting: False, + torch.jit.is_tracing: False, + torch._C._get_tracing_state: None, + torch.fx._symbolic_trace.is_fx_tracing: False, + torch.fx._symbolic_trace.is_fx_symbolic_tracing: False, + torch.onnx.is_in_onnx_export: False, + torch._dynamo.external_utils.is_compiling: True, + torch._utils.is_compiling: True, + torch.compiler.is_compiling: True, + torch.compiler.is_dynamo_compiling: True, + torch.compiler.is_exporting: True, + torch._dynamo.eval_frame._is_in_optimized_module: True, + # Look into https://github.com/pytorch/pytorch/pull/164721 why this is + # turned to True for Dynamo. + torch.nn.modules.activation._is_make_fx_tracing: True, + } + + +bin_ops = dict.fromkeys(["add", "sub", "mul", "div", "sqrt"]) + +dispatch_key_set_functions = { + torch._C._dispatch_keys, + torch._C._dispatch_tls_local_include_set, + torch._C._dispatch_tls_local_exclude_set, +} + + +@functools.cache +def get_overridable_functions(): + from itertools import chain + + from torch.overrides import get_overridable_functions as get_overridable_functions_ + + funcs = set(chain.from_iterable(get_overridable_functions_().values())) + more: set[Callable[..., Any]] = { + torch.ones, + torch.ones_like, + torch.zeros, + torch.zeros_like, + torch.empty, + torch.full, + } + funcs.update(more) + return funcs + + +class BaseTorchVariable(VariableTracker): + """common base for all torch.* functions, classes, modules and other things""" + + @classmethod + def create_with_source(cls, value, source): + if inspect.isclass(value): + install_guard(source.make_guard(GuardBuilder.CLASS_MATCH)) + elif inspect.ismodule(value): + install_guard(source.make_guard(GuardBuilder.MODULE_MATCH)) + elif inspect.isfunction(value): + install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH)) + elif inspect.isbuiltin(value) or isinstance( + value, (torch._ops.OpOverload, torch._ops.OpOverloadPacket) + ): + install_guard(source.make_guard(GuardBuilder.BUILTIN_MATCH)) + elif is_wrapper_or_member_descriptor(value) or isinstance( + value, torch._dynamo.compiled_autograd.Op + ): + # Dont need to guard on wrappers + pass + else: + install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH)) + return cls(value, source=source) + + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + + def reconstruct(self, codegen: "PyCodegen"): + try: + name = f"{self.value.__module__}.{self.value.__name__}" + except Exception: + name = f"torch_obj_{id(self.value)}" + unique_var_name = "__" + re.sub(r"[^a-zA-Z0-9_]+", "_", name) + codegen.extend_output( + codegen.setup_globally_cached(unique_var_name, self.value) + ) + + def as_proxy(self): + return self.value + + def as_python_constant(self): + return self.value + + def call_obj_hasattr(self, tx: "InstructionTranslator", name): + result = hasattr(self.value, name) + return variables.ConstantVariable.create(result) + + def can_constant_fold_through(self): + if self.value in constant_fold_functions: + return True + + if ( + self.value is torch.autograd._profiler_enabled + and config.constant_fold_autograd_profiler_enabled + ): + # The relevant flag is enabled only for export. One might wonder + # why? + # + # Actually we would like to not graph break even in the case of + # Dynamo. But there is a weird-unsolved bug with Kineto + Dynamo + # when there are distributed jobs that lead to NCCL timeouts. This + # bug is a rare edege case, but we have not been able to root cause + # it yet. See https://www.internalfb.com/sevmanager/view/560336 for + # more details. + # + # So is this safe for export? Yes, for export, we do not anticipate + # JIT tracing in distributed job training, and the weird edge-case + # interaction with Kineto is not a valid usecase. So, this is ok. + return True + + return getattr(self.value, "__module__", None) == "math" + + +class TorchCtxManagerClassVariable(BaseTorchVariable): + """Points to a context manager class in torch.* that dynamo has implementations""" + + def __repr__(self) -> str: + return f"TorchCtxManagerClassVariable({self.value})" + + @staticmethod + def is_matching_cls(value): + # Unwrap if it's a functools.lru_cache wrapper + value = unwrap_if_wrapper(value) + # We can't do isinstance(value, type) check because some ctx managers + # are implemented as a function decorated by contextlib.contextmanager, + # E.g., torch._functorch.vmap.vmap_increment_nesting. + return ( + # Context manager type or function with @contextmanager is callable + callable(value) + and ( + hashable(value) # accesses value.__hash__() + and value in supported_ctx_manager_classes + ) + ) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from . import ( + DisabledSavedTensorsHooksVariable, + DualLevelContextManager, + FSDPParamGroupUseTrainingStateVariable, + FxTracebackAnnotateVariable, + GradIncrementNestingCtxManagerVariable, + GradInplaceRequiresGradCtxManagerVariable, + GradModeVariable, + InferenceModeVariable, + JvpIncrementNestingCtxManagerVariable, + SDPAKernelVariable, + SetFwdGradEnabledContextManager, + StreamVariable, + VmapIncrementNestingCtxManagerVariable, + ) + + if self.value is torch.no_grad: + if len(args) == 1 and isinstance( + args[0], variables.functions.BaseUserFunctionVariable + ): + ctx = GradModeVariable.create(tx, False) + return ctx.call_function(tx, args, kwargs) + else: + return GradModeVariable.create(tx, False) + elif self.value is torch.enable_grad: + if len(args) == 1 and isinstance( + args[0], variables.functions.BaseUserFunctionVariable + ): + ctx = GradModeVariable.create(tx, True) + return ctx.call_function(tx, args, kwargs) + return GradModeVariable.create(tx, True) + elif self.value is torch.set_grad_enabled and len(args) == 1: + return GradModeVariable.create( + tx, args[0].as_python_constant(), initialized=True + ) + elif self.value is torch.inference_mode: + assert len(args) <= 1 and len(kwargs) == 0 + inf_mode = args[0].as_python_constant() if len(args) == 1 else True + return InferenceModeVariable.create(tx, inf_mode) + elif self.value in ( + torch.fx.traceback.annotate, + torch.fx.traceback.annotate.__wrapped__, # type: ignore[attr-defined] + ): + assert len(args) <= 1 and len(kwargs) == 0 + return FxTracebackAnnotateVariable( + args[0].as_python_constant(), source=self.source + ) + elif inspect.isclass(self.value) and issubclass(self.value, torch.Stream): + from torch._dynamo.variables.builder import wrap_fx_proxy_cls + + return wrap_fx_proxy_cls( + StreamVariable, + tx, + tx.output.create_proxy( + "call_function", + self.value, + (), + {}, + ), + ) + elif self.value in ( + torch.amp.autocast_mode.autocast, + torch.cuda.amp.autocast, + torch.cpu.amp.autocast, + ): + # pyrefly: ignore [bad-argument-type] + return AutocastModeVariable.create(self.value, args, kwargs) + elif self.value in ( + # NOTE any class added here must align with the semantic + # requirements of `ProfilerContextVariable`. + torch.profiler.profile, + torch.profiler.record_function, + torch.autograd.profiler.profile, + torch.autograd.profiler.record_function, + ): + warning_once(log, "Profiler function %s will be ignored", self.value) + return ProfilerContextVariable() + elif ( + self.value is torch._C.DisableTorchFunctionSubclass + or self.value is torch._C.DisableTorchFunction + ): + assert not (args or kwargs) + return TorchFunctionDisableVariable.create( + tx, only_subclass=self.value is torch._C.DisableTorchFunctionSubclass + ) + elif self.value is torch._functorch.vmap.vmap_increment_nesting: + assert len(args) == 2 + return VmapIncrementNestingCtxManagerVariable.create( + tx, + args, + ) + elif self.value is torch._functorch.eager_transforms.jvp_increment_nesting: + assert len(args) == 0 + return JvpIncrementNestingCtxManagerVariable.create(tx) + elif self.value is torch.autograd.forward_ad._set_fwd_grad_enabled: + assert len(args) == 1 + return SetFwdGradEnabledContextManager.create( + tx, + [guard_if_dyn(x) for x in args], + ) + elif self.value is torch.autograd.forward_ad.dual_level: + assert len(args) == 0 + return DualLevelContextManager.create(tx) + elif self.value is torch._functorch.eager_transforms.grad_increment_nesting: + assert len(args) == 0 + return GradIncrementNestingCtxManagerVariable.create(tx) + elif ( + self.value is torch._functorch.eager_transforms.enable_inplace_requires_grad + ): + assert len(args) == 1 + return GradInplaceRequiresGradCtxManagerVariable.create( + tx, + [guard_if_dyn(x) for x in args], + ) + elif self.value is torch.autograd.graph.disable_saved_tensors_hooks: + assert len(args) == 1 + return DisabledSavedTensorsHooksVariable.create( + tx, args[0].as_python_constant() + ) + elif ( + _fsdp_param_group is not None + and self.value is _fsdp_param_group.FSDPParamGroup.use_training_state + ): + assert len(args) == 2 + return FSDPParamGroupUseTrainingStateVariable.create( + tx, args[0], args[1].as_python_constant() + ) + elif self.value is torch.nn.attention.sdpa_kernel.__wrapped__: # type: ignore[attr-defined] + name_to_arg_map = bind_args_cached( + # pyrefly: ignore[bad-argument-type] + self.value, + tx, + self.source, + args, + kwargs, + ) + backends = name_to_arg_map["backends"].as_python_constant() + set_priority = name_to_arg_map["set_priority"].as_python_constant() + return SDPAKernelVariable.create(tx, backends, set_priority) + + return super().call_function(tx, args, kwargs) + + +class TorchInGraphFunctionVariable(BaseTorchVariable): + """Points to a torch function/method that should be put in FX graph""" + + def __init__(self, value, nonstrict_traceable=None, **kwargs) -> None: + super().__init__(value, **kwargs) + from ..trace_rules import is_nonstrict_trace_callable + + if nonstrict_traceable is None: + nonstrict_traceable = is_nonstrict_trace_callable(value) + self.nonstrict_traceable = nonstrict_traceable + + def __repr__(self) -> str: + return f"TorchInGraphFunctionVariable({self.value}, nonstrict_traceable={self.nonstrict_traceable})" + + def get_function(self): + return self.value + + @staticmethod + @functools.cache + def _get_handlers(): + """Build a dict from function -> method to handle it so that we are O(1) + in terms of the number of function with special handling.""" + handlers = {} + + def register(*fns): + def _register(handler): + for fn in fns: + assert fn not in handlers, fn + handlers[fn] = handler + return handler + + assert callable(fns[0]) + return _register + + from torch.backends.cuda import SDPAParams + + from . import ( + ConstantVariable, + DeterministicAlgorithmsVariable, + GradModeVariable, + StreamContextVariable, + SymNodeVariable, + TensorVariable, + UserDefinedObjectVariable, + ) + from .builder import wrap_fx_proxy, wrap_fx_proxy_cls + + @register(*tracing_state_functions()) + def handle_tracing_state_functions( + self, tx: "InstructionTranslator", *args, **kwargs + ): + assert not args and not kwargs + # See: https://github.com/pytorch/pytorch/issues/110765 + if self.value in ( + torch._utils.is_compiling, + torch._dynamo.external_utils.is_compiling, + torch.compiler.is_compiling, + torch.compiler.is_dynamo_compiling, + torch.compiler.is_exporting, + torch._dynamo.eval_frame._is_in_optimized_module, + ): + tx.mark_inconsistent_side_effects() + return ConstantVariable.create(tracing_state_functions()[self.value]) + + @register(*dispatch_key_set_functions) + def handle_dispatch_key_set_functions( + self, tx: "InstructionTranslator", *args, **kwargs + ): + assert not kwargs + if self.value is torch._C._dispatch_keys: + assert len(args) == 1 + assert args[0].is_tensor() + example_value = args[0].proxy.node.meta["example_value"] + dks = self.value(example_value) + # Remove Python and PythonTLSSnapshot from the dispatch key set, + # as they originate from FakeTensor propagation. + # This should only be done if the example_value is a FakeTensor. + # However, if tensor subclasses are present, + # it is reasonable for Python to remain in the dispatch key set. + if isinstance(example_value, torch._subclasses.FakeTensor): + dks = ( + dks + - torch._C.DispatchKeySet(torch._C.DispatchKey.Python) + - torch._C.DispatchKeySet( + torch._C.DispatchKey.PythonTLSSnapshot + ) + ) + return DispatchKeySetVariable.create(dks) + else: + assert not args + return DispatchKeySetVariable.create(self.value()) + + @register(torch.overrides.get_default_nowrap_functions.__wrapped__) + def handle_get_default_nowrap_functions( + self, tx: "InstructionTranslator", *args, **kwargs + ): + # [Note: __torch_function__] we return empty here because we restrict + # the set of functions that we trace __torch_function__ on to + # functions outside of the actual set. Implementing this properly will require implementing + # some variable types to track and compare tensor getset descriptors + return VariableTracker.build( + tx, torch.overrides.get_default_nowrap_functions() + ) + + @register(torch.ops.inductor.accumulate_grad_.default) + def handle_accumulate_grad_(self, tx: "InstructionTranslator", *args, **kwargs): + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.accumulate_grad), args, kwargs + ) + + @register(math.radians) + def handle_radians(self, tx: "InstructionTranslator", *args, **kwargs): + if not check_unspec_or_constant_args(args, kwargs): + # Use polyfill to convert math.radians(x) into math.pi * x / 180.0 + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.radians), args, kwargs + ) + + if hasattr(math, "fma"): # Python 3.13+ + + @register(math.fma) + def handle_fma(self, tx: "InstructionTranslator", *args, **kwargs): + if len(args) != 3 or kwargs: + return None + + if all(arg.is_tensor() for arg in args): + x, y, z = args + addcmul_fn = TorchInGraphFunctionVariable(torch.addcmul) + return addcmul_fn.call_function(tx, [z, x, y], {}) + + # Use math.fma if constants + return None + + @register(torch.is_inference_mode_enabled) + def handle_is_inference_mode_enabled(self, tx: "InstructionTranslator"): + unimplemented( + gb_type="Encountered torch.is_inference_mode_enabled during tracing", + context="", + explanation="torch.is_inference_mode_enabled() is not supported", + hints=[ + *graph_break_hints.FUNDAMENTAL, + *graph_break_hints.INFERENCE_MODE, + ], + ) + + @register(torch.is_tensor, torch.overrides.is_tensor_like) + def handle_is_tensor(self, tx: "InstructionTranslator", arg): + if arg.is_tensor() or ( + self.value is torch.overrides.is_tensor_like + and isinstance(arg, UserDefinedObjectVariable) + and hasattr(arg.value, "__torch_function__") + ): + return ConstantVariable.create(True) + else: + return ConstantVariable.create(False) + + @register( + torch.is_floating_point, + torch.is_complex, + ) + def handle_is_floating_point(self, tx: "InstructionTranslator", input): + input_arg = input + if input_arg.is_tensor() and input_arg.dtype is not None: + if self.value is torch.is_floating_point: + return ConstantVariable.create(input_arg.dtype.is_floating_point) + elif self.value is torch.is_complex: + return ConstantVariable.create(input_arg.dtype.is_complex) + else: + raise AssertionError(f"calling {self.value}") + + @register(torch.numel) + def handle_numel(self, tx: "InstructionTranslator", input): + if input.is_tensor() and input.valid_size(): + return ConstantVariable.create(product(input.size)) + elif input.is_tensor(): + # Workaround dynamic shapes issue + return input.call_method(tx, "numel", [], {}) + + @register(torch.compile) + def handle_torch_compile(self, tx: "InstructionTranslator", *args, **kwargs): + if len(args) == 1: + # torch.compile is a no-op in dynamo + return args[0] + + unimplemented( + gb_type="torch.compile call with > 1 args", + context=f"args={args}, kwargs={kwargs}", + explanation="Attempted to call `torch.compile` with > 1 args. Dynamo does not support this.", + hints=[ + "Remove the torch.compile call or its additional args.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + @register(*REWRITE_OPS_TO_TENSOR_SIZE_METHOD) + def handle_tensor_size_rewrites(self, tx: "InstructionTranslator", input): + assert input.is_tensor() + return input.call_method(tx, "size", [], {}) + + @register( + torch.nn.modules.utils._single, + torch.nn.modules.utils._pair, + torch.nn.modules.utils._triple, + torch.nn.modules.utils._quadruple, + torch.nn.modules.utils._ntuple, + ) + def handle_ntuple(self, tx: "InstructionTranslator", *args, **kwargs): + return self._call_ntuple(tx, args, kwargs) + + @register(torch.is_grad_enabled) + def handle_is_grad_enabled(self, tx): + install_guard(GradModeVariable._guards_singleton) + return ConstantVariable.create(torch.is_grad_enabled()) + + @register(torch.use_deterministic_algorithms) + def handle_use_deterministic_algorithms( + self, tx: "InstructionTranslator", mode, warn_only=False + ): + # pyrefly: ignore [missing-attribute] + if warn_only and warn_only.as_python_constant(): + unimplemented( + gb_type="Attempted to use torch.use_deterministic_algorithms(warn_only=True)", + context=f"mode={mode}, warn_only={warn_only}", + explanation="Dynamo does not support this.", + hints=[ + "Remove param warn_only in function call torch.use_deterministic_algorithms.", + *graph_break_hints.SUPPORTABLE, + ], + ) + return DeterministicAlgorithmsVariable.create(tx, mode.as_python_constant()) + + @register(torch.are_deterministic_algorithms_enabled) + def handle_are_deterministic_algorithms_enabled(self, tx): + install_guard(DeterministicAlgorithmsVariable._guards_singleton) + return ConstantVariable.create(torch.are_deterministic_algorithms_enabled()) + + @register(torch._C._is_torch_function_enabled) + def handle_is_torch_function_enabled(self, tx): + install_guard(TorchFunctionDisableVariable._guards_singleton) + # see comment on SymbolicTorchFunctionState class as to why + # this is not a bug + return ConstantVariable.create( + tx.symbolic_torch_function_state.torch_function_subclass_enabled + ) + + @register(torch._C._is_torch_function_all_disabled) + def handle_is_torch_function_all_disabled(self, tx): + install_guard(TorchFunctionDisableVariable._guards_singleton) + return ConstantVariable.create( + not tx.symbolic_torch_function_state.torch_function_mode_enabled + ) + + @register( + torch.overrides.has_torch_function, + torch.overrides.has_torch_function_variadic, + torch.overrides.has_torch_function_unary, + ) + def handle_has_torch_function(self, tx: "InstructionTranslator", *args): + elems = ( + args[0].unpack_var_sequence(tx) + if len(args) == 1 and isinstance(args[0], TupleVariable) + else args + ) + return ConstantVariable.create( + any(has_torch_function(x) for x in elems), + ) + + @register( + *dict.fromkeys( # remove duplicates + device_interface.stream + for _, device_interface in get_registered_device_interfaces() + ) + ) + def handle_device_interface_stream(self, tx: "InstructionTranslator", stream): + return StreamContextVariable.create(tx, stream) + + @register(torch.from_numpy) + def handle_from_numpy(self, tx: "InstructionTranslator", *args): + if not config.trace_numpy: + unimplemented( + gb_type="call `torch.from_numpy` with `torch._dynamo.config.trace_numpy=False`", + context=f"trace_numpy={config.trace_numpy}", + explanation=( + "Attempted to call `torch.from_numpy` with config " + "`torch._dynamo.config.trace_numpy` set to `False`." + ), + hints=[ + "Change `torch._dynamo.config.trace_numpy` to `True`.", + ], + ) + if not np: + unimplemented( + gb_type="`torch.from_numpy` with NumPy unavailable", + context="", + explanation="Attempted to call `torch.numpy` but NumPy could not be imported.", + hints=[ + "Check NumPy version and installation in your environment.", + *graph_break_hints.USER_ERROR, + ], + ) + return wrap_fx_proxy_cls( + target_cls=TensorVariable, + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + torch.as_tensor, + *proxy_args_kwargs(args, {}), + ), + example_value=None, + ) + + @register(torch.jit.annotate) + def handle_jit_annotate(self, tx: "InstructionTranslator", the_type, the_value): + return the_value + + @register(torch.backends.cudnn.is_acceptable) + def handle_cudnn_is_acceptable( + self, tx: "InstructionTranslator", tensor, *extra + ): + # is_acceptable(tensor) returns true if + # (a) tensor dtype/device are supported by cudnn + # (b) cudnn is available + # (c) some initialization has completed + # technically, it depends on some global state from (c) (torch.backends.cudnn.__cudnn_version) + assert not extra, "Expect 1 input to cudnn.is_acceptable" + assert tensor.is_tensor(), ( + "Expect input to cudnn.is_acceptable to be a tensor" + ) + tensor_inp = torch.tensor(0, dtype=tensor.dtype, device=tensor.device) + return ConstantVariable.create( + torch.backends.cudnn.is_acceptable(tensor_inp) + ) + + @register(torch.utils.hooks.BackwardHook) + def handle_backward_hook(self, tx: "InstructionTranslator", *args, **kwargs): + return variables.BackwardHookVariable.create(tx, *args, **kwargs) + + @register(torch.nn.Parameter) + def handle_parameter(self, tx: "InstructionTranslator", *args, **kwargs): + return self.call_nn_parameter(tx, *args, **kwargs) + + @register(torch.ops.aten.sym_size, torch.ops.aten.sym_size.int) + def handle_sym_size(self_, tx, self, dim=None): + # we see this when retracing already traced code + if dim is not None: + return self.call_method(tx, "size", [dim], {}) + + @register(torch.ops.aten.sym_stride, torch.ops.aten.sym_stride.int) + def handle_sym_stride(self_, tx, self, dim=None): + if dim is not None: + return self.call_method(tx, "stride", [dim], {}) + + @register(torch.addcdiv) + def handle_addcdiv(self, tx: "InstructionTranslator", *args, **kwargs): + if len(args) == 3 and "value" in kwargs and len(kwargs) == 1: + # decompose addcdiv into constituent ops, prevents a graph break due to converting + # value to a scalar + result = TorchInGraphFunctionVariable(torch.div).call_function( + tx, [*args[1:]], {} + ) + result = TorchInGraphFunctionVariable(torch.mul).call_function( + tx, [result, kwargs["value"]], {} + ) + return TorchInGraphFunctionVariable(torch.add).call_function( + tx, [args[0], result], {} + ) + + @register(torch.full) + def handle_full(self, tx, size, fill_value, **kwargs): + if fill_value.is_tensor(): + # Decompose: create empty tensor and fill it + # This avoids the scalar extraction at compile time + empty_result = TorchInGraphFunctionVariable(torch.empty).call_function( + tx, [size], kwargs + ) + # Call fill_ method on the empty tensor + return empty_result.call_method(tx, "fill_", [fill_value], {}) + + @register(torch._foreach_lerp_) + def handle_inplace_foreach_lerp_scalar( + _, tx: "InstructionTranslator", *args, **kwargs + ): + if len(args) == 3 and not isinstance(args[2], ListVariable) and not kwargs: + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.foreach_lerp_inplace), + args, + kwargs, + ) + + @register(torch._foreach_pow) + def handle_foreach_pow_scalar(_, tx: "InstructionTranslator", *args, **kwargs): + # In eager it's more performant to call item() from within the C op implementation + # in compile, it's more performant to not graph break. + if len(args) == 2 and args[0].is_tensor() and not kwargs: + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.foreach_pow_scalar), + args, + kwargs, + ) + + @register(torch._assert) + def handle_assert(self, tx: "InstructionTranslator", condition, message): + if (condition.is_python_constant() and condition.as_python_constant()) or ( + isinstance(condition, variables.SymNodeVariable) + and condition.evaluate_expr() + ): + return ConstantVariable(None) + + @register(SDPAParams) + def handle_sdpa_params(self, tx: "InstructionTranslator", *args, **kwargs): + return wrap_fx_proxy( + tx, + proxy=tx.output.create_proxy( + "call_function", + torch._C._SDPAParams, + *proxy_args_kwargs(args, kwargs), + ), + param_vars=args, + ) + + if DistributedVariable.is_available(): + from torch.distributed.distributed_c10d import ( + _get_group_size_by_name, + _get_group_tag, + _rank_not_in_group, + _resolve_group_name_by_ranks_and_tag, + get_process_group_ranks, + ) + from torch.distributed.tensor import DTensor + + @register( + _get_group_size_by_name, + _get_group_tag, + _rank_not_in_group, + get_process_group_ranks, + _resolve_group_name_by_ranks_and_tag, + ) + def handle_constant_processgroup_functions( + self, tx: "InstructionTranslator", *args + ): + # because the input is a "ProcessGroupVariable", we'll be guarding on its + # ID_MATCH based on how it was constructed. + + # We desugar it at trace-time into ranks by directly calling util + # bake the result into the trace + if len(args) == 1: + # group or group name + assert ( + isinstance(args[0], ProcessGroupVariable) + or args[0].is_python_constant() + ) + elif len(args) == 2: + # ranks + tag + assert ( + isinstance(args[0], ListVariable) + and args[1].is_python_constant() + ) + else: + raise AssertionError( + f"Invalid group value ({args}) for constant pg " + f"function {self.value}" + ) + args_as_value = [arg.as_python_constant() for arg in args] + invocation_result = self.value(*args_as_value) + + # Note - while we *could* cook up sources around invocations, like a FunctionSource + # the space of invoking functions in the middle of the guard chain is very iffy. As such, + # guard propagation via options is the best we can do. + return VariableTracker.build(tx, invocation_result) + + @register(DTensor.from_local) + def handle_from_local(self, tx: "InstructionTranslator", *args, **kwargs): + # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function + # and rewrite args to have only proxyable args, then insert call_function + placements_vt = kwargs.get("placements") + + if placements_vt is None and len(args) >= 3: + placements_vt = args[2] + + if placements_vt is None: + placements_vt = ConstantVariable.create(None) + elif isinstance(placements_vt, variables.UserDefinedObjectVariable): + placements_vt = variables.BuiltinVariable(tuple).call_function( + tx, [placements_vt], {} + ) + + new_args = list(args) + if len(new_args) >= 3: + new_args[2] = placements_vt + elif kwargs.get("placements") is not None: + kwargs["placements"] = placements_vt + + args_as_value = [x.as_python_constant() for x in new_args[1:]] + kwargs_as_value = { + k: v.as_python_constant() + for k, v in kwargs.items() + if k not in ["shape", "stride"] + } + + kwargs_to_be_proxied = { + k: kwargs[k] for k in ["shape", "stride"] if k in kwargs + } + + def fn_with_prim_types(x, shape=None, stride=None): + return self.value( + x, *args_as_value, **kwargs_as_value, shape=shape, stride=stride + ) + + # attach the same function name for better debugging + fn_with_prim_types.__name__ = "prim " + self.value.__name__ + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn_with_prim_types, + *proxy_args_kwargs( + [args[0]], + kwargs_to_be_proxied, + ), + ), + ) + + @register(torch.nested.nested_tensor) + def handle_nested_tensor( + self, + tx: "InstructionTranslator", + tensor_list=None, + *args, + layout=None, + **kwargs, + ): + from .lists import BaseListVariable + + if layout and layout.is_constant_match(torch.strided): + unimplemented( + gb_type="Attempted to use strided NestedTensor", + context=f"layout={layout}", + explanation="Dynamo does not support this.", + hints=[ + "Change layout=torch.jagged.", + *graph_break_hints.SUPPORTABLE, + ], + ) + if not isinstance(tensor_list, BaseListVariable): + unimplemented( + gb_type="Attempted to use `nested_tensor` with non-list input", + context=f"tensor_list={tensor_list}", + explanation="Dynamo does not support this.", + hints=[ + "Change `nested_tensor` with list input.", + *graph_break_hints.USER_ERROR, + ], + ) + + @register(torch.nn.functional.one_hot) + def handle_one_hot(self, tx: "InstructionTranslator", *args, **kwargs): + if len(args) + len(kwargs) == 1 or ( + len(args) == 2 and args[1].is_constant_match(-1) + ): + unimplemented( + gb_type="Attempted to use `torch.nn.functional.one_hot` with data-dependent output shape", + context=f"args={args}, kwargs={kwargs}", + explanation="Dynamo does not support this.", + hints=[ + "Explicitly set the `num_classes` param of the function call " + "`torch.nn.functional.one_hot` to something other than -1.", + ], + ) + + @register(torch.fx.experimental.symbolic_shapes.guard_size_oblivious) + def handle_guard_size_oblivious(self, tx: "InstructionTranslator", expr): + if isinstance(expr, SymNodeVariable): + # TODO: this probably should be folded somewhere else but I'm not sure where + # TODO: some of the other symbolic_shapes special tools can also get this treatment too + return variables.ConstantVariable.create( + torch.fx.experimental.symbolic_shapes.guard_size_oblivious( + expr.sym_num + ) + ) + elif expr.is_python_constant(): + return expr + + @register(torch.fx.experimental.symbolic_shapes.guard_or_true) + def handle_guard_or_true(self, tx: "InstructionTranslator", expr): + if isinstance(expr, SymNodeVariable): + # TODO: this probably should be folded somewhere else but I'm not sure where + # TODO: some of the other symbolic_shapes special tools can also get this treatment too + return variables.ConstantVariable.create( + torch.fx.experimental.symbolic_shapes.guard_or_true(expr.sym_num) + ) + elif expr.is_python_constant(): + return expr + + @register(torch.fx.experimental.symbolic_shapes.guard_or_false) + def handle_guard_or_false(self, tx: "InstructionTranslator", expr): + if isinstance(expr, SymNodeVariable): + # TODO: this probably should be folded somewhere else but I'm not sure where + # TODO: some of the other symbolic_shapes special tools can also get this treatment too + return variables.ConstantVariable.create( + torch.fx.experimental.symbolic_shapes.guard_or_false(expr.sym_num) + ) + elif expr.is_python_constant(): + return expr + + @register(torch.fx.experimental.symbolic_shapes.statically_known_false) + def handle_statically_known_false(self, tx: "InstructionTranslator", expr): + if isinstance(expr, SymNodeVariable): + return variables.ConstantVariable.create( + torch.fx.experimental.symbolic_shapes.statically_known_false( + expr.sym_num + ) + ) + elif expr.is_python_constant(): + return expr + + @register(torch.fx.experimental.symbolic_shapes.guard_scalar) + def guard_scalar(self, tx: "InstructionTranslator", expr): + if isinstance(expr, SymNodeVariable): + val = expr.sym_num + elif expr.is_python_constant(): + val = expr.as_python_constant() + else: + unimplemented( + gb_type="torch.fx.experimental.symbolic_shapes.guard_scalar branch not supported", + context=f"expr: {expr}", + explanation="Expected `expr` to be a symbolic variable or constant.", + hints=[], + ) + return variables.ConstantVariable.create( + # pyrefly: ignore [bad-argument-type, unbound-name] + torch.fx.experimental.symbolic_shapes.guard_scalar(val) + ) + + @register(torch.fx.experimental.symbolic_shapes.statically_known_true) + def handle_statically_known_true(self, tx: "InstructionTranslator", expr): + if isinstance(expr, SymNodeVariable): + return variables.ConstantVariable.create( + torch.fx.experimental.symbolic_shapes.statically_known_true( + expr.sym_num + ) + ) + elif expr.is_python_constant(): + return expr + + @register(torch.fx.experimental.symbolic_shapes.sym_and) + def handle_sym_and(self, tx: "InstructionTranslator", *terms): + if all(isinstance(x, SymNodeVariable) for x in terms): + return SymNodeVariable.create( + tx, + torch.fx.experimental.symbolic_shapes.sym_and( + *(x.as_proxy() for x in terms) + ), + sym_num=None, + ) + + @register(torch.fx.experimental.symbolic_shapes.sym_or) + def handle_sym_or(self, tx: "InstructionTranslator", *terms): + if all(isinstance(x, SymNodeVariable) for x in terms): + return SymNodeVariable.create( + tx, + torch.fx.experimental.symbolic_shapes.sym_or( + *(x.as_proxy() for x in terms) + ), + sym_num=None, + ) + + @register(torch.fx.experimental.symbolic_shapes.has_static_value) + def handle_has_static_value(self, tx: "InstructionTranslator", expr): + if isinstance(expr, SymNodeVariable): + val = expr.sym_num + elif expr.is_python_constant(): + val = expr.as_python_constant() + else: + return + + return variables.ConstantVariable.create( + # pyrefly: ignore [bad-argument-type] + torch.fx.experimental.symbolic_shapes.has_static_value(val) + ) + + @register(torch._C._autograd._unsafe_set_version_counter) + def handle_unsafe_set_version_counter( + self, tx: "InstructionTranslator", *args, **kwargs + ): + from ..tensor_version_op import _unsafe_set_version_counter + + return TorchInGraphFunctionVariable( + _unsafe_set_version_counter + ).call_function(tx, [*args], kwargs) + + @register(torch._C._functorch.peek_interpreter_stack) + def handle_functorch_peek_interpreter_stack( + self, tx: "InstructionTranslator", *args, **kwargs + ): + # Wrap C++ interpreter (torch._C._functorch.CInterpreter) as UserDefinedObjectVariable, + # but Python interpreter (torch._functorch.pyfunctorch.FuncTorchInterpreter) as FuncTorchInterpreterVariable. + return UserDefinedObjectVariable( + torch._C._functorch.peek_interpreter_stack() + ) + + @register(torch._functorch.pyfunctorch.coerce_cinterpreter) + def handle_functorch_pyfunctorch_coerce_cinterpreter( + self, tx: "InstructionTranslator", *args, **kwargs + ): + cinterpreter = args[0].value + return FuncTorchInterpreterVariable( + torch._functorch.pyfunctorch.coerce_cinterpreter(cinterpreter) + ) + + @register(torch.tensor) + def handle_torch_tensor(self, tx: "InstructionTranslator", *args, **kwargs): + def check_any_unspec(x): + # NB: This includes UnspecializedPythonVariable + if x.is_tensor() or isinstance(x, SymNodeVariable): + return True + elif isinstance(x, (ListVariable, TupleVariable)): + return any(check_any_unspec(y) for y in x.items) + # TODO: there maybe other recursive structures you need to + # check + else: + return False + + data_arg = None + if args: + data_arg = args[0] + elif "data" in kwargs: + data_arg = kwargs["data"] + + # NB: OK to pass torch.tensor(tensor), this will trace fine + if ( + data_arg is not None + and not data_arg.is_tensor() + and check_any_unspec(data_arg) + ): + # This is slower and less canonical, so only use it if we + # have to + return TorchInGraphFunctionVariable(torch._refs.tensor).call_function( + tx, [*args], kwargs + ) + + @register(torch._C._pop_torch_function_stack) + def handle_pop_torch_function( + self, tx: "InstructionTranslator", *args, **kwargs + ): + assert not args and not kwargs + if not tx.symbolic_torch_function_state.mode_stack: + unimplemented( + gb_type="Attempted to pop from empty torch function mode stack", + context="", + explanation="Called `torch._C._pop_torch_function_stack` when torch function mode stack is empty.", + hints=[ + "Do not pop from empty torch function mode stack.", + *graph_break_hints.USER_ERROR, + ], + ) + TorchFunctionModeStackVariable.register_mutation(tx) + return tx.symbolic_torch_function_state.pop_torch_function_mode() + + @register(torch._C._push_on_torch_function_stack) + def handle_push_torch_function( + self, tx: "InstructionTranslator", *args, **kwargs + ): + if len(args) != 1 or kwargs: + raise_type_error_exc( + tx, + f"push_torch_function takes exactly one argument ({len(args)} given)", + ) + TorchFunctionModeStackVariable.register_mutation(tx) + tx.symbolic_torch_function_state.push_torch_function_mode(args[0]) + return ConstantVariable.create(None) + + @register(torch._C._len_torch_function_stack) + def handle_len_torch_function( + self, tx: "InstructionTranslator", *args, **kwargs + ): + if args or kwargs: + raise_type_error_exc(tx, "len_torch_function_stack takes no arguments") + return ConstantVariable.create( + len(tx.symbolic_torch_function_state.mode_stack) + ) + + @register(torch._C._get_function_stack_at) + def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs): + if len(args) != 1 or kwargs: + raise_type_error_exc( + tx, + f"get_function_stack_at takes exactly one argument ({len(args)} given)", + ) + ind = args[0].as_python_constant() + assert ind >= 0 and ind < len(tx.symbolic_torch_function_state.mode_stack) + return tx.symbolic_torch_function_state.mode_stack[ind] + + @register(torch.get_device_module.__wrapped__) + def handle_get_device_module(self, tx, *args, **kwargs): + if len(args) + len(kwargs) > 1 or (kwargs and "device" not in kwargs): + unimplemented( + gb_type="improper torch.get_device_module arguments", + context=f"args={args}, kwargs={kwargs}", + explanation="torch.get_device_module accepts 1 optional argument `device`", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + try: + if kwargs: + device = kwargs["device"].as_python_constant() + elif args: + device = args[0].as_python_constant() + else: + device = None + module = torch.get_device_module(device) + except Exception as e: + unimplemented( + gb_type="bad device argument to torch.get_device_module", + context=f"args={args}, kwargs={kwargs}", + explanation="Expected valid string/torch.device argument ('cpu', 'cuda', etc.)", + hints=[*graph_break_hints.USER_ERROR], + from_exc=e, + ) + + # need to guard only on no-arg get_device_module + # pyrefly: ignore [unbound-name] + if device is None: + source = CallFunctionNoArgsSource(self.source) + install_guard(source.make_guard(GuardBuilder.ID_MATCH)) + # assumes `module` is in the form `torch.xyz` + new_source = AttrSource( + TorchSource(), + # pyrefly: ignore [unbound-name] + module.__name__.rsplit(".", maxsplit=1)[-1], + ) + # pyrefly: ignore [unbound-name] + return VariableTracker.build(tx, module, new_source) + + @register(torch.accelerator.current_stream, torch.cuda.current_stream) + def handle_current_stream(self, tx: "InstructionTranslator", *args, **kwargs): + if len(args) + len(kwargs) > 1 or (kwargs and "device" not in kwargs): + unimplemented( + gb_type="unsupported arguments to torch.accelerator.current_stream", + context=f"args={args}, kwargs={kwargs}", + explanation="torch.accelerator.current_stream accepts one optional argument `device`", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + try: + if kwargs: + device = torch.device(kwargs["device"].as_python_constant()) + elif args: + device = torch.device(args[0].as_python_constant()) + else: + device = None + + return tx.symbolic_stream_state.cur_stream(device) + except Exception as e: + unimplemented( + gb_type="bad device argument to torch.accelerator.current_stream", + context=f"args={args}, kwargs={kwargs}", + explanation="Expected valid string/torch.device argument ('cpu', 'cuda', etc.)", + hints=[*graph_break_hints.USER_ERROR], + from_exc=e, + ) + + @register(torch.set_default_device) + def handle_set_default_device( + self, tx: "InstructionTranslator", *args, **kwargs + ): + # Today this is inserted in the graph, once TF mode + # handling is complete, we can trace the device context + # like any other TF mode and remove this special handling + # Insert the TF mode representing the device context at + # the bottom of the stack to match the eager semantics + # Running the graph will ensure that the DeviceContext mode is + # at the correct position in the stack + TorchFunctionModeStackVariable.register_mutation(tx) + if args[0].is_constant_none(): + TorchFunctionModeStackVariable.clear_default_device(tx) + else: + TorchFunctionModeStackVariable.register_device_context_insertion(tx) + + return ConstantVariable.create(None) + + @register(torch._check) + def handle_check(self, tx: "InstructionTranslator", *args, **kwargs): + predicate_vt = None + message_vt = None + + if args: + predicate_vt = args[0] + rest_args = args[1:] + else: + rest_args = () + + if predicate_vt is None and "cond" in kwargs: + predicate_vt = kwargs.pop("cond") + + if rest_args: + message_vt = rest_args[0] + elif "message" in kwargs: + message_vt = kwargs.pop("message") + + if predicate_vt is None: + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + (), + {}, + ), + ) + + message_eager = None + message_graph_proxy = None + if message_vt is not None: + if ( + not isinstance(message_vt, NestedUserFunctionVariable) + or message_vt.has_closure() + ): + unimplemented( + gb_type="Can't extract message from torch._check()", + context=str(message_vt), + explanation=( + "The second argument of torch._check() must be a function" + "defined within the torch.compile region" + "that does not reference a non-local variable." + ), + hints=[ + "Make sure the message function is defined in the torch.compile region.", + "Remove any closure variables, e.g. " + "remove references to closure variable `x` in `lambda: f'{x} failed check'`", + *graph_break_hints.SUPPORTABLE, + ], + ) + message_eager = message_vt.get_function() + + message_graph_proxy = tx.output.register_static_attr_and_return_proxy( + "_check_message", message_eager + ) + + if predicate_vt.is_python_constant(): + self.value(predicate_vt.as_python_constant(), message_eager) + return ConstantVariable.create(None) + + predicate_proxy = predicate_vt.as_proxy() + + proxy_args: tuple[Any, ...] + if message_graph_proxy is None: + proxy_args = (predicate_proxy,) + else: + proxy_args = (predicate_proxy, message_graph_proxy) + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + proxy_args, + {}, + ), + ) + + return handlers + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from . import ConstantVariable, SymNodeVariable + from .builder import wrap_fx_proxy + + if self.nonstrict_traceable: + return self._call_nonstrict_traceable_function(tx, args, kwargs) + + if self.torch_function_override_enabled(tx, args, kwargs): + return dispatch_torch_function(tx, self, args, kwargs) + + if self.can_constant_fold_through() and check_unspec_or_constant_args( + args, kwargs + ): + # constant fold functions need to be guarded. + if self.value in constant_fold_functions_need_guards: + assert self.source is not None + source = CallFunctionNoArgsSource(self.source) + install_guard(source.make_guard(GuardBuilder.EQUALS_MATCH)) + # constant fold + try: + return ConstantVariable.create( + self.as_python_constant()( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ), + ) + except (OverflowError, TypeError, ValueError) as exc: + raise_observed_exception( + type(exc), + tx, + args=list(map(ConstantVariable.create, exc.args)), + ) + + if self.is_tensor_method(): + name = self.value.__name__ + # Guard against inplace view op on input tensor (not supported) + if args and args[0].is_tensor(): + tensor_var = args[0] + # Check if input tensor and inplace_view op specifically + if tensor_var.source is not None and hasattr(torch.ops.aten, name): + fn = getattr(torch.ops.aten, name) + if ( + hasattr(fn, "overloads") + and hasattr(fn, fn.overloads()[0]) + and torch.Tag.inplace_view + in getattr(fn, fn.overloads()[0]).tags + ): + unimplemented( + gb_type="Inplace op on input tensor", + context="", + explanation=f"Attempted to trace an inplace view op on input tensor {typestr(self.value)}.", + hints=[ + *graph_break_hints.SUPPORTABLE, + "Ensure you do not modify input tensor in place.", + ], + ) + return self.call_tensor_method(tx, args, kwargs) + + special_handler = self._get_handlers().get(self.value) + if special_handler: + result = special_handler(self, tx, *args, **kwargs) + if result: + return result + + any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args) + + all_ints_or_floats = all( + isinstance(x, SymNodeVariable) or x.is_python_constant() for x in args + ) + if ( + getattr(self.value, "__module__", "") == "torch" + and self.value.__name__ in bin_ops + and any_symints_or_symfloats + and all_ints_or_floats + ): + msg = f"""\ +Calling {str(self.value)} on only torch.SymInt arguments is not yet supported. +To support this behavior, we need to allow const-propping tensors that store symint data. +For now, dynamo will explicitly graph break when it encounters user code with this behavior. +""" + log.warning(msg) + unimplemented( + gb_type="Attempted to call torch in-graph function on only torch.SymInt arguments", + context=f"fn={self.value}, args={args}, kwargs={kwargs}", + explanation=( + f"Attempted to call {str(self.value)} (that should be put in the FX graph) on only torch.SymInt arguments. " + "Dynamo does not support this." + ), + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + # TODO(voz): Replace w/ dynamic shape rewrite table. + # Ideally, we would be able to do this at ctor time, but alas we need a combination + # of value + args to determine this. + fn_ = self.value + if any_symints_or_symfloats: + torch_sym_op = f"_sym_{self.value.__name__}" + if getattr(self.value, "__module__", None) == "math" and hasattr( + torch, torch_sym_op + ): + fn_ = getattr(torch, torch_sym_op) + + # TODO for each of the following check on `out=` or `requires_grad=` + # variant torch ops, the original function could come from a user + # defined `@allow_in_graph` function as well, which doesn't have the + # same semantics as the torch ops. + + # Calling fake tensor propagation can mutate the out= tensor in + # tx.output.tracked_fakes. tracked_fakes are used to apply + # symbolic_shape guards. Mutating them destroys the information + # prior to tracing, which is essential for creating right + # guards. So save the shape now, and check later if it has + # changed. If it has, graph break. + saved_out_shapes = None + out_kwarg_vt = None + if "out" in kwargs: + out_kwarg_vt = kwargs["out"] + + # e.g., out=(t1, t2, ...) + if isinstance(out_kwarg_vt, (TupleVariable, ListVariable)): + saved_out_shapes = [] + for vt in out_kwarg_vt.items: + if vt.is_tensor(): + shape = vt.as_proxy().node.meta["example_value"].shape + else: + shape = None + saved_out_shapes.append(shape) + + # e.g., out=output_tensor + if out_kwarg_vt.is_tensor(): + saved_out_shapes = ( + out_kwarg_vt.as_proxy().node.meta["example_value"].shape + ) + + tensor_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn_, + *proxy_args_kwargs(args, kwargs), + ), + ) + + # Handle e.g., `torch.ones(10, requires_grad=True)` + if ( + tensor_variable.is_tensor() + and "requires_grad" in kwargs + and kwargs["requires_grad"].as_python_constant() + ): + unimplemented( + gb_type="Attempted to use tensor creation function with requires_grad=True", + context=f"fn={self.value}, args={args}, kwargs={kwargs}", + explanation="Dynamo does not support this.", + hints=[ + "Create the tensor outside the compiled region.", + "Do not set `requires_grad=True`.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + # Handle e.g., `torch.add(a, b, out=result)` + if saved_out_shapes is not None: + # out variants of torch operators like torch.sort and torch.sigmoid + # mutate the tensors in the out field. + # + # However, it's non-trivial to update all references of the old + # `TensorVariable` to the new one returned (`result_var`), so we + # take the conservative approach to graph break on size changes, and + # assume other cases can fall through soundly. + # + # Note that although these tensor variables would hold different + # proxies, the in-place mutation semantics is preserved in the FX + # graph, so we won't have correctness issues. + if isinstance(saved_out_shapes, list): + for out_tensor_vt, saved_out_shape in zip( + out_kwarg_vt.items, # type: ignore[union-attr] + saved_out_shapes, + ): + if saved_out_shape is None: + # This should be extremely rare, but it's kept for now + # until we invest in enforcing the `out=` kwarg for only + # torch methods. + continue + + assert out_tensor_vt.is_tensor() + fake_out = out_tensor_vt.proxy.node.meta["example_value"] + if saved_out_shape != fake_out.shape: + # It's hard to get out variants with resizing on graph inputs work + # properly across dynamo/aot/inductor, just fall back. + unimplemented( + gb_type="Shape mismatch with out= list of tensor variants", + context=f"fn={self.value}, args={args}, kwargs={kwargs}", + explanation=( + f"Shape mismatch when calling {self.value} with `out=`. " + f"Provided `out=` shape: {saved_out_shape}. Actual shape: {fake_out.shape}." + ), + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + if not torch._prims_common.is_contiguous(fake_out): + # It's difficult to handle strides correctly in functionalization + # when calling an out= op with a non-contiguous out argument + unimplemented( + gb_type="Attempted to call op with non-contiguous `out=` list of tensors", + context=f"self.value={self.value}, args={args}, kwargs={kwargs}", + explanation="Dynamo does not support this.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + else: + assert out_kwarg_vt is not None and out_kwarg_vt.is_tensor() + assert "example_value" in out_kwarg_vt.as_proxy().node.meta + fake_out = out_kwarg_vt.as_proxy().node.meta["example_value"] + if saved_out_shapes != fake_out.shape: + # It's hard to get out variants with resizing on graph inputs work + # properly across dynamo/aot/inductor, just fall back. + unimplemented( + gb_type="Shape mismatch with out= tensor variant", + context=f"fn={self.value}, args={args}, kwargs={kwargs}", + explanation=( + f"Shape mismatch when calling {self.value} with `out=`. " + f"Provided `out=` shape: {saved_out_shapes}. Actual shape: {fake_out.shape}." + ), + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + if not torch._prims_common.is_contiguous_or_false(fake_out): + # It's difficult to handle strides correctly in functionalization + # when calling an out= op with a non-contiguous out argument + unimplemented( + gb_type="Attempted to call op with non-contiguous `out=` tensor", + context=f"self.value={self.value}, args={args}, kwargs={kwargs}", + explanation="Dynamo does not support this.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + return tensor_variable + + def _call_nonstrict_traceable_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + import torch._higher_order_ops.flat_apply as flat_apply + from torch._higher_order_ops.flat_apply import ( + func_to_graphable, + is_graphable_type, + ) + from torch._subclasses.fake_tensor import fake_tensor_tls + from torch.utils._pytree import tree_flatten + + from .base import AsPythonConstantNotImplementedError + from .builder import wrap_fx_proxy + + # 1. Convert `args, kwargs` into pytree-flattened proxy forms. + # + # Rather than reconstructing `args, kwargs` into python objects and + # then tree_flatten them, we just let Dynamo symbolically interpret + # `tree_flatten((args, kwargs))`. This saves us from having to + # worry about the reconstruction logic, side effects, and guards. + packed_input_vt = TupleVariable.build( + tx, (TupleVariable.build(tx, args), ConstDictVariable.build(tx, kwargs)) + ) + out_vt = variables.UserFunctionVariable(tree_flatten).call_function( # type: ignore[arg-type] + tx, [packed_input_vt], {} + ) + assert isinstance(out_vt, TupleVariable) and len(out_vt.items) == 2 + flat_args_vts, input_spec_vt = out_vt.items + assert isinstance(flat_args_vts, ListVariable) + + # Handle the case when the input contains a non-graphable type. + for flat_arg_vt in flat_args_vts.items: + arg_type = flat_arg_vt.python_type() + if not is_graphable_type(arg_type): + type_name = flat_arg_vt.python_type().__qualname__ + unimplemented( + gb_type="Invalid input type for nonstrict_trace-ed function", + context=f"Encountered input of type <{type_name}>.", + explanation=( + "For `nonstrict_trace`-ed functions, only basic types (e.g., torch.Tensor, int, float) " + "or pytree containers of those are allowed as inputs. The provided argument contains " + "an unsupported type." + ), + hints=[ + "Use one of the following to register the type with pytree:\n" + "* `torch.utils._pytree.register_constant`\n" + "* `torch.utils._pytree.register_dataclass`\n" + "* `torch.utils._pytree.register_pytree_node`", + ], + ) + + # Since we checked with `is_graphable` above, `as_proxy` on the + # flat_arg VT should always work. + proxified_flat_args = [ + flat_arg_vt.as_proxy() for flat_arg_vt in flat_args_vts.items + ] + + # The downstream `flat_apply` call requires the input spec; however, + # the spec not a graphable type, so we still have to reconstruct it + # into a python object, and store it as a constant attribute on the + # fx graph. + try: + input_spec = input_spec_vt.as_python_constant() + except AsPythonConstantNotImplementedError as e: + typ = e.vt.python_type() + type_name = typ.__qualname__ + import torch.utils._pytree as pytree + + if pytree.is_constant_class(typ): + unimplemented( + gb_type="Input marked with `pytree.register_constant` constructed in the `torch.compile` region", + context=f"Input={input_spec_vt}, offending type <{type_name}>.", + explanation=( + "Calling a `nonstrict_trace`-ed function with an input that contains an object " + f"of type <{type_name}>, which was marked with `pytree.register_constant`. However, the object " + "was constructed _inside_ the `torch.compile` region. This is not supported." + ), + hints=[ + "Construct the object _outside_ the `torch.compile` region, or submit an issue to GitHub.", + *graph_break_hints.SUPPORTABLE, + ], + from_exc=e, + ) + else: + unimplemented( + gb_type="Invalid use of pytree_flatten with nonstrict_trace-ed function", + context=f"Input={input_spec_vt}, offending type <{type_name}>.", + explanation=( + "Calling a `nonstrict_trace`-ed function where one of the inputs has been registered " + f"with a `pytree_flatten` that places an object of type <{type_name}> into the context." + ), + hints=[ + "Modifying the `pytree_flatten` to avoid placing the object into the context.", + f"Apply one of the following to <{type_name}>:\n" + "* `torch.utils._pytree.register_constant`\n" + "* `torch.utils._pytree.register_dataclass`\n" + "* `torch.utils._pytree.register_pytree_node`", + *graph_break_hints.SUPPORTABLE, + ], + from_exc=e, + ) + + fn = self.value + + def patched_fn(*args, **kwargs): + # This enables reads to global/captured tensors, and we'll just + # treat them as constants in the graph. Note that after + # AOTDispatcher, this logic would disappear. + old_val = fake_tensor_tls.allow_non_fake_inputs_override + fake_tensor_tls.allow_non_fake_inputs_override = True + try: + res = fn(*args, **kwargs) + finally: # reset even when `fn` raises + fake_tensor_tls.allow_non_fake_inputs_override = old_val + return res + + # `flat_apply` wants a TreeSpec for the function input. + _, f_spec = func_to_graphable(patched_fn) + + # TreeSpec isn't graphable, so we register the function and input + # specs as attributes on the graph module. + f_spec_proxy = tx.output.register_static_attr_and_return_proxy( + f"{fn.__name__}_spec", f_spec + ) + input_spec_proxy = tx.output.register_static_attr_and_return_proxy( + fn.__name__ + "_input_spec", + # pyrefly: ignore [unbound-name] + input_spec, + ) + f_spec_proxy.node.type = type(f_spec) + # pyrefly: ignore [unbound-name] + input_spec_proxy.node.type = type(input_spec) + all_args = (f_spec_proxy, input_spec_proxy, *proxified_flat_args) + + # 2. Create a proxy call to `flat_apply`, then fake-tensor propagate + # the call and wrap output into a VariableTracker. + proxy = tx.output.create_proxy("call_function", flat_apply, all_args, {}) + try: + # TODO support more output types once `flat_apply` supports + # pytree-able output types. We can have Dynamo trace through an + # unflatten call (just like we traced through a flatten above) + # to rebuild the actual output VT. + out_vt = wrap_fx_proxy(tx, proxy) + except ( + # From `handle_traced_output`. + torch._dynamo.exc.Unsupported, + # From `flat_apply` assert on output type. + torch._dynamo.exc.TorchRuntimeError, + ): + unimplemented( + gb_type="Unsupported output type for nonstrict_trace-ed function", + context=f"Function: {fn.__name__}", + explanation=( + "For `nonstrict_trace`-ed functions, only basic types (e.g., torch.Tensor, int, list)" + " are allowed as output. The result of this call contains an unsupported type." + ), + hints=[*graph_break_hints.SUPPORTABLE], + ) + + return out_vt + + def _call_ntuple(self, tx: "InstructionTranslator", args, kwargs): + """inline behavior of torch.nn.modules.utils._ntuple""" + if self.value is torch.nn.modules.utils._ntuple: + count = args[0].as_python_constant() + else: + count = self.value.__closure__[0].cell_contents + assert isinstance(count, int) + assert not kwargs + + def handle_ntuple(value): + if value.has_unpack_var_sequence(tx): + return variables.TupleVariable( + list(value.unpack_var_sequence(tx)), + ) + elif value.is_python_constant(): + # constant prop through it + return variables.ConstantVariable.create( + torch.nn.modules.utils._ntuple(count)(value.as_python_constant()), + ) + else: + unimplemented( + gb_type="Attempted to use `torch.nn.modules.utils._ntuple` with unsupported argument type", + context=f"value={value}", + explanation="Dynamo does not support this.", + hints=[ + "Change use of _ntuple with argument as constant or tensor.", + ], + ) + + if self.value is torch.nn.modules.utils._ntuple: + return variables.LambdaVariable(handle_ntuple) + else: + return handle_ntuple(args[0]) + + @classmethod + def call_nn_parameter(cls, tx, data=None, requires_grad=True): + """A call to torch.nn.Parameter() gets lifted to before the graph""" + if tx.export: + unimplemented( + gb_type="Attempted to use `torch.nn.Parameter()` with export", + context="", + explanation="Dynamo does not support this.", + hints=[ + "Do not use `torch.nn.Parameter()` with export.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + if isinstance(requires_grad, variables.VariableTracker): + try: + requires_grad = requires_grad.as_python_constant() + except NotImplementedError: + unimplemented( + gb_type="non-constant `requires_grad` argument to `torch.nn.Parameter`", + context=f"requires_grad={requires_grad}", + explanation="Dynamo does not support this.", + hints=[ + "Change `requires_grad` to be a bool.", + *graph_break_hints.USER_ERROR, + ], + ) + + if data is None or not data.is_tensor(): + unimplemented( + gb_type="`torch.nn.Parameter()` with unsupported data type", + context=f"data={data}", + explanation="Called `torch.nn.Parameter()` with non-Tensor argument.", + hints=[ + "Ensure the argument to `torch.nn.Parameter()` is a `torch.Tensor`.", + *graph_break_hints.USER_ERROR, + ], + ) + + # this results in cleaner graphs, but only works for inputs + # pyrefly: ignore [missing-attribute] + if data.source: + return cls._nn_param_via_prefix_insert(tx, data, requires_grad) + + if config.graph_break_on_nn_param_ctor: + # Need user to manually move since we cannot + unimplemented( + gb_type="Attempted to use `torch.nn.Parameter()` constructor with Dynamo", + context="", + explanation="Dynamo does not support this", + hints=[ + "Try to construct `torch.nn.Parameter()` outside the compiled region.", + "If this is not possible, turn `graph_break_on_nn_param_ctor` off", + *graph_break_hints.SUPPORTABLE, + ], + ) + + # TODO[@lucaskabela]: Remove the behavior below since it is deprecated + if isinstance( + data, + TensorWithTFOverrideVariable, + # pyrefly: ignore [missing-attribute] + ) or is_traceable_wrapper_subclass_type(data.class_type): + unimplemented( + gb_type="Attempted to use torch.nn.Parameter constructor with tensor subclass", + context=str(data), + explanation="Dynamo does not support this.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + if not can_convert_to_tracable_parameter(): + unimplemented( + gb_type="`torch.nn.Parameter`: cannot convert to traceable tracable", + context="", + explanation="convert_tracable_parameter is set to False.", + hints=[ + "Check usage of context manager: do_not_convert_to_tracable_parameter", + *graph_break_hints.DIFFICULT, + ], + ) + + try: + # pyrefly: ignore [missing-attribute] + shape = tuple(data.var_getattr(tx, "shape").as_python_constant()) + # pyrefly: ignore [missing-attribute] + dtype = data.var_getattr(tx, "dtype").as_python_constant() + # pyrefly: ignore [missing-attribute] + device = data.var_getattr(tx, "device").as_python_constant() + except NotImplementedError as e: + unimplemented( + gb_type="`torch.nn.Parameter` with non-constant Tensor attributes", + context=f"data={data}", + explanation="Dynamo does not support this.", + hints=[ + "Ensure the Tensor argument's shape, dtype, and device are correct.", + *graph_break_hints.USER_ERROR, + ], + from_exc=e, + ) + + placeholder = tx.output.synthetic_graph_input( + new_parameter_placeholder, + # pyrefly: ignore [unbound-name] + [shape, dtype, device, requires_grad], + ) + # pyrefly: ignore [missing-attribute] + if data.requires_grad: + # pyrefly: ignore [missing-attribute] + data = data.call_method(tx, "detach", [], {}) + + from .builder import wrap_fx_proxy + + result = wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", + tracable_create_parameter, + # pyrefly: ignore [missing-attribute] + (data.as_proxy(), placeholder.as_proxy()), + {}, + ), + # In reconstruct() we should use the original parameter. The one + # returned by the graph will be an alias. + source=placeholder.source, + ) + assert result.is_tensor() + result.class_type = torch.nn.Parameter # type: ignore[union-attr] + + # TODO(jansel/bdhirsh) - There is some issue with + # tracable_create_parameter. It does not seem to use the right + # grad_enabled. Since this is parameter, we can just override the + # has_grad_fn field to False to workaround the issue. + result.has_grad_fn = False # type: ignore[union-attr] + + # TODO(jansel): if the new param falls out of scope, currently it won't get freed until + # the end of the graph. We should fix this. + return result + + @staticmethod + def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad): + # Alternate version if we have a .source + varname = tx.output.new_var() + + # construct the nn.Parameter before the graph save it to varname + assert tx.output.root_tx is not None + cg = PyCodegen(tx.output.root_tx) + cg.add_push_null(lambda: cg.load_import_from("torch.nn", "Parameter")) + cg(data.source) + cg(variables.ConstantVariable(requires_grad)) + cg.call_function(2, False) + cg.store(varname) + tx.output.pregraph_bytecode.extend(cg.get_instructions()) + + data_node = data.as_proxy().node + if data_node.op not in ("placeholder", "get_attr"): + unimplemented( + gb_type="Unexpected type of data placeholder op for parameter construction", + context=f"data_node.op={data_node.op}", + explanation="Data node op should be placeholder or get_attr.", + hints=[ + *graph_break_hints.DIFFICULT, + ], + ) + + # add the newly constructed nn.Parameter as a graph input + source = SyntheticLocalSource(varname) + example_value = torch.nn.Parameter( + tx.output.example_value_from_input_node(data.as_proxy().node), + requires_grad=requires_grad, + ) + result = VariableTracker.build(tx, example_value, source) + # Realize the VT because we will delete the guards on it in the next line. + result = result.realize() + # No need to guard on this since we already guarded on `data`. + # These guards would fail since varname doesn't exist until after the function starts + TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source( + source + ) + return result + + def call_tensor_method(self, tx, args, kwargs): + return args[0].call_method(tx, self.get_function().__name__, args[1:], kwargs) + + def is_tensor_method(self): + from ..trace_rules import get_tensor_method + + return ( + inspect.ismethoddescriptor(self.get_function()) + and hasattr(self.get_function(), "__objclass__") + and self.get_function().__objclass__ == torch._C.TensorBase + ) or self.get_function() in get_tensor_method() + + def torch_function_override_enabled(self, tx, args, kwargs): + return ( + self.get_function() in get_overridable_functions() + or isinstance( + self.get_function(), + (torch._ops.OpOverload, torch._ops.OpOverloadPacket), + ) + ) and can_dispatch_torch_function(tx, args, kwargs) + + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.value) + + def is_python_equal(self, other): + return self.as_python_constant() == other.as_python_constant() + + +class DispatchKeySetVariable(BaseTorchVariable): + """represents torch.DispatchKeySet""" + + @staticmethod + def create(value, **kwargs): + return DispatchKeySetVariable(value, **kwargs) + + @classmethod + def create_with_source(cls, value, source): + install_guard(source.make_guard(GuardBuilder.DISPATCH_KEY_SET_MATCH)) + return cls(value, source=source) + + def is_constant_fold_method(self, name): + return name == "has" + + def call_method( + self, + tx, + name, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> "VariableTracker": + if self.is_constant_fold_method(name) and check_unspec_or_constant_args( + args, kwargs + ): + method = getattr(self.value, name) + return variables.ConstantVariable.create( + method( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ), + ) + elif name == "highestPriorityTypeId": + return variables.EnumVariable(self.value.highestPriorityTypeId()) + return super().call_method(tx, name, args, kwargs) + + +class FuncTorchInterpreterVariable(BaseTorchVariable): + """represents torch._functorch.pyfunctorch.FuncTorchInterpreter""" + + @classmethod + def create_with_source(cls, value, source): + install_guard(source.make_guard(GuardBuilder.ID_MATCH)) + return cls(value, source=source) + + def call_method( + self, + tx, + name, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> "VariableTracker": + if name == "key": + return variables.EnumVariable(self.value.key()) + elif name == "process": + return tx.inline_user_function_return( + VariableTracker.build(tx, self.value.process.__func__), + [self] + args, + kwargs, + ) + elif name in ["level", "batch_size", "randomness"]: + return variables.ConstantVariable.create(getattr(self.value, name)()) + elif name == "lower": + assert not args and not kwargs + return variables.TemporarilyPopInterpreterStackCtxManagerVariable.create( + tx, None + ) + return super().call_method(tx, name, args, kwargs) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/torch_function.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/torch_function.py new file mode 100644 index 0000000000000000000000000000000000000000..b2a86eb4f017f88b36fcd7ac94c488352bc4e26f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/torch_function.py @@ -0,0 +1,761 @@ +"""TorchDynamo support for __torch_function__ tensor subclasses. + +This module implements support for tensor subclasses with __torch_function__ overrides. +A tensor subclass instance is represented as a TensorWithTFOverrideVariable, which handles +dispatching __torch_function__ on attribute accesses, method calls, and torch API calls. + +Unsupported features: +- Triggering __torch_function__ on tensor subclass non-tensor custom attributes +- Graph breaking on mutating guardable tensor properties within a __torch_function__ context + (can cause excessive recompiles in certain cases) +- Matching exact eager behavior of ignoring __torch_function__ objects in non-tensor + argument positions of Torch API calls + +Supported features: +- Static method implementations of __torch_function__ on custom objects (triggers on torch + API calls with the object as any argument) +- Triggering __torch_function__ on torch API calls with tensor subclass arguments +- __torch_function__ calls on base tensor attribute access and method calls for tensor + subclass instances +- Matches dispatch ordering behavior of eager __torch_function__ with subclass/object + arguments in any position + +See https://docs.google.com/document/d/1WBxBSvW3NXhRp9ncmtokJloMLCtF4AYNhJaffvHe8Kw/edit#heading=h.vacn73lozd9w +for more information on the design. +""" + +import collections +import contextlib +import functools +import inspect +import operator +from collections.abc import Generator, Iterable, Sequence +from types import TracebackType +from typing import Any, Optional, TYPE_CHECKING + +import torch._C +import torch.utils._pytree as pytree +from torch._guards import Source +from torch.overrides import ( + _get_overloaded_args, + get_default_nowrap_functions, + TorchFunctionMode, +) +from torch.utils._device import DeviceContext + +from .. import graph_break_hints +from ..exc import unimplemented +from ..guards import GuardBuilder, install_guard +from ..polyfills import NoEnterTorchFunctionMode +from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource +from ..utils import ( + class_has_getattribute, + clear_torch_function_mode_stack, + get_safe_global_name, + has_torch_function, + is_tensor_base_attr_getter, + set_torch_function_mode_stack, +) +from .base import VariableTracker +from .constant import ConstantVariable +from .ctx_manager import GenericContextWrappingVariable +from .functions import UserMethodVariable +from .lazy import LazyVariableTracker +from .lists import TupleVariable +from .tensor import TensorSubclassVariable, TensorVariable +from .user_defined import UserDefinedObjectVariable + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + +bin_ops = [ + operator.pow, + operator.mul, + operator.matmul, + operator.floordiv, + operator.truediv, + operator.mod, + operator.add, + operator.lt, + operator.gt, + operator.ge, + operator.le, + operator.ne, + operator.eq, + operator.sub, + operator.ipow, + operator.imul, + operator.imatmul, + operator.ifloordiv, + operator.itruediv, + operator.imod, + operator.iadd, + operator.isub, +] + +bin_int_ops = [ + operator.and_, + operator.or_, + operator.xor, + operator.iand, + operator.ixor, + operator.ior, +] + +un_int_ops = [operator.invert] + +tensor_and_int_ops = [ + operator.lshift, + operator.rshift, + operator.ilshift, + operator.irshift, + operator.getitem, +] + +un_ops = [ + operator.abs, + operator.pos, + operator.neg, + operator.not_, # Note: this has a local scalar dense call + operator.length_hint, +] + + +banned_attrs = [ + fn.__self__.__name__ # type: ignore[attr-defined] + for fn in get_default_nowrap_functions() + if is_tensor_base_attr_getter(fn) +] + + +@functools.cache +def get_prev_stack_var_name() -> str: + from ..bytecode_transformation import unique_id + + return unique_id("___prev_torch_function_mode_stack") + + +class TorchFunctionModeVariable(GenericContextWrappingVariable): + @staticmethod + def is_supported_torch_function_mode(ty: type[TorchFunctionMode]) -> bool: + # Supported in this sense means we can support graph breaks under the + # context. + # We are able to trace custom modes but if there are graph breaks under them + # and they have a custom __enter__/__exit__ we don't handle this for the + # same reason we don't handle generic context managers: there may be side effects + # that are now affected by executing the function across two frames instead of one + # Today we support the enter/exit of the default TorchFunctionMode as well as + # DeviceContext (which is used for set_default_device) + return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or ( + not class_has_getattribute(ty) + and inspect.getattr_static(ty, "__enter__") is TorchFunctionMode.__enter__ + and inspect.getattr_static(ty, "__exit__") is TorchFunctionMode.__exit__ + ) + + def __init__( + self, + value: Optional[TorchFunctionMode], + source: Optional[Source] = None, + **kwargs: Any, + ): + if value is not None: + super().__init__(value, **kwargs) + self.value = value + # needed for BC with calling enter from CM code + self.cm_obj = value # type: ignore[assignment] + self.source = source # type: ignore[assignment] + + def reconstruct(self, codegen: "PyCodegen") -> None: + # This shouldn't be called unless we have a source + assert self.source + self.source.reconstruct(codegen) + + def module_name(self) -> str: + return self.value.__module__ + + def fn_name(self) -> str: + return type(self.value).__name__ + + def python_type(self) -> type: + return type(self.value) + + def call_torch_function( + self, + tx: "InstructionTranslator", + fn: VariableTracker, + types: TupleVariable, + args: Iterable[Any], + kwargs: dict[str, Any], + ) -> VariableTracker: + return call_torch_function( + tx, + get_torch_function_fn(tx, self), # type: ignore[arg-type] + fn, + types, + args, + kwargs, + ) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + from .torch import TorchInGraphFunctionVariable + + if isinstance(self.value, NoEnterTorchFunctionMode): + return ConstantVariable.create(None) + + TorchInGraphFunctionVariable( + torch._C._push_on_torch_function_stack + ).call_function(tx, [self], {}) + return ConstantVariable.create(None) + + def exit(self, tx: "InstructionTranslator", *args: Any) -> VariableTracker: + from .torch import TorchInGraphFunctionVariable + + TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function( + tx, [], {} + ) + return ConstantVariable.create(None) + + def reconstruct_type(self, codegen: "PyCodegen") -> None: + ty = NoEnterTorchFunctionMode + codegen( + AttrSource( + codegen.tx.import_source(ty.__module__), + ty.__name__, + ) + ) + + def supports_graph_breaks(self) -> bool: + return True + + def exit_on_graph_break(self) -> bool: + return False + + +# Used to clear/restore the python torch function mode stack and temporarily restore it as needed +class TorchFunctionModeStackStateManager: + def __init__(self) -> None: + self.stack: list[Any] = [] + + def __enter__(self) -> None: + self.stack = torch.overrides._get_current_function_mode_stack() + clear_torch_function_mode_stack() + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + set_torch_function_mode_stack(self.stack) + self.stack = [] + + @contextlib.contextmanager + def temp_restore_stack(self) -> Generator[None, None, None]: + prev = torch.overrides._get_current_function_mode_stack() + set_torch_function_mode_stack(self.stack) + try: + yield + finally: + set_torch_function_mode_stack(prev) + + +torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager() + + +class SymbolicTorchFunctionState: + def __init__(self, py_stack: Iterable[Any]) -> None: + # This is annoyingly complicated because of how the torch function subclass + mode C API was designed + # There are two exposed C knobs here as contexts: torch._C.DisableTorchFunction and torch._C.DisableTorchFunctionSubclass + # These are their definitions: + # 1) torch._C._is_torch_function_enabled indicates that neither of the above knobs have been entered + # (if either are entered, this will be False) + # 2) torch._C._is_torch_function_mode_enabled indicates that either the torch mode stack is empty OR + # torch._C.DisableTorchFunction has been entered + # To disambiguate these and keep myself sane I added a C API to check whether all torch function + # concepts (modes and subclasses) are enabled. + # This only returns true iff we have not entered torch._C.DisableTorchFunction and allows us to separate + # the stack length from the enablement state of torch function modes. + # This is important because now if a mode is pushed while dynamo is tracing, we know whether + # or not torch function modes are enabled and whether we should trace it. + self.torch_function_subclass_enabled = torch._C._is_torch_function_enabled() + + # This differs from the C API of the same name + # this will only be false iff we have entered torch._C.DisableTorchFunction + # and does not take into account the mode stack length, while the C API bundles these + # two concepts + self.torch_function_mode_enabled = ( + not torch._C._is_torch_function_all_disabled() + ) + + self.cur_mode = None + + TorchFunctionModeStackVariable.reset() + + self.mode_stack: collections.deque[TorchFunctionModeVariable] = ( + collections.deque() + ) + + for i, val in enumerate(py_stack): + self.mode_stack.append( + LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i)) # type: ignore[arg-type] + ) + + def in_torch_function_mode(self) -> bool: + return len(self.mode_stack) > 0 + + def pop_torch_function_mode(self) -> TorchFunctionModeVariable: + return self.mode_stack.pop() + + def push_torch_function_mode(self, mode_var: TorchFunctionModeVariable) -> None: + self.mode_stack.append(mode_var) + + def call_torch_function_mode( + self, + tx: "InstructionTranslator", + fn: VariableTracker, + types: TupleVariable, + args: Iterable[Any], + kwargs: dict[str, Any], + ) -> Any: + with self._pop_mode_for_inlining() as cur_mode: + return cur_mode.call_torch_function(tx, fn, types, args, kwargs) + + @contextlib.contextmanager + def _pop_mode_for_inlining( + self, + ) -> Generator[TorchFunctionModeVariable, None, None]: + old_mode = self.cur_mode + self.cur_mode = self.pop_torch_function_mode() # type: ignore[assignment] + try: + yield self.cur_mode # type: ignore[misc] + finally: + mode = self.cur_mode + self.cur_mode = old_mode + self.push_torch_function_mode(mode) # type: ignore[arg-type] + + +class TorchFunctionModeStackVariable(VariableTracker): + """Fake VT to use as a dummy object, indicating the presence of torch function mode stack mutation""" + + # singleton value representing the global torch function mode stack + # singleton (it exists in C++) + stack_value_singleton = object() + + # offset is used to track if we have inserted/removed a + # device context which is always placed at the bottom of the stack + # if a device context is inserted, the graph will run this mutation + # so when we want to reconstruct any other modes on the stack + # their indices should be shifted right by 1 (+1) + # Conversely, if there was a device context on the stack, and the graph + # mutates the stack to remove that context (set default device to None) + # each of the indices of other modes should be shifted left by 1 (-1) + offset = 0 + + def __init__( + self, + source: Source, + symbolic_stack: collections.deque[TorchFunctionModeVariable], + ) -> None: + self.source = source + self.symbolic_stack = symbolic_stack + + @classmethod + def reset(cls) -> None: + cls.offset = 0 + + @classmethod + def register_mutation(cls, tx: "InstructionTranslator") -> None: + if cls.stack_value_singleton not in tx.output.side_effects: + var = cls( + source=Source(), + symbolic_stack=tx.symbolic_torch_function_state.mode_stack, + ) + tx.output.side_effects.track_mutable(cls.stack_value_singleton, var) + tx.output.side_effects.mutation(var) + + @classmethod + def register_device_context_insertion(cls, tx: "InstructionTranslator") -> None: + stack = tx.symbolic_torch_function_state.mode_stack + if stack and cls.is_device_context(stack[0]): + return + else: + cls.offset += 1 + stack.insert( + 0, + TorchFunctionModeVariable( + None, source=TorchFunctionModeStackSource(-cls.offset) + ), + ) + + @classmethod + def clear_default_device(cls, tx: "InstructionTranslator") -> None: + stack = tx.symbolic_torch_function_state.mode_stack + if stack and cls.is_device_context(stack[0]): + stack.popleft() + cls.offset -= 1 + + @staticmethod + def is_device_context(var: TorchFunctionModeVariable) -> bool: + return isinstance(var.value, DeviceContext) or var.value is None + + @classmethod + def get_mode_index(cls, ind: int) -> int: + return ind + cls.offset + + +def _get_all_args( + args: Iterable[Any], kwargs: dict[str, Any] +) -> Iterable[VariableTracker]: + return _flatten_vts(pytree.arg_tree_leaves(*args, **kwargs)) + + +def _flatten_vts(vts: Iterable[VariableTracker]) -> list[VariableTracker]: + from collections import deque + + from .dicts import ConstDictVariable + from .lists import ListVariable + + vts = deque(vts) + output = [] + + while vts: + vt = vts.popleft() + + if not vt.is_realized() and vt.peek_type() in (dict, list, tuple): # type: ignore[attr-defined] + vt.realize() + + if vt.is_realized(): + if isinstance(vt, ListVariable): + vts.extend(vt.items) + continue + elif isinstance(vt, ConstDictVariable): + vts.extend(vt.items.values()) + continue + + output.append(vt) + + return output + + +def _get_subclass_type(var: VariableTracker) -> type: + assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable)) + return var.python_type() + + +def _get_subclass_type_var( + tx: "InstructionTranslator", var: VariableTracker +) -> VariableTracker: + if isinstance(var, TensorWithTFOverrideVariable): + return var.class_type_var(tx) + elif isinstance(var, UserDefinedObjectVariable): + source = var.source and TypeSource(var.source) + return VariableTracker.build(tx, var.python_type(), source) + else: + raise AssertionError(f"Unexpected type {type(var)}") + + +def _is_attr_overridden( + tx: "InstructionTranslator", var: VariableTracker, name: str +) -> bool: + if not isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable)): + return False + import torch + + overridden = False + try: + attr_val = inspect.getattr_static(var.python_type(), name) + overridden |= attr_val != getattr(torch.Tensor, name) + except AttributeError: + pass + + return overridden + + +def call_torch_function( + tx: "InstructionTranslator", + torch_function_var: VariableTracker, + fn: VariableTracker, + types: TupleVariable, + args: Iterable[Any], + kwargs: dict[str, Any], +) -> Any: + # This emulates calling __torch_function__, which has a signature + # def __torch_function__(cls, func, types, args=(), kwargs=None): + # + # Also notice the `cls` is not explicitly passed in the reference + # implementations: + # 1. https://github.com/pytorch/pytorch/blob/8d81806211bc3c0ee6c2ef235017bacf1d775a85/torch/csrc/utils/python_arg_parser.cpp#L368-L374 # noqa: B950 + # 2. https://github.com/pytorch/pytorch/blob/8d81806211bc3c0ee6c2ef235017bacf1d775a85/torch/overrides.py#L1741-L1743 + tf_args = [ + fn, + types, + VariableTracker.build(tx, tuple(args)), + VariableTracker.build(tx, kwargs), + ] + return torch_function_var.call_function(tx, tf_args, {}) + + +def get_torch_function_fn( + tx: "InstructionTranslator", vt: VariableTracker +) -> VariableTracker: + # The underlying function could be a classmethod, staticmethod, regular + # function or a function with C-implementation. It doesn't matter as long as + # they satisfy the calling convention in `call_torch_function`. + from .builtin import BuiltinVariable + + args = [vt, ConstantVariable("__torch_function__")] + func_vt = BuiltinVariable(getattr).call_function(tx, args, {}) + return func_vt + + +def can_dispatch_torch_function( + tx: "InstructionTranslator", args: Iterable[Any], kwargs: dict[str, Any] +) -> bool: + has_overridden_args = any( + has_torch_function(arg) for arg in _get_all_args(args, kwargs) + ) + tf_state = tx.symbolic_torch_function_state + return (has_overridden_args and tf_state.torch_function_subclass_enabled) or ( + tf_state.torch_function_mode_enabled and tf_state.in_torch_function_mode() + ) + + +def dispatch_torch_function( + tx: "InstructionTranslator", + fn: VariableTracker, + args: Iterable[Any], + kwargs: dict[str, Any], +) -> Any: + """Gathers all args that are TensorWithTFOverrideVariable and dispatches based on the ordering in _get_overloaded_args""" + + all_args = _get_all_args(args, kwargs) + overloaded_args = _get_overloaded_args( + [arg for arg in all_args if has_torch_function(arg)], + _get_subclass_type, + ) + + types = TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]) + + if tx.symbolic_torch_function_state.in_torch_function_mode(): + res = tx.symbolic_torch_function_state.call_torch_function_mode( + tx, fn, types, args, kwargs + ) + if not res.is_constant_match(NotImplemented): + return res + + for arg in overloaded_args: + res = arg.call_torch_function( + tx, + fn, + types, + args, + kwargs, + ) + + if not res.is_constant_match(NotImplemented): + return res + + unimplemented( + gb_type="All __torch_function__ overrides returned NotImplemented due to TypeError from user code", + context=f"{fn=}, {args=}, {kwargs=}", + explanation=f"All __torch_function__ overrides for for function {fn} returned NotImplemented", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + +class TensorWithTFOverrideVariable(TensorVariable): + """ + Represents a tensor subclass instance with a __torch_function__ override. + """ + + @classmethod + def from_tensor_var( + cls, + tx: "InstructionTranslator", + tensor_var: VariableTracker, + class_type: type, + cls_source: Source, + ) -> "TensorWithTFOverrideVariable": + # [Note: __torch_function__] coerce `tensor_var` into a + # TensorWithTFOverrideVariable. In eager, this is just a type change. + import torch + + # This simulates shallow-copying the tensor object. + kwargs = dict(tensor_var.__dict__) + input_tensor_type = kwargs.pop("class_type") + assert input_tensor_type in (torch.Tensor, torch.nn.Parameter), ( + f"invalid class type {input_tensor_type} in TensorWithTFOverrideVariable.from_tensor_var" + ) + var = cls(class_type=class_type, **kwargs) + var.install_global(tx) + return var + + def install_global(self, tx: "InstructionTranslator") -> None: + # stash the subclass type to rewrap an output tensor if needed + # this is needed because the actual type needs to be available + # each time the compiled artifact is run and outputs a wrapped tensor. + if self.global_mangled_class_name(tx) not in tx.output.global_scope: + # Safe because global_mangled_class_name figures it out + tx.output.install_global_unsafe( + self.global_mangled_class_name(tx), self.class_type + ) + + def python_type(self) -> type: + return self.class_type + + def class_type_var(self, tx: "InstructionTranslator") -> VariableTracker: + return TensorSubclassVariable( + self.class_type, source=GlobalSource(self.global_mangled_class_name(tx)) + ) + + def global_mangled_class_name(self, tx: "InstructionTranslator") -> str: + return get_safe_global_name( + tx, f"__subclass_{self.class_type.__name__}", self.class_type + ) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + # [Note: __torch_function__] We currently only support attributes that are defined on + # base tensors, custom attribute accesses will graph break. + import torch + + # I think only `_base` is breaking because we aren't modelling view + # relationship perfectly in some scenarios. + if name in banned_attrs: + unimplemented( + gb_type="Unsupported tensor subclass attribute access", + context=f"{name}", + explanation="`torch.compile` currently can't trace this", + hints=[ + f"Avoid accessing {name} of tensor subclass in torch.compile region", + *graph_break_hints.SUPPORTABLE, + ], + ) + + # Handle non-overridden attributes inherited from `torch.Tensor`. + attr_is_overridden = _is_attr_overridden(tx, self, name) + if ( + hasattr(torch.Tensor, name) + and not attr_is_overridden + and not inspect.ismethoddescriptor(getattr(torch.Tensor, name)) + ): + args = [self] + kwargs: dict[Any, Any] = {} + if can_dispatch_torch_function(tx, args, kwargs): + get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__) + + return self.call_torch_function( + tx, + get_fn, + TupleVariable([self.class_type_var(tx)]), + args, + kwargs, + ) + else: + # `TensorVariable.var_getattr` doesn't handle user-defined + # function/attribute well, so we explicitly handle them here. + # + # TODO move this logic into `TensorVariable`, or try to merge it + # with similar logic in `UserDefinedObjectVariable`. + try: + attr = inspect.getattr_static(self.class_type, name) + except AttributeError: + pass + else: + import types + + cls_source = GlobalSource(self.global_mangled_class_name(tx)) + attr_source = AttrSource(cls_source, name) + if isinstance(attr, types.FunctionType): + install_guard(attr_source.make_guard(GuardBuilder.CLOSURE_MATCH)) + return UserMethodVariable(attr, self) + + elif isinstance(attr, property): + getter_source = AttrSource(attr_source, "fget") + getter = attr.fget + getter_var = VariableTracker.build(tx, getter, source=getter_source) + return getter_var.call_function(tx, [self], {}) + + elif isinstance(attr, classmethod): + return UserMethodVariable( + attr.__func__, self.class_type_var(tx), source=attr_source + ) + + elif attr_is_overridden: + unimplemented( + gb_type="Unsupported tensor subclass overridden attribute access", + context=f"{name}", + explanation="`torch.compile` only support tracing certain types of overridden tensor subclass attributes", + hints=[ + f"Avoid accessing {name} of tensor subclass in torch.compile region", + f"Renaming attribute `{name}` of type {self.class_type}", + *graph_break_hints.SUPPORTABLE, + ], + ) + + return super().var_getattr(tx, name) + + def call_torch_function( + self, + tx: "InstructionTranslator", + fn: VariableTracker, + types: TupleVariable, + args: Iterable[Any], + kwargs: dict[str, Any], + ) -> Any: + # NOTE this assumes `__torch_function__` isn't modified during tracing. + if not hasattr(self, "torch_function_fn"): + self.torch_function_fn = get_torch_function_fn(tx, self) + + return call_torch_function( + tx, + self.torch_function_fn, + fn, + types, + args, + kwargs, + ) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: Sequence[VariableTracker], + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + # This code block implements inlining the __torch_function__ override + # of `call_method`. + tf_args = [self] + list(args) + if can_dispatch_torch_function(tx, tf_args, kwargs): + import torch + + if _is_attr_overridden(tx, self, name): + unimplemented( + gb_type="Tensor subclass overridden method call", + context=f"{name}", + explanation="`torch.compile` currently can't trace this", + hints=[ + f"Avoid calling {name} of tensor subclass in torch.compile region", + f"Renaming method `{name}` of type {self.class_type}", + *graph_break_hints.SUPPORTABLE, + ], + ) + + # [Note: __torch_function__] Currently we only support methods that are defined on tensor + # we will graph break in other cases this will need a bigger overhaul of extracting methods/comparing them for equality + # We've established with the above check that the method is not overridden, so we guard that the method is the same + # as the impl defined on tensor and retrieve it + if self.source: + source = AttrSource(AttrSource(self.source, "__class__"), name) + value = inspect.getattr_static(self.python_type(), name) + else: + source = None + value = getattr(torch.Tensor, name) + func_var = VariableTracker.build(tx, value, source) + return dispatch_torch_function(tx, func_var, tf_args, kwargs) + else: + return super().call_method(tx, name, args, kwargs) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/user_defined.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/user_defined.py new file mode 100644 index 0000000000000000000000000000000000000000..b3b39b2f9b53e0ba6c04db41dc616ad3d5daea4a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/user_defined.py @@ -0,0 +1,2397 @@ +# mypy: ignore-errors + +""" +This module contains variable classes for handling user-defined objects in Dynamo's tracing system. + +The key classes are: +- UserDefinedVariable: Base class for representing custom Python objects +- UserDefinedClassVariable: Handles Python class objects/types +- UserDefinedObjectVariable: Fallback class for instance objects, with support for method calls, + attribute access, and other Python object behaviors. +- Specialized subclasses for common patterns: + - UserDefinedDictVariable: For dict subclasses + - UserDefinedSetVariable: For set subclasses + - UserDefinedTupleVariable: For tuple subclasses + - UserDefinedExceptionObjectVariable: For exception subclasses + - FrozenDataClassVariable: Special handling of frozen dataclasses + - MutableMappingVariable: For collections.abc.MutableMapping subclasses + +Dynamo specializes to VariableTracker subclasses like FrozenDataClassVariable if available; if no +subclass qualifies, it falls back to UserDefinedObjectVariable. + +These classes help Dynamo track and handle arbitrary Python objects during tracing, +maintaining proper semantics while enabling optimizations where possible. +""" + +import _collections +import builtins +import collections +import contextlib +import dataclasses +import enum +import functools +import inspect +import itertools +import random +import sys +import threading +import types +import warnings +import weakref +from typing import TYPE_CHECKING +from typing_extensions import is_typeddict + +import torch._dynamo.config +import torch.nn +from torch._guards import TracingContext +from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type + +from .. import graph_break_hints, polyfills, variables +from ..bytecode_transformation import create_call_function +from ..create_parameter_op import do_not_convert_to_tracable_parameter +from ..exc import ( + handle_observed_exception, + ObservedAttributeError, + ObservedKeyError, + ObservedTypeError, + ObservedUserStopIteration, + raise_observed_exception, + unimplemented, +) +from ..graph_bytecode_inputs import get_external_object_by_index +from ..guards import GuardBuilder, install_guard +from ..source import ( + AttrSource, + CallFunctionNoArgsSource, + DataclassFieldsSource, + DictGetItemSource, + GetItemSource, + RandomValueSource, + TypeDictSource, + TypeMROSource, + TypeSource, + UnspecializedParamBufferSource, +) +from ..utils import ( + check_constant_args, + cmp_name_to_op_mapping, + dict_methods, + frozenset_methods, + get_custom_getattr, + has_torch_function, + is_frozen_dataclass, + is_lru_cache_wrapped_function, + is_namedtuple_cls, + is_wrapper_or_member_descriptor, + istype, + list_methods, + namedtuple_fields, + object_has_getattribute, + proxy_args_kwargs, + raise_args_mismatch, + raise_on_overridden_hash, + set_methods, + tensortype_to_dtype, + tuple_methods, + unpatched_nn_module_getattr, +) +from .base import raise_type_error_exc, ValueMutationNew, VariableTracker +from .dicts import ConstDictVariable, DefaultDictVariable +from .lists import SizeVariable + + +try: + import numpy as np +except ModuleNotFoundError: + np = None + +try: + from torch.utils._cxx_pytree import PyTreeSpec +except ImportError: + PyTreeSpec = type(None) + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + from .constant import ConstantVariable + + +def is_standard_setattr(val): + return val in (object.__setattr__, BaseException.__setattr__) + + +def is_standard_delattr(val): + return val in (object.__delattr__, BaseException.__delattr__) + + +def is_forbidden_context_manager(ctx): + f_ctxs = [] + + try: + from _pytest.python_api import RaisesContext + from _pytest.recwarn import WarningsChecker + + f_ctxs.append(RaisesContext) + f_ctxs.append(WarningsChecker) + except ImportError: + pass + + if m := sys.modules.get("torch.testing._internal.jit_utils"): + f_ctxs.append(m._AssertRaisesRegexWithHighlightContext) + + return ctx in f_ctxs + + +def is_cython_function(obj): + return ( + callable(obj) + and hasattr(type(obj), "__name__") + and type(obj).__name__ == "cython_function_or_method" + ) + + +class UserDefinedVariable(VariableTracker): + value: object + + +class UserDefinedClassVariable(UserDefinedVariable): + value: type[object] + + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + # Used when we materialize class.__dict__ to a MappingProxyObject. In + # this case, we don't want to allow mutation in the class because there + # is no way to reflect it in the created MappingProxyVariable. + self.ban_mutation = False + + def as_python_constant(self): + return self.value + + def as_proxy(self): + return self.value + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.value})" + + @staticmethod + @functools.cache + def _constant_fold_classes(): + return { + torch.device, + torch.finfo, + torch.iinfo, + torch.Size, + } + + @staticmethod + @functools.cache + def _in_graph_classes(): + _in_graph_class_list = { + torch.Tensor, + torch.cuda.FloatTensor, + torch.cuda.DoubleTensor, + torch.cuda.HalfTensor, + torch.cuda.BFloat16Tensor, + torch.cuda.ByteTensor, + torch.cuda.CharTensor, + torch.cuda.IntTensor, + torch.cuda.ShortTensor, + torch.cuda.LongTensor, + torch.Stream, + torch.Event, + torch.cuda.Stream, + torch.cuda.Event, + torch.xpu.Stream, + torch.xpu.Event, + } + if hasattr(torch, "hpu"): + _in_graph_class_list.update( + { + torch.hpu.Stream, + torch.hpu.Event, + } + ) + + return set(tensortype_to_dtype.keys()) | _in_graph_class_list + + @staticmethod + @functools.cache + def supported_c_new_functions(): + exceptions = [ + getattr(builtins, name).__new__ + for name in dir(builtins) + if isinstance(getattr(builtins, name), type) + and issubclass(getattr(builtins, name), BaseException) + ] + return { + object.__new__, + dict.__new__, + set.__new__, + frozenset.__new__, + tuple.__new__, + list.__new__, + }.union(exceptions) + + @staticmethod + def is_supported_new_method(value): + # TODO(anijain2305) - Extend this to support objects with default tp_new + # functions. + return value in UserDefinedClassVariable.supported_c_new_functions() + + def can_constant_fold_through(self): + return self.value in self._constant_fold_classes() + + def has_key_in_generic_dict(self, tx: "InstructionTranslator", key): + if tx.output.side_effects.has_pending_mutation_of_attr(self, key): + mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True) + return not isinstance(mutated_attr, variables.DeletedVariable) + + return key in self.value.__dict__ + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": + from . import ConstantVariable, EnumVariable + + source = AttrSource(self.source, name) if self.source is not None else None + + if name == "__name__": + return ConstantVariable.create(self.value.__name__) + elif name == "__qualname__": + return ConstantVariable.create(self.value.__qualname__) + elif name == "__dict__": + options = {"source": source} + return variables.GetAttrVariable(self, name, **options) + elif name == "__mro__": + attr_source = self.source and TypeMROSource(self.source) + return VariableTracker.build(tx, self.value.__mro__, attr_source) + + # Special handling of collections.OrderedDict.fromkeys() + # Wrap it as GetAttrVariable(collections.OrderedDict, "fromkeys") to make it consistent with + # collections.defaultdict, and both will be handled at UserDefinedClassVariable.call_method(). + # Otherwise, it would be wrapped as UserDefinedObjectVariable(collections.OrderedDict.fromkeys), + # and we need duplicate code to handle both cases. + if ( + self.value in {collections.OrderedDict, collections.defaultdict} + and name == "fromkeys" + ): + return super().var_getattr(tx, name) + + try: + obj = inspect.getattr_static(self.value, name) + except AttributeError: + if type(self.value) is type: + raise_observed_exception( + AttributeError, + tx, + args=[ + f"type object '{self.value.__name__}' has no attribute '{name}'" + ], + ) + else: + # Cannot reason about classes with a custom metaclass + # See: test_functions::test_getattr_metaclass + obj = None + + if name == "__new__" and UserDefinedClassVariable.is_supported_new_method(obj): + return super().var_getattr(tx, name) + + if name in cmp_name_to_op_mapping and not isinstance(obj, types.FunctionType): + return variables.GetAttrVariable(self, name, source=source) + + if isinstance(obj, staticmethod): + return VariableTracker.build(tx, obj.__get__(self.value), source) + elif isinstance(obj, classmethod): + if isinstance(obj.__func__, property): + fget_vt = VariableTracker.build(tx, obj.__func__.fget) + return fget_vt.call_function(tx, [self], {}) + return variables.UserMethodVariable(obj.__func__, self, source=source) + elif isinstance(obj, types.ClassMethodDescriptorType): + # e.g.: inspect.getattr_static(dict, "fromkeys") + # inspect.getattr_static(itertools.chain, "from_iterable") + func = obj.__get__(None, self.value) + return VariableTracker.build(tx, func, source) + elif source: + if inspect.ismemberdescriptor(obj): + return VariableTracker.build(tx, obj.__get__(self.value), source) + + if ConstantVariable.is_literal(obj): + return ConstantVariable.create(obj) + elif isinstance(obj, enum.Enum): + return EnumVariable(obj) + elif self.value is collections.OrderedDict: + return variables.GetAttrVariable(self, name) + elif name in getattr(self.value, "__dict__", {}) or ( + self.value.__module__.startswith("torch.") + or self.value.__module__ == "torch" + ): + if source: + return VariableTracker.build(tx, obj, source) + + if ( + source + and not inspect.ismethoddescriptor(obj) + and not is_wrapper_or_member_descriptor(obj) + ): + return VariableTracker.build(tx, obj, source) + + return super().var_getattr(tx, name) + + def _call_cross_entropy_loss(self, tx: "InstructionTranslator", args, kwargs): + """ + functional: input, target, weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', + label_smoothing=0.0 + + non functional ctor: weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', + label_smoothing=0.0 + + non functional loss call: input, target, optional_output + """ + from . import ConstantVariable + + def normalize_args( + weight=ConstantVariable.create(None), + size_average=ConstantVariable.create(None), + ignore_index=ConstantVariable.create(-100), + reduce=ConstantVariable.create(None), + reduction=ConstantVariable.create("mean"), + label_smoothing=ConstantVariable.create(0.0), + ): + return ( + weight, + size_average, + ignore_index, + reduce, + reduction, + label_smoothing, + ) + + ( + weight, + size_average, + ignore_index, + reduce_arg, + reduction, + label_smoothing, + ) = normalize_args(*args, **kwargs) + + def fake_cross_entropy_loss(input, target): + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + torch.nn.functional.cross_entropy, + *proxy_args_kwargs( + [ + input, + target, + weight, + size_average, + ignore_index, + reduce_arg, + reduction, + label_smoothing, + ], + {}, + ), + ), + ) + + return variables.LambdaVariable(fake_cross_entropy_loss) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if ( + name == "__subclasses__" + and len(args) == 0 + and not kwargs + and "__subclasses__" not in self.value.__dict__ + ): + source = self.source + if self.source: + source = AttrSource(self.source, "__subclasses__") + source = CallFunctionNoArgsSource(source) + return VariableTracker.build(tx, self.value.__subclasses__(), source) + elif ( + self.value in {collections.OrderedDict, collections.defaultdict} + and name == "fromkeys" + ): + return variables.BuiltinVariable.call_custom_dict_fromkeys( + tx, self.value, *args, **kwargs + ) + elif self.value is collections.OrderedDict and name == "move_to_end": + return args[0].call_method(tx, name, [*args[1:]], kwargs) + elif name == "__eq__" and len(args) == 1 and hasattr(args[0], "value"): + return variables.ConstantVariable(self.value == args[0].value) + elif name == "__ne__" and len(args) == 1 and hasattr(args[0], "value"): + return variables.ConstantVariable(self.value != args[0].value) + elif issubclass(self.value, dict) and name != "__new__": + # __new__ is handled below + return variables.BuiltinVariable(dict).call_method(tx, name, args, kwargs) + elif issubclass(self.value, (set, frozenset)) and name != "__new__": + # __new__ is handled below + return variables.BuiltinVariable(set).call_method(tx, name, args, kwargs) + elif ( + name == "__new__" + and self.value is collections.OrderedDict + and isinstance(args[0], UserDefinedClassVariable) + and args[0].value is collections.OrderedDict + ): + if kwargs and len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + return variables.ConstDictVariable( + {}, collections.OrderedDict, mutation_type=ValueMutationNew() + ) + elif name == "__new__" and UserDefinedClassVariable.is_supported_new_method( + self.value.__new__ + ): + return tx.output.side_effects.track_new_user_defined_object( + self, + args[0], + args[1:], + ) + elif name == "__setattr__" and self.ban_mutation: + unimplemented( + gb_type="Class attribute mutation when the __dict__ was already materialized", + context=str(self.value), + explanation="Dyanmo does not support tracing mutations on a class when its __dict__ is materialized", + hints=graph_break_hints.SUPPORTABLE, + ) + return super().call_method(tx, name, args, kwargs) + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from ..side_effects import SideEffects + from .builder import wrap_fx_proxy + + constant_args = check_constant_args(args, kwargs) + + if self.can_constant_fold_through() and constant_args: + # constant fold + return variables.ConstantVariable.create( + self.as_python_constant()( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ), + ) + elif self.value is torch.nn.CrossEntropyLoss: + return self._call_cross_entropy_loss(tx, args, kwargs) + elif self.value is contextlib.nullcontext: + # import here to avoid circular dependency + from .ctx_manager import NullContextVariable + + return NullContextVariable(*args, **kwargs) + elif self.value is collections.OrderedDict: + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.construct_dict), + [self, *args], + kwargs, + ) + elif self.value is collections.defaultdict: + if len(args) == 0: + default_factory = variables.ConstantVariable.create(None) + else: + default_factory, *args = args + dict_vt = variables.BuiltinVariable.call_custom_dict( + tx, dict, *args, **kwargs + ) + return DefaultDictVariable( + dict_vt.items, + collections.defaultdict, + default_factory, + mutation_type=ValueMutationNew(), + ) + elif is_typeddict(self.value): + if self.value.__optional_keys__: + unimplemented( + gb_type="TypedDict with optional keys", + context=str(self.value), + explanation="Dyanmo does not support tracing TypedDict with optional keys", + hints=[ + "Avoid using TypedDict with optional keys", + *graph_break_hints.SUPPORTABLE, + ], + ) + return variables.BuiltinVariable(dict).call_dict(tx, *args, **kwargs) + elif self.value is collections.deque: + maxlen = variables.ConstantVariable.create(None) + + def deque_signature(iterable=None, maxlen=None): + pass + + try: + bound_args = inspect.signature(deque_signature).bind(*args, **kwargs) + except TypeError as e: + unimplemented( + gb_type="collections.deque() with bad arguments", + context=f"args={args}, kwargs={kwargs}", + explanation="Detected call to collections.deque() with bad arguments.", + hints=[ + "Fix the call to collections.deque().", + *graph_break_hints.USER_ERROR, + ], + from_exc=e, + ) + + if "iterable" in bound_args.arguments: + if not bound_args.arguments["iterable"].has_force_unpack_var_sequence( + tx + ): + unimplemented( + gb_type="collections.deque() with bad iterable argument", + context=f"args={args}, kwargs={kwargs}", + explanation="Call to collections.deque() has an iterable argument that Dynamo cannot " + "convert to a list.", + hints=[ + "Use a simpler sequence type that Dynamo can convert to a list " + "(e.g. list, tuple, list iterator, etc.)", + *graph_break_hints.USER_ERROR, + ], + ) + items = bound_args.arguments["iterable"].force_unpack_var_sequence(tx) + else: + items = [] + + if "maxlen" in bound_args.arguments: + maxlen = bound_args.arguments["maxlen"] + + return variables.lists.DequeVariable( + items, maxlen=maxlen, mutation_type=ValueMutationNew() + ) + elif self.value is weakref.ref: + if len(args) > 1: + callback = args[1] + else: + callback = variables.ConstantVariable.create(None) + return variables.WeakRefVariable(args[0], callback) + elif self.value is functools.partial: + if not args: + unimplemented( + gb_type="missing args to functools.partial", + context="", + explanation="functools.partial requires at least one argument", + hints=[ + "Fix the functools.partial call.", + *graph_break_hints.USER_ERROR, + ], + ) + # The first arg, a callable (the ctor below will assert on types) + fn = args[0] + rest_args = args[1:] + # guards for the produced FunctoolsPartialVariable are installed in FunctoolsPartialVariable ctor from the + # args and keywords + return variables.functions.FunctoolsPartialVariable( + fn, args=rest_args, keywords=kwargs + ) + elif self.value is warnings.catch_warnings and not args: + return variables.CatchWarningsCtxManagerVariable.create(tx, kwargs) + elif self.value is torch.cuda.device and not kwargs and len(args) == 1: + if not args[0].is_python_constant(): + raise_type_error_exc( + tx, "torch.cuda.device() requires a constant argument" + ) + return variables.CUDADeviceVariable.create(tx, args[0].as_python_constant()) + elif ( + issubclass(type(self.value), type) + and hasattr( + self.value, "__enter__" + ) # TODO(voz): These can invoke user code! + and hasattr( + self.value, "__exit__" + ) # TODO(voz): These can invoke user code! + and self.is_standard_new() + and SideEffects.cls_supports_mutation_side_effects(self.value) + and self.source + and not is_forbidden_context_manager(self.value) + ): + from . import TorchCtxManagerClassVariable + from .functions import ( + BaseUserFunctionVariable, + FunctionDecoratedByContextlibContextManagerVariable, + ) + + # graph break on any contextlib.* that it is not contextlib.contextmanager + # Some of the APIs below are not supported because they rely on features + # that Dynamo doesn't play well today (i.e. contextlib.suppress) + if self.value in ( + contextlib._AsyncGeneratorContextManager, + contextlib.closing, + contextlib.redirect_stdout, + contextlib.redirect_stderr, + contextlib.suppress, + contextlib.ExitStack, + contextlib.AsyncExitStack, + ): + # We are not changing the behavior of Dynamo as these function were + # already ignored on trace_rules.py before #136033 landed + unimplemented( + gb_type="unsupported contextlib.* API", + context=f"{self.value}", + explanation=f"{self.value} not supported. This may be due to its use of " + "context-specific operations that are not supported in " + "Dynamo yet (i.e. Exception handling)", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + if self.value is contextlib._GeneratorContextManager and isinstance( + args[0], (BaseUserFunctionVariable, TorchCtxManagerClassVariable) + ): + if not torch._dynamo.config.enable_trace_contextlib: + unimplemented( + gb_type="attempted to trace contextlib.contextmanager", + context=f"args={args}", + explanation="Tracing contextlib.contextmanager is disabled.", + hints=[ + "Set torch._dynamo.config.enable_trace_contextlib = True", + ], + ) + + # Special treatments for certain context managers created via + # contextlib, because + # 1. we (pytorch) own their impls + # 2. it's tedious to trace through them, so we effectively + # "allow in graph" them without sacrificing soundness. + # + # We would typically reach here via either + # 1. the instance construction in `with ctx_manager(...):`: + # https://github.com/python/cpython/blob/3.12/Lib/contextlib.py#L301 + # 2. calling a function decorated with a context manager: + # https://github.com/python/cpython/blob/3.12/Lib/contextlib.py#L122 + # + # So we basically trace through the surface part of the + # contextlib code, and then special case the shared remaining + # logic (the actual context manager instance construction and + # usage later on). + if isinstance(args[0], TorchCtxManagerClassVariable): + fn_var = args[0] + args_list = args[1].items + kwargs_dict = args[2].keys_as_python_constant() + return fn_var.call_function(tx, args_list, kwargs_dict) + + # Wrap UserFunctionVariable in FunctionDecoratedByContextlibContextManagerVariable + # if the function is annotated with @contextlib.contextmanager + # This shouldn't be necessary once generator functions are fully + # supported in dynamo + args = [ + FunctionDecoratedByContextlibContextManagerVariable( + args[0], source=args[0].source + ) + ] + args[1:] + + cm_obj = tx.output.side_effects.track_new_user_defined_object( + variables.BuiltinVariable(object), + self, + args, + ) + cm_obj.call_method(tx, "__init__", args, kwargs) + return cm_obj + elif is_namedtuple_cls(self.value): + fields = namedtuple_fields(self.value) + # check if this a quasi-namedtuple or a real one + if self.value.__module__ == "torch.return_types": + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + "torch.return_types", + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + items = args[0].force_unpack_var_sequence(tx) + else: + field_defaults = self.value._field_defaults + + items = list(args) + items.extend([None] * (len(fields) - len(items))) + + var_tracker_kwargs = {} + for field_name, var_tracker in zip(fields, items): + if var_tracker is None: + if field_name in kwargs: + field_var = kwargs[field_name] + else: + assert field_name in field_defaults + field_var = VariableTracker.build( + tx, field_defaults[field_name] + ) + var_tracker_kwargs[field_name] = field_var + + for name, value in var_tracker_kwargs.items(): + assert name in fields + items[fields.index(name)] = value + + assert all(x is not None for x in items) + + # Modify mutability of namedtuple for sourcelesss instantiations. + from .base import AttributeMutationNew + + return variables.NamedTupleVariable( + items, self.value, mutation_type=AttributeMutationNew() + ) + elif self.value is torch.Size: + # This simulates `THPSize_pynew`, the C impl for `Size.__new__`. + tup = variables.BuiltinVariable(tuple).call_function(tx, args, kwargs) + return SizeVariable(tup.items) + elif is_frozen_dataclass(self.value) and self.is_standard_new(): + fields = dataclasses.fields(self.value) + fields_source = DataclassFieldsSource(self.source) + items = list(args) + items.extend([None] * (len(fields) - len(items))) + + default_kwargs = {} + for ind, field, var_tracker in zip(itertools.count(), fields, items): + if var_tracker is None: + if field.name in kwargs: + var_tracker = kwargs[field.name] + else: + if not field.init: + continue + + if field.default is not dataclasses.MISSING: + var_tracker = VariableTracker.build( + tx, + field.default, + source=AttrSource( + GetItemSource(fields_source, ind), "default" + ), + ) + elif field.default_factory is not dataclasses.MISSING: + factory_fn = VariableTracker.build( + tx, field.default_factory + ) + var_tracker = factory_fn.call_function(tx, [], {}) + else: + # if we are subclass, the constructor could possibly + # be missing args + continue + + default_kwargs[field.name] = var_tracker + kwargs.update(default_kwargs) + + var = tx.output.side_effects.track_new_user_defined_object( + variables.BuiltinVariable(object), self, args + ) + var.call_method(tx, "__init__", args, kwargs) + return var + elif ( + self.value in self._in_graph_classes() + or is_traceable_wrapper_subclass_type(self.value) + ): + # torch.LongTensor cannot accept a list of FakeTensors. + # So we stack the list of FakeTensors instead. + if ( + np + and self.value in tensortype_to_dtype + and len(args) == 1 + and isinstance(args[0], variables.ListVariable) + and len(args[0].items) > 1 + and all(x.is_tensor() for x in args[0].items) + ): + # Stack FakeTensor + stacked = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + torch.stack, + *proxy_args_kwargs(args, kwargs), + ), + ) + args = [stacked] + + if issubclass(self.value, torch.Stream): + from .constant import ConstantVariable + from .lists import TupleVariable + + # Register newly created stream for reconstruction + var_kwargs = ConstDictVariable( + {ConstantVariable(k): v for k, v in kwargs.items()} + ) + var_args = TupleVariable(list(args)) + stream = self.value( + *(var_args.as_python_constant()), + **(var_kwargs.as_python_constant()), + ) + from ..graph_bytecode_inputs import register_graph_created_object + from .streams import StreamVariable + + ind = register_graph_created_object( + stream, + StreamVariable.make_construct_in_graph_stream_fn( + var_args, var_kwargs + ), + ) + tensor_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", get_external_object_by_index, (ind,), {} + ), + ) + elif issubclass(self.value, torch.Event): + from .constant import ConstantVariable + from .lists import TupleVariable + + # Register newly created event for reconstruction + var_kwargs = ConstDictVariable( + {ConstantVariable(k): v for k, v in kwargs.items()} + ) + var_args = TupleVariable(list(args)) + event = self.value( + *(var_args.as_python_constant()), + **(var_kwargs.as_python_constant()), + ) + from ..graph_bytecode_inputs import register_graph_created_object + from .streams import EventVariable + + ind = register_graph_created_object( + event, + EventVariable.make_construct_in_graph_event_fn( + var_args, var_kwargs + ), + ) + tensor_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", get_external_object_by_index, (ind,), {} + ), + ) + else: + tensor_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + *proxy_args_kwargs(args, kwargs), + ), + ) + + return tensor_variable + elif self.value is random.Random: + if len(args) == 1 and args[0].is_python_constant(): + seed = args[0].as_python_constant() + else: + seed = None + random_object = random.Random(seed) + return RandomVariable(random_object) + elif ( + self.value is types.MappingProxyType + and len(args) == 1 + and isinstance(args[0], variables.ConstDictVariable) + ): + # types.MappingProxyType is a read-only proxy of the dict. If the + # original dict changes, the changes are reflected in proxy as well. + return variables.MappingProxyVariable(args[0]) + elif SideEffects.cls_supports_mutation_side_effects(self.value) and self.source: + with do_not_convert_to_tracable_parameter(): + return tx.inline_user_function_return( + VariableTracker.build( + tx, polyfills.instantiate_user_defined_class_object + ), + [self, *args], + kwargs, + ) + return super().call_function(tx, args, kwargs) + + def is_standard_new(self): + """Check for __new__ being overridden""" + new_fn = inspect.getattr_static(self.value, "__new__", None) + if isinstance(new_fn, staticmethod): + new_fn = new_fn.__func__ + return new_fn is object.__new__ + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> "ConstantVariable": + if self.source: + source = AttrSource(self.source, name) + install_guard(source.make_guard(GuardBuilder.HASATTR)) + return variables.ConstantVariable(hasattr(self.value, name)) + return super().call_obj_hasattr(tx, name) + + def const_getattr(self, tx: "InstructionTranslator", name): + if name == "__name__": + return self.value.__name__ + return super().const_getattr(tx, name) + + def is_python_hashable(self): + return True + + def get_python_hash(self): + return hash(self.value) + + def is_python_equal(self, other): + return ( + isinstance(other, variables.UserDefinedClassVariable) + and self.value is other.value + ) + + +class UserDefinedExceptionClassVariable(UserDefinedClassVariable): + @property + def fn(self): + return self.value + + +class NO_SUCH_SUBOBJ: + pass + + +def call_random_fn(tx, fn, args, kwargs): + from .builder import VariableBuilder + + args = [x.as_python_constant() for x in args] + kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + random_call_index = len(tx.output.random_calls) + example_value = fn(*args, **kwargs) + source = RandomValueSource(random_call_index) + tx.output.random_calls.append((fn, args, kwargs)) + # TODO: arguably, this should route to wrap_symint/wrap_symfloat + # (currently hypothetical), but I'm not going to poke my hand in + # this nest for now + return VariableBuilder(tx, source).wrap_unspecialized_primitive(example_value) + + +class UserDefinedObjectVariable(UserDefinedVariable): + """ + Mostly objects of defined type. Catch-all for something where we only know the type. + """ + + _nonvar_fields = { + "value", + "value_type", + "attrs_directly_modifed_on_dict", + *UserDefinedVariable._nonvar_fields, + } + + def __init__( + self, + value, + *, + value_type=None, + cls_source=None, + base_cls_vt=None, + init_args=None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.value = value + self.value_type = value_type or type(value) + assert type(value) is self.value_type + # This is used with __new__, when the new object is sourceless but the user class can be sourceful. + self.cls_source = cls_source + if cls_source is None and self.source is not None: + self.cls_source = TypeSource(self.source) + + # These attributes are used to reconstruct the user defined object. The + # pseudo code looks like this. Builtin C __new__ do not support kwargs, + # so init_args is sufficient. + # obj = base_cls.__new__(user_cls, *args) + self.base_cls_vt = base_cls_vt + self.init_args = init_args + + # This records names of the attributes that were modified via instance + # `__dict__` directly, rather than the normal setattr path. + # + # TODO consider emulating `obj.__dict__` as a `ConstDictVariable` to get + # rid of these workarounds here and in `GetAttrVariable`. + self.attrs_directly_modifed_on_dict = set() + + import torch.utils._pytree as pytree + + self.is_pytree_constant_class = pytree.is_constant_class(self.value_type) + if pytree.is_constant_class(self.value_type) and self.source: + install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH)) + + def __str__(self) -> str: + inner = self.value_type.__name__ + if inner in [ + "builtin_function_or_method", + "getset_descriptor", + "method_descriptor", + "method", + ]: + inner = str(getattr(self.value, "__name__", None)) + return f"{self.__class__.__name__}({inner})" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.value_type.__name__})" + + def is_underlying_vt_modified(self, side_effects): + return False + + def python_type(self): + return self.value_type + + def as_python_constant(self): + if self.is_pytree_constant_class and self.source: + # NOTE pytree constants created in the torch.compile region will + # NOT be guarded (even though they have a source set) + return self.value + # TODO else try reconstructing the object by, e.g., leveraging side + # effects and `as_python_constant`. + return super().as_python_constant() + + def guard_as_python_constant(self): + if self.source: + install_guard(self.source.make_guard(GuardBuilder.ID_MATCH)) + return self.value + return super().guard_as_python_constant() + + def torch_function_check(self): + assert has_torch_function(self), ( + f"calling torch function on object without __torch_function__ {self}" + ) + + def get_torch_fn(self, tx): + self.torch_function_check() + from .torch_function import get_torch_function_fn + + return get_torch_function_fn(tx, self) + + def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): + self.torch_function_check() + + from .torch_function import call_torch_function + + return call_torch_function( + tx, + self.get_torch_fn(tx), + fn, + types, + args, + kwargs, + ) + + @staticmethod + @functools.cache + def _supported_random_functions(): + fns = { + random.random, + random.randint, + random.randrange, + random.uniform, + } + return fns + + def _maybe_get_baseclass_method(self, name): + if name not in getattr(self.value, "__dict__", {}): + try: + return inspect.getattr_static(type(self.value), name) + except AttributeError: + pass + return None + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from . import ConstantVariable, UserMethodVariable + + method = self._maybe_get_baseclass_method(name) + if method is not None: + if method is object.__init__: + return ConstantVariable.create(None) + + if is_standard_setattr(method) or isinstance(self.value, threading.local): + return self.method_setattr_standard(tx, *args, **kwargs) + + if is_standard_delattr(method): + return self.method_setattr_standard( + tx, args[0], variables.DeletedVariable() + ) + + if method is object.__eq__ and len(args) == 1 and not kwargs: + other = args[0] + if not isinstance(other, UserDefinedObjectVariable): + return variables.ConstantVariable.create(NotImplemented) + + # TODO(anijain2305) - Identity checking should already be a part + # of the cmp_eq polyfill function. + return ConstantVariable.create(self.value is other.value) + + if torch._dynamo.config.enable_faithful_generator_behavior and isinstance( + self.value, types.GeneratorType + ): + unimplemented( + gb_type="call_method on generator", + context=f"object={self.value}, method={name}, args={args}, kwargs={kwargs}", + explanation="Detected a method call to a user-defined generator object. " + "This is not fully supported.", + hints=[ + "Set `torch._dynamo.config.enable_faithful_generator_behavior = False`. Note that this " + "may cause silent incorrectness, since we will eagerly unpack generators instead of lazily " + "evaluating them.", + ], + ) + + # check for methods implemented in C++ + if isinstance(method, types.FunctionType): + source = self.source + source_fn = None + if source: + source_fn = self.get_source_by_walking_mro(name) + # TODO(jansel): add a guard to check for monkey patching? + from ..mutation_guard import unpatched_nn_module_init + + if method is torch.nn.Module.__init__: + method = unpatched_nn_module_init + return UserMethodVariable( + method, self, source_fn=source_fn, source=source + ).call_function(tx, args, kwargs) + + if method is list.__len__ and self.source and not (args or kwargs): + install_guard(self.source.make_guard(GuardBuilder.SEQUENCE_LENGTH)) + return ConstantVariable(len(self.value)) + + return super().call_method(tx, name, args, kwargs) + + def method_setattr_standard( + self, tx: "InstructionTranslator", name, value, directly_update_dict=False + ): + try: + name = name.as_python_constant() + except NotImplementedError: + unimplemented( + gb_type="non-const setattr name on user-defined object", + context=f"object={self}, name={name}, value={value}", + explanation="Detected a call to `setattr` of a user-defined object with a non-constant name.", + hints=["Ensure that the name is a string."], + ) + assert tx.output.side_effects.is_attribute_mutation(self), ( + "Attempted setattr on a user-defined object that does not have " + "an AttributeMutation mutation_type" + ) + + if directly_update_dict: + self.attrs_directly_modifed_on_dict.add(name) + else: + tmp = self.try_get_descritor_and_setter_py_func(name) + if tmp: + descriptor, setter = tmp + # Emulate + # https://github.com/python/cpython/blob/3.11/Objects/object.c#L1371-L1452 + desc_source = None + func_source = None + if self.cls_source: + desc_source = self.get_source_by_walking_mro(name) + # use `type(...)` to ignore instance attrs. + func_source = AttrSource(TypeSource(desc_source), "__set__") + desc_var = VariableTracker.build(tx, descriptor, desc_source) + func_var = VariableTracker.build(tx, setter, func_source) + args = [desc_var, self, value] + return func_var.call_function(tx, args, {}) + # NOTE: else we assume the descriptor (if any) has a + # side-effect-free `__set__` as far as Dynamo tracing is concerned. + + # Emulate the standard setattr on instance dict. + tx.output.side_effects.store_attr(self, name, value) + return variables.ConstantVariable(None) + + def needs_slow_setattr(self): + return not is_standard_setattr( + inspect.getattr_static(self.value, "__setattr__", None) + ) and not isinstance(self.value, threading.local) + + def unpack_var_sequence(self, tx): + if ( + self.source + and self._maybe_get_baseclass_method("__iter__") is list.__iter__ + and self._maybe_get_baseclass_method("__len__") is list.__len__ + and self._maybe_get_baseclass_method("__getitem__") is list.__getitem__ + ): + install_guard(self.source.make_guard(GuardBuilder.SEQUENCE_LENGTH)) + return [ + variables.LazyVariableTracker.create( + self.value[k], + source=GetItemSource(self.source, k), + ) + for k in range(len(self.value)) + ] + return super().unpack_var_sequence(tx) + + def has_force_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: + try: + variables.BuiltinVariable(iter).call_function(tx, [self], {}) + return True + except ObservedTypeError: + handle_observed_exception(tx) + return False + + def force_unpack_var_sequence(self, tx): + result = [] + iter_ = variables.BuiltinVariable(iter).call_function(tx, [self], {}) + + while True: + try: + r = iter_.next_variable(tx) + result.append(r) + except ObservedUserStopIteration: + handle_observed_exception(tx) + break + return result + + def next_variable(self, tx): + return self.call_method(tx, "__next__", [], {}) + + def is_supported_random(self): + try: + return self.value in self._supported_random_functions() + except TypeError: + # TypeError: unhashable type + return False + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if ( + self.is_supported_random() + and all(k.is_python_constant() for k in args) + and all(v.is_python_constant() for v in kwargs.values()) + ): + return call_random_fn(tx, self.value, args, kwargs) + elif istype(self.value, types.MethodType): + func = self.value.__func__ + obj = self.value.__self__ + if ( + func is torch.utils._contextlib._DecoratorContextManager.clone + and variables.TorchCtxManagerClassVariable.is_matching_cls( + obj.__class__ + ) + and not (args or kwargs) + ): + return variables.TorchCtxManagerClassVariable( + obj.__class__ + ).call_function(tx, args, kwargs) + + if ( + func is torch.autograd.grad_mode.inference_mode.clone + and obj.__class__ is torch.autograd.grad_mode.inference_mode + ): + # simulate the inference_mode.clone implementation + var = variables.ConstantVariable(obj.mode) + return variables.TorchCtxManagerClassVariable( + obj.__class__ + ).call_function(tx, [var], kwargs) + + if self.source is None: + unimplemented( + gb_type="attempted to call sourceless user-defined object as a method", + context=f"object={self.value}, function={func}, args={args}, kwargs={kwargs}", + explanation="Dynamo does not support this.", + hints=[ + f"Ensure the user-defined object {self.value} is constructed outside the compiled region.", + ], + ) + func_src = AttrSource(self.source, "__func__") + func_var = VariableTracker.build(tx, func, func_src) + obj_src = AttrSource(self.source, "__self__") + obj_var = VariableTracker.build(tx, obj, obj_src) + return func_var.call_function(tx, [obj_var] + args, kwargs) + elif callable(self.value): + if self.source: + source = AttrSource(self.cls_source, "__call__") + install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH)) + return self.call_method(tx, "__call__", args, kwargs) + + return super().call_function(tx, args, kwargs) + + def _check_for_getattr(self): + return get_custom_getattr(self.value) + + def _is_c_defined_property(self, subobj): + if not isinstance(subobj, property): + return False + + # pybind def_readwrite is implemented via PyCFunction. At the python level, it is visible as a property whose + # fget is an instancemethod wrapper - https://docs.python.org/3/c-api/method.html#c.PyInstanceMethod_Check + + # If we have a PyCFunction, we make an assumption that there is no side effect. + return isinstance( + subobj.fget, types.BuiltinFunctionType + ) or torch._C._dynamo.utils.is_instancemethod(subobj.fget) + + def _getattr_static(self, name): + subobj = inspect.getattr_static(self.value, name, NO_SUCH_SUBOBJ) + + # In some cases, we have to do dynamic lookup because getattr_static is not enough. For example, threading.local + # has side-effect free __getattribute__ and the attribute is not visible without a dynamic lookup. + # NOTE we assume the following descriptors are side-effect-free as far + # as Dynamo tracing is concerned. + if not object_has_getattribute(self.value) and ( + subobj is NO_SUCH_SUBOBJ # e.g., threading.local + or inspect.ismemberdescriptor(subobj) # e.g., __slots__ + or inspect.isgetsetdescriptor(subobj) # e.g., __dict__ + or self._is_c_defined_property(subobj) + ): + # Call __getattribute__, we have already checked that this is not overridden and side-effect free. We don't + # want to call getattr because it can be user-overridden. + subobj = type(self.value).__getattribute__(self.value, name) + elif object_has_getattribute(self.value) and subobj is NO_SUCH_SUBOBJ: + # If the object has an overridden getattribute method, Dynamo has + # already tried tracing it, and encountered an AttributeError. We + # call getattr_static only when the __getattribute__ tracing fails + # (check var_getattr impl). So, it is safe here to raise the + # AttributeError. + raise AttributeError + + return subobj + + def should_skip_descriptor_setter(self, attr_name): + # Check if `attr_name` corresponds to a descriptor. + descriptor = inspect.getattr_static(type(self.value), attr_name, None) + setter = inspect.getattr_static(type(descriptor), "__set__", None) + if setter: + # Skip if `__set__` was traceable (no need to redo the side effect). + if inspect.isfunction(setter): + return True + # For untraceable `__set__` we should still skip if the attribute + # was mutated via instance `__dict__`. + elif attr_name in self.attrs_directly_modifed_on_dict: + return True + return False + + def try_get_descritor_and_setter_py_func(self, attr_name): + descriptor = inspect.getattr_static(type(self.value), attr_name, None) + setter = inspect.getattr_static(type(descriptor), "__set__", None) + if inspect.isfunction(setter): + return (descriptor, setter) + return None + + def has_key_in_generic_dict(self, tx: "InstructionTranslator", key): + if tx.output.side_effects.has_pending_mutation_of_attr(self, key): + mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True) + return not isinstance(mutated_attr, variables.DeletedVariable) + + return key in self.value.__dict__ + + def get_source_by_walking_mro(self, name): + assert self.cls_source is not None + + for idx, klass in enumerate(type(self.value).__mro__): + if name in klass.__dict__: + if idx != 0: + mro_source = TypeMROSource(self.cls_source) + klass_source = GetItemSource(mro_source, idx) + else: + klass_source = self.cls_source + dict_source = TypeDictSource(klass_source) + out_source = DictGetItemSource(dict_source, name) + + for absent_idx in range(1, idx): + # Insert a guard that the name is not present in the mro hierarchy + mro_source = TypeMROSource(self.cls_source) + klass_source = GetItemSource(mro_source, absent_idx) + dict_source = TypeDictSource(klass_source) + install_guard( + dict_source.make_guard( + functools.partial( + GuardBuilder.DICT_CONTAINS, key=name, invert=True + ) + ) + ) + # Insert a guard that the name is not present in the object __dict__ + if ( + self.source + and hasattr(self.value, "__dict__") + and name not in self.value.__dict__ + ): + install_guard( + self.source.make_guard( + functools.partial( + GuardBuilder.NOT_PRESENT_IN_GENERIC_DICT, attr=name + ) + ) + ) + return out_source + + unimplemented( + gb_type="could not find name in object's mro", + context=f"name={name}, object type={type(self.value)}, mro={type(self.value).__mro__}", + explanation=f"Could not find name `{name}` in mro {type(self.value).__mro__}", + hints=[ + f"Ensure the name `{name}` is defined somewhere in {self.value}'s type hierarchy.", + *graph_break_hints.USER_ERROR, + ], + ) + + def var_getattr(self, tx: "InstructionTranslator", name): + from . import ConstantVariable + + source = AttrSource(self.source, name) if self.source else None + + if object_has_getattribute(self.value): + getattribute_fn = inspect.getattr_static( + type(self.value), "__getattribute__" + ) + if self.source: + new_source = AttrSource(self.source, "__getattribute__") + try: + return variables.UserMethodVariable( + getattribute_fn, self, source=new_source + ).call_function(tx, [ConstantVariable.create(name)], {}) + except ObservedAttributeError: + # Pass through to __getattr__ if __getattribute__ fails + handle_observed_exception(tx) + + if tx.output.side_effects.has_pending_mutation_of_attr(self, name): + result = tx.output.side_effects.load_attr(self, name, deleted_ok=True) + if isinstance(result, variables.DeletedVariable): + raise_observed_exception( + AttributeError, + tx, + args=[ + f"'{type(self.value).__name__}' object has no attribute '{name}'" + ], + ) + return result + + if name == "__dict__": + options = {"source": source} + return variables.GetAttrVariable(self, name, **options) + + # TODO(anijain2305) - Investigate if we need specialization for more + # dunder attrs. inspect.getattr_static does not return correct value for + # them. + if name == "__class__": + cls_source = source + if cls_source is None: + cls_source = self.cls_source + options = {"source": cls_source} + return UserDefinedClassVariable(type(self.value), **options) + + try: + subobj = self._getattr_static(name) + except AttributeError: + subobj = NO_SUCH_SUBOBJ + getattr_fn = self._check_for_getattr() + if isinstance(getattr_fn, types.FunctionType): + # Dynamo is going to trace the __getattr__ function with + # args=name. Set the source accordingly. + if ( + getattr_fn is unpatched_nn_module_getattr + and isinstance(self, variables.UnspecializedNNModuleVariable) + # prevent against overwriting of params/buffers/submodules + and istype(self.value._parameters, dict) + and istype(self.value._buffers, dict) + and istype(self.value._modules, dict) + ): + # Manually trace out the nn module __getattr__ to avoid large compilation latency. + out = self.manually_trace_nn_module_getattr(tx, name) + else: + new_source = None + if self.source: + new_source = AttrSource(self.source, "__getattr__") + out = variables.UserMethodVariable( + getattr_fn, self, source=new_source + ).call_function(tx, [ConstantVariable.create(name)], {}) + + if self.source and getattr_fn is torch.nn.Module.__getattr__: + if isinstance( + out, + ( + variables.UnspecializedNNModuleVariable, + variables.NNModuleVariable, + ), + ): + # nn_module_stack source is BC surface area. Ensure that + # mod._modules["linear"] is reflected as mod.linear for + # nn_module_stack. + out.set_nn_module_stack_source( + AttrSource(self.get_nn_module_stack_source(), name) + ) + return out + + elif getattr_fn is not None: + unimplemented( + gb_type="User-defined object with non-function __getattr__", + context=f"object={self.value}, name={name}, getattr_fn={getattr_fn}", + explanation=f"Found a non-function __getattr__ {getattr_fn} from a user-defined object {self.value} " + f" when attempting to getattr `{name}`", + hints=[ + "Ensure the object's __getattr__ is a function type.", + ], + ) + + from ..mutation_guard import unpatched_nn_module_init + + if subobj is torch.nn.Module.__init__: + subobj = unpatched_nn_module_init + + subobj_from_class = inspect.getattr_static( + self.value.__class__, name, NO_SUCH_SUBOBJ + ) + is_accessible_from_type_mro = ( + subobj_from_class is subobj + and self.cls_source is not None + and self.source is not None + and hasattr(self.value, "__dict__") + and name not in self.value.__dict__ + ) + + if isinstance(subobj, property): + if self.source: + # Read the class attribute to reach the property + source = self.get_source_by_walking_mro(name) + # Get the getter function + source = AttrSource(source, "fget") + + fget_vt = VariableTracker.build(tx, subobj.fget, source=source) + return fget_vt.call_function(tx, [self], {}) + elif isinstance(subobj, _collections._tuplegetter): + # namedtuple fields are represented by _tuplegetter, and here we + # emulate its `__get__`, which is implemented in C. + _, (idx, _) = subobj.__reduce__() + # Don't go through the `__getitem__` method anymore, see + # https://github.com/python/cpython/blob/470941782f74288823b445120f6383914b659f23/Modules/_collectionsmodule.c#L2690 + assert isinstance(self, UserDefinedTupleVariable) + return self._tuple_vt.items[idx] + elif isinstance(subobj, staticmethod): + # Safe because `staticmethod.__get__` basically won't trigger user + # code and just returns the underlying `__func__`: + # https://github.com/python/cpython/blob/3.11/Objects/funcobject.c#L1088-L1100 + if is_accessible_from_type_mro: + # Accessing from __dict__ does not resolve the descriptor, it + # returns a staticmethod object, so access the __func__ + # attribute to get to the actual function. + source = AttrSource(self.get_source_by_walking_mro(name), "__func__") + func = subobj.__get__(self.value) + return VariableTracker.build(tx, func, source) + elif isinstance(subobj, classmethod): + source_fn = None + if is_accessible_from_type_mro: + # Accessing from __dict__ does not resolve the descriptor, it + # returns a classmethod object, so access the __func__ + # attribute to get to the actual function. + source_fn = AttrSource(self.get_source_by_walking_mro(name), "__func__") + return variables.UserMethodVariable( + subobj.__func__, + self.var_getattr(tx, "__class__"), + source_fn=source_fn, + source=source, + ) + elif isinstance(subobj, types.ClassMethodDescriptorType): + # e.g.: inspect.getattr_static({}, "fromkeys") + func = subobj.__get__(self.value, None) + return VariableTracker.build(tx, func, source) + elif is_lru_cache_wrapped_function(subobj): + # getattr_static returns the lru_wrapped function, and we cannot + # extract the underlying method from the wrapped function. To handle + # it, manually create a wrapped user method vt. + return variables.WrapperUserMethodVariable( + subobj, "__wrapped__", self, source=source + ) + elif inspect.getattr_static( + type(subobj), "__get__", NO_SUCH_SUBOBJ + ) is not NO_SUCH_SUBOBJ and not is_wrapper_or_member_descriptor( + type(subobj).__get__ + ): + # Emulate https://github.com/python/cpython/blob/3.11/Objects/object.c#L1271-L1285 + # + # Attribute has a __get__ method. Create a user defined object vt + # for the subobj, and then trace the __get__ method. + descriptor_source = None + descriptor_get_source = None + if self.cls_source: + # To access the method descriptor from the udf object w/o using + # inspect.getattr_static, we can look into the class mro + descriptor_source = self.get_source_by_walking_mro(name) + descriptor_get_source = AttrSource( + TypeSource(descriptor_source), "__get__" + ) + descriptor_var = VariableTracker.build(tx, subobj, descriptor_source) + else: + # Sourceless Builder does not support user defined objects + descriptor_var = UserDefinedObjectVariable(subobj) + + # The arguments of the __get__ function are (self, instance, owner) + # self - descriptor_var + # instance - instance of the class, represented by self here + # owner - class object + owner_var = UserDefinedClassVariable(type(self.value)) + return variables.UserMethodVariable( + subobj.__get__.__func__, descriptor_var, source=descriptor_get_source + ).call_function(tx, [self, owner_var], {}) + elif isinstance(subobj, types.FunctionType) or ( + isinstance(subobj, types.MethodType) + and isinstance(self.value, torch.nn.Module) + ): + # Since we get subobj via self._getattr_static, which may not trigger dynamic lookup. + # Static lookup can't tell us it's a method or function correctly, + # so we trigger dynamic lookup here to get the correct type. + dynamic_subobj = getattr(self.value, name) + + while dynamic_subobj is subobj and hasattr(subobj, "_torchdynamo_inline"): + subobj = subobj._torchdynamo_inline + dynamic_subobj = subobj + source = AttrSource(source, "_torchdynamo_inline") if source else None + + if isinstance(subobj, types.MethodType): + if dynamic_subobj.__self__ is not self.value: + if not isinstance(dynamic_subobj.__func__, types.FunctionType): + unimplemented( + gb_type="User-defined object method with non-function __func__", + context=f"object={self.value}, name={name}, method={dynamic_subobj}, " + f"method.__self__={dynamic_subobj.__self__}, method.__func__={dynamic_subobj.__func__}", + explanation=f"Method {dynamic_subobj} (name={name}) of user-defined object {self.value} has a " + f"__func__ ({dynamic_subobj.__func__}) that is not a function type.", + hints=[ + "Ensure that the method's __func__ is a function type.", + ], + ) + + # Use the __self__ attribute of the method to find the + # source of the new self object. + self_source = None + if source is not None: + self_source = AttrSource(source, "__self__") + object_vt = VariableTracker.build( + tx, dynamic_subobj.__self__, self_source + ) + + return variables.UserMethodVariable( + dynamic_subobj.__func__, object_vt + ) + func = subobj.__func__ + else: + assert isinstance(subobj, types.FunctionType) + func = subobj + + if inspect.ismethod(dynamic_subobj): + source_fn = None + if is_accessible_from_type_mro: + source_fn = self.get_source_by_walking_mro(name) + return variables.UserMethodVariable( + func, self, source_fn=source_fn, source=source + ) + elif inspect.isfunction(dynamic_subobj): + return VariableTracker.build(tx, func, source) + + if ( + # wrap the source only if inline_inbuilt_nn_modules is set or fsdp modules. This is a temporary solution to + # keep Dynamo behavior compatible with no inlining, as there will be some delay to turn on the flag in + # fbcode. + ( + torch._dynamo.config.inline_inbuilt_nn_modules + or isinstance(self, variables.FSDPManagedNNModuleVariable) + ) + and source + and isinstance(self, variables.UnspecializedNNModuleVariable) + # export has some awkwardness around specialized and unspecialized modules. Skip wrapping source for export + # usecase for now. + and (not tx.output.export or torch._dynamo.config.install_free_tensors) + ): + # Recalculate source for params/buffers + if name in ("_buffers", "_parameters"): + source = UnspecializedParamBufferSource(self.source, name) + source = self._wrap_source(source) + + if subobj is not NO_SUCH_SUBOBJ: + if ( + is_wrapper_or_member_descriptor(subobj) + or torch._C._dynamo.utils.is_instancemethod(subobj) + or is_cython_function(subobj) + ): + options = {"source": source} + return variables.GetAttrVariable(self, name, **options) + if source: + if is_accessible_from_type_mro: + source = self.get_source_by_walking_mro(name) + + return variables.LazyVariableTracker.create(subobj, source) + else: + # Check if the subobj is accessible from the class itself. If the class source is known, we can create a + # sourceful variable tracker. + if self.cls_source is not None: + subobj_from_class = inspect.getattr_static( + self.value.__class__, name, NO_SUCH_SUBOBJ + ) + if subobj_from_class is subobj: + src_from_class = AttrSource(self.cls_source, name) + return variables.LazyVariableTracker.create( + subobj_from_class, src_from_class + ) + + return VariableTracker.build(tx, subobj) + + # Earlier we were returning GetAttrVariable but its incorrect. In absence of attr, Python raises AttributeError. + raise_observed_exception( + AttributeError, + tx, + args=[f"'{type(self.value).__name__}' object has no attribute '{name}'"], + ) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> "VariableTracker": + if self.source: + install_guard( + AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR) + ) + + try: + var_vt = self.var_getattr(tx, name) + return variables.ConstantVariable.create( + not isinstance(var_vt, variables.DeletedVariable) + ) + except ObservedAttributeError: + handle_observed_exception(tx) + return variables.ConstantVariable.create(False) + + def is_python_hashable(self): + raise_on_overridden_hash(self.value, self) + return True + + def get_python_hash(self): + # default hash + return hash(self.value) + + def is_python_equal(self, other): + # id check + return self.value is other.value + + +class FrozenDataClassVariable(UserDefinedObjectVariable): + @staticmethod + def create(tx, value, source): + from dataclasses import fields + + assert is_frozen_dataclass(value) + + field_map = {} + for field in fields(value): + if hasattr(value, field.name): + field_map[field.name] = VariableTracker.build( + tx, + getattr(value, field.name), + source and AttrSource(source, field.name), + ) + + return FrozenDataClassVariable(value, fields=field_map, source=source) + + def __init__(self, value, fields=None, **kwargs) -> None: + super().__init__(value, **kwargs) + if fields is None: + fields = {} + self.fields = fields + + def as_python_constant(self): + # NOTE: this is an intentionally limited version of + # `as_python_constant` for `nonstrict_trace` implementation. + from dataclasses import fields + + import torch.utils._pytree as pytree + + if not istype( + self.value, (pytree.TreeSpec, pytree.LeafSpec, pytree.ConstantNode) + ): + # TODO loosen this restriction and fix `as_proxy`. + raise NotImplementedError( + "currently can't reconstruct arbitrary frozen dataclass instances" + ) + + # LeafSpec is deprecated, use treespec_leaf() instead + if istype(self.value, pytree.LeafSpec): + return pytree.treespec_leaf() + + args = [] + kwargs = {} + for field in fields(self.value): + if field.init: + data = self.fields[field.name].as_python_constant() + if getattr(field, "kw_only", False): + kwargs[field.name] = data + else: + args.append(data) + + # This is safe because we know the TreeSpec classes constructors don't + # have external side effects. + ctor = self.python_type() + return ctor(*args, **kwargs) + + def as_proxy(self): + from dataclasses import fields + + args = [] + kwargs = {} + for field in fields(self.value): + proxy = self.fields[field.name].as_proxy() + if hasattr(field, "kw_only") and field.kw_only: + kwargs[field.name] = proxy + else: + args.append(proxy) + + # TODO this isn't really safe, because + # 1. it could invoke a user defined `__post_init__`. + # 2. it could invoke a user defined `__init__` if the class _subclasses_ + # a frozen dataclass. + # Either of the above could end up mutating external state. + ctor = self.python_type() + return ctor(*args, **kwargs) + + def reconstruct(self, codegen: "PyCodegen") -> None: + from dataclasses import fields + + # Handle specific pytree classes + import torch.utils._pytree as pytree + + if isinstance(self.value, pytree.TreeSpec) and self.value.is_leaf(): + # Create a new LeafSpec instance by calling the constructor + codegen.add_push_null( + lambda: codegen.load_import_from("torch.utils._pytree", "LeafSpec") + ) + codegen.extend_output(create_call_function(0, False)) + return + + # For general frozen dataclasses, reconstruct by calling the constructor + # with the field values as arguments + dataclass_cls = self.python_type() + + if hasattr(dataclass_cls, "__post_init__"): + unimplemented( + gb_type="Frozen dataclass with __post_init__", + context=f"dataclass={dataclass_cls.__name__}", + explanation="Cannot reconstruct frozen dataclass with __post_init__ method, " + "as it may have side effects that would be incorrectly replayed.", + hints=[ + "Remove the __post_init__ method from the frozen dataclass.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + # Collect positional and keyword-only arguments + pos_args = [] + kw_args = [] + for field in fields(dataclass_cls): + if not field.init: + continue + field_vt = self.fields.get(field.name) + if field_vt is None: + unimplemented( + gb_type="Frozen dataclass with missing field", + context=f"dataclass={dataclass_cls.__name__}, field={field.name}", + explanation=f"Cannot reconstruct frozen dataclass: field '{field.name}' " + "was not tracked during tracing.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + if getattr(field, "kw_only", False): + kw_args.append((field.name, field_vt)) + else: + pos_args.append(field_vt) + + # Load the dataclass constructor + codegen.add_push_null( + lambda: codegen.append_output( + codegen.create_load_const_unchecked(dataclass_cls) + ) + ) + # Reconstruct all arguments + for arg_vt in pos_args: + codegen(arg_vt) + for _, arg_vt in kw_args: + codegen(arg_vt) + # Call the constructor + total_args = len(pos_args) + len(kw_args) + if kw_args: + kw_names = tuple(name for name, _ in kw_args) + codegen.extend_output( + codegen.create_call_function_kw(total_args, kw_names, push_null=False) + ) + else: + codegen.extend_output(create_call_function(total_args, False)) + + # NB: This is called during __init__ for a frozen dataclass + # use this to accumulate the most up-to-date field values + def method_setattr_standard(self, tx: "InstructionTranslator", name, value): + self.fields[name.as_python_constant()] = value + return super().method_setattr_standard(tx, name, value) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.value_type.__name__})" + + def is_python_hashable(self): + # TODO - Check corner cases like eq=False, hash=False etc + return True + + def get_python_hash(self): + return hash(tuple(arg.get_python_hash() for arg in self.fields.values())) + + def is_python_equal(self, other): + is_class_same = self.python_type() is other.python_type() + is_field_name_same = self.fields.keys() == other.fields.keys() + is_field_value_same = all( + value_a.is_python_equal(value_b) + for value_a, value_b in zip(self.fields.values(), other.fields.values()) + ) + return is_class_same and is_field_name_same and is_field_value_same + + +class SourcelessGraphModuleVariable(UserDefinedObjectVariable): + def __init__( + self, + value, + **kwargs, + ) -> None: + super().__init__(value, **kwargs) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + fn_variable = VariableTracker.build(tx, self.value.forward.__func__) + args = [self] + args + return tx.inline_user_function_return( + fn_variable, + args, + kwargs, + ) + + +class UserDefinedExceptionObjectVariable(UserDefinedObjectVariable): + def __init__(self, value, **kwargs): + super().__init__(value, **kwargs) + self.exc_vt = variables.ExceptionVariable(self.value_type, ()) + + @property + def fn(self): + return self.value_type + + def call_method(self, tx, name, args, kwargs): + if ( + name == "__init__" + and (method := self._maybe_get_baseclass_method(name)) + and inspect.ismethoddescriptor(method) + and len(kwargs) == 0 + ): + self.exc_vt.args = args + self.value.args = args + return variables.ConstantVariable(None) + elif ( + name == "__setattr__" + and len(args) == 2 + and args[0].is_constant_match( + "__cause__", "__context__", "__suppress_context__", "__traceback__" + ) + ): + self.exc_vt.call_setattr(tx, args[0], args[1]) + elif name == "with_traceback": + return self.exc_vt.call_method(tx, name, args, kwargs) + return super().call_method(tx, name, args, kwargs) + + @property + def __context__(self): + return self.exc_vt.__context__ + + @property + def args(self): + return self.exc_vt.args + + def set_context(self, context: "variables.ExceptionVariable"): + return self.exc_vt.set_context(context) + + @property + def exc_type(self): + return self.exc_vt.exc_type + + +class KeyedJaggedTensorVariable(UserDefinedObjectVariable): + @staticmethod + def is_matching_object(obj): + mod = sys.modules.get("torchrec.sparse.jagged_tensor") + return mod is not None and type(obj) is mod.KeyedJaggedTensor + + def __init__(self, value, **kwargs) -> None: + from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + assert type(value) is KeyedJaggedTensor + super().__init__(value, **kwargs) + + def var_getattr(self, tx: "InstructionTranslator", name): + if ( + torch._dynamo.config.force_unspec_int_unbacked_size_like_on_torchrec_kjt + and self.source is not None + and name in ("_length_per_key", "_offset_per_key") + ): + with TracingContext.patch(force_unspec_int_unbacked_size_like=True): + return super().var_getattr(tx, name) + return super().var_getattr(tx, name) + + +class IntWrapperVariable(UserDefinedObjectVariable): + # Dummy class to check if the object is an IntWrapper, and turn it into a + # symint + @staticmethod + def is_matching_object(obj): + mod = sys.modules.get("torch.export.dynamic_shapes") + return mod is not None and type(obj) is mod._IntWrapper + + +class RemovableHandleClass: + # Dummy class to pass to python_type of RemovableHandleVariable + # Useful for isinstance check on hooks + pass + + +class RemovableHandleVariable(VariableTracker): + REMOVED = -1 + + def __init__( + self, + mutation_type=None, + # index of the registration in the side_effects owned register_hook/handle list, used during removal. + idx=None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.mutation_type = mutation_type + self.idx = idx + + def call_method(self, tx: "InstructionTranslator", method_name, args, kwargs): + if method_name == "remove": + if self.idx != self.REMOVED: + tx.output.side_effects.remove_hook(self.idx) + self.idx = self.REMOVED + return variables.ConstantVariable.create(None) + super().call_method(tx, method_name, args, kwargs) + + def reconstruct(self, codegen: "PyCodegen"): + if self.idx == self.REMOVED: + # Hook has already been removed, return a dummy handle + codegen.add_push_null( + lambda: codegen.load_import_from( + "torch._dynamo.utils", "invalid_removeable_handle" + ) + ) + codegen.extend_output(create_call_function(0, False)) + return + # unreachable due to codegen.add_cache() when the hook is installed + super().reconstruct(codegen) + + def python_type(self): + return RemovableHandleClass + + +class UserDefinedDictVariable(UserDefinedObjectVariable): + """ + Represents user defined objects that are subclasses of dict/OrderedDict. + + Internally, it uses a ConstDictVariable to represent the dict part of the + variable tracker. For everything else, it falls back to + UserDefinedObjectVariable. + """ + + def __init__(self, value, dict_vt=None, **kwargs): + super().__init__(value, **kwargs) + self._dict_vt = dict_vt + if self._dict_vt is None: + assert self.source is None, ( + "dict_vt must be constructed by builder.py when source is present" + ) + self._dict_vt = variables.ConstDictVariable( + {}, type(value), mutation_type=ValueMutationNew() + ) + self._dict_methods = dict_methods + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + method = self._maybe_get_baseclass_method(name) + if method in self._dict_methods: + # Dict subclasses can override __missing__ to provide fallback + # behavior instead of raising a KeyError. This is used, for example, + # by collections.Counter. + try: + return self._dict_vt.call_method(tx, name, args, kwargs) + except ObservedKeyError: + if ( + name == "__getitem__" + and issubclass(self.python_type(), dict) + and self._maybe_get_baseclass_method("__missing__") + ): + return self.call_method(tx, "__missing__", args, kwargs) + else: + raise + return super().call_method(tx, name, args, kwargs) + + def unpack_var_sequence(self, tx): + if type(self.value).__iter__ in ( + dict.__iter__, + collections.OrderedDict.__iter__, + ): + return self._dict_vt.unpack_var_sequence(tx) + raise NotImplementedError + + def is_underlying_vt_modified(self, side_effects): + return side_effects.is_modified(self._dict_vt) + + @property + def user_cls(self): + return self._dict_vt.user_cls + + @property + def items(self): + return self._dict_vt.items + + def install_dict_keys_match_guard(self): + return self._dict_vt.install_dict_keys_match_guard() + + def install_dict_contains_guard(self): + return self._dict_vt.install_dict_contains_guard() + + def is_python_hashable(self): + raise_on_overridden_hash(self.value, self) + return False + + +class UserDefinedSetVariable(UserDefinedObjectVariable): + """ + Represents user defined objects that are subclasses of set. + + Internally, it uses a SetVariable to represent the set part of the + variable tracker. For everything else, it falls back to + UserDefinedObjectVariable. + """ + + def __init__(self, value, set_vt=None, **kwargs): + super().__init__(value, **kwargs) + self._set_vt = set_vt + + python_type = set if isinstance(value, set) else frozenset + self._set_methods = set_methods if python_type is set else frozenset_methods + + if self._set_vt is None: + assert self.source is None, ( + "set_vt must be constructed by builder.py when source is present" + ) + if python_type is set: + # set is initialized later + self._set_vt = variables.SetVariable( + {}, mutation_type=ValueMutationNew() + ) + else: + init_args = kwargs.get("init_args", {}) + tx = torch._dynamo.symbolic_convert.InstructionTranslator.current_tx() + self._set_vt = variables.BuiltinVariable(python_type).call_function( + tx, init_args, {} + ) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + method = self._maybe_get_baseclass_method(name) + if method in self._set_methods: + return self._set_vt.call_method(tx, name, args, kwargs) + return super().call_method(tx, name, args, kwargs) + + def as_python_constant(self): + return self._set_vt.as_python_constant() + + def unpack_var_sequence(self, tx): + if inspect.getattr_static(self.value, "__iter__") in ( + set.__iter__, + frozenset.__iter__, + ): + return self._set_vt.unpack_var_sequence(tx) + raise NotImplementedError + + @property + def set_items(self): + return self._set_vt.set_items + + @property + def items(self): + return self._set_vt.items + + def is_underlying_vt_modified(self, side_effects): + return side_effects.is_modified(self._set_vt) + + def install_dict_keys_match_guard(self): + return self._set_vt.install_dict_keys_match_guard() + + def install_dict_contains_guard(self): + return self._set_vt.install_dict_contains_guard() + + def is_python_hashable(self): + raise_on_overridden_hash(self.value, self) + return self._set_vt.is_python_hashable() + + def get_python_hash(self): + return self._set_vt.get_python_hash() + + def is_python_equal(self, other): + return isinstance( + other, UserDefinedSetVariable + ) and self._set_vt.is_python_equal(other._set_vt) + + +class UserDefinedListVariable(UserDefinedObjectVariable): + """ + Represents user defined objects that are subclasses of lists. + + Internally, it uses a ListVariable to represent the list part of the + variable tracker. For everything else, it falls back to + UserDefinedObjectVariable. + """ + + def __init__(self, value, list_vt=None, **kwargs): + super().__init__(value, **kwargs) + self._list_vt = list_vt + if self._list_vt is None: + assert self.source is None, ( + "list_vt must be constructed by builder.py when source is present" + ) + self._list_vt = variables.ListVariable([], mutation_type=ValueMutationNew()) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + assert self._list_vt is not None + method = self._maybe_get_baseclass_method(name) + if method in list_methods: + return self._list_vt.call_method(tx, name, args, kwargs) + return super().call_method(tx, name, args, kwargs) + + def unpack_var_sequence(self, tx): + assert self._list_vt is not None + if type(self.value).__iter__ is list.__iter__: + return self._list_vt.unpack_var_sequence(tx) + raise NotImplementedError + + def is_underlying_vt_modified(self, side_effects): + return side_effects.is_modified(self._list_vt) + + def is_python_hashable(self): + raise_on_overridden_hash(self.value, self) + return False + + +class UserDefinedTupleVariable(UserDefinedObjectVariable): + """ + Represents user defined objects that are subclasses of tuple. + + Internally, it uses a TupleVariable to represent the tuple part of the + variable tracker. For everything else, it falls back to + UserDefinedObjectVariable. + """ + + def __init__(self, value, tuple_vt=None, init_args=None, **kwargs): + super().__init__(value, init_args=init_args, **kwargs) + self._tuple_vt = tuple_vt + if self._tuple_vt is None: + assert self.source is None, ( + "tuple_vt must be constructed by builder.py when source is present" + ) + # Emulate `tuple.__new__` + # https://github.com/python/cpython/blob/3.11/Objects/tupleobject.c#L697-L710 + # + # TODO this duplicates the logic in `BuiltinVariable(tuple)` + from torch._dynamo.symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + elems = init_args[0].force_unpack_var_sequence(tx) + self._tuple_vt = variables.TupleVariable( + elems, mutation_type=ValueMutationNew() + ) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + assert self._tuple_vt is not None + method = self._maybe_get_baseclass_method(name) + if method in tuple_methods: + return self._tuple_vt.call_method(tx, name, args, kwargs) + return super().call_method(tx, name, args, kwargs) + + def unpack_var_sequence(self, tx): + assert self._tuple_vt is not None + if type(self.value).__iter__ is tuple.__iter__: + return self._tuple_vt.unpack_var_sequence(tx) + raise NotImplementedError + + def is_python_hashable(self): + raise_on_overridden_hash(self.value, self) + return self._tuple_vt.is_python_hashable() + + def get_python_hash(self): + return self._tuple_vt.get_python_hash() + + def is_python_equal(self, other): + return isinstance( + other, UserDefinedTupleVariable + ) and self._tuple_vt.is_python_equal(other._tuple_vt) + + +class MutableMappingVariable(UserDefinedObjectVariable): + def __init__(self, value, **kwargs): + super().__init__(value, **kwargs) + self.generic_dict_vt = variables.ConstDictVariable({}) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": + # A common pattern in the init code of MutableMapping objects is to + # update the __dict__ attribute. To prevent graph break, we directly + # return a ConstDictVariable for the __dict__attr. + # + # However, users can try to add a new attribute to the class using the + # __dict__ attribute. To catch this, we save the ConstDictVariable for + # the __dict__ and then lookup into this vt for each attr lookup. + if name == "get" and type(self.value).get in ( + collections.abc.Mapping.get, + dict.get, + ): + return variables.UserMethodVariable(polyfills.mapping_get, self) + elif name == "__dict__" and self.source: + self.generic_dict_vt = variables.LazyVariableTracker.create( + self.value.__dict__, AttrSource(self.source, "__dict__") + ) + return self.generic_dict_vt + elif out := self.generic_dict_vt.maybe_getitem_const( + variables.ConstantVariable(name) + ): + return out + else: + return super().var_getattr(tx, name) + + +class RandomVariable(UserDefinedObjectVariable): + pass diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fa9ec30c964a7d4ec3e588cec708bc42f2b0321 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/aot_autograd.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/aot_autograd.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e26c16bf4f1928c4f06f3b6d6b64a7ed9ab3010a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/aot_autograd.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/apis.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/apis.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af98cff1fb30dea19ea39509c7ca9e421640c829 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/apis.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/autograd_function.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/autograd_function.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01e77ed1edd1902b7c68ccb24545621a5b48774b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/autograd_function.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/batch_norm_replacement.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/batch_norm_replacement.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24fb630d7dd30e5d97b1cbd2ff6ba1deae861963 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/batch_norm_replacement.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/benchmark_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/benchmark_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d59f625843f4d25c44480c45b1bd552f14df8139 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/benchmark_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/compile_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/compile_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f81c3c137cd09eb7359399f1dd444a27524c1fa5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/compile_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/compilers.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/compilers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47bfda4a74154c170666ab5f5ed13f71426003c7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/compilers.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/config.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4653e90722395513be79ffdac9b387674e06e6da Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/config.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/deprecated.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/deprecated.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..927cde9a48af943c5cac504dd22ea5d2b3fdda5f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/deprecated.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/eager_transforms.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/eager_transforms.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b189cd67dd6b727aa43cea9c9a7092e4b76e2b3e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/eager_transforms.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/functional_call.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/functional_call.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b28f470cb7a9aaa9a4ce1209c0736c1fe022bd6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/functional_call.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/fx_minifier.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/fx_minifier.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00daed4ae408002788eb7bd5889219b5bc6843c2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/fx_minifier.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/make_functional.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/make_functional.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..268bf5e3f63bd1862840efb1b552f40e92915834 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/make_functional.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/predispatch.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/predispatch.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ed037a83ee84e596da986376a8def39dc105494 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/predispatch.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/pyfunctorch.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/pyfunctorch.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6f8df217e45917ac17fe03d861f6ebbd969e1c9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/pyfunctorch.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/python_key.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/python_key.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..019bd35c383aa6f916bcd8422b4cebacbb0040cb Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/python_key.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/pytree_hacks.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/pytree_hacks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..929332d268d95451c9b8ec6eafc2e1399f690bb1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/pytree_hacks.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/top_operators_github_usage.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/top_operators_github_usage.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..319d8ca31ce04bde5bbb0652cc3a252f00d30762 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/top_operators_github_usage.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..079ca9a711111c3081d61573e4dce95c36ece6f7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/vmap.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/vmap.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c63a831e9774f566b799a996b3bb8f0ba4e74275 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/__pycache__/vmap.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e701ff0faa8a2b73ff494ced51bd8816e21c98c6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/ac_logging_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/ac_logging_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a306e32888dbf24315f6c4ff2bbb0397309e3b8f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/ac_logging_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/graph_info_provider.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/graph_info_provider.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10ac203fd0abf41835a924fcb2415fae80bc0ba8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/graph_info_provider.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/knapsack.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/knapsack.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a7c792c70bcdff0382c11450940c06faa553000 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/knapsack.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/knapsack_evaluator.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/knapsack_evaluator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6c2368bd0612d2abaf7352e6bbe68a708c4c77c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/knapsack_evaluator.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/remat_using_tags_for_fwd_loss_bwd_graph_pass.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/remat_using_tags_for_fwd_loss_bwd_graph_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc7516ba7db898694547e570bd608a343c132779 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/remat_using_tags_for_fwd_loss_bwd_graph_pass.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_offloading/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_offloading/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..10a55772ab58b21573a6eba0356ddd3080164ac7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_offloading/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_offloading/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_offloading/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..015241d8fb779b207f0ba1296fd82399ce2c7a22 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_offloading/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_offloading/__pycache__/activation_offloading.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_offloading/__pycache__/activation_offloading.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87975ad02113e4f109f5cabb8a4cd6f2c0ed2726 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_offloading/__pycache__/activation_offloading.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_offloading/activation_offloading.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_offloading/activation_offloading.py new file mode 100644 index 0000000000000000000000000000000000000000..0a209ef4d824b524564709475eff9954c59cf126 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_activation_offloading/activation_offloading.py @@ -0,0 +1,824 @@ +""" +Activation offloading for memory optimization in (more like post) partitioners. + +This module provides functionality to offload activations to CPU during forward pass +and reload them during backward pass, reducing GPU memory usage. + +Additional TODO: +* given the fact that PT2 stream support is in active development, testings should + be done once that is more finalized. A issue currently known is that with streams, + each iteration will have its own offload streams, but the streams should be shared + across the iterations. +""" + +import logging +import operator +from dataclasses import dataclass + +import torch +import torch.fx as fx +from torch._dynamo.variables.streams import get_current_stream, new_event, new_stream +from torch._inductor import config as inductor_config +from torch._inductor.fx_passes.overlap_scheduling import benchmark_node, is_compute_node +from torch._subclasses.fake_tensor import extract_tensor_metadata +from torch.utils._ordered_set import OrderedSet + +from .. import config +from ..partitioners import _size_of, get_default_op_list, OpTypes + + +log: logging.Logger = logging.getLogger(__name__) + + +# Node name prefixes for offload/reload operations +# NOTE: right now we are using these prefixes as identifiers for offload/reload +CPU_OFFLOAD_PREFIX = "cpu_offload_" +GPU_RELOAD_PREFIX = "gpu_reload_" + + +@dataclass +class ReloadNodeInfo: + """ + Information about backward reload related nodes for each reload operation. + + Pattern: fork → wait_stream → device_put → record_event → join → wait_event + + This pattern is divided into two logical groups for optimization purposes: + - Reload group (fork → wait_stream → device_put → record_event → join): + Performs the actual asynchronous data transfer on a separate stream. + These nodes can be moved earlier in the graph to overlap with computation. + - Wait group (wait_event): + Synchronization point that blocks until the data transfer completes. + This must remain at the point where the reloaded data is first needed. + """ + + reload_group_nodes: list[fx.Node] + wait_event_node: fx.Node + transfer_size_bytes: int + transfer_time_ms: float + + +@dataclass +class ReloadQueueEntry: + """ + Entry in the reload queue for prefetch scheduling. + + Attributes: + pattern: The reload pattern information + remaining_time_ms: Remaining overlap time needed in milliseconds + """ + + pattern: ReloadNodeInfo + remaining_time_ms: float + + +def offload_activation_fw(graph: fx.Graph) -> None: + """ + Insert CPU offload operations in the forward pass graph. + + Offload operations are placed after the last effective use of each tensor marked + for offloading. This ensures the tensor is no longer needed on the GPU before + transferring it to CPU memory. + + NOTE: An alternative approach would offload tensors immediately after generation + to maximize compute-communication overlap. However, this requires additional + synchronization to ensure tensor deletion (which occurs on the default stream) + waits for the asynchronous offload operation to complete. This would necessitate + more complex tracking to separate operation scheduling from memory cleanup. + + Args: + graph: The forward graph to modify + """ + + op_types: OpTypes = get_default_op_list() + + def find_all_effective_users(node: fx.Node) -> OrderedSet[fx.Node]: + """ + Find all effective users of a node, where view ops extend the lifetime of the + original node. If a user is a view op, recursively find users of the view. + """ + effective_users: OrderedSet[fx.Node] = OrderedSet() + for user in node.users: + if user.op == "output": + continue + effective_users.add(user) + if op_types.is_view(user): + effective_users.update(find_all_effective_users(user)) + + return effective_users + + output_node: fx.Node = graph.find_nodes(op="output")[0] + fwd_outputs: tuple[fx.Node, ...] = output_node.args[ + 0 + ] # pyrefly: ignore [bad-assignment] + node_to_offload: dict[fx.Node, fx.Node] = dict() + node_to_index: dict[fx.Node, int] = { + node: idx for idx, node in enumerate(graph.nodes) + } + + for node in fwd_outputs: + if node.meta.get("saved_for_offloading", False) is False: + continue + + # Find insertion point, which is the last use + all_effective_users: OrderedSet[fx.Node] = find_all_effective_users(node) + if all_effective_users := find_all_effective_users(node): + last_user = max(all_effective_users, key=lambda n: node_to_index[n]) + else: + last_user: fx.Node = node + + # Insert the CPU offload operation after the last user + with graph.inserting_after(last_user): + cpu_node: fx.Node = graph.call_function( + torch.ops.prims.device_put.default, + args=(node, torch.device("cpu")), + kwargs={"non_blocking": True}, + name=CPU_OFFLOAD_PREFIX + str(node.name), + ) + cpu_node.meta["val"] = node.meta["val"].to(torch.device("cpu")) + cpu_node.meta["tensor_meta"] = extract_tensor_metadata(cpu_node.meta["val"]) + + node_to_offload[node] = cpu_node + + # Update the return node args + output_node.update_arg( + 0, tuple(node_to_offload.get(node, node) for node in fwd_outputs) + ) + + +def reload_activation_bw(graph: fx.Graph) -> None: + """ + Insert GPU reload operations in the backward pass graph. + + Reload operations are placed before the first use of each offloaded tensor, + transferring it from CPU back to GPU memory before it's needed for computation. + + Args: + graph: The backward graph to modify + """ + + node_to_index: dict[fx.Node, int] = { + node: idx for idx, node in enumerate(graph.nodes) + } + output_node: fx.Node = graph.find_nodes(op="output")[0] + + for node in graph.find_nodes(op="placeholder"): + if node.meta.get("saved_for_offloading", False) is False: + continue + + # Find insertion point, which is the first use or output node if no users + # The later should not happen, but inserting before output node is safe + insert_point: fx.Node = ( + min(node.users.keys(), key=lambda n: node_to_index[n]) + if node.users + else output_node + ) + + # Insert the GPU reload operation before the first user + original_device: torch.Device = node.meta["original_device"] + with graph.inserting_before(insert_point): + gpu_node: fx.Node = graph.call_function( + torch.ops.prims.device_put.default, + args=(node, original_device), + kwargs={"non_blocking": True}, + name=str(node.name).replace(CPU_OFFLOAD_PREFIX, GPU_RELOAD_PREFIX), + ) + gpu_node.meta["val"] = node.meta["val"].to(original_device) + gpu_node.meta["tensor_meta"] = extract_tensor_metadata(gpu_node.meta["val"]) + + # Replace all uses of the CPU tensor with the GPU tensor + for user in list(node.users.keys()): + if user != gpu_node: + user.replace_input_with(node, gpu_node) + + +def can_offload( + node: fx.Node, + fwd_outputs: OrderedSet[fx.Node], + model_outputs: OrderedSet[fx.Node], + static_lifetime_input_nodes: OrderedSet[fx.Node], +) -> bool: + """ + Determine if a node can be offloaded to CPU. + + Args: + node: The node to check + fwd_outputs: Forward module outputs, including model outputs and activations + model_outputs: Model outputs + + NOTE: Additional context for the logic behind these offloading checks: + + * fwd_outputs: Only saved intermediate tensors should be offloaded. + + * model_outputs / static_lifetime_input_nodes: Tensors that may be accessed outside + the compiled region (e.g., model outputs, static inputs) cannot be offloaded as + they must remain accessible beyond the scope of the compiled graph. + + * views / getitems: Offloading such nodes can lead to segmentation faults. + + * contiguous: Offloading non-contiguous tensors causes CPU-side stride changes + during both forward and backward passes when using the Inductor backend. While + these stride changes cancel each other out, they introduce significant compute + overhead. This is due to the contiguity check in ir.py (see link below). + TODO: This restriction could potentially be bypassed in the future. + Reference: https://github.com/pytorch/pytorch/blob/44ac69388a4a5eb463dbd2a13f00d1e3b924566c/torch/_inductor/ir.py#L3214 + + Additional criteria to consider for offloading optimization: + + * Tensor size: Small tensors may not fully utilize available bandwidth, reducing the + efficiency gains from offloading. + + * Position in forward/backward graph: Activations generated near the end of the forward + pass are typically consumed near the beginning of the backward pass. Offloading such + tensors may be counterproductive since they are quickly reloaded, not having sufficient + time to overlap the transfer with computation. + """ + + log.debug(f"Checking node {node.name} for offloading...") # noqa: G004 + + op_types: OpTypes = get_default_op_list() + + if node not in fwd_outputs: + log.debug("\tSkipped! Can only offload nodes in fwd_module_outputs.") + return False + if node in model_outputs: + log.debug("\tSkipped! Cannot offload model outputs.") + return False + if node in static_lifetime_input_nodes: + log.debug("\tSkipped! Cannot offload static input nodes.") + return False + if op_types.is_view(node): + log.debug("\tSkipped! Cannot offload views.") + return False + if node.target == operator.getitem: + log.debug("\tSkipped! Cannot offload getitems.") + return False + if hasattr(node, "meta") and "val" in node.meta: + if ( + isinstance(val := node.meta["val"], torch.Tensor) + and not val.is_contiguous() + ): + log.debug("\tSkipped! Cannot offload non-contiguous tensors.") + return False + + log.debug("\tGood!") + return True + + +def choose_offload_sets( + fwd_module: fx.GraphModule, + num_fwd_outputs: int, + static_lifetime_input_nodes: OrderedSet[fx.Node], +) -> bool: + """ + Decide which nodes will be offloaded based on the marked nodes and feasibility. + Marks nodes with "saved_for_offloading" if they should and can be offloaded. + + Args: + fwd_module: Forward graph module + bwd_module: Backward graph module + num_fwd_outputs: Number of forward outputs + + Returns: + bool: Whether activation offloading should be performed + """ + + fwd_outputs: OrderedSet[fx.Node] = OrderedSet( + fwd_module.graph.find_nodes(op="output")[0].args[0] + ) + model_outputs: OrderedSet[fx.Node] = OrderedSet( + fwd_module.graph.find_nodes(op="output")[0].args[0][:num_fwd_outputs] + ) + + should_perform_offloading = False + for node in fwd_module.graph.nodes: + if node.meta.get("should_offload", False) and can_offload( + node, fwd_outputs, model_outputs, static_lifetime_input_nodes + ): + node.meta["saved_for_offloading"] = True + node.meta["original_device"] = node.meta["val"].device + should_perform_offloading = True + + return should_perform_offloading + + +def offload_chosen_sets( + fwd_module: fx.GraphModule, + bwd_module: fx.GraphModule, +) -> None: + """ + Add offload and reload nodes to the forward and backward graphs. + This function adds device_put operations without any stream handling. + + Args: + fwd_module: Forward module graph + bwd_module: Backward module graph + """ + + # Add offload nodes in forward graph + offload_activation_fw(fwd_module.graph) + + # Update backward graph inputs to be offloaded tensors + bwd_inputs: dict[str, fx.Node] = { + node.name: node for node in bwd_module.graph.find_nodes(op="placeholder") + } + for fwd_node in fwd_module.graph.find_nodes(op="output")[0].args[0]: + if CPU_OFFLOAD_PREFIX not in fwd_node.name: + continue + + bwd_node: fx.Node = bwd_inputs[fwd_node.name.replace(CPU_OFFLOAD_PREFIX, "")] + with bwd_module.graph.inserting_after(bwd_node): + bwd_offload_node: fx.Node = bwd_module.graph.placeholder(name=fwd_node.name) + + bwd_offload_node.meta.update(fwd_node.meta) + bwd_offload_node.meta["saved_for_offloading"] = True + bwd_offload_node.meta["original_device"] = bwd_node.meta["val"].device + bwd_node.replace_all_uses_with(bwd_offload_node) + bwd_module.graph.erase_node(bwd_node) + + # Add reload nodes in backward graph + reload_activation_bw(bwd_module.graph) + + +def add_forward_offload_stream_ops(graph: fx.Graph) -> None: + """ + Add stream operations for forward pass CPU offloading. + + Pattern: record_event → fork → wait_event → record_stream → device_put → record_event_2 → join → wait_event_2 + + This ensures that: + 1. Offloading waits for the last use to complete (record_event on default stream) + 2. Offloading happens on a separate stream (fork → wait_event → device_put) + 3. The tensor is marked as used in the offload stream (record_stream) + 4. Execution returns to the default stream after offloading and + waits for offload to complete (record_event_2 → join → wait_event_2) + + NOTE: For stream optimization and overlapping compute with communication, + the "wait_event_2" ops can be sinked to the end of the graph. + + Args: + graph: The forward graph to modify + """ + + # Find all CPU offload nodes + offload_nodes: list[fx.Node] = [ + node + for node in graph.nodes + if CPU_OFFLOAD_PREFIX in node.name and node.op == "call_function" + ] + if not offload_nodes: + return + + # Get default stream id and offload stream id + current_stream_id: int = get_current_stream( + offload_nodes[0].args[0].meta["val"].device # type: ignore[assignment] + ) + offload_stream_id: int = new_stream() + + for offload_node in offload_nodes: + offload_ready_event_id: int = new_event() + offload_completion_event_id: int = new_event() + + # Get the tensor being offloaded + tensor_node: fx.Node = offload_node.args[0] # type: ignore[assignment] + + with graph.inserting_before(offload_node): + # Record event on default stream to ensure last use completes + graph.call_function( + torch.ops.streams.record_event.default, + args=(offload_ready_event_id, current_stream_id), + ) + # Fork to offload stream + graph.call_function( + torch.ops.streams.fork.default, + args=(current_stream_id, offload_stream_id), + name=f"stream_in_{offload_node.name}", + ) + # Wait for the event on offload stream + graph.call_function( + torch.ops.streams.wait_event.default, + args=(offload_ready_event_id, offload_stream_id), + ) + # Inform the CUDA Caching Allocator that this tensor will be accessed in the + # offload stream. Without this, the program may prematurely free its memory + # even though the async offload operation is still in progress, and this can + # lead to memory corruption, especially with reordering for compute and + # communication overlaps. + graph.call_function( + torch.ops.streams.record_stream.default, + args=(tensor_node, offload_stream_id), + name=f"record_stream_{tensor_node.name}", + ) + with graph.inserting_after(offload_node): + # Record event on offload stream after device_put completes + record_event_node = graph.call_function( + torch.ops.streams.record_event.default, + args=(offload_completion_event_id, offload_stream_id), + ) + with graph.inserting_after(record_event_node): + # Join back to default stream + join_node = graph.call_function( + torch.ops.streams.join.default, + args=(offload_stream_id, current_stream_id), + name=f"stream_out_{offload_node.name}", + ) + with graph.inserting_after(join_node): + # Wait for the offload to complete on default stream + graph.call_function( + torch.ops.streams.wait_event.default, + args=(offload_completion_event_id, current_stream_id), + ) + + +def add_backward_reload_stream_ops(graph: fx.Graph) -> None: + """ + Add stream operations for backward pass GPU reloading. + + Pattern: fork → wait_stream → device_put → record_event → join → wait_event + + This ensures that: + 1. Reloading doesn't start prematurely (fork → wait_stream) + 2. Reloading happens on a separate stream (device_put) + 3. First use waits for reload completion (record_event → join → wait_event) + + NOTE: The pattern consists of two logical groups: + - First group (fork → wait_stream → device_put → record_event → join): + Performs asynchronous data transfer on a separate stream + - Second group (wait_event): + Data transfer completion check when the data is actually needed + + For prefetch optimization, the first group can be moved earlier in the graph + to overlap computation with data transfer, while the wait_event must remain + at its current position to prevent blocking computation unnecessarily. + + Args: + graph: The backward graph to modify + """ + + # Find all GPU reload nodes + reload_nodes: list[fx.Node] = [ + node + for node in graph.nodes + if GPU_RELOAD_PREFIX in node.name and node.op == "call_function" + ] + if not reload_nodes: + return + + # Get default stream id and offload stream id + current_stream_id: int = get_current_stream( + reload_nodes[0].args[0].meta["original_device"] # type: ignore[assignment] + ) + reload_stream_id: int = new_stream() + + for reload_node in reload_nodes: + event_id: int = new_event() + + with graph.inserting_before(reload_node): + # Fork to reload stream + graph.call_function( + torch.ops.streams.fork.default, + args=(current_stream_id, reload_stream_id), + name=f"stream_in_{reload_node.name}", + ) + # Wait for default stream to prevent premature reloading + graph.call_function( + torch.ops.streams.wait_stream.default, + args=(reload_stream_id, current_stream_id), + ) + with graph.inserting_after(reload_node): + # Record event on reload stream after device_put + record_event_node = graph.call_function( + torch.ops.streams.record_event.default, + args=(event_id, reload_stream_id), + ) + with graph.inserting_after(record_event_node): + # Join back to default stream + join_node = graph.call_function( + torch.ops.streams.join.default, + args=(reload_stream_id, current_stream_id), + name=f"stream_out_{reload_node.name}", + ) + with graph.inserting_after(join_node): + # Wait for the event on default stream + graph.call_function( + torch.ops.streams.wait_event.default, + args=(event_id, current_stream_id), + ) + + +def put_offload_nodes_on_separate_stream( + fwd_module: fx.GraphModule, + bwd_module: fx.GraphModule, +) -> None: + """ + Add stream and event related operations around offload nodes. + + Args: + fwd_module: Forward module graph + bwd_module: Backward module graph + """ + + add_forward_offload_stream_ops(fwd_module.graph) + add_backward_reload_stream_ops(bwd_module.graph) + + +def _validate_pattern_nodes( + fork_node: fx.Node, + wait_stream_node: fx.Node, + record_event_node: fx.Node, + join_node: fx.Node, + wait_event_node: fx.Node, +) -> None: + """ + Validate that the pattern nodes match the expected structure. + + Raises ValueError if any node doesn't match expectations. + """ + + if not ( + fork_node.op == "call_function" + and fork_node.target == torch.ops.streams.fork.default + ): + raise ValueError("Expected fork node two nodes before device_put node") + + if not ( + wait_stream_node.op == "call_function" + and wait_stream_node.target == torch.ops.streams.wait_stream.default + ): + raise ValueError("Expected wait_stream node one node before device_put node") + + if not ( + record_event_node.op == "call_function" + and record_event_node.target == torch.ops.streams.record_event.default + ): + raise ValueError("Expected record_event node one node after device_put node") + + if not ( + join_node.op == "call_function" + and join_node.target == torch.ops.streams.join.default + ): + raise ValueError("Expected join node two nodes after device_put node") + + if not ( + wait_event_node.op == "call_function" + and wait_event_node.target == torch.ops.streams.wait_event.default + ): + raise ValueError("Expected wait_event node three nodes after device_put node") + + +def _calculate_transfer_size(device_put_node: fx.Node) -> int: + """Calculate the size in bytes of data being transferred.""" + + return _size_of(device_put_node.args[0]) # pyrefly: ignore [bad-argument-type] + + +def _estimate_transfer_time_in_ms(transfer_size_bytes: int) -> float: + """ + Estimate transfer time in milliseconds based on size and bandwidth. + NOTE: potentially could be standardized in node estimator class + """ + + return transfer_size_bytes / (1024**3) * 1_000 / inductor_config.cpu_gpu_bw + + +def identify_reload_patterns( + graph: fx.Graph, nodes_list: list[fx.Node], node_to_idx: dict[fx.Node, int] +) -> dict[fx.Node, ReloadNodeInfo]: + """ + Identify backward reload patterns in the graph. + + Pattern: fork → wait_stream → device_put → record_event → join → wait_event + + This uses position-based matching since these nodes are inserted together in + add_backward_reload_stream_ops() in a specific order. Since stream operations + do not have data dependencies between them, they are unsuitable for subgroup + pattern matching type of checks. + + Returns a dict mapping device_put node to ReloadNodeInfo containing: + - reload_group_nodes: fork → wait_stream → device_put → record_event → join + - wait_event_node: the wait_event node + - transfer_size_bytes: size of data being transferred + - transfer_time_ms: estimated transfer time in milliseconds + """ + patterns: dict[fx.Node, ReloadNodeInfo] = {} + + # Find all GPU reload device_put nodes whose inputs are placeholder nodes + reload_nodes: list[fx.Node] = [ + node + for node in graph.find_nodes( + op="call_function", target=torch.ops.prims.device_put.default + ) + if GPU_RELOAD_PREFIX in node.name + and ( + node.args + and isinstance(node.args[0], fx.Node) + and node.args[0].op == "placeholder" + ) + ] + + # Extract patterns for each reload device_put node + for reload_node in reload_nodes: + reload_node_idx: int = node_to_idx[reload_node] + + fork_node: fx.Node = nodes_list[reload_node_idx - 2] + wait_stream_node: fx.Node = nodes_list[reload_node_idx - 1] + record_event_node: fx.Node = nodes_list[reload_node_idx + 1] + join_node: fx.Node = nodes_list[reload_node_idx + 2] + wait_event_node: fx.Node = nodes_list[reload_node_idx + 3] + + # Validate the nodes are what we expect + _validate_pattern_nodes( + fork_node, + wait_stream_node, + record_event_node, + join_node, + wait_event_node, + ) + + # Calculate transfer size and time + transfer_size_bytes: int = _calculate_transfer_size(reload_node) + transfer_time_ms: float = _estimate_transfer_time_in_ms(transfer_size_bytes) + + patterns[reload_node] = ReloadNodeInfo( + reload_group_nodes=[ + fork_node, + wait_stream_node, + reload_node, + record_event_node, + join_node, + ], + wait_event_node=wait_event_node, + transfer_size_bytes=transfer_size_bytes, + transfer_time_ms=transfer_time_ms, + ) + + return patterns + + +def reorder_for_prefetch( + nodes_list: list[fx.Node], + reload_patterns: dict[fx.Node, ReloadNodeInfo], +) -> None: + """ + Reorder nodes to prefetch reload operations by directly manipulating the graph. + + This follows the algorithm as follows: + - Go through nodes in reverse order + - When encountering a reload pattern, add it to a queue with its transfer time + - When encountering a compute node, use its runtime to satisfy overlap requirements + - Place reload patterns when their overlap requirement is satisfied + - When encountering placeholder nodes, flush queue as reloads cannot move before inputs + """ + + # Build a set of all nodes in reload groups for quick lookup + reload_group_nodes_set: set[fx.Node] = set() + for pattern in reload_patterns.values(): + reload_group_nodes_set.update(pattern.reload_group_nodes) + + # Queue to hold reload group nodes waiting to be placed (FIFO) + reload_queue: list[ReloadQueueEntry] = [] + + # Loop through nodes in reverse + for node in reversed(nodes_list): + if node.op == "output": + continue + elif node.op == "placeholder": + # Flush queue - place all remaining reloads after the last placeholder + while reload_queue: + entry: ReloadQueueEntry = reload_queue.pop(0) + for reload_group_node in reversed(entry.pattern.reload_group_nodes): + node.append(reload_group_node) + break + elif node in reload_patterns: + pattern: ReloadNodeInfo = reload_patterns[node] + reload_queue.append( + ReloadQueueEntry( + pattern=pattern, remaining_time_ms=pattern.transfer_time_ms + ) + ) + elif node in reload_group_nodes_set: + continue + else: + if not reload_queue: + continue + compute_runtime_ms: float = ( + benchmark_node(node) if is_compute_node(node) else 0 + ) + reload_queue[0].remaining_time_ms -= compute_runtime_ms + + # Pop and place reload if its remaining time is satisfied (<= 0) + if reload_queue[0].remaining_time_ms <= 0: + entry: ReloadQueueEntry = reload_queue.pop(0) + for reload_group_node in entry.pattern.reload_group_nodes: + node.prepend(reload_group_node) + + +def activation_offload_sink_wait(fwd_module: fx.GraphModule) -> None: + """ + Sink wait_event operations for offload completion to the end of the graph. + + This function identifies wait_event nodes for offload completion and moves them + to the end of the graph, allowing computation to overlap with offload operations. + + Args: + fwd_module: Forward module graph + """ + graph: fx.Graph = fwd_module.graph + nodes_list: list[fx.Node] = list(graph.nodes) + node_to_idx: dict[fx.Node, int] = {node: idx for idx, node in enumerate(nodes_list)} + + # Find all CPU offload device_put nodes + offload_nodes: list[fx.Node] = [ + node + for node in graph.find_nodes( + op="call_function", target=torch.ops.prims.device_put.default + ) + if CPU_OFFLOAD_PREFIX in node.name + ] + + # Collect all wait_event nodes that need to be moved + wait_nodes_to_sink: list[fx.Node] = [] + for offload_node in offload_nodes: + offload_idx: int = node_to_idx[offload_node] + wait_event_node: fx.Node = nodes_list[offload_idx + 3] + + # Validate it's actually a wait_event node + if not ( + wait_event_node.op == "call_function" + and wait_event_node.target == torch.ops.streams.wait_event.default + ): + raise ValueError( + f"Expected wait_event node three positions after {offload_node.name}" + ) + + wait_nodes_to_sink.append(wait_event_node) + + # Find the output node, and move all wait_event nodes to just before the output node + output_node: fx.Node = graph.find_nodes(op="output")[0] + for wait_node in wait_nodes_to_sink: + output_node.prepend(wait_node) + + +def activation_reload_prefetch(bwd_module: fx.GraphModule) -> None: + """ + Prefetch backward reload operations by moving them earlier in the graph + to overlap communication with computation. + + This function identifies backward reload patterns (fork → wait_stream → device_put → + record_event → join) and moves them earlier in the execution order to overlap + the data transfer with computation, while keeping the wait_event at its original + position. + + Args: + bwd_module: Backward module graph + """ + graph: fx.Graph = bwd_module.graph + nodes_list: list[fx.Node] = list(graph.nodes) + node_to_idx: dict[fx.Node, int] = {node: idx for idx, node in enumerate(nodes_list)} + + # Step 1: Identify reload patterns + reload_patterns: dict[fx.Node, ReloadNodeInfo] = identify_reload_patterns( + graph, nodes_list, node_to_idx + ) + + # Step 2: Reorder nodes by directly manipulating the graph + reorder_for_prefetch(nodes_list, reload_patterns) + + +def enable_activation_offloading( + fwd_module: fx.GraphModule, + bwd_module: fx.GraphModule, + num_fwd_outputs: int, + static_lifetime_input_nodes: OrderedSet[fx.Node], +) -> None: + """ + Main entry point for activation offloading. + + Args: + fwd_module: Forward module graph + bwd_module: Backward module graph + num_fwd_outputs: Number of forward outputs + """ + + # Step 1: Decide which nodes to offload and mark them + should_perform_offloading: bool = choose_offload_sets( + fwd_module, + num_fwd_outputs, + static_lifetime_input_nodes, + ) + if not should_perform_offloading: + return + + # Step 2: Add offload and reload nodes to the graphs + offload_chosen_sets(fwd_module, bwd_module) + + # Step 3: Put offload nodes on separate stream if configured + if config.activation_offload_separate_stream: + put_offload_nodes_on_separate_stream(fwd_module, bwd_module) + if config.activation_offload_sink_wait: + activation_offload_sink_wait(fwd_module) + if config.activation_reload_prefetch: + activation_reload_prefetch(bwd_module) + + fwd_module.graph.lint() + bwd_module.graph.lint() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f17b9b8526f0df4ec3f1ff5832aaa48b04b82e4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/aot_autograd_result.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/aot_autograd_result.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..288770c7d94aebd2a9a8f459441caa62e7e75863 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/aot_autograd_result.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/autograd_cache.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/autograd_cache.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50c65c18259e3138ff597b8be0faed719439a16d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/autograd_cache.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/collect_metadata_analysis.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/collect_metadata_analysis.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a1f556eb0e611465013614885f912999ccae601 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/collect_metadata_analysis.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/descriptors.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/descriptors.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea97d092a3c357b5d6d8eb5547594f3f93bde069 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/descriptors.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/frontend_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/frontend_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9b3c0d94c34729120fa36282080096143a1c2c4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/frontend_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/functional_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/functional_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46316014850cd59ebf270bdef16b3ddf00084c5a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/functional_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/fx_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/fx_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca54d73f068977bd3e00f6b3cff94129f4039d3d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/fx_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/graph_capture.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/graph_capture.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad62bef5ea85846efd38cea6fa8df36bfc9091b5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/graph_capture.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/graph_capture_wrappers.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/graph_capture_wrappers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16fa29d7059da68a4f2dad790eded5b8acb4a6c1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/graph_capture_wrappers.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/graph_compile.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/graph_compile.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..345f97184f0d88363878edb84983a48c5a1ad1c8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/graph_compile.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/indexed_dict.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/indexed_dict.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..597b90fa63db299b82ff24b496fdeb986858c19d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/indexed_dict.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/input_output_analysis.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/input_output_analysis.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64cbd10a0d0ebf22b3f9d2e828da957d8cf729af Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/input_output_analysis.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/logging_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/logging_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b06b0225a9383014029eccdc5bc0be7f8ca975f0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/logging_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/runtime_wrappers.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/runtime_wrappers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01ba0e9a63709a5e53316244b9086f09d62d1185 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/runtime_wrappers.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/schemas.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/schemas.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..865c5eeceb4680a837aa3a2ed488c47551f387f3 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/schemas.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/streams.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/streams.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2edd9640497a336241e7fab9db83e89912b306cd Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/streams.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/subclass_parametrization.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/subclass_parametrization.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..653eacf78a996a22eed8f9780f521cc548ec5c20 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/subclass_parametrization.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/subclass_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/subclass_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..856f684e402274133af82947fc36a66ca2511198 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/subclass_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c197154ced24822714cd6779c83588c3010f33e5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/__pycache__/utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..11cef0f9205a511605162042b0c016041b5e8413 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -0,0 +1,873 @@ +# mypy: allow-untyped-defs +""" +This module is one of the analysis modules - it takes as input a function or graph +and some preexisting properties, and returns some data that is useful for deciding +how to further proceed with compilation or construct runtime wrappers. + +In particular, the analysis here constructs view and mutation metadata from running +a functionalized version of the graph under compilation. +""" + +import collections +import contextlib +import logging +from collections.abc import Callable +from typing import Optional + +import torch +import torch.utils._pytree as pytree +from torch import Tensor +from torch._guards import detect_fake_mode +from torch._logging import getArtifactLogger +from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode +from torch._subclasses.meta_utils import safe_is_leaf +from torch.fx.experimental.symbolic_shapes import is_concrete_int +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + transform_subclass, +) + +from .descriptors import ( + AOTInput, + AOTOutput, + InputMutationAOTOutput, + IntermediateBaseAOTOutput, + PlainAOTOutput, + TangentAOTInput, +) +from .functional_utils import ( + are_all_mutations_hidden_from_autograd, + are_all_mutations_under_no_grad_or_inference_mode, + from_fun, + has_data_mutation, + has_metadata_mutation, + MetadataKey, + to_fun, + ViewMetaSequence, + was_inductor_storage_resized, +) +from .schemas import ( + InputAliasInfo, + MemoryFormatMeta, + MutationType, + OutputAliasInfo, + OutputType, + ViewAndMutationMeta, +) +from .subclass_utils import create_subclass_meta +from .utils import _get_autocast_states, KNOWN_TYPES, simple_wraps, strict_zip + + +zip = strict_zip + +log = logging.getLogger(__name__) +static_input_logger = getArtifactLogger("torch._dynamo", "cudagraph_static_inputs") + + +# Note [Tangents memory format] +# We assume tangents memory format to be similar to corresponding output's memory_format. +# The idea is that we are technically making a guess about the strides of our tangents, +# while we trace out the joint. +# If runtime specified tangents will not have the same memory format as predicted traced tangents, +# we coerce them at runtime to traced tangents memory format. + + +# Coercing and collecting traced tangents memory format in one recursive traversal +# mypy: ignore-errors +def coerce_tangent_and_suggest_memory_format(x: Tensor): + updated = False + if not isinstance(x, Tensor): + return x, None, updated + + out = x.detach() + + is_subclass = is_traceable_wrapper_subclass(out) + + memory_format = MemoryFormatMeta.from_tensor(out) + + # pyrefly: ignore [missing-attribute] + if memory_format.memory_format is not None: + was = out + # pyrefly: ignore [bad-argument-type] + out = out.contiguous(memory_format=memory_format.memory_format) + updated = was is not out + + # For subclass we keep memory format of outer strides at the beginning of the list + out_memory_format = [memory_format] if is_subclass else memory_format + + # Note [Tangents memory format, Part 2] + # In the same way that "what strides do we assigns to our tangents" is a question + # that we can not answer (and therefore have to guess) as we trace the backward ahead-of-time, + # The same applies to any tensor subclass metadata, when we have tangents that are subclasses. + # To handle this situation, we have two new methods that a tensor subclass can implement: + # (1) __coerce_tangent_metadata__(self) + # Given a subclass with "non-standard" metadata, turn it into a new subclass with "normal" metadata. + # The main example here is a DTensor with the "_Partial" placement. + # If we have a forward output with a _Partial placement, and corresponding tangent + # with a Replicate/Shard placement, we have no way to convert the tangent "back" to a _Partial placement. + # This method lets us avoid the problem entirely by allowing subclasses to ensure that we can never + # have a tangent with "problematic" metadata, that we cannot convert to. + # (1) __coerce_same_metadata_as_tangent__(self, metadata) + # Given a subclass, and a target differing metadata, + # convert self to have the same metadata as the target. + # With DTensor being the main example, we can use this to convert a DTensor with a Replicate() + # placement into one with a Shard() placement, in the case that we "guessed wrong", + # and traced tangents with a Shard() placement at compile time. + # + if is_subclass and hasattr(out, "__coerce_tangent_metadata__"): + out = out.__coerce_tangent_metadata__() # type: ignore[attr-defined] + + if is_subclass: + # pyrefly: ignore [missing-attribute] + attrs = out.__tensor_flatten__()[0] + + for attr in attrs: + elem = getattr(out, attr) + ( + new_elem, + new_elem_memory_format, + elem_updated, + ) = coerce_tangent_and_suggest_memory_format(elem) + # pyrefly: ignore [missing-attribute] + out_memory_format.append(new_elem_memory_format) + if elem_updated: + setattr(out, attr, new_elem) + + return out, out_memory_format, updated + + +# This is a version of functionalization that is specifically designed +# for the AOTAutograd use case. +# +# Unlike functorch's variant, this doesn't use the functorch level system, +# instead it directly uses PyTorch's conventional dispatcher to hit the +# functionalization key. In particular, this means that FunctionalTensorWrapper +# can have autograd data stored directly on it. +# +# In typical AOTAutograd usage, the dispatch key order will look like: +# +# Autograd - Functionalization ~~~~> Proxy Mode - Fake Tensor +# outer tensor inner tensor +# +# Returns: +# - ViewAndMutationMeta, telling us metadata about the inputs and outputs, and +# The list of outputs from the forward, but **only** the outputs that we need +# to pass in as tangents into the backward. +# Specifically, aliased outputs from the forward get regenerated, and don't participate +# in the compiled backward function. +def run_functionalized_fw_and_collect_metadata( + f, + *, + flat_args_descs: list[AOTInput], + keep_input_mutations: bool, + # TODO: refactor to kill this flag + is_train: bool = False, + # Note: this is guaranteed to be set when running under dynamo + static_input_indices: Optional[list[int]] = None, + pre_dispatch: bool = False, +) -> Callable[..., ViewAndMutationMeta]: + memo: dict[Tensor, Tensor] = {} + + def _to_fun(t): + if isinstance(t, Tensor): + if t in memo: + return memo[t] + r = to_fun(t) + memo[t] = r + return r + else: + return t + + @simple_wraps(f) + def inner(*flat_args): + # This function is meant to be run with the forward, which expects a flat list of tensor/symint/other args. + assert all(isinstance(a, tuple(KNOWN_TYPES)) for a in flat_args) + + input_info: list[InputAliasInfo] = [] + output_info: list[OutputAliasInfo] = [] + + prior_grad_enabled = torch.is_grad_enabled() + prior_autocast_states = _get_autocast_states() + + # See Note [Disabling Functionalize TLS Above Python Functionalization] + disable_above = torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ) + + # It doesn't matter if we run this under predispatch or not because it is + # only for figuring out metadata + mode = FunctionalTensorMode(_allow_token_discovery=True) + suppress_pending = contextlib.nullcontext() + fake_mode = detect_fake_mode() + if fake_mode and (shape_env := fake_mode.shape_env): + suppress_pending = shape_env.ignore_fresh_unbacked_symbols() + with disable_above, mode, suppress_pending: + # precondition: The passed in function already handles unflattening inputs + flattening outputs + flat_f_args = pytree.tree_map(_to_fun, flat_args) + flat_f_args_descs = flat_args_descs + flat_f_outs = f(*flat_f_args) + + # Assert that f does NOT have an AOTOutputs in it, easy mistake to + # make! You need to drop the second output before calling this + # function + assert not pytree.tree_any( + lambda x: isinstance(x, AOTOutput), flat_f_outs + ), ( + f"{f} returned AOTOutput when it shouldn't. Did you remember to wrap the " + "function with without_output_descs before passing it here?" + ) + + # NB: this is just to setup the input descriptors, we will + # recreate these descriptors (with the same convention!) when we + # actually do the trace + flat_f_outs_descs = [PlainAOTOutput(i) for i in range(len(flat_f_outs))] + + # We didn't do any tracing, so we don't need to process the + # unbacked symbols, they will just disappear into the ether. + # Also, prevent memoization from applying. + if fake_mode: + fake_mode.epoch += 1 + fake_mode.reset_nt_tensor_id_counter() + + if prior_autocast_states != _get_autocast_states(): + raise RuntimeError( + "AOTAutograd does not support tracing graphs that mutate the autocast state. " + "Dynamo will only insert autocast context managers (e.g. with torch.autocast(..)) into the graph, " + "which will unwind all of their mutations to autocast state before the graph exits. " + "If you encounter this error while using torch.compile, please file a bug." + ) + + # Inspect the state of the input tensor functional wrapper to detect input mutation info + # If inp[i] has a metadata-only mutation, then maybe_inputs_with_mutated_metadata[i] contains the updated version + for arg, f_arg in zip(flat_args, flat_f_args): + # NB: Mutation of non-contiguous tensor subclass input can result in a mismatch in + # strides between the functionalized arg inner tensors and non-functionalized arg inner + # tensors. This is a problem as the inner tensor stride change may not be reflected + # correctly in the outer tensor, so disallow this for now. + mutates_data = has_data_mutation(f_arg) + mutates_metadata = has_metadata_mutation( + f_arg, arg, check_only_storage_mutation=False + ) + if mutates_metadata and is_traceable_wrapper_subclass(arg): + raise RuntimeError( + "Metadata mutations are currently not allowed on tensor subclasses" + ) + mutates_storage_metadata = has_metadata_mutation( + f_arg, arg, check_only_storage_mutation=True + ) + mutations_hidden_from_autograd = are_all_mutations_hidden_from_autograd( + f_arg + ) + mutations_under_no_grad_or_inference_mode = ( + mutates_data + and are_all_mutations_under_no_grad_or_inference_mode(f_arg) + ) + mutation_inductor_storage_resize = was_inductor_storage_resized(f_arg) + + if mutates_storage_metadata: + mutates_data = False + + requires_grad = isinstance(f_arg, torch.Tensor) and f_arg.requires_grad + + input_info.append( + InputAliasInfo( + is_leaf=isinstance(arg, Tensor) and safe_is_leaf(arg), + mutates_data=mutates_data, + mutates_metadata=mutates_metadata, + mutations_hidden_from_autograd=mutations_hidden_from_autograd, + mutates_storage_metadata=mutates_storage_metadata, + mutations_under_no_grad_or_inference_mode=mutations_under_no_grad_or_inference_mode, + mutation_inductor_storage_resize=mutation_inductor_storage_resize, + requires_grad=requires_grad, + keep_input_mutations=keep_input_mutations, + ) + ) + + # If a function involves creating a tensor, and returning a view of it, such that its _base is the intermediate, + # We need to make sure our graph returns the _base as a graph output, and we manually recreate the view + # to return to the user. Why? The backend compiler is free to (incorrectly) not set requires_grad + # on the base tensor, but we are obligated to properly set requires-gradness on the real output. + + inp_storage_refs = { + StorageWeakRef(inpt.untyped_storage()): idx + for idx, inpt in enumerate(flat_f_args) + if isinstance(inpt, Tensor) + } + + # We need inp tensor id's to be able to tell if an outputs **are** inputs. + inp_tensor_ids = {id(inpt) for inpt in flat_f_args if isinstance(inpt, Tensor)} + # We need output tensor id's to tell if any output._base` attributes **are** other outputs. + # (This is also a dict because we need to know that output's index, so we can regenerate + # the alias from it). + out_tensor_ids = {id(o): i for i, o in enumerate(flat_f_outs)} + + # Keep track of which outputs alias other outputs + out_tensor_alias_counts: collections.defaultdict = collections.defaultdict(int) + # This tells us, for a given group of outputs that alias each other, + # whether they e.g. all came from an unbind call + num_aliased_tensors_that_are_multi_output_views: collections.defaultdict = ( + collections.defaultdict(int) + ) + + out_storage_to_metadata_key_to_tensors: collections.defaultdict[ + Optional[StorageWeakRef], + collections.defaultdict[MetadataKey, set[torch.Tensor]], + ] = collections.defaultdict(lambda: collections.defaultdict(set)) + + curr_storage = None + for o in flat_f_outs: + if isinstance(o, torch.Tensor): + curr_storage = StorageWeakRef(o.untyped_storage()) + out_tensor_alias_counts[curr_storage] += 1 + # Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call] + # This is an optimization on top of the "alias of intermediates" logic, + # which you can read more about under Note [AOT Autograd: outputs aliasing inputs or intermediates!] + # + # Before describing the optimization: this is important for AOTAutograd to have good + # perf around, multi-output views. HOWEVER: + # - There is a more generic change to AOTAutograd that we'd like to make, that subsumes this case, + # around using pre-dispatch tracing to partition out a graph so we can faithfully replay all + # views without having to regenerate them at runtime. + # - It's loosely described in this doc (more details will be added soon): + # https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit + # - Once that change lands, we should just rip out this "optimization", since: + # (1) It will be fully unnecessary + # (2) Although it is only a few lines of code, it is a bit difficult to reason about + # its correctness with the autograd engine in all cases. + # + # + # What is this optimization? Consider the below case: + # def f(x): + # intermediate = x.mul(2) + # # x and intermediate here require grad + # o1, o2, ... o10 = intermediate.unbind(-1) + # return intermediate, o1, o2, ... o10 + # Now, the "intermediate base" handling in AOTAutograd implies that we must do the following: + # (1) return "intermediate as an extra output of the compiled graph + # (2) regenerate each aliased output off of "intermediate", **outside** of the autograd.Function. + # The reason AOTAutograd ordinarily does this is for safety: the autograd engine needs to know + # that o1 through o10 are all aliased, and if we blindly return o1 through o10 from the autograd.Function, + # this information will be hidden. + # In particular, mutating one alias might require autograd to update autograd metadata on the other aliases + # (like their grad_fn, for example, when the autograd engine needs to do view-replay). + # + # However, intermediate_base logic can be bad for backward performance (we sometimes generate + # as_strided calls during the intermediate base logic, which can have a slow backward formula). + # Is it possible to find a set of conditions where it is **safe** to hide the output aliasing from autograd? + # + # For a set of outputs of the graph that alias each other, o_1...o_k, consider: + # (1) They came from the same multi-output view op, e.g. o_1, ..., o_k = intermediate.unbind(0) + # (2) If there are any other aliases of o_1 through o_k (in the example above, intermediate), + # **at most** 1 can escape from the graph (e.g. there is not some other graph input/output + # o_other, that aliases these outputs) + # (3) o_1...o_k all require_grad, they all share the same ._base, and their ._base requires grad. + # This condition is important because it's what causes slowness in the intermediate_base + # codepath of aot_autograd. Ordinarily, o_1...o_k would all get a grad_fn, and + # aot_autograd's view-replay might give each output an AsStridedBackward as its grad_fn. + # "K" AsStridedBackward calls will be *much* slower than a single UnbindBackward. + # In this setup, is it possible to mutate one of the outputs o_i in a way that would affect the autograd meta + # of the other aliases? + # + # Claim: No! Consider a few example (which I'm pretty sure cover all cases of mutation w.r.t. autograd): + # (a) What happens if we mutate any of o_1 through o_k directly? + # Autograd raises an error: + # "RuntimeError: Output 0 of UnbindBackward0 is a view and is being modified inplace. This view is + # the output of a function that returns multiple views. Such functions do not allow the output + # views to be modified inplace. You should replace the inplace operation by an out-of-place one." + # (b) What if we take a view of o_k and mutate it, o_k.view(o_k.shape).mul_(2)? + # Autograd raises the same error- the "multi-output-view"ness of an alias propagates to future views. + # (c) What if we mutate o_k under no_grad? + # Autograd raises the same error + # (d) What if we detach and mutate, e.g. o_k.detach().mul_(2)? + # Autograd allows this, *but* autograd updates all alias's grad_fn's to be error functions when accessed. + # Autograd raises the same error + # (e) What if we try to mutate another alias of o_1...o_k, that was **not** created from a multi-output view? + # We promised that there is at most **one** such alias, e.g. intermediate in the example above. + # You can mutate intermediate, but in eager mode this will change the grad_fn of o_1...o_k + # to be error fn's. + # Since intermediate was the *only* non-multi-output-alias, there are no other aliases + # of `intermediate` around that were produced by the compiled fn and have a valid grad_fn. + # + # Coming back to this optimization: + # Given that it is not possible for mutating one of these aliases to affect the autograd metadata of another alias + # without causing an error in eager mode, we will simple hide the aliasing from autograd during torch.compile + # if all of the above conditions are met. + # This has the slight downside that it's possible to write some "bad" code that autograd will raise an error on + # in eager but fail to during torch.compile, but it has the benefit that this code has much better performance. + # NOTE: if and when we eventually update AOTAutograd to do the "view graph slicing" defined here: + # https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit, + # then this optimization will probably matter less and might be ok to remove. + is_cur_tensor_multi_out_view = isinstance( + o, FunctionalTensor + ) and torch._functionalize_is_multi_output_view( # type: ignore[attr-defined] + o.elem + ) + if is_cur_tensor_multi_out_view: + num_aliased_tensors_that_are_multi_output_views[curr_storage] += 1 + if o.requires_grad: + out_storage_to_metadata_key_to_tensors[curr_storage][ + MetadataKey.make(o) + ].add(o) + + # maps the id of an intermediate base to its index in the output of the compiled forward + intermediate_base_tensor_id_to_output_idx: dict[int, int] = {} + intermediate_bases: list[torch.Tensor] = [] + intermediate_bases_descs: list[AOTInput] = [] + # Why Do We Care If Storage Changed? + # It's important to understand the implications of storage changes in complex scenarios. Take this example: + # + # def f(x): + # x_storage = x.untyped_storage() + # non_leaf_tensor = torch.ones(4, requires_grad=True).clone() + # + # # Using no_grad() and _unsafe_preserve_version_counter to simulate the .data = operation + # with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(x): + # x.set_(non_leaf_tensor.untyped_storage()) + # + # out = x.view(-1) + # + # # Restoring x to its original storage, again simulating .data = operation + # with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(x): + # x.set_(x_storage) + # + # return out + # + # In this scenario, 'x' and 'out' have different shapes and are stored at different memory addresses, aka no aliasing. + # However, due to how set_() and more specificlaly, set is functionalized, is defined to preserve eager semantics, + # the autograd engine mistakenly assumes that 'x' and 'out' are aliased, treating 'x' as 'out._base'. + # This misinterpretation leads to an 'alias_of_input' flag, causing an unnecessary as_strided() call to be generated, + # which could lead to issues later in the code. + for o, desc in zip(flat_f_outs, flat_f_outs_descs): + functional_tensor_storage_changed = isinstance( + o, FunctionalTensor + ) and torch._functionalize_was_storage_changed( # type: ignore[attr-defined] + o.elem + ) + curr_storage = ( + None + if not isinstance(o, torch.Tensor) + else StorageWeakRef(o.untyped_storage()) + ) + outs_with_identical_metadata_that_require_grad = ( + [] + if not isinstance(o, Tensor) + else [ + curr + for curr in out_storage_to_metadata_key_to_tensors[curr_storage][ + MetadataKey.make(o) + ] + if o is not curr + ] + ) + + # See Note [Accessing .grad_fn on FunctionalTensor] + # In-place operations on views will trigger a lazy rebase of the autograd graph; + # this runs during access to the .grad_fn. The rebase logic will invoke view ops + # on FunctionalTensors, so we must enable a FunctionalTensorMode here to ensure + # these op calls succeed. + grad_fn = None + if isinstance(o, Tensor): + with FunctionalTensorMode(): + grad_fn = o.grad_fn + + is_result_of_custom_autograd_fn = False + # Need to check for both custom cpp (CppFunction) and python (BackwardCFunction) + # autograd fns + if type(grad_fn).__name__ == "CppFunction": + is_result_of_custom_autograd_fn = True + if isinstance(grad_fn, torch.autograd.function.BackwardCFunction): + is_result_of_custom_autograd_fn = True + + if not isinstance(o, Tensor): + output_type = OutputType.non_alias + base_idx = None + elif ( + curr_storage in inp_storage_refs + and grad_fn is not None + and is_result_of_custom_autograd_fn + ): + output_type = OutputType.custom_function_view + base_idx = None + elif ( + curr_storage in inp_storage_refs + and not functional_tensor_storage_changed + ): + # pyrefly: ignore [index-error] + base_idx = inp_storage_refs[curr_storage] + is_input_tensor = id(o) in inp_tensor_ids + num_aliased_outs = out_tensor_alias_counts[curr_storage] + num_multi_output_view_outs = ( + num_aliased_tensors_that_are_multi_output_views[curr_storage] + ) + num_aliased_outs_that_are_not_multi_output_views = ( + num_aliased_outs - num_multi_output_view_outs + ) + if ( + grad_fn is not None + and num_aliased_outs_that_are_not_multi_output_views == 0 + ): + # See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call] + # In particular, given: + # def f(x): + # return list(x.unbind(0)) + # The main reason we ordinarily try to regenerate these output aliases outside of the + # compiled autograd.Function is because if any of the outputs are later mutated, + # autograd needs to perform view-replay to regenerate them. + # However, autograd does not allow users to mutate multi-output views + # in any way that can change the autograd metadata of other aliases. + # So we hide this aliasing from autograd here. + log.debug( + "Encountered AOTAutograd case: differentiable outputs that \ +alias each other from a multi-output view call" + ) + output_type = OutputType.non_alias + elif is_input_tensor: + output_type = OutputType.is_input + else: + output_type = OutputType.alias_of_input + elif functional_tensor_storage_changed and id(o) in inp_tensor_ids: + # When there is a set_() on an input, we cannot rely on checking storages + # to detect if we are returning an input (since the inputs storage is different) + assert curr_storage is not None + base_idx = inp_storage_refs[curr_storage] + output_type = OutputType.is_input + + # We only need to handle the intermediate base case when both + # the intermediate base and the output require gradients. + # See Note [AOT Autograd: outputs aliasing inputs or intermediates!] + elif o._base is not None and o.requires_grad and o._base.requires_grad: + num_aliased_outs = out_tensor_alias_counts[curr_storage] + num_multi_output_view_outs = ( + num_aliased_tensors_that_are_multi_output_views[curr_storage] + ) + num_aliased_outs_that_are_not_multi_output_views = ( + num_aliased_outs - num_multi_output_view_outs + ) + # Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call] + if ( + out_tensor_alias_counts[curr_storage] == 1 + or num_aliased_outs_that_are_not_multi_output_views <= 1 + ): + # Note [Intermediate Bases Optimization] + # Normally if we have an output that aliases an intermediate, + # we need to add the extra "intermediate base" logic further down + # to prevent autograd from yelling at us if the user later tries to + # mutate that output. + # However, the common case here is if we have an output that aliases an intermediate, + # but doesn't alias any other outputs. + # In that case, autograd shouldn't have to worry about the aliasing at all + # (if that output is mutated, there are no other live aliases for autograd to worry about). + # The "intermediate bases" can hurt inductor perf by forcing more variables to become outputs. + # So as an optimization, we won't do intermediate base handling in this case. + # Instead, we'll hide the aliasing from autograd using aten._unsafe_view(). + if ( + out_tensor_alias_counts[curr_storage] != 1 + and num_aliased_outs_that_are_not_multi_output_views <= 1 + ): + log.debug( + "Encountered AOTAutograd case: differentiable outputs that alias each other \ +from a multi-output view call" + ) + output_type = OutputType.unsafe_view_alias + base_idx = None + else: + # First, check if o's ._base is an existing output + maybe_existing_out_idx = out_tensor_ids.get(id(o._base), None) + if maybe_existing_out_idx is not None: + # Special case where the output is an alias of a graph intermediate, but that intermediate + # is itself also a user output. + output_type = ( + OutputType.alias_of_intermediate_base_is_user_output + ) + base_idx = maybe_existing_out_idx + else: + # Next, check if o's ._base is an intermediate base that we already returned + maybe_existing_base_output_idx = ( + intermediate_base_tensor_id_to_output_idx.get( + id(o._base), None + ) + ) + if maybe_existing_base_output_idx is not None: + output_type = OutputType.alias_of_intermediate + base_idx = maybe_existing_base_output_idx + else: + # Otherwise, take o._base and explicitly return it as an output in the compiled graph + new_out_idx = len(intermediate_bases) + base_idx = new_out_idx + # Indicate to the logic later on (when we trace the joint) + # that this particular output should get it's ._base appended to the forward graph outputs + output_type = ( + OutputType.alias_of_intermediate_save_as_output + ) + intermediate_base_tensor_id_to_output_idx[id(o._base)] = ( + new_out_idx + ) + intermediate_bases.append(o._base) + # NB: The desc we picked here is guaranteed to be + # synchronized with the one in + # graph_capture_wrappers.py because we + # SPECIFICALLY notated this output as + # alias_of_intermediate_save_as_output + intermediate_bases_descs.append( + TangentAOTInput(IntermediateBaseAOTOutput(desc)) + ) + elif ( + # See https://github.com/pytorch/pytorch/issues/100348 for this case. + # This protects against the specific case where a user fn returns (output, output.detach()) + out_tensor_alias_counts[curr_storage] > 1 + and len(outs_with_identical_metadata_that_require_grad) > 0 + and not o.requires_grad + ): + # In theory we could use any of these tensors to regenerate the aliased outputs from, + # since they all alias each other and have identical metadata + out_alias = outs_with_identical_metadata_that_require_grad[0] + existing_out_idx = out_tensor_ids[id(out_alias)] + output_type = OutputType.alias_of_intermediate_base_is_user_output + base_idx = existing_out_idx + else: + output_type = OutputType.non_alias + base_idx = None + + if isinstance(o, torch.Tensor): + dynamic_dims = { + i for i, s in enumerate(o.shape) if not is_concrete_int(s) + } + else: + dynamic_dims = None + + # Save the current FunctionalTensor output. + # + # This will be used at runtime for reconstructing output views from + # their respective base tensors. + # + # The FunctionalTensor will be saved if one of the 2 conditions below + # is true: + view_meta_sequence = None + if ( + # 1. If the output_type is either of: + # (i) alias_of_intermediate; + # (ii) alias_of_intermediate_save_as_output; or + # (iii) alias_of_intermediate_base_is_user_output. + # + # No need to worry about in-place view operations here, since + # this functionalization step elimitates mutations. + # + # i.e. we have access to the actual base tensor, before the + # in-place operation was applied. + output_type + in ( + OutputType.alias_of_intermediate, + OutputType.alias_of_intermediate_save_as_output, + OutputType.alias_of_intermediate_base_is_user_output, + ) + ) or ( + # 2. If the output_type is alias_of_input, and no in-place view + # operationthe was run on the input (base tensor). + # + # In this case, we need to check for metadata mutation because + # the runtime explicitly reconstructs the inputs, before actually + # reconstructing the outputs. Due to in-place view operations, the + # fully reconstructed input may not be this output base tensor + # anymore. + output_type == OutputType.alias_of_input + and base_idx is not None + and not input_info[base_idx].mutates_metadata + ): + if isinstance(o, FunctionalTensor): + view_meta_sequence = ViewMetaSequence(o) + + out_info = OutputAliasInfo( + output_type=output_type, + raw_type=type(o), + base_idx=base_idx, + dynamic_dims=dynamic_dims, + requires_grad=isinstance(o, torch.Tensor) and o.requires_grad, + view_meta_sequence=view_meta_sequence, + ) + output_info.append(out_info) + + # See Note [AOT Autograd: Views to avoid tangents aliasing inputs] + def view_avoid_dupes_with_primals(t): + if isinstance(t, Tensor) and is_traceable_wrapper_subclass(t): + return transform_subclass( + t, lambda _, inner_t: view_avoid_dupes_with_primals(inner_t) + ) + if isinstance(t, Tensor): + return t.view(t.shape) + return t + + # This analysis function returns *only* the outputs that are meant to be tangents to the backwards. + # Anything that aliases (inputs returned in the fw due to metadata mutations, or outputs that alias inputs/intermediates) + # are *regenerated* later, and not used directly in the autograd graph + def _plain_fake_tensor_like_subclass(x): + # pyrefly: ignore [bad-context-manager] + with detect_fake_mode(): + return torch.empty( + x.shape, dtype=x.dtype, device=x.device, layout=x.layout + ) + + def _is_subclass_mutated_input_tangent_always_subclass(inp): + return ( + isinstance(inp, torch.nested._internal.nested_tensor.NestedTensor) + or torch._functorch.config.disable_guess_zero_tangent_for_mutated_input_subclass + ) + + f_input_tangents_pairs = [ + # Note: [AOTAutograd Tangent Subclassness for mutated inputs] + # Generally when creating tangents to trace with, we assume that tangents will have + # the same subclass-ness as their forward outs + # however: for tangents that correspond to input mutations, in practice it is more likely + # that these tangents will be plain tensors of zeros at runtime, so we tweak our guess + # to assume that these tangents should always be plaint tensors. + # Example: + # def f(x): + # x.mul_(2) + # return x + 1 + # out = f(x) + # out.sum().backward() + # In the above code, we will have a tangent "x_updated_tangent", + # which will be a plain tensor of zeros, *unless* x is used in some compute after executing f + # + # However, there are exceptions to this logic. If a view is created from mutated input and is used in backward, + # The tangent for this subclass input will be a subclass tensor. + # Example: + # def f(a, b): + # a.mul_(2) + # b.mul_(3) + # return b.view(b.shape), a + b + # a_out, b_out = f(..., Subclass) + # (a * b).sum().backward() + # + # We can not deduce it easily now, so introducing a debug config to be able to turn off this for specific cases. + # NJT guarantees to have its tangent as NJT, because it has dedicated integration in Autograd + # See torch/csrc/autograd/python_function.cpp, use_zeros_like. + ( + ( + _plain_fake_tensor_like_subclass(inp) + if is_traceable_wrapper_subclass(inp) + and not _is_subclass_mutated_input_tangent_always_subclass(inp) + else inp + ), + TangentAOTInput(InputMutationAOTOutput(inp_desc)), + ) + for inp, inp_desc, info in zip(flat_f_args, flat_f_args_descs, input_info) + if info.mutation_type == MutationType.MUTATED_OUT_GRAPH + and info.mutates_data + and info.requires_grad + ] + f_input_tangents, f_input_tangents_descs = ( + [x[0] for x in f_input_tangents_pairs], + [x[1] for x in f_input_tangents_pairs], + ) + + f_output_tangents_pairs = [ + (o, TangentAOTInput(desc)) + for o, info, desc in zip(flat_f_outs, output_info, flat_f_outs_descs) + if info.output_type + in [ + OutputType.non_alias, + OutputType.unsafe_view_alias, + OutputType.custom_function_view, + ] + and issubclass(info.raw_type, torch.Tensor) + and info.requires_grad + ] + f_output_tangents, f_output_tangents_descs = ( + [x[0] for x in f_output_tangents_pairs], + [x[1] for x in f_output_tangents_pairs], + ) + + # intermediate bases are also included in the backward graph + f_tangents = f_input_tangents + f_output_tangents + intermediate_bases + f_tangents_descs = ( + f_input_tangents_descs + f_output_tangents_descs + intermediate_bases_descs + ) + + # TODO: I'm pretty sure you don't need a tree_map here + traced_tangents = pytree.tree_map(from_fun, f_tangents) + traced_tangents = pytree.tree_map( + view_avoid_dupes_with_primals, traced_tangents + ) + traced_tangents = [ + coerce_tangent_and_suggest_memory_format(tt)[0] + for i, tt in enumerate(traced_tangents) + ] + # NB: update this if the maps above ever change structure. + # Also, it might be helpful to add coercion information to the tangent desc! + traced_tangents_descs = f_tangents_descs + + nonlocal static_input_indices + static_input_indices = static_input_indices or [] + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + passed_indices = set(static_input_indices) + static_input_indices = [ + i + for i, arg in enumerate(flat_args) + if (isinstance(arg, torch.nn.Parameter) or i in passed_indices) + ] + + static_input_logger.debug( + "static input indices metadata analysis: %s", static_input_indices + ) + + f_mutated_inputs = [ + inp + for inp, info in zip(flat_f_args, input_info) + if info.mutation_type == MutationType.MUTATED_OUT_GRAPH + ] + f_metadata_mutated_inputs = [ + inp for inp, info in zip(flat_f_args, input_info) if info.mutates_metadata + ] + # This logic (annoyingly) re-figures out exactly what the outputs to the compiled fw graph will be. + # When handling subclasses, we need info about **all** outputs of compiled forward graph, + # so we know precisely which graph outputs to wrap back into tensor subclasses + # Ideally we would refactor this so not have an is_train flag, and have the separate + # inference and training paths decide which inputs/output to ask for subclass info on. + # However, we currently stash indexing information on each SubclassMeta about its order + # in the graph outputs list. + f_fw_graph_outs = list(flat_f_outs) + if is_train or not keep_input_mutations: + f_fw_graph_outs = f_mutated_inputs + f_fw_graph_outs + else: + # even when "keep_input_mutations" is True, + # we never keep metadata-only mutations in the fw graph + f_fw_graph_outs = f_metadata_mutated_inputs + f_fw_graph_outs + if is_train: + f_fw_graph_outs = f_fw_graph_outs + intermediate_bases + fw_graph_outs = pytree.tree_map(from_fun, f_fw_graph_outs) + + grad_enabled_mutation = None + if torch.is_grad_enabled() != prior_grad_enabled: + grad_enabled_mutation = torch.is_grad_enabled() + torch.set_grad_enabled( + prior_grad_enabled + ) # Restore the prior state after tracing it + log.debug( + ( + "grad_mode mutation encountered in graph. " + "Will emit mutation epilogue, to set grad_mode=%s" + ), + grad_enabled_mutation, + ) + + metadata = ViewAndMutationMeta( + input_info=input_info, + output_info=output_info, + num_intermediate_bases=len(intermediate_bases), + keep_input_mutations=keep_input_mutations, + traced_tangents=traced_tangents, + traced_tangents_descs=traced_tangents_descs, + subclass_inp_meta=create_subclass_meta(flat_args), + subclass_fw_graph_out_meta=create_subclass_meta(fw_graph_outs), + subclass_tangent_meta=create_subclass_meta( + traced_tangents, count_symints=False, with_memory_format=True + ), + is_train=is_train, + grad_enabled_mutation=grad_enabled_mutation, + static_input_indices=static_input_indices, + tokens=mode._tokens, + ) + return metadata + + return inner diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/descriptors.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/descriptors.py new file mode 100644 index 0000000000000000000000000000000000000000..3d480cdf6f9ac66c12c394b0c43fe6e1aacc06c9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/descriptors.py @@ -0,0 +1,749 @@ +""" +AOTAutograd descriptors are a path-like data structure (similar to pytree +paths and sources) that describe the semantic meaning of an input/output to FX +graphs. Although you may know the input/output meaning at the top level of +the original function you traced, because we have many graph capture wrappers +that change the calling convention, it can be difficult to tell how these +correspond to the actual FX graph you get back, to say nothing about the extra +arguments/outputs for tangents, gradients, etc. Descriptors describe the meaning +of arguments. + +Examples +-------- + +Before we talk about the precise semantics, it's helpful to look at some +examples to get some intuition for the meaning of descriptors. Here are some +input descriptors you might find on the joint FX graph: + +* PlainAOTInput(idx=0) - the first input from the original callable, as is + +* ParamAOTInput(target="mod.weight") - the parameter with FQN mod.weight + +* TangentAOTInput(output=PlainAOTOutput(idx=1)) - the input tangent + corresponding to the gradients for the second output in the forward graph + +* ViewBaseAOTInput(base_of=PlainAOTInput(idx=0)) - it turned out the first + input was actually a (differentiable) view of a tensor which aliased with + another input tensor. We replaced this input with a single input for the + base of all of these inputs, replacing the original inputs (one of which is + mentioned in base_of). We would generate a GradAOTOutput for *this* input + (and not the original PlainAOTInputs!) If you have a joint graph where a + view base like this is undesirable, you can eliminate this by cloning + the views outside of the compiled region (assuming you aren't mutating this + tensor). + +* SubclassGetAttrAOTInput(base=AOTInput(idx=0), attr="inner") - this tensor + corresponds to the "inner" tensor from the tensor subclass that is at the + first index. In general, joint graphs from AOTAutograd never take tensor + subclasses as inputs; they are always unpacked into their constituent plain + tensor pieces; use the descriptors to identify the parts of the tensor that + are related. Note that this can be nested (if you have nested tensor + subclasses!) + +Here are some output descriptors you might find on the Joint FX graph: + +* PlainAOTOutput(idx=0) - the first output from the original forward function, + as is + +* GradAOTOutput(grad_of=PlainAOTInput(idx=1)) - the computed gradient for the + second input to the graph, an output of the backward graph + +* InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=0)) - when the first + input is mutated, the new value to be copied into the first input of the + graph. Sometimes, these outputs can be elided and the ``copy_`` is done directly + in the graph (controlled by keep_input_mutations), but if the input + mutation must be differentiated through we always generate an output like this + +* IntermediateBaseAOTOutput(base_of=PlainAOTOutput(idx=0)) - if we return + multiple outputs which alias each other, we instead replace them with a single + output tensor representing the base of all the aliases. This output indicates + it is the base for /one/ of those original outputs. If this is undesirable in + the joint graph, clone all outputs before returning from the graph. + +* SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), idx="inner") - this + tensor correspondings to the inner tensor of the first original output which + is a tensor subclass. This and other subclass components of that output will + get repacked into a tensor subclass. + +High level semantics +-------------------- + +OK, let's formally define a descriptor. Intuitively, suppose we have:: + + def wrapped_graph(*args): + ret = graph(*in_transform(args)) + return out_transform(ret) + +Then the descriptor for input[i] to graph describes a function fin_i such that:: + + fin_i(args) == in_transform(args)[i] + +and the descriptor for output[j] from graph describes a function fout_j such that:: + + fout_j(out_transform(ret)) == ret[j] + +AKA input descriptors tell you how to get from outer inputs to inner inputs, +while output descriptors tell you how to get from outer outputs to inner +outputs (inverse data flow!) + +We haven't said anything about what these transformations actually do. There +are three major transformations AOTAutograd does (performed in this order): + +* View/mutation handling +* Autograd +* Subclasses + +So intuitively, descriptors are built like this: + +1. **PlainAOTInput, PlainAOTOutput.** + + We start off descriptors describing the exact inputs/outputs of the + original flattened user function. This user function is assumed to already + be flattened; you would chain on pytree KeyPaths to further describe where + in the pytree each input/output lived if you needed to deal with + unflattened functions: this can be done from userland on top of + descriptors, so the main descriptors mechanism doesn't handle it. + +2. **SyntheticBaseAOTInput, ViewBaseAOTInput, MetadataMutationAOTOutput, + InputMutationAOTOutput, IntermediateBaseAOTOutput** + + We deal with mutations and aliasing by removing duplicate PlainAOTInputs + and introduce some new artificial inputs/outputs. These inputs do not + have a straightforward correspondence to the original user inputs, but if + you are implementing a pass that doesn't care about the exact semantics of + inputs, you should handle all of these uniformly in the same way as regular + inputs. + +3. **TangentAOTInput, GradAOTOutput** + + We deal with autograd by introducing a tangent input for every + differentiable AOTOutput (including the new ones introduced above), and a + gradient output for every differentiable AOTInput (also including new ones + introduced above.) The arguments to these AOTInput/AOTOutput can ONLY be + the ones we already have above (from steps 1-2). As AOTAutograd does not + currently support double backwards, you never have tangents of grads or + vice versa (but in the future we could!) + +4. **SubclassGetAttrAOTInput, SubclassGetAttrAOTOutput, et al.** + + We deal with subclasses by introducing flattened inputs/outputs (including + potentially symbolic sizes/strides) for every AOTInput/AOTOutput that was a + subclass. As above, the arguments to these AOTInput/AOTOutput can ONLY be + the ones we have above (from steps 1-3). Recursive subclasses are + supported, so these descriptors can nest with each other (so descriptors + from step 4 are fair game as well.) + +5. **ForwardTokenAOTInput, ForwardTokenAOTOutput, BackwardTokenAOTInput, BackwardTokenAOTOutput.** + + Some extra token inputs/outputs get added, these are synthetic and are just here to + prevent DCE/reordering. + +The important thing about the pipeline is that descriptors can ONLY be +created from top-to-bottom. So for example, you can have:: + + SubclassGetAttrAOTInput(TangentAOTInput(PlainAOTOutput(...))) # OK + +As you can see that PlainAOTOutput -> TangentAOTInput -> +SubclassGetAttrAOTInput is consistent with the pipeline ordering), but you can +NEVER have:: + + TangentAOTInput(SubclassGetAttrAOTOutput(PlainAOTOutput(...)) # BAD + +This is inconsistent; we always do autograd BEFORE we process subclasses! + +Similarly, for example, this is illegal:: + + GradAOTOutput(SubclassGetAttrAOTInput(PlainAOTInput(...))) # BAD + +It is illegal because subclasses are handled *after* create joint during +wrapper construction. Instead, you would have:: + + SubclassGetAttrAOTOutput(GradAOTOutput(PlainAOTInput(...))) # OK + +This intuitively captures the fact that we always to autograd directly on the +subclass, rather than after desugaring the subclass into its inner tensors. + +Descriptor index +---------------- + +Here is a list of all AOTInput/AOTOutput, organized by how likely you need to +handle them: + +* AOTInput + + * Important: + + * PlainAOTInput (the primals!) + * ParamAOTInput + * TangentAOTInput + * SubclassGetAttrAOTInput et al. (if you use subclasses) + + * View related (can be eliminated by cloning inputs to graph; if you don't + eliminate them, make sure to handle pairing them with GradAOTOutput): + + * ViewBaseAOTInput + * SyntheticBaseAOTInput + + * Non-tensor, mostly just ignore them: + + * DummyAOTInput + * PhiloxForwardSeedAOTInput + * PhiloxForwardBaseOffsetAOTInput + * PhiloxBackwardSeedAOTInput + * PhiloxBackwardBaseOffsetAOTInput + * ForwardTokenAOTInput + * BackwardTokenAOTInput + +* AOTOutput + + * Important: + + * PlainAOTOutput + * GradAOTOutput + * SubclassGetAttrAOTOutput et al. (if you use subclasses) + + * More obscure (if not eliminated, make sure you handle pairing them with + TangentAOTInput): + + * InputMutationAOTOutput (can be eliminated if mutations are non-differentiable) + * IntermediateBaseAOTOutput (can be eliminated by cloning outputs of graph) + * MetadataMutationAOTOutput (uhh, just don't mutate metadata?) + + * Non-tensor, mostly just ignore them: + + * PhiloxUpdatedForwardOffsetAOTOutput + * PhiloxUpdatedBackwardOffsetAOTOutput + * ForwardTokenAOTOutput + * BackwardTokenAOTOutput + * DummyAOTOutput + +For convenience, we also have DifferentiableAOTInput and +DifferentiableAOTOutput to help you classify which inputs/outputs can be +wrapped by GradAOTOutput/TangentAOTInput (respectively), which are essentially +all tensor AOTInput/AOTOutput excluding the subclass descriptors. + +Implementation details +---------------------- + +The stylized view above is good for understanding how to interpret +descriptors, but the way that descriptors are generated in code is a bit more +complicated. Specifically, AOTAutograd is structured as a series of wrappers +on the original user function, which are composed together to form the final +function to trace. As a result of this, AOTAutograd ends up first building +the full AOTInputs for a function to be traced (as it builds the wrappers and +modifies the flat arguments to be compatible with the new input signature of +the wrapper), and then in reverse builds up the AOTOutput as it is tracing. + +There is one major exception to this general idea of "build AOTInput first", +and then "build AOTOutput second": when we create TangentAOTInput, we need to +reference AOTOutputs (which output we are the tangents of) which we generally +haven't created yet. There's two ways we deal with this: + +- After the precompile steps (dedup and synthetic base handling), we do an + initial pass to collect forward metadata that produces the initial set of + PlainAOTOutputs which we use to create the tangent inputs. + +- We also sometimes just violate causality and predict that an AOTOutput will + be created in a particular way at some later point in time when we build an + AOTInput. + +As of July 2025, here is an exhaustive description of how inputs/outputs +traverse the wrappers from AOTAutograd, and what descriptors can be introduced +at these phases. + +:: + + Build wrappers (FLOWS DOWN) Run trace (FLOWS UP) + ------------------------------------------------------------------------------------------------- + Begin PlainAOTInput (n/a) + ParamAOTInput + + Precompile dedupe (remove dupes) (nothing) + + Precompile synthetic base SyntheticBaseAOTInput MetadataMutationAOTOutput + ViewBaseAOTInput + + Forward metadata trace PlainAOTOutput (n/a) + MetadataMutationAOTOutput + + Prepare for autograd (nothing) InputMutationAOTOutput + IntermediateBaseAOTOutput + + Create joint TangentAOTInput GradAOTOutput + w/ InputMutationAOTOutput + w/ IntermediateBaseAOTOutput + + Precompile subclass SubclassGetAttrAOTInput et al. SubclassGetAttrAOTOutput et al. + + Effect tokens ForwardTokenAOTInput ForwardTokenAOTOutput + BackwardTokenAOTInput BackwardTokenAOTOutput + + End (n/a) PlainAOTOutput + +It can be helpful to separately write down the input flow and the output flow +for ease of understanding the data flow: + +* Input desc propagation (happens as we build wrappers) + + * [IN] Begin with original calling convention (PlainAOTInput, ParamAOTInput) + * [IN] Precompile dedupe: (removes duplicate AOTInputs) + * [IN] Precompile synthetic base: SyntheticBaseAOTInput, ViewBaseAOTInput + * Forward metadata trace (mini output desc propagation) + + * [OUT] Original output convention: PlainAOTOutput + * [OUT] Precompile synthetic base: MetadataMutationAOTOutput + + * [IN] Prepare for autograd: (nothing) + * [IN] Create joint: TangentAOTInput (potentially w/ + IntermediateBaseAOTOutput, InputMutationAOTOutput) + * [IN] Precompile subclass: SubclassGetAttrAOTInput et al. + * [IN] Effect tokens: ForwardTokenAOTInput, BackwardTokenAOTInput + (Note: BackwardTokenAOTInput is technically generated not by a wrapper but + actually done by token_discovery which implicitly adds extra arguments + to the FX trace on-the-fly.) + +* Trigger a trace with the modified inputs on the wrapper +* Output desc propagation (happens as we unwind from the user function call in trace) + + * [OUT] Begin with original calling convention: PlainAOTOutput + * [OUT] Effect tokens: ForwardTokenAOTOutput, BackwardTokenAOTOutput + * [OUT] Precompile subclass: SubclassGetAttrAOTOutput et al. + * [OUT] Create joint: GradAOTOutput + * [OUT] Prepare for autograd: InputMutationAOTOutput, IntermediateBaseAOTOutput + * [OUT] Precompile synthetic base: MetadataMutationAOTOutput + * [OUT] Precompile dedupe: (nothing) +""" + +import dataclasses + + +# TODO: the is_* predicates are a little suspicious because (1) they're not +# used by anything and (2) they always report False even when a parameter got +# swizzled into a view base or deduped with a non-parameter. It is pretty +# difficult to exercise these cases but it's not clear if you will write code +# that works correctly in those cases. + + +@dataclasses.dataclass(frozen=True) +class AOTInput: + """Describes where an input from an AOTAutograd produced FX graph comes from""" + + def expr(self) -> str: + raise NotImplementedError("Subclasses must implement expr()") + + def is_param(self) -> bool: + """True if this input is a parameter or derived from a parameter (e.g., subclass attr)""" + return False + + def is_buffer(self) -> bool: + """True if this input is a buffer or derived from a buffer (e.g., subclass attr)""" + return False + + def is_tangent(self) -> bool: + """True if this input is a tangent or derived from a tangent (e.g., subclass attr)""" + return False + + +# Note: Currently, our typing discipline for differentiable versus not is not +# very good, so feel free to rely on runtime tests instead. + + +@dataclasses.dataclass(frozen=True) +class DifferentiableAOTInput(AOTInput): + """A subclass that classifies AOTInput that can be wrapped by GradAOTOutput""" + + +@dataclasses.dataclass(frozen=True) +class AOTOutput: + """Describes where an output from an AOTAutograd produced FX graph will + eventually be bundled into the final output""" + + def expr(self) -> str: + raise NotImplementedError("Subclasses must implement expr()") + + def is_grad(self) -> bool: + """True if this output is a grad or derived from a grad (e.g., subclass attr)""" + return False + + +@dataclasses.dataclass(frozen=True) +class DifferentiableAOTOutput(AOTOutput): + """A subclass that classifies AOTOutput that can be wrapped by TangentAOTInput""" + + +# ------------ + +# AOTInput + +# ------------ + + +@dataclasses.dataclass(frozen=True) +class ParamAOTInput(DifferentiableAOTInput): + """The input is a parameter, whose FQN is target""" + + target: str + + def expr(self) -> str: + return f"self.get_parameter({self.target!r})" + + def is_param(self) -> bool: + return True + + def is_buffer(self) -> bool: + return False + + +@dataclasses.dataclass(frozen=True) +class BufferAOTInput(DifferentiableAOTInput): + """The input is a buffer, whose FQN is target""" + + target: str + + def expr(self) -> str: + return f"self.get_buffer({self.target!r})" + + def is_param(self) -> bool: + return False + + def is_buffer(self) -> bool: + return True + + +@dataclasses.dataclass(frozen=True) +class DummyAOTInput(AOTInput): + """In some circumstances, we want to call into a function that expects AOTInput, but + we don't actually care about that logic (most typically, because some code is being used + for both compile-time and run-time; AOTInput processing is not needed in this situation. + Pass a dummy in this situation; but it is better to just have a version of the function + that doesn't have this at all.""" + + idx: int + + def expr(self) -> str: + return f"__dummy{self.idx}" + + +@dataclasses.dataclass(frozen=True) +class PlainAOTInput(DifferentiableAOTInput): + """The input is a plain input, corresponding to a particular positional index. + + Note that AOTInput is always relative to a function with a *flat* calling convention, + e.g., as accepted by `aot_module_simplified`. There are some AOTAutograd APIs that + flatten pytrees, and we don't record PyTree key paths from the flattening (but we + could and should!) + """ + + idx: int + + def expr(self) -> str: + return f"args[{self.idx}]" + + +@dataclasses.dataclass(frozen=True) +class SubclassGetAttrAOTInput(AOTInput): + """Subclass inputs get unpacked into their constituent pieces before going into an FX + graph. This tells you which particular attribute of the subclass this particular + input corresponds to (of the 'base' originally subclass argument.) + """ + + base: AOTInput + attr: str + + def expr(self) -> str: + return f"{self.base.expr()}.{self.attr}" + + def is_param(self) -> bool: + return self.base.is_param() + + def is_buffer(self) -> bool: + return self.base.is_buffer() + + def is_tangent(self) -> bool: + return self.base.is_tangent() + + +@dataclasses.dataclass(frozen=True) +class SubclassSizeAOTInput(AOTInput): + """Which subclass this particular outer size SymInt input (at dim idx) came from.""" + + base: AOTInput + idx: int + + def expr(self) -> str: + return f"{self.base.expr()}.size({self.idx})" + + +@dataclasses.dataclass(frozen=True) +class SubclassStrideAOTInput(AOTInput): + """Which subclass this particular outer stride SymInt input (at dim idx) came from.""" + + base: AOTInput + idx: int + + def expr(self) -> str: + return f"{self.base.expr()}.stride({self.idx})" + + +@dataclasses.dataclass(frozen=True) +class ViewBaseAOTInput(DifferentiableAOTInput): + """ + When multiple differentiable inputs are views of the same input, AOTAutograd will replace all of these + views with a single input representing the base. If this is undesirable, you can clone the views + example inputs before passing them into AOTAutograd. + + TODO: In principle we could report ALL of the inputs who this is a base of. + """ + + base_of: AOTInput + + def expr(self) -> str: + return f"{self.base_of.expr()}._base" + + +@dataclasses.dataclass(frozen=True) +class SyntheticBaseAOTInput(DifferentiableAOTInput): + """This is similar to ViewBaseAOTInput, but this happens when none of the views were differentiable, so + we weren't able to get our hands on the true original view and constructed a synthetic one instead + for the sake of autograd. + """ + + base_of: AOTInput + + def expr(self) -> str: + return f"__make_synthetic_base({self.base_of.expr()})" + + +@dataclasses.dataclass(frozen=True) +class PhiloxForwardSeedAOTInput(AOTInput): + """The seed for functionalized Philox RNG calls, specifically for forward graph.""" + + def expr(self) -> str: + return "__philox_forward_seed" + + +@dataclasses.dataclass(frozen=True) +class PhiloxForwardBaseOffsetAOTInput(AOTInput): + """The offset for functionalized Philox RNG calls, specifically for forward graph.""" + + def expr(self) -> str: + return "__philox_forward_base_offset" + + +@dataclasses.dataclass(frozen=True) +class PhiloxBackwardSeedAOTInput(AOTInput): + """The seed for functionalized Philox RNG calls, specifically for backward graph.""" + + def expr(self) -> str: + return "__philox_backward_seed" + + +@dataclasses.dataclass(frozen=True) +class PhiloxBackwardBaseOffsetAOTInput(AOTInput): + """The offset for functionalized Philox RNG calls, specifically for backward graph.""" + + def expr(self) -> str: + return "__philox_backward_base_offset" + + +@dataclasses.dataclass(frozen=True) +class ForwardTokenAOTInput(AOTInput): + """The world token which is threaded through side-effectful operations""" + + idx: int + + def expr(self) -> str: + return f"__forward_token{self.idx}" + + +@dataclasses.dataclass(frozen=True) +class BackwardTokenAOTInput(AOTInput): + """The world token which is threaded through side-effectful operations, for backwards""" + + idx: int + + def expr(self) -> str: + return f"__backward_token{self.idx}" + + +# Technically the "output" here is redundant, tangents always correspond to +# outputs +# NB: this is marked differentiable as it /would/ be differentiable if we +# support double backwards, but we never generate this today because we +# don't support double backwards. +@dataclasses.dataclass(frozen=True) +class TangentAOTInput(DifferentiableAOTInput): + """An input to the joint graph representing the tangent of an output.""" + + output: DifferentiableAOTOutput + + def __post_init__(self) -> None: + assert isinstance(self.output, DifferentiableAOTOutput) + + def expr(self) -> str: + return f"__output_tangent({self.output.expr()})" + + def is_tangent(self) -> bool: + return True + + +# ------------ + +# AOTOutput + +# ------------ + + +@dataclasses.dataclass(frozen=True) +class PlainAOTOutput(DifferentiableAOTOutput): + """A plain tensor output at position idx of the output tuple""" + + idx: int + + def expr(self) -> str: + return f"output[{self.idx}]" + + +@dataclasses.dataclass(frozen=True) +class InputMutationAOTOutput(DifferentiableAOTOutput): + """The mutated value of an input tensor, returned so we can appropriately propagate autograd.""" + + mutated_input: AOTInput + + def expr(self) -> str: + return f"__input_mutation({self.mutated_input.expr()})" + + +@dataclasses.dataclass(frozen=True) +class IntermediateBaseAOTOutput(DifferentiableAOTOutput): + """An intermediate base of multiple outputs which alias each other. We only report ONE of + the outputs that contributed to this base""" + + base_of: "AOTOutput" + + def expr(self) -> str: + return f"__intermediate_base({self.base_of.expr()})" + + +# TODO: it's a little dodgy this is differentiable lol, but we do generate +# these BEFORE autograd is handled +@dataclasses.dataclass(frozen=True) +class MetadataMutationAOTOutput(DifferentiableAOTOutput): + idx: int + + def expr(self) -> str: + return f"__aliased_arg_with_metadata_mutation{self.idx}" + + +# NB: this is marked differentiable as it /would/ be differentiable if we +# support double backwards, but we never generate this today because we +# don't support double backwards. +@dataclasses.dataclass(frozen=True) +class GradAOTOutput(DifferentiableAOTOutput): + """An output representing the computed gradient for a differentiable input, in the joint graph""" + + grad_of: DifferentiableAOTInput + + def __post_init__(self) -> None: + assert isinstance(self.grad_of, DifferentiableAOTInput) + + def expr(self) -> str: + return f"__grad({self.grad_of.expr()})" + + def is_grad(self) -> bool: + return True + + +@dataclasses.dataclass(frozen=True) +class PhiloxUpdatedForwardOffsetAOTOutput(AOTOutput): + """The final offset from the functionalized RNG calls, forward only""" + + def expr(self) -> str: + return "__philox_updated_forward_offset" + + +@dataclasses.dataclass(frozen=True) +class PhiloxUpdatedBackwardOffsetAOTOutput(AOTOutput): + """The final offset from the functionalized RNG calls, backward only""" + + def expr(self) -> str: + return "__philox_updated_backward_offset" + + +@dataclasses.dataclass(frozen=True) +class ForwardTokenAOTOutput(AOTOutput): + """The world token output for side-effectful calls, returned so we cannot DCE it, forward only""" + + idx: int + + def expr(self) -> str: + return f"__forward_token{self.idx}" + + +@dataclasses.dataclass(frozen=True) +class BackwardTokenAOTOutput(AOTOutput): + """The world token output for side-effectful calls, returned so we cannot DCE it, backward only""" + + idx: int + + def expr(self) -> str: + return f"__backward_token{self.idx}" + + +# These are seemingly symmetric with their AOTInput counterparts. The way to +# think about it is that a subclass could be an input or an output, and they +# get exploded into plain tensors on the way in and out. So we need +# descriptors for both. +@dataclasses.dataclass(frozen=True) +class SubclassGetAttrAOTOutput(AOTOutput): + """This output will be bundled into a subclass at this location""" + + base: AOTOutput + attr: str + + def expr(self) -> str: + return f"{self.base.expr()}.{self.attr}" + + def is_grad(self) -> bool: + return self.base.is_grad() + + +@dataclasses.dataclass(frozen=True) +class SubclassSizeAOTOutput(AOTOutput): + """This output size will be bundled into a subclass at this location""" + + base: AOTOutput + idx: int + + def expr(self) -> str: + return f"{self.base.expr()}.size({self.idx})" + + +@dataclasses.dataclass(frozen=True) +class SubclassStrideAOTOutput(AOTOutput): + """This output stride will be bundled into a subclass at this location""" + + base: AOTOutput + idx: int + + def expr(self) -> str: + return f"{self.base.expr()}.stride({self.idx})" + + +@dataclasses.dataclass(frozen=True) +class DummyAOTOutput(AOTOutput): + """For cases when you don't actually care about descriptor propagation, do not use under normal + circumstances.""" + + idx: int + + def expr(self) -> str: + return f"__dummy{self.idx}" + + +@dataclasses.dataclass(frozen=True) +class SavedForBackwardsAOTOutput(AOTOutput): + idx: int + + def expr(self) -> str: + return f"__saved_for_backwards_{self.idx}" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture.py new file mode 100644 index 0000000000000000000000000000000000000000..7dceaee3dacb23e9fa7d83e8b200628d2d1a71e4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture.py @@ -0,0 +1,506 @@ +# mypy: allow-untyped-defs +""" +This module dispatches the graphs to either the forward-only or joint compilation +pathways, taking into account the AOTConfig and the collected ViewAndMutationMetadata. +""" + +import contextlib +import dataclasses +from typing import Any, Optional + +import torch +import torch.utils._pytree as pytree +import torch.utils.dlpack +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.utils import detect_fake_mode, lazy_format_graph_code +from torch._logging import getArtifactLogger, trace_structured +from torch._subclasses.functional_tensor import FunctionalTensorMode +from torch.fx.experimental.proxy_tensor import make_fx +from torchgen.utils import dataclass_repr + +from .. import config +from .descriptors import AOTInput, BackwardTokenAOTInput +from .functional_utils import ( + assert_functional_graph, + propagate_input_mutation_stacktraces, +) +from .graph_capture_wrappers import ( + aot_dispatch_subclass, + create_functionalized_fn, + create_joint, + fn_input_mutations_to_outputs, + fn_prepped_for_autograd, + handle_effect_tokens_fn, +) +from .schemas import AOTConfig, FxValue, SubclassMeta, TraceFn, ViewAndMutationMeta +from .streams import assign_backward_streams, insert_backward_syncs, sync_deallocations +from .utils import ( + call_and_expect_output_descs, + copy_fwd_metadata_to_bw_nodes, + fn_wrappers, + register_buffer_assignment_hook, + root_module_when_exporting_non_strict, + simple_wraps, + unlift_tokens, +) + + +aot_graphs_log = getArtifactLogger(__name__, "aot_graphs") + + +def _create_graph( + f, + args: list[torch.Tensor], + args_descs: Optional[ + list[AOTInput] + ] = None, # keep compat with old clients; maybe we should split into two impls + *, + aot_config: AOTConfig, +) -> torch.fx.GraphModule: + # FunctionalTensorMode must be enabled here. + # See Note [Accessing .grad_fn on FunctionalTensor] + out_descs = None + + if args_descs is None: + inner_f = f + else: + + @simple_wraps(f) + def inner_f(*args): + nonlocal out_descs + assert out_descs is None + out, out_descs = call_and_expect_output_descs(f, args) + return out + + if aot_config.disable_functionalization: + ctx = contextlib.nullcontext() + else: + ctx = FunctionalTensorMode( # type: ignore[assignment] + pre_dispatch=aot_config.pre_dispatch, + export=aot_config.is_export, + # Allow token discovery for joint fn tracing as tokens can be used in backward. + _allow_token_discovery=True, + ) + + with ( + enable_python_dispatcher(), + ctx, + ): + fx_g = make_fx( + inner_f, + decomposition_table=aot_config.decompositions, + record_module_stack=True, + pre_dispatch=aot_config.pre_dispatch, + )(*args) + + if args_descs is not None: + flat_args_descs, _ = pytree.tree_flatten(args_descs) + flat_out_descs, _ = pytree.tree_flatten(out_descs) + + # Unfortunately, flat_args_descs is not guaranteed to match the + # number of actual arguments that show up on the FX graph. + # Specifically, allow_token_discovery=True means that we will + # silently add extra token arguments to the backwards graph. + # + # Although there are a few ways to detect what these tokens are, + # we are going to settle for something dodgy but simple to + # implement: match tangents_token placeholders specifically, + # as these are the only placeholders that are created by token + # discovery (NB: there is NO other code that treats this name + # as load bearing, so this is a bit naughty!) + # + # I originally wanted to detect tokens in exactly the same way + # that they are detected at normal runtime, but to be honest + # the normal runtime detection is pretty strange: it seems the + # backward tokens are not reliably at the end of the argument list + # but *precede* the RNG arguments (I don't understand why this is + # the case). And in unlift_tokens, token arguments are detected + # by seeing if they feed into an effects call! Dastardly. Why + # didn't we just introduce a new type. + + i = 0 + j = 0 + for n in fx_g.graph.nodes: + if n.op == "placeholder": + if n.name.startswith("tangents_token"): + n.meta["desc"] = BackwardTokenAOTInput(j) + j += 1 + else: + assert i < len(flat_args_descs), ( + (fn_wrappers(inner_f)), + [n for n in fx_g.graph.nodes if n.op == "placeholder"], + flat_args_descs, + ) + n.meta["desc"] = flat_args_descs[i] + i += 1 + elif n.op == "output": + n.meta["desc"] = flat_out_descs + + return fx_g + + +# TODO: Refactor the following code so detach() persists item_memo +def _detach_and_copy_item_memo(t): + detached_t = t.detach() + if hasattr(t, "item_memo"): + detached_t.item_memo = t.item_memo + return detached_t + + +def aot_dispatch_base_graph( + flat_fn: TraceFn, + flat_args: list[FxValue], + flat_args_descs: list[AOTInput], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, +) -> tuple[torch.fx.GraphModule, list[FxValue], list[AOTInput], Optional[SubclassMeta]]: + # aot_dispatch_base requires functionalization, but doesn't need to handle as many cases as the autograd case. + # The cases that aot_dispatch_base doesn't need to handle include: + # - outputs that are aliases of graph intermediates + # - outputs that are aliases of graph inputs + # While cases that it does need to handle include: + # - input mutations (including when inputs are aliases of each other) + # - input metadata mutations + fn_to_trace = fn_input_mutations_to_outputs( + flat_fn, + flat_args_descs, + fw_metadata, + keep_data_input_mutations=aot_config.keep_inference_input_mutations, + ) + + if aot_config.disable_functionalization: + updated_flat_args, updated_flat_args_descs = ( + flat_args, + flat_args_descs, + ) + else: + fn_to_trace, updated_flat_args, updated_flat_args_descs = ( + create_functionalized_fn( + fn_to_trace, + flat_args, + flat_args_descs, + meta=fw_metadata, + aot_config=aot_config, + trace_joint=False, + ) + ) + + # TODO: replace with AOTDispatchSubclassWrapper once we refactor + # fn_input_mutations_to_outputs and create_functionalized_fn + # into CompilerWrappers. + ( + fn_to_trace, + updated_flat_args_subclasses_desugared, + updated_flat_args_subclasses_desugared_descs, + maybe_subclass_meta, + ) = aot_dispatch_subclass( + fn_to_trace, + updated_flat_args, + updated_flat_args_descs, + is_joint_structure=False, + meta=fw_metadata, + fw_only=flat_fn, + ) + + if not aot_config.disable_functionalization: + ( + fn_to_trace, + updated_flat_args_subclasses_desugared, + updated_flat_args_subclasses_desugared_descs, + ) = handle_effect_tokens_fn( + fn_to_trace, + updated_flat_args_subclasses_desugared, + updated_flat_args_subclasses_desugared_descs, + meta=fw_metadata, + trace_joint=False, + ) + + aot_graphs_log.debug( + "aot_config id: %s, fw_metadata=%s,subclass_metadata=%s", + str(aot_config.aot_id), + str(fw_metadata), + str(maybe_subclass_meta), + ) + + # We track buffer assignments when exporting in non-strict mode. + # (In contrast, strict mode errors on any attribute assignment.) + mod_when_exporting_non_strict = root_module_when_exporting_non_strict(flat_fn) + if aot_config.is_export and mod_when_exporting_non_strict is not None: + # For any buffer that is assigned, we want to associate it to the final proxy node + # that it is assigned to. This node can then be added as a buffer mutation output. + assigned_buffers: dict[str, str] = {} + hook = register_buffer_assignment_hook( + mod_when_exporting_non_strict, assigned_buffers + ) + + fake_mode = detect_fake_mode() + if fake_mode: + saved_updated_flat_args_subclasses_desugared = pytree.tree_map_only( + torch.Tensor, + _detach_and_copy_item_memo, + updated_flat_args_subclasses_desugared, + ) + else: + saved_updated_flat_args_subclasses_desugared = pytree.tree_map_only( + torch.Tensor, lambda t: t.detach(), updated_flat_args_subclasses_desugared + ) + saved_updated_flat_args_subclasses_desugared_descs = ( + updated_flat_args_subclasses_desugared_descs + ) + + fw_module = _create_graph( + fn_to_trace, + updated_flat_args_subclasses_desugared, + updated_flat_args_subclasses_desugared_descs, + aot_config=aot_config, + ) + + if aot_config.is_export and mod_when_exporting_non_strict is not None: + # We update metadata to consider any assigned buffers as buffer mutations. + i = len(dict(mod_when_exporting_non_strict.named_parameters())) + for name, _ in mod_when_exporting_non_strict.named_buffers(): + if name in assigned_buffers and not fw_metadata.input_info[i].mutates_data: # type: ignore[possibly-undefined] + fw_metadata.input_info[i] = dataclasses.replace( + fw_metadata.input_info[i], mutates_data=True + ) + fw_metadata.num_mutated_inp_runtime_indices += 1 + i += 1 + + # We add nodes corresponding to buffer assignments as output nodes in the graph. + add_nodes = [] + output_node = list(fw_module.graph.nodes)[-1] + for name in assigned_buffers.values(): # type: ignore[possibly-undefined] + for node in fw_module.graph.nodes: + if node.name == name: + add_nodes.append(node) + node.users[output_node] = None + output_node.args = ((*add_nodes, *output_node.args[0]),) + + hook.remove() # type: ignore[possibly-undefined] + + # As long as we opted to remove input mutations, then + # there should be *NO* mutating ops in the graph at this point. + if not aot_config.disable_functionalization: + copy_count = assert_functional_graph(fw_module.graph) + fw_module.graph.eliminate_dead_code() + fw_module.recompile() + copy_count2 = assert_functional_graph(fw_module.graph) + propagate_input_mutation_stacktraces(fw_module.graph) + assert copy_count == copy_count2 + else: + fw_module.graph.eliminate_dead_code() + + # See Note [Side-Effectful Tokens in AOTAutograd] + num_tokens = len(fw_metadata.tokens) + if num_tokens != 0 and config.unlift_effect_tokens: + unlift_tokens(fw_module, fw_metadata, aot_config) + saved_updated_flat_args_subclasses_desugared = ( + saved_updated_flat_args_subclasses_desugared[num_tokens:] + ) + saved_updated_flat_args_subclasses_desugared_descs = ( + saved_updated_flat_args_subclasses_desugared_descs[num_tokens:] + ) + + if aot_config.enable_log: + aot_graphs_log.info( + "%s", + lazy_format_graph_code( + "Forward graph", + fw_module, + aot_config.aot_id, + include_stride=True, + include_device=True, + colored=True, + ), + ) + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(fw_metadata), + ) + if maybe_subclass_meta is not None: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_subclass_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(maybe_subclass_meta), + ) + + trace_structured( + "aot_inference_graph", + payload_fn=lambda: fw_module.print_readable( + print_output=False, + include_stride=True, + include_device=True, + expanded_def=True, + ), + ) + + # TODO: should factor this into a separate function for export that always only returns just the graph. + if aot_config.is_export: + assert maybe_subclass_meta is None, ( + "aot_export_module does not support tensor subclass inputs for now." + ) + return ( + fw_module, + saved_updated_flat_args_subclasses_desugared, + saved_updated_flat_args_subclasses_desugared_descs, + maybe_subclass_meta, + ) + + +# Has the precondition that there +# are no duplicate arguments in flat_args (e.g., the same Tensor +# object never shows up twice. However, two tensor inputs MAY alias +# the same storage, so long as they have separate TensorImpls.) +def aot_dispatch_autograd_graph( + flat_fn: TraceFn, + flat_args: list[Any], + flat_args_descs: list[AOTInput], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, +) -> tuple[ + torch.fx.GraphModule, + tuple[list[Any], list[Any]], + tuple[list[AOTInput], list[AOTInput]], + Optional[SubclassMeta], +]: + # NB: flat_fn here is the original user function (as far as + # aot_module_simplified is concerned) + + # traced_tangents corresponds to the set of outputs in the traced forward that should get grad_outputs in the traced backward. + # It includes outputs of the original forward, *and* any updated inputs due to input mutations. + # However, it does *not* include any outputs that are aliases of inputs or intermediates, or any metadata-only input mutations. + joint_inputs = (flat_args, fw_metadata.traced_tangents) + joint_inputs_descs = (flat_args_descs, fw_metadata.traced_tangents_descs) + + fn_prepared_for_autograd = fn_prepped_for_autograd( + flat_fn, + flat_args_descs, + fw_metadata, + aot_config, + ) + joint_fn_to_trace = create_joint( + fn_prepared_for_autograd, flat_args_descs, aot_config=aot_config + ) + joint_fn_handle = joint_fn_to_trace.handle + + if aot_config.disable_functionalization: + updated_joint_inputs, updated_joint_inputs_descs = ( + joint_inputs, + joint_inputs_descs, + ) + else: + joint_fn_to_trace, updated_joint_inputs, updated_joint_inputs_descs = ( + create_functionalized_fn( + joint_fn_to_trace, + joint_inputs, + joint_inputs_descs, + meta=fw_metadata, + aot_config=aot_config, + trace_joint=True, + joint_fn_handle=joint_fn_handle, + ) + ) + + # TODO: replace with AOTDispatchSubclassWrapper once we refactor + # fn_input_mutations_to_outputs and create_functionalized_fn + # into CompilerWrappers. + subclass_tracing_info = aot_dispatch_subclass( + joint_fn_to_trace, + updated_joint_inputs, + updated_joint_inputs_descs, + is_joint_structure=True, + meta=fw_metadata, + fw_only=flat_fn, + ) + + joint_fn_to_trace = subclass_tracing_info.plain_tensor_trace_fn + updated_joint_inputs = subclass_tracing_info.plain_tensor_args + updated_joint_inputs_descs = subclass_tracing_info.plain_tensor_args_descs + + if not aot_config.disable_functionalization: + (joint_fn_to_trace, updated_joint_inputs, updated_joint_inputs_descs) = ( + handle_effect_tokens_fn( + joint_fn_to_trace, + updated_joint_inputs, + updated_joint_inputs_descs, + meta=fw_metadata, + trace_joint=True, + ) + ) + + # When we call _create_graph, this may mutate the metadata of joint + # inputs. But callers are expecting to get the original joint inputs. So + # we make aliases of all the inputs to make sure we have a copy that + # doesn't get modified. + # + # This destroys requires_grad/grad_fn information. However, backends + # beneath AOTAutograd are indifferent to this information, so it doesn't + # matter. + + fake_mode = detect_fake_mode() + if fake_mode: + saved_updated_joint_inputs = pytree.tree_map_only( + torch.Tensor, _detach_and_copy_item_memo, updated_joint_inputs + ) + else: + saved_updated_joint_inputs = pytree.tree_map_only( + torch.Tensor, lambda t: t.detach(), updated_joint_inputs + ) + maybe_subclass_meta = subclass_tracing_info.maybe_subclass_meta + + fx_g = _create_graph( + joint_fn_to_trace, + updated_joint_inputs, + updated_joint_inputs_descs, + aot_config=aot_config, + ) + + # Redundant with the check above, but worth having in case tracing introduced + # a fake tensor. Unlikely. + # See Note: [Fake Modules and AOTAutograd] + torch._dynamo.utils.assert_no_fake_params_or_buffers(fx_g) + + # Have to copy before eliminate_dead_code otherwise the + # fw node match might be erased + copy_fwd_metadata_to_bw_nodes(fx_g) + + # After copying metadata, assign streams to gradient accumulation nodes + assign_backward_streams(fx_g) + + # Insert syncs for newly assigned backward streams + insert_backward_syncs(fx_g) + + # Sync deallocations for tensors where the stream w/ their last usage + # is distinct from their allocation strea + sync_deallocations(fx_g) + + fx_g.graph.eliminate_dead_code() + if not aot_config.disable_functionalization: + # There should be *NO* mutating ops in the graph at this point. + assert_functional_graph(fx_g.graph) + + fx_g.recompile() + + # TODO: in AOTAutograd, we create metadata like _indices_of_inps_to_detach to detect + # when we need to manually detach() some inputs in the forward. + # Higher order ops might eventually need to do the same. + if aot_config.is_export: + assert maybe_subclass_meta is None, ( + "aot_export_module does not support tensor subclass inputs for now." + ) + return ( + fx_g, + saved_updated_joint_inputs, + updated_joint_inputs_descs, + maybe_subclass_meta, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_compile.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..c4b1939a741e57daee2dd0fde613730743225ddb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_compile.py @@ -0,0 +1,2338 @@ +# mypy: allow-untyped-defs +""" +Functions in this module do most of the "work" of AOTAutograd. +An aot_dispatch_* function: +- Takes in the input flat_fn, flat_args, and some metadata +- Runs a set of pre compile wrappers (e.g. argument deduping) +- Runs the actual compiler +- Wraps the returned callable in a set of post compile wrappers +- Returns the wrapped callable and metadata. +""" + +import copy +import dataclasses +import itertools +import logging +import operator +import time +import traceback +from collections import defaultdict +from collections.abc import Callable +from contextlib import nullcontext +from typing import Any, Optional, TYPE_CHECKING, Union + + +if TYPE_CHECKING: + from collections.abc import Sequence + +import threading +from contextlib import contextmanager + +import torch +import torch.utils._pytree as pytree +import torch.utils.dlpack +from torch import Tensor +from torch._dynamo.utils import ( + CompileEventLogger, + detect_fake_mode, + dynamo_timed, + lazy_format_graph_code, +) +from torch._guards import CompileContext, TracingContext +from torch._logging import getArtifactLogger, trace_structured +from torch._subclasses import FakeTensor +from torch._subclasses.meta_utils import is_sparse_any +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental.proxy_tensor import is_sym_node +from torch.fx.experimental.symbolic_shapes import fx_placeholder_vals, guard_or_true +from torch.fx.graph_module import GraphModule +from torch.fx.passes._tensorify_python_scalars import tensorify_python_scalars +from torch.multiprocessing.reductions import StorageWeakRef +from torch.types import py_sym_types +from torch.utils._python_dispatch import is_traceable_wrapper_subclass +from torchgen.utils import dataclass_repr + +from .. import config +from .aot_autograd_result import GenericAOTAutogradResult, serialize_graph_module +from .autograd_cache import ( + AOTAutogradCache, + should_bundle_autograd_cache, + should_use_remote_autograd_cache, +) +from .descriptors import AOTOutput, PlainAOTOutput +from .graph_capture import aot_dispatch_autograd_graph, aot_dispatch_base_graph +from .logging_utils import track_graph_compiling +from .runtime_wrappers import ( + AOTDedupeWrapper, + AOTDispatchAutograd, + AOTDispatchSubclassWrapper, + AOTSyntheticBaseWrapper, + AutogradLazyBackwardCompileInfo, + CompilerWrapper, + DebugAssertWrapper, + EffectTokensWrapper, + FakifiedOutWrapper, + FunctionalizedRngRuntimeWrapper, + make_runtime_safe, + post_compile, + pre_compile, + RuntimeWrapper, + SerializableCompiledFunction, +) +from .schemas import ( + AOTConfig, + AOTGraphCapture, + AOTState, + FlatFn, + FxValue, + MutationType, + SubclassMeta, + ViewAndMutationMeta, +) +from .subclass_utils import compute_inner_mutated_inp_indices_from_subclass_meta +from .utils import ( + contain_metadata_mutation_ops, + get_cuda_generator_meta_val, + make_boxed_func, + simple_wraps, + strict_zip, + unlift_tokens, +) + + +_thread_local = threading.local() + + +@contextmanager +def maybe_skip_decompose(aot_config: AOTConfig): + old_decomp = aot_config.decompositions + try: + if config.selective_decompose: + aot_config.decompositions = {} + yield + finally: + aot_config.decompositions = old_decomp + + +# Saved tensor hooks context +# Compiled saved tensor hooks are convenient way to inline some logic in the graphs +# for saved nodes from forward to backward. (E.g. activations quantization) +# In base implementation user does not have any additional information about saved value +# in the hook, except FakeTensor shape, dtype, device etc. +# _get_saved_tensor_hook_context gives additional graph information about that saved value, +# that can be used to make a decisions which pack/unpack to apply for particular saved value. +# This allows user to reuse saved tensors hooks api to apply selective pack/unpack in +# graph aware way. +# Alternative to this will be making user to write a custom pass that mucks with forward outputs, +# backward input metadata, which requires significantly more effort. +# +# As for now in context we expose forward graph, backward graph and current saved node, +# which contains node.meta with additional information about that fx.Node. +# Warning: This API may change without backward compatibility. +@contextmanager +def _saved_tensor_hook_context(state: dict[str, Any]): + previous_state = getattr(_thread_local, "state", None) + try: + _thread_local.state = state + yield + finally: + # Clean up: restore previous state or remove attribute + if previous_state is not None: + _thread_local.state = previous_state + else: + if hasattr(_thread_local, "state"): + delattr(_thread_local, "state") + + +def _get_saved_tensor_hook_context() -> dict[str, Any] | None: + return getattr(_thread_local, "state", None) + + +zip = strict_zip + +log = logging.getLogger(__name__) +aot_joint_log = getArtifactLogger(__name__, "aot_joint_graph") +aot_graphs_log = getArtifactLogger(__name__, "aot_graphs") + +aten = torch.ops.aten + +# Returns a Callable and a ViewAndMutationMeta. +# Currently, only export needs the ViewAndMutationMeta after this function. +# TODO: Refactor this +DispatchReturn = tuple[Callable, ViewAndMutationMeta] + + +def _create_wrappers_for_dispatch(needs_autograd: bool) -> list[CompilerWrapper]: + """ + Wrappers that run on every dispatch function + """ + return [AOTDedupeWrapper(), AOTSyntheticBaseWrapper(trace_joint=needs_autograd)] + + +def aot_stage1_graph_capture( + aot_state: AOTState, + orig_flat_fn: FlatFn, +) -> AOTGraphCapture: + # NB: flat_fn at this point coincides with the initial info from forward + # metadata collection returning a list[Tensor]. We are now going to + # augment the output to return a tuple[list[Tensor], list[AOTOutput]] and + # then preserve this convention through the rest of the passes. + + # TODO: We could test for consistency with fw_metadata, but this is not a + # big deal + @simple_wraps(orig_flat_fn) + def orig_flat_fn2(*args: FxValue) -> tuple[list[FxValue], list[AOTOutput]]: + out = orig_flat_fn(*args) + out_descs: list[AOTOutput] = type(out)( # type: ignore[assignment] + PlainAOTOutput(i) # type: ignore[misc] + for i in range(len(out)) # type: ignore[misc] + ) + return out, out_descs + + aot_config = aot_state.aot_config + + wrappers = _create_wrappers_for_dispatch(aot_state.needs_autograd) + flat_fn, aot_state.flat_args, aot_state.flat_args_descs, aot_state.fw_metadata = ( + pre_compile( + wrappers, + orig_flat_fn2, + aot_state.flat_args, + aot_state.flat_args_descs, + aot_config, + fw_metadata=aot_state.fw_metadata, + ) + ) + + # NB: This is currently only used for backwards, where fwd/bwd + # deterministic TLS can be different + aot_state.fw_metadata.deterministic = torch.are_deterministic_algorithms_enabled() + updated_flat_args: Union[list[Any], tuple[list[Any], list[Any]]] + + with maybe_skip_decompose(aot_config): + # if config.selective_decompose, skip decomposition and apply selective_decompose + # after we get the joint graph. See [Note: Selective Decomposition] for details. + if aot_state.needs_autograd and not aot_config.pre_dispatch: + # FYI: this being moved to trigger in export is new, seems fine! + with dynamo_timed("aot_trace_joint_graph", log_pt2_compile_event=True): + ( + graph, + updated_flat_args, + updated_flat_args_descs, + maybe_subclass_meta, + ) = aot_dispatch_autograd_graph( + flat_fn, + aot_state.flat_args, + aot_state.flat_args_descs, + aot_config, + fw_metadata=aot_state.fw_metadata, + ) + else: + graph, updated_flat_args, updated_flat_args_descs, maybe_subclass_meta = ( + aot_dispatch_base_graph( + flat_fn, + aot_state.flat_args, + aot_state.flat_args_descs, + aot_config, + fw_metadata=aot_state.fw_metadata, + ) + ) + # Apply AC rematerialization to forward+loss+bwd graph + if torch._functorch.config.remat_using_tags_for_fwd_loss_bwd_graph: + from torch._functorch._activation_checkpointing.remat_using_tags_for_fwd_loss_bwd_graph_pass import ( + remat_using_tags_for_fwd_loss_bwd_graph, + ) + + graph = remat_using_tags_for_fwd_loss_bwd_graph(graph) + + if config.selective_decompose: + from torch.fx.experimental.proxy_tensor import selective_decompose + from torch.fx.passes.regional_inductor import _needs_inductor_compile + + graph = selective_decompose( + graph, + *updated_flat_args, + decomposition=aot_config.decompositions, + should_decompose=_needs_inductor_compile, + trace_joint_graph=aot_state.needs_autograd and not aot_config.pre_dispatch, + ) + + return AOTGraphCapture( + wrappers=wrappers, + graph_module=graph, + updated_flat_args=updated_flat_args, + updated_flat_args_descs=updated_flat_args_descs, + maybe_subclass_meta=maybe_subclass_meta, + ) + + +def aot_stage2_export( + aot_state: AOTState, aot_graph_capture: AOTGraphCapture +) -> DispatchReturn: + graph = aot_graph_capture.graph_module + aot_config = aot_state.aot_config + wrappers = aot_graph_capture.wrappers + + CompileEventLogger.try_add_pt2_compile("backend_compile", dispatch_mode="export") + + # NB: the wrappers that run in pre_compile for export are + # either a no-op, because they're not needed, or will raise a runtime error, + # since they don't support export. + # We still run these wrappers to make sure that they're not needed pre compile, + # but we technically don't need to run them post compile at all here. + compiled_fn, aot_state.fw_metadata = post_compile( + wrappers, graph, aot_config, runtime_metadata=aot_state.fw_metadata + ) + + # Therefore, since no wrapperes run, we don't get back a callable - we get back the raw fx graph + # (either a joint or an inference-only graph) + assert isinstance(compiled_fn, torch.fx.GraphModule) + return compiled_fn, aot_state.fw_metadata + + +def sanitize_aot_config(input: AOTConfig) -> AOTConfig: + return AOTConfig( + fw_compiler=None, # type: ignore[arg-type] + bw_compiler=None, # type: ignore[arg-type] + partition_fn=None, # type: ignore[arg-type] + decompositions={}, + inference_compiler=None, + num_params_buffers=input.num_params_buffers, + aot_id=input.aot_id, + keep_inference_input_mutations=input.keep_inference_input_mutations, + is_export=input.is_export, + no_tangents=input.no_tangents, + aot_autograd_arg_pos_to_source=input.aot_autograd_arg_pos_to_source, + dynamic_shapes=input.dynamic_shapes, + enable_log=input.enable_log, + static_input_indices=input.static_input_indices, + pre_dispatch=input.pre_dispatch, + cache_info=None, + precompile_backend_id=input.precompile_backend_id, + ) + + +def _get_inner_meta( + maybe_subclass_meta: Optional[SubclassMeta], + fw_metadata: ViewAndMutationMeta, +) -> ViewAndMutationMeta: + """ + Util to get view and mutation metadata. + """ + return ( + fw_metadata if maybe_subclass_meta is None else maybe_subclass_meta.fw_metadata + ) + + +def _apply_tensorify_python_scalars(module: torch.fx.GraphModule) -> None: + """ + Util to apply tensorify_python_scalars. + """ + # TODO(anijain2305) - Add tensorify_python_scalars to the HOP graph passes. + fake_mode = detect_fake_mode() + if fake_mode is not None and fake_mode.shape_env is not None: + tensorify_python_scalars(module, fake_mode.shape_env, fake_mode) + + +def aot_stage2_compile( + aot_state: AOTState, + aot_graph_capture: AOTGraphCapture, + partition_fn: Callable, + fw_compiler: Callable, + bw_compiler: Optional[Callable] = None, + inference_compiler: Optional[Callable] = None, +) -> DispatchReturn: + if bw_compiler is None: + bw_compiler = fw_compiler + if inference_compiler is None: + inference_compiler = fw_compiler + # Update the AOTState with the provided compilers + aot_state.aot_config.partition_fn = partition_fn + aot_state.aot_config.fw_compiler = fw_compiler + aot_state.aot_config.bw_compiler = bw_compiler + aot_state.aot_config.inference_compiler = inference_compiler + + if aot_state.needs_autograd and not aot_state.aot_config.pre_dispatch: + return aot_stage2_autograd(aot_state, aot_graph_capture) + else: + return aot_stage2_inference(aot_state, aot_graph_capture) + + +def _log_inference_graph( + fw_module: torch.fx.GraphModule, + aot_config: AOTConfig, +) -> Optional[str]: + """ + Log the inference graph to the structured logger. + Return a str representation of the graph. + """ + if aot_config.enable_log: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "torch._functorch.config", + "encoding": "string", + }, + payload_fn=lambda: torch._functorch.config.get_serializable_config_copy(), + ) + + # Save the forward_graph_str right after aot_dispatch_base_graph, + # to save in the cache + aot_forward_graph_str = None + if aot_config.cache_info is not None: + aot_forward_graph_str = fw_module.print_readable( + print_output=False, + include_stride=True, + include_device=True, + fast_sympy_print=True, + expanded_def=True, + ) + + return aot_forward_graph_str + + +def _aot_stage2b_inference_compile( + fw_module: torch.fx.GraphModule, + updated_flat_args: list[Any], + maybe_subclass_meta: Optional[SubclassMeta], + fw_metadata: ViewAndMutationMeta, + aot_config, +) -> Callable: + return _aot_stage2b_compile_forward_or_inference( + fw_module, + updated_flat_args, # type: ignore[arg-type] + maybe_subclass_meta, + fw_metadata, + aot_config, + is_inference=True, + )[1] + + +def aot_stage2_inference( + aot_state: AOTState, + aot_graph_capture: AOTGraphCapture, +) -> DispatchReturn: + """ + Handles functions that don't need autograd. Runs wrappers and compiles with fw_compiler. + """ + + aot_config = aot_state.aot_config + fw_metadata = aot_state.fw_metadata + fw_module = aot_graph_capture.graph_module + wrappers = aot_graph_capture.wrappers + updated_flat_args = aot_graph_capture.updated_flat_args + maybe_subclass_meta = aot_graph_capture.maybe_subclass_meta + + CompileEventLogger.try_add_pt2_compile("backend_compile", dispatch_mode="inference") + aot_forward_graph_str = _log_inference_graph(fw_module, aot_config) + + assert isinstance(fw_module, GraphModule) + _apply_tensorify_python_scalars(fw_module) + + compiled_fw = _aot_stage2b_inference_compile( + fw_module, + updated_flat_args, # type: ignore[arg-type] + maybe_subclass_meta, + fw_metadata, + aot_config, + ) + + entry = _cache_inference_info( + aot_config, + fw_metadata, + maybe_subclass_meta, + compiled_fw, + aot_forward_graph_str, + wrappers, + ) + + return _aot_stage2c_make_inference_function( + aot_config, + fw_metadata, + compiled_fw, + wrappers, + entry, + ) + + +def _cache_inference_info( + aot_config, + fw_metadata, + maybe_subclass_meta, + compiled_fw, + aot_forward_graph_str, + wrappers, +): + make_runtime_safe(fw_metadata, maybe_subclass_meta) + + cache_info = aot_config.cache_info + + def should_save_cache(): + if should_bundle_autograd_cache(): + return True + else: + return hasattr(compiled_fw, "_fx_graph_cache_key") + + entry: Optional[GenericAOTAutogradResult] = None + if cache_info is not None and should_save_cache(): + time_taken_ns = time.time_ns() - cache_info.start_time_ns + guards_expr = AOTAutogradCache.generate_guards_expression(cache_info) + entry = AOTAutogradCache.make_entry( + compiled_fw_func=compiled_fw, # type: ignore[arg-type] + compiled_bw_func=None, + aot_joint_graph_str=None, + aot_forward_graph_str=aot_forward_graph_str, + aot_backward_graph_str=None, + runtime_metadata=fw_metadata, + dispatch_wrappers=wrappers, + maybe_subclass_meta=maybe_subclass_meta, + num_fw_outs_saved_for_bw=None, + indices_of_inps_to_detach=[], + forward_time_taken_ns=time_taken_ns, + backward_time_taken_ns=0, + sanitized_aot_config=sanitize_aot_config(aot_config), + guards_expr=guards_expr, + backward_state_indices=None, + num_symints_saved_for_bw=None, + serialized_bw_module=None, + ) + AOTAutogradCache.save( + cache_info.cache_key, + entry, + remote=should_use_remote_autograd_cache(), + ) + + return entry + + +def _aot_stage2c_make_inference_function( + aot_config, + fw_metadata, + compiled_fw, + wrappers, + entry, +): + if entry is not None: + compiled_fw = SerializableCompiledFunction(compiled_fw, lambda: entry) + + disable_amp = torch._C._is_any_autocast_enabled() + compiled_fn = RuntimeWrapper( + indices_of_inps_to_detach=[], + trace_joint=False, + disable_amp=disable_amp, + ).post_compile( + compiled_fw, + aot_config, + runtime_metadata=fw_metadata, + ) + + compiled_fn = post_compile( + wrappers, compiled_fn, aot_config, runtime_metadata=fw_metadata + ) + return compiled_fn + + +def collect_fw_donated_buffer_idxs( + fw_ins: list[Optional[FakeTensor]], + user_fw_outs: list[Optional[FakeTensor]], + bw_outs: list[Optional[FakeTensor]], + saved_tensors: list[FakeTensor], +) -> list[int]: + """ + Checks if the saved tensors are donated buffers, which means a saved tensor is not + an alias of any tensors in fw_ins, user_fw_outs, and bw_outs. + """ + + storage_refs = set() + + for t in itertools.chain(fw_ins, user_fw_outs, bw_outs): + # Only access storage if a tensor has storage (not sparse) + if t is not None and isinstance(t, FakeTensor) and not is_sparse_any(t): + storage_refs.add(StorageWeakRef(t.untyped_storage())) + + num_saved_tensor = len(saved_tensors) + donated_buffer_idxs = [] + for i in range(num_saved_tensor): + t = saved_tensors[i] + if ( + t is not None + and not is_sparse_any(t) + and StorageWeakRef(t.untyped_storage()) not in storage_refs + ): + donated_buffer_idxs.append(i) + + return donated_buffer_idxs + + +def collect_bw_donated_buffer_idxs( + fw_module: torch.fx.GraphModule, + bw_module: torch.fx.GraphModule, + fw_metadata: ViewAndMutationMeta, +) -> list[int]: + """ + Collects backward donated buffer indexes from fw_module and bw_module. + """ + + # [Note: Metadata mutation in proxy tracing] + # node.meta["val"] is a snapshot of the tensor value when tracing a graph, + # instead of the final state after the graph has run. node.meta["val"] is + # not updated even if later there is a metadata mutation op. + # See: https://github.com/pytorch/pytorch/pull/141308#issuecomment-2495798947 + # + # Currently, metadata mutation op happens only for sacrificial parameter + # specifically the `set_` op. This motivates banning metadata mutation from + # proxy tracing. + # + # Since node.meta["val"] is used to detect donated buffer, we return an empty + # list if there exists metadata mutation op. + if contain_metadata_mutation_ops(fw_module) or contain_metadata_mutation_ops( + bw_module + ): + return [] + + fw_ins = fw_module.graph.find_nodes(op="placeholder") + bw_outs = next(reversed(bw_module.graph.find_nodes(op="output"))).args[0] + fw_outs = next(reversed(fw_module.graph.find_nodes(op="output"))).args[0] + + fw_ins = [ + n.meta["val"] if (hasattr(n, "meta") and "val" in n.meta) else None + for n in fw_ins + ] + fw_outs = [ + n.meta["val"] if (hasattr(n, "meta") and "val" in n.meta) else None + for n in fw_outs + ] + bw_outs = [ + n.meta["val"] if (hasattr(n, "meta") and "val" in n.meta) else None + for n in bw_outs + ] + + user_fw_outs = fw_outs[: fw_metadata.num_forward] + saved_tensors = fw_outs[fw_metadata.tensors_saved_for_backwards_slice] + + fw_donated_buffer = collect_fw_donated_buffer_idxs( + fw_ins, + user_fw_outs, + bw_outs, + # pyrefly: ignore [bad-argument-type] + saved_tensors, + ) + + assert fw_metadata.num_symints_saved_for_bw is not None + return [fw_metadata.num_symints_saved_for_bw + i for i in fw_donated_buffer] + + +@dataclasses.dataclass +class InvokeSubgraphHopGraphs: + """ + A data structure to hold all the information needed to partition the + `joint_hop_gm` and joint graph and the restitch the `new_fw_hop_gm` and + `new_bw_hop_gm` into the bigger `joint_gm`. + """ + + # To avoid re-partitioning subgraphs + partitioning_done: bool = False + old_num_fw_outputs: Optional[int] = None + old_num_fw_inputs: Optional[int] = None + + new_fw_hop_gm: Optional[torch.fx.GraphModule] = None + new_bw_hop_gm: Optional[torch.fx.GraphModule] = None + new_num_sym_nodes: Optional[int] = None + new_num_saved_nodes: Optional[int] = None + + +def prepare_for_partitioner(mod, num_primals, num_fw_outputs): + # min-cut partitioner requires the placeholders to have primals and + # tangents string in the node.name. The signature of the joint graph is + # (*primals, *tangents) + + # We also have to update the output signature which is right now + # (*grads, *fw_outs) and we have to change to (*fw_outs, *grads) for the + # partitioner to work. + new_graph = torch.fx.Graph() + env = {} + + primals_counter = itertools.count(0) + tangents_counter = itertools.count(0) + + for idx, node in enumerate(mod.graph.nodes): + if node.op == "placeholder": + if idx < num_primals: + env[node] = new_graph.placeholder(f"primals_{next(primals_counter)}") + else: + env[node] = new_graph.placeholder(f"tangents_{next(tangents_counter)}") + env[node].meta = copy.copy(node.meta) + elif node.op == "output": + # Reverse the (*grads, *fw_outs) to (*fw_outs, *grads) + # The reason for having the reversed signature in the first + # place is to simplify step 3. + old_outputs = node.args[0] + new_outputs = ( + *old_outputs[-num_fw_outputs:], + *old_outputs[:-num_fw_outputs], + ) + new_outputs = [env[n] if n else None for n in new_outputs] + new_graph.output(tuple(new_outputs)) + else: + env[node] = new_graph.node_copy(node, lambda n: env[n]) + env[node].meta = copy.copy(node.meta) + + new_graph.lint() + + out = torch.fx.GraphModule(mod, new_graph) + return out + + +def run_joint_graph_passes_on_hops( + joint_gm: torch.fx.GraphModule, + joint_inputs: Any, + aot_config: AOTConfig, +) -> torch.fx.GraphModule: + """ + This pass runs the joint graph passes on the HOP graph. In torch.compile, we + typically have many passes which work on the joint graph and then end with a + partitioner. + + + The partitioner part is quite mechanical to handle. HOP have their own + forward and backward graph. The process can be broken into following steps + + 1) Get a `joint_hop_gm` from the `fw_hop_gm` and `bw_hop_gm` + 2) Run joint graph passes on the `joint_hop_gm` to get `new_fw_hop_gm` and `new_bw_hop_gm` + 3) Stitch the `new_fw_hop_gm` and `new_bw_hop_gm` back into the `joint_gm`. + + The terminology used in the code is + `joint_graph/joint_gm` : Refers to the main graph. This may contain many HOPs which have their own `hop_graph` + `fw_hop_graph/fw_hop_gm` : Refers to the forward graph associated with a HOP. + `bw_hop_graph/bw_hop_gm` : Refers to the backward graph associated with a HOP. + `joint_hop_graph/joint_hop_gm` : Refers to the subgraph associated with the HOP like invoke_subgraph. + `new_fw_hop_graph/new_fw_hop_gm` : Refers to the forward graph after partitioning is applied to `joint_hop_gm`. + `new_bw_hop_graph/new_bw_hop_gm` : Refers to the backward graph after partitioning is applied to `joint_hop_gm`. + + NB: This pass works for invoke_subgraph today because we took extra care in + the Autograd.Dispatch key of invoke_subgraph to vastly simplify Step 1. + """ + from torch._higher_order_ops import invoke_subgraph + + def num_outputs(mod): + return len(mod.graph.find_nodes(op="output")[0].args[0]) + + def num_inputs(mod): + return len(mod.graph.find_nodes(op="placeholder")) + + new_hop_graphs: dict[str, InvokeSubgraphHopGraphs] = defaultdict( + lambda: InvokeSubgraphHopGraphs() + ) + + # Step 1 - Get a `joint_hop_gm` from the `fw_hop_gm` and `bw_hop_gm` This is + # easy to do for `invoke_subgraph` HOP. During the Autograd dispatch key + # tracing, we have put the joint_hop_graph in the backward hop graph itself. + # So to recover the joint_hop_gm, we just have to look at the backward + # HOP graphs. + # So we will merge step 1 and step 2 in this next section + + # Save the fw and bwd hop nodes. We will later in-place modify the graph + # using these nodes. + fw_hop_nodes = [] + bw_hop_nodes = [] + for node in joint_gm.graph.nodes: + if ( + node.op == "call_function" + and node.target is invoke_subgraph + and isinstance(node.args[1], str) + ): + if node.args[1].startswith("fw"): + fw_hop_nodes.append(node) + elif node.args[1].startswith("bw"): + bw_hop_nodes.append(node) + + if not bw_hop_nodes: + return joint_gm + + assert len(fw_hop_nodes) == len(bw_hop_nodes) + + # Create a bw to hop node mapping. This helps us in identifying the bw and + # fw subgraph pairs without relying on the identifier. This is important + # because we can have different subgraphs for bwd for same subgraph in the + # fwd because of differing strides in the backward. + bw_to_fw_hop_node = dict(zip(list(reversed(bw_hop_nodes)), fw_hop_nodes)) + + for node in bw_hop_nodes: + identifier = node.args[1].removeprefix("bw") + + # If partitioning already done for this identifier, skip. This saves + # redundant joint graph passes for same subgraphs. + if new_hop_graphs[identifier].partitioning_done: + continue + + # Collect some information from the forward hop graph + fw_hop_node = bw_to_fw_hop_node[node] + fw_hop_gm = getattr(joint_gm, fw_hop_node.args[0].target) + assert isinstance(fw_hop_gm, torch.fx.GraphModule) + num_fw_inputs = num_inputs(fw_hop_gm) + num_fw_outputs = num_outputs(fw_hop_gm) + new_hop_graphs[identifier].old_num_fw_inputs = num_fw_inputs + new_hop_graphs[identifier].old_num_fw_outputs = num_fw_outputs + + # Step 1) - Get the `joint_hop_gm`. As mentioned earlier, the + # backward graph is the joint graph. + joint_hop_gm = getattr(joint_gm, node.args[0].target) + assert isinstance(joint_hop_gm, torch.fx.GraphModule) + + # Prepare the graph for the partitioner + joint_hop_gm = prepare_for_partitioner( + joint_hop_gm, num_fw_inputs, num_fw_outputs + ) + + # TODO: invoke_subgraph should track which of its inputs static indices + # so it can propagate them to the partitioner (and use in cudagraphs) + static_lifetime_input_indices: list[int] = [] + # Step 2) and 3) - Run joint graph passes and partitioner + new_fw_hop_gm, new_bw_hop_gm = aot_config.partition_fn( + joint_hop_gm, + [], + num_fwd_outputs=num_fw_outputs, + static_lifetime_input_indices=static_lifetime_input_indices, + ) + + # Save the new forward and backward graph modules + new_hop_graphs[identifier].new_fw_hop_gm = new_fw_hop_gm + new_hop_graphs[identifier].new_bw_hop_gm = new_bw_hop_gm + + # Save the number of symints and saved tensors + new_fw_out_nodes = new_fw_hop_gm.graph.find_nodes(op="output")[0].args[0] + extra_outputs = new_fw_out_nodes[num_fw_outputs:] + symint_outputs = [n for n in extra_outputs if is_sym_node(n)] + + new_hop_graphs[identifier].new_num_sym_nodes = len(symint_outputs) + new_hop_graphs[identifier].new_num_saved_nodes = len(extra_outputs) - len( + symint_outputs + ) + + new_hop_graphs[identifier].partitioning_done = True + + # Step 3) Restitch the new fw and bw graphs back into the main graph. + # + # This is a very mechanical process. There are a quite a few pieces that we + # need to connect together to make it work. Lets try to understand the + # problem statement first. + # + # For the forward graph, the signature of the old_fw_hop_gm is + # inputs - (*primals) + # outputs - (*fw_outs) + # Now the signature of the new_fw_hop_gm is + # inputs - (*primals) -- This is same + # outputs - (*fw_outs, *saved_tensors) - This is different + # At a high level, this is an easy transformation, in the new graph we just + # have to replace the old_fw_hop_gm with the new_fw_hop_gm. Everything else + # falls into place, because the input signature (i.e. args) is same. And + # even though output signature is different, fw_outs are still at the same + # indexes as before. So the forward of the `joint_gm` works nicely. + # + # Now, lets look at the backward hop graph. Old signature + # inputs - (*primals, *tangents) + # outputs - (*grad_outs, *fw_outs) + # New signature + # inputs - (*saved_tensors, *tangents) -- Different + # outputs - (*grad_outs) -- Different + # Here both input and output signature change. The output signature handling + # is quite easy because the grads_out are sitting at the right place, so we + # dont have to do anything. + # + # For the input signature, we have to collect the saved tensors from the + # corresponding forward graph output. We collect all saved_tensors when we + # see the forward graph, and save it into a map and then later use it during + # the backward. + + # The stack of fw_nodes for invoke_subgraph HOP. There is an implicit + # assumption about the graph structure, i.e., if we have hop1, hop2, hop3, + # ... in the forward part of the joint graph, we will have .., hop3, hop2, + # hop1 order for the backward. This structure allows us to just use a stack + # to collect all the information that we need to pass from the forward hop + # node to the corresponding backward node. + + already_added_new_hop_mods = set() + + def add_new_hop_gm(new_subgraph_mod, name): + new_subgraph_attr_name = f"partitioned_{name}" + if new_subgraph_attr_name in already_added_new_hop_mods: + return new_subgraph_attr_name + + joint_gm.register_module(new_subgraph_attr_name, new_subgraph_mod) + already_added_new_hop_mods.add(new_subgraph_attr_name) + return new_subgraph_attr_name + + def propagate_meta_info(new_hop_gm, new_call_function_node, old_call_function_node): + # Copy all the fields from the old call_function node. And then override + # the `val` meta field with the outputs of new_hop_gm. + new_call_function_node.meta = copy.copy(old_call_function_node.meta) + + output = new_hop_gm.graph.find_nodes(op="output")[0] + out_example_vals = [n.meta["val"] if n else None for n in output.args[0]] + new_call_function_node.meta["val"] = tuple(out_example_vals) + + for bw_node in reversed(bw_hop_nodes): + identifier = bw_node.args[1].removeprefix("bw") + + # Make changes to the corresponding fw and bw node pair simultaneously. + # The removes the need of any bookkeeping. + + # Fw node changes + # Insert the new_fw_hop_gm. This is straightforward. Get the + # new_fw_hop_gm, insert the hop_gm as a get_attr fw_node, and then + # add a call_function fw_node. Additionally, also use getitem + # call_functions to collect the saved_tensor nodes + + fw_node = bw_to_fw_hop_node[bw_node] + new_fw_hop_gm = new_hop_graphs[identifier].new_fw_hop_gm + assert new_fw_hop_gm is not None + + old_num_fw_outputs = new_hop_graphs[identifier].old_num_fw_outputs + new_num_sym_nodes = new_hop_graphs[identifier].new_num_sym_nodes + new_num_saved_nodes = new_hop_graphs[identifier].new_num_saved_nodes + assert old_num_fw_outputs is not None + assert new_num_sym_nodes is not None + assert new_num_saved_nodes is not None + total_outputs = old_num_fw_outputs + new_num_saved_nodes + new_num_sym_nodes + + extra_fw_outputs = [] + + # Insert the new_fw_hop_gm into the joint_gm + with joint_gm.graph.inserting_after(fw_node): + new_fw_mod_attr_name = add_new_hop_gm(new_fw_hop_gm, f"fw{identifier}") + new_fw_mod_attr = joint_gm.graph.get_attr(new_fw_mod_attr_name) + new_fw_mod_attr.meta = copy.copy(fw_node.args[0].meta) + + # new_hop_fw_gm output signature is (*fw_outs, *saved_tensors) + with joint_gm.graph.inserting_after(new_fw_mod_attr): + new_fw_node = joint_gm.graph.call_function( + the_function=invoke_subgraph, + args=( + new_fw_mod_attr, + new_fw_mod_attr_name, + *fw_node.args[2:], + ), + ) + propagate_meta_info(new_fw_hop_gm, new_fw_node, fw_node) + + # old_num_fw_outputs = (*fw_outs) + # new_num_fw_outputs = (*fw_outs, *saved_tensors, *sym_nodes) + with joint_gm.graph.inserting_after(new_fw_node): + for fw_out_idx in range(old_num_fw_outputs, total_outputs): + saved_tensor_node = joint_gm.graph.call_function( + the_function=operator.getitem, args=(new_fw_node, fw_out_idx) + ) + saved_tensor_node.meta = copy.copy(new_fw_node.meta) + saved_tensor_node.meta["val"] = new_fw_node.meta["val"][fw_out_idx] + extra_fw_outputs.append(saved_tensor_node) + + fw_node.replace_all_uses_with(new_fw_node) + joint_gm.graph.erase_node(fw_node) + + # Bw node changes + # Prepare the operands for the bwd graph + # Old bw graph signature : (*primals, *tangents) + # New signature will be : (*sym_nodes, *saved_tensors, *tangents) + # We have already collected the saved_tensors in the forward hop processing. + + # extra_fw_outputs are in the order (*saved_nodes, *sym_nodes). + # Partitioner has this quirk where the backward wants sym_nodes + # first. So extract the sym and saved nodes. + + new_bw_hop_gm = new_hop_graphs[identifier].new_bw_hop_gm + assert new_bw_hop_gm is not None + + saved_tensor_nodes = extra_fw_outputs[:new_num_saved_nodes] + sym_nodes = extra_fw_outputs[new_num_saved_nodes:] + + num_primals = new_hop_graphs[identifier].old_num_fw_inputs + assert num_primals is not None + tangents = list(bw_node.args[2 + num_primals :]) + operands = sym_nodes + saved_tensor_nodes + tangents + + # Insert the new_bw_hop_gm into the joint_gm + with joint_gm.graph.inserting_after(bw_node): + new_bw_mod_attr_name = add_new_hop_gm(new_bw_hop_gm, bw_node.args[1]) + new_bw_mod_attr = joint_gm.graph.get_attr(new_bw_mod_attr_name) + new_bw_mod_attr.meta = copy.copy(bw_node.args[0].meta) + + with joint_gm.graph.inserting_after(new_bw_mod_attr): + new_bw_node = joint_gm.graph.call_function( + the_function=invoke_subgraph, + args=( + new_bw_mod_attr, + new_bw_mod_attr_name, + *operands, + ), + ) + propagate_meta_info(new_bw_hop_gm, new_bw_node, bw_node) + # Since the partitioner is run after the graph passes, we have lost + # the eager information and cannot faithfully extract the eager + # inputs for the new partitioned backward graph. For the forward + # graph, it was fine because the input signature remains same. + new_bw_node.meta.pop("eager_input_vals", None) + + bw_node.replace_all_uses_with(new_bw_node) + joint_gm.graph.erase_node(bw_node) + + joint_gm.graph.eliminate_dead_code() + joint_gm.graph.lint() + joint_gm.recompile() + return joint_gm + + +def maybe_log_graph( + gm, + graph_name, + aot_config, + structured_log_prefix_fn, + out_structured_logs: Optional[list[str]] = None, +): + if not aot_config.enable_log: + return + aot_graphs_log.debug( + "%s", + lazy_format_graph_code( + f"{graph_name}", + gm, + aot_config.aot_id, + include_stride=True, + include_device=True, + colored=True, + ), + ) + + def gm_str_fn() -> str: + return gm.print_readable( + print_output=False, + include_stride=True, + include_device=True, + expanded_def=True, + ) + + if out_structured_logs is not None: + out_structured_logs.append(f"{structured_log_prefix_fn()}:{gm_str_fn()}") + else: + trace_structured( + f"{structured_log_prefix_fn()}", + payload_fn=lambda: gm_str_fn(), + ) + + +def create_wrap_fn(fn, args): + from torch.fx.experimental.proxy_tensor import maybe_enable_thunkify + + from .functional_utils import from_fun, has_data_mutation, to_fun + + def assert_no_mutation(t): + assert not has_data_mutation(t), ( + "Saved tensors hooks with inputs mutations are not allowed" + ) + + @simple_wraps(fn) + def _wrapper(*args): + with maybe_enable_thunkify(): + disable_above = torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ) + + with disable_above: + f_args = pytree.tree_map(to_fun, args) + f_outs = fn(*f_args) + pytree.tree_map(assert_no_mutation, f_args) + return pytree.tree_map(from_fun, f_outs) + + return _wrapper, args + + +def prepare_hook_gm(aot_config, fn, args): + from torch._functorch._aot_autograd.graph_capture import _create_graph + + fn, args = create_wrap_fn(fn, args) + gm = _create_graph(fn, args, aot_config=aot_config) + return gm + + +# Inline Autograd saved_tensors_hooks into epilogue of forward graph +# and prologue of backward graph. +# This changes forward graph outputs and inputs. +# Pack hook can return tensors, sym scalars, constants. +# All tensors to save for backward will be grouped together at front. +# Sym scalars grouped on another end. Constants are inlined in the graph. +def maybe_inline_graph_saved_tensors_hooks( + fw_module, # torch.fx.GraphModule + bw_module, # torch.fx.GraphModule + num_inner_fwd_outputs, + inner_meta, + aot_config, + static_input_indices, +): + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + return + + get_hooks = torch._functorch._aot_autograd.utils.top_saved_tensors_hooks + are_inline_hooks = ( + torch._functorch._aot_autograd.utils.saved_tensors_hooks_are_inlineable + ) + + hooks = get_hooks() + if not are_inline_hooks(hooks): + return + + pack_hook_gm, unpack_hook_gm = hooks + + structured_logs: list[str] = [] + maybe_log_graph( + fw_module, + "Forward graph pre saved_tensors_hooks inlining", + aot_config, + lambda: "aot_forward_graph_pre_saved_tensors_hooks", + structured_logs, + ) + maybe_log_graph( + bw_module, + "Backward graph pre saved_tensors_hooks inlining", + aot_config, + lambda: "aot_backward_graph_pre_saved_tensors_hooks", + structured_logs, + ) + fw_g = fw_module.graph + bw_g = bw_module.graph + + fw_g_names = {node.name for node in fw_g.nodes} + bw_g_names = {node.name for node in bw_g.nodes} + + def _gen_unused_name(candidate: str): + c = candidate + i = 0 + while c in fw_g_names or c in bw_g_names: + c = f"{candidate}_{i}" + i = i + 1 + return c + + bw_g_inputs = bw_g.find_nodes(op="placeholder") + + fw_out_n = fw_g.output_node() + fw_outs = fw_out_n.args[0] # type: ignore[var-annotated] + fw_outs_inner_set = set(fw_outs[:num_inner_fwd_outputs]) + fw_outs_saved_for_bw = fw_outs[num_inner_fwd_outputs:] + fw_outs_packed_tensors = [] # type: ignore[var-annotated] + fw_outs_packed_syms = [] # type: ignore[var-annotated] + + # The main use case for saved_tensors_hooks is activation quantization, + # for memory usage optimization. + # Desired behavior is to quantize saved activations to free the original saved tensor. + # Saved nodes may include forward inputs, outputs, parameters. + # They may be held by something else and will not be deallocated after quantization. + # Donated buffers are intermediates in the graph invisible for the user, + # this guarantees that they can be deallocated. + # Using this as a default behavior to select saved nodes to apply hooks. + # There is also a config to apply hooks for all saved nodes without any filtering. + # The plan is to propagate meta about the source of the saved node to the user hook function. + mode = torch._functorch.config.saved_tensors_hooks_filtering_mode + allow_set = None + exclude_set = None + + if mode == "donated": + # collect_bw_donated_buffer_idxs requires inner_meta to have num_symints_saved_for_bw + inner_meta.num_symints_saved_for_bw = len( + [n for n in fw_outs_saved_for_bw if is_sym_node(n)] + ) + bw_donated_idxs = collect_bw_donated_buffer_idxs( + fw_module, + bw_module, + inner_meta, + ) + fw_donated_idxs = [ + i - inner_meta.num_symints_saved_for_bw for i in bw_donated_idxs + ] + allow_set = {fw_outs_saved_for_bw[i].name for i in fw_donated_idxs} + elif mode == "no_static": + fw_g_inputs = fw_g.find_nodes(op="placeholder") + exclude_set = {fw_g_inputs[i].name for i in static_input_indices} + + if (allow_set is not None) and (not allow_set): + # This means we have empty whitelist, + # No donated (intermediate) saved. + # Do not do anything in this case + return + + if aot_config.enable_log: + structured_logs.append(f"fw_outs_saved_for_bw:{fw_outs_saved_for_bw}") + structured_logs.append(f"mode:{mode}") + structured_logs.append(f"allow_set:{allow_set}") + structured_logs.append(f"exclude_set:{exclude_set}") + + for saved in fw_outs_saved_for_bw: + if ((allow_set is not None) and (saved.name not in allow_set)) or ( + (exclude_set is not None) and (saved.name in exclude_set) + ): + if isinstance(saved.meta["val"], torch.Tensor): + fw_outs_packed_tensors.append(saved) + continue + + val = saved.meta["val"] + if not isinstance(val, torch.Tensor): + continue + + def _get_extra_info() -> dict[str, Any]: + return {"_fw_graph": fw_g, "_bw_graph": bw_g, "_node": saved} + + with _saved_tensor_hook_context(_get_extra_info()): + pack_out_val = pack_hook_gm(val) + + requires_sc_handling = any( + is_traceable_wrapper_subclass(x) for x in pytree.tree_leaves(pack_out_val) + ) + if requires_sc_handling: + raise NotImplementedError( + "Tensor subclasses in GraphModule saved tensors hooks are not supported" + "You can workaround it by manually returning subclass's inner tensors" + " in the pack hook, and reconstructing the subclass in the unpack hook" + ) + + with _saved_tensor_hook_context(_get_extra_info()): + pack_gm = prepare_hook_gm(aot_config, pack_hook_gm, (val,)) + pack_g = pack_gm.graph + maybe_log_graph( + pack_gm, + f"saved_tensors_pack_hook {saved.name}", + aot_config, + lambda: f"aot_saved_tensors_hooks_pack {saved.name}", + structured_logs, + ) + pack_out_val = pack_gm(val) + + # Install pack hook graph as eiplogue of fw_module. + # Saved tensor output becomes input of pack hook graph. + # Replace saved tensor output with pack hook graph output. + # Outputs symbolic scalars, tensors are accumulated separately. + # Then in forward outputs and backward inputs installed in order + # sym_scalars, packed_saved_tensors. + # Keeping all tensors together allows to preserve + # the same identification at runtime, + # updating only number of saved sym_scalars and tensors. + pack_g_inputs = pack_g.find_nodes(op="placeholder") + assert len(pack_g_inputs) == 1 + env = {pack_g_inputs[0]: saved} + fw_pack_out_args = None + with fw_g.inserting_before(fw_out_n): + for node in pack_g.nodes: + if node.op == "placeholder": + continue + new_n = fw_g.node_copy(node, lambda n: env[n]) + fw_g_names.add(new_n.name) + env[node] = new_n + # Output node is temporarily copied to have remapped arguments. + # Removed in the end. + if node.op == "output": + fw_pack_out_args = new_n.args[0] + fw_g.erase_node(new_n) + + env.clear() + assert fw_pack_out_args + fw_outs_bw_ins_node_names = [] + for out_idx, _n in enumerate(pytree.tree_leaves(fw_pack_out_args)): + if not isinstance(_n, torch.fx.Node): + fw_outs_bw_ins_node_names.append("") + continue + + # This happens when hook is noop and it is either user input or user output. + # Do not do anything with this node. + if _n.op == "placeholder" or _n in fw_outs_inner_set: + # This means the hook returned input primals unchanged + # Do not rename in this case. + n = _n + new_node_name = _n.name + fw_outs_bw_ins_node_names.append(new_node_name) + else: + # We can not specify desired name in node_copy. + # Copying node manually to set specific name, + # to have matching fw_outs, bw_inputs names. + new_node_name = _gen_unused_name(f"{saved.name}_hook_{out_idx}") + with fw_g.inserting_before(_n): + n = fw_g.create_node( + _n.op, + _n.target, + _n.args, + _n.kwargs, + name=new_node_name, + ) + assert n.name == new_node_name + fw_outs_bw_ins_node_names.append(new_node_name) + n.meta = copy.copy(_n.meta) + _n.replace_all_uses_with(n) + fw_g.erase_node(_n) + if isinstance(n.meta["val"], torch.Tensor): + fw_outs_packed_tensors.append(n) + elif is_sym_node(n): + fw_outs_packed_syms.append(n) + + # Install unpack hook graph as a prologue of backward graph + # Saved tensors inputs are replaced with packed tensors and packed sym scalars. + # The saved tensors inputs usages in the graph are replaced with unpack hook graph outputs. + with _saved_tensor_hook_context(_get_extra_info()): + unpack_gm = prepare_hook_gm(aot_config, unpack_hook_gm, (pack_out_val,)) + unpack_g = unpack_gm.graph + maybe_log_graph( + unpack_gm, + f"saved_tensors_unpack_hook {saved.name}", + aot_config, + lambda: f"aot_saved_tensors_hooks_unpack {saved.name}", + structured_logs, + ) + + def find_saved_in_bw_inputs(bw_inputs): + for n in bw_inputs: + if n.name == saved.name: + return n + + bw_g_input = find_saved_in_bw_inputs(bw_g_inputs) + assert bw_g_input + original_bw_g_input_users = list(bw_g_input.users.keys()) + bw_g_input_used_directly = False + + # Replace backward graph saved tensor input with copy of pack graph outputs + # All non-Tensor, non-symscalars outputs are constanted. + + unpack_g_inputs = unpack_g.find_nodes(op="placeholder") + env = {} + for out_idx, (unp_in_n, out_n, val) in enumerate( + zip( + unpack_g_inputs, + pytree.tree_leaves(fw_pack_out_args), + pytree.tree_leaves(pack_out_val), + ) + ): + is_sym = isinstance(val, py_sym_types) + if isinstance(val, torch.Tensor) or is_sym: + # We want forward_outputs names to match backward_inputs, + # Potentially backward may already have "{saved.name}_hook_{idx}", + # In this case fx.Graph will add suffix. + new_node_name = fw_outs_bw_ins_node_names[out_idx] + if bw_g_input.name == new_node_name: + env[unp_in_n] = bw_g_input + bw_g_input_used_directly = True + else: + # Backward calling convention: ctx_symints,ctx_saved_tensors + # Inserting packed sym scalars before first saved tensor input. + # Inserting packed tensors before last saved tensor input. + # Saved tensor inputs between them will be removed. + with ( + bw_g.inserting_before(bw_g_inputs[0]) + if is_sym + else bw_g.inserting_before(bw_g_input) + ): + new_n = bw_g.placeholder(new_node_name) + assert new_n.name == new_node_name + new_n.meta = copy.copy(out_n.meta) + env[unp_in_n] = new_n + else: + # Inline values of non-Tensor, non-SymScalars + env[unp_in_n] = val + + # Inserting unpack hook after placeholders. + bw_unpack_out_n = None + with bw_g.inserting_before(bw_g_inputs[-1].next): + for node in unpack_g.nodes: + if node.op == "placeholder": + continue + new_n = bw_g.node_copy(node, lambda n: env[n]) + bw_g_names.add(new_n.name) + env[node] = new_n + # Temporary insert output, to have remapped by node_copy args. + # Removed in the end. + if node.op == "output": + bw_unpack_out_n = new_n + + assert bw_unpack_out_n + _leaves = pytree.tree_leaves(bw_unpack_out_n.args) + assert len(_leaves) == 1 + unpack_saved_tensor_n = _leaves[0] + + if not bw_g_input_used_directly: + bw_g_input.replace_all_uses_with(unpack_saved_tensor_n) + bw_g.erase_node(bw_g_input) + else: + # Keep usages of bw_g_input in inserted unpacked hook graph. + # Replace other usages of bw_g_input with unpack_saved_tensor_n. + for use_node in original_bw_g_input_users: + use_node._replace_input_with(bw_g_input, unpack_saved_tensor_n) + bw_g.erase_node(bw_unpack_out_n) + + # Changing forward graph outputs, + # Inserting packed_tensors and packed_syms on the place of saved tensors. + # Packed sym_scalars are together with saved symints + symint_outs_saved_for_bw = [n for n in fw_outs_saved_for_bw if is_sym_node(n)] + fw_new_outs = pytree.tree_leaves( + ( + fw_outs[:num_inner_fwd_outputs], + fw_outs_packed_tensors, + fw_outs_packed_syms, + symint_outs_saved_for_bw, + ) + ) + fw_out_n.args = (tuple(fw_new_outs),) + + # Assert that saved tensors and symints in forward outputs are aligned with backward inputs + _fw_n = num_inner_fwd_outputs + _fw_num_t = len(fw_outs_packed_tensors) + _fw_num_s = len(fw_outs_packed_syms) + len(symint_outs_saved_for_bw) + fw_outs_saved_tensors = fw_new_outs[_fw_n : _fw_n + _fw_num_t] + fw_outs_saved_syms = fw_new_outs[_fw_n + _fw_num_t :] + bw_new_ins = list(bw_g.find_nodes(op="placeholder")) + bw_ins_saved_syms = bw_new_ins[:_fw_num_s] + bw_ins_saved_tensors = bw_new_ins[_fw_num_s : _fw_num_s + _fw_num_t] + + fw_t_names = [n.name for n in fw_outs_saved_tensors] + bw_t_names = [n.name for n in bw_ins_saved_tensors] + fw_s_names = [n.name for n in fw_outs_saved_syms] + bw_s_names = [n.name for n in bw_ins_saved_syms] + + def _log_structured_logs(): + if not aot_config.enable_log: + return + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_saved_tensors_hooks_graphs", + "encoding": "string", + }, + payload_fn=lambda: "\n".join(structured_logs), + ) + + if aot_config.enable_log: + structured_logs.append( + f"fw_outs[:num_inner_fwd_outputs]:{fw_outs[:num_inner_fwd_outputs]}" + ) + structured_logs.append(f"fw_outs_packed_tensors:{fw_outs_packed_tensors}") + structured_logs.append(f"fw_t_names:{fw_t_names}") + structured_logs.append(f"bw_t_names:{bw_t_names}") + structured_logs.append(f"fw_s_names:{fw_s_names}") + structured_logs.append(f"bw_s_names:{bw_s_names}") + structured_logs.append(f"\nfw_g_pre_assert:{fw_g}") + structured_logs.append(f"\nbw_g_pre_assert:{bw_g}") + maybe_log_graph( + fw_module, + "Forward graph after transform pre-assert", + aot_config, + lambda: "aot_forward_graph_pre_assert_saved_tensors_hooks", + structured_logs, + ) + maybe_log_graph( + bw_module, + "Backward graph after transform pre-assert", + aot_config, + lambda: "aot_backward_graph_pre_assert_saved_tensors_hooks", + structured_logs, + ) + _log_structured_logs() + + assert fw_t_names == bw_t_names + assert fw_s_names == bw_s_names + + fw_g.lint() + bw_g.lint() + fw_module.recompile() + bw_module.recompile() + + +def _log_joint_graph( + fx_g: torch.fx.GraphModule, + aot_config: AOTConfig, +) -> Optional[str]: + """ + Log the joint graph to the structured logger. + Return a str representation of the graph. + """ + joint_graph_str = None + if aot_config.enable_log: + aot_joint_log.info( + "%s", + lazy_format_graph_code( + "Joint graph", + fx_g, + aot_config.aot_id, + include_stride=True, + include_device=True, + colored=True, + ), + ) + joint_graph_str = fx_g.print_readable( + print_output=False, + include_stride=True, + include_device=True, + expanded_def=True, + ) + trace_structured( + "aot_joint_graph", + payload_fn=lambda: joint_graph_str, + ) + return joint_graph_str + + +def _log_fw_bw_graphs( + fw_module: torch.fx.GraphModule, + bw_module: torch.fx.GraphModule, + maybe_subclass_meta: Optional[SubclassMeta], + fw_metadata: ViewAndMutationMeta, + aot_config: AOTConfig, +) -> tuple[Optional[str], Optional[str]]: + """ + Log the fw and bw graphs to the structured logger. + Return str representations of the graphs. + """ + fw_module_str = None + bw_module_str = None + if aot_config.enable_log: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "torch._functorch.config", + "encoding": "string", + }, + payload_fn=lambda: torch._functorch.config.get_serializable_config_copy(), + ) + aot_graphs_log.info( + "aot_config id: %s, fw_metadata=%s, inner_meta=%s", + str(aot_config.aot_id), + str(fw_metadata), + str(_get_inner_meta(maybe_subclass_meta, fw_metadata)), + ) + + aot_graphs_log.info( + "%s", + lazy_format_graph_code( + "Forward graph", + fw_module, + aot_config.aot_id, + include_stride=True, + include_device=True, + colored=True, + ), + ) + aot_graphs_log.info( + "%s", + lazy_format_graph_code( + "Backward graph", + bw_module, + aot_config.aot_id, + include_stride=True, + include_device=True, + colored=True, + ), + ) + fw_module_str = fw_module.print_readable( + print_output=False, + include_stride=True, + include_device=True, + expanded_def=True, + ) + bw_module_str = bw_module.print_readable( + print_output=False, + include_stride=True, + include_device=True, + expanded_def=True, + ) + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(fw_metadata), + ) + if maybe_subclass_meta is not None: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_subclass_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(maybe_subclass_meta), + ) + + trace_structured( + "aot_forward_graph", + payload_fn=lambda: fw_module_str, + ) + trace_structured( + "aot_backward_graph", + payload_fn=lambda: bw_module_str, + ) + return fw_module_str, bw_module_str + + +def _aot_stage2a_partition( + fx_g: torch.fx.GraphModule, + joint_inputs: Union[list[Any], tuple[list[Any], list[Any]]], + maybe_subclass_meta: Optional[SubclassMeta], + fw_metadata: ViewAndMutationMeta, + aot_config: AOTConfig, +) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule, int, int, list[int], list[Any]]: + """ + Partition the joint graph into a forward graph and a backward graph. Returns: + - the forward and backward graphs + - the number of forward outputs and the number of symints saved for backward + - indices of inputs to detach + - adjusted inputs to forward + """ + disable_amp = torch._C._is_any_autocast_enabled() + inner_meta = _get_inner_meta(maybe_subclass_meta, fw_metadata) + + with torch.no_grad(): + context = torch._C._DisableAutocast if disable_amp else nullcontext + with context(), track_graph_compiling(aot_config, "joint"): + # See Note: [Partitioner handling for Subclasses, Part 1] + # See Note: [Recomputing subclass mutation handling] + mutated_inp_runtime_indices = ( + compute_inner_mutated_inp_indices_from_subclass_meta( + fw_metadata, inner_meta + ) + ) + num_tokens = len(fw_metadata.tokens) + num_mutated_inp_runtime_indices = len(mutated_inp_runtime_indices) + num_inner_fwd_outputs = ( + num_mutated_inp_runtime_indices + + inner_meta.num_outputs + + inner_meta.num_intermediate_bases + + inner_meta.num_outputs_rng_offset + + num_tokens # See Note [Side-Effectful Tokens in AOTAutograd] + ) + fx_g = run_joint_graph_passes_on_hops(fx_g, joint_inputs, aot_config) + + # apply joint_gm callback here + if callable(torch._functorch.config.joint_custom_pass): + # pyrefly: ignore [bad-assignment] + fx_g = torch._functorch.config.joint_custom_pass(fx_g, joint_inputs) + + static_lifetime_input_indices = fw_metadata.static_input_indices + fw_module, bw_module = aot_config.partition_fn( + fx_g, + joint_inputs, + num_fwd_outputs=num_inner_fwd_outputs, + static_lifetime_input_indices=static_lifetime_input_indices, + ) + rng_states = [ + n + for n in fw_module.graph.find_nodes(op="placeholder") + if "fwd_rng_state" in n.name + ] + fw_metadata.num_graphsafe_rng_states = len(rng_states) + if rng_states: + fw_metadata.graphsafe_rng_state_index = ( + rng_states[0].meta["val"].device.index + ) + + # See Note [Side-Effectful Tokens in AOTAutograd] + if config.unlift_effect_tokens and ( + num_tokens > 0 or fw_metadata.num_backward_tokens > 0 + ): + unlift_tokens(fw_module, fw_metadata, aot_config, bw_module) + + num_inner_fwd_outputs -= num_tokens + joint_inputs = ( + joint_inputs[0][num_tokens:], + joint_inputs[1], + ) + + maybe_inline_graph_saved_tensors_hooks( + fw_module, + bw_module, + num_inner_fwd_outputs, + inner_meta, + aot_config, + fw_metadata.static_input_indices, + ) + static_lifetime_input_indices = fw_metadata.static_input_indices + + fw_outs = next(iter(fw_module.graph.find_nodes(op="output"))).args[0] + # we only need to bookkeep the symints that are saved for bw, not any symints + # the user forward might have returned in its own output + fw_outs_saved_for_bw = fw_outs[num_inner_fwd_outputs:] + num_fw_outs_saved_for_bw = len(fw_outs_saved_for_bw) + symint_outs_saved_for_bw = [] + for idx, node in enumerate(fw_outs_saved_for_bw): + if is_sym_node(node): + symint_outs_saved_for_bw.append(node) + elif ( + isinstance(node, torch.fx.Node) + and "val" in getattr(node, "meta", {}) + and isinstance(node.meta["val"], FakeTensor) + ): + # record dynamic tensor activations + dynamic_dims: set[int] = { + dim + for dim, size in enumerate(node.meta["val"].shape) + if not isinstance(size, int) + } + if dynamic_dims: + fw_metadata.dynamic_saved_tensors_idxs[idx] = dynamic_dims + + num_symints_saved_for_bw = len(symint_outs_saved_for_bw) + fw_metadata.num_symints_saved_for_bw = num_symints_saved_for_bw + inner_meta.num_symints_saved_for_bw = num_symints_saved_for_bw + if torch._functorch.config.donated_buffer: + fw_metadata.bw_donated_idxs = collect_bw_donated_buffer_idxs( + fw_module, + bw_module, + inner_meta, + ) + inner_meta.bw_donated_idxs = fw_metadata.bw_donated_idxs + + # Note [Detaching inputs that never need gradients] + # See https://github.com/pytorch/pytorch/issues/97745 + # Suppose we have a function like this that we want to compile: + # + # def f(x, y): + # return torch.mul(x, y.detach()) + # + # What gradients should we compute for x and y? + # By default, AOTAutograd will compute a gradient for **every** input that requires gradients, + # and so we'll compute: + # x_grad_input = y + # y_grad_input = None + # Does this preserve the semantics of eager mode? + # Unfortunately, no. + # Doing the above will cause autograd to **continue** to backprop the autograd tape + # that was generated from constructing y. + # + # This is **different** from what would have happened in eager mode. + # In eager mode, if we backprop through the output of this function, autograd will only traverse + # the bit of the autograd tape corresponding to "x". + # In particular, if a user had previously backpropped through y's autograd tape, + # And then they try to backprop through the output of the above function, + # then we'll hit the dreaded "Trying to backward through the graph a second time" error. + # + # You might think: If autograd sees that a gradient is None, shouldn't it stop early, + # instead of continuing the backprop through the ancestors of that node in the graph? + # + # Autograd has two passes: + # (1) a first pass that traverses the autograd graph and figures out which nodes need to be executed + # (2) a second pass that actually goes ahead and executes each node when it becomes ready, + # propagating gradients + # By the time we're executing a node and we see that it produces a None, the set of nodes to execute + # is already locked-in. + # + # The fix: instead, we can recognize statically that the graph we're compiling will never contribute + # gradients to y, and prevent autograd from trying to traverse y's autograd tape at all. + # We can do this by manually detach'ing y before sending it through the `CompiledFunction`. + # + # Note that this solution is not bulletproof. + # It's possible to construct a case where eager may or may not have have tried to autograd through y, + # depending on the actual grad_outputs that were passed in during the backward. + # There is no easy fix for this: the simplest fix would be to run with `retain_graph=True`, + # allowing autograd to reuse the graph. + # + # An example of this case is: + # def f(x): + # return x.detach() * 2, x * 3 + # If we were to only backprop through outs[0], in eager, we would stop + # If we backward only on the first output, we shouldn't send a grad through x. + # But the custom autograd function doesn't know that: it will materialize zero grads for x * 3 + # and we will end up with a zero grad at x. + # If we later backprop through the second output, this will also require backprop'ing through x. + # Meaning we'll need to use `retain_graph=True` to be able to backprop through x the second time. + _indices_of_inps_to_detach: list[int] = [] + + # reversed() since we expect output at end of graph + bw_output = next(reversed(bw_module.graph.find_nodes(op="output"))) + bw_outs: Sequence[torch.fx.Node] = bw_output.args[0] # type: ignore[assignment] + + # TODO: we should apply the below "detach inputs if their gradients are statically known to be None" + # optimization even if we have subclass inputs/outputs (we do not handle this today). + # Computing which our our inputs get None gradients is a bit more complicated, + # if any of our inputs are subclasses. Why? + # (a) we need to make sure that we call .detach() on the input subclasses, since autograd sees subclasses. + # (b) The grad_outputs that we AOT computed in our backward graph are the desugared tensor tensors, + # so we need to figure out which subclass fw inputs they map to. + if maybe_subclass_meta is None: + num_backward_tokens: int = inner_meta.num_backward_tokens + assert ( + len(bw_outs) + == len(fw_metadata.input_info) + + inner_meta.num_outputs_rng_offset + + num_backward_tokens + ) + bw_outs_no_rng_no_tokens = bw_outs + if (inner_meta.num_outputs_rng_offset + num_backward_tokens) > 0: + bw_outs_no_rng_no_tokens = bw_outs[ + : -(inner_meta.num_outputs_rng_offset + num_backward_tokens) + ] + assert len(bw_outs_no_rng_no_tokens) == len(fw_metadata.input_info) + + for i, (bw_out) in enumerate(bw_outs_no_rng_no_tokens): + # If our input experiences a metadata mutation inside the graph (e.g. set_()), + # we *must* not detach, otherwise it will be the detach'd input that gets the metadata mutation + metadata_mutation_in_graph = ( + fw_metadata.input_info[i].mutation_type + == MutationType.MUTATED_IN_GRAPH + and fw_metadata.input_info[i].mutates_storage_metadata + ) + is_non_leaf = ( + fw_metadata.input_info[i].requires_grad + and not fw_metadata.input_info[i].is_leaf + ) + if bw_out is None and not metadata_mutation_in_graph and is_non_leaf: + _indices_of_inps_to_detach.append(i) + + return ( + fw_module, + bw_module, + num_fw_outs_saved_for_bw, + num_symints_saved_for_bw, + _indices_of_inps_to_detach, + joint_inputs[0], + ) + + +def _aot_stage2b_fw_compile( + fw_module: torch.fx.GraphModule, + adjusted_flat_args: list[Any], + maybe_subclass_meta: Optional[SubclassMeta], + fw_metadata: ViewAndMutationMeta, + num_fw_outs_saved_for_bw: int, + aot_config: AOTConfig, +) -> tuple[Optional[list[Optional[tuple[int, ...]]]], Callable]: + return _aot_stage2b_compile_forward_or_inference( + fw_module, + adjusted_flat_args, + maybe_subclass_meta, + fw_metadata, + aot_config, + is_inference=False, + num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw, + ) + + +def _aot_stage2b_bw_compile( + bw_module: torch.fx.GraphModule, + maybe_subclass_meta: Optional[SubclassMeta], + fw_metadata: ViewAndMutationMeta, + fwd_output_strides: Optional[list[Optional[tuple[int, ...]]]], + num_symints_saved_for_bw: int, + aot_config: AOTConfig, +) -> tuple[AutogradLazyBackwardCompileInfo, Optional[Callable]]: + """ + Compile the backward graph. Returns: + - the placeholder list for the backward graph + - the compiled backward function + """ + with torch.no_grad(): + # NB: It's important to compile backwards ahead of time, as this may + # add extra guards which we need to apply to the Dynamo cache at + # forwards + with track_graph_compiling(aot_config, "backward"), torch._C._DisableAutocast(): + placeholder_list = fx_placeholder_vals(bw_module) + + forward_saved_for_backwards_strides = None + if fwd_output_strides is not None: + inner_meta = _get_inner_meta(maybe_subclass_meta, fw_metadata) + forward_saved_for_backwards_strides = fwd_output_strides[ + inner_meta.tensors_saved_for_backwards_slice + ] + + # saved activations can have different stride to eager if + # the compiler does layout optimization. We should restride the + # tensor passed in for compiling the backward graph using the + # saved tensor's stride. + for i in range(len(placeholder_list)): + ph_arg = placeholder_list[i] + if not isinstance(ph_arg, torch.Tensor): + continue + + if forward_saved_for_backwards_strides is None: + continue + + real_stride = None + # Per all_args calling convention + j = i - num_symints_saved_for_bw + if 0 <= j < len(forward_saved_for_backwards_strides): + real_stride = forward_saved_for_backwards_strides[j] + if real_stride is None: + continue + + # Comparing ph_arg.stride() with real_stride directly may + # cause dynamic dimensions in ph_arg being specialized to static + # value. Using suppress_guards and guard_or_true to avoid that. + + stride_different = False + fake_mode = detect_fake_mode() + suppress_ctx = ( + fake_mode.shape_env.suppress_guards() + if fake_mode is not None and fake_mode.shape_env is not None + else nullcontext() + ) + + # Inductor can choose different strides for activations than + # what backward graph has. if we can't statically tell that + # strides are the same, we assume they are not. + with suppress_ctx: + for k in range(len(ph_arg.stride())): + # real_stride can't be symbolic. + # pyrefly: ignore [index-error] + if guard_or_true(ph_arg.stride()[k] != int(real_stride[k])): + stride_different = True + break + + if stride_different: + # Note that here we use the stride of the real tensor to + # restride a FakeTensor. This does not cause trouble + # for dynamic shape since this code path only get + # executed if layout optimization is enabled. And we + # disable layout optimization for dynamic shape right + # now. + # + # A solution that decide stride order based on real + # tensor's stride and then apply that stride order to + # the FakeTensor does not work smoothly since some + # tensor's layout is not 'dense'. E.g. mixnet_l has a + # tensor with size [8, 64, 112, 112] and strides + # (2408448, 1, 21504, 192). The solution mentioned will + # decide a stride of (802816, 1, 7168, 64) for this + # tensor which is wrong. + + ph_size = ph_arg.size() + + # pyrefly: ignore [bad-argument-type] + placeholder_list[i] = ph_arg.as_strided(ph_size, real_stride) + compiled_bw_func = None + if ( + num_symints_saved_for_bw > 0 + or aot_config.force_non_lazy_backward_lowering + ): + try: + # See Note: [Backward graph lazy lowering] + with torch._subclasses.fake_tensor.unset_fake_temporarily(): + # If bw_module contains lifted constants, they will be real tensors stored as + # GraphModule. Deepcopying tensors under fake mode is not supported and will + # raise when attempting to set storage. + bw_module_copy = copy.deepcopy(bw_module) + compiled_bw_func = aot_config.bw_compiler( + bw_module_copy, placeholder_list + ) + del bw_module_copy + except Exception as e: + if aot_config.force_non_lazy_backward_lowering: + raise + exc = e + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "eager_compile_backwards_failure", + "encoding": "string", + }, + payload_fn=lambda: "\n".join( + traceback.format_exception( + type(exc), exc, exc.__traceback__ + ) + ), + ) + log.warning( + "failed to eagerly compile backwards for dynamic, suppressing in case backwards not needed", + exc_info=True, + ) + # Compiled autograd will run the bw_module in the backward pass, + # so recompilation need happen anyway if the backward pass is ever + # called. + # + # The reason we do the GraphModule recompilation here is because + # the lazy recompilation will cause issue in the backward pass + # with compiled autograd. + # + # Do the _LazyGraphModule.force_recompile here rather than when + # bw_module is first generated by the partitioner because the bw_module.recompile + # may be called in some code path later and cause the _LazyGraphModule.forward + # becomes the lazy version again. One example is when dynamic shape is enabled + # upfront, the bw_compiler will be called above which can cause extra + # graph module recompilation on bw_module. + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + from torch.fx._lazy_graph_module import _LazyGraphModule + + _LazyGraphModule.force_recompile(bw_module) + + saved_context = TracingContext.try_get() + saved_compile_context = CompileContext.try_get() + + lazy_backward_info = AutogradLazyBackwardCompileInfo( + bw_module, + placeholder_list, + saved_context, + saved_compile_context, + ) + + return lazy_backward_info, compiled_bw_func + + +def aot_stage2_autograd( + aot_state: AOTState, + aot_graph_capture: AOTGraphCapture, +) -> DispatchReturn: + """ + Autograd logic. Generates a joint graph, partitions it, manipulates the input with various wrappers, + and returns a wrapped torch.autograd.Function with a forward and backward. + """ + + fx_g = aot_graph_capture.graph_module + maybe_subclass_meta = aot_graph_capture.maybe_subclass_meta + fw_metadata = aot_state.fw_metadata + aot_config = aot_state.aot_config + + CompileEventLogger.try_add_pt2_compile("backend_compile", dispatch_mode="autograd") + joint_graph_str = _log_joint_graph(fx_g, aot_config) + + _apply_tensorify_python_scalars(fx_g) + + ( + fw_module, + bw_module, + num_fw_outs_saved_for_bw, + num_symints_saved_for_bw, + _indices_of_inps_to_detach, + adjusted_flat_args, + ) = _aot_stage2a_partition( + fx_g, + aot_graph_capture.updated_flat_args, + maybe_subclass_meta, + fw_metadata, + aot_config, + ) + + fw_module_str, bw_module_str = _log_fw_bw_graphs( + fw_module, bw_module, maybe_subclass_meta, fw_metadata, aot_config + ) + + fwd_output_strides, compiled_fw_func = _aot_stage2b_fw_compile( + fw_module, + adjusted_flat_args, + maybe_subclass_meta, + fw_metadata, + num_fw_outs_saved_for_bw, + aot_config, + ) + + lazy_backward_info, compiled_bw_func = _aot_stage2b_bw_compile( + bw_module, + maybe_subclass_meta, + fw_metadata, + fwd_output_strides, + num_symints_saved_for_bw, + aot_config, + ) + + try_save_cache_entry, entry = _cache_autograd_info( + aot_config, + aot_state.flat_args, + compiled_fw_func, + compiled_bw_func, + fw_module_str, + bw_module_str, + joint_graph_str, + aot_graph_capture.wrappers, + maybe_subclass_meta, + fw_metadata, + num_fw_outs_saved_for_bw, + _indices_of_inps_to_detach, + num_symints_saved_for_bw, + bw_module, + ) + + return _aot_stage2c_make_autograd_function( + aot_config, + aot_state.flat_args, + fw_metadata, + maybe_subclass_meta, + aot_graph_capture.wrappers, + compiled_fw_func, + compiled_bw_func, + lazy_backward_info, + try_save_cache_entry, + entry, + _indices_of_inps_to_detach, + num_symints_saved_for_bw, + ) + + +def _aot_stage2c_make_autograd_function( + aot_config, + flat_args, + fw_metadata, + maybe_subclass_meta, + wrappers, + compiled_fw_func, + compiled_bw_func, + lazy_backward_info, + try_save_cache_entry, + entry, + _indices_of_inps_to_detach, + num_symints_saved_for_bw, +): + backward_state_indices = [ + idx for idx, x in enumerate(flat_args) if isinstance(x, BackwardState) + ] + assert len(backward_state_indices) <= 1 + + disable_amp = torch._C._is_any_autocast_enabled() + compiled_fn = AOTDispatchAutograd.post_compile( + compiled_fw_func, + compiled_bw_func, + maybe_subclass_meta, + num_symints_saved_for_bw, + backward_state_indices, + disable_amp, + _indices_of_inps_to_detach, + lazy_backward_info, + aot_config, + fw_metadata=fw_metadata, + try_save_cache_entry=try_save_cache_entry, + ) + + if entry is not None: + compiled_fn = SerializableCompiledFunction(compiled_fn, lambda: entry) + + if config.debug_assert: + flat_requires_grad: list[Optional[bool]] = [ + a.requires_grad if isinstance(a, Tensor) else None for a in flat_args + ] + compiled_fn = DebugAssertWrapper( + flat_requires_grad=flat_requires_grad + ).post_compile(compiled_fn, aot_config, runtime_metadata=fw_metadata) + + compiled_fn = post_compile( + wrappers, + compiled_fn, + aot_config, + runtime_metadata=fw_metadata, + ) + return compiled_fn + + +def _cache_autograd_info( + aot_config, + flat_args, + compiled_fw_func, + compiled_bw_func, + fw_module_str, + bw_module_str, + joint_graph_str, + wrappers, + maybe_subclass_meta, + fw_metadata, + num_fw_outs_saved_for_bw, + _indices_of_inps_to_detach, + num_symints_saved_for_bw, + bw_module, +): + backward_state_indices = [ + idx for idx, x in enumerate(flat_args) if isinstance(x, BackwardState) + ] + assert len(backward_state_indices) <= 1 + + make_runtime_safe(fw_metadata, maybe_subclass_meta) + + try_save_cache_entry: Optional[Callable] = None + entry: Optional[GenericAOTAutogradResult] = None + + if aot_config.cache_info is not None: + forward_time_taken_ns = time.time_ns() - aot_config.cache_info.start_time_ns + + # NB: aot_config here is technically not needed as an argument: we could just + # close over aot_config.cache_info, since aot_config never changes. + # But closing over random variables is confusing IMO, so I'm leaving it. + def try_save_cache_entry( # noqa: F811 + compiled_bw_func: Callable, + bw_module: torch.fx.GraphModule, + _fw_metadata: ViewAndMutationMeta, + aot_config: AOTConfig, + ) -> Optional[GenericAOTAutogradResult]: + cache_info = aot_config.cache_info + + def should_save_cache(): + if should_bundle_autograd_cache(): + return True + else: + return hasattr(compiled_fw_func, "_fx_graph_cache_key") and hasattr( + compiled_bw_func, "_fx_graph_cache_key" + ) + + if cache_info is not None and should_save_cache(): + assert forward_time_taken_ns is not None + # TODO: technically, AOTAutograd does a *little* bit of post processing work + # in the backward that isn't measured here. But it's small enough that it's not worth + # the complexity of threading a bunch of times through the code, so we + # use the compiled_bw_func's inductor compile time instead. + # It's possible this changes in the future, in which case we should + # update backward_time_taken_ns to be more inclusive + backward_time_taken_ns = getattr(compiled_bw_func, "_time_taken_ns", 0) + + aot_forward_graph_str: Optional[str] = fw_module_str + aot_backward_graph_str: Optional[str] = bw_module_str + aot_joint_graph_str: Optional[str] = joint_graph_str + guards_expr = AOTAutogradCache.generate_guards_expression(cache_info) + + entry = AOTAutogradCache.make_entry( + compiled_fw_func, # type: ignore[arg-type] + compiled_bw_func, # type: ignore[arg-type] + aot_joint_graph_str, + aot_forward_graph_str, + aot_backward_graph_str, + _fw_metadata, + wrappers, + maybe_subclass_meta, + num_fw_outs_saved_for_bw, + _indices_of_inps_to_detach, + forward_time_taken_ns, + backward_time_taken_ns, + sanitized_aot_config=sanitize_aot_config(aot_config), + guards_expr=guards_expr, + backward_state_indices=backward_state_indices, + num_symints_saved_for_bw=num_symints_saved_for_bw, + serialized_bw_module=serialize_graph_module(bw_module), + ) + AOTAutogradCache.save( + cache_info.cache_key, + entry, + remote=should_use_remote_autograd_cache(), + ) + return entry + return None + + if compiled_bw_func is not None: + # If we already compiled the backward, we save its cache entry now + entry = try_save_cache_entry( + compiled_bw_func, bw_module, fw_metadata, aot_config + ) + try_save_cache_entry = None + + return try_save_cache_entry, entry + + +def _aot_stage2b_compile_forward_or_inference( + fw_module: torch.fx.GraphModule, + adjusted_flat_args: list[Any], + maybe_subclass_meta: Optional[SubclassMeta], + fw_metadata: ViewAndMutationMeta, + aot_config: AOTConfig, + *, + is_inference: bool, + num_fw_outs_saved_for_bw: Optional[int] = None, +) -> tuple[Optional[list[Optional[tuple[int, ...]]]], Callable]: + """ + Compile the forward or inference graph. Returns: + - the output strides of the forward graph + - the compiled forward/inference function + + Args: + fw_module: The forward graph module to compile + adjusted_flat_args: Flattened arguments after adjustments + maybe_subclass_meta: Metadata for tensor subclasses + fw_metadata: View and mutation metadata + aot_config: AOT configuration + is_inference: If True, compile for inference; if False, compile for forward (autograd) + num_fw_outs_saved_for_bw: Number of forward outputs saved for backward (required if not is_inference) + + Before compiling, we run pre_compile for the following wrappers: + - FakifiedOutWrapper + - FunctionalizedRngRuntimeWrapper + After compiling, we run post_compile for the following wrappers: + - EffectTokensWrapper + - AOTDispatchSubclassWrapper + - FunctionalizedRngRuntimeWrapper + - FakifiedOutWrapper + """ + + # Validation + if not is_inference and num_fw_outs_saved_for_bw is None: + raise ValueError( + "num_fw_outs_saved_for_bw must be provided when is_inference=False" + ) + + # Determine grad context, autocast context, tracking mode, compiler + if is_inference: + grad_ctx: Any = nullcontext + autocast_ctx: Any = ( + torch._C._DisableAutocast + if torch._C._is_any_autocast_enabled() + else nullcontext + ) + tracking_mode: str = "inference" + compiler: Any = aot_config.inference_compiler + else: + grad_ctx = torch.no_grad + autocast_ctx = torch._C._DisableAutocast + tracking_mode = "forward" + compiler = aot_config.fw_compiler + + with grad_ctx(), autocast_ctx(), track_graph_compiling(aot_config, tracking_mode): + # Setup wrappers + fakified_out_wrapper = FakifiedOutWrapper() + fakified_out_wrapper.pre_compile( + fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata + ) + + # Initialize RNG wrapper based on mode + functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper( + return_new_outs=is_inference + ) + + # Add RNG states for forward mode only + if not is_inference and fw_metadata.num_graphsafe_rng_states > 0: + index = fw_metadata.graphsafe_rng_state_index + assert index is not None + rng_states = [ + get_cuda_generator_meta_val(index) + for _ in range(fw_metadata.num_graphsafe_rng_states) + ] + adjusted_flat_args.extend(rng_states) # type: ignore[arg-type] + + functionalized_rng_wrapper.pre_compile( + fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata + ) + + # Set tracing context + if tracing_context := torch._guards.TracingContext.try_get(): + tracing_context.fw_metadata = _get_inner_meta( + maybe_subclass_meta, fw_metadata + ) + + with TracingContext.report_output_strides() as fwd_output_strides: + compiled_fw_func = compiler(fw_module, adjusted_flat_args) + + # Make boxed if needed + if not getattr(compiled_fw_func, "_boxed_call", False): + compiled_fw_func = make_boxed_func(compiled_fw_func) + + # Set forward output strides if needed + if fakified_out_wrapper.needs_post_compile: + fakified_out_wrapper.set_fwd_output_strides(fwd_output_strides) + + # Apply post-compile wrappers + compiled_fw_func = EffectTokensWrapper().post_compile( + compiled_fw_func, + aot_config, + runtime_metadata=fw_metadata, + ) + + compiled_fw_func = AOTDispatchSubclassWrapper( + fw_only=None, + trace_joint=False, + maybe_subclass_meta=maybe_subclass_meta, + num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw, + ).post_compile( + compiled_fw_func, + aot_config, + runtime_metadata=fw_metadata, + ) + + compiled_fw_func = functionalized_rng_wrapper.post_compile( + compiled_fw_func, aot_config, runtime_metadata=fw_metadata + ) + + compiled_fw_func = fakified_out_wrapper.post_compile( + compiled_fw_func, + aot_config, + runtime_metadata=fw_metadata, + ) + + return fwd_output_strides, compiled_fw_func diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/logging_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/logging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6325b6e6ab2489c175347afe13e05bfbed3c7e8d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/logging_utils.py @@ -0,0 +1,144 @@ +# mypy: allow-untyped-defs +""" +Contains utils for logging in AOTAutograd, including managing the names of the graphs under +compilation, capturing user-friendly tracebacks, and debug messages. +""" + +import collections +from contextlib import contextmanager + +import torch +import torch.fx.traceback as fx_traceback + + +# This is a list since looking forward, we can have this arbitrarily nested. +graph_being_compiled: list[str] = [] +# TODO: It would be nice to reset the numbering every time aot_id goes +# up, but this is annoying to do right now (because we don't know if +# an aot_id will come back from the dead), so right now this also happens +# to be a globally unique number too (at the cost of wobbling if you change +# how the graphs compile) +nth_graph: int = 0 +model_name: str = "model" + + +def set_model_name(name): + global model_name + model_name = name + + +def get_aot_compilation_context() -> tuple[list[str], str, int]: + return list(graph_being_compiled), model_name, nth_graph + + +def get_aot_graph_name() -> str: + """ + Returns the name of the graph being compiled. + """ + global model_name, graph_being_compiled, nth_graph + return f"{model_name}__{'_'.join(graph_being_compiled)}_{nth_graph}" + + +get_graph_being_compiled = get_aot_graph_name + + +@contextmanager +def track_graph_compiling(aot_config, graph_name): + global graph_being_compiled + # TODO: Don't shove the aot_id in here; set it in the context + graph_being_compiled = [f"{aot_config.aot_id}_{graph_name}"] + old_name = None + if tracing_context := torch._guards.TracingContext.try_get(): + old_name = tracing_context.aot_graph_name + tracing_context.aot_graph_name = graph_being_compiled + has_tracing_context = True + else: + has_tracing_context = False + try: + yield + finally: + global nth_graph + nth_graph += 1 + graph_being_compiled = [] + if has_tracing_context: + if tracing_context := torch._guards.TracingContext.try_get(): + tracing_context.aot_graph_name = old_name + + +# Set up hooks so that during backward the fx's stack_trace is properly set +callback_set = False + + +def setup_stacktrace_preservation_hooks(roots: list): + def iter_graph(roots): + if not roots: + return + seen = set() + q = collections.deque() # type: ignore[var-annotated] + for node in roots: + if node is not None and node not in seen: + seen.add(node) + q.append(node) + + while q: + node = q.popleft() + for fn, _idx in node.next_functions: + if fn in seen or fn is None: + continue + seen.add(fn) + q.append(fn) + + yield node + + def get_callback(saved_stack_): + def callback(): + global callback_set + fx_traceback.set_stack_trace(saved_stack_) + callback_set = False + + return callback + + def get_prehook(stack_, seq_nr): + def prehook(grad_output): + global callback_set + + if not callback_set: + torch.autograd.variable.Variable._execution_engine.queue_callback( # type: ignore[attr-defined] + get_callback(fx_traceback.format_stack()) + ) + callback_set = True + + fx_traceback.set_stack_trace(stack_) + fx_traceback.set_grad_fn_seq_nr(seq_nr) + + return prehook + + def get_posthook(special_stack_, seq_nr): + def posthook(grad_input, grad_output): + fx_traceback.set_stack_trace(special_stack_) + fx_traceback.reset_grad_fn_seq_nr() + + return posthook + + for node in iter_graph(roots): + forward_node_stack = node.metadata.get("traceback_", []) + node.register_prehook(get_prehook(forward_node_stack, node._sequence_nr())) + + special_stack = forward_node_stack.copy() + special_stack.append(fx_traceback.GRADIENT_ACC_SPECIAL_STACK) + node.register_hook(get_posthook(special_stack, node._sequence_nr())) + + +def describe_input(i, aot_config): + if i < aot_config.num_params_buffers: + return f"parameter/buffer {i}" + else: + return f"input {i - aot_config.num_params_buffers}" + + +def format_guard_bug_msg(aot_config, expected): + return ( + f"At compilation time, graph {aot_config.aot_id} was compiled under the " + f"assumption that {expected}, but at runtime this was not the case. " + "This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch." + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_logging/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_logging/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc1deefb9f1b655b320c1623f12f0c814f589c4a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_logging/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_logging/__pycache__/_internal.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_logging/__pycache__/_internal.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..319e9121913c27166184a316619c0f6e978b6588 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_logging/__pycache__/_internal.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_logging/__pycache__/_registrations.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_logging/__pycache__/_registrations.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a88adac508f338b4738391117c930b8a49a546e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_logging/__pycache__/_registrations.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_logging/__pycache__/scribe.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_logging/__pycache__/scribe.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48c1b9a42530d8049ce2026cb5129e531d215a84 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_logging/__pycache__/scribe.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_logging/__pycache__/structured.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_logging/__pycache__/structured.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1bf99b6b2aea04cd9fb989e8ec4eaa89a3814c8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_logging/__pycache__/structured.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03172cb8b702a68510d72cbccc755a6e79c63434 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_binary_ufuncs_impl.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_binary_ufuncs_impl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c539e74bd7c5a93a015924ec6cac1f984e8e0652 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_binary_ufuncs_impl.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_casting_dicts.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_casting_dicts.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fa921c7dd4c82ffce37910d1770c8d59fae542c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_casting_dicts.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_dtypes.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_dtypes.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..537da49f1dc511186771e19a13b5dcc228391eac Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_dtypes.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_dtypes_impl.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_dtypes_impl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..358041bde0351f119d24468caab8052601ea9cbd Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_dtypes_impl.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_funcs.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_funcs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c2ff3fad41f0856252174d894631d1de84962b5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_funcs.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_funcs_impl.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_funcs_impl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3606eef027fe5829d42fa6ec85b504437925299e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_funcs_impl.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_getlimits.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_getlimits.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..809229ca3cdeb526b2963eef837587e091821241 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_getlimits.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_ndarray.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_ndarray.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..250c13768e17a7c4f90eec137a2ad788c8ba5976 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_ndarray.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_normalizations.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_normalizations.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2bd18a8a4d5671735df20fd6117173910cc633e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_normalizations.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_reductions_impl.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_reductions_impl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93e4d40dc7a38e644d0ab7d7f109300d3d88863f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_reductions_impl.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_ufuncs.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_ufuncs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85996c72b19c02763a5596791058a981b53e4a4e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_ufuncs.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_unary_ufuncs_impl.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_unary_ufuncs_impl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6dce611e370bfc2a6dc766f77b3acd80eb80e6ca Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_unary_ufuncs_impl.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_util.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_util.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50477b6bda263d70b8ef79d798f1ff4be91b136c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_util.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/fft.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/fft.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7108df7fa27bdddb76318dd143e59db956a7233c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/fft.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/linalg.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/linalg.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db54795c5a0ea6932a682500514e55dd2b73f3c6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/linalg.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/random.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/random.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be0e9aa675e28fd41698feedad3cb48db0bcec49 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/random.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/testing/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/testing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05e73b12e29f8e6608647a3f16fabab39fbfb582 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/testing/__init__.py @@ -0,0 +1,20 @@ +# mypy: ignore-errors + +from .utils import ( + _gen_alignment_data, + assert_, + assert_allclose, + assert_almost_equal, + assert_array_almost_equal, + assert_array_equal, + assert_array_less, + assert_equal, + assert_raises_regex, + assert_warns, + HAS_REFCOUNT, + IS_WASM, + suppress_warnings, +) + + +# from .testing import assert_allclose # FIXME diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/testing/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/testing/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46ae640b3a6cf25c9ac33afc66cf751b7bf86d80 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/testing/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/testing/__pycache__/utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/testing/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cb608ed9ab2fdb34db9b7c52d63de617814af5b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/testing/__pycache__/utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/testing/utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/testing/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ffc027043b6f55aae572e2fb0ffe1142f6226959 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_numpy/testing/utils.py @@ -0,0 +1,2451 @@ +# mypy: ignore-errors + +""" +Utility function to facilitate testing. + +""" + +import contextlib +import gc +import operator +import os +import platform +import pprint +import re +import shutil +import sys +import warnings +from functools import wraps +from io import StringIO +from tempfile import mkdtemp, mkstemp +from warnings import WarningMessage + +import torch._numpy as np +from torch._numpy import arange, asarray as asanyarray, empty, float32, intp, ndarray + + +__all__ = [ + "assert_equal", + "assert_almost_equal", + "assert_approx_equal", + "assert_array_equal", + "assert_array_less", + "assert_string_equal", + "assert_", + "assert_array_almost_equal", + "build_err_msg", + "decorate_methods", + "print_assert_equal", + "verbose", + "assert_", + "assert_array_almost_equal_nulp", + "assert_raises_regex", + "assert_array_max_ulp", + "assert_warns", + "assert_no_warnings", + "assert_allclose", + "IgnoreException", + "clear_and_catch_warnings", + "temppath", + "tempdir", + "IS_PYPY", + "HAS_REFCOUNT", + "IS_WASM", + "suppress_warnings", + "assert_array_compare", + "assert_no_gc_cycles", + "break_cycles", + "IS_PYSTON", +] + + +verbose = 0 + +IS_WASM = platform.machine() in ["wasm32", "wasm64"] +IS_PYPY = sys.implementation.name == "pypy" +IS_PYSTON = hasattr(sys, "pyston_version_info") +HAS_REFCOUNT = getattr(sys, "getrefcount", None) is not None and not IS_PYSTON + + +def assert_(val, msg=""): + """ + Assert that works in release mode. + Accepts callable msg to allow deferring evaluation until failure. + + The Python built-in ``assert`` does not work when executing code in + optimized mode (the ``-O`` flag) - no byte-code is generated for it. + + For documentation on usage, refer to the Python documentation. + + """ + __tracebackhide__ = True # Hide traceback for py.test + if not val: + try: + smsg = msg() + except TypeError: + smsg = msg + raise AssertionError(smsg) + + +def gisnan(x): + return np.isnan(x) + + +def gisfinite(x): + return np.isfinite(x) + + +def gisinf(x): + return np.isinf(x) + + +def build_err_msg( + arrays, + err_msg, + header="Items are not equal:", + verbose=True, + names=("ACTUAL", "DESIRED"), + precision=8, +): + msg = ["\n" + header] + if err_msg: + if err_msg.find("\n") == -1 and len(err_msg) < 79 - len(header): + msg = [msg[0] + " " + err_msg] + else: + msg.append(err_msg) + if verbose: + for i, a in enumerate(arrays): + if isinstance(a, ndarray): + # precision argument is only needed if the objects are ndarrays + # r_func = partial(array_repr, precision=precision) + r_func = ndarray.__repr__ + else: + r_func = repr + + try: + r = r_func(a) + except Exception as exc: + r = f"[repr failed for <{type(a).__name__}>: {exc}]" + if r.count("\n") > 3: + r = "\n".join(r.splitlines()[:3]) + r += "..." + msg.append(f" {names[i]}: {r}") + return "\n".join(msg) + + +def assert_equal(actual, desired, err_msg="", verbose=True): + """ + Raises an AssertionError if two objects are not equal. + + Given two objects (scalars, lists, tuples, dictionaries or numpy arrays), + check that all elements of these objects are equal. An exception is raised + at the first conflicting values. + + When one of `actual` and `desired` is a scalar and the other is array_like, + the function checks that each element of the array_like object is equal to + the scalar. + + This function handles NaN comparisons as if NaN was a "normal" number. + That is, AssertionError is not raised if both objects have NaNs in the same + positions. This is in contrast to the IEEE standard on NaNs, which says + that NaN compared to anything must return False. + + Parameters + ---------- + actual : array_like + The object to check. + desired : array_like + The expected object. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired are not equal. + + Examples + -------- + >>> np.testing.assert_equal([4, 5], [4, 6]) + Traceback (most recent call last): + ... + AssertionError: + Items are not equal: + item=1 + ACTUAL: 5 + DESIRED: 6 + + The following comparison does not raise an exception. There are NaNs + in the inputs, but they are in the same positions. + + >>> np.testing.assert_equal(np.array([1.0, 2.0, np.nan]), [1, 2, np.nan]) + + """ + __tracebackhide__ = True # Hide traceback for py.test + + num_nones = sum([actual is None, desired is None]) + if num_nones == 1: + raise AssertionError(f"Not equal: {actual} != {desired}") + elif num_nones == 2: + return True + # else, carry on + + if isinstance(actual, np.DType) or isinstance(desired, np.DType): + result = actual == desired + if not result: + raise AssertionError(f"Not equal: {actual} != {desired}") + else: + return True + + if isinstance(desired, str) and isinstance(actual, str): + assert actual == desired + return + + if isinstance(desired, dict): + if not isinstance(actual, dict): + raise AssertionError(repr(type(actual))) + assert_equal(len(actual), len(desired), err_msg, verbose) + for k in desired: + if k not in actual: + raise AssertionError(repr(k)) + assert_equal(actual[k], desired[k], f"key={k!r}\n{err_msg}", verbose) + return + if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)): + assert_equal(len(actual), len(desired), err_msg, verbose) + for k in range(len(desired)): + assert_equal(actual[k], desired[k], f"item={k!r}\n{err_msg}", verbose) + return + + from torch._numpy import imag, iscomplexobj, isscalar, ndarray, real, signbit + + if isinstance(actual, ndarray) or isinstance(desired, ndarray): + return assert_array_equal(actual, desired, err_msg, verbose) + msg = build_err_msg([actual, desired], err_msg, verbose=verbose) + + # Handle complex numbers: separate into real/imag to handle + # nan/inf/negative zero correctly + # XXX: catch ValueError for subclasses of ndarray where iscomplex fail + try: + usecomplex = iscomplexobj(actual) or iscomplexobj(desired) + except (ValueError, TypeError): + usecomplex = False + + if usecomplex: + if iscomplexobj(actual): + actualr = real(actual) + actuali = imag(actual) + else: + actualr = actual + actuali = 0 + if iscomplexobj(desired): + desiredr = real(desired) + desiredi = imag(desired) + else: + desiredr = desired + desiredi = 0 + try: + assert_equal(actualr, desiredr) + assert_equal(actuali, desiredi) + except AssertionError: + raise AssertionError(msg) # noqa: B904 + + # isscalar test to check cases such as [np.nan] != np.nan + if isscalar(desired) != isscalar(actual): + raise AssertionError(msg) + + # Inf/nan/negative zero handling + try: + isdesnan = gisnan(desired) + isactnan = gisnan(actual) + if isdesnan and isactnan: + return # both nan, so equal + + if desired == 0 and actual == 0: + if not signbit(desired) == signbit(actual): + raise AssertionError(msg) + + except (TypeError, ValueError, NotImplementedError): + pass + + try: + # Explicitly use __eq__ for comparison, gh-2552 + if not (desired == actual): + raise AssertionError(msg) + + except (DeprecationWarning, FutureWarning) as e: + # this handles the case when the two types are not even comparable + if "elementwise == comparison" in e.args[0]: + raise AssertionError(msg) # noqa: B904 + else: + raise + + +def print_assert_equal(test_string, actual, desired): + """ + Test if two objects are equal, and print an error message if test fails. + + The test is performed with ``actual == desired``. + + Parameters + ---------- + test_string : str + The message supplied to AssertionError. + actual : object + The object to test for equality against `desired`. + desired : object + The expected result. + + Examples + -------- + >>> np.testing.print_assert_equal( + ... "Test XYZ of func xyz", [0, 1], [0, 1] + ... ) # doctest: +SKIP + >>> np.testing.print_assert_equal( + ... "Test XYZ of func xyz", [0, 1], [0, 2] + ... ) # doctest: +SKIP + Traceback (most recent call last): + ... + AssertionError: Test XYZ of func xyz failed + ACTUAL: + [0, 1] + DESIRED: + [0, 2] + + """ + __tracebackhide__ = True # Hide traceback for py.test + import pprint + + if actual != desired: + msg = StringIO() + msg.write(test_string) + msg.write(" failed\nACTUAL: \n") + pprint.pprint(actual, msg) + msg.write("DESIRED: \n") + pprint.pprint(desired, msg) + raise AssertionError(msg.getvalue()) + + +def assert_almost_equal(actual, desired, decimal=7, err_msg="", verbose=True): + """ + Raises an AssertionError if two items are not equal up to desired + precision. + + .. note:: It is recommended to use one of `assert_allclose`, + `assert_array_almost_equal_nulp` or `assert_array_max_ulp` + instead of this function for more consistent floating point + comparisons. + + The test verifies that the elements of `actual` and `desired` satisfy. + + ``abs(desired-actual) < float64(1.5 * 10**(-decimal))`` + + That is a looser test than originally documented, but agrees with what the + actual implementation in `assert_array_almost_equal` did up to rounding + vagaries. An exception is raised at conflicting values. For ndarrays this + delegates to assert_array_almost_equal + + Parameters + ---------- + actual : array_like + The object to check. + desired : array_like + The expected object. + decimal : int, optional + Desired precision, default is 7. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired are not equal up to specified precision. + + See Also + -------- + assert_allclose: Compare two array_like objects for equality with desired + relative and/or absolute precision. + assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal + + Examples + -------- + >>> from torch._numpy.testing import assert_almost_equal + >>> assert_almost_equal(2.3333333333333, 2.33333334) + >>> assert_almost_equal(2.3333333333333, 2.33333334, decimal=10) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not almost equal to 10 decimals + ACTUAL: 2.3333333333333 + DESIRED: 2.33333334 + + >>> assert_almost_equal( + ... np.array([1.0, 2.3333333333333]), np.array([1.0, 2.33333334]), decimal=9 + ... ) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not almost equal to 9 decimals + + Mismatched elements: 1 / 2 (50%) + Max absolute difference: 6.666699636781459e-09 + Max relative difference: 2.8571569790287484e-09 + x: torch.ndarray([1.0000, 2.3333], dtype=float64) + y: torch.ndarray([1.0000, 2.3333], dtype=float64) + + """ + __tracebackhide__ = True # Hide traceback for py.test + from torch._numpy import imag, iscomplexobj, ndarray, real + + # Handle complex numbers: separate into real/imag to handle + # nan/inf/negative zero correctly + # XXX: catch ValueError for subclasses of ndarray where iscomplex fail + try: + usecomplex = iscomplexobj(actual) or iscomplexobj(desired) + except ValueError: + usecomplex = False + + def _build_err_msg(): + header = f"Arrays are not almost equal to {decimal:d} decimals" + return build_err_msg([actual, desired], err_msg, verbose=verbose, header=header) + + if usecomplex: + if iscomplexobj(actual): + actualr = real(actual) + actuali = imag(actual) + else: + actualr = actual + actuali = 0 + if iscomplexobj(desired): + desiredr = real(desired) + desiredi = imag(desired) + else: + desiredr = desired + desiredi = 0 + try: + assert_almost_equal(actualr, desiredr, decimal=decimal) + assert_almost_equal(actuali, desiredi, decimal=decimal) + except AssertionError: + raise AssertionError(_build_err_msg()) # noqa: B904 + + if isinstance(actual, (ndarray, tuple, list)) or isinstance( + desired, (ndarray, tuple, list) + ): + return assert_array_almost_equal(actual, desired, decimal, err_msg) + try: + # If one of desired/actual is not finite, handle it specially here: + # check that both are nan if any is a nan, and test for equality + # otherwise + if not (gisfinite(desired) and gisfinite(actual)): + if gisnan(desired) or gisnan(actual): + if not (gisnan(desired) and gisnan(actual)): + raise AssertionError(_build_err_msg()) + else: + if not desired == actual: + raise AssertionError(_build_err_msg()) + return + except (NotImplementedError, TypeError): + pass + if abs(desired - actual) >= np.float64(1.5 * 10.0 ** (-decimal)): + raise AssertionError(_build_err_msg()) + + +def assert_approx_equal(actual, desired, significant=7, err_msg="", verbose=True): + """ + Raises an AssertionError if two items are not equal up to significant + digits. + + .. note:: It is recommended to use one of `assert_allclose`, + `assert_array_almost_equal_nulp` or `assert_array_max_ulp` + instead of this function for more consistent floating point + comparisons. + + Given two numbers, check that they are approximately equal. + Approximately equal is defined as the number of significant digits + that agree. + + Parameters + ---------- + actual : scalar + The object to check. + desired : scalar + The expected object. + significant : int, optional + Desired precision, default is 7. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired are not equal up to specified precision. + + See Also + -------- + assert_allclose: Compare two array_like objects for equality with desired + relative and/or absolute precision. + assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal + + Examples + -------- + >>> np.testing.assert_approx_equal( + ... 0.12345677777777e-20, 0.1234567e-20 + ... ) # doctest: +SKIP + >>> np.testing.assert_approx_equal( + ... 0.12345670e-20, + ... 0.12345671e-20, # doctest: +SKIP + ... significant=8, + ... ) + >>> np.testing.assert_approx_equal( + ... 0.12345670e-20, + ... 0.12345672e-20, # doctest: +SKIP + ... significant=8, + ... ) + Traceback (most recent call last): + ... + AssertionError: + Items are not equal to 8 significant digits: + ACTUAL: 1.234567e-21 + DESIRED: 1.2345672e-21 + + the evaluated condition that raises the exception is + + >>> abs(0.12345670e-20 / 1e-21 - 0.12345672e-20 / 1e-21) >= 10 ** -(8 - 1) + True + + """ + __tracebackhide__ = True # Hide traceback for py.test + import numpy as np + + (actual, desired) = map(float, (actual, desired)) + if desired == actual: + return + # Normalized the numbers to be in range (-10.0,10.0) + # scale = float(pow(10,math.floor(math.log10(0.5*(abs(desired)+abs(actual)))))) + scale = 0.5 * (np.abs(desired) + np.abs(actual)) + scale = np.power(10, np.floor(np.log10(scale))) + try: + sc_desired = desired / scale + except ZeroDivisionError: + sc_desired = 0.0 + try: + sc_actual = actual / scale + except ZeroDivisionError: + sc_actual = 0.0 + msg = build_err_msg( + [actual, desired], + err_msg, + header=f"Items are not equal to {significant:d} significant digits:", + verbose=verbose, + ) + try: + # If one of desired/actual is not finite, handle it specially here: + # check that both are nan if any is a nan, and test for equality + # otherwise + if not (gisfinite(desired) and gisfinite(actual)): + if gisnan(desired) or gisnan(actual): + if not (gisnan(desired) and gisnan(actual)): + raise AssertionError(msg) + else: + if not desired == actual: + raise AssertionError(msg) + return + except (TypeError, NotImplementedError): + pass + if np.abs(sc_desired - sc_actual) >= np.power(10.0, -(significant - 1)): + raise AssertionError(msg) + + +def assert_array_compare( + comparison, + x, + y, + err_msg="", + verbose=True, + header="", + precision=6, + equal_nan=True, + equal_inf=True, + *, + strict=False, +): + __tracebackhide__ = True # Hide traceback for py.test + from torch._numpy import all, array, asarray, bool_, inf, isnan, max + + x = asarray(x) + y = asarray(y) + + def array2string(a): + return str(a) + + # original array for output formatting + ox, oy = x, y + + def func_assert_same_pos(x, y, func=isnan, hasval="nan"): + """Handling nan/inf. + + Combine results of running func on x and y, checking that they are True + at the same locations. + + """ + __tracebackhide__ = True # Hide traceback for py.test + x_id = func(x) + y_id = func(y) + # We include work-arounds here to handle three types of slightly + # pathological ndarray subclasses: + # (1) all() on `masked` array scalars can return masked arrays, so we + # use != True + # (2) __eq__ on some ndarray subclasses returns Python booleans + # instead of element-wise comparisons, so we cast to bool_() and + # use isinstance(..., bool) checks + # (3) subclasses with bare-bones __array_function__ implementations may + # not implement np.all(), so favor using the .all() method + # We are not committed to supporting such subclasses, but it's nice to + # support them if possible. + if (x_id == y_id).all().item() is not True: + msg = build_err_msg( + [x, y], + err_msg + f"\nx and y {hasval} location mismatch:", + verbose=verbose, + header=header, + names=("x", "y"), + precision=precision, + ) + raise AssertionError(msg) + # If there is a scalar, then here we know the array has the same + # flag as it everywhere, so we should return the scalar flag. + if isinstance(x_id, bool) or x_id.ndim == 0: + return bool_(x_id) + elif isinstance(y_id, bool) or y_id.ndim == 0: + return bool_(y_id) + else: + return y_id + + try: + if strict: + cond = x.shape == y.shape and x.dtype == y.dtype + else: + cond = (x.shape == () or y.shape == ()) or x.shape == y.shape + if not cond: + if x.shape != y.shape: + reason = f"\n(shapes {x.shape}, {y.shape} mismatch)" + else: + reason = f"\n(dtypes {x.dtype}, {y.dtype} mismatch)" + msg = build_err_msg( + [x, y], + err_msg + reason, + verbose=verbose, + header=header, + names=("x", "y"), + precision=precision, + ) + raise AssertionError(msg) + + flagged = bool_(False) + + if equal_nan: + flagged = func_assert_same_pos(x, y, func=isnan, hasval="nan") + + if equal_inf: + flagged |= func_assert_same_pos( + x, y, func=lambda xy: xy == +inf, hasval="+inf" + ) + flagged |= func_assert_same_pos( + x, y, func=lambda xy: xy == -inf, hasval="-inf" + ) + + if flagged.ndim > 0: + x, y = x[~flagged], y[~flagged] + # Only do the comparison if actual values are left + if x.size == 0: + return + elif flagged: + # no sense doing comparison if everything is flagged. + return + + val = comparison(x, y) + + if isinstance(val, bool): + cond = val + reduced = array([val]) + else: + reduced = val.ravel() + cond = reduced.all() + + # The below comparison is a hack to ensure that fully masked + # results, for which val.ravel().all() returns np.ma.masked, + # do not trigger a failure (np.ma.masked != True evaluates as + # np.ma.masked, which is falsy). + if not cond: + n_mismatch = reduced.size - int(reduced.sum(dtype=intp)) + n_elements = flagged.size if flagged.ndim != 0 else reduced.size + percent_mismatch = 100 * n_mismatch / n_elements + remarks = [ + f"Mismatched elements: {n_mismatch} / {n_elements} ({percent_mismatch:.3g}%)" + ] + + # with errstate(all='ignore'): + # ignore errors for non-numeric types + with contextlib.suppress(TypeError, RuntimeError): + error = abs(x - y) + if np.issubdtype(x.dtype, np.unsignedinteger): + error2 = abs(y - x) + np.minimum(error, error2, out=error) + max_abs_error = max(error) + remarks.append( + "Max absolute difference: " + array2string(max_abs_error.item()) + ) + + # note: this definition of relative error matches that one + # used by assert_allclose (found in np.isclose) + # Filter values where the divisor would be zero + nonzero = bool_(y != 0) + if all(~nonzero): + max_rel_error = array(inf) + else: + max_rel_error = max(error[nonzero] / abs(y[nonzero])) + remarks.append( + "Max relative difference: " + array2string(max_rel_error.item()) + ) + + err_msg += "\n" + "\n".join(remarks) + msg = build_err_msg( + [ox, oy], + err_msg, + verbose=verbose, + header=header, + names=("x", "y"), + precision=precision, + ) + raise AssertionError(msg) + except ValueError: + import traceback + + efmt = traceback.format_exc() + header = f"error during assertion:\n\n{efmt}\n\n{header}" + + msg = build_err_msg( + [x, y], + err_msg, + verbose=verbose, + header=header, + names=("x", "y"), + precision=precision, + ) + raise ValueError(msg) # noqa: B904 + + +def assert_array_equal(x, y, err_msg="", verbose=True, *, strict=False): + """ + Raises an AssertionError if two array_like objects are not equal. + + Given two array_like objects, check that the shape is equal and all + elements of these objects are equal (but see the Notes for the special + handling of a scalar). An exception is raised at shape mismatch or + conflicting values. In contrast to the standard usage in numpy, NaNs + are compared like numbers, no assertion is raised if both objects have + NaNs in the same positions. + + The usual caution for verifying equality with floating point numbers is + advised. + + Parameters + ---------- + x : array_like + The actual object to check. + y : array_like + The desired, expected object. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + strict : bool, optional + If True, raise an AssertionError when either the shape or the data + type of the array_like objects does not match. The special + handling for scalars mentioned in the Notes section is disabled. + + Raises + ------ + AssertionError + If actual and desired objects are not equal. + + See Also + -------- + assert_allclose: Compare two array_like objects for equality with desired + relative and/or absolute precision. + assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal + + Notes + ----- + When one of `x` and `y` is a scalar and the other is array_like, the + function checks that each element of the array_like object is equal to + the scalar. This behaviour can be disabled with the `strict` parameter. + + Examples + -------- + The first assert does not raise an exception: + + >>> np.testing.assert_array_equal( + ... [1.0, 2.33333, np.nan], [np.exp(0), 2.33333, np.nan] + ... ) + + Use `assert_allclose` or one of the nulp (number of floating point values) + functions for these cases instead: + + >>> np.testing.assert_allclose( + ... [1.0, np.pi, np.nan], [1, np.sqrt(np.pi) ** 2, np.nan], rtol=1e-10, atol=0 + ... ) + + As mentioned in the Notes section, `assert_array_equal` has special + handling for scalars. Here the test checks that each value in `x` is 3: + + >>> x = np.full((2, 5), fill_value=3) + >>> np.testing.assert_array_equal(x, 3) + + Use `strict` to raise an AssertionError when comparing a scalar with an + array: + + >>> np.testing.assert_array_equal(x, 3, strict=True) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not equal + + (shapes (2, 5), () mismatch) + x: torch.ndarray([[3, 3, 3, 3, 3], + [3, 3, 3, 3, 3]]) + y: torch.ndarray(3) + + The `strict` parameter also ensures that the array data types match: + + >>> x = np.array([2, 2, 2]) + >>> y = np.array([2.0, 2.0, 2.0], dtype=np.float32) + >>> np.testing.assert_array_equal(x, y, strict=True) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not equal + + (dtypes dtype("int64"), dtype("float32") mismatch) + x: torch.ndarray([2, 2, 2]) + y: torch.ndarray([2., 2., 2.]) + """ + __tracebackhide__ = True # Hide traceback for py.test + assert_array_compare( + operator.__eq__, + x, + y, + err_msg=err_msg, + verbose=verbose, + header="Arrays are not equal", + strict=strict, + ) + + +def assert_array_almost_equal(x, y, decimal=6, err_msg="", verbose=True): + """ + Raises an AssertionError if two objects are not equal up to desired + precision. + + .. note:: It is recommended to use one of `assert_allclose`, + `assert_array_almost_equal_nulp` or `assert_array_max_ulp` + instead of this function for more consistent floating point + comparisons. + + The test verifies identical shapes and that the elements of ``actual`` and + ``desired`` satisfy. + + ``abs(desired-actual) < 1.5 * 10**(-decimal)`` + + That is a looser test than originally documented, but agrees with what the + actual implementation did up to rounding vagaries. An exception is raised + at shape mismatch or conflicting values. In contrast to the standard usage + in numpy, NaNs are compared like numbers, no assertion is raised if both + objects have NaNs in the same positions. + + Parameters + ---------- + x : array_like + The actual object to check. + y : array_like + The desired, expected object. + decimal : int, optional + Desired precision, default is 6. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired are not equal up to specified precision. + + See Also + -------- + assert_allclose: Compare two array_like objects for equality with desired + relative and/or absolute precision. + assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal + + Examples + -------- + the first assert does not raise an exception + + >>> np.testing.assert_array_almost_equal([1.0, 2.333, np.nan], [1.0, 2.333, np.nan]) + + >>> np.testing.assert_array_almost_equal( + ... [1.0, 2.33333, np.nan], [1.0, 2.33339, np.nan], decimal=5 + ... ) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not almost equal to 5 decimals + + Mismatched elements: 1 / 3 (33.3%) + Max absolute difference: 5.999999999994898e-05 + Max relative difference: 2.5713661239633743e-05 + x: torch.ndarray([1.0000, 2.3333, nan], dtype=float64) + y: torch.ndarray([1.0000, 2.3334, nan], dtype=float64) + + >>> np.testing.assert_array_almost_equal( + ... [1.0, 2.33333, np.nan], [1.0, 2.33333, 5], decimal=5 + ... ) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not almost equal to 5 decimals + + x and y nan location mismatch: + x: torch.ndarray([1.0000, 2.3333, nan], dtype=float64) + y: torch.ndarray([1.0000, 2.3333, 5.0000], dtype=float64) + + """ + __tracebackhide__ = True # Hide traceback for py.test + from torch._numpy import any as npany, float_, issubdtype, number, result_type + + def compare(x, y): + try: + if npany(gisinf(x)) or npany(gisinf(y)): + xinfid = gisinf(x) + yinfid = gisinf(y) + if not (xinfid == yinfid).all(): + return False + # if one item, x and y is +- inf + if x.size == y.size == 1: + return x == y + x = x[~xinfid] + y = y[~yinfid] + except (TypeError, NotImplementedError): + pass + + # make sure y is an inexact type to avoid abs(MIN_INT); will cause + # casting of x later. + dtype = result_type(y, 1.0) + y = asanyarray(y, dtype) + z = abs(x - y) + + if not issubdtype(z.dtype, number): + z = z.astype(float_) # handle object arrays + + return z < 1.5 * 10.0 ** (-decimal) + + assert_array_compare( + compare, + x, + y, + err_msg=err_msg, + verbose=verbose, + header=f"Arrays are not almost equal to {decimal:d} decimals", + precision=decimal, + ) + + +def assert_array_less(x, y, err_msg="", verbose=True): + """ + Raises an AssertionError if two array_like objects are not ordered by less + than. + + Given two array_like objects, check that the shape is equal and all + elements of the first object are strictly smaller than those of the + second object. An exception is raised at shape mismatch or incorrectly + ordered values. Shape mismatch does not raise if an object has zero + dimension. In contrast to the standard usage in numpy, NaNs are + compared, no assertion is raised if both objects have NaNs in the same + positions. + + + + Parameters + ---------- + x : array_like + The smaller object to check. + y : array_like + The larger object to compare. + err_msg : string + The error message to be printed in case of failure. + verbose : bool + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired objects are not equal. + + See Also + -------- + assert_array_equal: tests objects for equality + assert_array_almost_equal: test objects for equality up to precision + + + + Examples + -------- + >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1.1, 2.0, np.nan]) + >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1, 2.0, np.nan]) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not less-ordered + + Mismatched elements: 1 / 3 (33.3%) + Max absolute difference: 1.0 + Max relative difference: 0.5 + x: torch.ndarray([1., 1., nan], dtype=float64) + y: torch.ndarray([1., 2., nan], dtype=float64) + + >>> np.testing.assert_array_less([1.0, 4.0], 3) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not less-ordered + + Mismatched elements: 1 / 2 (50%) + Max absolute difference: 2.0 + Max relative difference: 0.6666666666666666 + x: torch.ndarray([1., 4.], dtype=float64) + y: torch.ndarray(3) + + >>> np.testing.assert_array_less([1.0, 2.0, 3.0], [4]) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not less-ordered + + (shapes (3,), (1,) mismatch) + x: torch.ndarray([1., 2., 3.], dtype=float64) + y: torch.ndarray([4]) + + """ + __tracebackhide__ = True # Hide traceback for py.test + assert_array_compare( + operator.__lt__, + x, + y, + err_msg=err_msg, + verbose=verbose, + header="Arrays are not less-ordered", + equal_inf=False, + ) + + +def assert_string_equal(actual, desired): + """ + Test if two strings are equal. + + If the given strings are equal, `assert_string_equal` does nothing. + If they are not equal, an AssertionError is raised, and the diff + between the strings is shown. + + Parameters + ---------- + actual : str + The string to test for equality against the expected string. + desired : str + The expected string. + + Examples + -------- + >>> np.testing.assert_string_equal("abc", "abc") # doctest: +SKIP + >>> np.testing.assert_string_equal("abc", "abcd") # doctest: +SKIP + Traceback (most recent call last): + File "", line 1, in + ... + AssertionError: Differences in strings: + - abc+ abcd? + + + """ + # delay import of difflib to reduce startup time + __tracebackhide__ = True # Hide traceback for py.test + import difflib + + if not isinstance(actual, str): + raise AssertionError(repr(type(actual))) + if not isinstance(desired, str): + raise AssertionError(repr(type(desired))) + if desired == actual: + return + + diff = list( + difflib.Differ().compare(actual.splitlines(True), desired.splitlines(True)) + ) + diff_list = [] + while diff: + d1 = diff.pop(0) + if d1.startswith(" "): + continue + if d1.startswith("- "): + l = [d1] + d2 = diff.pop(0) + if d2.startswith("? "): + l.append(d2) + d2 = diff.pop(0) + if not d2.startswith("+ "): + raise AssertionError(repr(d2)) + l.append(d2) + if diff: + d3 = diff.pop(0) + if d3.startswith("? "): + l.append(d3) + else: + diff.insert(0, d3) + if d2[2:] == d1[2:]: + continue + diff_list.extend(l) + continue + raise AssertionError(repr(d1)) + if not diff_list: + return + msg = f"Differences in strings:\n{''.join(diff_list).rstrip()}" + if actual != desired: + raise AssertionError(msg) + + +import unittest + + +class _Dummy(unittest.TestCase): + def nop(self): + pass + + +_d = _Dummy("nop") + + +def assert_raises_regex(exception_class, expected_regexp, *args, **kwargs): + """ + assert_raises_regex(exception_class, expected_regexp, callable, *args, + **kwargs) + assert_raises_regex(exception_class, expected_regexp) + + Fail unless an exception of class exception_class and with message that + matches expected_regexp is thrown by callable when invoked with arguments + args and keyword arguments kwargs. + + Alternatively, can be used as a context manager like `assert_raises`. + + Notes + ----- + .. versionadded:: 1.9.0 + + """ + __tracebackhide__ = True # Hide traceback for py.test + return _d.assertRaisesRegex(exception_class, expected_regexp, *args, **kwargs) + + +def decorate_methods(cls, decorator, testmatch=None): + """ + Apply a decorator to all methods in a class matching a regular expression. + + The given decorator is applied to all public methods of `cls` that are + matched by the regular expression `testmatch` + (``testmatch.search(methodname)``). Methods that are private, i.e. start + with an underscore, are ignored. + + Parameters + ---------- + cls : class + Class whose methods to decorate. + decorator : function + Decorator to apply to methods + testmatch : compiled regexp or str, optional + The regular expression. Default value is None, in which case the + nose default (``re.compile(r'(?:^|[\\b_\\.%s-])[Tt]est' % os.sep)``) + is used. + If `testmatch` is a string, it is compiled to a regular expression + first. + + """ + if testmatch is None: + testmatch = re.compile(rf"(?:^|[\\b_\\.{os.sep}-])[Tt]est") + else: + testmatch = re.compile(testmatch) + cls_attr = cls.__dict__ + + # delayed import to reduce startup time + from inspect import isfunction + + methods = [_m for _m in cls_attr.values() if isfunction(_m)] + for function in methods: + try: + if hasattr(function, "compat_func_name"): + funcname = function.compat_func_name + else: + funcname = function.__name__ + except AttributeError: + # not a function + continue + if testmatch.search(funcname) and not funcname.startswith("_"): + setattr(cls, funcname, decorator(function)) + return + + +def _assert_valid_refcount(op): + """ + Check that ufuncs don't mishandle refcount of object `1`. + Used in a few regression tests. + """ + if not HAS_REFCOUNT: + return True + + import gc + + import numpy as np + + b = np.arange(100 * 100).reshape(100, 100) + c = b + i = 1 + + gc.disable() + try: + rc = sys.getrefcount(i) + for _ in range(15): + d = op(b, c) + assert_(sys.getrefcount(i) >= rc) + finally: + gc.enable() + del d # for pyflakes + + +def assert_allclose( + actual, + desired, + rtol=1e-7, + atol=0, + equal_nan=True, + err_msg="", + verbose=True, + check_dtype=False, +): + """ + Raises an AssertionError if two objects are not equal up to desired + tolerance. + + Given two array_like objects, check that their shapes and all elements + are equal (but see the Notes for the special handling of a scalar). An + exception is raised if the shapes mismatch or any values conflict. In + contrast to the standard usage in numpy, NaNs are compared like numbers, + no assertion is raised if both objects have NaNs in the same positions. + + The test is equivalent to ``allclose(actual, desired, rtol, atol)`` (note + that ``allclose`` has different default values). It compares the difference + between `actual` and `desired` to ``atol + rtol * abs(desired)``. + + .. versionadded:: 1.5.0 + + Parameters + ---------- + actual : array_like + Array obtained. + desired : array_like + Array desired. + rtol : float, optional + Relative tolerance. + atol : float, optional + Absolute tolerance. + equal_nan : bool, optional. + If True, NaNs will compare equal. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired are not equal up to specified precision. + + See Also + -------- + assert_array_almost_equal_nulp, assert_array_max_ulp + + Notes + ----- + When one of `actual` and `desired` is a scalar and the other is + array_like, the function checks that each element of the array_like + object is equal to the scalar. + + Examples + -------- + >>> x = [1e-5, 1e-3, 1e-1] + >>> y = np.arccos(np.cos(x)) + >>> np.testing.assert_allclose(x, y, rtol=1e-5, atol=0) + + """ + __tracebackhide__ = True # Hide traceback for py.test + + def compare(x, y): + return np.isclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan) + + actual, desired = asanyarray(actual), asanyarray(desired) + header = f"Not equal to tolerance rtol={rtol:g}, atol={atol:g}" + + if check_dtype: + assert actual.dtype == desired.dtype + + assert_array_compare( + compare, + actual, + desired, + err_msg=str(err_msg), + verbose=verbose, + header=header, + equal_nan=equal_nan, + ) + + +def assert_array_almost_equal_nulp(x, y, nulp=1): + """ + Compare two arrays relatively to their spacing. + + This is a relatively robust method to compare two arrays whose amplitude + is variable. + + Parameters + ---------- + x, y : array_like + Input arrays. + nulp : int, optional + The maximum number of unit in the last place for tolerance (see Notes). + Default is 1. + + Returns + ------- + None + + Raises + ------ + AssertionError + If the spacing between `x` and `y` for one or more elements is larger + than `nulp`. + + See Also + -------- + assert_array_max_ulp : Check that all items of arrays differ in at most + N Units in the Last Place. + spacing : Return the distance between x and the nearest adjacent number. + + Notes + ----- + An assertion is raised if the following condition is not met:: + + abs(x - y) <= nulp * spacing(maximum(abs(x), abs(y))) + + Examples + -------- + >>> x = np.array([1.0, 1e-10, 1e-20]) + >>> eps = np.finfo(x.dtype).eps + >>> np.testing.assert_array_almost_equal_nulp(x, x * eps / 2 + x) # doctest: +SKIP + + >>> np.testing.assert_array_almost_equal_nulp(x, x * eps + x) # doctest: +SKIP + Traceback (most recent call last): + ... + AssertionError: X and Y are not equal to 1 ULP (max is 2) + + """ + __tracebackhide__ = True # Hide traceback for py.test + import numpy as np + + ax = np.abs(x) + ay = np.abs(y) + ref = nulp * np.spacing(np.where(ax > ay, ax, ay)) + if not np.all(np.abs(x - y) <= ref): + if np.iscomplexobj(x) or np.iscomplexobj(y): + msg = f"X and Y are not equal to {nulp:d} ULP" + else: + max_nulp = np.max(nulp_diff(x, y)) + msg = f"X and Y are not equal to {nulp:d} ULP (max is {max_nulp:g})" + raise AssertionError(msg) + + +def assert_array_max_ulp(a, b, maxulp=1, dtype=None): + """ + Check that all items of arrays differ in at most N Units in the Last Place. + + Parameters + ---------- + a, b : array_like + Input arrays to be compared. + maxulp : int, optional + The maximum number of units in the last place that elements of `a` and + `b` can differ. Default is 1. + dtype : dtype, optional + Data-type to convert `a` and `b` to if given. Default is None. + + Returns + ------- + ret : ndarray + Array containing number of representable floating point numbers between + items in `a` and `b`. + + Raises + ------ + AssertionError + If one or more elements differ by more than `maxulp`. + + Notes + ----- + For computing the ULP difference, this API does not differentiate between + various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000 + is zero). + + See Also + -------- + assert_array_almost_equal_nulp : Compare two arrays relatively to their + spacing. + + Examples + -------- + >>> a = np.linspace(0.0, 1.0, 100) + >>> res = np.testing.assert_array_max_ulp(a, np.arcsin(np.sin(a))) # doctest: +SKIP + + """ + __tracebackhide__ = True # Hide traceback for py.test + import numpy as np + + ret = nulp_diff(a, b, dtype) + if not np.all(ret <= maxulp): + raise AssertionError( + f"Arrays are not almost equal up to {maxulp:g} " + f"ULP (max difference is {np.max(ret):g} ULP)" + ) + return ret + + +def nulp_diff(x, y, dtype=None): + """For each item in x and y, return the number of representable floating + points between them. + + Parameters + ---------- + x : array_like + first input array + y : array_like + second input array + dtype : dtype, optional + Data-type to convert `x` and `y` to if given. Default is None. + + Returns + ------- + nulp : array_like + number of representable floating point numbers between each item in x + and y. + + Notes + ----- + For computing the ULP difference, this API does not differentiate between + various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000 + is zero). + + Examples + -------- + # By definition, epsilon is the smallest number such as 1 + eps != 1, so + # there should be exactly one ULP between 1 and 1 + eps + >>> nulp_diff(1, 1 + np.finfo(x.dtype).eps) # doctest: +SKIP + 1.0 + """ + import numpy as np + + if dtype: + x = np.asarray(x, dtype=dtype) + y = np.asarray(y, dtype=dtype) + else: + x = np.asarray(x) + y = np.asarray(y) + + t = np.common_type(x, y) + if np.iscomplexobj(x) or np.iscomplexobj(y): + raise NotImplementedError("_nulp not implemented for complex array") + + x = np.array([x], dtype=t) + y = np.array([y], dtype=t) + + x[np.isnan(x)] = np.nan + y[np.isnan(y)] = np.nan + + if not x.shape == y.shape: + raise ValueError(f"x and y do not have the same shape: {x.shape} - {y.shape}") + + def _diff(rx, ry, vdt): + diff = np.asarray(rx - ry, dtype=vdt) + return np.abs(diff) + + rx = integer_repr(x) + ry = integer_repr(y) + return _diff(rx, ry, t) + + +def _integer_repr(x, vdt, comp): + # Reinterpret binary representation of the float as sign-magnitude: + # take into account two-complement representation + # See also + # https://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/ + rx = x.view(vdt) + if rx.size != 1: + rx[rx < 0] = comp - rx[rx < 0] + else: + if rx < 0: + rx = comp - rx + + return rx + + +def integer_repr(x): + """Return the signed-magnitude interpretation of the binary representation + of x.""" + import numpy as np + + if x.dtype == np.float16: + return _integer_repr(x, np.int16, np.int16(-(2**15))) + elif x.dtype == np.float32: + return _integer_repr(x, np.int32, np.int32(-(2**31))) + elif x.dtype == np.float64: + return _integer_repr(x, np.int64, np.int64(-(2**63))) + else: + raise ValueError(f"Unsupported dtype {x.dtype}") + + +@contextlib.contextmanager +def _assert_warns_context(warning_class, name=None): + __tracebackhide__ = True # Hide traceback for py.test + with suppress_warnings() as sup: + l = sup.record(warning_class) + yield + if not len(l) > 0: + name_str = f" when calling {name}" if name is not None else "" + raise AssertionError("No warning raised" + name_str) + + +def assert_warns(warning_class, *args, **kwargs): + """ + Fail unless the given callable throws the specified warning. + + A warning of class warning_class should be thrown by the callable when + invoked with arguments args and keyword arguments kwargs. + If a different type of warning is thrown, it will not be caught. + + If called with all arguments other than the warning class omitted, may be + used as a context manager: + + with assert_warns(SomeWarning): + do_something() + + The ability to be used as a context manager is new in NumPy v1.11.0. + + .. versionadded:: 1.4.0 + + Parameters + ---------- + warning_class : class + The class defining the warning that `func` is expected to throw. + func : callable, optional + Callable to test + *args : Arguments + Arguments for `func`. + **kwargs : Kwargs + Keyword arguments for `func`. + + Returns + ------- + The value returned by `func`. + + Examples + -------- + >>> import warnings + >>> def deprecated_func(num): + ... warnings.warn("Please upgrade", DeprecationWarning) + ... return num * num + >>> with np.testing.assert_warns(DeprecationWarning): + ... assert deprecated_func(4) == 16 + >>> # or passing a func + >>> ret = np.testing.assert_warns(DeprecationWarning, deprecated_func, 4) + >>> assert ret == 16 + """ + if not args: + return _assert_warns_context(warning_class) + + func = args[0] + args = args[1:] + with _assert_warns_context(warning_class, name=func.__name__): + return func(*args, **kwargs) + + +@contextlib.contextmanager +def _assert_no_warnings_context(name=None): + __tracebackhide__ = True # Hide traceback for py.test + with warnings.catch_warnings(record=True) as l: + warnings.simplefilter("always") + yield + if len(l) > 0: + name_str = f" when calling {name}" if name is not None else "" + raise AssertionError(f"Got warnings{name_str}: {l}") + + +def assert_no_warnings(*args, **kwargs): + """ + Fail if the given callable produces any warnings. + + If called with all arguments omitted, may be used as a context manager: + + with assert_no_warnings(): + do_something() + + The ability to be used as a context manager is new in NumPy v1.11.0. + + .. versionadded:: 1.7.0 + + Parameters + ---------- + func : callable + The callable to test. + \\*args : Arguments + Arguments passed to `func`. + \\*\\*kwargs : Kwargs + Keyword arguments passed to `func`. + + Returns + ------- + The value returned by `func`. + + """ + if not args: + return _assert_no_warnings_context() + + func = args[0] + args = args[1:] + with _assert_no_warnings_context(name=func.__name__): + return func(*args, **kwargs) + + +def _gen_alignment_data(dtype=float32, type="binary", max_size=24): + """ + generator producing data with different alignment and offsets + to test simd vectorization + + Parameters + ---------- + dtype : dtype + data type to produce + type : string + 'unary': create data for unary operations, creates one input + and output array + 'binary': create data for unary operations, creates two input + and output array + max_size : integer + maximum size of data to produce + + Returns + ------- + if type is 'unary' yields one output, one input array and a message + containing information on the data + if type is 'binary' yields one output array, two input array and a message + containing information on the data + + """ + ufmt = "unary offset=(%d, %d), size=%d, dtype=%r, %s" + bfmt = "binary offset=(%d, %d, %d), size=%d, dtype=%r, %s" + for o in range(3): + for s in range(o + 2, max(o + 3, max_size)): + if type == "unary": + + def inp(): + return arange(s, dtype=dtype)[o:] + + out = empty((s,), dtype=dtype)[o:] + yield out, inp(), ufmt % (o, o, s, dtype, "out of place") + d = inp() + yield d, d, ufmt % (o, o, s, dtype, "in place") + yield ( + out[1:], + inp()[:-1], + ufmt + % ( + o + 1, + o, + s - 1, + dtype, + "out of place", + ), + ) + yield ( + out[:-1], + inp()[1:], + ufmt + % ( + o, + o + 1, + s - 1, + dtype, + "out of place", + ), + ) + yield inp()[:-1], inp()[1:], ufmt % (o, o + 1, s - 1, dtype, "aliased") + yield inp()[1:], inp()[:-1], ufmt % (o + 1, o, s - 1, dtype, "aliased") + if type == "binary": + + def inp1(): + return arange(s, dtype=dtype)[o:] + + inp2 = inp1 + out = empty((s,), dtype=dtype)[o:] + yield out, inp1(), inp2(), bfmt % (o, o, o, s, dtype, "out of place") + d = inp1() + yield d, d, inp2(), bfmt % (o, o, o, s, dtype, "in place1") + d = inp2() + yield d, inp1(), d, bfmt % (o, o, o, s, dtype, "in place2") + yield ( + out[1:], + inp1()[:-1], + inp2()[:-1], + bfmt + % ( + o + 1, + o, + o, + s - 1, + dtype, + "out of place", + ), + ) + yield ( + out[:-1], + inp1()[1:], + inp2()[:-1], + bfmt + % ( + o, + o + 1, + o, + s - 1, + dtype, + "out of place", + ), + ) + yield ( + out[:-1], + inp1()[:-1], + inp2()[1:], + bfmt + % ( + o, + o, + o + 1, + s - 1, + dtype, + "out of place", + ), + ) + yield ( + inp1()[1:], + inp1()[:-1], + inp2()[:-1], + bfmt + % ( + o + 1, + o, + o, + s - 1, + dtype, + "aliased", + ), + ) + yield ( + inp1()[:-1], + inp1()[1:], + inp2()[:-1], + bfmt + % ( + o, + o + 1, + o, + s - 1, + dtype, + "aliased", + ), + ) + yield ( + inp1()[:-1], + inp1()[:-1], + inp2()[1:], + bfmt + % ( + o, + o, + o + 1, + s - 1, + dtype, + "aliased", + ), + ) + + +class IgnoreException(Exception): + "Ignoring this exception due to disabled feature" + + +@contextlib.contextmanager +def tempdir(*args, **kwargs): + """Context manager to provide a temporary test folder. + + All arguments are passed as this to the underlying tempfile.mkdtemp + function. + + """ + tmpdir = mkdtemp(*args, **kwargs) + try: + yield tmpdir + finally: + shutil.rmtree(tmpdir) + + +@contextlib.contextmanager +def temppath(*args, **kwargs): + """Context manager for temporary files. + + Context manager that returns the path to a closed temporary file. Its + parameters are the same as for tempfile.mkstemp and are passed directly + to that function. The underlying file is removed when the context is + exited, so it should be closed at that time. + + Windows does not allow a temporary file to be opened if it is already + open, so the underlying file must be closed after opening before it + can be opened again. + + """ + fd, path = mkstemp(*args, **kwargs) + os.close(fd) + try: + yield path + finally: + os.remove(path) + + +class clear_and_catch_warnings(warnings.catch_warnings): + """Context manager that resets warning registry for catching warnings + + Warnings can be slippery, because, whenever a warning is triggered, Python + adds a ``__warningregistry__`` member to the *calling* module. This makes + it impossible to retrigger the warning in this module, whatever you put in + the warnings filters. This context manager accepts a sequence of `modules` + as a keyword argument to its constructor and: + + * stores and removes any ``__warningregistry__`` entries in given `modules` + on entry; + * resets ``__warningregistry__`` to its previous state on exit. + + This makes it possible to trigger any warning afresh inside the context + manager without disturbing the state of warnings outside. + + For compatibility with Python 3.0, please consider all arguments to be + keyword-only. + + Parameters + ---------- + record : bool, optional + Specifies whether warnings should be captured by a custom + implementation of ``warnings.showwarning()`` and be appended to a list + returned by the context manager. Otherwise None is returned by the + context manager. The objects appended to the list are arguments whose + attributes mirror the arguments to ``showwarning()``. + modules : sequence, optional + Sequence of modules for which to reset warnings registry on entry and + restore on exit. To work correctly, all 'ignore' filters should + filter by one of these modules. + + Examples + -------- + >>> import warnings + >>> with np.testing.clear_and_catch_warnings( # doctest: +SKIP + ... modules=[np.core.fromnumeric] + ... ): + ... warnings.simplefilter("always") + ... warnings.filterwarnings("ignore", module="np.core.fromnumeric") + ... # do something that raises a warning but ignore those in + ... # np.core.fromnumeric + """ + + class_modules = () + + def __init__(self, record=False, modules=()): + self.modules = set(modules).union(self.class_modules) + self._warnreg_copies = {} + super().__init__(record=record) + + def __enter__(self): + for mod in self.modules: + if hasattr(mod, "__warningregistry__"): + mod_reg = mod.__warningregistry__ + self._warnreg_copies[mod] = mod_reg.copy() + mod_reg.clear() + return super().__enter__() + + def __exit__(self, *exc_info): + super().__exit__(*exc_info) + for mod in self.modules: + if hasattr(mod, "__warningregistry__"): + mod.__warningregistry__.clear() + if mod in self._warnreg_copies: + mod.__warningregistry__.update(self._warnreg_copies[mod]) + + +class suppress_warnings: + """ + Context manager and decorator doing much the same as + ``warnings.catch_warnings``. + + However, it also provides a filter mechanism to work around + https://bugs.python.org/issue4180. + + This bug causes Python before 3.4 to not reliably show warnings again + after they have been ignored once (even within catch_warnings). It + means that no "ignore" filter can be used easily, since following + tests might need to see the warning. Additionally it allows easier + specificity for testing warnings and can be nested. + + Parameters + ---------- + forwarding_rule : str, optional + One of "always", "once", "module", or "location". Analogous to + the usual warnings module filter mode, it is useful to reduce + noise mostly on the outmost level. Unsuppressed and unrecorded + warnings will be forwarded based on this rule. Defaults to "always". + "location" is equivalent to the warnings "default", match by exact + location the warning warning originated from. + + Notes + ----- + Filters added inside the context manager will be discarded again + when leaving it. Upon entering all filters defined outside a + context will be applied automatically. + + When a recording filter is added, matching warnings are stored in the + ``log`` attribute as well as in the list returned by ``record``. + + If filters are added and the ``module`` keyword is given, the + warning registry of this module will additionally be cleared when + applying it, entering the context, or exiting it. This could cause + warnings to appear a second time after leaving the context if they + were configured to be printed once (default) and were already + printed before the context was entered. + + Nesting this context manager will work as expected when the + forwarding rule is "always" (default). Unfiltered and unrecorded + warnings will be passed out and be matched by the outer level. + On the outmost level they will be printed (or caught by another + warnings context). The forwarding rule argument can modify this + behaviour. + + Like ``catch_warnings`` this context manager is not threadsafe. + + Examples + -------- + + With a context manager:: + + with np.testing.suppress_warnings() as sup: + sup.filter(DeprecationWarning, "Some text") + sup.filter(module=np.ma.core) + log = sup.record(FutureWarning, "Does this occur?") + command_giving_warnings() + # The FutureWarning was given once, the filtered warnings were + # ignored. All other warnings abide outside settings (may be + # printed/error) + assert_(len(log) == 1) + assert_(len(sup.log) == 1) # also stored in log attribute + + Or as a decorator:: + + sup = np.testing.suppress_warnings() + sup.filter(module=np.ma.core) # module must match exactly + + + @sup + def some_function(): + # do something which causes a warning in np.ma.core + pass + """ + + def __init__(self, forwarding_rule="always"): + self._entered = False + + # Suppressions are either instance or defined inside one with block: + self._suppressions = [] + + if forwarding_rule not in {"always", "module", "once", "location"}: + raise ValueError("unsupported forwarding rule.") + self._forwarding_rule = forwarding_rule + + def _clear_registries(self): + if hasattr(warnings, "_filters_mutated"): + # clearing the registry should not be necessary on new pythons, + # instead the filters should be mutated. + warnings._filters_mutated() + return + # Simply clear the registry, this should normally be harmless, + # note that on new pythons it would be invalidated anyway. + for module in self._tmp_modules: + if hasattr(module, "__warningregistry__"): + module.__warningregistry__.clear() + + def _filter(self, category=Warning, message="", module=None, record=False): + if record: + record = [] # The log where to store warnings + else: + record = None + if self._entered: + if module is None: + warnings.filterwarnings("always", category=category, message=message) + else: + module_regex = module.__name__.replace(".", r"\.") + "$" + warnings.filterwarnings( + "always", category=category, message=message, module=module_regex + ) + self._tmp_modules.add(module) + self._clear_registries() + + self._tmp_suppressions.append( + (category, message, re.compile(message, re.IGNORECASE), module, record) + ) + else: + self._suppressions.append( + (category, message, re.compile(message, re.IGNORECASE), module, record) + ) + + return record + + def filter(self, category=Warning, message="", module=None): + """ + Add a new suppressing filter or apply it if the state is entered. + + Parameters + ---------- + category : class, optional + Warning class to filter + message : string, optional + Regular expression matching the warning message. + module : module, optional + Module to filter for. Note that the module (and its file) + must match exactly and cannot be a submodule. This may make + it unreliable for external modules. + + Notes + ----- + When added within a context, filters are only added inside + the context and will be forgotten when the context is exited. + """ + self._filter(category=category, message=message, module=module, record=False) + + def record(self, category=Warning, message="", module=None): + """ + Append a new recording filter or apply it if the state is entered. + + All warnings matching will be appended to the ``log`` attribute. + + Parameters + ---------- + category : class, optional + Warning class to filter + message : string, optional + Regular expression matching the warning message. + module : module, optional + Module to filter for. Note that the module (and its file) + must match exactly and cannot be a submodule. This may make + it unreliable for external modules. + + Returns + ------- + log : list + A list which will be filled with all matched warnings. + + Notes + ----- + When added within a context, filters are only added inside + the context and will be forgotten when the context is exited. + """ + return self._filter( + category=category, message=message, module=module, record=True + ) + + def __enter__(self): + if self._entered: + raise RuntimeError("cannot enter suppress_warnings twice.") + + self._orig_show = warnings.showwarning + self._filters = warnings.filters + warnings.filters = self._filters[:] + + self._entered = True + self._tmp_suppressions = [] + self._tmp_modules = set() + self._forwarded = set() + + self.log = [] # reset global log (no need to keep same list) + + for cat, mess, _, mod, log in self._suppressions: + if log is not None: + del log[:] # clear the log + if mod is None: + warnings.filterwarnings("always", category=cat, message=mess) + else: + module_regex = mod.__name__.replace(".", r"\.") + "$" + warnings.filterwarnings( + "always", category=cat, message=mess, module=module_regex + ) + self._tmp_modules.add(mod) + warnings.showwarning = self._showwarning + self._clear_registries() + + return self + + def __exit__(self, *exc_info): + warnings.showwarning = self._orig_show + warnings.filters = self._filters + self._clear_registries() + self._entered = False + del self._orig_show + del self._filters + + def _showwarning( + self, message, category, filename, lineno, *args, use_warnmsg=None, **kwargs + ): + for cat, _, pattern, mod, rec in (self._suppressions + self._tmp_suppressions)[ + ::-1 + ]: + if issubclass(category, cat) and pattern.match(message.args[0]) is not None: + if mod is None: + # Message and category match, either recorded or ignored + if rec is not None: + msg = WarningMessage( + message, category, filename, lineno, **kwargs + ) + self.log.append(msg) + rec.append(msg) + return + # Use startswith, because warnings strips the c or o from + # .pyc/.pyo files. + elif mod.__file__.startswith(filename): + # The message and module (filename) match + if rec is not None: + msg = WarningMessage( + message, category, filename, lineno, **kwargs + ) + self.log.append(msg) + rec.append(msg) + return + + # There is no filter in place, so pass to the outside handler + # unless we should only pass it once + if self._forwarding_rule == "always": + if use_warnmsg is None: + self._orig_show(message, category, filename, lineno, *args, **kwargs) + else: + self._orig_showmsg(use_warnmsg) + return + + if self._forwarding_rule == "once": + signature = (message.args, category) + elif self._forwarding_rule == "module": + signature = (message.args, category, filename) + elif self._forwarding_rule == "location": + signature = (message.args, category, filename, lineno) + + if signature in self._forwarded: + return + self._forwarded.add(signature) + if use_warnmsg is None: + self._orig_show(message, category, filename, lineno, *args, **kwargs) + else: + self._orig_showmsg(use_warnmsg) + + def __call__(self, func): + """ + Function decorator to apply certain suppressions to a whole + function. + """ + + @wraps(func) + def new_func(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return new_func + + +@contextlib.contextmanager +def _assert_no_gc_cycles_context(name=None): + __tracebackhide__ = True # Hide traceback for py.test + + # not meaningful to test if there is no refcounting + if not HAS_REFCOUNT: + yield + return + + assert_(gc.isenabled()) + gc.disable() + gc_debug = gc.get_debug() + try: + for _ in range(100): + if gc.collect() == 0: + break + else: + raise RuntimeError( + "Unable to fully collect garbage - perhaps a __del__ method " + "is creating more reference cycles?" + ) + + gc.set_debug(gc.DEBUG_SAVEALL) + yield + # gc.collect returns the number of unreachable objects in cycles that + # were found -- we are checking that no cycles were created in the context + n_objects_in_cycles = gc.collect() + objects_in_cycles = gc.garbage[:] + finally: + del gc.garbage[:] + gc.set_debug(gc_debug) + gc.enable() + + if n_objects_in_cycles: + name_str = f" when calling {name}" if name is not None else "" + raise AssertionError( + "Reference cycles were found{}: {} objects were collected, " + "of which {} are shown below:{}".format( + name_str, + n_objects_in_cycles, + len(objects_in_cycles), + "".join( + "\n {} object with id={}:\n {}".format( + type(o).__name__, + id(o), + pprint.pformat(o).replace("\n", "\n "), + ) + for o in objects_in_cycles + ), + ) + ) + + +def assert_no_gc_cycles(*args, **kwargs): + """ + Fail if the given callable produces any reference cycles. + + If called with all arguments omitted, may be used as a context manager: + + with assert_no_gc_cycles(): + do_something() + + .. versionadded:: 1.15.0 + + Parameters + ---------- + func : callable + The callable to test. + \\*args : Arguments + Arguments passed to `func`. + \\*\\*kwargs : Kwargs + Keyword arguments passed to `func`. + + Returns + ------- + Nothing. The result is deliberately discarded to ensure that all cycles + are found. + + """ + if not args: + return _assert_no_gc_cycles_context() + + func = args[0] + args = args[1:] + with _assert_no_gc_cycles_context(name=func.__name__): + func(*args, **kwargs) + + +def break_cycles(): + """ + Break reference cycles by calling gc.collect + Objects can call other objects' methods (for instance, another object's + __del__) inside their own __del__. On PyPy, the interpreter only runs + between calls to gc.collect, so multiple calls are needed to completely + release all cycles. + """ + + gc.collect() + if IS_PYPY: + # a few more, just to make sure all the finalizers are called + gc.collect() + gc.collect() + gc.collect() + gc.collect() + + +def requires_memory(free_bytes): + """Decorator to skip a test if not enough memory is available""" + import pytest + + def decorator(func): + @wraps(func) + def wrapper(*a, **kw): + msg = check_free_memory(free_bytes) + if msg is not None: + pytest.skip(msg) + + try: + return func(*a, **kw) + except MemoryError: + # Probably ran out of memory regardless: don't regard as failure + pytest.xfail("MemoryError raised") + + return wrapper + + return decorator + + +def check_free_memory(free_bytes): + """ + Check whether `free_bytes` amount of memory is currently free. + Returns: None if enough memory available, otherwise error message + """ + env_var = "NPY_AVAILABLE_MEM" + env_value = os.environ.get(env_var) + if env_value is not None: + try: + mem_free = _parse_size(env_value) + except ValueError as exc: + raise ValueError( # noqa: B904 + f"Invalid environment variable {env_var}: {exc}" + ) + + msg = ( + f"{free_bytes / 1e9} GB memory required, but environment variable " + f"NPY_AVAILABLE_MEM={env_value} set" + ) + else: + mem_free = _get_mem_available() + + if mem_free is None: + msg = ( + "Could not determine available memory; set NPY_AVAILABLE_MEM " + "environment variable (e.g. NPY_AVAILABLE_MEM=16GB) to run " + "the test." + ) + mem_free = -1 + else: + msg = f"{free_bytes / 1e9} GB memory required, but {mem_free / 1e9} GB available" + + return msg if mem_free < free_bytes else None + + +def _parse_size(size_str): + """Convert memory size strings ('12 GB' etc.) to float""" + suffixes = { + "": 1, + "b": 1, + "k": 1000, + "m": 1000**2, + "g": 1000**3, + "t": 1000**4, + "kb": 1000, + "mb": 1000**2, + "gb": 1000**3, + "tb": 1000**4, + "kib": 1024, + "mib": 1024**2, + "gib": 1024**3, + "tib": 1024**4, + } + + size_re = re.compile( + r"^\s*(\d+|\d+\.\d+)\s*({})\s*$".format("|".join(suffixes.keys())), + re.IGNORECASE, + ) + + m = size_re.match(size_str.lower()) + if not m or m.group(2) not in suffixes: + raise ValueError(f"value {size_str!r} not a valid size") + return int(float(m.group(1)) * suffixes[m.group(2)]) + + +def _get_mem_available(): + """Return available memory in bytes, or None if unknown.""" + try: + import psutil + + return psutil.virtual_memory().available + except (ImportError, AttributeError): + pass + + if sys.platform.startswith("linux"): + info = {} + with open("/proc/meminfo") as f: + for line in f: + p = line.split() + info[p[0].strip(":").lower()] = int(p[1]) * 1024 + + if "memavailable" in info: + # Linux >= 3.14 + return info["memavailable"] + else: + return info["memfree"] + info["cached"] + + return None + + +def _no_tracing(func): + """ + Decorator to temporarily turn off tracing for the duration of a test. + Needed in tests that check refcounting, otherwise the tracing itself + influences the refcounts + """ + if not hasattr(sys, "gettrace"): + return func + else: + + @wraps(func) + def wrapper(*args, **kwargs): + original_trace = sys.gettrace() + try: + sys.settrace(None) + return func(*args, **kwargs) + finally: + sys.settrace(original_trace) + + return wrapper + + +def _get_glibc_version(): + try: + ver = os.confstr("CS_GNU_LIBC_VERSION").rsplit(" ")[1] + except Exception: + ver = "0.0" + + return ver + + +_glibcver = _get_glibc_version() + + +def _glibc_older_than(x): + return _glibcver != "0.0" and _glibcver < x diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7439c22d66882d058e617edb85bc4407cfd742a9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/nn/__init__.py @@ -0,0 +1,35 @@ +# We are exposing all subpackages to the end-user. +# Because of possible inter-dependency, we want to avoid +# the cyclic imports, thus implementing lazy version +# as per https://peps.python.org/pep-0562/ + +from typing import TYPE_CHECKING as _TYPE_CHECKING + + +if _TYPE_CHECKING: + from types import ModuleType + + from torch.ao.nn import ( # noqa: TC004 + intrinsic as intrinsic, + qat as qat, + quantizable as quantizable, + quantized as quantized, + sparse as sparse, + ) + + +__all__ = [ + "intrinsic", + "qat", + "quantizable", + "quantized", + "sparse", +] + + +def __getattr__(name: str) -> "ModuleType": + if name in __all__: + import importlib + + return importlib.import_module("." + name, __name__) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e676a786c3f9d33d1ae92517e5264e67339608ae Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/__pycache__/_mappings.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/__pycache__/_mappings.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1ca3d9b7b5f34699c12bc0a3732e870e238a9e6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/__pycache__/_mappings.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0404181bd971f880749b7c3b35799c660b8214f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py new file mode 100644 index 0000000000000000000000000000000000000000..241b4e70e8196e66b471e3855033e55e3e426249 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py @@ -0,0 +1,482 @@ +# mypy: allow-untyped-defs +import copy +import warnings +from collections import defaultdict +from typing import Any + +import torch +from torch import nn +from torch.ao.pruning.sparsifier.utils import fqn_to_module, module_to_fqn + + +__all__ = ["ActivationSparsifier"] + + +class ActivationSparsifier: + r""" + The Activation sparsifier class aims to sparsify/prune activations in a neural + network. The idea is to attach the sparsifier to a layer (or layers) and it + zeroes out the activations based on the mask_fn (or sparsification function) + input by the user. + The mask_fn is applied once all the inputs are aggregated and reduced i.e. + mask = mask_fn(reduce_fn(aggregate_fn(activations))) + + Note:: + The sparsification mask is computed on the input **before it goes through the attached layer**. + + Args: + model (nn.Module): + The model whose layers will be sparsified. The layers that needs to be + sparsified should be added separately using the register_layer() function + aggregate_fn (Optional, Callable): + default aggregate_fn that is used if not specified while registering the layer. + specifies how inputs should be aggregated over time. + The aggregate_fn should usually take 2 torch tensors and return the aggregated tensor. + Example + def add_agg_fn(tensor1, tensor2): return tensor1 + tensor2 + reduce_fn (Optional, Callable): + default reduce_fn that is used if not specified while registering the layer. + reduce_fn will be called on the aggregated tensor i.e. the tensor obtained after + calling agg_fn() on all inputs. + Example + def mean_reduce_fn(agg_tensor): return agg_tensor.mean(dim=0) + mask_fn (Optional, Callable): + default mask_fn that is used to create the sparsification mask using the tensor obtained after + calling the reduce_fn(). This is used by default if a custom one is passed in the + register_layer(). + Note that the mask_fn() definition should contain the sparse arguments that is passed in sparse_config + arguments. + features (Optional, list): + default selected features to sparsify. + If this is non-empty, then the mask_fn will be applied for each feature of the input. + For example, + mask = [mask_fn(reduce_fn(aggregated_fn(input[feature])) for feature in features] + feature_dim (Optional, int): + default dimension of input features. Again, features along this dim will be chosen + for sparsification. + sparse_config (Dict): + Default configuration for the mask_fn. This config will be passed + with the mask_fn() + + Example: + >>> # xdoctest: +SKIP + >>> model = SomeModel() + >>> act_sparsifier = ActivationSparsifier(...) # init activation sparsifier + >>> # Initialize aggregate_fn + >>> def agg_fn(x, y): + >>> return x + y + >>> + >>> # Initialize reduce_fn + >>> def reduce_fn(x): + >>> return torch.mean(x, dim=0) + >>> + >>> # Initialize mask_fn + >>> def mask_fn(data): + >>> return torch.eye(data.shape).to(data.device) + >>> + >>> + >>> act_sparsifier.register_layer( + ... model.some_layer, + ... aggregate_fn=agg_fn, + ... reduce_fn=reduce_fn, + ... mask_fn=mask_fn, + ... ) + >>> + >>> # start training process + >>> for _ in [...]: + >>> # epoch starts + >>> # model.forward(), compute_loss() and model.backwards() + >>> # epoch ends + >>> act_sparsifier.step() + >>> # end training process + >>> sparsifier.squash_mask() + """ + + def __init__( + self, + model: nn.Module, + aggregate_fn=None, + reduce_fn=None, + mask_fn=None, + features=None, + feature_dim=None, + **sparse_config, + ): + self.model = model + self.defaults: dict[str, Any] = defaultdict() + self.defaults["sparse_config"] = sparse_config + + # functions + self.defaults["aggregate_fn"] = aggregate_fn + self.defaults["reduce_fn"] = reduce_fn + self.defaults["mask_fn"] = mask_fn + + # default feature and feature_dim + self.defaults["features"] = features + self.defaults["feature_dim"] = feature_dim + + self.data_groups: dict[str, dict] = defaultdict( + dict + ) # contains all relevant info w.r.t each registered layer + + self.state: dict[str, Any] = defaultdict(dict) # layer name -> mask + + @staticmethod + def _safe_rail_checks(args): + """Makes sure that some of the functions and attributes are not passed incorrectly""" + + # if features are not None, then feature_dim must not be None + features, feature_dim = args["features"], args["feature_dim"] + if features is not None: + if feature_dim is None: + raise AssertionError("need feature dim to select features") + + # all the *_fns should be callable + fn_keys = ["aggregate_fn", "reduce_fn", "mask_fn"] + for key in fn_keys: + fn = args[key] + if not callable(fn): + raise AssertionError(f"{fn} must be callable") + + def _aggregate_hook(self, name): + """Returns hook that computes aggregate of activations passing through.""" + + # gather some data + feature_dim = self.data_groups[name]["feature_dim"] + features = self.data_groups[name]["features"] + agg_fn = self.data_groups[name]["aggregate_fn"] + + def hook(module, input) -> None: + input_data = input[0] + + data = self.data_groups[name].get("data") # aggregated data + if features is None: + # no features associated, data should not be a list + if data is None: + data = torch.zeros_like(input_data) + self.state[name]["mask"] = torch.ones_like(input_data) + out_data = agg_fn(data, input_data) + else: + # data should be a list [aggregated over each feature only] + if data is None: + out_data = [ + 0 for _ in range(len(features)) + ] # create one in case of 1st forward + self.state[name]["mask"] = [0 for _ in range(len(features))] + else: + out_data = data # a list + + # compute aggregate over each feature + for feature_idx in range(len(features)): + # each feature is either a list or scalar, convert it to torch tensor + feature_tensor = ( + torch.Tensor([features[feature_idx]]) + .long() + .to(input_data.device) + ) + data_feature = torch.index_select( + input_data, feature_dim, feature_tensor + ) + if data is None: + curr_data = torch.zeros_like(data_feature) + self.state[name]["mask"][feature_idx] = torch.ones_like( + data_feature + ) + else: + curr_data = data[feature_idx] + out_data[feature_idx] = agg_fn(curr_data, data_feature) + self.data_groups[name]["data"] = out_data + + return hook + + def register_layer( + self, + layer: nn.Module, + aggregate_fn=None, + reduce_fn=None, + mask_fn=None, + features=None, + feature_dim=None, + **sparse_config, + ): + r""" + Registers a layer for sparsification. The layer should be part of self.model. + Specifically, registers a pre-forward hook to the layer. The hook will apply the aggregate_fn + and store the aggregated activations that is input over each step. + + Note:: + - There is no need to pass in the name of the layer as it is automatically computed as per + the fqn convention. + + - All the functions (fn) passed as argument will be called at a dim, feature level. + """ + name = module_to_fqn(self.model, layer) + if name is None: + raise AssertionError("layer not found in the model") + + if name in self.data_groups: # unregister layer if already present + warnings.warn( + "layer already attached to the sparsifier, deregistering the layer and registering with new config", + stacklevel=2, + ) + self.unregister_layer(name=name) + + local_args = copy.deepcopy(self.defaults) + update_dict = { + "aggregate_fn": aggregate_fn, + "reduce_fn": reduce_fn, + "mask_fn": mask_fn, + "features": features, + "feature_dim": feature_dim, + "layer": layer, + } + local_args.update( + (arg, val) for arg, val in update_dict.items() if val is not None + ) + local_args["sparse_config"].update(sparse_config) + + self._safe_rail_checks(local_args) + + self.data_groups[name] = local_args + agg_hook = layer.register_forward_pre_hook(self._aggregate_hook(name=name)) + + self.state[name]["mask"] = ( + None # mask will be created when model forward is called. + ) + + # attach agg hook + self.data_groups[name]["hook"] = agg_hook + + # for serialization purposes, we know whether aggregate_hook is attached + # or sparsify_hook() + self.data_groups[name]["hook_state"] = "aggregate" # aggregate hook is attached + + def get_mask(self, name: str | None = None, layer: nn.Module | None = None): + """ + Returns mask associated to the layer. + + The mask is + - a torch tensor is features for that layer is None. + - a list of torch tensors for each feature, otherwise + + Note:: + The shape of the mask is unknown until model.forward() is applied. + Hence, if get_mask() is called before model.forward(), an + error will be raised. + """ + if name is None and layer is None: + raise AssertionError("Need at least name or layer obj to retrieve mask") + + if name is None: + if layer is None: + raise AssertionError("layer must be provided when name is None") + name = module_to_fqn(self.model, layer) + if name is None: + raise AssertionError("layer not found in the specified model") + + if name not in self.state: + raise ValueError("Error: layer with the given name not found") + + mask = self.state[name].get("mask", None) + + if mask is None: + raise ValueError( + "Error: shape unknown, call layer() routine at least once to infer mask" + ) + return mask + + def unregister_layer(self, name): + """Detaches the sparsifier from the layer""" + + # detach any hooks attached + self.data_groups[name]["hook"].remove() + + # pop from the state dict + self.state.pop(name) + + # pop from the data groups + self.data_groups.pop(name) + + def step(self): + """Internally calls the update_mask() function for each layer""" + with torch.no_grad(): + for name, configs in self.data_groups.items(): + data = configs["data"] + self.update_mask(name, data, configs) + + self.data_groups[name].pop("data") # reset the accumulated data + + def update_mask(self, name, data, configs): + """ + Called for each registered layer and does the following- + 1. apply reduce_fn on the aggregated activations + 2. use mask_fn to compute the sparsification mask + + Note: + the reduce_fn and mask_fn is called for each feature, dim over the data + """ + mask = self.get_mask(name) + sparse_config = configs["sparse_config"] + features = configs["features"] + reduce_fn = configs["reduce_fn"] + mask_fn = configs["mask_fn"] + if features is None: + data = reduce_fn(data) + mask.data = mask_fn(data, **sparse_config) + else: + for feature_idx in range(len(features)): + data_feature = reduce_fn(data[feature_idx]) + mask[feature_idx].data = mask_fn(data_feature, **sparse_config) + + def _sparsify_hook(self, name): + """Returns hook that applies sparsification mask to input entering the attached layer""" + mask = self.get_mask(name) + features = self.data_groups[name]["features"] + feature_dim = self.data_groups[name]["feature_dim"] + + def hook(module, input): + input_data = input[0] + if features is None: + # apply to all the features + return input_data * mask + else: + # apply per feature, feature_dim + for feature_idx in range(len(features)): + feature = ( + torch.Tensor([features[feature_idx]]) + .long() + .to(input_data.device) + ) + sparsified = ( + torch.index_select(input_data, feature_dim, feature) + * mask[feature_idx] + ) + input_data.index_copy_(feature_dim, feature, sparsified) + return input_data + + return hook + + def squash_mask(self, attach_sparsify_hook=True, **kwargs): + """ + Unregisters aggregate hook that was applied earlier and registers sparsification hooks if + attach_sparsify_hook = True. + """ + for name, configs in self.data_groups.items(): + # unhook agg hook + configs["hook"].remove() + configs.pop("hook") + self.data_groups[name]["hook_state"] = "None" + if attach_sparsify_hook: + configs["hook"] = configs["layer"].register_forward_pre_hook( + self._sparsify_hook(name) + ) + configs["hook_state"] = ( + "sparsify" # signals that sparsify hook is now attached + ) + + def _get_serializable_data_groups(self): + """Exclude hook and layer from the config keys before serializing + + TODO: Might have to treat functions (reduce_fn, mask_fn etc) in a different manner while serializing. + For time-being, functions are treated the same way as other attributes + """ + data_groups: dict[str, Any] = defaultdict() + for name, config in self.data_groups.items(): + new_config = { + key: value + for key, value in config.items() + if key not in ["hook", "layer"] + } + data_groups[name] = new_config + return data_groups + + def _convert_mask(self, states_dict, sparse_coo=True): + r"""Converts the mask to sparse coo or dense depending on the `sparse_coo` argument. + If `sparse_coo=True`, then the mask is stored as sparse coo else dense tensor + """ + states = copy.deepcopy(states_dict) + for state in states.values(): + if state["mask"] is not None: + if isinstance(state["mask"], list): + for idx in range(len(state["mask"])): + if sparse_coo: + state["mask"][idx] = state["mask"][idx].to_sparse_coo() + else: + state["mask"][idx] = state["mask"][idx].to_dense() + else: + if sparse_coo: + state["mask"] = state["mask"].to_sparse_coo() + else: + state["mask"] = state["mask"].to_dense() + return states + + def state_dict(self) -> dict[str, Any]: + r"""Returns the state of the sparsifier as a :class:`dict`. + + It contains: + * state - contains name -> mask mapping. + * data_groups - a dictionary containing all config information for each + layer + * defaults - the default config while creating the constructor + """ + data_groups = self._get_serializable_data_groups() + state = self._convert_mask(self.state) + return {"state": state, "data_groups": data_groups, "defaults": self.defaults} + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + r"""The load_state_dict() restores the state of the sparsifier based on the state_dict + + Args: + * state_dict - the dictionary that to which the current sparsifier needs to be restored to + """ + state = state_dict["state"] + data_groups, defaults = state_dict["data_groups"], state_dict["defaults"] + + self.__set_state__( + {"state": state, "data_groups": data_groups, "defaults": defaults} + ) + + def __get_state__(self) -> dict[str, Any]: + data_groups = self._get_serializable_data_groups() + state = self._convert_mask(self.state) + return { + "defaults": self.defaults, + "state": state, + "data_groups": data_groups, + } + + def __set_state__(self, state: dict[str, Any]) -> None: + state["state"] = self._convert_mask( + state["state"], sparse_coo=False + ) # convert mask to dense tensor + self.__dict__.update(state) + + # need to attach layer and hook info into the data_groups + for name, config in self.data_groups.items(): + # fetch layer + layer = fqn_to_module(self.model, name) + if layer is None: + raise AssertionError(f"layer {name} not found in the model") + + # if agg_mode is True, then layer in aggregate mode + if "hook_state" in config and config["hook_state"] == "aggregate": + hook = layer.register_forward_pre_hook(self._aggregate_hook(name)) + + elif "hook_state" in config and config["hook_state"] == "sparsify": + hook = layer.register_forward_pre_hook(self._sparsify_hook(name)) + + config["layer"] = layer + config["hook"] = hook # type: ignore[possibly-undefined] + + def __repr__(self): + format_string = self.__class__.__name__ + " (" + for name, config in self.data_groups.items(): + format_string += "\n" + format_string += "\tData Group\n" + format_string += f"\t name: {name}\n" + for key in sorted(config.keys()): + if key in ["data", "hook", "reduce_fn", "mask_fn", "aggregate_fn"]: + continue + format_string += f"\t {key}: {config[key]}\n" + format_string += ")" + return format_string diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_scheduler/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_scheduler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1a7564fe408b36e5fb62eb4cb2272ef432095981 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_scheduler/__init__.py @@ -0,0 +1,6 @@ +from .base_data_scheduler import BaseDataScheduler + + +__all__ = [ + "BaseDataScheduler", +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..c2f48abfc9deec2393816ffd227cc414f9f14a29 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py @@ -0,0 +1,199 @@ +# mypy: allow-untyped-defs +import abc +import warnings +import weakref +from functools import wraps + +from torch.ao.pruning._experimental.data_sparsifier import BaseDataSparsifier + + +__all__ = ["BaseDataScheduler"] + + +class BaseDataScheduler: + r""" + The BaseDataScheduler is the abstract scheduler class specifically for the + BaseDataSparsifier class. This class controls a specific hyperparameter of + the sparsifier class and varies it across the training process (or across time). + + Args: + data_sparsifier (instance of BaseDataSparsifier) + Implemented class data sparsifier class wherein the update_mask is implemented + schedule_param (str) + A specific hyperparameter of the passed sparsifier that needs to be scheduled/varied + last_epoch (int, default=-1) + This is specifically is passed when training needs to be resumed from a particular + point. + verbose (bool, default=False) + Verbosity of the BaseDataScheduler + + The *get_hyperparam()* function needs to be implemented by the user. + """ + + def __init__( + self, data_sparsifier, schedule_param: str, last_epoch=-1, verbose=False + ): + # Attach sparsifier + if not isinstance(data_sparsifier, BaseDataSparsifier): + raise TypeError( + f"{type(data_sparsifier).__name__} is not an instance of torch.ao.pruning.BaseDataSparsifier" + ) + self.data_sparsifier = data_sparsifier + self.schedule_param = schedule_param + + # Initialize epoch and base hyper-params + self.base_param = { + name: config.get(schedule_param, None) + for name, config in self.data_sparsifier.data_groups.items() + } + + self.last_epoch = last_epoch + + # Following https://github.com/pytorch/pytorch/issues/20124 + # We would like to ensure that `scheduler.step()` is called after + # `sparsifier.step()` + def with_counter(method): + if getattr(method, "_with_counter", False): + # `sparsifier.step()` has already been replaced, return. + return method + + # Keep a weak reference to the sparsifier instance to prevent + # cyclic references. + instance_ref = weakref.ref(method.__self__) + # Get the unbound method for the same purpose. + func = method.__func__ + cls = instance_ref().__class__ + del method + + @wraps(func) + def wrapper(*args, **kwargs): + instance = instance_ref() + instance._step_count += 1 # type: ignore[union-attr] + wrapped = func.__get__(instance, cls) + return wrapped(*args, **kwargs) + + # Note that the returned function here is no longer a bound method, + # so attributes like `__func__` and `__self__` no longer exist. + wrapper._with_counter = True # type: ignore[attr-defined] + return wrapper + + self.data_sparsifier.step = with_counter(self.data_sparsifier.step) # type: ignore[assignment] + self.data_sparsifier._step_count = 0 # type: ignore[attr-defined] + self._step_count: int = 0 + self.verbose = verbose + + # Housekeeping + self._get_sp_called_within_step: bool = False # sp -> schedule parameter + self.step() + + @abc.abstractmethod + def get_schedule_param(self): + r""" + Abstract method that needs to be implemented by the child class. + The expected return type should is a dictionary of name to schedule_param value + The returned values will be updated in sparsifier when the scheduler step() function + is called. + + Example: + >>> def get_schedule_param(self): + ... new_param = {} + ... for name in self.sparsifier.data_groups.keys(): + ... new_param[name] = ( + ... self.sparsifier.data_groups[name][self.schedule_param] * 0.5 + ... ) + ... return new_param + + When the step() function is called, the value in self.sparsifier.data_groups[name][self.schedule_param] + would be halved + """ + raise NotImplementedError + + def __repr__(self): + format_string = self.__class__.__name__ + " (" + format_string += "\n" + format_string += f"Data Sparsifier {self.data_sparsifier}\n" + format_string += f" {self.schedule_param}: {self.base_param}\n" + format_string += ")" + return format_string + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the sparsifier. + + Note: + The scheduler class does not track the state of the data_sparsifier. + Make sure to store the state of the sparsifier before storing the + state of the scheduler + """ + return { + key: value + for key, value in self.__dict__.items() + if key != "data_sparsifier" + } + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Note: + Remember to restore the state of the data_sparsifier before the scheduler. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_param(self): + return self._last_param + + def step(self): + # Raise warning if trying to call scheduler step before the sparsifier. + # https://github.com/pytorch/pytorch/issues/20124 + if self._step_count == 1: + if not hasattr(self.data_sparsifier.step, "_with_counter"): + warnings.warn( + "Seems like `data_sparsifier.step()` has been overridden after sparsity scheduler " + "initialization. Please, make sure to call `data_sparsifier.step()` before " + "`scheduler.step()`.", + UserWarning, + stacklevel=2, + ) + + # Just check if there were two first scheduler.step() calls before sparsifier.step() + elif self.data_sparsifier._step_count < 1: # type: ignore[attr-defined] + warnings.warn( + "Detected call of `scheduler.step()` before `data_sparsifier.step()`. " + "You have to make sure you run the data_sparsifier.step() BEFORE any " + "calls to the scheduler.step().", + UserWarning, + stacklevel=2, + ) + self._step_count += 1 + + class _enable_get_sp_call: + def __init__(self, o): + self.o = o + + def __enter__(self): + self.o._get_sp_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_sp_called_within_step = False + + with _enable_get_sp_call(self): + self.last_epoch += 1 + updated_scheduler_params = self.get_schedule_param() + + for name, param in updated_scheduler_params.items(): + self.data_sparsifier.data_groups[name][self.schedule_param] = param + if self.verbose: + print(f"Adjusting {self.schedule_param} for group {name} to {param}") + + self._last_param = { + name: config.get(self.schedule_param, None) + for name, config in self.data_sparsifier.data_groups.items() + } + self.data_sparsifier.enable_mask_update = True diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b1b5b9b96ec96fffdb0b66e21686a927a0c41b4a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__init__.py @@ -0,0 +1,8 @@ +from .base_data_sparsifier import BaseDataSparsifier +from .data_norm_sparsifier import DataNormSparsifier + + +__all__ = [ + "BaseDataSparsifier", + "DataNormSparsifier", +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5dc4a538f2ae7f528fcb244c8c3339e31d15c335 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/base_data_sparsifier.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/base_data_sparsifier.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed6beb207df4b9f97fed9dae34e2cdc89844b61b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/base_data_sparsifier.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/data_norm_sparsifier.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/data_norm_sparsifier.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c257c977f4b5ad4aef778a9c9bc71bd439d7b0a4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/data_norm_sparsifier.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/quantization_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/quantization_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec662e81531d723b64c931a4eb8d28d2ae7acf39 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/quantization_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py new file mode 100644 index 0000000000000000000000000000000000000000..e76b5ccd7b5c571f636cce6a2f8beb907f50004e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py @@ -0,0 +1,334 @@ +# mypy: allow-untyped-defs +import abc +import copy +import sys +import warnings +from collections import defaultdict +from typing import Any + +import torch +from torch import nn +from torch.ao.pruning.sparsifier import base_sparsifier, utils +from torch.nn.utils import parametrize + + +if not sys.warnoptions: + # to suppress repeated warnings when being used in a training loop. + warnings.simplefilter("once") + +__all__ = ["BaseDataSparsifier"] + +EMBEDDING_TYPES = { + nn.Embedding, + nn.EmbeddingBag, +} + +SUPPORTED_TYPES = { + torch.Tensor, + nn.Parameter, + *EMBEDDING_TYPES, +} + + +class _Container(nn.Module): + pass + + +class BaseDataSparsifier(base_sparsifier.BaseSparsifier): + r""" + Base Data Sparsifier class for all Data sparsifiers. + The abstract class accepts raw torch tensors / embedding / embedding bags (refer to SUPPORTED_TYPES above) + to prepare for sparsification. + In this case, mask (and parametrizations) is owned by the class and not by the user. + Specifically, the container object inside the class maintains the mask and parametrizations of the input data + + Args: + data_list (list of tuples) + list of (name, data) tuples to sparsify. Lookup SUPPORTED_TYPES + for type of data. Internally, a container module handles the data sparsification. + + defaults (dict) + default configurations will be attached to the + configuration. Only the keys that don't exist in the `config` will + be updated. + Example:: + >>> # xdoctest: +SKIP + >>> data_list = [('tensor_1', torch.randn(3,3)), ('tensor_2', torch.randn(4,4))] + >>> defaults = {'sparsity_level': 0.7} + >>> sparsifier = DerivedDataSparsifier(data_list = data_list, **defaults) # Some sparsifier that inherits BaseDataSparsifier + >>> new_tensor_to_add = {'name': 'tensor_3', 'data': torch.randn(5,5), 'sparsity_level': 0.3} + >>> sparsifier.add_data(**new_tensor_to_add) + >>> # tensor_1 and tensor_2 will have sparsity_level of 0.7 but tensor_3 will have sparsity_level=0.3 + """ + + def __init__(self, data_list: list[tuple[str, Any]] | None = None, **defaults): + super().__init__(defaults=defaults) + + self._container = _Container() + + self.data_groups: dict[str, dict] = defaultdict(dict) # name -> {**config} + if data_list is not None: + # add data with default config here + [self.add_data(name, data, **self.defaults) for name, data in data_list] + + def prepare(self, model, config): + raise NotImplementedError("this function is undefined for this class") + + def _extract_weight(self, data): + # extract the weight parameter instead of underlying data + if type(data) in [torch.Tensor, nn.Parameter]: + return data + elif type(data) in EMBEDDING_TYPES: + return data.weight + + def add_data(self, name: str, data, reuse_mask=True, **config): + r"""Configures and parametrizes the internal container model with name and data. + + **Note**: + 1. If the data with name already exists, it replaces the data. + 2. While replacing, the old mask is reused when `reuse_mask=True` + 3. If `reuse_mask=True`, then the replacing data needs to have the same shape as that of old data. + 4. By default, the config of the replaced data is used as config for the replacing data, unless something + is specified in the config dictionary. + """ + if type(data) not in SUPPORTED_TYPES: + raise AssertionError( + f"specified data type:{type(data)} not supported at the moment" + ) + local_args = copy.deepcopy(self.defaults) + local_args.update(config) + weight = self._extract_weight(data) + + # Bookkeeping in the container class + mask = local_args.get("mask", torch.ones_like(weight)) + param_class = local_args.get("parametrization", utils.FakeSparsity) + + if name in self.state: + # If the named data already exists - replace + warnings.warn( + "Replacing existing data of the same name. - Did you mean a different name?", + stacklevel=2, + ) + + # reuse old config + old_args = self.data_groups[name] + local_args = copy.deepcopy(old_args) + local_args.update(config) + + if reuse_mask: + current_data = self.get_data(name=name) + if weight.shape != current_data.shape: + raise AssertionError( + "to retain the old mask, the shape of the new data must be the same as the previous one" + ) + mask = self.get_mask( + name=name + ) # reuse mask instead of creating a new one + + self._delete_data(name=name) + + # parameter creates a deepcopy of the weight inside, so create a buffer + self._container.register_buffer(name=name, tensor=weight) + parametrize.register_parametrization(self._container, name, param_class(mask)) + self.state[name]["mask"] = mask + self.data_groups[name] = local_args + return getattr(self._container, name) + + def get_data(self, name: str, return_original: bool = True): + r"""Returns weight tensor (or data) + Args: + - name: name of the data to be returned + - return_original returns weight tensor without applying parametrization if True + else - returns the sparsified version (parametrized) + """ + if name not in self.data_groups: + raise ValueError("data with specified name does not exist") + + if return_original: + if not parametrize.is_parametrized(self._container, name): + raise ValueError("mask squashed - original mask value does not exist") + data = getattr(self._container.parametrizations, name).original + return data + else: + return getattr(self._container, name) + + def _convert_mask(self, states, sparse_coo=True): + r"""Converts the mask to sparse coo or dense tensors depending on the `sparse_coo` argument.""" + states = copy.deepcopy(states) + for state in states.values(): + if sparse_coo: + state["mask"] = state["mask"].to_sparse_coo() + else: + state["mask"] = state["mask"].to_dense() + + return states + + def state_dict(self): + r"""Returns the state of the optimizer as a :class:`dict`. + + It contains: + * state - contains name -> mask mapping. + * data_groups - a list containing all sparsity configuration groups + with the key name specifying the name of the data + * container_state_dict - the state dictionary of the internal + container model used for sparsification + """ + state = self._convert_mask(self.state) + return { + "state": state, + "data_groups": self.data_groups, + "_container": self._container.state_dict(), + } + + def _load_container_from_state(self, states, data_groups, container_state_dict): + r"""This restores the state of the container specifically based on the data present in state and data_groups + If the data was parametrized, then the data would be added to the container and then parametrized, + else it would just add the attribute the container. + """ + for name, state in states.items(): + config_name = data_groups.get(name, None) + if config_name is None: + raise RuntimeError(f"Error loading {name}") + + # check if the data with such a name was parametrized, if so parametrize + # otherwise just set the attribute and continue + parametrized_name = f"parametrizations.{name}.original" + parametrized = False + data = container_state_dict.get(name, None) + if name in container_state_dict: + # the parametrization was probably removed for this + data = container_state_dict.get(name) + + elif parametrized_name in container_state_dict: + # so the weight was parametrized + data = container_state_dict.get(parametrized_name) + parametrized = True + + else: + raise RuntimeError(f"Error loading {name}") + + self._container.register_buffer(name=name, tensor=data) + + if parametrized: + # register parameter if parametrized + mask = state.get("mask", torch.ones_like(data)) + param_class = data_groups.get( + "parametrization", utils.FakeSparsity + ) # change once public_api for utils is fixed! + parametrize.register_parametrization( + self._container, name, param_class(mask) + ) + + def load_state_dict(self, state_dict, strict=True): + r"""The load_state_dict() restores the state of the sparsifier based on the state_dict + + Args: + * state_dict - the dictionary that to which the current sparsifier needs to be restored to + * strict - If True - the sparsifier is reset and is restored exactly to the state in state_dict. + If False - the current sparsifier is not reset before loading the state_dict i.e. data added + before loading the state_dict is not erased. + """ + states = copy.deepcopy(state_dict["state"]) + data_groups = copy.deepcopy(state_dict["data_groups"]) + container_state_dict = copy.deepcopy(state_dict["_container"]) + + states = self._convert_mask( + states, sparse_coo=False + ) # convert sparse coo mask to dense + if strict: + # if strict load -> then reset container + self._container = _Container() + + self._load_container_from_state(states, data_groups, container_state_dict) + + if not strict: + states.update(self.state) + data_groups.update(self.data_groups) + + self.__setstate__({"state": states, "data_groups": data_groups}) + + def __setstate__(self, state): + if "_container" in state: # If container object is in state then load model + container_dict = state.pop("_container") + self._container = _Container() + state["state"] = self._convert_mask( + state["state"], sparse_coo=False + ) # convert sparse coo mask to dense + self._load_container_from_state( + state["state"], state["data_groups"], container_dict + ) + + self.__dict__.update(state) + + def __getstate__(self): + state = self._convert_mask(self.state) + return { + "defaults": self.defaults, + "state": state, + "data_groups": self.data_groups, + "_container": self._container.state_dict(), + } + + def __repr__(self): # type:ignore[override] + format_string = self.__class__.__name__ + " (" + for name, sparse_args in self.data_groups.items(): + format_string += "\n" + format_string += "\tData Group\n" + format_string += f"\t name: {name}\n" + for key in sorted(sparse_args.keys()): + if key == "data": + continue + format_string += f"\t {key}: {sparse_args[key]}\n" + format_string += ")" + return format_string + + def get_mask(self, name: str): + if name not in self.state: + raise ValueError("data with specified name does not exist") + return self.state[name]["mask"] + + def squash_mask(self, *args, leave_parametrized=True, names=None, **kwargs): + r"""Squashes the sparse masks into the appropriate tensors. Also, accepts list of strings + to squash mask for. If none, squashes mask for all the keys + kwargs: + * names: list of strings to squash mask for + * sparsified: if true - applies the mask before squashing + if false - does not apply the mask before squashing + """ + if names is None: + names = list(self.data_groups.keys()) + for name in names: + parametrize.remove_parametrizations( + self._container, name, leave_parametrized=leave_parametrized + ) + + def step(self): # type:ignore[override] + if not self.enable_mask_update: + return + with torch.no_grad(): + for name, config in self.data_groups.items(): + # get non-sparsified data + data = self.get_data(name) + # need name for the mask otherwise can directly pass mask? + self.update_mask(name, data, **config) + + @abc.abstractmethod + def update_mask(self, name, data, **kwargs): # type: ignore[override] + pass + + def _delete_data(self, name): + """Detaches some data from the sparsifier. + + Args: + name (str) + Name of the data to be removed from the sparsifier + + Note: + Currently private. Kind of used as a helper function when replacing data of the same name + """ + self.squash_mask( + names=[name], leave_parametrized=False + ) # do not apply the mask while deleting + delattr(self._container, name) + self.state.pop(name) + self.data_groups.pop(name) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py new file mode 100644 index 0000000000000000000000000000000000000000..aa2971cd0b3d0cae7afb8763bee319ac819ad2f8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py @@ -0,0 +1,204 @@ +# mypy: allow-untyped-defs +import operator +from functools import reduce +from typing import Any + +import torch +from torch.nn import functional as F + +from .base_data_sparsifier import BaseDataSparsifier + + +__all__ = ["DataNormSparsifier"] + + +class DataNormSparsifier(BaseDataSparsifier): + r"""L1-Norm Sparsifier + This sparsifier computes the *L1-norm* of every sparse block and "zeroes-out" the + ones with the lowest norm. The level of sparsity defines how many of the + blocks is removed. + This sparsifier is controlled by three variables: + 1. `sparsity_level` defines the number of *sparse blocks* that are zeroed-out + 2. `sparse_block_shape` defines the shape of the sparse blocks. Note that + the sparse blocks originate at the zero-index of the tensor. + 3. `zeros_per_block` is the number of zeros that we are expecting in each + sparse block. By default we assume that all elements within a block are + zeroed-out. However, setting this variable sets the target number of + zeros per block. The zeros within each block are chosen as the *smallest + absolute values*. + Args: + sparsity_level: The target level of sparsity + sparse_block_shape: The shape of a sparse block + zeros_per_block: Number of zeros in a sparse block + Note:: + All arguments to the DataNormSparsifier constructor are "default" + arguments and could be overridden by the configuration provided in the + `add_data` step. + """ + + def __init__( + self, + data_list: list[tuple[str, Any]] | None = None, + sparsity_level: float = 0.5, + sparse_block_shape: tuple[int, int] = (1, 4), + zeros_per_block: int | None = None, + norm: str = "L1", + ): + if zeros_per_block is None: + zeros_per_block = reduce(operator.mul, sparse_block_shape) + + if norm not in ["L1", "L2"]: + raise AssertionError("only L1 and L2 norm supported at the moment") + + defaults = { + "sparsity_level": sparsity_level, + "sparse_block_shape": sparse_block_shape, + "zeros_per_block": zeros_per_block, + } + self.norm = norm + super().__init__(data_list=data_list, **defaults) + + def __get_scatter_folded_mask( + self, data, dim, indices, output_size, sparse_block_shape + ): + mask = torch.ones_like(data) + mask.scatter_(dim=dim, index=indices, value=0) # zeroing out + mask = F.fold( + mask, + output_size=output_size, + kernel_size=sparse_block_shape, + stride=sparse_block_shape, + ) + mask = mask.to(torch.int8) + return mask + + def __get_block_level_mask(self, data, sparse_block_shape, zeros_per_block): + # Assume data is a squeezed tensor + height, width = data.shape[-2], data.shape[-1] + block_height, block_width = sparse_block_shape + values_per_block = block_height * block_width + + # just return zeros if zeroing all elements in block + if values_per_block == zeros_per_block: + return torch.zeros_like(data, dtype=torch.int8) + + # creating additional height and width to support padding + dh = (block_height - height % block_height) % block_height + dw = (block_width - width % block_width) % block_width + + # create a new padded tensor like data (to match the block_shape) + padded_data = torch.ones( + height + dh, width + dw, dtype=data.dtype, device=data.device + ) + padded_data = ( + padded_data * torch.nan + ) # can also be replaced with 0 to stop the removal of edge data + padded_data[0:height, 0:width] = data + unfolded_data = F.unfold( + padded_data[None, None, :], + kernel_size=sparse_block_shape, + stride=sparse_block_shape, + ) + + _, sorted_idx = torch.sort(unfolded_data, dim=1) + sorted_idx = sorted_idx[ + :, :zeros_per_block, : + ] # zero out zeros_per_block number of elements + + mask = self.__get_scatter_folded_mask( + data=unfolded_data, + dim=1, + indices=sorted_idx, + output_size=padded_data.shape, + sparse_block_shape=sparse_block_shape, + ) + + mask = ( + mask.squeeze(0).squeeze(0)[:height, :width].contiguous() + ) # remove padding and make contiguous + return mask + + def __get_data_level_mask(self, data, sparsity_level, sparse_block_shape): + height, width = data.shape[-2], data.shape[-1] + block_height, block_width = sparse_block_shape + dh = (block_height - height % block_height) % block_height + dw = (block_width - width % block_width) % block_width + + data_norm = F.avg_pool2d( + data[None, None, :], + kernel_size=sparse_block_shape, + stride=sparse_block_shape, + ceil_mode=True, + ) + + values_per_block = reduce(operator.mul, sparse_block_shape) + + data_norm = data_norm.flatten() + num_blocks = len(data_norm) + + data_norm = data_norm.repeat( + 1, values_per_block, 1 + ) # get similar shape after unfold + _, sorted_idx = torch.sort(data_norm, dim=2) + + threshold_idx = round(sparsity_level * num_blocks) # number of blocks to remove + sorted_idx = sorted_idx[:, :, :threshold_idx] + + mask = self.__get_scatter_folded_mask( + data=data_norm, + dim=2, + indices=sorted_idx, + output_size=(height + dh, width + dw), + sparse_block_shape=sparse_block_shape, + ) + + mask = mask.squeeze(0).squeeze(0)[ + :height, :width + ] # squeeze only the first 2 dimension + return mask + + def update_mask( # type: ignore[override] + self, name, data, sparsity_level, sparse_block_shape, zeros_per_block, **kwargs + ): + values_per_block = reduce(operator.mul, sparse_block_shape) + if zeros_per_block > values_per_block: + raise ValueError( + "Number of zeros per block cannot be more than " + "the total number of elements in that block." + ) + if zeros_per_block < 0: + raise ValueError("Number of zeros per block should be positive.") + + if self.norm == "L1": + data_norm = torch.abs(data).squeeze() # absolute value based (L1) + else: + data_norm = (data * data).squeeze() # square every element for L2 + + if len(data_norm.shape) > 2: # only supports 2 dimensional data at the moment + raise ValueError("only supports 2-D at the moment") + + elif len(data_norm.shape) == 1: # in case the data is bias (or 1D) + data_norm = data_norm[None, :] + + mask = self.get_mask(name) + if sparsity_level <= 0 or zeros_per_block == 0: + mask.data = torch.ones_like(mask) + elif sparsity_level >= 1.0 and (zeros_per_block == values_per_block): + mask.data = torch.zeros_like(mask) + + # Fetch the high level mask that zeros out entire blocks + data_lvl_mask = self.__get_data_level_mask( + data=data_norm, + sparsity_level=sparsity_level, + sparse_block_shape=sparse_block_shape, + ) + + # Fetch block level mask that zeros out 'zeros_per_block' number of elements in every block + block_lvl_mask = self.__get_block_level_mask( + data=data_norm, + sparse_block_shape=sparse_block_shape, + zeros_per_block=zeros_per_block, + ) + + # zero out the entries inside those blocks whose block is sparsified + mask.data = torch.where(data_lvl_mask == 1, data_lvl_mask, block_lvl_mask) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95d00214667a87bc65b841d2170bf5caf2903e46 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/_data_sparstity_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/_data_sparstity_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..792ab2934f8ba41261223b68181f8447cf8a9134 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/_data_sparstity_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/data_sparsity.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/data_sparsity.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c0ca531535d482dfc6d8c81138ea108f83c4e5b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/data_sparsity.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/_data_sparstity_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/_data_sparstity_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..50d5684961bc807d5ae1b02615ade168416c9b3d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/_data_sparstity_utils.py @@ -0,0 +1,44 @@ +# mypy: allow-untyped-defs +import logging + +from torch.ao.pruning._experimental.data_sparsifier.base_data_sparsifier import ( + SUPPORTED_TYPES, +) + + +logger: logging.Logger = logging.getLogger(__name__) + + +def _attach_model_to_data_sparsifier(module, data_sparsifier, config=None): + """Attaches a data sparsifier to all the layers of the module. + Essentially, loop over all the weight parameters in the module and + attach it to the data sparsifier. + Note:: + The '.' in the layer names are replaced with '_' (refer to _get_valid_name() below) + before attaching to the sparsifier. This is because, the data + sparsifier uses a dummy model inside to store the weight parameters. + """ + if config is None: + config = {} + for name, parameter in module.named_parameters(): + if type(parameter) in SUPPORTED_TYPES: + valid_name = _get_valid_name(name) + # will be defaulted to default configs + data_sparsifier.add_data( + name=valid_name, data=parameter, **config.get(valid_name, {}) + ) + + +def _get_valid_name(name): + return name.replace(".", "_") # . is not allowed as a name + + +def _log_sparsified_level(model, data_sparsifier) -> None: + # Show the level of sparsity AFTER step: + for name, parameter in model.named_parameters(): + if type(parameter) not in SUPPORTED_TYPES: + continue + valid_name = _get_valid_name(name) + mask = data_sparsifier.get_mask(name=valid_name) + sparsity_level = 1.0 - mask.float().mean() + logger.info("Sparsity in layer %s = % .2%", name, sparsity_level) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/data_sparsity.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/data_sparsity.py new file mode 100644 index 0000000000000000000000000000000000000000..c1c8a91c5c9dcea9ad5cceaa9ecc80e3f32bd8a7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/data_sparsity.py @@ -0,0 +1,181 @@ +# mypy: allow-untyped-defs +from collections import defaultdict +from copy import deepcopy +from typing import Any, TYPE_CHECKING + +import pytorch_lightning as pl # type: ignore[import] + +from ._data_sparstity_utils import ( + _attach_model_to_data_sparsifier, + _get_valid_name, + _log_sparsified_level, +) + + +if TYPE_CHECKING: + import torch + + +class PostTrainingDataSparsity(pl.callbacks.Callback): + """Lightning callback that enables post-training sparsity. + + This callback aims to sparsify the model inside lightning module after training. + **Note that the model is copied and then sparsified, so the existing model is not modified** + + The sparsified model can be used for comparison and can be accessed using + .sparsified + + Args: + data_sparsifier_class (some implemented class of BaseDataSparsifier) + The data sparsifier object of this class is created when the + training starts. + Note: Objects should not be passed in here as they are created + once the training completes. + + data_sparsifier_args (Dict) + Dictionary of args to be passed to the data sparsifier. + Note: data_list arg should be ignored + + Hooks implemented: + on_fit_end() + 1. copies the model and attaches it to the sparsifier + 2. sparsier step() is called + 3. squashes the mask() + """ + + def __init__(self, data_sparsifier_class, data_sparsifier_args): + super().__init__() + self.data_sparsifier_class = data_sparsifier_class + self.data_sparsifier_args = data_sparsifier_args + self.data_sparsifier: Any = None + self.sparsified: torch.nn.Module | None = None + + def on_fit_end(self, trainer, pl_module) -> None: + self.sparsified = deepcopy(pl_module.model).eval() + self.data_sparsifier = self.data_sparsifier_class(**self.data_sparsifier_args) + + _attach_model_to_data_sparsifier(self.sparsified, self.data_sparsifier) + + self.data_sparsifier.step() + + self.data_sparsifier.squash_mask() # currently squashes params for all mask + + _log_sparsified_level(self.sparsified, self.data_sparsifier) + + +class TrainingAwareDataSparsity(pl.callbacks.Callback): + """Lightning callback that enables in-training sparsity. + + This callback aims to sparsify the model inside lightning module during training. + **Note that the model is copied and then sparsified, so the existing model is not modified** + + The sparsified model can be used for comparison and can be accessed using + .sparsified + + Args: + data_sparsifier_class (some implemented class of BaseDataSparsifier) + The data sparsifier object of this class is created when the + training starts. + Note: Objects should not be passed in here as they are created + when the training starts. + + data_sparsifier_args (Dict) + Dictionary of args to be passed to the data sparsifier. + Note: data_list arg should be ignored + + data_scheduler_class (some implemented class of BaseDataScheduler) + The data scheduler of this class is created when the training starts + Note: Objects should not be passed in here as they are created + when the training starts. + + data_scheduler_args(Dict) + Dictionary of args to be passed to the data scheduler. + **Note: data_sparsifier arg should be ignored as the recipe + creates and pass sparsifier object into the class** + + Hooks implemented: + on_train_start() + Data sparsifier and scheduler objects are created. + Pytorch model attached to the sparsifier + + on_train_epoch_start() + Loads the state_dict of the data sparsifier + + on_train_epoch_end() + 1. Copies the model and attaches it to the sparsifier + 2. sparsifier step() and scheduler step() + 3. Dump state_dict of the current sparsifier + + on_train_end() + squash mask + """ + + def __init__( + self, + data_sparsifier_class, + data_sparsifier_args, + data_scheduler_class, + data_scheduler_args, + ): + super().__init__() + # data sparsifier objects + self.data_sparsifier_class = data_sparsifier_class + self.data_sparsifier_args = data_sparsifier_args + + # scheduler objects + self.data_scheduler_class = data_scheduler_class + self.data_scheduler_args = data_scheduler_args + + # fields + self.data_sparsifier: Any = None + self.data_scheduler: Any = None + self.sparsified: torch.nn.Module | None = None + + self.data_sparsifier_state_dict: Any = None + + def on_train_start(self, trainer, pl_module) -> None: + # create sparsifier + self.data_sparsifier = self.data_sparsifier_class(**self.data_sparsifier_args) + self.sparsified = deepcopy(pl_module.model) + + _attach_model_to_data_sparsifier( + self.sparsified, self.data_sparsifier + ) # just to populate the base_sl in the scheduler + + # create scheduler + args = deepcopy(self.data_scheduler_args) + args["data_sparsifier"] = self.data_sparsifier + self.data_scheduler = self.data_scheduler_class(**args) + + def on_train_epoch_start(self, trainer, pl_module): + if self.data_sparsifier_state_dict is None: + return # probably first epoch + + # load the existing config for each data + self.data_sparsifier.load_state_dict(self.data_sparsifier_state_dict) + + def __create_config_based_on_state(self, pl_module): + config: dict = defaultdict() + if self.data_sparsifier_state_dict is None: + return config + for name, _ in pl_module.model.named_parameters(): + valid_name = _get_valid_name(name) + config[valid_name] = self.data_sparsifier.data_groups[valid_name] + + return config + + def on_train_epoch_end(self, trainer, pl_module): + self.sparsified = deepcopy(pl_module.model) + config = self.__create_config_based_on_state(pl_module) + + # attach model to the data sparsifier + _attach_model_to_data_sparsifier( + self.sparsified, self.data_sparsifier, config=config + ) + self.data_sparsifier.step() + self.data_scheduler.step() + + self.data_sparsifier_state_dict = self.data_sparsifier.state_dict() + + def on_train_end(self, trainer, pl_module): + self.data_sparsifier.squash_mask() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b727635d08151abd39c94ee40b0417afea97a05b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py @@ -0,0 +1,154 @@ +# mypy: allow-untyped-defs + +import torch +import torch.nn as nn +from torch.ao.pruning.sparsifier.utils import fqn_to_module, module_to_fqn + + +SUPPORTED_MODULES = {nn.Embedding, nn.EmbeddingBag} + + +def _fetch_all_embeddings(model): + """Fetches Embedding and EmbeddingBag modules from the model""" + embedding_modules = [] + stack = [model] + while stack: + module = stack.pop() + for _, child in module.named_children(): + fqn_name = module_to_fqn(model, child) + if type(child) in SUPPORTED_MODULES: + embedding_modules.append((fqn_name, child)) + else: + stack.append(child) + return embedding_modules + + +def post_training_sparse_quantize( + model, + data_sparsifier_class, + sparsify_first=True, + select_embeddings: list[nn.Module] | None = None, + **sparse_config, +): + """Takes in a model and applies sparsification and quantization to only embeddings & embeddingbags. + The quantization step can happen before or after sparsification depending on the `sparsify_first` argument. + + Args: + - model (nn.Module) + model whose embeddings needs to be sparsified + - data_sparsifier_class (type of data sparsifier) + Type of sparsification that needs to be applied to model + - sparsify_first (bool) + if true, sparsifies first and then quantizes + otherwise, quantizes first and then sparsifies. + - select_embeddings (List of Embedding modules) + List of embedding modules to in the model to be sparsified & quantized. + If None, all embedding modules with be sparsified + - sparse_config (Dict) + config that will be passed to the constructor of data sparsifier object. + + Note: + 1. When `sparsify_first=False`, quantization occurs first followed by sparsification. + - before sparsifying, the embedding layers are dequantized. + - scales and zero-points are saved + - embedding layers are sparsified and `squash_mask` is applied + - embedding weights are requantized using the saved scales and zero-points + 2. When `sparsify_first=True`, sparsification occurs first followed by quantization. + - embeddings are sparsified first + - quantization is applied on the sparsified embeddings + """ + data_sparsifier = data_sparsifier_class(**sparse_config) + + # if select_embeddings is None, perform it on all embeddings + if select_embeddings is None: + embedding_modules = _fetch_all_embeddings(model) + + else: + embedding_modules = [] + if not isinstance(select_embeddings, list): + raise AssertionError( + "the embedding_modules must be a list of embedding modules" + ) + for emb in select_embeddings: + if type(emb) not in SUPPORTED_MODULES: + raise AssertionError( + "the embedding_modules list must be an embedding or embedding bags" + ) + fqn_name = module_to_fqn(model, emb) + if fqn_name is None: + raise AssertionError( + "the embedding modules must be part of input model" + ) + embedding_modules.append((fqn_name, emb)) + + if sparsify_first: + # sparsify + for name, emb_module in embedding_modules: + valid_name = name.replace(".", "_") + data_sparsifier.add_data(name=valid_name, data=emb_module) + + data_sparsifier.step() + data_sparsifier.squash_mask() + + # quantize + for _, emb_module in embedding_modules: + emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig + + torch.ao.quantization.prepare(model, inplace=True) + torch.ao.quantization.convert(model, inplace=True) + + else: + # quantize + for _, emb_module in embedding_modules: + emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig + + torch.ao.quantization.prepare(model, inplace=True) + torch.ao.quantization.convert(model, inplace=True) + + # retrieve scale & zero_points + quantize_params: dict[str, dict] = { + "scales": {}, + "zero_points": {}, + "dequant_weights": {}, + "axis": {}, + "dtype": {}, + } + + for name, _ in embedding_modules: + quantized_emb = fqn_to_module(model, name) + if quantized_emb is None: + raise AssertionError(f"quantized embedding {name} not found in model") + + quantized_weight = quantized_emb.weight() # type: ignore[operator] + quantize_params["scales"][name] = quantized_weight.q_per_channel_scales() + quantize_params["zero_points"][name] = ( + quantized_weight.q_per_channel_zero_points() + ) + quantize_params["dequant_weights"][name] = torch.dequantize( + quantized_weight + ) + quantize_params["axis"][name] = quantized_weight.q_per_channel_axis() + quantize_params["dtype"][name] = quantized_weight.dtype + + # attach data to sparsifier + data_sparsifier.add_data( + name=name.replace(".", "_"), + data=quantize_params["dequant_weights"][name], + ) + + data_sparsifier.step() + data_sparsifier.squash_mask() + + for name, _ in embedding_modules: + quantized_emb = fqn_to_module(model, name) + if quantized_emb is None: + raise AssertionError(f"quantized embedding {name} not found in model") + requantized_vector = torch.quantize_per_channel( + quantize_params["dequant_weights"][name], + scales=quantize_params["scales"][name], + zero_points=quantize_params["zero_points"][name], + dtype=quantize_params["dtype"][name], + axis=quantize_params["axis"][name], + ) + + quantized_emb.set_weight(requantized_vector) # type: ignore[operator] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py new file mode 100644 index 0000000000000000000000000000000000000000..1a89de12bd9345a05acee98309f90d38d70daac1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py @@ -0,0 +1,96 @@ +# mypy: allow-untyped-defs +from collections.abc import Callable + +import torch + +from .base_structured_sparsifier import BaseStructuredSparsifier + + +__all__ = ["FPGMPruner"] + + +class FPGMPruner(BaseStructuredSparsifier): + r"""Filter Pruning via Geometric Median (FPGM) Structured Pruner + This sparsifier prune filter (row) in a tensor according to distances among filters according to + `Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration `_. + + This sparsifier is controlled by three variables: + 1. `sparsity_level` defines the number of filters (rows) that are zeroed-out. + 2. `dist` defines the distance measurement type. Default: 3 (L2 distance). + Available options are: [1, 2, (custom callable distance function)]. + + Note:: + Inputs should be a 4D convolutional tensor of shape (N, C, H, W). + - N: output channels size + - C: input channels size + - H: height of kernel + - W: width of kernel + """ + + def __init__(self, sparsity_level: float = 0.5, dist: Callable | int | None = None): + defaults = { + "sparsity_level": sparsity_level, + } + + if dist is None: + dist = 2 + + if callable(dist): + self.dist_fn = dist + elif dist == 1: + self.dist_fn = lambda x: torch.cdist(x, x, p=1) + elif dist == 2: + self.dist_fn = lambda x: torch.cdist(x, x, p=2) + else: + raise NotImplementedError("Distance function is not yet implemented.") + super().__init__(defaults=defaults) + + def _compute_distance(self, t): + r"""Compute distance across all entries in tensor `t` along all dimension + except for the one identified by dim. + Args: + t (torch.Tensor): tensor representing the parameter to prune + Returns: + distance (torch.Tensor): distance computed across filtters + """ + dim = 0 # prune filter (row) + + size = t.size(dim) + slc = [slice(None)] * t.dim() + + # flatten the tensor along the dimension + t_flatten = [ + t[tuple(slc[:dim] + [slice(i, i + 1)] + slc[dim + 1 :])].reshape(-1) + for i in range(size) + ] + t_flatten = torch.stack(t_flatten) + + # distance measurement + dist_matrix = self.dist_fn(t_flatten) + + # more similar with other filter indicates large in the sum of row + # pyrefly: ignore [bad-argument-type] + distance = torch.sum(torch.abs(dist_matrix), 1) + + return distance + + def update_mask( # type: ignore[override] + self, module, tensor_name, sparsity_level, **kwargs + ): + tensor_weight = getattr(module, tensor_name) + mask = getattr(module.parametrizations, tensor_name)[0].mask + + if sparsity_level <= 0: + mask.data = torch.ones_like(mask).bool() + elif sparsity_level >= 1.0: + mask.data = torch.zeros_like(mask).bool() + else: + distance = self._compute_distance(tensor_weight) + + tensor_size = tensor_weight.shape[0] # prune filter (row) + nparams_toprune = round(sparsity_level * tensor_size) + nparams_toprune = min( + max(nparams_toprune, 0), tensor_size + ) # clamp to [0, tensor_size] + topk = torch.topk(distance, k=nparams_toprune, largest=False) + mask[topk.indices] = False diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a57db6a8d8cde9a89c7cbda4dff6f6075559b59b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__init__.py @@ -0,0 +1,5 @@ +from .base_structured_sparsifier import BaseStructuredSparsifier +from .FPGM_pruner import FPGMPruner +from .lstm_saliency_pruner import LSTMSaliencyPruner +from .parametrization import BiasHook, FakeStructuredSparsity +from .saliency_pruner import SaliencyPruner diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/FPGM_pruner.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/FPGM_pruner.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad2b2a4d6234f74d47005d4364fe8300093a5480 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/FPGM_pruner.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4a8655683b386a147e7dceb7adca1d149912243 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/base_structured_sparsifier.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/base_structured_sparsifier.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3b1ba6701861f22b6ecae1448d1eda325ba1c07 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/base_structured_sparsifier.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/lstm_saliency_pruner.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/lstm_saliency_pruner.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0641b6e53bb680c0112f07ac8528e2de687c0860 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/lstm_saliency_pruner.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/match_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/match_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c27d86e30d7d00ba3232477c6dc09f5d2cd8c238 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/match_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/parametrization.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/parametrization.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b3c7905e6c33e38110c9f30bd0a0bde875c429d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/parametrization.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/prune_functions.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/prune_functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d887e09be354e17601953f06131381df2e4279f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/prune_functions.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/saliency_pruner.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/saliency_pruner.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21adff2c281360e5ee9dfd8b71f32e1aa31d1234 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/saliency_pruner.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py new file mode 100644 index 0000000000000000000000000000000000000000..d1676292f7d74c4a620de0a53334d6dcd33aa764 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py @@ -0,0 +1,313 @@ +# mypy: allow-untyped-defs +from collections.abc import Callable +from itertools import chain +from operator import getitem + +import torch +import torch.nn.functional as F +from torch import nn +from torch.ao.pruning.sparsifier.base_sparsifier import BaseSparsifier +from torch.fx import symbolic_trace +from torch.nn.utils import parametrize + +from .match_utils import apply_match, MatchAllNode +from .parametrization import BiasHook, FakeStructuredSparsity, module_contains_param +from .prune_functions import ( + prune_conv2d, + prune_conv2d_activation_conv2d, + prune_conv2d_activation_pool_conv2d, + prune_conv2d_conv2d, + prune_conv2d_pool_activation_conv2d, + prune_conv2d_pool_flatten_linear, + prune_linear, + prune_linear_activation_linear, + prune_linear_linear, + prune_lstm_output_layernorm_linear, + prune_lstm_output_linear, +) + + +def _get_supported_structured_pruning_modules(): + SUPPORTED_STRUCTURED_PRUNING_MODULES = { # added to config if None given + nn.Linear, + nn.Conv2d, + nn.LSTM, + } + return SUPPORTED_STRUCTURED_PRUNING_MODULES + + +def _get_supported_activation_functions(): + SUPPORTED_ACTIVATION_FUNCTIONS = { + F.relu, + F.rrelu, + F.hardtanh, + F.relu6, + F.sigmoid, + F.hardsigmoid, + F.tanh, + F.silu, + F.mish, + F.hardswish, + F.elu, + F.celu, + F.selu, + F.hardshrink, + F.leaky_relu, + F.logsigmoid, + F.softplus, + F.prelu, + F.softsign, + F.tanhshrink, + F.gelu, + } + return SUPPORTED_ACTIVATION_FUNCTIONS + + +def _get_supported_activation_modules(): + SUPPORTED_ACTIVATION_MODULES = { + nn.ReLU, + nn.RReLU, + nn.Hardtanh, + nn.ReLU6, + nn.Sigmoid, + nn.Hardsigmoid, + nn.Tanh, + nn.SiLU, + nn.Mish, + nn.Hardswish, + nn.ELU, + nn.CELU, + nn.SELU, + nn.Hardshrink, + nn.LeakyReLU, + nn.LogSigmoid, + nn.Softplus, + nn.PReLU, + nn.Softsign, + nn.Tanhshrink, + nn.GELU, + } + return SUPPORTED_ACTIVATION_MODULES + + +def _get_default_structured_pruning_patterns() -> dict[ + tuple[type[nn.Module] | Callable | MatchAllNode | str, ...], + Callable[..., None], +]: + """ + Returns the patterns for conv2d / linear conversion for each element in the activation functions/modules defined above. + """ + patterns: dict[ + tuple[type[nn.Module] | Callable | MatchAllNode | str, ...], + Callable[..., None], + ] = { + # linear -> linear + (nn.Linear, "output"): prune_linear, + (nn.Linear, nn.Linear): prune_linear_linear, + # conv2d -> conv2d + (nn.Conv2d, "output"): prune_conv2d, + (nn.Conv2d, nn.Conv2d): prune_conv2d_conv2d, + # TODO LSTM Structured pruning does not support returned state currently. + # Should find a way to explicitly match getitem(0) instead of getitem. + # This will also require changing the pruning function. + # lstm -> getitem(0) -> linear + (nn.LSTM, getitem, nn.Linear): prune_lstm_output_linear, + # lstm -> getitem(0) -> layernorm -> linear + (nn.LSTM, getitem, nn.LayerNorm, nn.Linear): prune_lstm_output_layernorm_linear, + } + + for activation in chain( + _get_supported_activation_functions(), _get_supported_activation_modules() + ): + patterns.update( + { + # linear -> activation -> linear + (nn.Linear, activation, nn.Linear): prune_linear_activation_linear, + # conv2d -> activation -> conv2d + (nn.Conv2d, activation, nn.Conv2d): prune_conv2d_activation_conv2d, + # conv2d -> activation -> pool -> conv2d + ( + nn.Conv2d, + activation, + nn.AvgPool2d, + nn.Conv2d, + ): prune_conv2d_activation_pool_conv2d, + ( + nn.Conv2d, + activation, + F.avg_pool2d, + nn.Conv2d, + ): prune_conv2d_activation_pool_conv2d, + ( + nn.Conv2d, + activation, + nn.MaxPool2d, + nn.Conv2d, + ): prune_conv2d_activation_pool_conv2d, + ( + nn.Conv2d, + activation, + F.max_pool2d, + nn.Conv2d, + ): prune_conv2d_activation_pool_conv2d, + # conv2d -> pool -> activation -> conv2d + ( + nn.Conv2d, + nn.AvgPool2d, + activation, + nn.Conv2d, + ): prune_conv2d_pool_activation_conv2d, + ( + nn.Conv2d, + F.avg_pool2d, + activation, + nn.Conv2d, + ): prune_conv2d_pool_activation_conv2d, + ( + nn.Conv2d, + nn.MaxPool2d, + activation, + nn.Conv2d, + ): prune_conv2d_pool_activation_conv2d, + ( + nn.Conv2d, + F.max_pool2d, + activation, + nn.Conv2d, + ): prune_conv2d_pool_activation_conv2d, + # conv2d -> adaptive pool -> flatten -> linear + ( + nn.Conv2d, + nn.AdaptiveAvgPool2d, + nn.Flatten, + nn.Linear, + ): prune_conv2d_pool_flatten_linear, + ( + nn.Conv2d, + nn.AdaptiveAvgPool2d, + torch.flatten, + nn.Linear, + ): prune_conv2d_pool_flatten_linear, + ( + nn.Conv2d, + nn.AdaptiveMaxPool2d, + nn.Flatten, + nn.Linear, + ): prune_conv2d_pool_flatten_linear, + ( + nn.Conv2d, + nn.AdaptiveMaxPool2d, + torch.flatten, + nn.Linear, + ): prune_conv2d_pool_flatten_linear, + } + ) + return patterns + + +class BaseStructuredSparsifier(BaseSparsifier): + r"""Base class for structured pruning. + + Abstract methods that need to be implemented: + - update_mask: Function to compute a new mask for all keys in the + `groups` attribute. + + Args: + - defaults [dict]: default configurations will be attached to the + configuration. Only the keys that don't exist in the `config` will + be updated. + """ + + def __init__(self, defaults, patterns=None): + super().__init__(defaults) + if patterns is None: + patterns = _get_default_structured_pruning_patterns() + self.patterns = patterns + + def make_config_from_model( + self, + model: nn.Module, + SUPPORTED_MODULES: set[type] | None = None, + ) -> None: + if SUPPORTED_MODULES is None: + SUPPORTED_MODULES = _get_supported_structured_pruning_modules() + super().make_config_from_model(model, SUPPORTED_MODULES=SUPPORTED_MODULES) + + def _prepare(self, *args, **kwargs) -> None: + r"""This function will attach the FakeStructuredSparsity parameterizations + and BiasHooks at the appropriate points in the model. + """ + for config in self.groups: + module = config["module"] + tensor_name = config["tensor_name"] + parametrization = config.get("parametrization", FakeStructuredSparsity) + tensor = getattr(module, tensor_name) + + mask = config.get( + "mask", + torch.ones(tensor.shape[0], dtype=torch.bool, device=tensor.device), + ) + self.state[config["tensor_fqn"]]["mask"] = mask + parametrize.register_parametrization( + module, tensor_name, parametrization(mask) + ) + + # if linear / conv, we add in bias hooks + if isinstance(module, (nn.Linear, nn.Conv2d)): + prune_bias = config.get("prune_bias", True) + if module.bias is not None: + module.register_parameter( + "_bias", nn.Parameter(module.bias.detach()) + ) + # pyrefly: ignore [bad-assignment] + module.bias = None + module.prune_bias = prune_bias + + module.register_forward_hook( + BiasHook(module.parametrizations.weight[0], prune_bias) # type: ignore[union-attr, index] + ) + + def prune(self) -> None: + r""" + This function will FX symbolically trace the model and then find instances of the patterns + defined in self.patterns (by default SUPPORTED_STRUCTURED_PRUNING_PATTERNS ). + + For each pattern, it will apply to corresponding conversion function, which will modify the output + and input size expected by the modules within the pattern + """ + + self.traced = symbolic_trace(self.model) + modules = dict(self.traced.named_modules()) + + # Right now we check for matches simply by iterating across all the patterns + # if this is slow we can store patterns in a trie-structure and modify this code for faster lookup + for node in self.traced.graph.nodes: + for pattern, convert_fn in self.patterns.items(): + matched = apply_match(modules, pattern, node, []) + if matched is None: + continue + + first_module = modules.get(node.target) + # check if first module exists and has appropriate parameterization, otherwise skip + if ( + first_module is not None + and parametrize.is_parametrized(first_module) + and module_contains_param(first_module, FakeStructuredSparsity) + ): + convert_block = [] + for node in matched: + if node.op == "call_module": + convert_block.append(modules.get(node.target)) + elif node.op == "call_function": + convert_block.append(node.target) + convert_fn(*convert_block) + + for module in self.traced.modules(): + if module_contains_param(module, FakeStructuredSparsity): + raise Exception( # noqa: TRY002 + f"Error: {module} still contains FakeStructuredSparsity parametrizations!" + ) + + self.traced.graph.lint() + self.traced.recompile() + return self.traced # type: ignore[return-value] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py new file mode 100644 index 0000000000000000000000000000000000000000..f904cc3ab8c4c34a193dd30926fff164010287a8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py @@ -0,0 +1,54 @@ +from typing import Any, cast + +import torch +from torch import nn + +from .base_structured_sparsifier import BaseStructuredSparsifier +from .parametrization import FakeStructuredSparsity + + +class LSTMSaliencyPruner(BaseStructuredSparsifier): + """ + Prune packed LSTM weights based on saliency. + For each layer {k} inside a LSTM, we have two packed weight matrices + - weight_ih_l{k} + - weight_hh_l{k} + + These tensors pack the weights for the 4 linear layers together for efficiency. + + [W_ii | W_if | W_ig | W_io] + + Pruning this tensor directly will lead to weights being misassigned when unpacked. + To ensure that each packed linear layer is pruned the same amount: + 1. We split the packed weight into the 4 constituent linear parts + 2. Update the mask for each individual piece using saliency individually + + This applies to both weight_ih_l{k} and weight_hh_l{k}. + """ + + def update_mask(self, module: nn.Module, tensor_name: str, **kwargs: Any) -> None: + weights = getattr(module, tensor_name) + + for p in getattr(module.parametrizations, tensor_name): + if isinstance(p, FakeStructuredSparsity): + mask = cast(torch.Tensor, p.mask) + + # select weights based on magnitude + if weights.dim() <= 1: + raise Exception( # noqa: TRY002 + "Structured pruning can only be applied to a 2+dim weight tensor!" + ) + # take norm over all but first dim + dims = tuple(range(1, weights.dim())) + saliency = weights.norm(dim=dims, p=1) + + # handle weights in 4 groups + split_size = len(mask) // 4 + masks = torch.split(mask, split_size) + saliencies = torch.split(saliency, split_size) + + for keep_mask, sal in zip(masks, saliencies): + # mask smallest k values to be removed + k = int(len(keep_mask) * kwargs["sparsity_level"]) + prune = sal.topk(k, largest=False, sorted=False).indices + keep_mask.data[prune] = False # modifies underlying p.mask directly diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/match_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/match_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e22b979ab900c63a9a975b7a07c9b2a64ed8c0b5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/match_utils.py @@ -0,0 +1,65 @@ +""" +Contains utility functions to check if a pattern is in the graph and return the matching nodes +""" + +from typing import Any + +import torch +from torch import nn +from torch.ao.quantization.utils import MatchAllNode +from torch.fx import Node +from torch.nn.utils import parametrize + + +def _match( + modules: dict[str, nn.ModuleDict], + node: Node, + current: nn.Module | Any, +) -> bool: + r""" + checks to see if a single node of a pattern matches + """ + if isinstance(current, type) and issubclass(current, MatchAllNode): + return True + if not isinstance(node, Node): + return False + if isinstance(current, type) and issubclass(current, torch.nn.Module): + return ( + node.op == "call_module" + and parametrize.type_before_parametrizations(modules[node.target]) # type: ignore[index] + == current + ) + elif callable(current): + return node.op == "call_function" and node.target is current + elif isinstance(current, str): + return node.target == current + return False + + +def apply_match( + modules: dict[str, nn.ModuleDict], + pattern: tuple[Any] | Any, + node: Node, + matched_node_pattern: list[Node], +) -> list[Node] | None: + r""" + This function will return the matched nodes if the pattern matches the node given + If there is no match, it will return None + """ + if isinstance(pattern, tuple): + if len(pattern) == 1: + if _match(modules, node, pattern[0]): + return matched_node_pattern + [node] + + first, *rest = pattern + if _match(modules, node, first): + if rest is None: + return matched_node_pattern + [node] + + for user in node.users: + return apply_match( + modules, tuple(rest), user, matched_node_pattern + [node] + ) + elif _match(modules, node, pattern): + return [node] + return None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/parametrization.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/parametrization.py new file mode 100644 index 0000000000000000000000000000000000000000..4256d6fd01750d4408b92342bfb8d12239bf129a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/parametrization.py @@ -0,0 +1,63 @@ +# mypy: allow-untyped-defs +import torch +from torch import nn +from torch.nn.utils.parametrize import is_parametrized + + +def module_contains_param(module, parametrization): + if is_parametrized(module): + # see if any of the module tensors have a parametriztion attached that matches the one passed in + return any( + any(isinstance(param, parametrization) for param in param_list) + for key, param_list in module.parametrizations.items() + ) + return False + + +# Structured Pruning Parameterizations +class FakeStructuredSparsity(nn.Module): + r""" + Parametrization for Structured Pruning. Like FakeSparsity, this should be attached to + the 'weight' or any other parameter that requires a mask. + + Instead of an element-wise bool mask, this parameterization uses a row-wise bool mask. + """ + + def __init__(self, mask): + super().__init__() + self.register_buffer("mask", mask) + + def forward(self, x): + if not isinstance(self.mask, torch.Tensor): + raise AssertionError("mask must be a torch.Tensor") + if self.mask.shape[0] != x.shape[0]: + raise AssertionError( + f"mask shape[0] ({self.mask.shape[0]}) must match x shape[0] ({x.shape[0]})" + ) + shape = [1] * len(x.shape) + shape[0] = -1 + return self.mask.reshape(shape) * x + + def state_dict(self, *args, **kwargs): + # avoid double saving masks + return {} + + +class BiasHook: + def __init__(self, parametrization, prune_bias): + self.param = parametrization + self.prune_bias = prune_bias + + def __call__(self, module, input, output): + if getattr(module, "_bias", None) is not None: + bias = module._bias.data + if self.prune_bias: + bias[~self.param.mask] = 0 + + # reshape bias to broadcast over output dimensions + idx = [1] * len(output.shape) + idx[1] = -1 + bias = bias.reshape(idx) + + output += bias + return output diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/prune_functions.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/prune_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..14a1c9a97b07ccb87a5ffab2923a105b7abbd6d4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/prune_functions.py @@ -0,0 +1,485 @@ +# mypy: allow-untyped-defs +""" +Collection of conversion functions for linear / conv2d structured pruning +Also contains utilities for bias propagation +""" + +from collections.abc import Callable +from typing import cast + +import torch +from torch import nn, Tensor +from torch.nn.utils import parametrize +from torch.nn.utils.parametrize import ParametrizationList + +from .parametrization import BiasHook, FakeStructuredSparsity + + +# BIAS PROPAGATION +def _remove_bias_handles(module: nn.Module) -> None: + if hasattr(module, "_forward_hooks"): + bias_hooks: list[int] = [] + for key, hook in module._forward_hooks.items(): + if isinstance(hook, BiasHook): + bias_hooks.append(key) + + for key in bias_hooks: + del module._forward_hooks[key] + + +def _get_adjusted_next_layer_bias( + next_layer: nn.Module, pruned_biases: Tensor, mask: Tensor +) -> nn.Parameter: + r"""Returns new adjusted bias for the second supported module""" + if parametrize.is_parametrized(next_layer): + # need to access original weight + parametrization_dict = cast(nn.ModuleDict, next_layer.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + next_weight = weight_parameterizations.original + else: + next_weight = cast(Tensor, next_layer.weight) + + scaling_weight = next_weight[:, ~mask] + if isinstance(next_layer, nn.Conv2d): # checking for Conv2d + # Propagating first layer pruned biases and calculating the new second layer bias + # involves more steps since the Conv2d scaling weight has extra dimensions, + # so adding bias involves broadcasting, logically: + # for each channel k in range(oC): + # scaled_biases = sum(first_bias[pruned_idx] @ next_weight[k, pruned_idx, :, :].T) + # new_next_bias[k] = old_next_bias[k] + scaled_biases + scaling_product = torch.matmul( + pruned_biases.reshape(1, -1), torch.transpose(scaling_weight, 1, 2) + ) + sum_range = list(range(len(scaling_product.shape)))[ + 1: + ] # all but the first dimension + scaled_biases = torch.sum(scaling_product, sum_range) + elif isinstance(next_layer, nn.Linear): # Linear + scaled_biases = torch.matmul( + pruned_biases, torch.transpose(scaling_weight, 0, 1) + ) # recall b2_new = b1 @ w2.T + b2 + else: + raise NotImplementedError(f"Type {type(next_layer)} not supported yet.") + + if ( + parametrize.is_parametrized(next_layer) + and getattr(next_layer, "_bias", None) is not None + ): # next_layer is parametrized & has original bias ._bias + adjusted_bias = nn.Parameter(scaled_biases + next_layer._bias) # type: ignore[operator] + elif ( + not parametrize.is_parametrized(next_layer) and next_layer.bias is not None + ): # next_layer not parametrized & has .bias + adjusted_bias = nn.Parameter(scaled_biases + next_layer.bias) # type: ignore[operator] + else: # next_layer has no bias + adjusted_bias = nn.Parameter(scaled_biases) + return adjusted_bias + + +def _prune_module_bias(module: nn.Module, mask: Tensor) -> None: + r"""Applies mask to given modules bias""" + # prune bias along with weights, discard pruned indices of bias + original_bias = cast(Tensor, getattr(module, "_bias", module.bias)) + if original_bias is not None: + module.bias = nn.Parameter(original_bias[mask]) + + # remove _bias parameter + if hasattr(module, "_bias"): + delattr(module, "_bias") + + +def _propagate_module_bias(module: nn.Module, mask: Tensor) -> Tensor | None: + r""" + In the case that we need to propagate biases, this function will return the biases we need + """ + # set current module bias + if module.bias is not None: + module.bias = nn.Parameter(cast(Tensor, module.bias)[mask]) + elif getattr(module, "_bias", None) is not None: + # pyrefly: ignore [bad-assignment] + module.bias = nn.Parameter(cast(Tensor, module._bias)[mask]) + + # get pruned biases to propagate to subsequent layer + if getattr(module, "_bias", None) is not None: + pruned_biases = cast(Tensor, module._bias)[~mask] + else: + pruned_biases = None + + if hasattr(module, "_bias"): + delattr(module, "_bias") + + return pruned_biases + + +# LINEAR +def _prune_linear_helper(linear: nn.Linear) -> Tensor: + # expects linear to be a parameterized linear module + parametrization_dict = cast(nn.ModuleDict, linear.parametrizations) + weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight) + for p in weight_parameterizations: + if isinstance(p, FakeStructuredSparsity): + mask = cast(Tensor, p.mask) + + with torch.no_grad(): + parametrize.remove_parametrizations(linear, "weight", leave_parametrized=True) + linear.weight = nn.Parameter(linear.weight[mask]) # type: ignore[possibly-undefined] + linear.out_features = linear.weight.shape[0] + _remove_bias_handles(linear) + + # pyrefly: ignore [unbound-name] + return mask + + +def prune_linear(linear: nn.Linear) -> None: + mask = _prune_linear_helper(linear) + if getattr(linear, "prune_bias", False): + _prune_module_bias(linear, mask) + + +def prune_linear_linear(linear1: nn.Linear, linear2: nn.Linear) -> None: + prune_linear_activation_linear(linear1, None, linear2) + + +def prune_linear_activation_linear( + linear1: nn.Linear, + activation: Callable[[Tensor], Tensor] | None, + linear2: nn.Linear, +): + mask = _prune_linear_helper(linear1) + if getattr(linear1, "prune_bias", False): + _prune_module_bias(linear1, mask) + else: + pruned_biases = _propagate_module_bias(linear1, mask) + if pruned_biases is not None: + if activation: + pruned_biases = activation(pruned_biases) + linear2.bias = _get_adjusted_next_layer_bias(linear2, pruned_biases, mask) + + with torch.no_grad(): + if parametrize.is_parametrized(linear2): + parametrization_dict = cast(nn.ModuleDict, linear2.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + + weight_parameterizations.original = nn.Parameter( + weight_parameterizations.original[:, mask] + ) + linear2.in_features = weight_parameterizations.original.shape[1] + else: + linear2.weight = nn.Parameter(linear2.weight[:, mask]) + linear2.in_features = linear2.weight.shape[1] + + +# CONV2D +def _prune_conv2d_helper(conv2d: nn.Conv2d) -> Tensor: + parametrization_dict = cast(nn.ModuleDict, conv2d.parametrizations) + weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight) + for p in weight_parameterizations: + if isinstance(p, FakeStructuredSparsity): + mask = cast(Tensor, p.mask) + + with torch.no_grad(): + parametrize.remove_parametrizations(conv2d, "weight", leave_parametrized=True) + conv2d.weight = nn.Parameter(conv2d.weight[mask]) # type: ignore[possibly-undefined] + conv2d.out_channels = conv2d.weight.shape[0] + + _remove_bias_handles(conv2d) + # pyrefly: ignore [unbound-name] + return mask + + +def prune_conv2d_padded(conv2d_1: nn.Conv2d) -> None: + parametrization_dict = cast(nn.ModuleDict, conv2d_1.parametrizations) + weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight) + for p in weight_parameterizations: + if isinstance(p, FakeStructuredSparsity): + mask = cast(Tensor, p.mask) + + with torch.no_grad(): + parametrize.remove_parametrizations(conv2d_1, "weight", leave_parametrized=True) + + if getattr(conv2d_1, "_bias", None) is not None: + if ( + conv2d_1.bias is not None + ): # conv2d_1 has original bias and bias propagated from previous layer + new_bias = torch.zeros(conv2d_1.bias.shape) + new_bias[mask] = conv2d_1.bias[mask] # type: ignore[possibly-undefined] + # adjusted bias that to keep in conv2d_1 + # pyrefly: ignore [unbound-name] + new_bias[~mask] = cast(Tensor, conv2d_1._bias)[~mask] + # pruned biases that are kept instead of propagated + conv2d_1.bias = nn.Parameter(new_bias) + else: # conv2d_1 has only original bias + conv2d_1.bias = nn.Parameter(cast(Tensor, conv2d_1._bias)) + else: + # no original bias, only propagated bias + if ( + conv2d_1.bias is not None + ): # conv2d_1 has bias propagated from previous layer + conv2d_1.bias.data[~mask] = 0 # type: ignore[possibly-undefined] + + if hasattr(conv2d_1, "_bias"): + delattr(conv2d_1, "_bias") + + +def prune_conv2d(conv2d: nn.Conv2d) -> None: + mask = _prune_conv2d_helper(conv2d) + if getattr(conv2d, "prune_bias", False): + _prune_module_bias(conv2d, mask) + + +def prune_conv2d_conv2d(conv2d_1: nn.Conv2d, conv2d_2: nn.Conv2d) -> None: + prune_conv2d_activation_conv2d(conv2d_1, None, conv2d_2) + + +def prune_conv2d_activation_conv2d( + conv2d_1: nn.Conv2d, + activation: Callable[[Tensor], Tensor] | None, + conv2d_2: nn.Conv2d, +): + r""" + Fusion Pattern for conv2d -> some activation module / function -> conv2d layers + """ + parametrization_dict = cast(nn.ModuleDict, conv2d_1.parametrizations) + weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight) + for p in weight_parameterizations: + if isinstance(p, FakeStructuredSparsity): + mask = cast(Tensor, p.mask) + + prune_bias = getattr(conv2d_1, "prune_bias", False) + if ( + hasattr(conv2d_2, "padding") + and cast(tuple[int], conv2d_2.padding) > (0, 0) + and (conv2d_1.bias is not None or getattr(conv2d_1, "_bias", None) is not None) + ): + prune_conv2d_padded(conv2d_1) + else: + mask = _prune_conv2d_helper(conv2d_1) + if prune_bias: + _prune_module_bias(conv2d_1, mask) + else: + pruned_biases = _propagate_module_bias(conv2d_1, mask) + if pruned_biases is not None: + if activation: + pruned_biases = activation(pruned_biases) + conv2d_2.bias = _get_adjusted_next_layer_bias( + conv2d_2, pruned_biases, mask + ) + + if ( + not ( + hasattr(conv2d_2, "padding") + and cast(tuple[int], conv2d_2.padding) > (0, 0) + ) + or conv2d_1.bias is None + ): + with torch.no_grad(): + if parametrize.is_parametrized(conv2d_2): + parametrization_dict = cast( + nn.ModuleDict, conv2d_2.parametrizations + ) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + weight_parameterizations.original = nn.Parameter( + weight_parameterizations.original[:, mask] + ) + conv2d_2.in_channels = weight_parameterizations.original.shape[1] + else: + conv2d_2.weight = nn.Parameter(conv2d_2.weight[:, mask]) + conv2d_2.in_channels = conv2d_2.weight.shape[1] + + +def prune_conv2d_pool_activation_conv2d( + c1: nn.Conv2d, + pool: nn.Module, + activation: Callable[[Tensor], Tensor] | None, + c2: nn.Conv2d, +) -> None: + prune_conv2d_activation_conv2d(c1, activation, c2) + + +def prune_conv2d_activation_pool_conv2d( + c1: nn.Conv2d, + activation: Callable[[Tensor], Tensor] | None, + pool: nn.Module, + c2: nn.Conv2d, +) -> None: + prune_conv2d_activation_conv2d(c1, activation, c2) + + +def prune_conv2d_pool_flatten_linear( + conv2d: nn.Conv2d, + pool: nn.Module, + flatten: Callable[[Tensor], Tensor] | None, + linear: nn.Linear, +) -> None: + mask = _prune_conv2d_helper(conv2d) + + # We map the pruned indices of the Conv2d output to the flattened indices of the Linear following the Flatten layer. + # we determine the flattening scale (h * w), and readjust `first_pruned_indices` + # (each idx maps to range idx * h * w to (idx+1) * h * w), `first_valid_indices`, + # and `pruned_biases` (repeat each bias by h * w). + if parametrize.is_parametrized(linear): + parametrization_dict = cast(nn.ModuleDict, linear.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + linear_ic = weight_parameterizations.original.shape[1] + else: + linear_ic = linear.weight.shape[1] + + conv2d_oc = len(mask) + if linear_ic % conv2d_oc != 0: + raise AssertionError( + f"Flattening from dimensions {conv2d_oc} to {linear_ic} not supported" + ) + + flatten_scale = linear_ic // conv2d_oc + flattened_mask = torch.tensor( + [[val] * flatten_scale for val in mask], dtype=torch.bool, device=mask.device + ).flatten() + + if getattr(conv2d, "prune_bias", False): + _prune_module_bias(conv2d, mask) + else: + pruned_biases = cast(Tensor, _propagate_module_bias(conv2d, mask)) + flattened_pruned_biases = torch.tensor( + [[bias] * flatten_scale for bias in pruned_biases], device=mask.device + ).flatten() + linear.bias = _get_adjusted_next_layer_bias( + linear, flattened_pruned_biases, flattened_mask + ) + + with torch.no_grad(): + if parametrize.is_parametrized(linear): + parametrization_dict = cast(nn.ModuleDict, linear.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + weight_parameterizations.original = nn.Parameter( + weight_parameterizations.original[:, flattened_mask] + ) + linear.in_features = weight_parameterizations.original.shape[1] + else: + linear.weight = nn.Parameter(linear.weight[:, flattened_mask]) + linear.in_features = linear.weight.shape[1] + + +def prune_lstm_output_linear( + lstm: nn.LSTM, getitem: Callable, linear: nn.Linear +) -> None: + prune_lstm_output_layernorm_linear(lstm, getitem, None, linear) + + +def prune_lstm_output_layernorm_linear( + lstm: nn.LSTM, + getitem: Callable, + layernorm: nn.LayerNorm | None, + linear: nn.Linear, +) -> None: + for i in range(lstm.num_layers): + if parametrize.is_parametrized(lstm, f"weight_ih_l{i}"): + parametrization_dict = cast(nn.ModuleDict, lstm.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict[f"weight_ih_l{i}"] + ) + mask = weight_parameterizations[0].mask + + with torch.no_grad(): + parametrize.remove_parametrizations( + lstm, f"weight_ih_l{i}", leave_parametrized=True + ) + setattr( + lstm, + f"weight_ih_l{i}", + nn.Parameter(getattr(lstm, f"weight_ih_l{i}")[mask]), + ) + setattr( + lstm, + f"bias_ih_l{i}", + nn.Parameter(getattr(lstm, f"bias_ih_l{i}")[mask]), + ) + + if parametrize.is_parametrized(lstm, f"weight_hh_l{i}"): + parametrization_dict = cast(nn.ModuleDict, lstm.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict[f"weight_hh_l{i}"] + ) + mask = weight_parameterizations[0].mask + + with torch.no_grad(): + parametrize.remove_parametrizations( + lstm, f"weight_hh_l{i}", leave_parametrized=True + ) + # splitting out hidden-hidden masks + W_hi, W_hf, W_hg, W_ho = torch.split( + getattr(lstm, f"weight_hh_l{i}"), lstm.hidden_size + ) + M_hi, M_hf, M_hg, M_ho = torch.split(mask, lstm.hidden_size) # type: ignore[arg-type] + + # resize each individual weight separately + W_hi = W_hi[M_hi][:, M_hi] + W_hf = W_hf[M_hf][:, M_hf] + W_hg = W_hg[M_hg][:, M_hg] + W_ho = W_ho[M_ho][:, M_ho] + + # concat, use this as new weight + new_weight = torch.cat((W_hi, W_hf, W_hg, W_ho)) + setattr(lstm, f"weight_hh_l{i}", nn.Parameter(new_weight)) + setattr( + lstm, + f"bias_hh_l{i}", + nn.Parameter(getattr(lstm, f"bias_hh_l{i}")[mask]), + ) + + # If this is the final layer, then we need to prune linear layer columns + if i + 1 == lstm.num_layers: + lstm.hidden_size = int(M_hi.sum()) + with torch.no_grad(): + if parametrize.is_parametrized(linear): + parametrization_dict = cast( + nn.ModuleDict, linear.parametrizations + ) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + + weight_parameterizations.original = nn.Parameter( + weight_parameterizations.original[:, M_ho] + ) + linear.in_features = weight_parameterizations.original.shape[1] + else: + linear.weight = nn.Parameter(linear.weight[:, M_ho]) + linear.in_features = linear.weight.shape[1] + + # if layernorm module, prune weight and bias + if layernorm is not None: + layernorm.normalized_shape = (linear.in_features,) + layernorm.weight = nn.Parameter(layernorm.weight[M_ho]) + layernorm.bias = nn.Parameter(layernorm.bias[M_ho]) + + # otherwise need to prune the columns of the input of the next LSTM layer + else: + with torch.no_grad(): + if parametrize.is_parametrized(lstm, f"weight_ih_l{i + 1}"): + parametrization_dict = cast( + nn.ModuleDict, lstm.parametrizations + ) + weight_parameterizations = cast( + ParametrizationList, + getattr(parametrization_dict, f"weight_ih_l{i + 1}"), + ) + + weight_parameterizations.original = nn.Parameter( + weight_parameterizations.original[:, M_ho] + ) + else: + next_layer_weight = getattr(lstm, f"weight_ih_l{i + 1}") + setattr( + lstm, + f"weight_ih_l{i + 1}", + nn.Parameter(next_layer_weight[:, M_ho]), + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/saliency_pruner.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/saliency_pruner.py new file mode 100644 index 0000000000000000000000000000000000000000..11c4652a7f0dafe2d3dd94f85c68fece035fd827 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/saliency_pruner.py @@ -0,0 +1,35 @@ +# mypy: allow-untyped-defs +from .base_structured_sparsifier import BaseStructuredSparsifier + + +class SaliencyPruner(BaseStructuredSparsifier): + """ + Prune rows based on the saliency (L1 norm) of each row. + + This pruner works on N-Dimensional weight tensors. + For each row, we will calculate the saliency, which is the sum the L1 norm of all weights in that row. + We expect that the resulting saliency vector has the same shape as our mask. + We then pick elements to remove until we reach the target sparsity_level. + """ + + def update_mask(self, module, tensor_name, **kwargs): + # tensor_name will give you the FQN, all other entries in sparse config is present in kwargs + weights = getattr(module, tensor_name) + mask = getattr(module.parametrizations, tensor_name)[0].mask + + # use negative weights so we can use topk (we prune out the smallest) + if weights.dim() <= 1: + raise Exception( # noqa: TRY002 + "Structured pruning can only be applied to a 2+dim weight tensor!" + ) + saliency = -weights.norm(dim=tuple(range(1, weights.dim())), p=1) + if saliency.shape != mask.shape: + raise AssertionError( + f"saliency shape ({saliency.shape}) must match mask shape ({mask.shape})" + ) + + num_to_pick = int(len(mask) * kwargs["sparsity_level"]) + prune = saliency.topk(num_to_pick).indices + + # Set the mask to be false for the rows we want to prune + mask.data[prune] = False diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/scheduler/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/scheduler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/scheduler/base_scheduler.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/scheduler/base_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..ac8916713dae6fe008b75e6dca9d63851560ab6e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/scheduler/base_scheduler.py @@ -0,0 +1,173 @@ +# mypy: allow-untyped-defs + +import warnings +import weakref +from functools import wraps + +from torch.ao.pruning.sparsifier.base_sparsifier import BaseSparsifier + + +__all__ = ["BaseScheduler"] + + +class BaseScheduler: + def __init__(self, sparsifier, last_epoch=-1, verbose=False): + # Attach sparsifier + if not isinstance(sparsifier, BaseSparsifier): + raise TypeError( + f"{type(sparsifier).__name__} is not an instance of torch.ao.pruning.BaseSparsifier" + ) + self.sparsifier = sparsifier + + # Initialize epoch and base sparsity levels + + self.base_sl = [group["sparsity_level"] for group in sparsifier.groups] + self.last_epoch = last_epoch + + # Following https://github.com/pytorch/pytorch/issues/20124 + # We would like to ensure that `scheduler.step()` is called after + # `sparsifier.step()` + def with_counter(method): + if getattr(method, "_with_counter", False): + # `sparsifier.step()` has already been replaced, return. + return method + + # Keep a weak reference to the sparsifier instance to prevent + # cyclic references. + instance_ref = weakref.ref(method.__self__) + # Get the unbound method for the same purpose. + func = method.__func__ + cls = instance_ref().__class__ + del method + + @wraps(func) + def wrapper(*args, **kwargs): + instance = instance_ref() + instance._step_count += 1 # type: ignore[union-attr] + wrapped = func.__get__(instance, cls) + return wrapped(*args, **kwargs) + + # Note that the returned function here is no longer a bound method, + # so attributes like `__func__` and `__self__` no longer exist. + wrapper._with_counter = True # type: ignore[attr-defined] + return wrapper + + self.sparsifier.step = with_counter(self.sparsifier.step) # type: ignore[assignment] + self.sparsifier._step_count = 0 # type: ignore[attr-defined] + self._step_count: int = 0 + self.verbose = verbose + + # Housekeeping + self._get_sl_called_within_step: bool = False + + self.step() + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the sparsifier. + """ + return { + key: value for key, value in self.__dict__.items() if key != "sparsifier" + } + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_sl(self): + """Return last computed sparsity level by current scheduler.""" + return self._last_sl + + def get_sl(self): + # Compute sparsity level using chainable form of the scheduler + # Note: This method is not intended to be called directly, and is only + # used by the ".step" method. Use .get_last_sl() instead. + if not self._get_sl_called_within_step: + warnings.warn( + "To get the last sparsity level computed by the scheduler, " + "please use `get_last_sl()`.", + stacklevel=2, + ) + raise NotImplementedError + + def print_sl(self, is_verbose, group, sl, epoch=None): + """Display the current sparsity level.""" + if is_verbose: + if epoch is None: + print(f"Adjusting sparsity level of group {group} to {sl:.4e}.") + else: + print( + f"Epoch {epoch:5d}: adjusting sparsity level of group {group} to {sl:.4e}." + ) + + def __repr__(self): + format_string = self.__class__.__name__ + " (" + format_string += "\n" + format_string += f"Sparsifier {self.sparsifier}\n" + format_string += f" base_sl: {self.base_sl}\n" + format_string += ")" + return format_string + + def step(self, epoch=None): + # Raise warning if trying to call scheduler step before the sparsifier. + # https://github.com/pytorch/pytorch/issues/20124 + if self._step_count == 1: + if not hasattr(self.sparsifier.step, "_with_counter"): + warnings.warn( + "Seems like `sparsifier.step()` has been overridden after sparsity scheduler " + "initialization. Please, make sure to call `sparsifier.step()` before " + "`scheduler.step()`.", + UserWarning, + stacklevel=2, + ) + + # Just check if there were two first scheduler.step() calls before sparsifier.step() + elif self.sparsifier._step_count < 1: # type: ignore[attr-defined] + warnings.warn( + "Detected call of `scheduler.step()` before `sparsifier.step()`. " + "You have to make sure you run the sparsifier.step() BEFORE any " + "calls to the scheduler.step().", + UserWarning, + stacklevel=2, + ) + self._step_count += 1 + + class _enable_get_sl_call: + def __init__(self, o): + self.o = o + + def __enter__(self): + self.o._get_sl_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_sl_called_within_step = False + + with _enable_get_sl_call(self): + self.last_epoch += 1 + values = self.get_sl() + + for i, data in enumerate(zip(self.sparsifier.groups, values)): + param_group, sl = data + param_group["sparsity_level"] = sl + self.print_sl(self.verbose, i, sl, epoch) + + self._last_sl = [group["sparsity_level"] for group in self.sparsifier.groups] + self.sparsifier.enable_mask_update = True + + def _make_sure_a_list(self, var): + r"""Utility that extends it to the same length as the .groups, ensuring it is a list""" + n = len(self.sparsifier.groups) + if not isinstance(var, (list, tuple)): + return [var] * n + else: + if len(var) != n: + raise ValueError(f"Expected variable of length {n}, but got {len(var)}") + return list(var) # We want the result to be in a list, not tuple diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/scheduler/cubic_scheduler.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/scheduler/cubic_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..d4706900762adf411eb68dfd7fee3ff9fed36b51 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/scheduler/cubic_scheduler.py @@ -0,0 +1,114 @@ +# mypy: allow-untyped-defs +import warnings + +from .base_scheduler import BaseScheduler + + +__all__ = ["CubicSL"] + + +def _clamp(x, lo, hi): + return max(lo, min(hi, x)) + + +class CubicSL(BaseScheduler): + r"""Sets the sparsity level of each parameter group to the final sl + plus a given exponential function. + + .. math:: + + s_i = s_f + (s_0 - s_f) \cdot \left( 1 - \frac{t - t_0}{n\Delta t} \right)^3 + + where :math:`s_i` is the sparsity at epoch :math:`t`, :math;`s_f` is the final + sparsity level, :math:`f(i)` is the function to be applied to the current epoch + :math:`t`, initial epoch :math:`t_0`, and final epoch :math:`t_f`. + :math:`\Delta t` is used to control how often the update of the sparsity level + happens. By default, + + Args: + sparsifier (BaseSparsifier): Wrapped sparsifier. + init_sl (int, list): Initial level of sparsity + init_t (int, list): Initial step, when pruning starts + delta_t (int, list): Pruning frequency + total_t (int, list): Total number of pruning steps + initially_zero (bool, list): If True, sets the level of sparsity to 0 + before init_t (:math:`t_0`). Otherwise, the sparsity level before + init_t (:math:`t_0`) is set to init_sl(:math:`s_0`) + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + """ + + def __init__( + self, + sparsifier, + init_sl=0.0, + init_t=0, + delta_t=10, + total_t=100, + initially_zero=False, + last_epoch=-1, + verbose=False, + ): + self.sparsifier = sparsifier + + self.init_sl = self._make_sure_a_list(init_sl) + self.init_t = self._make_sure_a_list(init_t) + self.delta_t = self._make_sure_a_list(delta_t) + self.total_t = self._make_sure_a_list(total_t) + + self.initially_zero = self._make_sure_a_list(initially_zero) + + super().__init__(sparsifier, last_epoch, verbose) + + @staticmethod + def sparsity_compute_fn(s_0, s_f, t, t_0, dt, n, initially_zero=False): + r""" "Computes the current level of sparsity. + + Based on https://arxiv.org/pdf/1710.01878.pdf + + Args: + s_0: Initial level of sparsity, :math:`s_i` + s_f: Target level of sparsity, :math:`s_f` + t: Current step, :math:`t` + t_0: Initial step, :math:`t_0` + dt: Pruning frequency, :math:`\Delta T` + n: Pruning steps, :math:`n` + initially_zero: Sets the level of sparsity to 0 before t_0. + If False, sets to s_0 + + Returns: + The sparsity level :math:`s_t` at the current step :math:`t` + """ + if initially_zero and t < t_0: + return 0 + s_t = s_f + (s_0 - s_f) * (1.0 - (t - t_0) / (dt * n)) ** 3 + s_t = _clamp(s_t, s_0, s_f) + return s_t + + def get_sl(self): + if not self._get_sl_called_within_step: + warnings.warn( + "To get the last sparsity level computed by the scheduler, " + "please use `get_last_sl()`.", + stacklevel=2, + ) + return [ + self.sparsity_compute_fn( + s_0=initial_sparsity, + s_f=final_sparsity, + t=self.last_epoch, + t_0=initial_epoch, + dt=delta_epoch, + n=interval_epochs, + initially_zero=initially_zero, + ) + for initial_sparsity, final_sparsity, initial_epoch, delta_epoch, interval_epochs, initially_zero in zip( + self.init_sl, + self.base_sl, + self.init_t, + self.delta_t, + self.total_t, + self.initially_zero, + ) + ] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/scheduler/lambda_scheduler.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/scheduler/lambda_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5737095bf6662ba13a22a8ee8287d07263c05f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/scheduler/lambda_scheduler.py @@ -0,0 +1,64 @@ +import warnings +from collections.abc import Callable + +from torch.ao.pruning.sparsifier.base_sparsifier import BaseSparsifier + +from .base_scheduler import BaseScheduler + + +__all__ = ["LambdaSL"] + + +class LambdaSL(BaseScheduler): + """Sets the sparsity level of each parameter group to the final sl + times a given function. When last_epoch=-1, sets initial sl as zero. + Args: + sparsifier (BaseSparsifier): Wrapped sparsifier. + sl_lambda (function or list): A function which computes a multiplicative + factor given an integer parameter epoch, or a list of such + functions, one for each group in sparsifier.param_groups. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + Example: + >>> # Assuming sparsifier has two groups. + >>> lambda1 = lambda epoch: epoch // 30 + >>> lambda2 = lambda epoch: 0.95**epoch + >>> # xdoctest: +SKIP + >>> scheduler = LambdaSL(sparsifier, sl_lambda=[lambda1, lambda2]) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__( + self, + sparsifier: BaseSparsifier, + sl_lambda: Callable[[int], float] | list[Callable[[int], float]], + last_epoch: int = -1, + verbose: bool = False, + ) -> None: + self.sparsifier = sparsifier + + if not isinstance(sl_lambda, list) and not isinstance(sl_lambda, tuple): + self.sl_lambdas = [sl_lambda] * len(sparsifier.groups) + else: + if len(sl_lambda) != len(sparsifier.groups): + raise ValueError( + f"Expected {len(sparsifier.groups)} lr_lambdas, but got {len(sl_lambda)}" + ) + self.sl_lambdas = list(sl_lambda) + super().__init__(sparsifier, last_epoch, verbose) # type: ignore[no-untyped-call] + + def get_sl(self) -> list[float]: + if not self._get_sl_called_within_step: + warnings.warn( + "To get the last sparsity level computed by the scheduler, " + "please use `get_last_sl()`.", + stacklevel=2, + ) + return [ + base_sl * lmbda(self.last_epoch) + for lmbda, base_sl in zip(self.sl_lambdas, self.base_sl) + ] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/base_sparsifier.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/base_sparsifier.py new file mode 100644 index 0000000000000000000000000000000000000000..1f55d63a26781a3875a5d3ee36fb0ee906a5a0d9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/base_sparsifier.py @@ -0,0 +1,359 @@ +# mypy: allow-untyped-defs +import abc +import copy +from collections import defaultdict +from typing import Any + +import torch +from torch import nn +from torch.nn.utils import parametrize +from torch.nn.utils.parametrize import type_before_parametrizations + +from .utils import ( + FakeSparsity, + get_arg_info_from_tensor_fqn, + module_contains_param, + module_to_fqn, + swap_module, +) + + +__all__ = ["BaseSparsifier"] + +SUPPORTED_MODULES = {nn.Linear} + +KEYS_NOT_IN_STATE_DICT = ["module", "module_fqn", "tensor_name"] + + +# TODO update desc with new config args +class BaseSparsifier(abc.ABC): + r"""Base class for all sparsifiers. + + Abstract methods that need to be implemented: + + - update_mask: Function to compute a new mask for all keys in the + `groups`. + + Args: + - model [nn.Module]: model to configure. The model itself is not saved + but used for the state_dict saving / loading. + - config [list]: configuration elements should be a dict map that includes + `tensor_fqn` of tensors to sparsify + - defaults [dict]: default configurations will be attached to the + configuration. Only the keys that don't exist in the `config` will + be updated. + + Example:: + + >>> # xdoctest: +SKIP("Can't instantiate abstract class BaseSparsifier with abstract method update_mask") + >>> config = [{'tensor_fqn': 'layer1.weight', 'tensor_fqn': 'linear2.weight2', 'sparsity_level': 0.5}] + >>> defaults = {'sparsity_level': 0.7} + >>> # model.layer1.weight will have `sparsity_level` = 0.7 (getting default) + >>> sparsifier = BaseSparsifier(config, defaults) + """ + + def __init__(self, defaults: dict[str, Any] | None = None): + super().__init__() + self.defaults: dict[str, Any] = defaults or {} + + self.state: dict[str, dict] = defaultdict(dict) + self.groups: list[dict[str, Any]] = [] + self.enable_mask_update = True + + def __getstate__(self) -> dict[str, Any]: + return { + "defaults": self.defaults, + "state": self.state, + "groups": self.groups, + } + + def __setstate__(self, state: dict[str, dict[str, Any]]) -> None: + self.__dict__.update(state) + + def __repr__(self): + format_string = self.__class__.__name__ + " (" + for i, sparse_args in enumerate(self.groups): + module = sparse_args["module"] + format_string += "\n" + format_string += f"\tGroup {i}\n" + format_string += f"\t module: {module}\n" + for key in sorted(sparse_args.keys()): + if key == "module": + continue + format_string += f"\t {key}: {sparse_args[key]}\n" + format_string += ")" + return format_string + + def state_dict(self) -> dict[str, Any]: + r"""Returns the state of the optimizer as a :class:`dict`. + + It contains: + * state - current state of the sparsification. + * groups - a list containing all sparsity configuration groups + with the key 'tensor_fqn' specifying the path to the sparsified tensor within a model + + TODO: Need a clean way of loading the state of the "prepared" module + """ + + groups: list[dict[str, Any]] = [ + dict( + filter( + lambda key_value: key_value[0] not in KEYS_NOT_IN_STATE_DICT, + mg.items(), + ) + ) + for mg in self.groups + ] + + return { + "state": self.state, + "groups": groups, + } + + def load_state_dict(self, state_dict: dict[str, Any], strict: bool = True): + groups = copy.deepcopy(state_dict["groups"]) + states = state_dict["state"] + for tensor_fqn, s in states.items(): + arg_info = get_arg_info_from_tensor_fqn(self.model, tensor_fqn) + module = arg_info["module"] + tensor_name = arg_info["tensor_name"] + if strict and module is None: + raise RuntimeError(f"Error loading {tensor_fqn} into the model") + + found = False + for p in module.parametrizations[tensor_name]: + if isinstance(p, FakeSparsity): + found = True + break + if not found: + p = FakeSparsity(torch.ones(getattr(module, tensor_name).shape)) + parametrize.register_parametrization(module, tensor_name, p) + if s.get("mask", None) is not None: + mask = s.pop("mask") + p.mask = mask + + for mg in groups: + if mg["tensor_fqn"] == tensor_fqn: + mg.update(arg_info) + self.__setstate__({"state": states, "groups": groups}) + + def make_config_from_model( + self, + model: nn.Module, + SUPPORTED_MODULES: set[type[nn.Linear]] = SUPPORTED_MODULES, + ) -> None: + self.config = [] + stack = [model] + while stack: + module = stack.pop() + for _name, child in module.named_children(): + if type(child) in SUPPORTED_MODULES: + module_fqn = module_to_fqn(model, child) + if not isinstance(module_fqn, str): + raise AssertionError("module_fqn must be a string") + self.config.append({"tensor_fqn": module_fqn + ".weight"}) + else: + stack.append(child) + + def prepare(self, model, config): + r"""Prepares a model, by adding the parametrizations. + + Note:: + + The model is modified inplace. If you need to preserve the original + model, use copy.deepcopy. + """ + self.model = model # TODO: Need to figure out how to load without this. + self.config = config + + # If no config -- try getting all the supported layers + if self.config is None: + self.make_config_from_model(model) + + # TODO: Remove the configuration by reference ('module') + # pyrefly: ignore [not-iterable] + for module_config in self.config: + if not isinstance(module_config, dict): + raise AssertionError( + "config elements should be dicts not modules i.e.:" + "[{`tensor_fqn`: `foo.bar.weight`}, {`tensor_fqn`: ... }, ...]" + ) + + if not isinstance(self.defaults, dict): + raise AssertionError("defaults must be a dict") + local_args = copy.deepcopy(self.defaults) + local_args.update(module_config) + + tensor_fqn = local_args.get("tensor_fqn", None) + if tensor_fqn is None: + raise AssertionError( + "tensor_fqn is a required argument in the sparsity config which" + "replaces previous `module` and [module]`fqn` arguments" + ) + + # populate all information from tensor_fqn + info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn) + + # check that whatever was put into local_args agrees with what was obtained + # from tensor_fqn + for key in info_from_tensor_fqn: + if key in local_args: + if not ( + info_from_tensor_fqn[key] == local_args[key] + or ( + key == "tensor_fqn" + and "." + info_from_tensor_fqn[key] == local_args[key] + ) + # info_from_tensor_fqn will chop leading '.' from tensor_fqn so ignore that + ): + raise AssertionError( + f"Given both `{key}` and `tensor_fqn` in the config, it is expected them to agree!" + ) + local_args.update(info_from_tensor_fqn) + self.groups.append(local_args) + self._prepare() + + def _prepare(self, *args, **kwargs): + r"""Adds mask parametrization to the layer weight""" + for config in self.groups: + module = config["module"] + tensor_name = config["tensor_name"] + parametrization = config.get("parametrization", FakeSparsity) + mask = config.get("mask", torch.ones_like(getattr(module, tensor_name))) + self.state[config["tensor_fqn"]]["mask"] = mask + parametrize.register_parametrization( + module, tensor_name, parametrization(mask) + ) + + def squash_mask( + self, + params_to_keep: tuple[str, ...] | None = None, + params_to_keep_per_layer: dict[str, tuple[str, ...]] | None = None, + *args, + **kwargs, + ): + r"""Squashes the sparse masks into the appropriate tensors. + + If either the `params_to_keep` or `params_to_keep_per_layer` is set, + the module will have a `sparse_params` dict attached to it. + + Args: + params_to_keep: List of keys to save in the module or a dict + representing the modules and keys that will have + sparsity parameters saved + params_to_keep_per_layer: Dict to specify the params that should be + saved for specific layers. The keys in the dict + should be the module fqn, while the values should + be a list of strings with the names of the variables + to save in the `sparse_params` + + Examples: + >>> # xdoctest: +SKIP("locals are undefined") + >>> # Don't save any sparse params + >>> sparsifier.squash_mask() + >>> hasattr(model.submodule1, "sparse_params") + False + + >>> # Keep sparse params per layer + >>> sparsifier.squash_mask( + ... params_to_keep_per_layer={ + ... "submodule1.linear1": ("foo", "bar"), + ... "submodule2.linear42": ("baz",), + ... } + ... ) + >>> print(model.submodule1.linear1.sparse_params) + {'foo': 42, 'bar': 24} + >>> print(model.submodule2.linear42.sparse_params) + {'baz': 0.1} + + >>> # Keep sparse params for all layers + >>> sparsifier.squash_mask(params_to_keep=("foo", "bar")) + >>> print(model.submodule1.linear1.sparse_params) + {'foo': 42, 'bar': 24} + >>> print(model.submodule2.linear42.sparse_params) + {'foo': 42, 'bar': 24} + + >>> # Keep some sparse params for all layers, and specific ones for + >>> # some other layers + >>> sparsifier.squash_mask( + ... params_to_keep=("foo", "bar"), + ... params_to_keep_per_layer={"submodule2.linear42": ("baz",)}, + ... ) + >>> print(model.submodule1.linear1.sparse_params) + {'foo': 42, 'bar': 24} + >>> print(model.submodule2.linear42.sparse_params) + {'foo': 42, 'bar': 24, 'baz': 0.1} + """ + for config in self.groups: + module = config["module"] + tensor_name = config["tensor_name"] + parametrize.remove_parametrizations( + module, tensor_name, leave_parametrized=True + ) + sparse_params = {} + if params_to_keep is not None: + global_params = {k: config[k] for k in params_to_keep} + sparse_params.update(global_params) + if params_to_keep_per_layer is not None: + params = params_to_keep_per_layer.get(config["module_fqn"], None) + if params is not None: + per_layer_params = {k: config[k] for k in params} + sparse_params.update(per_layer_params) + if sparse_params: + # TODO handle multiple tensor being quantized on a single module, where to store sparse_params? + module.sparse_params = sparse_params + + def convert( + self, + module: nn.Module, + mapping: dict[type[nn.Module], type[nn.Module]] | None = None, + inplace: bool = False, + parameterization: type[nn.Module] = FakeSparsity, + ): + r"""Converts submodules in input module to a different module according to `mapping` + by calling `from_dense` method on the target module class + Args: + module: input module + mapping: a dictionary that maps from source module type to target + module type, can be overwritten to allow swapping user defined + Modules + inplace: carry out model transformations in-place, the original module + is mutated + """ + if mapping is None: + raise NotImplementedError("Need to auto generate mapping ") + if not inplace: + module = copy.deepcopy(module) + + reassign = {} + for name, mod in module.named_children(): + # leaf node + if ( + module_contains_param(mod, parameterization) + and type_before_parametrizations(mod) in mapping + ): + reassign[name] = swap_module(mod, mapping) + else: + # recurse + reassign[name] = self.convert( + mod, + mapping=mapping, + inplace=True, + parameterization=parameterization, + ) + + for key, value in reassign.items(): + module._modules[key] = value + + return module + + def step(self, use_path: bool = True) -> None: + if not self.enable_mask_update: + return + with torch.no_grad(): + for config in self.groups: + self.update_mask(**config) + + @abc.abstractmethod + def update_mask(self, module: nn.Module, tensor_name: str, **kwargs): + pass diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py new file mode 100644 index 0000000000000000000000000000000000000000..26fb3a98b8fb7d37e6bd5965d1d41b091d3e4818 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py @@ -0,0 +1,60 @@ +# mypy: allow-untyped-defs +import torch + +from . import base_sparsifier + + +class NearlyDiagonalSparsifier(base_sparsifier.BaseSparsifier): + r"""Nearly Diagonal Sparsifier + + This sparsifier creates a nearly diagonal mask to be applied to the weight matrix. + Nearly Diagonal Matrix is a matrix that contains non-zero elements near the diagonal and the rest are zero. + An example of a nearly diagonal matrix with degree (or nearliness) 3 and 5 are follows respectively. + 1 1 0 0 1 1 1 0 + 1 1 1 0 1 1 1 1 + 0 1 1 1 1 1 1 1 + 0 0 1 1 0 1 1 1 + Note that a nearly diagonal matrix with degree 1 is just a matrix with main diagonal populated + + This sparsifier is controlled by one variable: + 1. `nearliness` defines the number of non-zero diagonal lines that are closest to the main diagonal. + Currently - supports only odd number + + Note: + This can be accelerated (vectorized) once the Spdiagonal feature (PR: #78439) is landed or the banded matrix + feature is landed: https://stackoverflow.com/questions/52463972/generating-banded-matrices-using-numpy + + Args: + nearliness: The degree of nearliness (default = 1) + + """ + + def __init__(self, nearliness: int = 1): + defaults = {"nearliness": nearliness} + super().__init__(defaults=defaults) + + def update_mask( # type:ignore[override] + self, module, tensor_name, nearliness, **kwargs + ): + mask = getattr(module.parametrizations, tensor_name)[0].mask + mask.data = torch.zeros_like(mask) + if nearliness <= 0: + return + + tensor = getattr(module, tensor_name) + height, width = tensor.shape + + if nearliness % 2 == 0: + raise ValueError("nearliness can only be an odd number") + dist_to_diagonal = nearliness // 2 + # check + if dist_to_diagonal >= min(height, width): + raise ValueError( + "nearliness cannot be larger than the dimensions of tensor." + ) + + for row in range(height): + # Bounds of entries that needs to be set to 1 + low = max(0, row - dist_to_diagonal) + high = min(width, row + dist_to_diagonal + 1) + mask[row, low:high].fill_(1) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..97461630bc3ae9ce60cd02ce13a2371d9ba05536 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/utils.py @@ -0,0 +1,141 @@ +# mypy: allow-untyped-defs +from itertools import chain +from typing import Any + +from torch import nn +from torch.nn.utils.parametrize import is_parametrized, type_before_parametrizations + + +__all__ = [ + "module_contains_param", + "swap_module", + "module_to_fqn", + "fqn_to_module", + "get_arg_info_from_tensor_fqn", + "FakeSparsity", +] + + +def module_contains_param(module: nn.Module, parametrization: type[nn.Module]) -> bool: + if is_parametrized(module): + # see if any of the module tensors have a parametriztion attached that matches the one passed in + return any( + any(isinstance(param, parametrization) for param in param_list) + for key, param_list in module.parametrizations.items() # type: ignore[union-attr,operator] + ) + return False + + +def swap_module( + mod: nn.Module, mapping: dict[type[nn.Module], type[nn.Module]] +) -> nn.Module: + r"""Swaps the module using from_dense according to the mapping passed in. + Args: + mod: input module + mapping: a dictionary that maps from nn module to sparse nn module + Return: + The corresponding sparse module of `mod` according to mapping, created using from_dense + """ + if type_before_parametrizations(mod) in mapping: + sparse_mod = mapping[type_before_parametrizations(mod)] + + # TODO Fix this typing, as Type[Module] has no attribute "from_dense" + new_mod = sparse_mod.from_dense(mod) # type: ignore[attr-defined] + + # Preserve module's pre forward hooks. They'll be called on quantized input + for pre_hook_fn in mod._forward_pre_hooks.values(): + new_mod.register_forward_pre_hook(pre_hook_fn) + # Preserve module's post forward hooks except _observer_forward_hook + # After convert they'll work with quantized output + for hook_fn in mod._forward_hooks.values(): + new_mod.register_forward_hook(hook_fn) + + # respect device affinity when swapping modules + # pyrefly: ignore [bad-argument-type] + devices = {p.device for p in chain(mod.parameters(), mod.buffers())} + if len(devices) > 1: + raise AssertionError( + f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" + ) + device = next(iter(devices)) if len(devices) > 0 else None + if device: + new_mod.to(device) + + return new_mod + + else: + return mod + + +def module_to_fqn(model: nn.Module, module: nn.Module, prefix: str = "") -> str | None: + """ + Returns the fqn for a module or None if module not a descendent of model. + """ + if module is model: + return "" + for name, child in model.named_children(): + fqn = module_to_fqn(child, module, ".") + if isinstance(fqn, str): + return prefix + name + fqn + return None + + +def fqn_to_module(model: nn.Module | None, path: str) -> nn.Module | None: + """ + Given an fqn, returns the corresponding module or tensor or None if the fqn given by `path` + doesn't correspond to anything. Similar to model.get_submodule(path) but works for tensors. + """ + if path != "": + for name in path.split("."): + model = getattr(model, name, None) + return model + + +def get_arg_info_from_tensor_fqn(model: nn.Module, tensor_fqn: str) -> dict[str, Any]: + """ + Uses tensor_fqn to obtain a dict containing module_fqn, module and tensor_name + """ + # string manip to split tensor_fqn into module_fqn and tensor_name + # if tensor_fqn is 'weight' then module_fqn and tensor_name are '' and 'weight' + # if tensor_fqn is 'linear.weight' then module_fqn and tensor_name are 'linear' and 'weight' + tensor_name = tensor_fqn.rsplit(".", maxsplit=1)[-1] + module_fqn = tensor_fqn[: -len(tensor_name) - ("." in tensor_fqn)] + + module = fqn_to_module(model, module_fqn) + + return { + "module_fqn": module_fqn, + "module": module, + "tensor_name": tensor_name, + "tensor_fqn": tensor_fqn, + } + + +# Parametrizations +class FakeSparsity(nn.Module): + r"""Parametrization for the weights. Should be attached to the 'weight' or + any other parameter that requires a mask applied to it. + + Note:: + + Once the mask is passed, the variable should not change the id. The + contents of the mask can change, but the mask reference itself should + not. + """ + + def __init__(self, mask): + super().__init__() + self.register_buffer("mask", mask) + + def forward(self, x): + if self.mask.shape != x.shape: + raise AssertionError( + f"mask shape ({self.mask.shape}) must match x shape ({x.shape})" + ) + return self.mask * x + + def state_dict(self, *args, **kwargs): + # We don't want to let the parametrizations to save the mask. + # That way we make sure that the linear module doesn't store the masks + # alongside their parametrizations. + return {} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py new file mode 100644 index 0000000000000000000000000000000000000000..0fd0368f156744f1af362670fa73baf505f50251 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py @@ -0,0 +1,250 @@ +# mypy: allow-untyped-defs +import operator +from collections.abc import Callable +from functools import reduce + +import torch +import torch.nn.functional as F + +from .base_sparsifier import BaseSparsifier + + +__all__ = ["WeightNormSparsifier"] + + +def _flat_idx_to_2d(idx, shape): + rows = idx // shape[1] + cols = idx % shape[1] + return rows, cols + + +class WeightNormSparsifier(BaseSparsifier): + r"""Weight-Norm Sparsifier + + This sparsifier computes the norm of every sparse block and "zeroes-out" the + ones with the lowest norm. The level of sparsity defines how many of the + blocks is removed. + + This sparsifier is controlled by three variables: + 1. `sparsity_level` defines the number of *sparse blocks* that are zeroed-out + 2. `sparse_block_shape` defines the shape of the sparse blocks. Note that + the sparse blocks originate at the zero-index of the tensor. + 3. `zeros_per_block` is the number of zeros that we are expecting in each + sparse block. By default we assume that all elements within a block are + zeroed-out. However, setting this variable sets the target number of + zeros per block. The zeros within each block are chosen as the *smallest + absolute values*. + + Args: + + sparsity_level: The target level of sparsity + sparse_block_shape: The shape of a sparse block (see note below) + zeros_per_block: Number of zeros in a sparse block + norm: Norm to use. Could be either `int` or a callable. + If `int`, only L1 and L2 are implemented. + + Note:: + The `sparse_block_shape` is tuple representing (block_ROWS, block_COLS), + irrespective of what the rows / cols mean in the data tensor. That means, + if you were to sparsify a weight tensor in the nn.Linear, which has a + weight shape `(Cout, Cin)`, the `block_ROWS` would refer to the output + channels, while the `block_COLS` would refer to the input channels. + + Note:: + All arguments to the WeightNormSparsifier constructor are "default" + arguments and could be overridden by the configuration provided in the + `prepare` step. + """ + + def __init__( + self, + sparsity_level: float = 0.5, + sparse_block_shape: tuple[int, int] = (1, 4), + zeros_per_block: int | None = None, + norm: Callable | int | None = None, + ): + if zeros_per_block is None: + zeros_per_block = reduce(operator.mul, sparse_block_shape) + defaults = { + "sparsity_level": sparsity_level, + "sparse_block_shape": sparse_block_shape, + "zeros_per_block": zeros_per_block, + } + if norm is None: + norm = 2 + if callable(norm): + self.norm_fn = norm + elif norm == 1: + self.norm_fn = lambda T: T.abs() + elif norm == 2: + self.norm_fn = lambda T: T * T + else: + raise NotImplementedError(f"L-{norm} is not yet implemented.") + super().__init__(defaults=defaults) + + def _scatter_fold_block_mask( + self, + output_shape, + dim, + indices, + block_shape, + mask=None, + input_shape=None, + device=None, + ): + r"""Creates patches of size `block_shape` after scattering the indices.""" + if mask is None: + if input_shape is None: + raise AssertionError("input_shape must be provided when mask is None") + mask = torch.ones(input_shape, device=device) + mask.scatter_(dim=dim, index=indices, value=0) + mask.data = F.fold( + mask, output_size=output_shape, kernel_size=block_shape, stride=block_shape + ) + return mask + + def _make_tensor_mask( + self, data, input_shape, sparsity_level, sparse_block_shape, mask=None + ): + r"""Creates a tensor-level mask. + + Tensor-level mask is described as a mask, where the granularity of sparsification of the + smallest patch is the sparse_block_shape. That means, that for a given mask and a + sparse_block_shape, the smallest "patch" of zeros/ones could be the sparse_block_shape. + + In this context, `sparsity_level` describes the fraction of sparse patches. + """ + h, w = data.shape[-2:] + block_h, block_w = sparse_block_shape + dh = (block_h - h % block_h) % block_h + dw = (block_w - w % block_w) % block_w + + if mask is None: + mask = torch.ones(h + dh, w + dw, device=data.device) + + if sparsity_level >= 1.0: + mask.data = torch.zeros_like(mask) + return mask + elif sparsity_level <= 0.0: + mask.data = torch.ones_like(mask) + return mask + + values_per_block = reduce(operator.mul, sparse_block_shape) + if values_per_block > 1: + # Reduce the data + data = F.avg_pool2d( + data[None, None, :], + kernel_size=sparse_block_shape, + stride=sparse_block_shape, + ceil_mode=True, + ) + data = data.flatten() + num_blocks = len(data) + + data = data.repeat(1, values_per_block, 1) + + threshold_idx = round(sparsity_level * num_blocks) + threshold_idx = max(0, min(num_blocks - 1, threshold_idx)) # Sanity check + _, sorted_idx = torch.topk(data, k=threshold_idx, dim=2, largest=False) + + # Temp reshape for mask + mask_reshape = mask.reshape(data.shape) # data might be reshaped + self._scatter_fold_block_mask( + dim=2, + output_shape=(h + dh, w + dw), + indices=sorted_idx, + block_shape=sparse_block_shape, + mask=mask_reshape, + ) + mask.data = mask_reshape.squeeze().reshape(mask.shape)[:h, :w].contiguous() + return mask + + def _make_block_mask(self, data, sparse_block_shape, zeros_per_block, mask=None): + r"""Creates a block-level mask. + + Block-level mask is described as a mask, where the granularity of sparsification of the + largest patch is the sparse_block_shape. That means that for a given mask and a + sparse_block_shape, the sparsity is computed only within a patch of a size sparse_block_shape. + + In this context the `zeros_per_block` describes the number of zeroed-out elements within a patch. + """ + h, w = data.shape[-2:] + block_h, block_w = sparse_block_shape + dh = (block_h - h % block_h) % block_h + dw = (block_w - w % block_w) % block_w + values_per_block = reduce(operator.mul, sparse_block_shape) + + if mask is None: + mask = torch.ones((h + dh, w + dw), device=data.device) + + if values_per_block == zeros_per_block: + # Everything should be sparsified + mask.data = torch.zeros_like(mask) + return mask + + # create a new padded tensor like data (to match the block_shape) + padded_data = torch.ones(h + dh, w + dw, dtype=data.dtype, device=data.device) + padded_data.fill_(torch.nan) + padded_data[:h, :w] = data + unfolded_data = F.unfold( + padded_data[None, None, :], + kernel_size=sparse_block_shape, + stride=sparse_block_shape, + ) + + # Temp reshape for mask + mask_reshape = mask.reshape(unfolded_data.shape) + _, sorted_idx = torch.topk( + unfolded_data, k=zeros_per_block, dim=1, largest=False + ) + + self._scatter_fold_block_mask( + dim=1, + indices=sorted_idx, + output_shape=padded_data.shape, + block_shape=sparse_block_shape, + mask=mask_reshape, + ) + + mask.data = mask_reshape.squeeze().reshape(mask.shape).contiguous() + return mask + + def update_mask( # type: ignore[call-override, override] + self, + module, + tensor_name, + sparsity_level, + sparse_block_shape, + zeros_per_block, + **kwargs, + ): + values_per_block = reduce(operator.mul, sparse_block_shape) + if zeros_per_block > values_per_block: + raise ValueError( + "Number of zeros per block cannot be more than the total number of elements in that block." + ) + if zeros_per_block < 0: + raise ValueError("Number of zeros per block should be positive.") + + mask = getattr(module.parametrizations, tensor_name)[0].mask + if sparsity_level <= 0 or zeros_per_block == 0: + mask.data = torch.ones_like(mask) + elif sparsity_level >= 1.0 and (zeros_per_block == values_per_block): + mask.data = torch.zeros_like(mask) + else: + ww = self.norm_fn(getattr(module, tensor_name)) + tensor_mask = self._make_tensor_mask( + data=ww, + # pyrefly: ignore [missing-attribute] + input_shape=ww.shape, + sparsity_level=sparsity_level, + sparse_block_shape=sparse_block_shape, + ) + if values_per_block != zeros_per_block: + block_mask = self._make_block_mask( + data=ww, + sparse_block_shape=sparse_block_shape, + zeros_per_block=zeros_per_block, + ) + tensor_mask = torch.logical_or(tensor_mask, block_mask) + mask.data = tensor_mask diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/fake_quantize.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/fake_quantize.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ed339c55da3c5953f4ea3313909d04591b75237 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/fake_quantize.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quant_type.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quant_type.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91b357e29d951d63c36d27286cacb84720286c6a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quant_type.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf4f1912a987639fc987a5061d7ef68c3af1cff0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/_common_operator_config_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/_common_operator_config_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d2376b8009fd385c4b2415f3b417d1d1f28157c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/_common_operator_config_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/_qnnpack_pt2e.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/_qnnpack_pt2e.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50f07653ea8e01f54bd6750cba77db0e7d234ac1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/_qnnpack_pt2e.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/backend_config.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/backend_config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcf6cea9a6f9b7056c04a11ff379654f0c162097 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/backend_config.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/executorch.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/executorch.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89f17b0cf1671956a7966e27247ec1bccccef369 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/executorch.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/fbgemm.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/fbgemm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6c217cac8b2134a2d6ecd7a2a7366af639b4d3e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/fbgemm.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/native.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/native.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1c0b70a48d42d2dd9e1b57fec8ae928e2a82d76 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/native.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/onednn.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/onednn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8dbf67a3c7f0cfe53e0290fe30ad59eb20c46c3c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/onednn.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/qnnpack.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/qnnpack.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fffea6adda1adefd8af828007f692b23d6e709d5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/qnnpack.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/tensorrt.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/tensorrt.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..acd1e7bf3fe45cb7d7b13a003c11df895d9d0e68 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/tensorrt.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f3f383e876f0aada64eed7af5be2397b1721533 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/x86.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/x86.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccbd9d8346f07b58eb7d5852ef46e4617e07118a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__pycache__/x86.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a052c1d398317e2e5d8ee1e6950239e60821348f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/_decomposed.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/_decomposed.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6005be8e14ba8bcddd512f678cc3a6a23708e81c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/_decomposed.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/_equalize.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/_equalize.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec9ce914af6b81f96c14c6e1db3fc81943dd5725 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/_equalize.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/_lower_to_native_backend.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/_lower_to_native_backend.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c95493a899fdf95541d164644cc55ec954c87cf4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/_lower_to_native_backend.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/convert.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/convert.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..671f92d91aa661ce936312a98c4712f6b810fb92 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/convert.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/custom_config.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/custom_config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5dcf79036af3d8ae9cb7e83a9dd1617cea80e4d1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/custom_config.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/fuse.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/fuse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b88076c7f20f3caa623f83f74fb7a9ac39d7c92 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/fuse.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/fuse_handler.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/fuse_handler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08630992a0d5c93b290067bdc57ec5608bae861d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/fuse_handler.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/graph_module.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/graph_module.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ead04ade3ec129994674bd9c268904086cd66d5a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/graph_module.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_fbgemm.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_fbgemm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffde27f01d77c147ba41808018eb4eac769ab54c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_fbgemm.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_qnnpack.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_qnnpack.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6511c15e5615e753668afc58abb1611196a4f24 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_qnnpack.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/lstm_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/lstm_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4537f24bc26d497b31601d89bb9aa2c93fe4b03 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/lstm_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/match_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/match_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96716916020b949fe28e565b2490deefce7e73e3 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/match_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/pattern_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/pattern_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad8e1575a75316d0cd4e6119c0add4c68473e41a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/pattern_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/prepare.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/prepare.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4731c0f27c763ced32599ab562017a891dce4a40 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/prepare.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/qconfig_mapping_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/qconfig_mapping_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6eb104be7e28c56b4efd99d0f1943ad7e00932cd Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/qconfig_mapping_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/quantize_handler.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/quantize_handler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d1a4528b81e6e917e4d51af229482177b6f4285 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/quantize_handler.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/tracer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/tracer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb87a38b660710e13f7df88b7f0f3a57fcff7d09 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/tracer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39a3c4d2b35d5b926f9e9b240204d154e2251076 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__pycache__/utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d33055a0ee3052dfa8e0a78f98dbcbab89b35c9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/detector.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/detector.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c82a99f06ef08601e185b4a583a285751b0abdcf Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/detector.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cf90de037ce1bf5c858434aed71d0b309726455 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_observer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_observer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59acb80622fdc4e849265d30c7249f2a9e326523 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_observer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_visualizer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_visualizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc4aeffce98ce7821929625cb8c286e63a2b3c9c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_visualizer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/detector.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/detector.py new file mode 100644 index 0000000000000000000000000000000000000000..0a48bbbaaee901871d41396e0583642c4d486dce --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/detector.py @@ -0,0 +1,1743 @@ +# mypy: allow-untyped-defs +from abc import ABC, abstractmethod +from collections.abc import Callable +from typing import Any + +import torch +import torch.ao.nn.qat as nnqat +import torch.nn as nn +from torch.ao.quantization.fake_quantize import FakeQuantize +from torch.ao.quantization.fx._equalize import ( + default_equalization_qconfig, + EqualizationQConfig, +) +from torch.ao.quantization.fx._model_report.model_report_observer import ( + ModelReportObserver, +) +from torch.ao.quantization.fx.graph_module import GraphModule +from torch.ao.quantization.observer import ( + _is_activation_post_process, + default_dynamic_quant_observer, + default_observer, + default_per_channel_weight_observer, + default_weight_observer, + ObserverBase, +) +from torch.ao.quantization.qconfig import ( + _assert_valid_qconfig, + default_qconfig, + QConfig, +) + + +# Names for observer insert keys +DETECTOR_TARGET_NODE_KEY = "target_node" +DETECTOR_OBS_TO_INSERT_KEY = "observer_to_insert" +DETECTOR_IS_POST_OBS_KEY = "is_post_observer" +DETECTOR_OBS_ARGS_KEY = "observer_args" + + +# Mapping related code +class DetectorQConfigInfo: + r""" + This class contains the QConfig information for a single module. + The list of variables / values this contains can grow depending on the + extensibility of the qconfig mapping feature set but this currently includes: + - if activation observer is dynamic + - if weight observer is per channel + + + Args: + module_fqn (str): The fully qualified name (fqn) of the module that this + information contains info relevant to qconfig for + """ + + def __init__(self, module_fqn: str): + super().__init__() + self.module_fqn = module_fqn + + # populate this section with all the variables we might find important + # change from none if your detector is actually using this + self.is_activation_dynamic = False + self.is_weight_per_channel = False + + # equalization related options + self.is_equalization_recommended = False + + def generate_quantization_qconfig(self, module: torch.nn.Module) -> QConfig: + r""" + Args: + module (torch.nn.Module) The module we are generating + the qconfig for + + Returns the generated quantization QConfig according to what a valid configuration is + """ + # Apply suggestions to new qconfig + module_qconfig = default_qconfig + + # keep track of dynamic and per_channel recommendations + recommendations_list = [] + # append as if a list of combinations + recommendations_list.append( + (self.is_activation_dynamic, self.is_weight_per_channel) + ) + recommendations_list.append( + (self.is_activation_dynamic, False) + ) # only trying dynamic rec + recommendations_list.append( + (False, self.is_weight_per_channel) + ) # only trying dynamic + + # now we try each of the combinations + for rec in recommendations_list: + # rec[0] -> dynamic recommended + # rec[1] -> per channel recommended + activation = default_dynamic_quant_observer if rec[0] else default_observer + weight = ( + default_per_channel_weight_observer + if rec[1] + else default_weight_observer + ) + test_config = QConfig(activation, weight) + try: + _assert_valid_qconfig(test_config, module) + module_qconfig = test_config + break + except AssertionError: + # if not a valid configuration, we move on to the next one in priority + continue + + # return the QConfig chosen + return module_qconfig + + def generate_equalization_qconfig(self) -> EqualizationQConfig: + r""" + This returns the equalization configuration for a module. + + For now, it just returns the default, but as more equalization options become + possible, this method can get more fleshed out with more nuanced granularity. + + + Returns the generated equalization QConfig according to what a valid configuration is + """ + # in this case, we just return default equalization config + # we know this is valid because only valid modules would even + # have this option + return default_equalization_qconfig + + +# Adding base class for detectors +class DetectorBase(ABC): + r"""Base Detector Module + Any detector class should derive from this class. + + Concrete detectors should follow the same general API, which includes: + - A method to calculate and return observer insertion points + - Should return both the fqns and the Observer class to insert + - A method to return a report based on the detector + - Should return a str-based report and dict info in Tuple[str,Dict] format + """ + + def __init__(self) -> None: + super().__init__() + self.detector_config_info = None + + @abstractmethod + def determine_observer_insert_points(self, model) -> dict: + r""" + Args + model (nn.Module or subclass): model to find observer insertion points + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict. + This dict maps string keys to detector specific information + """ + + @abstractmethod + def get_detector_name(self) -> str: + r"""Returns the name of the current detector""" + + @abstractmethod + def get_qconfig_info(self, model) -> dict[str, DetectorQConfigInfo]: + r"""Returns the DetectorQConfigInfo for each module_fqn relevant + Args + model (nn.Module or subclass): model to find observer insertion points + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to: + A DetectorQConfigInfo with the information to generate a QConfig for a specific module + """ + + def _get_targeting_node( + self, prepared_fx_model: GraphModule, target_fqn: str + ) -> torch.fx.node.Node: + r""" + Takes in a GraphModule and the target_fqn and finds the node whose target is this fqn. + + If it's not found, it means it is most likely inside a fused layer + We just go one layer up in terms of the fqn we are searching for until we find parent node + If we get to empty string, then we know that it doesn't exist + + The reason for the recursion is that if the model that we are looking for got fused, + we will have module fqn as e.g. x.linear.0 but the graph will only have a node for the fused module, + which would have fqn as x.linear so they will not match. + To handle this, if we don't match, we then take off the last bit of the fqn e.g. x.linear.0 -> x.linear, + or more generally foo.bar.baz -> foo.bar and search again, this will allow us to locate the correct module + even in cases with fusion + + Args: + prepared_fx_model (GraphModule): The prepared Fx GraphModule + target_fqn (str): The fqn of the layer we are trying to target + + Returns the node object we are trying to add observers around + """ + for node in prepared_fx_model.graph.nodes: + # if the node's target is our target, return it + if node.target == target_fqn: + return node + + # getting here means node not found + # if no "." we are already at base and failed + parent_fqn_sep_index = target_fqn.rfind(".") + if parent_fqn_sep_index == -1: + raise ValueError("passed in target_fqn not found in graph's targets.") + else: + # recursively call it with parent fqn + return self._get_targeting_node( + prepared_fx_model, target_fqn[:parent_fqn_sep_index] + ) + + @abstractmethod + def generate_detector_report(self, model) -> tuple[str, dict[str, Any]]: + r""" + Args + model (nn.Module or subclass): model to find observer insertion points + + Returns a Tuple of two elements: + Str: string report of the suggested improvements + Dict: contains useful data collected by the observer pertinent to this report + """ + + +class PerChannelDetector(DetectorBase): + r"""This class is used to detect if any Linear or Conv layers in a model utilize per_channel quantization. + Only Linear and Conv layers can use per_channel as of now so only these two are currently checked. + + per_channel quantization can lead to major benefits in the form of accuracy. + Therefore, if the backend used by the user supports it, it is recommended to use + + Args: + backend (str, optional): the backend the user wishes to use in production + Default value is current torch.backends.quantized.engine + """ + + # Keys for return dictionary + BACKEND_KEY = "backend" + PER_CHAN_SUPPORTED_KEY = "per_channel_quantization_supported" + PER_CHAN_USED_KEY = "per_channel_quantization_used" + + # Default map for representing supported per channel quantization modules for different backends + DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES: dict[str, set[Any]] = { + "fbgemm": { + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nnqat.Linear, + nnqat.Conv1d, + nnqat.Conv2d, + nnqat.Conv3d, + }, + "qnnpack": { + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nnqat.Linear, + nnqat.Conv1d, + nnqat.Conv2d, + nnqat.Conv3d, + }, + "onednn": { + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nnqat.Linear, + nnqat.Conv1d, + nnqat.Conv2d, + nnqat.Conv3d, + }, + "x86": { + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nnqat.Linear, + nnqat.Conv1d, + nnqat.Conv2d, + nnqat.Conv3d, + }, + } + + def __init__(self, backend: str = torch.backends.quantized.engine): + super().__init__() + + # store the backend information + self.backend_chosen = backend + self.supported_modules = set() + if self.backend_chosen in self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES: + self.supported_modules = self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES[ + self.backend_chosen + ] + else: + raise ValueError( + f"Not configured to work with {self.backend_chosen}. Try a different default backend" + ) + + def get_detector_name(self) -> str: + r"""returns the string name of this detector""" + return "per_channel_detector" + + def get_qconfig_info(self, model) -> dict[str, DetectorQConfigInfo]: + r"""Returns the DetectorQConfigInfo for each module_fqn relevant + Args + model (nn.Module or subclass): model to find observer insertion points + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to: + A DetectorQConfigInfo with the information to generate a QConfig for a specific module + """ + # run the helper function to populate the dictionary + per_channel_info = self._detect_per_channel_helper(model) + + # we actually have a qconfig info object we are populating + module_fqn_to_detector_qconfig_info = {} + + for module_fqn in per_channel_info: + # create a detector info instance + detector_qconfig_info = DetectorQConfigInfo(module_fqn) + + # see if per channel quantization is supported + per_chan_supported: bool = per_channel_info[module_fqn][ + self.PER_CHAN_SUPPORTED_KEY + ] + detector_qconfig_info.is_weight_per_channel = per_chan_supported + module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info + + return module_fqn_to_detector_qconfig_info + + def determine_observer_insert_points(self, model: nn.Module) -> dict: + r""" + There is no observers inserted for the PerChannelDetector. + + Returns an empty dictionary since no observers are added or needed + """ + return {} + + def _detect_per_channel_helper(self, model: nn.Module): + r""" + determines if per_channel quantization is supported in modules and submodules. + + Returns a dictionary in the higher level _detect_per_channel function. + Each entry maps the fully-qualified-name to information on whether per_channel quantization. + + Args: + model: The current module that is being checked to see if it is per_channel quantizable + + Returns dictionary mapping fqns to if per_channel quantization is possible + """ + # create dict we will return + per_channel_info: dict = {} + + # get the fully qualified name and check if in list of modules to include and list of modules to ignore + for fqn, module in model.named_modules(): + is_in_include_list = any( + isinstance(module, x) for x in self.supported_modules + ) + + # check if the module per_channel is supported + # based on backend + per_channel_supported = False + + if is_in_include_list: + per_channel_supported = True + + # assert statement for MyPy + q_config_file = module.qconfig + if not isinstance(q_config_file, QConfig): + raise AssertionError("module.qconfig must be a QConfig") + + # this object should either be fake quant or observer + q_or_s_obj = module.qconfig.weight.p.func() + if not isinstance(q_or_s_obj, (FakeQuantize, ObserverBase)): + raise AssertionError( + "module.qconfig.weight must be a FakeQuantize or ObserverBase" + ) + + per_channel_used = False # will be true if found in qconfig + + if hasattr( + q_or_s_obj, "ch_axis" + ): # then we know that per_channel quantization used + # all fake quants have channel axis so need to check is_per_channel + if isinstance(q_or_s_obj, FakeQuantize): + if ( + hasattr(q_or_s_obj, "is_per_channel") + and q_or_s_obj.is_per_channel + ): + per_channel_used = True + elif isinstance(q_or_s_obj, ObserverBase): + # should be an observer otherwise + per_channel_used = True + else: + raise ValueError("Should be either observer or fake quant") + + per_channel_info[fqn] = { + self.PER_CHAN_SUPPORTED_KEY: per_channel_supported, + self.PER_CHAN_USED_KEY: per_channel_used, + self.BACKEND_KEY: self.backend_chosen, + } + + return per_channel_info + + def generate_detector_report(self, model: nn.Module) -> tuple[str, dict[str, Any]]: + r"""Checks if any Linear or Conv layers in the model utilize per_channel quantization. + Only Linear and Conv layers can use per_channel as of now so only these two are currently checked. + + Looks at q_config format and backend to determine if per_channel can be utilized. + Uses the DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES structure to determine support + + Args: + model: The prepared and calibrated model we want to check if using per_channel + + Returns a tuple with two elements: + String report of potential actions to improve model (if per_channel quantization is available in backend) + Dictionary mapping per_channel quantizable elements to: + whether per_channel quantization is supported by the backend + if it is being utilized in the current model + """ + + # run the helper function to populate the dictionary + per_channel_info = self._detect_per_channel_helper(model) + + # String to let the user know of further optimizations + further_optims_str = ( + f"Further Optimizations for backend {self.backend_chosen}: \n" + ) + + optimizations_possible = False + for fqn in per_channel_info: + fqn_dict = per_channel_info[fqn] + if ( + fqn_dict[self.PER_CHAN_SUPPORTED_KEY] + and not fqn_dict[self.PER_CHAN_USED_KEY] + ): + optimizations_possible = True + further_optims_str += ( + f"Module {fqn} can be configured to use per_channel quantization.\n" + ) + + if optimizations_possible: + further_optims_str += "To use per_channel quantization, make sure the qconfig has a per_channel weight observer." + else: + further_optims_str += "No further per_channel optimizations possible." + + # return the string and the dictionary form of same information + return (further_optims_str, per_channel_info) + + +class DynamicStaticDetector(DetectorBase): + r""" + Determines whether dynamic or static quantization is more appropriate for a given module. + + Takes advantage of the ModelReportObserver that records range information. + Stationary distribution of data are strictly above tolerance level for the comparison statistic: + + S = average_batch_activation_range/epoch_activation_range + + Nonstationary distributions are below or at the tolerance level for this metric. + + If the distribution of data right after the module is non-stationary, recommend dynamic quantization + Otherwise recommend static quantization + + Args: + tolerance (float, optional): The threshold where S metric is stationary above and non-stationary otherwise. Default: 0.5 + """ + + # names for the pre and post observers that are inserted + DEFAULT_PRE_OBSERVER_NAME = "model_report_pre_observer" + DEFAULT_POST_OBSERVER_NAME = "model_report_post_observer" + + # naming conventions for stationary vs non-stationary data + STATIONARY_STR = "stationary" + NON_STATIONARY_STR = "non-stationary" + + # naming for activation + INPUT_ACTIVATION_PREFIX = "input_activation_" + OUTPUT_ACTIVATION_PREFIX = "output_activation_" + + # naming conventions for the keys of the return module info + TOLERANCE_KEY = "dynamic_static_tolerance" + DEFAULT_DYNAMIC_REC_KEY = "dynamic_recommended" + PRE_OBS_COMP_STAT_KEY = INPUT_ACTIVATION_PREFIX + "dynamic_static_comp_stat" + POST_OBS_COMP_STAT_KEY = OUTPUT_ACTIVATION_PREFIX + "dynamic_static_comp_stat" + PRE_OBS_DATA_DIST_KEY = ( + INPUT_ACTIVATION_PREFIX + "dynamic_static_data_classification" + ) + POST_OBS_DATA_DIST_KEY = ( + OUTPUT_ACTIVATION_PREFIX + "dynamic_static_data_classification" + ) + IS_CURRENTLY_SUPPORTED_KEY = "is_dynamic_supported" + + # modules that are supported both dynamic and static for this report function + DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED = {nn.Linear} + + # modules that will be supported soon for both + DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED = {nn.Conv1d, nn.Conv2d, nn.Conv3d} + + def __init__(self, tolerance=0.5): + super().__init__() + + # set tolerance level and initialize a set to keep track of useful fqn locations + self.tolerance = tolerance + self.useful_observer_fqns: set[str] = set() + + def determine_observer_insert_points( + self, prepared_fx_model: GraphModule + ) -> dict[str, dict[str, Any]]: + r""" + Determines where observers need to be inserted for the Dynamic vs Static detector. + For this detector, we want to place observers on either side of linear layers in the model. + + Currently inserts observers for: + linear layers + + Args: + prepared_fx_model (GraphModule): The prepared Fx GraphModule + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with: + key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node) + key "observer_to_insert" -> the observer we wish to insert (ObserverBase) + key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer + key "observer_args" -> The arguments that are meant to be passed into the observer + """ + + # observer for this detector is ModelReportObserver + obs_ctr = ModelReportObserver + + # return dict + obs_fqn_to_info: dict[str, dict[str, Any]] = {} + + for fqn, module in prepared_fx_model.named_modules(): + # make sure module is supported + if self._is_supported(module, insert=True): + # if it's a supported type, we want to get node and add observer insert locations + targeted_node = self._get_targeting_node(prepared_fx_model, fqn) + + # add entry for pre-observer + pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME + + obs_fqn_to_info[pre_obs_fqn] = { + DETECTOR_TARGET_NODE_KEY: targeted_node, + DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(), + DETECTOR_IS_POST_OBS_KEY: False, + DETECTOR_OBS_ARGS_KEY: targeted_node.args, + } + + # add entry for post-observer + post_obs_fqn = fqn + "." + self.DEFAULT_POST_OBSERVER_NAME + + obs_fqn_to_info[post_obs_fqn] = { + DETECTOR_TARGET_NODE_KEY: targeted_node, + DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(), + DETECTOR_IS_POST_OBS_KEY: True, + DETECTOR_OBS_ARGS_KEY: (targeted_node,), + } + + return obs_fqn_to_info + + def get_detector_name(self) -> str: + r"""returns the string name of this detector""" + return "dynamic_vs_static_detector" + + def get_qconfig_info(self, model) -> dict[str, DetectorQConfigInfo]: + r"""Returns the DetectorQConfigInfo for each module_fqn relevant + Args + model (nn.Module or subclass): model to find observer insertion points + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to: + A DetectorQConfigInfo with the information to generate a QConfig for a specific module + """ + # run the helper function to populate the dictionary + dynamic_static_info = self._generate_dict_info(model) + + # we actually have a qconfig info object we are populating + module_fqn_to_detector_qconfig_info = {} + + for module_fqn in dynamic_static_info: + # create a detector info instance + detector_qconfig_info = DetectorQConfigInfo(module_fqn) + + # see if per channel quantization is supported + dynamic_static_recommended: bool = dynamic_static_info[module_fqn][ + self.DEFAULT_DYNAMIC_REC_KEY + ] + detector_qconfig_info.is_activation_dynamic = dynamic_static_recommended + module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info + + return module_fqn_to_detector_qconfig_info + + def _is_supported(self, module: nn.Module, insert: bool = False) -> bool: + r"""Returns whether the given module is supported for observers + + Args + module: The module to check and ensure is supported + insert: True if this is check for observer insertion, false if for report gen + + Returns True if the module is supported by observer, False otherwise + """ + # check to see if module is of a supported type + is_supported_type = any( + isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED + ) + + # check if it will be supported + future_supported_type = any( + isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED + ) + + # supported + supported = is_supported_type or future_supported_type + + # this is check for observer insertion + if insert: + return supported + else: + # this is for report gen and we also need to check if it contains observers + has_obs = hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME) and hasattr( + module, self.DEFAULT_POST_OBSERVER_NAME + ) + return supported and has_obs + + def _generate_dict_info(self, model: GraphModule) -> dict[str, Any]: + r""" + Helper function for generate_detector_report that does the generation of the dictionary. + This process is done as specified in generate_detector_report documentation + + Args: + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a Dictionary mapping modules with ModelReportObservers around them to: + whether dynamic quantization is recommended + their S metric of input to module + whether input to module is stationary or non-stationary + their S metric of output of module + whether output of module is stationary or non-stationary + the tolerance level to decided whether input/output is stationary or non-stationary + whether it is currently supported or planned for the future + """ + # store modules dynamic vs static information + module_dynamic_static_info = {} + + # This for loop goes through the modules, and extracts all relevant information into module_dynamic_static_info + # This information primary includes whether the data distributions around a supported module is stationary or not + # Based on this, it is recorded whether dynamic or static quantization is recommended + + # loop through all submodules included nested ones + for fqn, module in model.named_modules(): + # if module is Linear has the ModelReportObserver attached to it + if self._is_supported(module): + # get pre and post observers for the module + pre_obs = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME) + post_obs = getattr(module, self.DEFAULT_POST_OBSERVER_NAME) + + # get the statistics for each module + pre_stat = pre_obs.get_batch_to_epoch_ratio() + post_stat = post_obs.get_batch_to_epoch_ratio() + + # record module, pre and post stat, and whether to do dynamic or static based off it + # true if post observer data distribution is non-stationary, false if it's stationary + dynamic_recommended = post_stat <= self.tolerance + + # specify the classifications for whether data distributions considered stationary or non-stationary + pre_obs_dist_classif = ( + self.STATIONARY_STR + if pre_stat > self.tolerance + else self.NON_STATIONARY_STR + ) + post_obs_dist_classif = ( + self.STATIONARY_STR + if post_stat > self.tolerance + else self.NON_STATIONARY_STR + ) + + # check if current support or future support + is_supported_type = any( + isinstance(module, x) + for x in self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED + ) + + # store the set of important information for this module + module_info = { + self.TOLERANCE_KEY: self.tolerance, + self.DEFAULT_DYNAMIC_REC_KEY: dynamic_recommended, + self.PRE_OBS_COMP_STAT_KEY: pre_stat, + self.PRE_OBS_DATA_DIST_KEY: pre_obs_dist_classif, + self.POST_OBS_COMP_STAT_KEY: post_stat, + self.POST_OBS_DATA_DIST_KEY: post_obs_dist_classif, + self.IS_CURRENTLY_SUPPORTED_KEY: is_supported_type, + } + + module_dynamic_static_info[fqn] = module_info + + return module_dynamic_static_info + + def generate_detector_report( + self, model: GraphModule + ) -> tuple[str, dict[str, Any]]: + r""" + Determines whether dynamic or static quantization is more appropriate for a given module. + + Takes advantage of the ModelReportObserver that records range information. + Stationary distribution of data are strictly above tolerance level for the comparison statistic: + + S = average_batch_activation_range/epoch_activation_range + + Nonstationary distributions are below or at the tolerance level for this metric. + + If the distribution of data right after the module is non-stationary, recommend dynamic quantization + Otherwise recommend static quantization + + This will then generate suggestions for dynamic vs static quantization focused around Linear. + + Args: + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a tuple with two elements: + String report of of whether dynamic or static quantization is recommended for certain modules + Dictionary mapping modules with ModelReportObservers around them to: + whether dynamic quantization is recommended + their S metric of input to module + whether input to module is stationary or non-stationary + their S metric of output of module + whether output of module is stationary or non-stationary + the tolerance level to decided whether input/output is stationary or non-stationary + whether it is currently supported or planned for the future + """ + + # get the dictionary of the information to format the string report + module_dynamic_static_info = self._generate_dict_info(model) + + dynamic_vs_static_string = "Dynamic vs. Static Quantization suggestions: \n" + + modules_added: bool = False # check to make sure at least 1 module added. + + dynamic_benefit = ( + " You will get more accurate results if you use dynamic quantization" + ) + static_benefit = ( + " You can increase model efficiency if you use static quantization" + ) + future_support_str = ( + ". This layer is not yet supported for dynamic quantization" + ) + # This for loop goes through the information collected in module_dynamic_static_info and: + # Populates the string based report with the information from module_dynamic_static_info + # Compiles the complete report by appending relevant formatted strings + + for module_fqn in module_dynamic_static_info: + # there is at least 1 module for suggestion + modules_added = True + module_info = module_dynamic_static_info[module_fqn] + suggestion_string_template = ( + "For module {} it is suggested to use {} quantization because {}.\n" + ) + + # decide what string formatting values will be + quantization_type = "" + quantization_reasoning = "the distribution of data before {} is {} and the distribution after is {}." + + benefit_str = "" + + # strings for if dynamic quantized per tensor is needed + recommend_per_tensor = ( + ". We recommend to add a {} before this module if it is static." + ) + rec_lay_to_add = "dynamic quantize per tensor layer" + dynamic_per_tensor_string = recommend_per_tensor.format(rec_lay_to_add) + dynamic_per_tensor_reasoning_string = " This is because the input to this module has a non-stationary distribution" + + # start composing explanation + if module_info[self.DEFAULT_DYNAMIC_REC_KEY]: + quantization_type = "dynamic" + # check if currently supported or future supported + benefit_str = dynamic_benefit + if not module_info[self.IS_CURRENTLY_SUPPORTED_KEY]: + benefit_str += future_support_str + else: + quantization_type = "static" + benefit_str = static_benefit + + # now set the quantization explanation string + quantization_reasoning = ( + quantization_reasoning.format( + module_fqn, + module_info[self.PRE_OBS_DATA_DIST_KEY], + module_info[self.POST_OBS_DATA_DIST_KEY], + ) + + benefit_str + ) + + # if we have a non-stationary input -> linear -> stationary we suggested static + # however, we want to also recommend they add a dynamic quantize per tensor right if this change is made + if ( + module_info[self.PRE_OBS_DATA_DIST_KEY] == self.NON_STATIONARY_STR + and module_info[self.POST_OBS_DATA_DIST_KEY] == self.STATIONARY_STR + ): + quantization_reasoning = ( + quantization_reasoning + + dynamic_per_tensor_string + + dynamic_per_tensor_reasoning_string + ) + + # format the overall suggestion string with the specific inputs + module_suggestion_string = suggestion_string_template.format( + module_fqn, quantization_type, quantization_reasoning + ) + + # append to overall suggestion + dynamic_vs_static_string += module_suggestion_string + + if not modules_added: + dynamic_vs_static_string += "No applicable layers for suggestions. Only linear and conv are valid.\n" + + # return the string as well as the dictionary of information + return (dynamic_vs_static_string, module_dynamic_static_info) + + +class InputWeightEqualizationDetector(DetectorBase): + r""" + Determines whether input-weight equalization can help improve quantization for certain modules. + + Specifically, this list of modules includes: + linear + conv + + Determines whether input-weight equalization is recommended based on the comp stat: + s_c = sqrt(w_c/W)/sqrt(i_c/I) + where: + w_c is range of weight for channel c, W is range of weight over all channels + i_c is range of input for channel c, I is range of input over all channels + + if s_c >= threshold or <= 1 / threshold, recommends input-weight equalization + + Args: + ratio_threshold (float): The threshold for s_c to determine if input-weight equalization is suggested + Should be between 0 and 1 (both non-inclusive) + ch_axis (int, optional): The channel axis being observed to determine input weight equalization + Default: 1 + + * :attr:`ratio_threshold`: The threshold for s_c to determine if input-weight equalization is suggested + Should be between 0 and 1 + + * :attr:`ch_axis`: The channel axis being observed to determine input weight equalization + + * :attr:`SUPPORTED_MODULES`: This specifies the modules that are supported for input-weight equalization + + * :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector + """ + + SUPPORTED_MODULES: set[Callable] = { + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nnqat.Linear, + nnqat.Conv1d, + nnqat.Conv2d, + nnqat.Conv3d, + } + + # names for the pre and post observers that are inserted + DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer" + + # weight / activation prefix for each of the below info + WEIGHT_PREFIX = "weight_" + ACTIVATION_PREFIX = "input_activation_" + + # string names for keys of info dictionaries + PER_CHANNEL_MAX_KEY = "per_channel_max" + PER_CHANNEL_MIN_KEY = "per_channel_min" + GLOBAL_MAX_KEY = "global_max" + GLOBAL_MIN_KEY = "global_min" + + # keys for return dict of recommendations + RECOMMENDED_KEY = "input_weight_equalization_recommended" + COMP_METRIC_KEY = "input_weight_channel_comparison_metrics" + THRESHOLD_KEY = "input_weight_threshold" + CHANNEL_KEY = "input_weight_channel_axis" + + # default weight and info strings + WEIGHT_STR = "weight" + INPUT_STR = "input" + + # default for what ratio we recommend input weight + DEFAULT_RECOMMEND_INPUT_WEIGHT_CHANNEL_RATIO = 0.4 + + def __init__(self, ratio_threshold: float, ch_axis: int = 1): + # ensure passed in inputs are valid + if ratio_threshold <= 0 or ratio_threshold >= 1: + raise ValueError("Make sure threshold is > 0 and < 1") + + # initialize attributes based on args + self.ratio_threshold: float = ratio_threshold + self.ch_axis: int = ch_axis + + def _is_supported(self, module: nn.Module, insert: bool = False) -> bool: + r"""Returns whether the given module is supported for observers + + Args + module: The module to check and ensure is supported + insert: True if this is check for observer insertion, false if for report gen + + Returns True if the module is supported by observer, False otherwise + """ + # check to see if module is of a supported type + is_supported_type = any(type(module) is x for x in self.SUPPORTED_MODULES) + + # this is check for observer insertion + if insert: + return is_supported_type + else: + # this is for report gen and we also need to check if it contains observers + has_obs = hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME) + return is_supported_type and has_obs + + def get_qconfig_info(self, model) -> dict[str, DetectorQConfigInfo]: + r"""Returns the DetectorQConfigInfo for each module_fqn relevant + Args + model (nn.Module or subclass): model to find observer insertion points + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to: + A DetectorQConfigInfo with the information to generate a QConfig for a specific module + """ + # run the helper function to populate the dictionary + # find the range of inputs + input_values: dict[str, dict] = self._extract_input_info(model) + + # find the range of weights + weight_values: dict[str, dict] = self._extract_weight_info(model) + + # calculate per_channel comparison statistic s_c + comp_stats: dict[str, torch.Tensor] = self._generate_comparison_values( + input_values, weight_values + ) + + # generate the return dictionary + input_weight_equalization_info: dict[str, dict] = self._generate_dict_info( + input_values, weight_values, comp_stats + ) + + # we actually have a qconfig info object we are populating + module_fqn_to_detector_qconfig_info = {} + + for module_fqn in input_weight_equalization_info: + # create a detector info instance + detector_qconfig_info = DetectorQConfigInfo(module_fqn) + + # see if per channel quantization is supported + input_weight_recommended: bool = input_weight_equalization_info[module_fqn][ + self.RECOMMENDED_KEY + ] + detector_qconfig_info.is_equalization_recommended = input_weight_recommended + module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info + + return module_fqn_to_detector_qconfig_info + + def determine_observer_insert_points( + self, prepared_fx_model: GraphModule + ) -> dict[str, dict[str, Any]]: + r"""Determines where observers need to be inserted for the Input Weight Equalization Detector. + For this detector, we want to place observers in front of supported layers. + + Currently inserts observers for: + linear layers + conv layers + + Args: + prepared_fx_model (GraphModule): The prepared Fx GraphModule + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with: + key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node) + key "observer_to_insert" -> the observer we wish to insert (ObserverBase) + key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer + key "observer_args" -> The arguments that are meant to be passed into the observer + """ + + # observer for this detector is ModelReportObserver + obs_ctr = ModelReportObserver + + # return dict + obs_fqn_to_info: dict[str, dict[str, Any]] = {} + + for fqn, module in prepared_fx_model.named_modules(): + # check to see if module is of a supported type + if self._is_supported(module, insert=True): + # if it's a supported type, we want to get node and add observer insert locations + targeted_node = self._get_targeting_node(prepared_fx_model, fqn) + + # add entry for pre-observer + pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME + + obs_fqn_to_info[pre_obs_fqn] = { + DETECTOR_TARGET_NODE_KEY: targeted_node, + DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(ch_axis=self.ch_axis), + DETECTOR_IS_POST_OBS_KEY: False, + DETECTOR_OBS_ARGS_KEY: targeted_node.args, + } + + return obs_fqn_to_info + + def get_detector_name(self) -> str: + r"""Returns the name of this detector""" + return "input_weight_equalization_detector" + + def _extract_input_info(self, model: GraphModule) -> dict[str, dict]: + r""" + Takes in a calibrated GraphModule and then finds the relevant observers. + It then extracts the input information for each observer returns it + + Args + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a dict mapping relevant module fqns (str) to a dict with keys: + "input_activation_per_channel_max" : maps to the per_channel max values + "input_activation_per_channel_min" : maps to the per_channel min values + "input_activation_global_max" : maps to the global max recorded + "input_activation_global_min" : maps to the global min recorded + """ + + # return dictionary mapping observer fqns to desired info + input_info: dict[str, dict] = {} + + for fqn, module in model.named_modules(): + # if module is supported and it has a pre-observer + if self._is_supported(module): + # get pre observer for the module + pre_obs = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME) + + input_info[fqn] = { + self.ACTIVATION_PREFIX + self.PER_CHANNEL_MAX_KEY: pre_obs.max_val, + self.ACTIVATION_PREFIX + self.PER_CHANNEL_MIN_KEY: pre_obs.min_val, + self.ACTIVATION_PREFIX + self.GLOBAL_MAX_KEY: max(pre_obs.max_val), + self.ACTIVATION_PREFIX + self.GLOBAL_MIN_KEY: min(pre_obs.min_val), + } + + return input_info + + def _extract_weight_info(self, model: GraphModule) -> dict[str, dict]: + r""" + Takes in a calibrated GraphModule and then finds the relevant observers. + It then extracts the weight information for each layer an observer is attached to. + + Args + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a dict mapping module fqns (str) to a dict with keys: + "per_channel_max" : maps to the per_channel max values + "per_channel_min" : maps to the per_channel min values + "global_max" : maps to the global max recorded + "global_min" : maps to the global min recorded + """ + # return dictionary mapping observer fqns to desired info + weight_info: dict[str, dict] = {} + + for fqn, module in model.named_modules(): + # if module is supported and it has a pre-observer + if self._is_supported(module): + # we don't need actual observer, just the module weights + # calculate min and max vals + device = module.weight.device + min_val: torch.Tensor = torch.tensor([float("inf")], device=device) + max_val: torch.Tensor = torch.tensor([float("-inf")], device=device) + x_copy = module.weight + x_dim = x_copy.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x_copy.permute(new_axis_list) + + # Need to match dtype of min/max because the updates to buffers + # are done in place and types need to match for comparisons + y = y.to(min_val.dtype) + y = torch.flatten(y, start_dim=1) + if min_val.numel() == 0 or max_val.numel() == 0: + min_val, max_val = torch.aminmax(y, dim=1) + else: + min_val_cur, max_val_cur = torch.aminmax(y, dim=1) + min_val = torch.min(min_val_cur, min_val) + max_val = torch.max(max_val_cur, max_val) + + weight_info[fqn] = { + self.WEIGHT_PREFIX + self.PER_CHANNEL_MAX_KEY: max_val, + self.WEIGHT_PREFIX + self.PER_CHANNEL_MIN_KEY: min_val, + self.WEIGHT_PREFIX + self.GLOBAL_MAX_KEY: max(max_val), + self.WEIGHT_PREFIX + self.GLOBAL_MIN_KEY: min(min_val), + } + + return weight_info + + def _calculate_range_ratio( + self, info_dict: dict, info_str: str, module_fqn: str + ) -> torch.Tensor: + r""" + Takes in an info dict and calculates the s_c matrix. + + Args: + info_dict (dict): A dictionary of either input or weight range info + info_str (str): A str describing whether currently looking at weight or input info + Either "weight" or "input" + module_fqn (str): The fqn of the module we are looking at + + Returns a tensor of values, where each value is the s_c stat for a different channel + """ + # calculate the ratios of the info + # get the prefix str + prefix_str = ( + self.ACTIVATION_PREFIX if info_str == self.INPUT_STR else self.WEIGHT_PREFIX + ) + + per_channel_range = ( + info_dict[prefix_str + self.PER_CHANNEL_MAX_KEY] + - info_dict[prefix_str + self.PER_CHANNEL_MIN_KEY] + ) + global_range = ( + info_dict[prefix_str + self.GLOBAL_MAX_KEY] + - info_dict[prefix_str + self.GLOBAL_MIN_KEY] + ) + + if global_range == 0: + range_zero_explanation = "We recommend removing this channel as it doesn't provide any useful information." + raise ValueError( + f"The range of the {info_str} data for module {module_fqn} is 0, " + f"which means you have a constant value channel. {range_zero_explanation}" + ) + + ratio = per_channel_range / global_range + + return ratio + + def _generate_comparison_values( + self, input_info: dict, weight_info: dict + ) -> dict[str, torch.Tensor]: + r""" + Takes in the information on the min and max values of the inputs and weights and: + Calculates the comp stat for each channel: s_c = sqrt(w_c/W)/sqrt(i_c/I) + + Args: + input_info (dict): A dict mapping each observer to input range information + weight_info (dict): A dict mapping each observer to weight range information + + Returns a dict mapping relevant observer fqns (str) to a 1-D tensor. + Each value is a different s_c value for a different channel + """ + # create return dictionary for each observer + module_fqn_to_channel: dict[str, torch.Tensor] = {} + + # for each module (both passed in dicts should have same keys) + for module_fqn in input_info: + # raise error if not in weight info + if module_fqn not in weight_info: + raise KeyError( + f"Unable to find weight range stats for module {module_fqn}" + ) + + # calculate the ratios of the weight info and input info + weight_ratio = self._calculate_range_ratio( + weight_info[module_fqn], self.WEIGHT_STR, module_fqn + ) + input_ratio = self._calculate_range_ratio( + input_info[module_fqn], self.INPUT_STR, module_fqn + ) + + # if mismatched size, because of grouping, we want to replicate weight enough times + weight_channels = len(weight_ratio) + input_channels = len(input_ratio) + if weight_channels != input_channels: + # we try to replicate + if input_channels % weight_channels != 0: + raise AssertionError( + "input channels should be divisible by weight channels." + ) + # get replication factor + rep_factor: int = input_channels // weight_channels + + # weight ratio is (n,), input ratio is (k,), we just repeat weight ratio k // n + weight_ratio = weight_ratio.repeat(rep_factor) + + # calculate the s metric per channel + s = torch.sqrt(weight_ratio) / torch.sqrt(input_ratio) + module_fqn_to_channel[module_fqn] = s + + # return compiled observer ratios + return module_fqn_to_channel + + def _generate_dict_info( + self, input_info: dict, weight_info: dict, comp_stats: dict + ) -> dict[str, dict]: + r""" + Helper function for generate_detector_report that does the generation of the dictionary. + This process is done as specified in generate_detector_report documentation + + Args: + input_info (dict): A dict mapping each module to input range information + weight_info (dict): A dict mapping each module to weight range information + comp_stats (dict): A dict mapping each module to its corresponding comp stat + + Returns a dictionary mapping each module with relevant ModelReportObservers around them to: + whether input weight equalization is recommended + their s_c metric compared to the threshold + the threshold used to make the recommendation + the channel used for recording data + the input channel range info + the weight channel range info + """ + # store modules input weight equalization info + input_weight_equalization_info: dict[str, dict] = {} + + # for each module we add separate set of suggestions + for module_fqn in input_info: + # get relevant info for this module + mod_input_info: dict = input_info[module_fqn] + mod_weight_info: dict = weight_info[module_fqn] + mod_comp_stat: dict = comp_stats[module_fqn] + + # decide if each channel should have input weight equalization or not + channel_rec_vals: list = [] + + for val in mod_comp_stat: + float_rep: float = val.item() + + # decide if recommending input weight equalization + recommended: bool = ( + float_rep >= self.ratio_threshold + and float_rep <= 1 / self.ratio_threshold + ) + channel_rec_vals.append(recommended) + + # build the return dict input + # also unpack input and weight dicts into it + input_weight_equalization_info[module_fqn] = { + self.RECOMMENDED_KEY: channel_rec_vals, + self.COMP_METRIC_KEY: mod_comp_stat, + self.THRESHOLD_KEY: self.ratio_threshold, + self.CHANNEL_KEY: self.ch_axis, + **mod_input_info, + **mod_weight_info, + } + + # return our compiled info for each module + return input_weight_equalization_info + + def generate_detector_report( + self, model: GraphModule + ) -> tuple[str, dict[str, Any]]: + r""" + Determines whether input weight equalization is appropriate for a given module. + + Takes advantage of the ModelReport Observer which records per channel information of input range + It then uses the passed in weight info inconjunction to compute the desired ratio + Finally, it gives suggestions based on this information for each module of interest + + Args: + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a tuple with two elements: + String report of of whether input weight equalization is recommended for certain modules + Dictionary mapping modules of interest to: + whether input weight equalization is recommended + their s_c metric compared to the threshold + the threshold used to make the recommendation + the channel used for recording data + the input channel range info + the weight channel range info + """ + + # find the range of inputs + input_values: dict[str, dict] = self._extract_input_info(model) + + # find the range of weights + weight_values: dict[str, dict] = self._extract_weight_info(model) + + # calculate per_channel comparison statistic s_c + comp_stats: dict[str, torch.Tensor] = self._generate_comparison_values( + input_values, weight_values + ) + + # generate the return dictionary + input_weight_equalization_info: dict[str, dict] = self._generate_dict_info( + input_values, weight_values, comp_stats + ) + + # now we can generate report based on this information + input_weight_string = "Input-Weight Equalization suggestions: \n" + + # some strings to be formatted depending on module we are adding + module_suggestion_str = "For Module {} looked at with axis {}: \n" + channel_suggestion_str = ( + "\tWe suggest {} input weight equalization because {}\n" + ) + use_str = "to use" + no_use_str = "to not use" + input_weight_benefit_str = "{}/{} channels would benefit and we expect significant reduction in quantization error." + input_weight_non_benefit_reasoning = ( + "{}/{} channels benefitting from input-weight equalization being applied." + ) + input_weight_non_benefit_str = "we don't expect much improvement from input-weight equalization based on {}" + + # added module check + added_module: bool = False + + # compile the suggestion string + for module_fqn in input_weight_equalization_info: + # we added at least 1 module + added_module = True + # add the module level description + input_weight_string += module_suggestion_str.format( + module_fqn, self.ch_axis + ) + + mod_info: dict[str, Any] = input_weight_equalization_info[module_fqn] + + # gather info on how many channels would benefit from input weight and + recommendation_per_channel: torch.Tensor = mod_info[self.RECOMMENDED_KEY] + num_recs = sum(recommendation_per_channel) + + if ( + num_recs / len(recommendation_per_channel) + >= self.DEFAULT_RECOMMEND_INPUT_WEIGHT_CHANNEL_RATIO + ): + input_benefit_formatted = input_weight_benefit_str.format( + num_recs, len(recommendation_per_channel) + ) + channel_str = channel_suggestion_str.format( + use_str, input_benefit_formatted + ) + input_weight_string += channel_str + else: + non_benefit_reason_formatted = ( + input_weight_non_benefit_reasoning.format( + num_recs, len(recommendation_per_channel) + ) + ) + non_benefit_str = input_weight_non_benefit_str.format( + non_benefit_reason_formatted + ) + channel_str = channel_suggestion_str.format(no_use_str, non_benefit_str) + input_weight_string += channel_str + + # if no modules looked at, amend return string + if not added_module: + input_weight_string += ( + "No applicable layers for suggestions. Only linear and conv valid.\n" + ) + + # return a tuple with the string explanation and the compiled dict info + return (input_weight_string, input_weight_equalization_info) + + +class OutlierDetector(DetectorBase): + r""" + Determines whether there are significant outliers in activation data around a certain layer. + + This is ideally used in conjunction with information on stationary vs. non-stationary distribution: + If the data is stationary, and there are significant outliers, then we want to flag them + We want to do this on a per channel basis for detecting outliers + + Determines whether activation data is flagged as outlier based on if data is stationary and: + p_r = avg(100th percentile / "reference_percentile"th percentile) + where: + p_r is average percentile ratio across all batches in the epoch + reference_percentile is a percentile values between 0 and 100 exclusive + + if p_r is above some threshold, then we consider the activations to have significant outliers + + Args: + ratio_threshold (float, optional): The threshold for p_r to determine if there are outliers in activations + Should be >= 1 + Default: 3.5 + reference_percentile (float, optional): The denominator to find the relative scale of the 100th percentile + Should be between 0 and 1 + Default: 0.975 + fraction_batches_used_threshold (float, optional): Threshold of fraction of batches per channel to determine outlier + If fraction is below this, we deem number of samples used to calculate outliers as insignificant and alert user + regardless of whether we detected outliers or not in channel to take a closer look at channel results + Should be between 0 and 1 + Default: 0.95 + ch_axis (int, optional): The channel axis being observed to determine input weight equalization + Default: 1 + + * :attr:`ratio_threshold`: The threshold for p_r to determine if there are outliers in activations + The p_r value (average ratio of 100th percentile/reference_percentile) is compared to ratio_threshold + If it is significantly greater, then we consider it an outlier + This threshold was calculated based on the ratio of the percentiles in a normal distribution + The calculations behind value choice: https://drive.google.com/file/d/1N2wdtXWI-kOH8S7HH4-PYB_NmqzZil4p/view?usp=sharing + + * :attr:`reference_percentile`: The denominator of the top fraction to find the relative scale of the 100th percentile + Should be between 0 and 1 + The calculations behind value choice: https://drive.google.com/file/d/1N2wdtXWI-kOH8S7HH4-PYB_NmqzZil4p/view?usp=sharing + + * :attr:`fraction_batches_used_threshold`: The fraction of batches to determine outliers for each channel should be above this + Some batches may not be used because of 0-based errors, so this is to ensure a good amount of the total batches are used + Should be between 0 and 1 + + * :attr:`ch_axis`: The channel axis being observed to determine outliers + + * :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector + """ + + # names for the pre observers that are inserted + DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer" + + # pre activation prefix + INPUT_ACTIVATION_PREFIX = "input_activation_" + + # names for dict keys + OUTLIER_KEY = "outliers_detected" + NUM_BATCHES_KEY = "outlier_detection_batches_used" + IS_SUFFICIENT_BATCHES_KEY = "outlier_detection_is_sufficient_batches" + COMP_METRIC_KEY = "outlier_detection_percentile_ratios" + RATIO_THRES_KEY = "outlier_detection_ratio_threshold" + REF_PERCENTILE_KEY = "outlier_detection_reference_percentile" + CHANNEL_AXIS_KEY = "outlier_detection_channel_axis" + MAX_VALS_KEY = INPUT_ACTIVATION_PREFIX + "per_channel_max" + CONSTANT_COUNTS_KEY = "constant_batch_counts" + + def __init__( + self, + ratio_threshold: float = 3.5, + reference_percentile: float = 0.975, + fraction_batches_used_threshold: float = 0.95, + ch_axis: int = 1, + ): + # initialize the variables of interest + self.ratio_threshold = ratio_threshold + + # make sure passed in percentile is valid + if reference_percentile < 0 or reference_percentile > 1: + raise AssertionError("reference_percentile must be between 0 and 1") + if not ( + fraction_batches_used_threshold >= 0 + and fraction_batches_used_threshold <= 1 + ): + raise AssertionError( + "fraction_batches_used_threshold must be between 0 and 1" + ) + self.reference_percentile = reference_percentile + self.fraction_batches_used_threshold = fraction_batches_used_threshold + self.ch_axis = ch_axis + + def get_detector_name(self) -> str: + r"""Returns the name of this detector""" + return "outlier_detector" + + def _supports_insertion(self, module: nn.Module) -> bool: + r"""Returns whether the given module is supported for observers insertion + + Any module that doesn't have children and isn't an observer itself is supported + + Args + module: The module to check and ensure is supported + + Returns True if the module is supported by observer, False otherwise + """ + # case for insertion of module + # check if the module has any children and isn't observer + num_children = len(list(module.children())) + return num_children == 0 and not _is_activation_post_process(module) + + def get_qconfig_info(self, model) -> dict[str, DetectorQConfigInfo]: + r"""Returns the DetectorQConfigInfo for each module_fqn relevant + Args + model (nn.Module or subclass): model to find observer insertion points + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to: + A DetectorQConfigInfo with the information to generate a QConfig for a specific module + """ + # currently doesn't do anything for outlier detector + return {} + + def _supports_report_gen(self, module: nn.Module) -> bool: + r"""Returns whether the given module is supported for report generation + + Any module that has a model report pre-observer is supported + + Args + module: The module to check and ensure is supported + + Returns True if the module is supported by observer, False otherwise + """ + return hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME) + + def determine_observer_insert_points( + self, prepared_fx_model: GraphModule + ) -> dict[str, dict[str, Any]]: + r"""Determines where observers need to be inserted for the Outlier Detector. + + For this detector, we want to place observers in front of supported layers. + + Currently inserts observers for: + all layers that do not have children (leaf level layers) + + Args: + prepared_fx_model (GraphModule): The prepared Fx GraphModule + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with: + key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node) + key "observer_to_insert" -> the observer we wish to insert (ObserverBase) + key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer + key "observer_args" -> The arguments that are meant to be passed into the observer + """ + # observer for this detector is ModelReportObserver + obs_ctr = ModelReportObserver + + # return dict + obs_fqn_to_info: dict[str, dict[str, Any]] = {} + + for fqn, module in prepared_fx_model.named_modules(): + # check to see if module is of a supported type + if self._supports_insertion(module): + # if it's a supported type, we want to get node and add observer insert locations + targeted_node = self._get_targeting_node(prepared_fx_model, fqn) + + # add entry for pre-observer + pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME + + obs_fqn_to_info[pre_obs_fqn] = { + DETECTOR_TARGET_NODE_KEY: targeted_node, + DETECTOR_OBS_TO_INSERT_KEY: obs_ctr( + ch_axis=self.ch_axis, comp_percentile=self.reference_percentile + ), + DETECTOR_IS_POST_OBS_KEY: False, + DETECTOR_OBS_ARGS_KEY: targeted_node.args, + } + + return obs_fqn_to_info + + def _calculate_outlier_info( + self, + percentile_ratios: torch.Tensor, + counted_batches: torch.Tensor, + total_batches: int, + ) -> dict[str, list[bool]]: + r""" + Gives info on whether the percentile ratios calculated would be considered outliers + Also gives information on whether the collected data is statistically significant to make this claim + + Args: + percentile_ratios (torch.Tensor): The average percentile_ratios per channel calculated by the observer + counted_batches (torch.Tensor): The number of batches used for average calculation per tensor + total_batches (int): The total number of batches that passed through observer in this epoch + + Returns a dictionary mapping: + "outliers_detected" : list of bools per channel that are true if it is considered an outlier + "is_sufficient_batches": if o_r was >= fraction_batches_used_threshold: + where o_r = counted_batches / total_batches + """ + outlier_dict: dict[str, list[bool]] = { + self.OUTLIER_KEY: [], + self.IS_SUFFICIENT_BATCHES_KEY: [], + } + + # get both as flattened lists for easy mapping + ratios_list: list = percentile_ratios.tolist() + num_batches_list: list = counted_batches.tolist() + + # calculate whether channels were statistically significant + significant_size = [ + batch_size / total_batches >= self.fraction_batches_used_threshold + for batch_size in num_batches_list + ] + outlier_dict[self.IS_SUFFICIENT_BATCHES_KEY] = significant_size + + # calculate for each channel whether it's an outlier or not based on ratio + outlier_detected = [ratio > self.ratio_threshold for ratio in ratios_list] + outlier_dict[self.OUTLIER_KEY] = outlier_detected + + # return the dictionary with the two lists + return outlier_dict + + def _generate_info_dict(self, model: GraphModule) -> dict[str, dict]: + r""" + Helper function for generate_detector_report that does the generation of the dictionary. + This process is done as specified in generate_detector_report documentation + + Args: + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a dict mapping relevant module fqns to: + whether there were outliers found in activation before + the number of batches used for each channel + whether fraction of applicable batches used is above fraction_batches_used_threshold + their p_r metric compared to the threshold + the threshold used to make the recommendation + the reference_percentile used to make the recommendation + the channel axis used to determine individual channels + the constant batch counts per channel + the per channel max values + """ + # return dictionary mapping observer fqns to desired info + info_dict: dict[str, dict] = {} + + for fqn, module in model.named_modules(): + # if module is supported and it has a pre-observer + if self._supports_report_gen(module): + # get pre observer for the module + pre_obs: ModelReportObserver = getattr( + module, self.DEFAULT_PRE_OBSERVER_NAME + ) + + # get the number of batches and calculated ratio thresholds + num_batches: torch.Tensor = pre_obs.percentile_batches_tracked + average_ratios: torch.Tensor = pre_obs.average_percentile_ratio + channel_batch_cnts: torch.Tensor = pre_obs.constant_channels + total_batches: int = pre_obs.num_batches_tracked + + # also get the max values + max_vals: torch.Tensor = pre_obs.max_val + + # we have to specifically modify how we are recording negative ratio for pre-relu layers + for index, ratio_val in enumerate(average_ratios): + # check if we have a negative ratio + # a ratio might be negative if we have a situation where the 100th percentile is + # > 0 while the nth percentile is < 0, in which case this would not be detected + # as an outlier. Since we care more about magnitude, we make it positive. + if ratio_val.item() < 0: + # first make it positive + average_ratios[index] = -ratio_val + + if ratio_val.item() < 1: + # if it's less than 1 we have the flip it as well + average_ratios[index] = 1 / ratio_val + + outlier_calcs = self._calculate_outlier_info( + average_ratios, num_batches, total_batches + ) + + # calculate whether ratios were outliers + info_dict[fqn] = { + self.CHANNEL_AXIS_KEY: self.ch_axis, + self.REF_PERCENTILE_KEY: self.reference_percentile, + self.RATIO_THRES_KEY: self.ratio_threshold, + self.COMP_METRIC_KEY: average_ratios, + self.NUM_BATCHES_KEY: num_batches, + self.OUTLIER_KEY: outlier_calcs[self.OUTLIER_KEY], + self.IS_SUFFICIENT_BATCHES_KEY: outlier_calcs[ + self.IS_SUFFICIENT_BATCHES_KEY + ], + self.CONSTANT_COUNTS_KEY: channel_batch_cnts, + self.MAX_VALS_KEY: max_vals, + } + + return info_dict + + def generate_detector_report( + self, model: GraphModule + ) -> tuple[str, dict[str, Any]]: + r""" + Determines whether input weight equalization is appropriate for a given module. + + Takes advantage of the ModelReport Observer which records the relevant percentile information + + Args: + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a tuple with two elements: + String report of of whether there are outliers in the activations around certain modules + Dictionary mapping modules of interest to: + whether there were outliers found in activation before + the number of batches used for each channel + whether fraction of applicable batches used is above fraction_batches_used_threshold + their p_r metric compared to the threshold + the threshold used to make the recommendation + the reference_percentile used to make the recommendation + the channel axis used to determine individual channels + the constant batch counts per channel + the per channel max values + """ + # generate the information dictionary of outlier information + info_dict = self._generate_info_dict(model) + + # now we can generate report based on this information + outlier_string = "Outlier detection report: \n" + + # added module check + added_module: bool = False + + # some strings to be formatted depending on module we are adding + module_suggestion_str = "For Module {} looked at with axis {}: \n" + channel_suggestion_str = "\tFor channel {}, we found outliers in the preceding activation data with {}.\n" + channel_max_value_str = "a max value across all batches of {}" + note_string = "Note: outlier detection is only reliable for {}. We recommend {} to ensure the most accurate results." + note_distribution = "stationary distributions" + note_rec = "running the static vs. dynamic detector to ensure activation data before modules above is stationary" + + # suggestion for constant batch check since that can make it no outliers + constant_str = "\tFor channel {}, we found {} constant value batches. {}\n" + constant_suggestion = "We recommend taking a look at the dict and data to see how frequent this occurred and why." + + # compile the suggestion string + for module_fqn in info_dict: + # get module specific info + mod_info: dict[str, Any] = info_dict[module_fqn] + # check to see if we already added high level model desc + added_model_desc = False + # look at each individual channel and add a suggestion + for index, outlier_detected in enumerate(mod_info[self.OUTLIER_KEY]): + if outlier_detected: + # we found at least 1 outlier + if not added_model_desc: + # add the module level description + outlier_string += module_suggestion_str.format( + module_fqn, self.ch_axis + ) + added_model_desc = True + + # we mark that we found at least one outlier + added_module = True + max_value_found_str = channel_max_value_str.format( + mod_info[self.MAX_VALS_KEY][index] + ) + channel_str = channel_suggestion_str.format( + index, max_value_found_str + ) + outlier_string += channel_str + + # also check if we found constant batch + if mod_info[self.CONSTANT_COUNTS_KEY][index] != 0: + # make sure we add a module level highlight. + if not added_model_desc: + # add the module level description + outlier_string += module_suggestion_str.format( + module_fqn, self.ch_axis + ) + added_model_desc = True + + constant_values_for_channel = mod_info[self.CONSTANT_COUNTS_KEY][ + index + ] + formatted_str = constant_str.format( + index, constant_values_for_channel, constant_suggestion + ) + outlier_string += formatted_str + # we also added at least one thing to description + added_module = True + + # if found outlier, give suggestion, else give default response + if added_module: + # compose the note string + note_composed = note_string.format(note_distribution, note_rec) + outlier_string += note_composed + else: + outlier_string += "There were no outliers found in the activations.\n" + + return (outlier_string, info_dict) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/model_report.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/model_report.py new file mode 100644 index 0000000000000000000000000000000000000000..0ffbff88dd2d80dc237ae779eddd6fad5d26daee --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/model_report.py @@ -0,0 +1,666 @@ +# mypy: allow-untyped-defs +from collections import OrderedDict +from collections.abc import Callable +from typing import Any + +import torch +from torch.ao.quantization.fx._equalize import EqualizationQConfig +from torch.ao.quantization.fx._model_report.detector import ( + DETECTOR_IS_POST_OBS_KEY, + DETECTOR_OBS_ARGS_KEY, + DETECTOR_OBS_TO_INSERT_KEY, + DETECTOR_TARGET_NODE_KEY, + DetectorBase, + DetectorQConfigInfo, +) +from torch.ao.quantization.fx._model_report.model_report_visualizer import ( + ModelReportVisualizer, +) +from torch.ao.quantization.fx.graph_module import GraphModule +from torch.ao.quantization.observer import ObserverBase +from torch.ao.quantization.qconfig_mapping import QConfig, QConfigMapping + + +class ModelReport: + r""" + The ModelReport class aims to provide users an easy way to diagnose issues that they run into + with their models. The class works with all traceable GraphModules to help diagnose issues, + though the requirements on the type of model more-so depends on the specific report the user + is trying to generate. With respect to the reports, the ModelReport class is initialized with + a set of Detector classes, each of which generate reports on quantization configuration + issues a use might have. + + Currently supports generating reports on: + - Suggestions for per-channel vs. per-tensor quantization (nn.Module) + - Suggestions for dynamic vs static quantization for linear layers (Graph Modules) + - Suggestions for input-weight equalization for linear and conv layers (Graph Modules) + - Suggestions for outlier detection for all layers (Graph Modules) + + The ModelReport class has the primary functionality of inserting observers (primarily the ModelReportObserver) + where needed for each detector to gather the information it needs, and then after calibration, the ModelReport + class compiles the report generated by each Detector class into a single report to return to the user. It also + has the capability to remove all the observers it inserted as well. + + * :attr:`_model` The model we wish to generate the report for. Must be a traceable GraphModule + + * :attr:`_desired_report_detectors` The set of Detectors representing desired reports from the ModelReport class + Make sure that these are all unique types of detectors [do not have more than 1 of the same class] + + * :attr:`_desired_detector_names` The set of detector names of the _desired_report_detectors. + This set is generated by calling the get_detector_name() of each detector + + * :attr:`_detector_name_to_observer_fqns` The mapping from each detector to fqns of observers of interest + The purpose of this is to keep track of what observers were inserted for each detector, so that they + can be removed at the end if desired + + * :attr:`_prepared_flag` A boolean flag that keeps track of whether we have prepared the model or not + This is to ensure we only insert observers once with the ModelReport instance + + * :attr:`_removed_observers` A boolean to track if we have removed observers already + The purpose is to ensure we don't attempt to remove observers twice with the same ModelReport + instance. This also allows the functionality where we can generate the report multiple times + as long as we haven't removed the observers yet. + + Note: + This class was initially designed to work with the Fx Graph Mode workflow in mind. However, + full functionality is available as long as there is a traceable GraphModule that is being used. + One method to get a traceable GraphModule without going through the Fx workflow is to use + the QuantizationTracer class. + + General Flow for Fx workflow: + 1.) Initialize ModelReport object with reports of interest by passing in initialized detector objects and model + 2.) Prepare your model with prepare_fx + 3.) Call model_report.prepare_detailed_calibration to add relevant observers + 4.) Calibrate your model with data + 5.) Call model_report.generate_report on your model to generate report and optionally remove added observers + Optional + 6.) Call model_report.generate_visualizer to get a ModelReportVisualizer instance + 7.) To help in parsing report information and debugging, view report info as a: + - Table + - Histogram + - Line plot + 8.) Call model_report.generate_qconfigs to generate the qconfigs based on the report suggestions + + Example (with QuantizationTracer): + >>> # xdoctest: +SKIP + >>> # get the necessary qconfig + >>> config = PrepareCustomConfig() + >>> skipped_module_names, skipped_module_classes = ( + ... get_skipped_module_name_and_classes(config, False) + ... ) + + >>> # initialize our model and get GraphModule + >>> model = SomeModel() + >>> tracer = QuantizationTracer(skipped_module_names, skipped_module_classes) + >>> graph_module = GraphModule(model, tracer.trace(model)) + + >>> # get our set of detectors and ModelReport instance + >>> detector_set = set( + ... [ + ... DynamicStaticDetector(tolerance=0.5), + ... InputWeightEqualizationDetector(ratio_threshold=0.7), + ... ] + ... ) + >>> tracer_reporter = ModelReport(graph_module, tracer_detector_set) + + >>> # now we insert the observers and calibrate the model + >>> tracer_model_with_observers = tracer_reporter.prepare_detailed_calibration() + >>> for i in range(num_callibration_batches): + >>> example_input = get_callibration_input() + >>> tracer_model_with_observers(example_input) + + >>> # finally we generate the reports and optionally remove the observers we inserted + >>> reports = tracer_reporter.generate_model_report( + ... remove_inserted_observers=True + ... ) + + >>> # Optional: we can generate the qconfig mapping based on the suggestions + >>> qconfigs = model_report.generate_qconfig_mapping() + + >>> # Optional: we can generate the equalization mapping based on the suggestions + >>> qconfigs = model_report.generate_equalization_mapping() + + >>> # Optional: we get a ModelReportVisualizer instance to do any visualizations desired + >>> model_report_visualizer = tracer_reporter.generate_visualizer() + + """ + + def __init__(self, model: GraphModule, desired_report_detectors: set[DetectorBase]): + if len(desired_report_detectors) == 0: + raise ValueError("Should include at least 1 desired report") + + # keep track of the model we wish to generate report for + self._model: GraphModule = model + + # keep the reports private so they can't be modified + self._desired_report_detectors = desired_report_detectors + self._desired_detector_names = { + detector.get_detector_name() for detector in desired_report_detectors + } + + # keep a mapping of desired reports to observers of interest + # this is to get the readings, and to remove them, can create a large set + # this set can then be used to traverse the graph and remove added observers + self._detector_name_to_observer_fqns: dict[str, set[str]] = {} + + # initialize each report to have empty set of observers of interest + for desired_report in self._desired_detector_names: + self._detector_name_to_observer_fqns[desired_report] = set() + + # flags to ensure that we can only prepare and remove observers once + self._prepared_flag = False + self._removed_observers = False + + # store the reports that we generated for visualization purposes + # initially empty since no reports generated + self._generated_reports: dict[str, dict] = {} + + def get_desired_reports_names(self) -> set[str]: + """Returns a copy of the desired reports for viewing""" + return self._desired_detector_names.copy() + + def get_observers_of_interest(self) -> dict[str, set[str]]: + """Returns a copy of the observers of interest for viewing""" + return self._detector_name_to_observer_fqns.copy() + + def prepare_detailed_calibration(self) -> GraphModule: + r""" + Takes in a graph model and inserts the following observers: + - ModelReportObserver + + Each observer is inserted based on the desired_reports into the relevant locations + + Right now, each report in self._desired_detector_names has independent insertions + However, if a module already has a Observer of the same type, the insertion will not occur + This is because all of the same type of Observer collect same information, so redundant + + Returns the same GraphModule with the observers inserted + """ + + # if already prepared once, cannot prepare again + if self._prepared_flag: + raise ValueError( + "Already ran preparing detailed calibration. Run the report generation next after calibration." + ) + + # loop through each detector, find where placements should be, and keep track + insert_observers_fqns: dict[str, Any] = {} + + for detector in self._desired_report_detectors: + # determine observer points for each detector + obs_fqn_to_info = detector.determine_observer_insert_points(self._model) + # map each insert point to the observer to use + insert_observers_fqns.update(obs_fqn_to_info) + # update the set of observers this report cares about + self._detector_name_to_observer_fqns[detector.get_detector_name()] = set( + obs_fqn_to_info.keys() + ) + + # now insert all the observers at their desired locations + for observer_fqn in insert_observers_fqns: + target_node = insert_observers_fqns[observer_fqn][DETECTOR_TARGET_NODE_KEY] + insert_obs = insert_observers_fqns[observer_fqn][DETECTOR_OBS_TO_INSERT_KEY] + insert_post = insert_observers_fqns[observer_fqn][DETECTOR_IS_POST_OBS_KEY] + observer_args = insert_observers_fqns[observer_fqn][DETECTOR_OBS_ARGS_KEY] + self._insert_observer_around_module( + observer_fqn, target_node, insert_obs, observer_args, insert_post + ) + + self._prepared_flag = True + + return self._model + + def _insert_observer_around_module( + self, + obs_fqn: str, + target_node: torch.fx.node.Node, + obs_to_insert: ObserverBase, + observer_args: tuple, + insert_post: bool, + ): + r""" + Helper function that inserts the observer into both the graph structure and the module of the model + + Args + node_fqn (str): The fully qualified name of the observer we want to insert + target_node (torch.fx.node.Node): The node in model we are inserting observers around + obs_to_insert (ObserverBase): The observer we are inserting around target_node + observer_args (Tuple): The arguments we want to pass into the observer + insert_post (bool): whether this is meant to be a post observer for this node + """ + # if we are inserting post, then our target node is the next node + if insert_post: + target_node = target_node.next + + with self._model.graph.inserting_before(target_node): + self._model.add_submodule(obs_fqn, obs_to_insert) + self._model.graph.create_node( + op="call_module", target=obs_fqn, args=observer_args + ) + + # recompile model after inserts are made + self._model.recompile() + + def _get_node_from_fqn(self, node_fqn: str) -> torch.fx.node.Node: + r""" + Takes in a node fqn and returns the node based on the fqn + + Args + node_fqn (str): The fully qualified name of the node we want to find in model + + Returns the Node object of the given node_fqn otherwise returns None + """ + node_to_return = None + for node in self._model.graph.nodes: + # if the target matches the fqn, it's the node we are looking for + if node.target == node_fqn: + node_to_return = node + break + + if node_to_return is None: + raise ValueError("The node_fqn is was not found within the module.") + + # assert for MyPy + if not isinstance(node_to_return, torch.fx.node.Node): + raise AssertionError("node_to_return must be a torch.fx.node.Node") + + return node_to_return + + def generate_model_report( + self, remove_inserted_observers: bool + ) -> dict[str, tuple[str, dict]]: + r""" + Generates all the requested reports. + + Note: + You should have calibrated the model with relevant data before calling this + + The reports generated are specified by the desired_reports specified in desired_reports + + Can optionally remove all the observers inserted by the ModelReport instance + + Args: + remove_inserted_observers (bool): True to remove the observers inserted by this ModelReport instance + + Returns a mapping of each desired report name to a tuple with: + The textual summary of that report information + A dictionary containing relevant statistics or information for that report + + Note: + Throws exception if we try to generate report on model we already removed observers from + Throws exception if we try to generate report without preparing for calibration + """ + # if we haven't prepped model for calibration, then we shouldn't generate report yet + if not self._prepared_flag: + raise Exception( # noqa: TRY002 + "Cannot generate report without preparing model for calibration" + ) + + # if we already removed the observers, we cannot generate report + if self._removed_observers: + raise Exception( # noqa: TRY002 + "Cannot generate report on model you already removed observers from" + ) + + # keep track of all the reports of interest and their outputs + reports_of_interest = {} + + for detector in self._desired_report_detectors: + # generate the individual report for the detector + report_output = detector.generate_detector_report(self._model) + reports_of_interest[detector.get_detector_name()] = report_output + + # if user wishes to remove inserted observers, go ahead and remove + if remove_inserted_observers: + self._removed_observers = True + # get the set of all Observers inserted by this instance of ModelReport + all_observers_of_interest: set[str] = set() + for desired_report in self._detector_name_to_observer_fqns: + observers_of_interest = self._detector_name_to_observer_fqns[ + desired_report + ] + all_observers_of_interest.update(observers_of_interest) + + # go through all_observers_of_interest and remove them from the graph and model + for observer_fqn in all_observers_of_interest: + # remove the observer from the model + self._model.delete_submodule(observer_fqn) + + # remove the observer from the graph structure + node_obj = self._get_node_from_fqn(observer_fqn) + + if node_obj: + self._model.graph.erase_node(node_obj) + else: + raise ValueError("Node no longer exists in GraphModule structure") + + # remember to recompile the model + self._model.recompile() + + # save the generated reports for visualization purposes + saved_reports: dict[str, dict] = { + report_name: report_tuple[1] + for report_name, report_tuple in reports_of_interest.items() + } + + self._generated_reports = saved_reports + + # return the reports of interest + return reports_of_interest + + def _is_same_info_for_same_key(self, info_dict_a: dict, info_dict_b: dict) -> bool: + r""" + Takes in two dictionaries and ensures that any common keys between the two have the same + values. + + Args: + info_dict_a (Dict): First dictionary we wish to compare + info_dict_b (Dict): Second dictionary we wish to compare + + Returns True if all shared keys have same values, false otherwise + """ + # get the set of keys for both + dict_a_keys: set = set(info_dict_a.keys()) + dict_b_keys: set = set(info_dict_b.keys()) + + # get the insersection keys and check if same value for both dicts + intersecting_keys: set = dict_a_keys.intersection(dict_b_keys) + + for key in intersecting_keys: + dict_a_val = info_dict_a[key] + dict_b_val = info_dict_b[key] + + # if it's a tensor we have to handle separately + if type(dict_a_val) is torch.Tensor: + # if dict_b_val not tensor, automatically false + if ( + type(dict_b_val) is not torch.Tensor + or sum(dict_a_val != dict_b_val) != 0 + ): + return False + else: + # for non-tensor vals + if dict_a_val != dict_b_val: + return False + + # if no non matching shared keys found, return true + return True + + def _reformat_reports_for_visualizer(self) -> OrderedDict: + r""" + Takes the generated reports and reformats them into the format that is desired by the + ModelReportVisualizer + + Returns an OrderedDict mapping module_fqns to their features + """ + # we want to reorder and reformat the information so it is ordered in terms of order + # found in the model + + # first create new dict with all modules as keys and features under respective module + module_fqns_to_features: dict[str, dict] = {} + + for report_name in self._generated_reports: + # get mod -> feature dict and go through + module_info = self._generated_reports[report_name] + + for module_fqn in module_info: + # check if already in our accumulation dict + if module_fqn in module_fqns_to_features: + # we merge all the features together + new_info: dict = module_info[module_fqn] + present_info: dict = module_fqns_to_features[module_fqn] + + # merge them together into the new unioned dict + # same features keys -> same info, so okay if override + + # do safety check to make sure shared keys have same info + if self._is_same_info_for_same_key(new_info, present_info): + module_fqns_to_features[module_fqn] = { + **new_info, + **present_info, + } + else: + error_str = "You have the same key with different values across detectors. " + error_str += "Someone incorrectly implemented a detector with conflicting keys to existing detectors." + raise ValueError(error_str) + else: + # we just set it + module_fqns_to_features[module_fqn] = module_info[module_fqn] + + # our ordered dict so that modules can be ordered in order of how they appear in model + features_by_module: OrderedDict[str, dict] = OrderedDict() + + # we loop through modules in graph in order + for fqn, _module in self._model.named_modules(): + # find that fqn in fqns_to_features + if fqn in module_fqns_to_features: + # add it to our ordered dict + features_by_module[fqn] = module_fqns_to_features[fqn] + + # return the ordered dict of info we created + return features_by_module + + def generate_visualizer(self) -> ModelReportVisualizer: + r""" + Generates a ModelReportVisualizer instance using the reports generated + by the generate_model_report() method. + + Returns the generated ModelReportVisualizer instance initialized + + Note: + Throws exception if attempt to get visualizers without generating report + """ + # check if user has generated reports at least once + if len(self._generated_reports) == 0: + raise Exception( # noqa: TRY002 + "Unable to generate visualizers without first generating reports" + ) + + # get the ordered dict mapping modules to their full set of collected features / stats + module_fqns_to_features: OrderedDict = self._reformat_reports_for_visualizer() + + # create and return ModelReportVisualizer instance + visualizer: ModelReportVisualizer = ModelReportVisualizer( + module_fqns_to_features + ) + + return visualizer + + def _generate_qconfig_mapping_helper( + self, + detector_qconfig_info_combined: dict[str, DetectorQConfigInfo], + generation_function: Callable, + ) -> QConfigMapping: + r""" + This helper takes in the compiled detector qconfig info that + has been compiled together and merges it into a QConfigMapping + """ + # keep track of the qconfigmapping + qconfig_mapping = QConfigMapping() + + # loop through each module / fqn and attempt to create QConfigMapping + for fqn, module in self._model.named_modules(): + # if we have a qconfig info for this module + if fqn in detector_qconfig_info_combined: + qconfig_info_compiled = detector_qconfig_info_combined[fqn] + + # now generate the qconfig and add it to the mapping + generated_qconfig = generation_function(qconfig_info_compiled, module) + + # add to our config + qconfig_mapping.set_module_name(fqn, generated_qconfig) + + # return compiled mapping + return qconfig_mapping + + def _update_detector_quantizaiton_qconfig_info( + self, combined_info: DetectorQConfigInfo, new_info: DetectorQConfigInfo + ): + r""" + Takes in the old and new information and updates the combined information. + + Args: + combined_info (DetectorQConfigInfo): The DetectorQConfigInfo we are compiling all of the information in + new_info (DetectorQConfigInfo): The DetectorQConfigInfo with the information we are trying to merge the new info + into it + """ + combined_info.is_activation_dynamic = ( + combined_info.is_activation_dynamic or new_info.is_activation_dynamic + ) + combined_info.is_weight_per_channel = ( + combined_info.is_weight_per_channel or new_info.is_weight_per_channel + ) + + def _update_detector_equalization_qconfig_info( + self, combined_info: DetectorQConfigInfo, new_info: DetectorQConfigInfo + ): + r""" + Takes in the old and new information and updates the combined information. + + Args: + combined_info (DetectorQConfigInfo): The DetectorQConfigInfo we are compiling all of the information in + new_info (DetectorQConfigInfo): The DetectorQConfigInfo with the information we are trying to merge the new info + into it + """ + is_equalization_recommended = ( + combined_info.is_equalization_recommended + or new_info.is_equalization_recommended + ) + combined_info.is_equalization_recommended = is_equalization_recommended + + def _generate_module_fqn_to_detector_info_mapping( + self, update_qconfig_info_function: Callable + ) -> dict[str, DetectorQConfigInfo]: + r""" + Generates a QConfigMapping based on the suggestions of the + ModelReport API. The generated mapping encompasses all the + different types of feedback from the different detectors + all into one place. + + These configs are based on the suggestions provided by the ModelReport API + and can only be generated once the reports have been generated. + + Args: + update_qconfig_info_function (Callable) takes in a function that takes in two DetectorQConfigInfo + and updates the one that is being compiled + + Returns a Dict mapping module_fqns to DetectorQConfigInfo objects + + Note: + Throws exception if we try to generate mapping on model we already removed observers from + Throws exception if we try to generate mapping without preparing for calibration + """ + # if we haven't prepped model for calibration, then we shouldn't generate mapping yet + if not self._prepared_flag: + raise Exception( # noqa: TRY002 + "Cannot generate report without preparing model for calibration" + ) + + # if we already removed the observers, we cannot mapping + if self._removed_observers: + raise Exception( # noqa: TRY002 + "Cannot generate report on model you already removed observers from" + ) + + # keep track of qconfig info for each module across detectors + detector_qconfig_info_combined: dict[str, DetectorQConfigInfo] = {} + + for detector in self._desired_report_detectors: + # get the info from the detector + detector_info: dict[str, DetectorQConfigInfo] = detector.get_qconfig_info( + self._model + ) + + # we go through the modules + for module_fqn in detector_info: + # see if we already have info on it + if module_fqn in detector_qconfig_info_combined: + # we combine the current options with what is there + current_options = detector_qconfig_info_combined[module_fqn] + detector_options = detector_info[module_fqn] + + update_qconfig_info_function(current_options, detector_options) + else: + # we just use this for now + detector_qconfig_info_combined[module_fqn] = detector_info[ + module_fqn + ] + + return detector_qconfig_info_combined + + def generate_qconfig_mapping(self) -> QConfigMapping: + r""" + Generates a QConfigMapping based on the suggestions of the + ModelReport API. The generated mapping encompasses all the + different types of feedback from the different detectors + all into one place. + + These configs are based on the suggestions provided by the ModelReport API + and can only be generated once the reports have been generated. + + Returns a QConfigMapping for the quantization configuration + + Note: + Throws exception if we try to generate mapping on model we already removed observers from + Throws exception if we try to generate mapping without preparing for calibration + """ + # get the mapping info + detector_qconfig_info_combined = ( + self._generate_module_fqn_to_detector_info_mapping( + self._update_detector_quantizaiton_qconfig_info + ) + ) + + # we will do a bit of processing and remove fqns that don't have input weight recommended + + # now we generate the QConfig for each of the options + mapping: QConfigMapping = self._generate_qconfig_mapping_helper( + detector_qconfig_info_combined, self._quantization_config_generator + ) + + # return the generated mapping + return mapping + + def _quantization_config_generator( + self, detector_qconfig_info: DetectorQConfigInfo, module: torch.nn.Module + ) -> QConfig: + r""" + Returns the quantization configuration generated by the DetectorQConfigInfo object + """ + return detector_qconfig_info.generate_quantization_qconfig(module) + + def _equalization_config_generator( + self, detector_qconfig_info: DetectorQConfigInfo, module: torch.nn.Module + ) -> EqualizationQConfig: + r""" + We ignore the module argument here, and only focus on thedetector_qconfig_info + + Returns the equalization configuration generated by the DetectorQConfigInfo object + """ + return detector_qconfig_info.generate_equalization_qconfig() + + def generate_equalization_mapping(self) -> QConfigMapping: + r""" + Generates a QConfigMapping based on the suggestions of the + ModelReport API for equalization. The generated mapping encompasses all the + different types of feedback from the input-weight equalization detector. + + These configs are based on the suggestions provided by the ModelReport API + and can only be generated once the reports have been generated. + + Returns a QConfigMapping for the equalization configuration + """ + # get the mapping info + detector_qconfig_info_combined = ( + self._generate_module_fqn_to_detector_info_mapping( + self._update_detector_equalization_qconfig_info + ) + ) + + # now we generate the QConfig for each of the options + mapping: QConfigMapping = self._generate_qconfig_mapping_helper( + detector_qconfig_info_combined, self._equalization_config_generator + ) + + # return the generated mapping + return mapping diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/model_report_observer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/model_report_observer.py new file mode 100644 index 0000000000000000000000000000000000000000..a809dc60838e574e0bd484ee9698e9d1a0a5ee47 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/model_report_observer.py @@ -0,0 +1,285 @@ +# mypy: allow-untyped-defs +import torch +from torch.ao.quantization.observer import ObserverBase + + +class ModelReportObserver(ObserverBase): + r"""This observer is used to record additional information regarding keeping track + of S = average_batch_activation_range/epoch_activation_range. + + The purpose of this information is to prepare a report to present to users on whether + Dynamic or Static Quantization is more appropriate for their model given the general + distributions of their data. + + Args: + ch_axis (int, optional): The channel axis for which the range and outlier stats are computed + Default: 1 + comp_percentile (float, optional): The percentile to compare against 100 percentile to find outliers + Should be between 0 and 1 exclusive + Default: 0.9 + + * :attr:`num_batches_tracked` specifies number of batches passed through the observer + + * :attr:`average_batch_activation_range` defines average across the ranges of each batch passed through + + * :attr:`epoch_activation_min` defines the minimum value passed through the observer + + * :attr:`epoch_activation_max` defines the maximum value passed through the observer + + * :attr:`ch_axis` defines the channel being used to compute per channel min max stats + + * :attr:`min_val` defines the per channel minimum values passed through + + * :attr:`max_val` defines the per channel maximum values passed through + + * :attr:`comp_percentile` defines comparison percentile to find outliers + + * :attr:`average_percentile_ratio` defines the per channel average percentile ratios + + * :attr:`percentile_batches_tracked` defines the number of percentile batches tracked for each channel + + * :attr:`constant_channels` defines the number of batches that aren't constant channels per channel + + Note: this tool is meant for FX Graph Mode Quantization + """ + + epoch_activation_min: torch.Tensor + epoch_activation_max: torch.Tensor + min_val: torch.Tensor + max_val: torch.Tensor + comp_percentile: torch.Tensor + average_percentile_ratio: torch.Tensor + percentile_batches_tracked: torch.Tensor + constant_channels: torch.Tensor + + def __init__(self, ch_axis: int = 1, comp_percentile: float = 0.9): + super().__init__(torch.qint8) + self.num_batches_tracked = 0 + + # keep track of the min and mix of the range for average batch and epoch as a whole + self.average_batch_activation_range: torch.Tensor = torch.tensor(float(0)) + self.register_buffer("epoch_activation_min", torch.tensor(float("inf"))) + self.register_buffer("epoch_activation_max", torch.tensor(float("-inf"))) + + # keep track of per channel min max information using the given channel + self.ch_axis: int = ch_axis + self.register_buffer("min_val", torch.tensor([])) + self.register_buffer("max_val", torch.tensor([])) + + # keep track of percentile ratio information per channel + self.register_buffer("comp_percentile", torch.tensor([comp_percentile])) + self.register_buffer("average_percentile_ratio", torch.tensor([])) + self.register_buffer("percentile_batches_tracked", torch.tensor([])) + self.register_buffer("constant_channels", torch.tensor([])) + + def forward(self, x): + x_copy = x.detach() # avoid keeping autograd tape + x_copy = x_copy.to(self.epoch_activation_min.dtype) + + x_copy = self._calculate_range_stats(x_copy) + x_copy = self._calculate_min_max_stats(x_copy) + x_copy = self._calculate_percentile_stats(x_copy) + + # return the passed in the value + return x + + def _calculate_range_stats(self, x_copy): + r"""Calculates and stores range stats with forward values. + + Args + x_copy: A copy of the forward data + + Returns the passed in x_copy + """ + # get the min, max values of the data + min_val_cur, max_val_cur = torch.aminmax(x_copy) + + # calculate new epoch range values + epoch_min_val = torch.min(self.epoch_activation_min, min_val_cur) + epoch_max_val = torch.max(self.epoch_activation_max, max_val_cur) + + self.epoch_activation_min.copy_(epoch_min_val) + self.epoch_activation_max.copy_(epoch_max_val) + + # calculate the average batch activation range + current_batch_range = max_val_cur - min_val_cur + new_range = ( + self.average_batch_activation_range * self.num_batches_tracked + + current_batch_range + ) / (self.num_batches_tracked + 1) + + self.average_batch_activation_range = new_range + self.num_batches_tracked += 1 # new batch was processed + + return x_copy + + def _calculate_min_max_stats(self, x_copy): + r"""Calculates and stores the per_channel min, max stats with forward values. + Does calculation based on channel axis: self.ch_axis + + Args + x_copy: A copy of the forward data + + Returns the passed in x_copy + """ + # get the current min and max vals + min_val = self.min_val + max_val = self.max_val + x_dim = x_copy.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x_copy.permute(new_axis_list) + # Need to match dtype of min/max because the updates to buffers + # are done in place and types need to match for comparisons + y = y.to(self.min_val.dtype) + y = torch.flatten(y, start_dim=1) + if min_val.numel() == 0 or max_val.numel() == 0: + min_val, max_val = torch.aminmax(y, dim=1) + else: + min_val_cur, max_val_cur = torch.aminmax(y, dim=1) + min_val = torch.min(min_val_cur, min_val) + max_val = torch.max(max_val_cur, max_val) + + self.min_val.resize_(min_val.shape) + self.max_val.resize_(max_val.shape) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + + return x_copy + + def _calculate_percentile_stats(self, x_copy): + r"""Calculates and stores the per_channel percentile stats with forward values. + Does calculation based on channel axis: self.ch_axis + + Args + x_copy: A copy of the forward data + + Returns the passed in x_copy + """ + # get the dimension of the copy + x_dim = x_copy.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x_copy.permute(new_axis_list) + # Need to match dtype of min/max because the updates to buffers + # are done in place and types need to match for comparisons + y = y.to(self.min_val.dtype) + y = torch.flatten(y, start_dim=1) + y = y.to(dtype=self.min_val.dtype, device="cpu") + + # find the percentile values along the axis + # we want both 100th percentile and comp_percentile + # we also want to find 0th quartile to see if we have constant channel + quantiles_list = [0, self.comp_percentile, 1.00] + quantiles_to_find = torch.tensor(quantiles_list, dtype=self.min_val.dtype) + + # find the quantiles + desired_quantiles = torch.quantile( + y, quantiles_to_find, dim=self.ch_axis, interpolation="lower" + ) + zero_quantile = desired_quantiles[0] + comp_quantile = desired_quantiles[1] + hundreth_quartile = desired_quantiles[2] + + # if any of the channels have 0s, we ignore that channel for this calculation + any_non_zero_quantile_value: torch.Tensor = ( + comp_quantile != torch.tensor([0]) + ) | (hundreth_quartile != torch.tensor([0])) + any_non_zero_quantile_value = ( + any_non_zero_quantile_value.int() + ) # transform boolean values to int values + + # we also check if we have a constant channel + any_constant_channels: torch.Tensor = ( + hundreth_quartile - zero_quantile + ) == torch.tensor([0]) + any_constant_channels = ( + any_constant_channels.int() + ) # transform boolean values to int values + + # possibilities to get nan as an answer + # will ignore any of these three cases with 0s and just not deal with them for now + # case (1) 0 in numerator: issue if 0 is largest, all negative, and rest are really negative + # case (2) 0 in denominator: is possible unless case 3, we just ignore + # case (3) 0 in both: not outlier, channel just kinda useless, ignore + + # get the ratio and get rid of nan values + quantile_ratios = hundreth_quartile / comp_quantile + quantile_ratios = torch.nan_to_num(quantile_ratios) + # update averages, remembering to only update if didn't have zeros + ratio_if_not_zero = any_non_zero_quantile_value * quantile_ratios + + # if num_batches and average_ratio are not initialized, we want to initialize them + if ( + self.percentile_batches_tracked.shape[0] == 0 + or self.average_percentile_ratio.shape[0] == 0 + ): + self.percentile_batches_tracked = torch.zeros_like( + any_non_zero_quantile_value + ) + self.average_percentile_ratio = torch.zeros_like(ratio_if_not_zero) + + # also initialize the constant channel var if that is not initialized separately + if self.constant_channels.shape[0] == 0: + self.constant_channels = torch.zeros_like(any_constant_channels) + + # get current num batches and average ratio + num_batches = self.percentile_batches_tracked + average_ratio = self.average_percentile_ratio + + # calculate new_number of batches, new_ratios, and get rid of nans because of 0 size batches + new_number_of_batches: torch.Tensor = num_batches + any_non_zero_quantile_value + new_ratios: torch.Tensor = ( + (average_ratio * num_batches) + ratio_if_not_zero + ) / new_number_of_batches + new_ratios = torch.nan_to_num(new_ratios) + + # update the number of non-constant channels + new_constant_count: torch.Tensor = ( + self.constant_channels + any_constant_channels + ) + + # update the values locally + self.percentile_batches_tracked.copy_(new_number_of_batches) + self.average_percentile_ratio.copy_(new_ratios) + self.constant_channels.copy_(new_constant_count) + + return x_copy + + @torch.jit.export + def get_batch_to_epoch_ratio(self): + epoch_activation_range = self.epoch_activation_max - self.epoch_activation_min + + if epoch_activation_range == torch.tensor(float(0)): + raise ValueError("Range for Epoch is 0") + elif epoch_activation_range == torch.tensor(float("inf")): + raise ValueError( + "No data has been run through observer or infinity value present" + ) + else: + return self.average_batch_activation_range / epoch_activation_range + + @torch.jit.export + def reset_batch_and_epoch_values(self): + # set all the values back to their original defaults for a new epoch + # keep device + device = self.max_val.device + self.num_batches_tracked = 0 + self.average_batch_activation_range = torch.tensor(float(0), device=device) + self.epoch_activation_min = torch.tensor(float("inf"), device=device) + self.epoch_activation_max = torch.tensor(float("-inf"), device=device) + self.min_val = torch.tensor([], device=device) + self.max_val = torch.tensor([], device=device) + self.average_percentile_ratio = torch.tensor([], device=device) + self.percentile_batches_tracked = torch.tensor([], device=device) + self.constant_channels = torch.tensor([], device=device) + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + raise Exception( # noqa: TRY002 + "calculate_qparams should not be called for ModelReportObserver" + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/model_report_visualizer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/model_report_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..2e58772660c5a9067f727bf066b5519f65f37637 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/model_report_visualizer.py @@ -0,0 +1,712 @@ +# mypy: allow-untyped-defs +from collections import OrderedDict, OrderedDict as OrdDict +from typing import Any + +import torch + + +# try to import tablate +got_tabulate = True +try: + from tabulate import tabulate +except ImportError: + got_tabulate = False + + +# var to see if we could import matplotlib +got_matplotlib = True +try: + import matplotlib.pyplot as plt +except ImportError: + got_matplotlib = False + + +class ModelReportVisualizer: + r""" + The ModelReportVisualizer class aims to provide users a way to visualize some of the statistics + that were generated by the ModelReport API. However, at a higher level, the class aims to provide + some level of visualization of statistics to PyTorch in order to make it easier to parse data and + diagnose any potential issues with data or a specific model. With respect to the visualizations, + the ModelReportVisualizer class currently supports several methods of visualizing data. + + Supported Visualization Methods Include: + - Table format + - Plot format (line graph) + - Histogram format + + For all of the existing visualization methods, there is the option to filter data based on: + - A module fqn prefix + - Feature [required for the plot and histogram] + + * :attr:`generated_reports` The reports generated by the ModelReport class in the structure below + Ensure sure that features that are the same across different report contain the same name + Ensure that objects representing the same features are the same type / dimension (where applicable) + + Note: + Currently, the ModelReportVisualizer class supports visualization of data generated by the + ModelReport class. However, this structure is extensible and should allow the visualization of + other information as long as the information is structured in the following general format: + + Report Structure + -- module_fqn [module with attached detectors] + | + -- feature keys [not every detector extracts same information] + [same collected info has same keys, unless can be specific to detector] + + + The goal behind the class is that the generated visualizations can be used in conjunction with the generated + report for people to get a better understanding of issues and what the fix might be. It is also just to provide + a good visualization platform, since it might be hard to parse through the ModelReport returned dictionary as + that grows in size. + + General Use Flow Expected + 1.) Initialize ModelReport object with reports of interest by passing in initialized detector objects + 2.) Prepare your model with prepare_fx + 3.) Call model_report.prepare_detailed_calibration on your model to add relevant observers + 4.) Calibrate your model with data + 5.) Call model_report.generate_report on your model to generate report and optionally remove added observers + 6.) Use output of model_report.generate_report to initialize ModelReportVisualizer instance + 7.) Use instance to view different views of data as desired, applying filters as needed + 8.) Either see the super detailed information or just the actual printed or shown table / plot / histogram + + """ + + # keys for table dict + TABLE_TENSOR_KEY = "tensor_level_info" + TABLE_CHANNEL_KEY = "channel_level_info" + + # Constants for header vals + NUM_NON_FEATURE_TENSOR_HEADERS = 2 + NUM_NON_FEATURE_CHANNEL_HEADERS = 3 + + # Constants for row index in header + CHANNEL_NUM_INDEX = 2 + + def __init__(self, generated_reports: OrderedDict[str, Any]): + r""" + Initializes the ModelReportVisualizer instance with the necessary reports. + + Args: + generated_reports (Dict[str, Any]): The reports generated by the ModelReport class + can also be a dictionary generated in another manner, as long as format is same + """ + self.generated_reports = generated_reports + + def get_all_unique_module_fqns(self) -> set[str]: + r""" + The purpose of this method is to provide a user the set of all module_fqns so that if + they wish to use some of the filtering capabilities of the ModelReportVisualizer class, + they don't need to manually parse the generated_reports dictionary to get this information. + + Returns all the unique module fqns present in the reports the ModelReportVisualizer + instance was initialized with. + """ + # returns the keys of the ordered dict + return set(self.generated_reports.keys()) + + def get_all_unique_feature_names( + self, plottable_features_only: bool = True + ) -> set[str]: + r""" + The purpose of this method is to provide a user the set of all feature names so that if + they wish to use the filtering capabilities of the generate_table_view(), or use either of + the generate_plot_view() or generate_histogram_view(), they don't need to manually parse + the generated_reports dictionary to get this information. + + Args: + plottable_features_only (bool): True if the user is only looking for plottable features, + False otherwise + plottable features are those that are tensor values + Default: True (only return those feature names that are plottable) + + Returns all the unique module fqns present in the reports the ModelReportVisualizer + instance was initialized with. + """ + unique_feature_names = set() + for module_fqn in self.generated_reports: + # get dict of the features + feature_dict: dict[str, Any] = self.generated_reports[module_fqn] + + # loop through features + for feature_name in feature_dict: + # if we need plottable, ensure type of val is tensor + if ( + not plottable_features_only + or type(feature_dict[feature_name]) is torch.Tensor + ): + unique_feature_names.add(feature_name) + + # return our compiled set of unique feature names + return unique_feature_names + + def _get_filtered_data( + self, feature_filter: str, module_fqn_filter: str + ) -> OrderedDict[str, Any]: + r""" + Filters the data and returns it in the same ordered dictionary format so the relevant views can be displayed. + + Args: + feature_filter (str): The feature filter, if we want to filter the set of data to only include + a certain set of features that include feature_filter + If feature = "", then we do not filter based on any features + module_fqn_filter (str): The filter on prefix for the module fqn. All modules that have fqn with + this prefix will be included + If module_fqn_filter = "" we do not filter based on module fqn, and include all modules + + First, the data is filtered based on module_fqn, and then filtered based on feature + Returns an OrderedDict (sorted in order of model) mapping: + module_fqns -> feature_names -> values + """ + # create return dict + filtered_dict: OrderedDict[str, Any] = OrdDict() + + for module_fqn in self.generated_reports: + # first filter based on module + if module_fqn_filter == "" or module_fqn_filter in module_fqn: + # create entry for module and loop through features + filtered_dict[module_fqn] = {} + module_reports = self.generated_reports[module_fqn] + for feature_name in module_reports: + # check if filtering on features and do so if desired + if feature_filter == "" or feature_filter in feature_name: + filtered_dict[module_fqn][feature_name] = module_reports[ + feature_name + ] + + # we have populated the filtered dict, and must return it + + return filtered_dict + + def _generate_tensor_table( + self, + filtered_data: OrderedDict[str, dict[str, Any]], + tensor_features: list[str], + ) -> tuple[list, list]: + r""" + Takes in the filtered data and features list and generates the tensor headers and table + + Currently meant to generate the headers and table for both the tensor information. + + Args: + filtered_data (OrderedDict[str, Dict[str, Any]]): An OrderedDict (sorted in order of model) mapping: + module_fqns -> feature_names -> values + tensor_features (List[str]): A list of the tensor level features + + Returns a tuple with: + A list of the headers of the tensor table + A list of lists containing the table information row by row + The 0th index row will contain the headers of the columns + The rest of the rows will contain data + """ + # now we compose the tensor information table + tensor_table: list[list[Any]] = [] + tensor_headers: list[str] = [] + + # append the table row to the table only if we have features + if len(tensor_features) > 0: + # now we add all the data + for index, module_fqn in enumerate(filtered_data): + # we make a new row for the tensor table + tensor_table_row = [index, module_fqn] + for feature in tensor_features: + # we iterate in same order of added features + + if feature in filtered_data[module_fqn]: + # add value if applicable to module + feature_val = filtered_data[module_fqn][feature] + else: + # add that it is not applicable + feature_val = "Not Applicable" + + # if it's a tensor we want to extract val + if isinstance(feature_val, torch.Tensor): + feature_val = feature_val.item() + + # we add to our list of values + # pyrefly: ignore [bad-argument-type] + tensor_table_row.append(feature_val) + + tensor_table.append(tensor_table_row) + + # add row of headers of we actually have something, otherwise just empty + if len(tensor_table) != 0: + tensor_headers = ["idx", "layer_fqn"] + tensor_features + + return (tensor_headers, tensor_table) + + def _generate_channels_table( + self, + filtered_data: OrderedDict[str, Any], + channel_features: list[str], + num_channels: int, + ) -> tuple[list, list]: + r""" + Takes in the filtered data and features list and generates the channels headers and table + + Currently meant to generate the headers and table for both the channels information. + + Args: + filtered_data (OrderedDict[str, Any]): An OrderedDict (sorted in order of model) mapping: + module_fqns -> feature_names -> values + channel_features (List[str]): A list of the channel level features + num_channels (int): Number of channels in the channel data + + Returns a tuple with: + A list of the headers of the channel table + A list of lists containing the table information row by row + The 0th index row will contain the headers of the columns + The rest of the rows will contain data + """ + # now we compose the table for the channel information table + channel_table: list[list[Any]] = [] + channel_headers: list[str] = [] + + # counter to keep track of number of entries in + channel_table_entry_counter: int = 0 + + if len(channel_features) > 0: + # now we add all channel data + for module_fqn in filtered_data: + # we iterate over all channels + for channel in range(num_channels): + # we make a new row for the channel + new_channel_row = [channel_table_entry_counter, module_fqn, channel] + for feature in channel_features: + if feature in filtered_data[module_fqn]: + # add value if applicable to module + feature_val = filtered_data[module_fqn][feature][channel] + else: + # add that it is not applicable + feature_val = "Not Applicable" + + # if it's a tensor we want to extract val + if type(feature_val) is torch.Tensor: + feature_val = feature_val.item() + + # add value to channel specific row + # pyrefly: ignore [bad-argument-type] + new_channel_row.append(feature_val) + + # add to table and increment row index counter + channel_table.append(new_channel_row) + channel_table_entry_counter += 1 + + # add row of headers of we actually have something, otherwise just empty + if len(channel_table) != 0: + channel_headers = ["idx", "layer_fqn", "channel"] + channel_features + + return (channel_headers, channel_table) + + def generate_filtered_tables( + self, feature_filter: str = "", module_fqn_filter: str = "" + ) -> dict[str, tuple[list, list]]: + r""" + Takes in optional filter values and generates two tables with desired information. + + The generated tables are presented in both a list-of-lists format + + The reason for the two tables are that they handle different things: + 1.) the first table handles all tensor level information + 2.) the second table handles and displays all channel based information + + The reasoning for this is that having all the info in one table can make it ambiguous which collected + statistics are global, and which are actually per-channel, so it's better to split it up into two + tables. This also makes the information much easier to digest given the plethora of statistics collected + + Tensor table columns: + idx layer_fqn feature_1 feature_2 feature_3 .... feature_n + ---- --------- --------- --------- --------- --------- + + Per-Channel table columns: + idx layer_fqn channel feature_1 feature_2 feature_3 .... feature_n + ---- --------- ------- --------- --------- --------- --------- + + Args: + feature_filter (str, optional): Filters the features presented to only those that + contain this filter substring + Default = "", results in all the features being printed + module_fqn_filter (str, optional): Only includes modules that contains this string + Default = "", results in all the modules in the reports to be visible in the table + + Returns a dictionary with two keys: + (Dict[str, Tuple[List, List]]) A dict containing two keys: + "tensor_level_info", "channel_level_info" + Each key maps to a tuple with: + A list of the headers of each table + A list of lists containing the table information row by row + The 0th index row will contain the headers of the columns + The rest of the rows will contain data + + Example Use: + >>> # xdoctest: +SKIP("undefined variables") + >>> mod_report_visualizer.generate_filtered_tables( + ... feature_filter="per_channel_min", module_fqn_filter="block1" + ... ) # generates table with per_channel_min info for all modules in block 1 of the model + """ + # first get the filtered data + filtered_data: OrderedDict[str, Any] = self._get_filtered_data( + feature_filter, module_fqn_filter + ) + + # now we split into tensor and per-channel data + tensor_features: set[str] = set() + channel_features: set[str] = set() + + # keep track of the number of channels we have + num_channels: int = 0 + + for module_fqn in filtered_data: + for feature_name in filtered_data[module_fqn]: + # get the data for that specific feature + feature_data = filtered_data[module_fqn][feature_name] + + # check if not zero dim tensor + is_tensor: bool = isinstance(feature_data, torch.Tensor) + is_not_zero_dim: bool = is_tensor and len(feature_data.shape) != 0 + + if is_not_zero_dim or isinstance(feature_data, list): + # works means per channel + channel_features.add(feature_name) + num_channels = len(feature_data) + else: + # means is per-tensor + tensor_features.add(feature_name) + + # we make them lists for iteration purposes + tensor_features_list: list[str] = sorted(tensor_features) + channel_features_list: list[str] = sorted(channel_features) + + # get the tensor info + tensor_headers, tensor_table = self._generate_tensor_table( + filtered_data, tensor_features_list + ) + + # get the channel info + channel_headers, channel_table = self._generate_channels_table( + filtered_data, channel_features_list, num_channels + ) + + # let's now create the dictionary to return + table_dict = { + self.TABLE_TENSOR_KEY: (tensor_headers, tensor_table), + self.TABLE_CHANNEL_KEY: (channel_headers, channel_table), + } + + # return the two tables + return table_dict + + def generate_table_visualization( + self, feature_filter: str = "", module_fqn_filter: str = "" + ): + r""" + Takes in optional filter values and prints out formatted tables of the information. + + The reason for the two tables printed out instead of one large one are that they handle different things: + 1.) the first table handles all tensor level information + 2.) the second table handles and displays all channel based information + + The reasoning for this is that having all the info in one table can make it ambiguous which collected + statistics are global, and which are actually per-channel, so it's better to split it up into two + tables. This also makes the information much easier to digest given the plethora of statistics collected + + Tensor table columns: + idx layer_fqn feature_1 feature_2 feature_3 .... feature_n + ---- --------- --------- --------- --------- --------- + + Per-Channel table columns: + + idx layer_fqn channel feature_1 feature_2 feature_3 .... feature_n + ---- --------- ------- --------- --------- --------- --------- + + Args: + feature_filter (str, optional): Filters the features presented to only those that + contain this filter substring + Default = "", results in all the features being printed + module_fqn_filter (str, optional): Only includes modules that contains this string + Default = "", results in all the modules in the reports to be visible in the table + + Example Use: + >>> # xdoctest: +SKIP("undefined variables") + >>> mod_report_visualizer.generate_table_visualization( + ... feature_filter="per_channel_min", module_fqn_filter="block1" + ... ) + >>> # prints out neatly formatted table with per_channel_min info + >>> # for all modules in block 1 of the model + """ + # see if we got tabulate + if not got_tabulate: + print("Make sure to install tabulate and try again.") + return None + + # get the table dict and the specific tables of interest + table_dict = self.generate_filtered_tables(feature_filter, module_fqn_filter) + tensor_headers, tensor_table = table_dict[self.TABLE_TENSOR_KEY] + channel_headers, channel_table = table_dict[self.TABLE_CHANNEL_KEY] + + # get the table string and print it out + # now we have populated the tables for each one + # let's create the strings to be returned + table_str = "" + # the tables will have some headers columns that are non-feature + # ex. table index, module name, channel index, etc. + # we want to look at header columns for features, that come after those headers + if len(tensor_headers) > self.NUM_NON_FEATURE_TENSOR_HEADERS: + # if we have at least one tensor level feature to be added we add tensor table + table_str += "Tensor Level Information \n" + table_str += tabulate(tensor_table, headers=tensor_headers) + if len(channel_headers) > self.NUM_NON_FEATURE_CHANNEL_HEADERS: + # if we have at least one channel level feature to be added we add tensor table + table_str += "\n\n Channel Level Information \n" + table_str += tabulate(channel_table, headers=channel_headers) + + # if no features at all, let user know + if table_str == "": + table_str = "No data points to generate table with." + + print(table_str) + + def _get_plottable_data( + self, feature_filter: str, module_fqn_filter: str + ) -> tuple[list, list[list], bool]: + r""" + Takes in the feature filters and module filters and outputs the x and y data for plotting + + Args: + feature_filter (str): Filters the features presented to only those that + contain this filter substring + module_fqn_filter (str): Only includes modules that contains this string + + Returns a tuple of three elements + The first is a list containing relevant x-axis data + The second is a list containing the corresponding y-axis data + If the data is per channel + """ + # get the table dict and the specific tables of interest + table_dict = self.generate_filtered_tables(feature_filter, module_fqn_filter) + tensor_headers, tensor_table = table_dict[self.TABLE_TENSOR_KEY] + channel_headers, channel_table = table_dict[self.TABLE_CHANNEL_KEY] + + # make sure it is only 1 feature that is being plotted + # get the number of features in each of these + tensor_info_features_count = ( + len(tensor_headers) - ModelReportVisualizer.NUM_NON_FEATURE_TENSOR_HEADERS + ) + channel_info_features_count = ( + len(channel_headers) - ModelReportVisualizer.NUM_NON_FEATURE_CHANNEL_HEADERS + ) + + # see if valid tensor or channel plot + is_valid_per_tensor_plot: bool = tensor_info_features_count == 1 + is_valid_per_channel_plot: bool = channel_info_features_count == 1 + + # offset should either be one of tensor or channel table or neither + feature_column_offset = ModelReportVisualizer.NUM_NON_FEATURE_TENSOR_HEADERS + table = tensor_table + + # if a per_channel plot, we have different offset and table + if is_valid_per_channel_plot: + feature_column_offset = ( + ModelReportVisualizer.NUM_NON_FEATURE_CHANNEL_HEADERS + ) + table = channel_table + + x_data: list = [] + y_data: list[list] = [] + # the feature will either be a tensor feature or channel feature + if is_valid_per_tensor_plot: + for table_row_num, row in enumerate(table): + # get x_value to append + x_val_to_append = table_row_num + # the index of the feature will the 0 + num non feature columns + tensor_feature_index = feature_column_offset + row_value = row[tensor_feature_index] + if type(row_value) is not str: + x_data.append(x_val_to_append) + y_data.append(row_value) + elif is_valid_per_channel_plot: + # gather the x_data and multiple y_data + # calculate the number of channels + num_channels: int = max(row[self.CHANNEL_NUM_INDEX] for row in table) + 1 + + # separate data list per channel + y_data.extend([] for _ in range(num_channels)) + + for table_row_num, row in enumerate(table): + # get x_value to append + x_val_to_append = table_row_num + current_channel = row[ + self.CHANNEL_NUM_INDEX + ] # initially chose current channel + new_module_index: int = table_row_num // num_channels + x_val_to_append = new_module_index + + # the index of the feature will the 0 + num non feature columns + tensor_feature_index = feature_column_offset + row_value = row[tensor_feature_index] + if type(row_value) is not str: + # only append if new index we are appending + if len(x_data) == 0 or x_data[-1] != x_val_to_append: + x_data.append(x_val_to_append) + + # append value for that channel + y_data[current_channel].append(row_value) + else: + # more than one feature was chosen + error_str = "Make sure to pick only a single feature with your filter to plot a graph." + error_str += " We recommend calling get_all_unique_feature_names() to find unique feature names." + error_str += " Pick one of those features to plot." + raise ValueError(error_str) + + # return x, y values, and if data is per-channel + return (x_data, y_data, is_valid_per_channel_plot) + + def generate_plot_visualization( + self, feature_filter: str, module_fqn_filter: str = "" + ): + r""" + Takes in a feature and optional module_filter and plots of the desired data. + + For per channel features, it averages the value across the channels and plots a point + per module. The reason for this is that for models with hundreds of channels, it can + be hard to differentiate one channel line from another, and so the point of generating + a single average point per module is to give a sense of general trends that encourage + further deep dives. + + Note: + Only features in the report that have tensor value data are plottable by this class + When the tensor information is plotted, it will plot: + idx as the x val, feature value as the y_val + When the channel information is plotted, it will plot: + the first idx of each module as the x val, feature value as the y_val [for each channel] + The reason for this is that we want to be able to compare values across the + channels for same layer, and it will be hard if values are staggered by idx + This means each module is represented by only 1 x value + Args: + feature_filter (str): Filters the features presented to only those that + contain this filter substring + module_fqn_filter (str, optional): Only includes modules that contains this string + Default = "", results in all the modules in the reports to be visible in the table + + Example Use: + >>> # xdoctest: +SKIP("undefined variables") + >>> mod_report_visualizer.generate_plot_visualization( + ... feature_filter="per_channel_min", module_fqn_filter="block1" + ... ) + >>> # outputs line plot of per_channel_min information for all + >>> # modules in block1 of model each channel gets it's own line, + >>> # and it's plotted across the in-order modules on the x-axis + """ + # checks if we have matplotlib and let's user know to install it if don't + if not got_matplotlib: + print("make sure to install matplotlib and try again.") + return None + + # get the x and y data and if per channel + x_data, y_data, data_per_channel = self._get_plottable_data( + feature_filter, module_fqn_filter + ) + + # plot based on whether data is per channel or not + ax = plt.subplot() + ax.set_ylabel(feature_filter) + ax.set_title(feature_filter + " Plot") + plt.xticks(x_data) # only show ticks for actual points + + if data_per_channel: + ax.set_xlabel("First idx of module") + # set the legend as well + # plot a single line that is average of the channel values + num_modules = len( + y_data[0] + ) # all y_data have same length, so get num modules + num_channels = len( + y_data + ) # we want num channels to be able to calculate average later + + avg_vals = [ + sum(y_data[:][index]) / num_channels for index in range(num_modules) + ] + + # plot the three things we measured + ax.plot( + x_data, avg_vals, label=f"Average Value Across {num_channels} Channels" + ) + ax.legend(loc="upper right") + else: + ax.set_xlabel("idx") + ax.plot(x_data, y_data) + + # actually show the plot + plt.show() + + def generate_histogram_visualization( + self, feature_filter: str, module_fqn_filter: str = "", num_bins: int = 10 + ): + r""" + Takes in a feature and optional module_filter and plots the histogram of desired data. + + Note: + Only features in the report that have tensor value data can be viewed as a histogram + If you want to plot a histogram from all the channel values of a specific feature for + a specific model, make sure to specify both the model and the feature properly + in the filters and you should be able to see a distribution of the channel data + + Args: + feature_filter (str, optional): Filters the features presented to only those that + contain this filter substring + Default = "", results in all the features being printed + module_fqn_filter (str, optional): Only includes modules that contains this string + Default = "", results in all the modules in the reports to be visible in the table + num_bins (int, optional): The number of bins to create the histogram with + Default = 10, the values will be split into 10 equal sized bins + + Example Use: + >>> # xdoctest: +SKIP + >>> mod_report_visualizer.generategenerate_histogram_visualization_plot_visualization( + ... feature_filter="per_channel_min", module_fqn_filter="block1" + ... ) + # outputs histogram of per_channel_min information for all modules in block1 of model + information is gathered across all channels for all modules in block 1 for the + per_channel_min and is displayed in a histogram of equally sized bins + """ + # checks if we have matplotlib and let's user know to install it if don't + if not got_matplotlib: + print("make sure to install matplotlib and try again.") + return None + + # get the x and y data and if per channel + _x_data, y_data, data_per_channel = self._get_plottable_data( + feature_filter, module_fqn_filter + ) + + # for histogram, we just care about plotting the y data + # plot based on whether data is per channel or not + ax = plt.subplot() + ax.set_xlabel(feature_filter) + ax.set_ylabel("Frequency") + ax.set_title(feature_filter + " Histogram") + + if data_per_channel: + # set the legend as well + # combine all the data + all_data = [] + for channel_info in y_data: + all_data.extend(channel_info) + + _val, bins, _ = plt.hist( + all_data, + bins=num_bins, + stacked=True, + rwidth=0.8, + ) + plt.xticks(bins) + else: + _val, bins, _ = plt.hist( + y_data, + bins=num_bins, + stacked=False, + rwidth=0.8, + ) + plt.xticks(bins) + + plt.show() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6b16d3c2f3ce5055dd73f33abaaeb351757a83e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/_affine_quantization.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/_affine_quantization.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c35bbdcb8ac6c804c1b3d82d58f7afc6ace3a10c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/_affine_quantization.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/_numeric_debugger.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/_numeric_debugger.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8a7580aefe14a2aa339effc04953d659714b891 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/_numeric_debugger.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/graph_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/graph_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..690de9b4f6dda9be8138ff6b3b07664c0074c2c4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/graph_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/lowering.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/lowering.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc755825ee56e83029a17af4683413dfd31dc520 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/lowering.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/port_metadata_pass.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/port_metadata_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3ba34cc0fe696e2cf71a51183025c527754f4fb Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/port_metadata_pass.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/prepare.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/prepare.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1dba100a53305269a99cbcefcd012c95c27cef8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/prepare.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/qat_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/qat_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1c53dad4ec1e00d8d398dec83990f726470cf47 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/qat_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03d98199844a30d3fc64e6c9ab114dd75b6745bd Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c92418634eb5532ccd6b0a88e27cc9d69842deab Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/rewrite.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/rewrite.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ba4fd328394360d64c0a1518c75c5224d02eb6a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/rewrite.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66bab43581ce1c727552d254616c42c739a1b67f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/composable_quantizer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/composable_quantizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4aca22230a79070f94710a986663406fc00cd5d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/composable_quantizer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/embedding_quantizer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/embedding_quantizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df3f9e85b7de171c66b0a474d7dd3c2b0821551d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/embedding_quantizer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/quantizer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/quantizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69d98ef317086dc1ea90f46a29134682425af376 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/quantizer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a6c68d76039adce7d6e651ef2744c2b9333664a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/x86_inductor_quantizer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/x86_inductor_quantizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eab9385db9f5c81f97b978237f98d99add013faa Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/x86_inductor_quantizer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83deaeec9c64a56709b063231629b0ce0b0d816d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f05a40b77c70cb7ba06b542ac931b1af23ba4c3 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/xpu_inductor_quantizer.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/xpu_inductor_quantizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18be2c877b3f663dccc087d8bf07f3ceb16602b0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__pycache__/xpu_inductor_quantizer.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9bf3981c268328c3369a595c4d556f50a3c2ca87 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/anomaly_mode.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/anomaly_mode.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e703b1b1a30b9ed461de5b0c2d6b7c2c8f9658f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/anomaly_mode.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/functional.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/functional.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b88bd4d330324606ad77752a90dc41f43b58de32 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/functional.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/grad_mode.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/grad_mode.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a047f4c411c07ff9f7e3ae1fb58f797e8f377dd Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/grad_mode.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/profiler.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/profiler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e9806a09f5505bb6779c5bdfc7405fa3caa8072 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/profiler.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/profiler_legacy.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/profiler_legacy.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6255886d9323e3f283f0829697f388b49dd8ff0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/profiler_legacy.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/variable.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/variable.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cf3b682d7ec6fcffdf255ed254ba9270b77bc07 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/__pycache__/variable.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/_functions/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/_functions/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b5ed607d13b9d40b5d1afc8acf8a37bfab4eb86 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/_functions/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/_functions/__pycache__/tensor.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/_functions/__pycache__/tensor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e51d650f6b55c84ced7786843a76995db111755 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/_functions/__pycache__/tensor.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/_functions/__pycache__/utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/_functions/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15491cdd7613528ebd99b5b81a59786c2e9aec6d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/autograd/_functions/__pycache__/utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7f9539cccf2e91a7469801891e37d4e34d0f489 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/_coreml/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/_coreml/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/_coreml/preprocess.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/_coreml/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..3180e56a6baf96b56c88a712a4426108d8c8e2fc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/_coreml/preprocess.py @@ -0,0 +1,150 @@ +# mypy: allow-untyped-defs +import hashlib +import json + +import coremltools as ct # type: ignore[import] +from coremltools.converters.mil.input_types import TensorType # type: ignore[import] +from coremltools.converters.mil.mil import types # type: ignore[import] +from coremltools.models.neural_network import quantization_utils # type: ignore[import] + +import torch + + +CT_METADATA_VERSION = "com.github.apple.coremltools.version" +CT_METADATA_SOURCE = "com.github.apple.coremltools.source" + + +class ScalarType: + Float = 0 + Double = 1 + Int = 2 + Long = 3 + Undefined = 4 + + +# Supported Tensor types in coremltools: +# https://github.com/apple/coremltools/blob/main/coremltools/converters/mil/frontend/torch/converter.py#L28 +torch_to_mil_types = { + ScalarType.Float: types.fp32, + ScalarType.Double: types.fp64, + ScalarType.Int: types.int32, + ScalarType.Long: types.int64, +} + + +class CoreMLComputeUnit: + CPU = "cpuOnly" + CPUAndGPU = "cpuAndGPU" + ALL = "all" + + +class CoreMLQuantizationMode: + LINEAR = "linear" + LINEAR_SYMMETRIC = "linear_symmetric" + NONE = "none" + + +def TensorSpec(shape, dtype=ScalarType.Float): + return (shape, dtype) + + +def CompileSpec( + inputs, + outputs, + backend=CoreMLComputeUnit.CPU, + allow_low_precision=True, + quantization_mode=CoreMLQuantizationMode.NONE, + mlmodel_export_path=None, + convert_to=None, +): + return ( + inputs, + outputs, + backend, + allow_low_precision, + quantization_mode, + mlmodel_export_path, + convert_to, + ) + + +def _check_enumerated_shape(shape): + for s in shape: + if not isinstance(s, (list, tuple)): + return False + return True + + +def _convert_to_mil_type(shape, dtype, name: str): + mil_shape = shape + if _check_enumerated_shape(shape): + mil_shape = ct.EnumeratedShapes(shape) + ml_type = TensorType(shape=mil_shape, dtype=torch_to_mil_types[dtype]) + ml_type.name = name + return ml_type + + +def preprocess(script_module: torch._C.ScriptObject, compile_spec: dict[str, tuple]): + spec = compile_spec["forward"] + ( + input_specs, + output_specs, + backend, + allow_low_precision, + quantization_mode, + mlmodel_export_path, + convert_to, + ) = spec + mil_inputs = [] + inputs = [] + for index, input in enumerate(input_specs): + shape, dtype = input + name = "input_" + str(index) + inputs.append([name, str(dtype), str(shape)]) + ml_type = _convert_to_mil_type(shape, dtype, name) + mil_inputs.append(ml_type) + model = torch.jit.RecursiveScriptModule._construct(script_module, lambda x: None) + mlmodel = ct.convert(model, inputs=mil_inputs, convert_to=convert_to) + + if quantization_mode != CoreMLQuantizationMode.NONE: + quant_model_spec = quantization_utils.quantize_weights( + mlmodel, nbits=8, quantization_mode=quantization_mode + ) + mlmodel = ct.models.MLModel(quant_model_spec) + + spec = mlmodel.get_spec() + assert len(spec.description.output) == len(output_specs) # type: ignore[attr-defined] + outputs = [] + for index, output in enumerate(output_specs): + shape, dtype = output + name = spec.description.output[index].name # type: ignore[attr-defined] + outputs.append([name, str(dtype), str(shape)]) + mlmodel = ct.models.model.MLModel(spec) + print(mlmodel) + + if mlmodel_export_path is not None: + print(f"Saving CoreML .mlmodel file to {mlmodel_export_path}") + mlmodel.save(mlmodel_export_path) + + config = { + "spec_ver": str(spec.specificationVersion), # type: ignore[attr-defined] + "backend": backend, + "allow_low_precision": str(allow_low_precision), + } + metadata = { + "coremltool_ver": mlmodel.user_defined_metadata[CT_METADATA_VERSION], + "torch_ver": mlmodel.user_defined_metadata[CT_METADATA_SOURCE], + } + coreml_compile_spec = { + "inputs": inputs, + "outputs": outputs, + "config": config, + "metadata": metadata, + } + mlmodel = spec.SerializeToString() # type: ignore[attr-defined] + + return { + "model": mlmodel, + "hash": str(hashlib.sha256(mlmodel).hexdigest()), + "extra": json.dumps(coreml_compile_spec), + } diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/_nnapi/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/_nnapi/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/_nnapi/prepare.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/_nnapi/prepare.py new file mode 100644 index 0000000000000000000000000000000000000000..0fc48d711111ffd417fa1c544bd4b2362e75cf16 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/_nnapi/prepare.py @@ -0,0 +1,199 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +from typing import Optional + +import torch +from torch.backends._nnapi.serializer import _NnapiSerializer + + +ANEURALNETWORKS_PREFER_LOW_POWER = 0 +ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER = 1 +ANEURALNETWORKS_PREFER_SUSTAINED_SPEED = 2 + + +class NnapiModule(torch.nn.Module): + """Torch Module that wraps an NNAPI Compilation. + + This module handles preparing the weights, initializing the + NNAPI TorchBind object, and adjusting the memory formats + of all inputs and outputs. + """ + + # _nnapi.Compilation is defined + comp: Optional[torch.classes._nnapi.Compilation] # type: ignore[name-defined] + weights: list[torch.Tensor] + out_templates: list[torch.Tensor] + + def __init__( + self, + shape_compute_module: torch.nn.Module, + ser_model: torch.Tensor, + weights: list[torch.Tensor], + inp_mem_fmts: list[int], + out_mem_fmts: list[int], + compilation_preference: int, + relax_f32_to_f16: bool, + ): + super().__init__() + self.shape_compute_module = shape_compute_module + self.ser_model = ser_model + self.weights = weights + self.inp_mem_fmts = inp_mem_fmts + self.out_mem_fmts = out_mem_fmts + self.out_templates = [] + self.comp = None + self.compilation_preference = compilation_preference + self.relax_f32_to_f16 = relax_f32_to_f16 + + @torch.jit.export + def init(self, args: list[torch.Tensor]): + assert self.comp is None + self.out_templates = self.shape_compute_module.prepare(self.ser_model, args) # type: ignore[operator] + self.weights = [w.contiguous() for w in self.weights] + comp = torch.classes._nnapi.Compilation() + comp.init2( + self.ser_model, + self.weights, + self.compilation_preference, + self.relax_f32_to_f16, + ) + + self.comp = comp + + def forward(self, args: list[torch.Tensor]) -> list[torch.Tensor]: + if self.comp is None: + self.init(args) + comp = self.comp + assert comp is not None + outs = [torch.empty_like(out) for out in self.out_templates] + + assert len(args) == len(self.inp_mem_fmts) + fixed_args = [] + for idx in range(len(args)): + fmt = self.inp_mem_fmts[idx] + # These constants match the values in DimOrder in serializer.py + # TODO: See if it's possible to use those directly. + if fmt == 0: + fixed_args.append(args[idx].contiguous()) + elif fmt == 1: + fixed_args.append(args[idx].permute(0, 2, 3, 1).contiguous()) + else: + raise ValueError("Invalid mem_fmt") + comp.run(fixed_args, outs) + assert len(outs) == len(self.out_mem_fmts) + for idx in range(len(self.out_templates)): + fmt = self.out_mem_fmts[idx] + # These constants match the values in DimOrder in serializer.py + # TODO: See if it's possible to use those directly. + if fmt in (0, 2): + pass + elif fmt == 1: + outs[idx] = outs[idx].permute(0, 3, 1, 2) + else: + raise ValueError("Invalid mem_fmt") + return outs + + +def convert_model_to_nnapi( + model, + inputs, + serializer=None, + return_shapes=None, + use_int16_for_qint16=False, + compilation_preference=ANEURALNETWORKS_PREFER_SUSTAINED_SPEED, + relax_f32_to_f16=False, +): + ( + shape_compute_module, + ser_model_tensor, + used_weights, + inp_mem_fmts, + out_mem_fmts, + retval_count, + ) = process_for_nnapi( + model, inputs, serializer, return_shapes, use_int16_for_qint16 + ) + + nnapi_model = NnapiModule( + shape_compute_module, + ser_model_tensor, + used_weights, + inp_mem_fmts, + out_mem_fmts, + compilation_preference, + relax_f32_to_f16, + ) + + class NnapiInterfaceWrapper(torch.nn.Module): + """NNAPI list-ifying and de-list-ifying wrapper. + + NNAPI always expects a list of inputs and provides a list of outputs. + This module allows us to accept inputs as separate arguments. + It returns results as either a single tensor or tuple, + matching the original module. + """ + + def __init__(self, mod): + super().__init__() + self.mod = mod + + wrapper_model_py = NnapiInterfaceWrapper(nnapi_model) + wrapper_model = torch.jit.script(wrapper_model_py) + # TODO: Maybe make these names match the original. + arg_list = ", ".join(f"arg_{idx}" for idx in range(len(inputs))) + if retval_count < 0: + ret_expr = "retvals[0]" + else: + ret_expr = "".join(f"retvals[{idx}], " for idx in range(retval_count)) + wrapper_model.define( + f"def forward(self, {arg_list}):\n" + f" retvals = self.mod([{arg_list}])\n" + f" return {ret_expr}\n" + ) + return wrapper_model + + +def process_for_nnapi( + model, inputs, serializer=None, return_shapes=None, use_int16_for_qint16=False +): + model = torch.jit.freeze(model) + + if isinstance(inputs, torch.Tensor): + inputs = [inputs] + + serializer = serializer or _NnapiSerializer( + config=None, use_int16_for_qint16=use_int16_for_qint16 + ) + ( + ser_model, + used_weights, + inp_mem_fmts, + out_mem_fmts, + shape_compute_lines, + retval_count, + ) = serializer.serialize_model(model, inputs, return_shapes) + ser_model_tensor = torch.tensor(ser_model, dtype=torch.int32) + + # We have to create a new class here every time this function is called + # because module.define adds a method to the *class*, not the instance. + class ShapeComputeModule(torch.nn.Module): + """Code-gen-ed module for tensor shape computation. + + module.prepare will mutate ser_model according to the computed operand + shapes, based on the shapes of args. Returns a list of output templates. + """ + + shape_compute_module = torch.jit.script(ShapeComputeModule()) + real_shape_compute_lines = [ + "def prepare(self, ser_model: torch.Tensor, args: List[torch.Tensor]) -> List[torch.Tensor]:\n", + ] + [f" {line}\n" for line in shape_compute_lines] + shape_compute_module.define("".join(real_shape_compute_lines)) + + return ( + shape_compute_module, + ser_model_tensor, + used_weights, + inp_mem_fmts, + out_mem_fmts, + retval_count, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/_nnapi/serializer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/_nnapi/serializer.py new file mode 100644 index 0000000000000000000000000000000000000000..0ff09959f840c4b8c61147cc2180abc8d5d25b13 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/_nnapi/serializer.py @@ -0,0 +1,2231 @@ +# mypy: allow-untyped-defs +import array +import enum +import functools +import logging +import operator +import struct +import sys +from typing import NamedTuple, Optional + +import torch + + +# TODO: Add type annotations +# TODO: Check tensor types for ops + + +LOG = logging.getLogger("nnapi_serialize") + + +class NNAPI_OperandCode: + FLOAT32 = 0 + INT32 = 1 + UINT32 = 2 + TENSOR_FLOAT32 = 3 + TENSOR_INT32 = 4 + TENSOR_QUANT8_ASYMM = 5 + BOOL = 6 + TENSOR_QUANT16_SYMM = 7 + TENSOR_FLOAT16 = 8 + TENSOR_BOOL8 = 9 + FLOAT16 = 10 + TENSOR_QUANT8_SYMM_PER_CHANNEL = 11 + TENSOR_QUANT16_ASYMM = 12 + + +class NNAPI_OperationCode: + ADD = 0 + AVERAGE_POOL_2D = 1 + CONCATENATION = 2 + CONV_2D = 3 + DEPTHWISE_CONV_2D = 4 + DEPTH_TO_SPACE = 5 + DEQUANTIZE = 6 + EMBEDDING_LOOKUP = 7 + FLOOR = 8 + FULLY_CONNECTED = 9 + HASHTABLE_LOOKUP = 10 + L2_NORMALIZATION = 11 + L2_POOL_2D = 12 + LOCAL_RESPONSE_NORMALIZATION = 13 + LOGISTIC = 14 + LSH_PROJECTION = 15 + LSTM = 16 + MAX_POOL_2D = 17 + MUL = 18 + RELU = 19 + RELU1 = 20 + RELU6 = 21 + RESHAPE = 22 + RESIZE_BILINEAR = 23 + RNN = 24 + SOFTMAX = 25 + SPACE_TO_DEPTH = 26 + SVDF = 27 + TANH = 28 + BATCH_TO_SPACE_ND = 29 + DIV = 30 + MEAN = 31 + PAD = 32 + SPACE_TO_BATCH_ND = 33 + SQUEEZE = 34 + STRIDED_SLICE = 35 + SUB = 36 + TRANSPOSE = 37 + ABS = 38 + ARGMAX = 39 + ARGMIN = 40 + AXIS_ALIGNED_BBOX_TRANSFORM = 41 + BIDIRECTIONAL_SEQUENCE_LSTM = 42 + BIDIRECTIONAL_SEQUENCE_RNN = 43 + BOX_WITH_NMS_LIMIT = 44 + CAST = 45 + CHANNEL_SHUFFLE = 46 + DETECTION_POSTPROCESSING = 47 + EQUAL = 48 + EXP = 49 + EXPAND_DIMS = 50 + GATHER = 51 + GENERATE_PROPOSALS = 52 + GREATER = 53 + GREATER_EQUAL = 54 + GROUPED_CONV_2D = 55 + HEATMAP_MAX_KEYPOINT = 56 + INSTANCE_NORMALIZATION = 57 + LESS = 58 + LESS_EQUAL = 59 + LOG = 60 + LOGICAL_AND = 61 + LOGICAL_NOT = 62 + LOGICAL_OR = 63 + LOG_SOFTMAX = 64 + MAXIMUM = 65 + MINIMUM = 66 + NEG = 67 + NOT_EQUAL = 68 + PAD_V2 = 69 + POW = 70 + PRELU = 71 + QUANTIZE = 72 + QUANTIZED_16BIT_LSTM = 73 + RANDOM_MULTINOMIAL = 74 + REDUCE_ALL = 75 + REDUCE_ANY = 76 + REDUCE_MAX = 77 + REDUCE_MIN = 78 + REDUCE_PROD = 79 + REDUCE_SUM = 80 + ROI_ALIGN = 81 + ROI_POOLING = 82 + RSQRT = 83 + SELECT = 84 + SIN = 85 + SLICE = 86 + SPLIT = 87 + SQRT = 88 + TILE = 89 + TOPK_V2 = 90 + TRANSPOSE_CONV_2D = 91 + UNIDIRECTIONAL_SEQUENCE_LSTM = 92 + UNIDIRECTIONAL_SEQUENCE_RNN = 93 + RESIZE_NEAREST_NEIGHBOR = 94 + + +class NNAPI_FuseCode: + FUSED_NONE = 0 + FUSED_RELU = 1 + FUSED_RELU1 = 2 + FUSED_RELU6 = 3 + + +class OperandValueSourceType: + IMMEDIATE = 0 + NUMBERED_BUFFER = 2 + NUMBERED_MEMORY = 3 + + +# Scalar types that appear explicitly in models. +# These must be kept in sync with +# AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS. +# TODO: Expose these directly to Python to avoid maintaining this list. +class TorchScalarTypes(enum.Enum): + QUINT8 = 13 + + +def approx_equal(lhs, rhs, tolerance=1e-6): + return abs(lhs - rhs) <= tolerance * min(lhs, rhs) + + +def tensor_size(op_type, dims): + ITEM_SIZES = { + NNAPI_OperandCode.TENSOR_FLOAT32: 4, + NNAPI_OperandCode.TENSOR_INT32: 4, + NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: 1, + NNAPI_OperandCode.TENSOR_QUANT16_SYMM: 2, + NNAPI_OperandCode.TENSOR_QUANT16_ASYMM: 2, + } + size = ITEM_SIZES[op_type] + for d in dims: + size *= d + return size + + +def change_element(tup, index, value): + ls = list(tup) + ls[index] = value + return tuple(ls) + + +class ConvPoolArgs2d(NamedTuple): + """Configuration arguments for a convolution.""" + + kernel_h: int + kernel_w: int + stride_h: int + stride_w: int + pad_t: int + pad_b: int + pad_l: int + pad_r: int + dilation_h: int + dilation_w: int + group: int + + +class DimOrder(enum.Enum): + PRESUMED_CONTIGUOUS = 0 + CHANNELS_LAST = 1 + SCALAR_OR_VECTOR = 2 + UNKNOWN_CONSTANT = 999 + + +class Operand(NamedTuple): + """Representation of an NNAPI operand.""" + + # NNAPI operand type. One of NNAPI_OperandCode. + # TODO: Make this an enum. + op_type: int + + # This is always the PyTorch shape, which is NCHW for feature maps. + # The actual NNAPI operand might have a transposed shape. + # we use 0 for load time dynamic shapes & -1 for runtime dynamic shapes + shape: tuple[int, ...] + + # Specifies how the shape of the operand that we define in NNAPI + # relates to the shape we track above. + # - PRESUMED_CONTIGUOUS: physical NNAPI operand will exactly match + # the shape of the PyTorch tensor. + # - CHANNELS_LAST: The PyTorch tensor is expected to be NCHW, and + # the NNAPI operand will be represented explicitly as NHWC. + dim_order: DimOrder + + # Quantization params + scale: float + zero_point: int + + def use_nchw(self): + if self.dim_order is DimOrder.PRESUMED_CONTIGUOUS: + return True + if self.dim_order is DimOrder.CHANNELS_LAST: + return False + raise Exception("Unknown dim order") # noqa: TRY002 + + +def broadcast_shapes(shape1, shape2): + assert len(shape1) > 0 + assert len(shape2) > 0 + s1 = list(shape1) + s2 = list(shape2) + # TODO: Support non-equal-rank broadcast where semantics match. + # This can be tricky for NHWC tensors because dimension orders + # don't match between PT and NNAPI, even though semantics match. + if len(s1) > len(s2): + # s2 = [1] * (len(s1) - len(s2)) + s2 + raise Exception( # noqa: TRY002 + "Non-equal-rank broadcast is not supported yet." + ) # noqa: TRY002 + if len(s2) > len(s1): + # s3 = [1] * (len(s2) - len(s1)) + s1 + raise Exception( # noqa: TRY002 + "Non-equal-rank broadcast is not supported yet." + ) # noqa: TRY002 + ret = [] + for d1, d2 in zip(s1, s2): + if d1 == 1: + ret.append(d2) + elif d2 == 1: + ret.append(d1) + elif d1 == d2: + ret.append(d1) + else: + raise Exception( # noqa: TRY002 + f"Cannot broadcast shapes: {shape1} and {shape2}" + ) # noqa: TRY002 + return tuple(ret) + + +def get_conv_pool_shape(image_shape, args, out_ch, transpose): + batch, _in_c, in_h, in_w = image_shape + + # TODO: Handle dilation + if args.dilation_h != 1 or args.dilation_w != 1: + raise Exception("Dilation not supported yet.") # noqa: TRY002 + + if transpose: + out_h = (in_h - 1) * args.stride_h + args.kernel_h - args.pad_t - args.pad_b + out_w = (in_w - 1) * args.stride_w + args.kernel_w - args.pad_l - args.pad_l + else: + out_h = (in_h - args.kernel_h + args.pad_t + args.pad_b) // args.stride_h + 1 + out_w = (in_w - args.kernel_w + args.pad_l + args.pad_r) // args.stride_w + 1 + + # Handle variable-sized tensors. + if in_h == 0: + out_h = 0 + if in_w == 0: + out_w = 0 + + out_shape = (batch, out_ch, out_h, out_w) + return out_shape + + +def fix_shape(shape, dim_order): + # Return the actual shape that an operand should have in NNAPI, + # given a PyTorch shape and dimension order. This is where we + # convert from PyTorch's "always NCHW" shape to explicit NHWC. + if dim_order is DimOrder.PRESUMED_CONTIGUOUS: + return shape + if dim_order is DimOrder.CHANNELS_LAST: + return tuple([shape[0]] + list(shape[2:]) + [shape[1]]) + if dim_order is DimOrder.SCALAR_OR_VECTOR: + assert len(shape) == 0 or len(shape) == 1 + return shape + if dim_order is DimOrder.UNKNOWN_CONSTANT: + # XXX think this through + return shape + raise Exception(f"Bad dim_order: {dim_order!r}.") # noqa: TRY002 + + +def reverse_map_dim(dim_order, d): + # Return the original PyTorch dimension position for a given dimension. + # d should be the dimension that NNAPI will see. + # reverse_map_dim(PRESUMED_CONTIGUOUS, x) == x + # reverse_map_dim(CHANNELS_LAST, 3) == 1 + if dim_order in (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.SCALAR_OR_VECTOR): + return d + assert dim_order is DimOrder.CHANNELS_LAST + return [0, 2, 3, 1][d] + + +def flex_name(op_id, dim): + # Return the local variable name for the computed flexible size + # for a given op and dimension. + return f"s_{op_id}_{dim}" + + +class _NnapiSerializer: + def __init__(self, config, use_int16_for_qint16=False): + self.operands = [] + self.values = [] + self.operations = [] + self.value_data = [] + self.operation_args = [] + self.inputs = [] + self.outputs = [] + self.flexible_shape_computation_lines = [] + + self.modules = {} + self.constants = {} + self.tensor_sequences = {} + self.jitval_operand_map = {} + self.cached_immediates = {} + self.used_weights = [] + self.weight_offset = 0 + self.use_int16_for_qint16 = use_int16_for_qint16 + + if config is None: + config = {} + + def get_next_operand_id(self): + return len(self.operands) + + # Add a tensor operand corresponding to a JIT Value. + # Returns the NNAPI operand ID. Can be looked up later with + # get_tensor_operand_by_jitval. + def add_tensor_operand(self, jitval, oper): + assert isinstance(oper, Operand) + if jitval in self.jitval_operand_map: + raise Exception(f"Duplicate tensor: {jitval!r}") # noqa: TRY002 + + operand_id = self.get_next_operand_id() + self.operands.append(oper) + self.jitval_operand_map[jitval] = operand_id + return operand_id + + # Add a tensor operand that does not correspond to a JIT Value. + # Useful for cases where multiple NNAPI operands are required + # to implement one JIT IR node. Returns the NNAPI operand ID. + def add_anonymous_tensor_operand(self, oper): + assert isinstance(oper, Operand) + operand_id = self.get_next_operand_id() + self.operands.append(oper) + return operand_id + + def torch_tensor_to_operand(self, tensor, dim_order): + dtype = str(tensor.dtype).replace("torch.", "") + scale = 0.0 + zero_point = 0 + if dtype == "float32": + op_type = NNAPI_OperandCode.TENSOR_FLOAT32 + elif dtype == "int32": + op_type = NNAPI_OperandCode.TENSOR_INT32 + elif dtype == "quint8": + op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM + scale = tensor.q_scale() + zero_point = tensor.q_zero_point() + elif dtype == "qint32": + op_type = NNAPI_OperandCode.TENSOR_INT32 + scale = tensor.q_scale() + zero_point = tensor.q_zero_point() + assert zero_point == 0 + elif dtype == "int16": + if self.use_int16_for_qint16: + nnapi_dtype = getattr(tensor, "nnapi_dtype", None) + op_codes = ( + NNAPI_OperandCode.TENSOR_QUANT16_SYMM, + NNAPI_OperandCode.TENSOR_QUANT16_ASYMM, + ) + if nnapi_dtype in op_codes: + op_type = nnapi_dtype + scale = tensor.nnapi_scale + zero_point = tensor.nnapi_zero_point + else: + raise Exception( # noqa: TRY002 + f"`nnapi_type` needs to be one of {op_codes} for `int16`" + ) + else: + raise Exception( # noqa: TRY002 + "`int16` isn't supported. If you're trying to represent NNAPI" + " qint16 with Pytorch int16, set `use_int16_for_qint16 = True`" + ) + else: + raise Exception( # noqa: TRY002 + f"Can't handle input with dtype '{tensor.dtype}'" + ) # noqa: TRY002 + return Operand( + shape=tuple(tensor.shape), + # pyrefly: ignore [bad-argument-type] + op_type=op_type, + dim_order=dim_order, + scale=scale, + zero_point=zero_point, + ) + + def add_tensor_operand_for_input(self, arg_idx, jitval, tensor): + dim_order = ( + DimOrder.CHANNELS_LAST + if getattr(tensor, "nnapi_nhwc", False) + else DimOrder.PRESUMED_CONTIGUOUS + ) + toper = self.torch_tensor_to_operand(tensor, dim_order) + operand_id = self.add_tensor_operand(jitval, toper) + self.inputs.append(operand_id) + for dim, size in enumerate(tensor.shape): + if size == 0: + self.compute_operand_shape( + operand_id, dim, f"args[{arg_idx}].shape[{dim}]" + ) + return operand_id + + def add_tensor_operand_for_weight( + self, tensor, dim_order=DimOrder.UNKNOWN_CONSTANT + ): + toper = self.torch_tensor_to_operand(tensor, dim_order) + operand_id = len(self.operands) + self.operands.append(toper) + tsize = tensor_size(toper.op_type, toper.shape) + self.values.append((operand_id, OperandValueSourceType.NUMBERED_BUFFER)) + buf_num = len(self.used_weights) + offset = 0 + self.value_data.append(struct.pack("iii", buf_num, offset, tsize)) + # For NHWC NNAPI op, lay out data in the same dim order by permuting torch tensor + if dim_order == DimOrder.CHANNELS_LAST: + tensor = tensor.permute(0, 2, 3, 1) + self.used_weights.append(tensor) + return operand_id + + def add_immediate_operand(self, code, value, dims): + assert isinstance(dims, tuple) + cache_key = (code, value) + if cache_key not in self.cached_immediates: + operand_id = len(self.operands) + self.operands.append(Operand(code, dims, DimOrder.SCALAR_OR_VECTOR, 0.0, 0)) + self.values.append((operand_id, OperandValueSourceType.IMMEDIATE)) + self.value_data.append(value) + self.cached_immediates[cache_key] = operand_id + return self.cached_immediates[cache_key] + + def add_immediate_int_scalar(self, value): + return self.add_immediate_operand( + NNAPI_OperandCode.INT32, struct.pack("i", value), () + ) + + def add_immediate_float_scalar(self, value): + return self.add_immediate_operand( + NNAPI_OperandCode.FLOAT32, struct.pack("f", value), () + ) + + def add_immediate_bool_scalar(self, value): + return self.add_immediate_operand( + NNAPI_OperandCode.BOOL, b"\x01" if value else b"\x00", () + ) + + def add_immediate_int_vector(self, value): + return self.add_immediate_operand( + NNAPI_OperandCode.TENSOR_INT32, + array.array("i", value).tobytes(), + (len(value),), + ) + + def has_operand_for_jitval(self, jitval): + return jitval in self.jitval_operand_map + + def get_tensor_operand_by_jitval(self, jitval): + operand_id = self.jitval_operand_map[jitval] + return (operand_id, self.operands[operand_id]) + + def get_tensor_operand_by_jitval_fixed_size(self, jitval): + op_id, oper = self.get_tensor_operand_by_jitval(jitval) + for s in oper.shape: + if s == 0: + # TODO: Improve this error message, possibly after converting + # many callsites to support flexible size. + raise Exception( # noqa: TRY002 + "Flexible size is not supported for this operand." + ) # noqa: TRY002 + if s < 0: + # runtime flex + LOG.warning("Operand %s has runtime flex shape", oper) + return op_id, oper + + def get_tensor_operand_or_constant( + self, jitval, dim_order=DimOrder.PRESUMED_CONTIGUOUS + ): + operand_id = self.jitval_operand_map.get(jitval) + if operand_id is None: + _, value = self.get_constant_value(jitval, "TensorType") + operand_id = self.add_tensor_operand_for_weight(value, dim_order) + return (operand_id, self.operands[operand_id]) + + def get_tensor_operand_for_weight(self, jitval): + _, value = self.get_constant_value(jitval, "TensorType") + operand_id = self.add_tensor_operand_for_weight(value) + return (operand_id, self.operands[operand_id]) + + def add_operation(self, opcode, inputs, outputs): + self.operations.append((opcode, len(inputs), len(outputs))) + self.operation_args.extend(inputs + outputs) + + def add_tensor_sequence(self, jitval, values): + assert jitval not in self.tensor_sequences + self.tensor_sequences[jitval] = values + + def add_constant_value(self, jitval, ctype, value): + assert jitval not in self.constants + self.constants[jitval] = (ctype, value) + + def get_constant_value(self, jitval, typekind=None): + record = self.constants.get(jitval) + if record is None: + raise Exception( # noqa: TRY002 + f"Could not find constant value for '{jitval!r}'." + ) # noqa: TRY002 + ctype, _ = record + if typekind is not None and ctype.kind() != typekind: + raise Exception( # noqa: TRY002 + f"Expected constant value of type {typekind}, but got {ctype.kind()} for value '{jitval!r}'" + ) + return record + + def operand_to_template_torchscript(self, op_id, oper, shape=None): + """Return a TorchScript expression to build a template for a given operand.""" + if shape is None: + shape = oper.shape + else: + assert len(shape) == len(oper.shape) + + shape_parts = ["("] + for d, s in enumerate(shape): + if s > 0: + # Fixed shape dimension: just add the value. + shape_parts.append(str(s)) + elif s == 0: + # Load time flexible shape dimension: it should have been computed in a variable. + shape_parts.append(flex_name(op_id, d)) + elif s == -1: + # Runtime flexible shape + shape_parts.append("0") + else: + raise Exception( # noqa: TRY002 + "Unknown dim value, dimensions should be >= -1" + ) # noqa: TRY002 + shape_parts.append(",") + shape_parts.append(")") + shape_code = "".join(shape_parts) + if oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32: + return f"torch.zeros({shape_code}, dtype=torch.float32)" + elif oper.op_type == NNAPI_OperandCode.TENSOR_INT32: + return f"torch.zeros({shape_code}, dtype=torch.int32)" + elif oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: + return ( + f"torch.quantize_per_tensor(" + f"torch.zeros(1), scale={oper.scale}, zero_point={oper.zero_point}, dtype=torch.quint8)" + f".expand({shape_code}).contiguous()" + ) + elif oper.op_type in ( + NNAPI_OperandCode.TENSOR_QUANT16_ASYMM, + NNAPI_OperandCode.TENSOR_QUANT16_SYMM, + ): + if self.use_int16_for_qint16: + return f"torch.zeros({shape_code}, dtype=torch.int16)" + else: + raise Exception( # noqa: TRY002 + "`int16` isn't supported. If you're trying to represent NNAPI" + " qint16 with Pytorch int16, set `use_int16_for_qint16 = True`" + ) + + raise Exception( # noqa: TRY002 + f"Unsupported output operand type: {oper.op_type}" + ) # noqa: TRY002 + + def forward_operand_shape(self, out_op_id, out_dim, in_op_id, in_dim): + self.compute_operand_shape(out_op_id, out_dim, flex_name(in_op_id, in_dim)) + + def compute_operand_shape(self, op_id, dim, expr): + self.flexible_shape_computation_lines.append( + f"{flex_name(op_id, dim)} = {expr}" + ) + + def transpose_to_nhwc(self, in_id, oper): + if oper.shape[2:] != (1, 1): + raise Exception( # noqa: TRY002 + "Automatic transpose only supported for H,W == 1,1" + ) # noqa: TRY002 + + out_oper = oper._replace(dim_order=DimOrder.CHANNELS_LAST) + + inputs = [None] * 2 + inputs[0] = in_id + inputs[1] = self.add_immediate_int_vector([0, 2, 3, 1]) + + outputs = [None] * 1 + outputs[0] = self.add_anonymous_tensor_operand(out_oper) + + self.add_operation(NNAPI_OperationCode.TRANSPOSE, inputs, outputs) + + return outputs[0], out_oper + + # Transpose inputs as necessary to allow broadcasting. + def transpose_for_broadcast(self, in0_id, in0_oper, in1_id, in1_oper): + if in0_oper.dim_order == in1_oper.dim_order: + return in0_id, in0_oper, in1_id, in1_oper + + # Assume NHWC is preferred if there is a mismatch. + orders = (in0_oper.dim_order, in1_oper.dim_order) + if orders == (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.CHANNELS_LAST): + return self.transpose_to_nhwc(in0_id, in0_oper) + (in1_id, in1_oper) + if orders == (DimOrder.CHANNELS_LAST, DimOrder.PRESUMED_CONTIGUOUS): + return (in0_id, in0_oper) + self.transpose_to_nhwc(in1_id, in1_oper) + + raise Exception( # noqa: TRY002 + f"Automatic transpose not supported for dim_orders: {in0_oper.dim_order!r}, {in1_oper.dim_order!r}" + ) + + def get_size_arg(self, jitval): + ctype, value = self.get_constant_value(jitval) + if ctype.kind() == "ListType": + assert ctype.getElementType().kind() == "IntType" + return value + raise Exception( # noqa: TRY002 + f"Can't handle size arg of type '{ctype!r}' for '{jitval!r}'" + ) # noqa: TRY002 + + def get_conv_pool_args_2d_from_pack(self, kernel_size, packed_config): + pc = [i.item() for i in packed_config] + assert pc[0] == 2 + strides = [pc[1], pc[2]] + paddings = [pc[3], pc[4]] + dilations = [pc[5], pc[6]] + output_padding = [pc[7], pc[8]] + group_num = pc[9] + + assert len(pc) == 11 + assert output_padding == [0, 0] + + return self.get_conv_pool_args_2d_common( + kernel_size, strides, paddings, dilations, group_num + ) + + def get_conv_pool_args_2d_from_jit( + self, kernel_size, stride, padding, dilation=None, group=None + ): + strides = self.get_size_arg(stride) + paddings = self.get_size_arg(padding) + if dilation is None: + dilations = [1, 1] + else: + dilations = self.get_size_arg(dilation) + if group is not None: + _, group_num = self.get_constant_value(group, "IntType") + else: + group_num = None + return self.get_conv_pool_args_2d_common( + kernel_size, strides, paddings, dilations, group_num + ) + + def get_conv_pool_args_2d_common( + self, kernel_size, strides, paddings, dilations, group_num + ): + kernels = list(kernel_size) + + assert len(kernels) == 2 + assert len(strides) == 2 + assert len(paddings) == 2 + assert len(dilations) == 2 + + # NNAPI uses 4 values for padding. + ph, pw = paddings + real_paddings = [ph, ph, pw, pw] + + return ConvPoolArgs2d( + *(kernels + strides + real_paddings + dilations + [group_num]) + ) + + def serialize_model(self, model, inputs, return_shapes=None): + self.add_immediate_bool_scalar(False) + self.add_immediate_bool_scalar(True) + + inp_dim_orders = [] + out_dim_orders = [] + + self_jitval = next(model.graph.inputs()) + self.add_constant_value(self_jitval, self_jitval.type(), model) + + for arg_idx, (input_value, input_tensor) in enumerate( + zip(list(model.graph.inputs())[1:], inputs) + ): + op_id = self.add_tensor_operand_for_input( + arg_idx, input_value, input_tensor + ) + inp_dim_orders.append(self.operands[op_id].dim_order.value) + + for idx, node in enumerate(model.graph.nodes()): + LOG.debug("Processing node #%d: %r", idx, node) + self.add_node(node) + + retn = model.graph.return_node() + assert retn.inputsSize() == 1 + assert retn.outputsSize() == 0 + retn_input = retn.inputsAt(0) + template_return_lines = ["return ["] + if retn_input.type().kind() == "TensorType": + return_values = [retn_input] + retval_count = -1 + elif retn_input.type().kind() == "TupleType": + return_values = self.tensor_sequences[retn_input] + retval_count = len(return_values) + else: + raise Exception( # noqa: TRY002 + f"Unsupported return type: {retn_input.type()}" + ) # noqa: TRY002 + + if return_shapes is not None: + assert len(return_shapes) == len(return_values) + for i, v in enumerate(return_values): + op_id = self.jitval_operand_map[v] + self.outputs.append(op_id) + out_dim_orders.append(self.operands[op_id].dim_order.value) + shape = return_shapes[i] if return_shapes else None + template_return_lines.append( + self.operand_to_template_torchscript(op_id, self.operands[op_id], shape) + + "," + ) + template_return_lines.append("]") + + model = [] + + version = 1 + header = struct.pack( + "iiiiii", + version, + len(self.operands), + len(self.values), + len(self.operations), + len(self.inputs), + len(self.outputs), + ) + model.append(header) + + serialized_values, serialized_value_data = self.serialize_values() + + model.extend( + struct.pack("iifi", t, len(d), s, z) for (t, d, _m, s, z) in self.operands + ) + model.extend(serialized_values) + model.extend(struct.pack("iii", *x) for x in self.operations) + + # Compact the model so we can get its length so far. + model = [b"".join(model)] + model_offset = len(model[0]) + # Model offset is the index into the model (in 32-bit words, not bytes) + # of the next dimension we're about to serialize. If it's 0, + # generate code to mutate it before passing to NNAPI. + assert model_offset % 4 == 0 + model_offset = int(model_offset / 4) + + for op_id, (_, dims, dim_order, _, _) in enumerate(self.operands): + shape = fix_shape(dims, dim_order) + for d, s in enumerate(shape): + if s == 0: + pt_d = reverse_map_dim(dim_order, d) + self.flexible_shape_computation_lines.append( + f"ser_model[{model_offset}] = {flex_name(op_id, pt_d)}" + ) + model_offset += 1 + + # convert runtime flex shape from -1 to 0 + shape = tuple(d if d != -1 else 0 for d in shape) + model.append(self.serialize_ints(shape)) + + model.extend(serialized_value_data) + model.append(self.serialize_ints(self.operation_args)) + model.append(self.serialize_ints(self.inputs)) + model.append(self.serialize_ints(self.outputs)) + + self.flexible_shape_computation_lines.extend(template_return_lines) + + return ( + array.array("i", b"".join(model)), + self.used_weights, + inp_dim_orders, + out_dim_orders, + self.flexible_shape_computation_lines, + retval_count, + ) + + def serialize_values(self): + serialized_values = [] + serialized_value_data = [] + assert len(self.values) == len(self.value_data) + for (op_index, source_type), data in zip(self.values, self.value_data): + source_length = len(data) + + # Pad with 0 bytes out to a multiple of 4 for alignment. + physical_length = ((source_length - 1) | 0x3) + 1 + padded_data = data + (b"\0" * (physical_length - source_length)) + + serialized_values.append( + struct.pack("iii", op_index, source_type, source_length) + ) + serialized_value_data.append(padded_data) + + return serialized_values, serialized_value_data + + @staticmethod + def serialize_ints(ints): + return array.array("i", ints).tobytes() + + ADDER_MAP = { + "prim::GetAttr": lambda self, node: self.add_getattr(node), + "prim::Constant": lambda self, node: self.add_constant_node(node), + "prim::ListConstruct": lambda self, node: self.add_list_construct(node), + "prim::TupleConstruct": lambda self, node: self.add_tuple_construct(node), + "aten::unsqueeze": lambda self, node: self.add_unsqueeze(node), + "aten::to": lambda self, node: self.add_to(node), + "aten::detach": lambda self, node: self._identity(node), + "aten::reshape": lambda self, node: self.add_reshape(node), + "aten::flatten": lambda self, node: self.add_flatten(node), + "aten::slice": lambda self, node: self.add_slice(node), + "aten::size": lambda self, node: self.add_size(node), + "aten::cat": lambda self, node: self.add_cat(node), + "aten::mean": lambda self, node: self.add_mean(node), + "aten::quantize_per_tensor": lambda self, node: self.add_quantize(node), + "aten::dequantize": lambda self, node: self.add_dequantize(node), + "aten::add": lambda self, node: self.add_add_sub_op( + node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE + ), + "aten::sub": lambda self, node: self.add_add_sub_op( + node, NNAPI_OperationCode.SUB, NNAPI_FuseCode.FUSED_NONE + ), + "aten::mul": lambda self, node: self.add_pointwise_simple_binary_broadcast_op( + node, NNAPI_OperationCode.MUL, NNAPI_FuseCode.FUSED_NONE + ), + "aten::div": lambda self, node: self.add_pointwise_simple_binary_broadcast_op( + node, NNAPI_OperationCode.DIV, NNAPI_FuseCode.FUSED_NONE + ), + "aten::relu": lambda self, node: self.add_pointwise_simple_unary_op( + node, NNAPI_OperationCode.RELU + ), + "aten::sigmoid": lambda self, node: self.add_pointwise_simple_unary_op( + node, NNAPI_OperationCode.LOGISTIC + ), + "aten::softmax": lambda self, node: self.add_softmax(node), + "aten::hardtanh": lambda self, node: self.add_hardtanh(node), + "aten::avg_pool2d": lambda self, node: self.add_avg_pool2d(node), + "aten::max_pool2d": lambda self, node: self.add_pool2d_node( + node, NNAPI_OperationCode.MAX_POOL_2D + ), + "aten::adaptive_avg_pool2d": lambda self, node: self.add_adaptive_avg_pool2d( + node + ), + "aten::upsample_nearest2d": lambda self, node: self.add_upsample_nearest2d( + node + ), + "aten::prelu": lambda self, node: self.add_prelu_op(node), + "aten::addmm": lambda self, node: self.add_addmm(node), + "aten::linear": lambda self, node: self.add_linear(node), + "aten::_convolution": lambda self, node: self.add_conv_underscore(node), + "aten::conv2d": lambda self, node: self.add_conv2d(node), + "aten::log_softmax": lambda self, node: self.add_log_softmax(node), + "quantized::linear": lambda self, node: self.add_qlinear(node), + "quantized::conv2d": lambda self, node: self.add_qconv2d( + node, NNAPI_FuseCode.FUSED_NONE + ), + "quantized::conv2d_relu": lambda self, node: self.add_qconv2d( + node, NNAPI_FuseCode.FUSED_RELU + ), + "quantized::conv_transpose2d": lambda self, node: self.add_qconv2d( + node, NNAPI_FuseCode.FUSED_NONE, transpose=True + ), + "quantized::add": lambda self, node: self.add_qadd( + node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE + ), + "quantized::add_relu": lambda self, node: self.add_qadd( + node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_RELU + ), + "quantized::mul": lambda self, node: self.add_qadd( + node, NNAPI_OperationCode.MUL, NNAPI_FuseCode.FUSED_NONE + ), + } + + def add_node(self, node): + adder = self.ADDER_MAP.get(node.kind()) + if not adder: + raise Exception( # noqa: TRY002 + f"Unsupported node kind ({node.kind()!r}) in node {node!r}" + ) # noqa: TRY002 + adder(self, node) + + def _identity(self, node): + in_id, _in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + jitval = node.outputsAt(0) + self.jitval_operand_map[jitval] = in_id + + def add_getattr(self, node): + assert node.inputsSize() == 1 + assert node.outputsSize() == 1 + obj_ctype, obj = self.get_constant_value(node.inputsAt(0)) + assert str(obj_ctype).startswith("__torch__.") + name = node.s("name") + value = getattr(obj, name) + output = node.outputsAt(0) + ctype = output.type() + self.add_constant_value(output, ctype, value) + + def add_constant_node(self, node): + assert node.inputsSize() == 0 + assert node.outputsSize() == 1 + output = node.outputsAt(0) + ctype = output.type() + value = output.toIValue() + self.add_constant_value(output, ctype, value) + + def add_list_construct(self, node): + assert node.outputsSize() == 1 + output = node.outputsAt(0) + ctype = output.type() + const_vals: Optional[list] = [] + tensors: Optional[list] = [] + for inp in node.inputs(): + if const_vals is not None and inp in self.constants: + _, val = self.get_constant_value(inp) + const_vals.append(val) + else: + const_vals = None + if tensors is not None and inp.type().kind() == "TensorType": + tensors.append(inp) + else: + tensors = None + + if const_vals is not None: + # NOTE: Now that TorchScript supports list constants, + # this code path might not be used anymore. + self.add_constant_value(output, ctype, const_vals) + if tensors is not None: + self.add_tensor_sequence(output, tensors) + if const_vals is None and tensors is None: + raise Exception( # noqa: TRY002 + f"Unable to handle ListConstruct node. Neither all constants nor all tensors. {node!r}" + ) + + def add_tuple_construct(self, node): + assert node.outputsSize() == 1 + output = node.outputsAt(0) + values = list(node.inputs()) + self.add_tensor_sequence(output, values) + + def add_unsqueeze(self, node): + assert node.inputsSize() == 2 + assert node.outputsSize() == 1 + + in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) + + _, dim = self.get_constant_value(node.inputsAt(1), "IntType") + assert in_oper.dim_order == DimOrder.PRESUMED_CONTIGUOUS + + real_dim = dim if dim >= 0 else dim + len(in_oper.shape) + 1 + out_shape_list = list(in_oper.shape) + out_shape_list.insert(real_dim, 1) + out_shape = tuple(out_shape_list) + out_oper = in_oper._replace(shape=out_shape) + + inputs = [None] * 2 + inputs[0] = in_id + inputs[1] = self.add_immediate_int_scalar(dim) + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) + + self.add_operation(NNAPI_OperationCode.EXPAND_DIMS, inputs, outputs) + + def add_to(self, node): + # Handle to("cpu") / to("gpu") case + self._identity(node) + + def add_reshape(self, node): + assert node.inputsSize() == 2 + assert node.outputsSize() == 1 + + in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) + + shape_ctype, shape = self.get_constant_value(node.inputsAt(1)) + assert shape_ctype.kind() == "ListType" + assert shape_ctype.getElementType().kind() == "IntType" + is_trivial_reshape = len(shape) == 2 and shape[1] == -1 + + if in_oper.dim_order != DimOrder.PRESUMED_CONTIGUOUS and not is_trivial_reshape: + raise Exception( # noqa: TRY002 + "Currently, reshape is only supported on NHWC tensors if the target size is [X, -1]." + ) + + # Bit of a hack here. Use a real tensor to infer the output shape. + out_shape = torch.zeros(1).expand(in_oper.shape).reshape(shape).shape + out_oper = in_oper._replace( + shape=out_shape, dim_order=DimOrder.PRESUMED_CONTIGUOUS + ) + + inputs = [None] * 2 + inputs[0] = in_id + inputs[1] = self.add_immediate_int_vector(shape) + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) + + self.add_operation(NNAPI_OperationCode.RESHAPE, inputs, outputs) + + def add_flatten(self, node): + assert node.inputsSize() == 3 + assert node.outputsSize() == 1 + + in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + + _start_ctype, start_dim = self.get_constant_value(node.inputsAt(1), "IntType") + _end_ctype, end_dim = self.get_constant_value(node.inputsAt(2), "IntType") + + # channels last with channels == 1 or (height & width both 1) + is_trivial_flatten = len(in_oper.shape) == 4 and ( + in_oper.shape[1] == 1 or (in_oper.shape[2] == 1 and in_oper.shape[3] == 1) + ) + if in_oper.dim_order != DimOrder.PRESUMED_CONTIGUOUS and not is_trivial_flatten: + raise Exception( # noqa: TRY002 + "Currently, flatten is not supported on NHWC tensors unless C=1 or H=W=1" + ) + + if start_dim < 0: + start_dim += len(in_oper.shape) + if end_dim < 0: + end_dim += len(in_oper.shape) + + out_shape = ( + in_oper.shape[:start_dim] + + (functools.reduce(operator.mul, in_oper.shape[start_dim : end_dim + 1]),) + + in_oper.shape[end_dim + 1 :] + ) + + if any(dim == 0 for dim in in_oper.shape[start_dim : end_dim + 1]): + raise Exception( # noqa: TRY002 + "Flattening flexible dims is not supported yet" + ) # noqa: TRY002 + non_flattened_dims = in_oper.shape[:start_dim] + in_oper.shape[end_dim + 1 :] + if non_flattened_dims.count(0) > 1: + raise Exception("Only 1 dim can be flexible") # noqa: TRY002 + + out_oper = in_oper._replace( + shape=out_shape, dim_order=DimOrder.PRESUMED_CONTIGUOUS + ) + out_id = self.add_tensor_operand(node.outputsAt(0), out_oper) + + for idx, dim in enumerate(out_shape): + if dim == 0: + self.forward_operand_shape(out_id, idx, in_id, in_oper.shape.index(0)) + + inputs_1 = tuple(dim if dim != 0 else -1 for dim in out_shape) + inputs = [None] * 2 + inputs[0] = in_id + inputs[1] = self.add_immediate_int_vector(inputs_1) + + outputs = [None] * 1 + outputs[0] = out_id + + self.add_operation(NNAPI_OperationCode.RESHAPE, inputs, outputs) + + def add_slice(self, node): + assert node.inputsSize() == 5 + assert node.outputsSize() == 1 + + in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + _, dim_value = self.get_constant_value(node.inputsAt(1)) + _, start_value = self.get_constant_value(node.inputsAt(2)) + _, stop_value = self.get_constant_value(node.inputsAt(3)) + _, step_value = self.get_constant_value(node.inputsAt(4)) + + if start_value is None: + start_value = 0 + if stop_value is None: + stop_value = sys.maxsize + + if start_value < 0: + start_value += in_oper.shape[dim_value] + elif start_value == sys.maxsize: + start_value = 0 + + if start_value == 0 and stop_value == sys.maxsize: + self._identity(node) + return + + if in_oper.shape[dim_value] == 0: + raise Exception("Unable to slice with flexible shape") # noqa: TRY002 + + if stop_value < 0: + stop_value += in_oper.shape[dim_value] + elif stop_value == sys.maxsize: + stop_value = in_oper.shape[dim_value] + + if start_value >= stop_value: + raise Exception( # noqa: TRY002 + "Slice start value should be less than stop value" + ) # noqa: TRY002 + + out_len = (stop_value - start_value) // step_value + out_shape = tuple( + out_len if i == dim_value else dim for i, dim in enumerate(in_oper.shape) + ) + out_id = self.add_tensor_operand( + node.outputsAt(0), in_oper._replace(shape=out_shape) + ) + + # flex inputs + end_mask = 0 + for idx, dim in enumerate(out_shape): + if dim == 0: + self.forward_operand_shape(out_id, idx, in_id, idx) + end_mask |= 1 << idx + + inputs = [None] * 7 + inputs[0] = in_id + inputs[1] = self.add_immediate_int_vector( + [start_value if i == dim_value else 0 for i in range(len(in_oper.shape))] + ) + inputs[2] = self.add_immediate_int_vector( + [ + stop_value if i == dim_value else dim + for i, dim in enumerate(in_oper.shape) + ] + ) + inputs[3] = self.add_immediate_int_vector( + [step_value if i == dim_value else 1 for i in range(len(in_oper.shape))] + ) + inputs[4] = self.add_immediate_int_scalar(0) # begin mask + inputs[5] = self.add_immediate_int_scalar(end_mask) + inputs[6] = self.add_immediate_int_scalar(0) # shrink axis mas + + outputs = [None] * 1 + outputs[0] = out_id + + self.add_operation(NNAPI_OperationCode.STRIDED_SLICE, inputs, outputs) + + def add_size(self, node): + assert node.inputsSize() == 2 + assert node.outputsSize() == 1 + + _, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) + _, value = self.constants[node.inputsAt(1)] + res = in_oper.shape[value] + output = node.outputsAt(0) + self.add_constant_value(output, output.type(), res) + + def add_cat(self, node): + assert node.inputsSize() == 2 + assert node.outputsSize() == 1 + + tensors = self.tensor_sequences[node.inputsAt(0)] + _, dim = self.get_constant_value(node.inputsAt(1), "IntType") + + assert len(tensors) > 0 + in_ids = [] + out_oper = None + out_dim_size = 0 + for inp in tensors: + in_id, in_oper = self.get_tensor_operand_by_jitval(inp) + if out_oper is None: + out_shape = change_element(in_oper.shape, dim, -1) + out_oper = in_oper._replace(shape=out_shape) + assert in_oper.op_type == out_oper.op_type + assert in_oper.dim_order == out_oper.dim_order + assert change_element(in_oper.shape, dim, -1) == change_element( + out_oper.shape, dim, -1 + ) + # TODO: Possibly check scale and zero point. + in_ids.append(in_id) + # TODO: Possibly support variable-sized inputs. + out_dim_size += in_oper.shape[dim] + + assert out_oper is not None + out_oper = out_oper._replace( + shape=change_element(out_oper.shape, dim, out_dim_size) + ) + + if in_oper.dim_order == DimOrder.CHANNELS_LAST: # type: ignore[possibly-undefined] + assert len(out_oper.shape) == 4 + nnapi_dim = [0, 3, 1, 2][dim] + else: + nnapi_dim = dim + + out_id = self.add_tensor_operand(node.outputsAt(0), out_oper) + for idx, d in enumerate(out_oper.shape): + if d == 0: + if idx == dim: + shape = " + ".join(flex_name(ip_id, dim) for ip_id in in_ids) + self.compute_operand_shape(out_id, idx, shape) + else: + self.forward_operand_shape(out_id, idx, in_ids[0], idx) + + inputs = in_ids + [self.add_immediate_int_scalar(nnapi_dim)] + + outputs = [None] * 1 + outputs[0] = out_id + + self.add_operation(NNAPI_OperationCode.CONCATENATION, inputs, outputs) + + def add_mean(self, node): + assert node.inputsSize() == 4 + assert node.outputsSize() == 1 + + in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) + dim_ctype, dim = self.get_constant_value(node.inputsAt(1)) + assert dim_ctype.kind() == "ListType" + assert dim_ctype.getElementType().kind() == "IntType" + _, keep_dim = self.get_constant_value(node.inputsAt(2), "BoolType") + # Expect None for dtype + self.get_constant_value(node.inputsAt(3), "NoneType") + + if in_oper.dim_order == DimOrder.CHANNELS_LAST: + assert len(in_oper.shape) == 4 + nnapi_dim = [[0, 3, 1, 2][d] for d in dim] + else: + nnapi_dim = dim + + collapsed_dims = set() + for d in dim: + if d < 0: + d += len(in_oper.shape) + collapsed_dims.add(d) + + if in_oper.dim_order == DimOrder.CHANNELS_LAST and not keep_dim: + assert collapsed_dims.issuperset({2, 3}) + out_dim_order = DimOrder.PRESUMED_CONTIGUOUS + else: + out_dim_order = in_oper.dim_order + + out_shape = [] + for i, s in enumerate(in_oper.shape): + if i not in collapsed_dims: + out_shape.append(s) + elif keep_dim: + out_shape.append(1) + + out_oper = in_oper._replace(shape=out_shape, dim_order=out_dim_order) + + inputs = [None] * 3 + inputs[0] = in_id + inputs[1] = self.add_immediate_int_vector(nnapi_dim) + inputs[2] = self.add_immediate_int_scalar(keep_dim) + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) + + self.add_operation(NNAPI_OperationCode.MEAN, inputs, outputs) + + def add_quantize(self, node): + assert node.inputsSize() == 4 + assert node.outputsSize() == 1 + + in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) + if in_oper.dim_order != DimOrder.CHANNELS_LAST: + raise Exception( # noqa: TRY002 + "Most hardware backends prefer NHWC quantized tensors. " + "Try setting `t.nnapi_nhwc = True` on your tensor inputs. " + ) + _, scale = self.get_constant_value(node.inputsAt(1), "FloatType") + _, zero_point = self.get_constant_value(node.inputsAt(2), "IntType") + _, scalar_type = self.get_constant_value(node.inputsAt(3), "IntType") + if scalar_type != TorchScalarTypes.QUINT8.value: + raise Exception( # noqa: TRY002 + "PyTorch NNAPI export only supports quantized tensors " + "with the quint8 dtype." + ) + op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM + + out_oper = in_oper._replace( + op_type=op_type, + scale=scale, + zero_point=zero_point, + ) + + inputs = [None] * 1 + inputs[0] = in_id + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) + + self.add_operation(NNAPI_OperationCode.QUANTIZE, inputs, outputs) + + def add_dequantize(self, node): + assert node.inputsSize() == 1 + assert node.outputsSize() == 1 + + in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) + out_oper = in_oper._replace( + op_type=NNAPI_OperandCode.TENSOR_FLOAT32, + scale=0.0, + zero_point=0, + ) + + inputs = [None] * 1 + inputs[0] = in_id + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) + + self.add_operation(NNAPI_OperationCode.DEQUANTIZE, inputs, outputs) + + def add_pointwise_simple_unary_op(self, node, opcode): + assert node.inputsSize() == 1 + assert node.outputsSize() == 1 + + in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + + out_oper = in_oper + if opcode == NNAPI_OperationCode.LOGISTIC: + # NNAPI docs: For ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, the scale + # must be 1.f / 256 and the zeroPoint must be 0. + # https://fburl.com/h52stoog + if in_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: + out_oper = in_oper._replace(zero_point=0, scale=1.0 / 256) + + out_id = self.add_tensor_operand(node.outputsAt(0), out_oper) + + for idx, dim in enumerate(in_oper.shape): + if dim == 0: + self.forward_operand_shape(out_id, idx, in_id, idx) + + inputs = [None] * 1 + inputs[0] = in_id + + outputs = [None] * 1 + outputs[0] = out_id + + self.add_operation(opcode, inputs, outputs) + + def _do_add_binary(self, node, opcode, fuse_code, *, qparams=None): # noqa: D401 + """Helper for pointwise binary broadcast ops with superfluous extra args.""" + assert node.outputsSize() == 1 + + assert node.inputsAt(0).type().kind() == "TensorType" + assert node.inputsAt(1).type().kind() == "TensorType" + + if self.has_operand_for_jitval(node.inputsAt(0)): + in0_id, in0_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + in1_id, in1_oper = self.get_tensor_operand_or_constant( + node.inputsAt(1), in0_oper.dim_order + ) + elif self.has_operand_for_jitval(node.inputsAt(1)): + in1_id, in1_oper = self.get_tensor_operand_by_jitval(node.inputsAt(1)) + in0_id, in0_oper = self.get_tensor_operand_or_constant( + node.inputsAt(0), in1_oper.dim_order + ) + else: + raise Exception( # noqa: TRY002 + f"Can't do a NNAPI binary op: {opcode} on two constants" + ) # noqa: TRY002 + + assert in0_oper.op_type == in1_oper.op_type + in0_id, in0_oper, in1_id, in1_oper = self.transpose_for_broadcast( + in0_id, in0_oper, in1_id, in1_oper + ) + # NOTE: PyTorch and NNAPI have the same broadcast semantics. + out_shape = broadcast_shapes(in0_oper.shape, in1_oper.shape) + out_oper = in0_oper._replace(shape=out_shape) + if qparams is not None: + scale, zp = qparams + out_oper = out_oper._replace(scale=scale, zero_point=zp) + + out_id = self.add_tensor_operand(node.outputsAt(0), out_oper) + for idx, (d0, d1) in enumerate(zip(in0_oper.shape, in1_oper.shape)): + if d0 == 1 and d1 == 0: + self.forward_operand_shape(out_id, idx, in1_id, idx) + elif d0 == 0 and d1 == 1: + self.forward_operand_shape(out_id, idx, in0_id, idx) + elif d0 == 0 and d1 == 0: + self.flexible_shape_computation_lines.append( + f"assert {flex_name(in0_id, idx)} == {flex_name(in1_id, idx)}" + ) + self.forward_operand_shape(out_id, idx, in0_id, idx) + + inputs = [None] * 3 + inputs[0] = in0_id + inputs[1] = in1_id + inputs[2] = self.add_immediate_int_scalar(fuse_code) + + outputs = [None] * 1 + outputs[0] = out_id + + self.add_operation(opcode, inputs, outputs) + + def add_pointwise_simple_binary_broadcast_op(self, node, opcode, fuse_code): + assert node.inputsSize() == 2 + self._do_add_binary(node, opcode, fuse_code) + + def add_add_sub_op(self, node, opcode, fuse_code): + assert node.inputsSize() == 3 + + _, alpha = self.get_constant_value(node.inputsAt(2), "IntType") + if alpha != 1: + raise Exception( # noqa: TRY002 + "NNAPI does not support add/sub with alpha." + ) # noqa: TRY002 + + self._do_add_binary(node, opcode, fuse_code) + + def add_qadd(self, node, opcode, fuse_code): + assert node.inputsSize() == 4 + + _, scale = self.get_constant_value(node.inputsAt(2), "FloatType") + _, zero_point = self.get_constant_value(node.inputsAt(3), "IntType") + + self._do_add_binary(node, opcode, fuse_code, qparams=(scale, zero_point)) + + def add_softmax(self, node): + assert node.inputsSize() == 3 + in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + + _, softmax_dim = self.get_constant_value(node.inputsAt(1), "IntType") + + out_id = self.add_tensor_operand(node.outputsAt(0), in_oper) + for dim, size in enumerate(in_oper.shape): + if size == 0: + self.forward_operand_shape(out_id, dim, in_id, dim) + + inputs = [None] * 3 + inputs[0] = in_id + inputs[1] = self.add_immediate_float_scalar( + 1.0 + ) # positive scaling factor of exponent, beta + inputs[2] = self.add_immediate_int_scalar(softmax_dim) + + outputs = [None] * 1 + outputs[0] = out_id + + self.add_operation(NNAPI_OperationCode.SOFTMAX, inputs, outputs) + + def add_hardtanh(self, node): + assert node.inputsSize() == 3 + assert node.outputsSize() == 1 + + in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) + _, min_val = self.get_constant_value(node.inputsAt(1), "FloatType") + _, max_val = self.get_constant_value(node.inputsAt(2), "FloatType") + + op_map = { + (-1, 1): NNAPI_OperationCode.RELU1, + (0, 6): NNAPI_OperationCode.RELU6, # noqa: E201 + } + + opcode = op_map.get((min_val, max_val)) + if opcode is None: + raise Exception( # noqa: TRY002 + "NNAPI only supports hardtanh with args (-1, 1) or (0, 6)." + ) # noqa: TRY002 + + inputs = [None] * 1 + inputs[0] = in_id + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), in_oper) + + self.add_operation(opcode, inputs, outputs) + + def add_prelu_op(self, node): + assert node.inputsSize() == 2 + assert node.outputsSize() == 1 + + assert node.inputsAt(0).type().kind() == "TensorType" + assert node.inputsAt(1).type().kind() == "TensorType" + + in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + w_id, w_oper = self.get_tensor_operand_for_weight(node.inputsAt(1)) + assert len(w_oper.shape) == 1 + assert w_oper.shape[0] > 0 + if w_oper.shape[0] > 1: + if in_oper.use_nchw(): + # TODO: Support this by adding trailing 1 dims. + raise Exception( # noqa: TRY002 + "Per-channel PReLU only supports channels_last right now." + ) + + out_id = self.add_tensor_operand(node.outputsAt(0), in_oper) + for dim, size in enumerate(in_oper.shape): + if size > 0: + pass + elif dim <= 1: + raise Exception( # noqa: TRY002 + "PReLU requires fixed size for dim 0 and dim 1." + ) # noqa: TRY002 + else: + self.forward_operand_shape(out_id, dim, in_id, dim) + + inputs = [None] * 2 + inputs[0] = in_id + inputs[1] = w_id + + outputs = [None] * 1 + outputs[0] = out_id + + self.add_operation(NNAPI_OperationCode.PRELU, inputs, outputs) + + def add_pool2d_node(self, node, opcode): + assert node.inputsSize() == 6 + assert node.outputsSize() == 1 + image, kernel, stride, padding, dilation, _ceil_mode = node.inputs() + + stride = stride or kernel + + # TODO: Validate ceil_mode semantics. + + args = self.get_conv_pool_args_2d_from_jit( + self.get_size_arg(kernel), stride, padding, dilation + ) + if args.dilation_h != 1 or args.dilation_w != 1: + raise Exception("NNAPI does not support dilated pooling.") # noqa: TRY002 + + image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size(image) + assert len(image_oper.shape) == 4 + + out_shape = get_conv_pool_shape( + image_oper.shape, args, image_oper.shape[1], False + ) + use_nchw = image_oper.use_nchw() + + inputs = [None] * 11 + inputs[0] = image_id + inputs[1] = self.add_immediate_int_scalar(args.pad_l) + inputs[2] = self.add_immediate_int_scalar(args.pad_r) + inputs[3] = self.add_immediate_int_scalar(args.pad_t) + inputs[4] = self.add_immediate_int_scalar(args.pad_b) + inputs[5] = self.add_immediate_int_scalar(args.stride_w) + inputs[6] = self.add_immediate_int_scalar(args.stride_h) + inputs[7] = self.add_immediate_int_scalar(args.kernel_w) + inputs[8] = self.add_immediate_int_scalar(args.kernel_h) + inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) + inputs[10] = self.add_immediate_bool_scalar(use_nchw) + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand( + node.outputsAt(0), image_oper._replace(shape=out_shape) + ) + + self.add_operation(opcode, inputs, outputs) + + def add_avg_pool2d(self, node): + assert node.inputsSize() == 7 + assert node.outputsSize() == 1 + ( + image, + kernel, + stride, + padding, + _ceil_mode, + count_include_pad, + divisor_override, + ) = node.inputs() + + _, count_include_pad_value = self.get_constant_value(count_include_pad) + _, divisor_override_value = self.get_constant_value(divisor_override) + if not count_include_pad_value or divisor_override_value: + raise Exception( # noqa: TRY002 + "NNAPI doesn't support count_include_pad=False or divisor_override" + ) + + args = self.get_conv_pool_args_2d_from_jit( + self.get_size_arg(kernel), stride, padding + ) + + image_id, image_oper = self.get_tensor_operand_by_jitval(image) + assert len(image_oper.shape) == 4 + + out_shape = get_conv_pool_shape( + image_oper.shape, args, image_oper.shape[1], False + ) + use_nchw = image_oper.use_nchw() + + inputs = [None] * 11 + inputs[0] = image_id + inputs[1] = self.add_immediate_int_scalar(args.pad_l) + inputs[2] = self.add_immediate_int_scalar(args.pad_r) + inputs[3] = self.add_immediate_int_scalar(args.pad_t) + inputs[4] = self.add_immediate_int_scalar(args.pad_b) + inputs[5] = self.add_immediate_int_scalar(args.stride_w) + inputs[6] = self.add_immediate_int_scalar(args.stride_h) + inputs[7] = self.add_immediate_int_scalar(args.kernel_w) + inputs[8] = self.add_immediate_int_scalar(args.kernel_h) + inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) + inputs[10] = self.add_immediate_bool_scalar(use_nchw) + + outputs = [None] * 1 + out_id = self.add_tensor_operand( + node.outputsAt(0), image_oper._replace(shape=out_shape) + ) + self._handle_conv_pool_flexible_input(out_id, image, args, False) + outputs[0] = out_id + + self.add_operation(NNAPI_OperationCode.AVERAGE_POOL_2D, inputs, outputs) + + def add_adaptive_avg_pool2d(self, node): + assert node.inputsSize() == 2 + assert node.outputsSize() == 1 + + image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size( + node.inputsAt(0) + ) + assert len(image_oper.shape) == 4 + + size_ctype, size_arg = self.get_constant_value(node.inputsAt(1)) + assert size_ctype.kind() == "ListType" + assert size_ctype.getElementType().kind() == "IntType" + if size_arg != [1, 1]: + raise Exception( # noqa: TRY002 + "NNAPI only supports adaptive_avg_pool2d with output size (1, 1)." + ) + + out_shape = image_oper.shape[0:2] + tuple(size_arg) + use_nchw = image_oper.use_nchw() + + inputs = [None] * 11 + inputs[0] = image_id + inputs[1] = self.add_immediate_int_scalar(0) + inputs[2] = self.add_immediate_int_scalar(0) + inputs[3] = self.add_immediate_int_scalar(0) + inputs[4] = self.add_immediate_int_scalar(0) + inputs[5] = self.add_immediate_int_scalar(1) + inputs[6] = self.add_immediate_int_scalar(1) + inputs[7] = self.add_immediate_int_scalar(image_oper.shape[3]) + inputs[8] = self.add_immediate_int_scalar(image_oper.shape[2]) + inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) + inputs[10] = self.add_immediate_bool_scalar(use_nchw) + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand( + node.outputsAt(0), image_oper._replace(shape=out_shape) + ) + + self.add_operation(NNAPI_OperationCode.AVERAGE_POOL_2D, inputs, outputs) + + def add_upsample_nearest2d(self, node): + assert node.inputsSize() == 3 or node.inputsSize() == 4 + assert node.outputsSize() == 1 + if node.inputsSize() == 3: + image, size_jit, scale_jit = node.inputs() + else: + image, size_jit, scale_h_jit, scale_w_jit = node.inputs() + size_ctype, size_arg = self.get_constant_value(size_jit) + + if node.inputsSize() == 3: + scale_ctype, scale_arg = self.get_constant_value(scale_jit) # type: ignore[possibly-undefined] + else: + scale_h_ctype, scale_h_arg = self.get_constant_value(scale_h_jit) # type: ignore[possibly-undefined] + scale_w_ctype, _scale_w_arg = self.get_constant_value(scale_w_jit) # type: ignore[possibly-undefined] + + # The only way for the 4-argument overload of upsample_nearest2d to + # have been added to the graph without error is if the scale_h and + # scale_w arguments are None + assert scale_h_ctype.kind() == "NoneType" + assert scale_w_ctype.kind() == "NoneType" + + scale_ctype = scale_h_ctype + scale_arg = scale_h_arg + + image_id, image_oper = self.get_tensor_operand_by_jitval(image) + assert len(image_oper.shape) == 4 + + if size_ctype.kind() != "NoneType" and scale_ctype.kind() != "NoneType": + raise Exception("Size and scale cannot both be non-None.") # noqa: TRY002 + elif size_ctype.kind() != "NoneType": + assert size_ctype.kind() == "ListType" + assert size_ctype.getElementType().kind() == "IntType" + assert scale_ctype.kind() == "NoneType" + assert scale_arg is None + assert isinstance(size_arg, list) + assert size_arg + assert all(isinstance(val, int) for val in size_arg) + if len(size_arg) == 1: + size_arg = size_arg * 2 + assert len(size_arg) == 2 + out_h = size_arg[0] + out_w = size_arg[1] + arg_h = self.add_immediate_int_scalar(out_h) + arg_w = self.add_immediate_int_scalar(out_w) + elif scale_ctype.kind() != "NoneType": + assert scale_ctype.kind() == "ListType" + assert scale_ctype.getElementType().kind() == "FloatType" + assert size_ctype.kind() == "NoneType" + assert size_arg is None + assert isinstance(scale_arg, list) + assert scale_arg + assert all(isinstance(val, float) for val in scale_arg) + if len(scale_arg) == 1: + scale_arg = scale_arg * 2 + assert len(scale_arg) == 2 + out_h = int(scale_arg[0] * image_oper.shape[2]) + out_w = int(scale_arg[1] * image_oper.shape[3]) + arg_h = self.add_immediate_float_scalar(scale_arg[0]) + arg_w = self.add_immediate_float_scalar(scale_arg[1]) + else: + raise Exception("Size and scale cannot both be None.") # noqa: TRY002 + + out_shape = (image_oper.shape[0], image_oper.shape[1], out_h, out_w) + use_nchw = image_oper.use_nchw() + out_id = self.add_tensor_operand( + node.outputsAt(0), image_oper._replace(shape=out_shape) + ) + + if image_oper.shape[0] == 0 or image_oper.shape[1] == 0: + raise Exception("Flexible batch or channels not supported") # noqa: TRY002 + + # Handle variable input size + for dim in (2, 3): # h, w indices + if image_oper.shape[dim] == 0: + if size_ctype.kind() != "NoneType": + # pyrefly: ignore [unsupported-operation] + self.compute_operand_shape(out_id, dim, size_arg[dim - 2]) + elif scale_ctype.kind() != "NoneType": + self.compute_operand_shape( + out_id, + dim, + # pyrefly: ignore [unsupported-operation] + f"int({scale_arg[dim - 2]} * {flex_name(image_id, dim)})", + ) + else: + raise Exception( # noqa: TRY002 + "Size and scale cannot both be None." + ) # noqa: TRY002 + + inputs = [None] * 4 + inputs[0] = image_id + inputs[1] = arg_w + inputs[2] = arg_h + inputs[3] = self.add_immediate_bool_scalar(use_nchw) + + outputs = [None] * 1 + outputs[0] = out_id + + self.add_operation(NNAPI_OperationCode.RESIZE_NEAREST_NEIGHBOR, inputs, outputs) + + def add_addmm(self, node): + assert node.inputsSize() == 5 + assert node.outputsSize() == 1 + jit_bias, jit_input, jit_weight, jit_beta, jit_alpha = node.inputs() + + for jitval in (jit_beta, jit_alpha): + scale_ctype, scale_value = self.get_constant_value(jitval) + assert scale_ctype.kind() in ("IntType", "FloatType") + if scale_value != 1: + raise Exception( # noqa: TRY002 + "NNAPI Fully-Connected does not support alpha and beta." + ) + + self.add_addmm_or_linear(node, True, jit_input, jit_weight, jit_bias) + + def add_linear(self, node): + assert node.inputsSize() == 3 + assert node.outputsSize() == 1 + jit_input, jit_weight, jit_bias = node.inputs() + + self.add_addmm_or_linear(node, False, jit_input, jit_weight, jit_bias) + + def add_addmm_or_linear( + self, node, transpose_weight, jit_input, jit_weight, jit_bias + ): + input_id, input_oper = self.get_tensor_operand_by_jitval(jit_input) + bias_id, bias_oper = self.get_tensor_operand_for_weight(jit_bias) + + assert len(input_oper.shape) == 2 + assert len(bias_oper.shape) == 1 + + # TODO: Transform at load time to share weights with CPU model. + _, weight_tensor = self.get_constant_value(jit_weight, "TensorType") + assert len(weight_tensor.shape) == 2 + if transpose_weight: + nnapi_weight_tensor = weight_tensor.t().contiguous() + else: + nnapi_weight_tensor = weight_tensor.contiguous() + weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor) + weight_oper = self.operands[weight_id] + + out_shape = (input_oper.shape[0], weight_oper.shape[0]) + out_id = self.add_tensor_operand( + node.outputsAt(0), input_oper._replace(shape=out_shape) + ) + + if input_oper.shape[0] == 0: + self.forward_operand_shape(out_id, 0, input_id, 0) + + inputs = [None] * 4 + inputs[0] = input_id + inputs[1] = weight_id + inputs[2] = bias_id + inputs[3] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) + + outputs = [None] * 1 + outputs[0] = out_id + + self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs) + + def add_qlinear(self, node): + assert node.inputsSize() == 4 + assert node.outputsSize() == 1 + ( + jit_input, + jit_packed_weight, + jit_scale, + jit_zero_point, + ) = node.inputs() + + input_id, input_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_input) + # TODO: Support automatic reshape + assert len(input_oper.shape) == 2 + + _, out_scale = self.get_constant_value(jit_scale, "FloatType") + _, out_zero_point = self.get_constant_value(jit_zero_point, "IntType") + weight_ctype, packed_weight = self.get_constant_value(jit_packed_weight) + assert weight_ctype.name() == "LinearPackedParamsBase" + raw_weight, raw_bias = packed_weight.__getstate__()[0] + assert raw_bias is not None + + assert len(raw_weight.shape) == 2 + assert len(raw_bias.shape) == 1 + assert raw_bias.shape[0] == raw_weight.shape[0] + assert raw_weight.shape[1] == input_oper.shape[1] + + assert raw_weight.qscheme() == torch.per_tensor_affine + if raw_weight.dtype == torch.quint8: + unsigned_weight = raw_weight + else: + assert raw_weight.dtype == torch.qint8 + unsigned_weight = torch._make_per_tensor_quantized_tensor( + (raw_weight.int_repr().int() + 128).to(torch.uint8), + scale=raw_weight.q_scale(), + zero_point=raw_weight.q_zero_point() + 128, + ) + weight_scale = unsigned_weight.q_scale() + bias_scale = input_oper.scale * weight_scale + int_bias = torch.quantize_per_tensor(raw_bias, bias_scale, 0, torch.qint32) + bias_id = self.add_tensor_operand_for_weight(int_bias) + + multiplier = input_oper.scale * weight_scale / out_scale + assert multiplier > 0 + if multiplier >= 1: + raise Exception( # noqa: TRY002 + "Quantized convolution multiplier is greater than 1. " + "This is supported by NNAPI, but not by most hardware backends. " + "Try training a model without quantization-aware training. " + ) + + # TODO: Transform at load time to share weights with CPU model. + nnapi_weight_tensor = unsigned_weight.contiguous() + weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor) + weight_oper = self.operands[weight_id] + + out_shape = (input_oper.shape[0], weight_oper.shape[0]) + out_oper = input_oper._replace( + shape=out_shape, + scale=out_scale, + zero_point=out_zero_point, + ) + + inputs = [None] * 4 + inputs[0] = input_id + inputs[1] = weight_id + inputs[2] = bias_id + inputs[3] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) + + self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs) + + def get_optional_bias(self, jit_bias, weight_tensor, transpose=False): + ctype, _value = self.get_constant_value(jit_bias) + if ctype.kind() == "NoneType": + bias_idx = 1 if transpose else 0 + nnapi_bias_tensor = torch.zeros( + weight_tensor.size()[bias_idx], dtype=weight_tensor.dtype + ) + bias_id = self.add_tensor_operand_for_weight(nnapi_bias_tensor) + bias_oper = self.operands[bias_id] + return bias_id, bias_oper + else: + return self.get_tensor_operand_for_weight(jit_bias) + + def add_conv2d(self, node): + assert node.inputsSize() == 7 + assert node.outputsSize() == 1 + + ( + jit_image, + jit_weight, + jit_bias, + jit_stride, + jit_pad, + jit_dilation, + jit_groups, + ) = node.inputs() + + _, weight_tensor = self.get_constant_value(jit_weight, "TensorType") + bias_id, _bias_oper = self.get_optional_bias(jit_bias, weight_tensor) + args = self.get_conv_pool_args_2d_from_jit( + weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups + ) + + return self.add_conv2d_common( + node.outputsAt(0), + 0.0, + 0, + jit_image, + weight_tensor, + bias_id, + args, + False, # transpose + NNAPI_FuseCode.FUSED_NONE, + ) + + def add_conv_underscore(self, node): + assert node.inputsSize() == 13 + assert node.outputsSize() == 1 + + ( + jit_image, + jit_weight, + jit_bias, + jit_stride, + jit_pad, + jit_dilation, + jit_transpose, + _, + jit_groups, + _, + _, + _, + _, + ) = node.inputs() + + _, weight_tensor = self.get_constant_value(jit_weight, "TensorType") + _, transpose = self.get_constant_value(jit_transpose) + bias_id, _bias_oper = self.get_optional_bias(jit_bias, weight_tensor, transpose) + args = self.get_conv_pool_args_2d_from_jit( + weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups + ) + + return self.add_conv2d_common( + node.outputsAt(0), + 0.0, + 0, + jit_image, + weight_tensor, + bias_id, + args, + transpose, + NNAPI_FuseCode.FUSED_NONE, + ) + + def add_log_softmax(self, node): + assert node.inputsSize() == 3 + assert node.outputsSize() == 1 + + jit_input, jit_dim, _jit_half_to_float = node.inputs() + input_id, input_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_input) + _, dim = self.get_constant_value(jit_dim, "IntType") + + out_shape = input_oper.shape + + inputs = [None] * 3 + inputs[0] = input_id + # specifying 1 as the scaling factor for the exponent, beta + inputs[1] = self.add_immediate_float_scalar(1) + inputs[2] = self.add_immediate_int_scalar(dim) + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand( + node.outputsAt(0), input_oper._replace(shape=out_shape) + ) + self.add_operation(NNAPI_OperationCode.LOG_SOFTMAX, inputs, outputs) + + def add_qconv2d(self, node, fuse_code, transpose=False): + assert node.inputsSize() == 4 + assert node.outputsSize() == 1 + + ( + jit_image, + jit_packed_weight, + jit_scale, + jit_zero_point, + ) = node.inputs() + + _, out_scale = self.get_constant_value(jit_scale, "FloatType") + _, out_zero_point = self.get_constant_value(jit_zero_point, "IntType") + weight_ctype, packed_weight = self.get_constant_value(jit_packed_weight) + assert weight_ctype.name() == "Conv2dPackedParamsBase" + ( + pack_version, + tensors, + opt_tensors, + ) = packed_weight.__getstate__()[0] + assert pack_version == "2" + packed_config, raw_weight = tensors + (raw_bias,) = opt_tensors + assert raw_bias is not None + args = self.get_conv_pool_args_2d_from_pack( + raw_weight.shape[2:4], packed_config + ) + + assert raw_weight.qscheme() == torch.per_tensor_affine + if raw_weight.dtype == torch.quint8: + unsigned_weight = raw_weight + else: + assert raw_weight.dtype == torch.qint8 + unsigned_weight = torch._make_per_tensor_quantized_tensor( + (raw_weight.int_repr().int() + 128).to(torch.uint8), + scale=raw_weight.q_scale(), + zero_point=raw_weight.q_zero_point() + 128, + ) + weight_scale = unsigned_weight.q_scale() + _, image_oper = self.get_tensor_operand_by_jitval(jit_image) + bias_scale = image_oper.scale * weight_scale + int_bias = torch.quantize_per_tensor(raw_bias, bias_scale, 0, torch.qint32) + bias_id = self.add_tensor_operand_for_weight(int_bias) + + multiplier = image_oper.scale * weight_scale / out_scale + assert multiplier > 0 + if multiplier >= 1: + raise Exception( # noqa: TRY002 + "Quantized convolution multiplier is greater than 1. " + "This is supported by NNAPI, but not by most hardware backends. " + "Try training a model without quantization-aware training. " + ) + + return self.add_conv2d_common( + node.outputsAt(0), + out_scale, + out_zero_point, + jit_image, + unsigned_weight, + bias_id, + args, + transpose, + fuse_code, + ) + + def add_conv2d_common( + self, + jit_out, + out_scale, + out_zero_point, + jit_image, + weight_tensor, + bias_id, + args, + transpose, + fuse_code, + ): + image_id, image_oper = self.get_tensor_operand_by_jitval(jit_image) + in_c = image_oper.shape[1] + + if args.group == 1: + # Full convolution + depthwise = False + if transpose: + weight_permutation = (1, 2, 3, 0) + else: + weight_permutation = (0, 2, 3, 1) + elif args.group == in_c: + # Depthwise convolution + depthwise = True + weight_permutation = (1, 2, 3, 0) + else: + raise Exception("Group convolution not supported yet.") # noqa: TRY002 + + # TODO: Transform at load time to share weights with CPU model. + nnapi_weight_tensor = weight_tensor.permute(*weight_permutation).contiguous() + weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor) + weight_oper = self.operands[weight_id] + + bias_oper = self.operands[bias_id] + + if image_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32: + assert weight_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32 + assert bias_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32 + elif image_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: + assert weight_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM + assert bias_oper.op_type == NNAPI_OperandCode.TENSOR_INT32 + assert approx_equal(image_oper.scale * weight_oper.scale, bias_oper.scale) + assert bias_oper.zero_point == 0 + else: + raise Exception( # noqa: TRY002 + f"Unsupported input type for conv2d: {image_oper.op_type}" + ) # noqa: TRY002 + + assert len(image_oper.shape) == 4 + assert len(weight_oper.shape) == 4 + assert len(bias_oper.shape) == 1 + + if depthwise: + # Depthwise convolution + one, _kern_h, _kern_w, out_c = weight_oper.shape + assert one == 1 + assert out_c % in_c == 0 + channel_multiplier = out_c // in_c + assert channel_multiplier == 1 # Don't support multiplier + assert out_c == in_c + else: + # Full convolution + out_c, _kern_h, _kern_w, kern_d = weight_oper.shape + assert kern_d == in_c + + assert out_c == bias_oper.shape[0] + + use_nchw = image_oper.use_nchw() + + if depthwise: + num_args = 12 + opcode = NNAPI_OperationCode.DEPTHWISE_CONV_2D + else: + num_args = 11 + if transpose: + opcode = NNAPI_OperationCode.TRANSPOSE_CONV_2D + else: + opcode = NNAPI_OperationCode.CONV_2D + + inputs = [None] * num_args + inputs[0] = image_id + inputs[1] = weight_id + inputs[2] = bias_id + inputs[3] = self.add_immediate_int_scalar(args.pad_l) + inputs[4] = self.add_immediate_int_scalar(args.pad_r) + inputs[5] = self.add_immediate_int_scalar(args.pad_t) + inputs[6] = self.add_immediate_int_scalar(args.pad_b) + inputs[7] = self.add_immediate_int_scalar(args.stride_w) + inputs[8] = self.add_immediate_int_scalar(args.stride_h) + if depthwise: + inputs[9] = self.add_immediate_int_scalar(1) + inputs[10] = self.add_immediate_int_scalar(fuse_code) + inputs[11] = self.add_immediate_bool_scalar(use_nchw) + else: + inputs[9] = self.add_immediate_int_scalar(fuse_code) + inputs[10] = self.add_immediate_bool_scalar(use_nchw) + + outputs = [None] * 1 + out_shape = get_conv_pool_shape(image_oper.shape, args, out_c, transpose) + out_oper = image_oper._replace( + shape=out_shape, + scale=out_scale, + zero_point=out_zero_point, + ) + out_id = self.add_tensor_operand(jit_out, out_oper) + self._handle_conv_pool_flexible_input(out_id, jit_image, args, transpose) + + outputs[0] = out_id + self.add_operation(opcode, inputs, outputs) + + def _handle_conv_pool_flexible_input(self, out_id, jit_image, args, transpose): + image_id, image_oper = self.get_tensor_operand_by_jitval(jit_image) + batch, in_ch, in_h, in_w = image_oper.shape + + if batch == 0: + self.forward_operand_shape(out_id, 0, image_id, 0) + if in_ch == 0: + raise Exception("Input channels can't be flexible") # noqa: TRY002 + # H & W + if transpose: + if in_h == 0: + self.compute_operand_shape( + out_id, + 2, + f"({flex_name(image_id, 2)} - 1) * {args.stride_h} + {args.kernel_h} - {args.pad_t} - {args.pad_b}", + ) + if in_w == 0: + self.compute_operand_shape( + out_id, + 3, + f"({flex_name(image_id, 3)} - 1) * {args.stride_w} + {args.kernel_w} - {args.pad_l} - {args.pad_r}", + ) + else: + if in_h == 0: + self.compute_operand_shape( + out_id, + 2, + f"({flex_name(image_id, 2)} - {args.kernel_h} + {args.pad_t} + {args.pad_b}) // {args.stride_h} + 1", + ) + if in_w == 0: + self.compute_operand_shape( + out_id, + 3, + f"({flex_name(image_id, 3)} - {args.kernel_w} + {args.pad_l} + {args.pad_r}) // {args.stride_w} + 1", + ) + + +def serialize_model( + module, inputs, *, config=None, return_shapes=None, use_int16_for_qint16=False +): + """Convert to NNAPI and serialize torchscript module. + + Parameters: + module: Torchscript module to convert + inputs: Tensors used to specify input details for NNAPI + config (optional): Optional config to attach to module + return_shapes (optional): Specify shape of outputs if + your module uses runtime flexible shapes to set output + buffer size for NNAPI + use_int16_for_qint16 (optional): Use Pytorch int16 to represent NNAPI qint16 values + """ + return _NnapiSerializer(config, use_int16_for_qint16).serialize_model( + module, inputs, return_shapes + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cpu/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cpu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..82dc52cd4904c1cda023c876c586550a5a33ff7a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cpu/__init__.py @@ -0,0 +1,21 @@ +import torch + + +__all__ = [ + "get_cpu_capability", +] + + +def get_cpu_capability() -> str: + r"""Return cpu capability as a string value. + + Possible values: + - "DEFAULT" + - "VSX" + - "Z VECTOR" + - "NO AVX" + - "AVX2" + - "AVX512" + - "SVE256" + """ + return torch._C._get_cpu_capability() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cuda/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cuda/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d62c2b05a1ea1f3ecc5ceb0fbc17f5a714d87941 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cuda/__init__.py @@ -0,0 +1,593 @@ +# mypy: allow-untyped-defs +import contextlib +from typing import Any, Union +from typing_extensions import deprecated + +import torch + + +__all__ = [ + "is_built", + "cuFFTPlanCacheAttrContextProp", + "cuFFTPlanCache", + "cuFFTPlanCacheManager", + "cuBLASModule", + "preferred_linalg_library", + "preferred_blas_library", + "preferred_rocm_fa_library", + "cufft_plan_cache", + "matmul", + "SDPAParams", + "enable_cudnn_sdp", + "cudnn_sdp_enabled", + "enable_flash_sdp", + "flash_sdp_enabled", + "enable_mem_efficient_sdp", + "mem_efficient_sdp_enabled", + "math_sdp_enabled", + "enable_math_sdp", + "allow_fp16_bf16_reduction_math_sdp", + "fp16_bf16_reduction_math_sdp_allowed", + "is_flash_attention_available", + "can_use_flash_attention", + "can_use_efficient_attention", + "can_use_cudnn_attention", + "sdp_kernel", +] + + +def is_built(): + r""" + Return whether PyTorch is built with CUDA support. + + Note that this doesn't necessarily mean CUDA is available; just that if this PyTorch + binary were run on a machine with working CUDA drivers and devices, we would be able to use it. + """ + return torch._C._has_cuda + + +class cuFFTPlanCacheAttrContextProp: + # Like regular ContextProp, but uses the `.device_index` attribute from the + # calling object as the first argument to the getter and setter. + def __init__(self, getter, setter): + self.getter = getter + self.setter = setter + + def __get__(self, obj, objtype): + return self.getter(obj.device_index) + + def __set__(self, obj, val): + if isinstance(self.setter, str): + raise RuntimeError(self.setter) + self.setter(obj.device_index, val) + + +class cuFFTPlanCache: + r""" + Represent a specific plan cache for a specific `device_index`. + + The attributes `size` and `max_size`, and method `clear`, can fetch and/ or + change properties of the C++ cuFFT plan cache. + """ + + def __init__(self, device_index): + self.device_index = device_index + + size = cuFFTPlanCacheAttrContextProp( + torch._cufft_get_plan_cache_size, + ".size is a read-only property showing the number of plans currently in the " + "cache. To change the cache capacity, set cufft_plan_cache.max_size.", + ) + + max_size = cuFFTPlanCacheAttrContextProp( + torch._cufft_get_plan_cache_max_size, torch._cufft_set_plan_cache_max_size + ) + + def clear(self): + return torch._cufft_clear_plan_cache(self.device_index) + + +class cuFFTPlanCacheManager: + r""" + Represent all cuFFT plan caches, return the cuFFTPlanCache for a given device when indexed. + + Finally, this object, when used directly as a `cuFFTPlanCache` object (e.g., + setting the `.max_size`) attribute, the current device's cuFFT plan cache is + used. + """ + + __initialized = False + + def __init__(self): + self.caches = [] + self.__initialized = True + + def __getitem__(self, device): + index = torch.cuda._utils._get_device_index(device) + if index < 0 or index >= torch.cuda.device_count(): + raise RuntimeError( + f"cufft_plan_cache: expected 0 <= device index < {torch.cuda.device_count()}, but got " + f"device with index {index}" + ) + if len(self.caches) == 0: + self.caches.extend( + cuFFTPlanCache(index) for index in range(torch.cuda.device_count()) + ) + return self.caches[index] + + def __getattr__(self, name): + return getattr(self[torch.cuda.current_device()], name) + + def __setattr__(self, name, value): + if self.__initialized: + return setattr(self[torch.cuda.current_device()], name, value) + else: + return super().__setattr__(name, value) + + +class cuBLASModule: + @staticmethod + def _parse_reduction_setting(value: Any, attr_name: str) -> tuple[bool, bool]: + def _ensure_bool(obj: Any, which: str) -> bool: + if isinstance(obj, bool): + return obj + raise TypeError( + f"{attr_name} expects a bool for {which}, but got {type(obj)!r}" + ) + + if isinstance(value, bool): + return value, True + if isinstance(value, (list, tuple)): + if not value: + raise TypeError(f"{attr_name} expects at least one boolean argument") + if len(value) > 2: + raise TypeError(f"{attr_name} expects at most two boolean arguments") + allow_reduced_precision = _ensure_bool(value[0], "allow_reduced_precision") + if len(value) == 1: + return allow_reduced_precision, True + allow_splitk = _ensure_bool(value[1], "allow_splitk") + return allow_reduced_precision, allow_splitk + raise TypeError( + f"{attr_name} expects a bool or a tuple/list of bools, but got {type(value)!r}" + ) + + def __getattr__(self, name): + if name == "allow_tf32": + return torch._C._get_cublas_allow_tf32() + elif name == "allow_fp16_reduced_precision_reduction": + allow_reduced_precision, _ = ( + torch._C._get_cublas_allow_fp16_reduced_precision_reduction() + ) + return allow_reduced_precision + elif name == "allow_fp16_reduced_precision_reduction_split_k": + _, allow_splitk = ( + torch._C._get_cublas_allow_fp16_reduced_precision_reduction() + ) + return allow_splitk + elif name == "allow_bf16_reduced_precision_reduction": + allow_reduced_precision, _ = ( + torch._C._get_cublas_allow_bf16_reduced_precision_reduction() + ) + return allow_reduced_precision + elif name == "allow_bf16_reduced_precision_reduction_split_k": + _, allow_splitk = ( + torch._C._get_cublas_allow_bf16_reduced_precision_reduction() + ) + return allow_splitk + elif name == "allow_fp16_accumulation": + return torch._C._get_cublas_allow_fp16_accumulation() + elif name == "fp32_precision": + return torch._C._get_fp32_precision_getter("cuda", "matmul") + raise AttributeError("Unknown attribute " + name) + + def __setattr__(self, name, value): + if name == "allow_tf32": + return torch._C._set_cublas_allow_tf32(value) + elif name == "allow_fp16_reduced_precision_reduction": + allow_reduced_precision, allow_splitk = self._parse_reduction_setting( + value, "allow_fp16_reduced_precision_reduction" + ) + return torch._C._set_cublas_allow_fp16_reduced_precision_reduction( + allow_reduced_precision, + allow_splitk, + ) + elif name == "allow_bf16_reduced_precision_reduction": + allow_reduced_precision, allow_splitk = self._parse_reduction_setting( + value, "allow_bf16_reduced_precision_reduction" + ) + return torch._C._set_cublas_allow_bf16_reduced_precision_reduction( + allow_reduced_precision, + allow_splitk, + ) + elif name == "allow_fp16_accumulation": + return torch._C._set_cublas_allow_fp16_accumulation(value) + elif name == "fp32_precision": + return torch._C._set_fp32_precision_setter("cuda", "matmul", value) + raise AttributeError("Unknown attribute " + name) + + +_LinalgBackends = { + "default": torch._C._LinalgBackend.Default, + "cusolver": torch._C._LinalgBackend.Cusolver, + "magma": torch._C._LinalgBackend.Magma, +} +_LinalgBackends_str = ", ".join(_LinalgBackends.keys()) + + +def preferred_linalg_library( + backend: Union[None, str, torch._C._LinalgBackend] = None, +) -> torch._C._LinalgBackend: + r""" + Override the heuristic PyTorch uses to choose between cuSOLVER and MAGMA for CUDA linear algebra operations. + + .. warning:: This flag is experimental and subject to change. + + When PyTorch runs a CUDA linear algebra operation it often uses the cuSOLVER or MAGMA libraries, + and if both are available it decides which to use with a heuristic. + This flag (a :class:`str`) allows overriding those heuristics. + + * If `"cusolver"` is set then cuSOLVER will be used wherever possible. + * If `"magma"` is set then MAGMA will be used wherever possible. + * If `"default"` (the default) is set then heuristics will be used to pick between + cuSOLVER and MAGMA if both are available. + * When no input is given, this function returns the currently preferred library. + * User may use the environment variable TORCH_LINALG_PREFER_CUSOLVER=1 to set the preferred library to cuSOLVER + globally. + This flag only sets the initial value of the preferred library and the preferred library + may still be overridden by this function call later in your script. + + Note: When a library is preferred other libraries may still be used if the preferred library + doesn't implement the operation(s) called. + This flag may achieve better performance if PyTorch's heuristic library selection is incorrect + for your application's inputs. + + Currently supported linalg operators: + + * :func:`torch.linalg.inv` + * :func:`torch.linalg.inv_ex` + * :func:`torch.linalg.cholesky` + * :func:`torch.linalg.cholesky_ex` + * :func:`torch.cholesky_solve` + * :func:`torch.cholesky_inverse` + * :func:`torch.linalg.lu_factor` + * :func:`torch.linalg.lu` + * :func:`torch.linalg.lu_solve` + * :func:`torch.linalg.qr` + * :func:`torch.linalg.eigh` + * :func:`torch.linalg.eighvals` + * :func:`torch.linalg.svd` + * :func:`torch.linalg.svdvals` + """ + if backend is None: + pass + elif isinstance(backend, str): + if backend not in _LinalgBackends: + raise RuntimeError( + f"Unknown input value. Choose from: {_LinalgBackends_str}." + ) + torch._C._set_linalg_preferred_backend(_LinalgBackends[backend]) + elif isinstance(backend, torch._C._LinalgBackend): + torch._C._set_linalg_preferred_backend(backend) + else: + raise RuntimeError("Unknown input value type.") + + return torch._C._get_linalg_preferred_backend() + + +_BlasBackends = { + "default": torch._C._BlasBackend.Default, + "cublas": torch._C._BlasBackend.Cublas, + "hipblas": torch._C._BlasBackend.Cublas, # alias + "cublaslt": torch._C._BlasBackend.Cublaslt, + "hipblaslt": torch._C._BlasBackend.Cublaslt, # alias + "ck": torch._C._BlasBackend.Ck, +} +_BlasBackends_str = ", ".join(_BlasBackends.keys()) + + +def preferred_blas_library( + backend: Union[None, str, torch._C._BlasBackend] = None, +) -> torch._C._BlasBackend: + r""" + Override the library PyTorch uses for BLAS operations. Choose between cuBLAS, cuBLASLt, and CK [ROCm-only]. + + .. warning:: This flag is experimental and subject to change. + + When PyTorch runs a CUDA BLAS operation it defaults to cuBLAS even if both cuBLAS and cuBLASLt are available. + For PyTorch built for ROCm, hipBLAS, hipBLASLt, and CK may offer different performance. + This flag (a :class:`str`) allows overriding which BLAS library to use. + + * If `"cublas"` is set then cuBLAS will be used wherever possible. + * If `"cublaslt"` is set then cuBLASLt will be used wherever possible. + * If `"ck"` is set then CK will be used wherever possible. + * If `"default"` (the default) is set then heuristics will be used to pick between the other options. + * When no input is given, this function returns the currently preferred library. + * User may use the environment variable TORCH_BLAS_PREFER_CUBLASLT=1 to set the preferred library to cuBLASLt + globally. + This flag only sets the initial value of the preferred library and the preferred library + may still be overridden by this function call later in your script. + + Note: When a library is preferred other libraries may still be used if the preferred library + doesn't implement the operation(s) called. + This flag may achieve better performance if PyTorch's library selection is incorrect + for your application's inputs. + + """ + if backend is None: + pass + elif isinstance(backend, str): + if backend not in _BlasBackends: + raise RuntimeError( + f"Unknown input value. Choose from: {_BlasBackends_str}." + ) + torch._C._set_blas_preferred_backend(_BlasBackends[backend]) + elif isinstance(backend, torch._C._BlasBackend): + torch._C._set_blas_preferred_backend(backend) + else: + raise RuntimeError("Unknown input value type.") + + return torch._C._get_blas_preferred_backend() + + +_ROCmFABackends = { + "default": torch._C._ROCmFABackend.Default, + "aotriton": torch._C._ROCmFABackend.AOTriton, + "ck": torch._C._ROCmFABackend.Ck, +} +_ROCmFABackends_str = ", ".join(_ROCmFABackends.keys()) + + +from torch._C import _SDPAParams as SDPAParams, _SDPBackend as SDPBackend + + +def preferred_rocm_fa_library( + backend: Union[None, str, torch._C._ROCmFABackend] = None, +) -> torch._C._ROCmFABackend: + r""" + [ROCm-only] + Override the backend PyTorch uses in ROCm environments for Flash Attention. Choose between AOTriton and CK + + .. warning:: This flag is experimental and subject to change. + + When Flash Attention is enabled and desired, PyTorch defaults to using AOTriton as the backend. + This flag (a :class:`str`) allows users to override this backend to use composable_kernel + + * If `"default"` is set then the default backend will be used wherever possible. Currently AOTriton. + * If `"aotriton"` is set then AOTriton will be used wherever possible. + * If `"ck"` is set then CK will be used wherever possible. + * When no input is given, this function returns the currently preferred library. + * User may use the environment variable TORCH_ROCM_FA_PREFER_CK=1 to set the preferred library to CK + globally. + + Note: When a library is preferred other libraries may still be used if the preferred library + doesn't implement the operation(s) called. + This flag may achieve better performance if PyTorch's library selection is incorrect + for your application's inputs. + """ + if backend is None: + pass + elif isinstance(backend, str): + if backend not in _ROCmFABackends: + raise RuntimeError( + f"Unknown input value. Choose from: {_ROCmFABackends_str}." + ) + torch._C._set_rocm_fa_preferred_backend(_ROCmFABackends[backend]) + elif isinstance(backend, torch._C._ROCmFABackend): + torch._C._set_rocm_fa_preferred_backend(backend) + else: + raise ValueError(f"Unknown input value. Choose from: {_ROCmFABackends_str}.") + + return torch._C._get_rocm_fa_preferred_backend() + + +# Set the __module__ attribute +SDPAParams.__module__ = "torch.backends.cuda" +SDPAParams.__name__ = "SDPAParams" + + +def flash_sdp_enabled(): + r""" + .. warning:: This flag is beta and subject to change. + + Returns whether flash scaled dot product attention is enabled or not. + """ + return torch._C._get_flash_sdp_enabled() + + +def enable_flash_sdp(enabled: bool): + r""" + .. warning:: This flag is beta and subject to change. + + Enables or disables flash scaled dot product attention. + """ + torch._C._set_sdp_use_flash(enabled) + + +def mem_efficient_sdp_enabled(): + r""" + .. warning:: This flag is beta and subject to change. + + Returns whether memory efficient scaled dot product attention is enabled or not. + """ + return torch._C._get_mem_efficient_sdp_enabled() + + +def enable_mem_efficient_sdp(enabled: bool): + r""" + .. warning:: This flag is beta and subject to change. + + Enables or disables memory efficient scaled dot product attention. + """ + torch._C._set_sdp_use_mem_efficient(enabled) + + +def math_sdp_enabled(): + r""" + .. warning:: This flag is beta and subject to change. + + Returns whether math scaled dot product attention is enabled or not. + """ + return torch._C._get_math_sdp_enabled() + + +def enable_math_sdp(enabled: bool): + r""" + .. warning:: This flag is beta and subject to change. + + Enables or disables math scaled dot product attention. + """ + torch._C._set_sdp_use_math(enabled) + + +def allow_fp16_bf16_reduction_math_sdp(enabled: bool): + r""" + .. warning:: This flag is beta and subject to change. + + Enables or disables fp16/bf16 reduction in math scaled dot product attention. + """ + torch._C._set_math_sdp_allow_fp16_bf16_reduction(enabled) + + +def fp16_bf16_reduction_math_sdp_allowed(): + r""" + .. warning:: This flag is beta and subject to change. + + Returns whether fp16/bf16 reduction in math scaled dot product attention is enabled or not. + """ + return torch._C._get_math_sdp_allow_fp16_bf16_reduction() + + +def is_flash_attention_available() -> bool: + r"""Check if PyTorch was built with FlashAttention for scaled_dot_product_attention. + + Returns: + True if FlashAttention is built and available; otherwise, False. + + Note: + This function is dependent on a CUDA-enabled build of PyTorch. It will return False + in non-CUDA environments. + """ + return torch._C._is_flash_attention_available() + + +def can_use_flash_attention(params: SDPAParams, debug: bool = False) -> bool: + r"""Check if FlashAttention can be utilized in scaled_dot_product_attention. + + Args: + params: An instance of SDPAParams containing the tensors for query, + key, value, an optional attention mask, dropout rate, and + a flag indicating if the attention is causal. + debug: Whether to logging.warn debug information as to why FlashAttention could not be run. + Defaults to False. + + Returns: + True if FlashAttention can be used with the given parameters; otherwise, False. + + Note: + This function is dependent on a CUDA-enabled build of PyTorch. It will return False + in non-CUDA environments. + """ + return torch._C._can_use_flash_attention(params, debug) + + +def can_use_efficient_attention(params: SDPAParams, debug: bool = False) -> bool: + r"""Check if efficient_attention can be utilized in scaled_dot_product_attention. + + Args: + params: An instance of SDPAParams containing the tensors for query, + key, value, an optional attention mask, dropout rate, and + a flag indicating if the attention is causal. + debug: Whether to logging.warn with information as to why efficient_attention could not be run. + Defaults to False. + + Returns: + True if efficient_attention can be used with the given parameters; otherwise, False. + + Note: + This function is dependent on a CUDA-enabled build of PyTorch. It will return False + in non-CUDA environments. + """ + return torch._C._can_use_mem_efficient_attention(params, debug) + + +def can_use_cudnn_attention(params: SDPAParams, debug: bool = False) -> bool: + r"""Check if cudnn_attention can be utilized in scaled_dot_product_attention. + + Args: + params: An instance of SDPAParams containing the tensors for query, + key, value, an optional attention mask, dropout rate, and + a flag indicating if the attention is causal. + debug: Whether to logging.warn with information as to why cuDNN attention could not be run. + Defaults to False. + + Returns: + True if cuDNN can be used with the given parameters; otherwise, False. + + Note: + This function is dependent on a CUDA-enabled build of PyTorch. It will return False + in non-CUDA environments. + """ + return torch._C._can_use_cudnn_attention(params, debug) + + +def cudnn_sdp_enabled(): + r""" + .. warning:: This flag is beta and subject to change. + + Returns whether cuDNN scaled dot product attention is enabled or not. + """ + return torch._C._get_cudnn_sdp_enabled() + + +def enable_cudnn_sdp(enabled: bool): + r""" + .. warning:: This flag is beta and subject to change. + + Enables or disables cuDNN scaled dot product attention. + """ + torch._C._set_sdp_use_cudnn(enabled) + + +@contextlib.contextmanager +@deprecated( + ( + "`torch.backends.cuda.sdp_kernel()` is deprecated. " + "In the future, this context manager will be removed. " + "Please see `torch.nn.attention.sdpa_kernel()` for the new context manager, " + "with updated signature." + ), + category=FutureWarning, +) +def sdp_kernel( + enable_flash: bool = True, + enable_math: bool = True, + enable_mem_efficient: bool = True, + enable_cudnn: bool = True, +): + r""" + .. warning:: This flag is beta and subject to change. + + This context manager can be used to temporarily enable or disable any of the three backends for scaled dot product attention. + Upon exiting the context manager, the previous state of the flags will be restored. + """ + from torch.nn.attention import sdpa_kernel + + backend_list = [] + if enable_flash: + backend_list.append(SDPBackend.FLASH_ATTENTION) + if enable_mem_efficient: + backend_list.append(SDPBackend.EFFICIENT_ATTENTION) + if enable_math: + backend_list.append(SDPBackend.MATH) + if enable_cudnn: + backend_list.append(SDPBackend.CUDNN_ATTENTION) + + with sdpa_kernel(backend_list) as context: + try: + yield context + finally: + pass + + +cufft_plan_cache = cuFFTPlanCacheManager() +matmul = cuBLASModule() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cuda/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cuda/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02d68c086f86adf15cd3d0839ec3d1f09c672c24 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cuda/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cudnn/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cudnn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5cd6ec297c7a8a21c407e12112ba961b76624a6f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cudnn/__init__.py @@ -0,0 +1,248 @@ +# mypy: allow-untyped-defs +import os +import sys +import warnings +from contextlib import contextmanager +from typing import Optional + +import torch +from torch.backends import ( + __allow_nonbracketed_mutation, + _FP32Precision, + _get_fp32_precision_getter, + _set_fp32_precision_setter, + ContextProp, + PropModule, +) + + +try: + from torch._C import _cudnn +except ImportError: + _cudnn = None # type: ignore[assignment] + +# Write: +# +# torch.backends.cudnn.enabled = False +# +# to globally disable CuDNN/MIOpen + +__cudnn_version: Optional[int] = None + +if _cudnn is not None: + + def _init(): + global __cudnn_version + if __cudnn_version is None: + # pyrefly: ignore [missing-attribute] + __cudnn_version = _cudnn.getVersionInt() + # pyrefly: ignore [missing-attribute] + runtime_version = _cudnn.getRuntimeVersion() + # pyrefly: ignore [missing-attribute] + compile_version = _cudnn.getCompileVersion() + runtime_major, runtime_minor, _ = runtime_version + compile_major, compile_minor, _ = compile_version + # Different major versions are always incompatible + # Starting with cuDNN 7, minor versions are backwards-compatible + # Not sure about MIOpen (ROCm), so always do a strict check + if runtime_major != compile_major: + cudnn_compatible = False + # pyrefly: ignore [missing-attribute] + elif runtime_major < 7 or not _cudnn.is_cuda: + cudnn_compatible = runtime_minor == compile_minor + else: + cudnn_compatible = runtime_minor >= compile_minor + if not cudnn_compatible: + if os.environ.get("PYTORCH_SKIP_CUDNN_COMPATIBILITY_CHECK", "0") == "1": + return True + base_error_msg = ( + f"cuDNN version incompatibility: " + f"PyTorch was compiled against {compile_version} " + f"but found runtime version {runtime_version}. " + f"PyTorch already comes bundled with cuDNN. " + f"One option to resolving this error is to ensure PyTorch " + f"can find the bundled cuDNN. " + ) + + if "LD_LIBRARY_PATH" in os.environ: + ld_library_path = os.environ.get("LD_LIBRARY_PATH", "") + if any( + substring in ld_library_path for substring in ["cuda", "cudnn"] + ): + raise RuntimeError( + f"{base_error_msg}" + f"Looks like your LD_LIBRARY_PATH contains incompatible version of cudnn. " + f"Please either remove it from the path or install cudnn {compile_version}" + ) + else: + raise RuntimeError( + f"{base_error_msg}" + f"one possibility is that there is a " + f"conflicting cuDNN in LD_LIBRARY_PATH." + ) + else: + raise RuntimeError(base_error_msg) + # Check if cuDNN version is compatible with available CUDA devices + if torch.cuda.is_available() and not torch.version.hip: + min_cc = min( + [ + torch.cuda.get_device_capability(i) + for i in range(torch.cuda.device_count()) + ] + ) + if __cudnn_version >= 91100 and min_cc < (7, 5): + raise RuntimeError( + f"cuDNN version {__cudnn_version} is not compatible with devices with SM < 7.5. " + f"Please install a version of PyTorch with a compatible cuDNN version. " + f"https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix" + ) + + return True + +else: + + def _init(): + return False + + +def version(): + """Return the version of cuDNN.""" + if not _init(): + return None + return __cudnn_version + + +CUDNN_TENSOR_DTYPES = { + torch.half, + torch.float, + torch.double, +} + + +def is_available(): + r"""Return a bool indicating if CUDNN is currently available.""" + return torch._C._has_cudnn + + +def is_acceptable(tensor): + if not torch._C._get_cudnn_enabled(): + return False + if tensor.device.type != "cuda" or tensor.dtype not in CUDNN_TENSOR_DTYPES: + return False + if not is_available(): + warnings.warn( + "PyTorch was compiled without cuDNN/MIOpen support. To use cuDNN/MIOpen, rebuild " + "PyTorch making sure the library is visible to the build system.", + stacklevel=2, + ) + return False + if not _init(): + warnings.warn( + "cuDNN/MIOpen library not found. Check your {libpath}".format( + libpath={"darwin": "DYLD_LIBRARY_PATH", "win32": "PATH"}.get( + sys.platform, "LD_LIBRARY_PATH" + ) + ), + stacklevel=2, + ) + return False + return True + + +def set_flags( + _enabled=None, + _benchmark=None, + _benchmark_limit=None, + _deterministic=None, + _allow_tf32=None, + _fp32_precision="none", +): + orig_flags = ( + torch._C._get_cudnn_enabled(), + torch._C._get_cudnn_benchmark(), + None if not is_available() else torch._C._cuda_get_cudnn_benchmark_limit(), + torch._C._get_cudnn_deterministic(), + torch._C._get_cudnn_allow_tf32(), + torch._C._get_fp32_precision_getter("cuda", "all"), + ) + if _enabled is not None: + torch._C._set_cudnn_enabled(_enabled) + if _benchmark is not None: + torch._C._set_cudnn_benchmark(_benchmark) + if _benchmark_limit is not None and is_available(): + torch._C._cuda_set_cudnn_benchmark_limit(_benchmark_limit) + if _deterministic is not None: + torch._C._set_cudnn_deterministic(_deterministic) + if _allow_tf32 is not None: + torch._C._set_cudnn_allow_tf32(_allow_tf32) + if _fp32_precision is not None: + torch._C._set_fp32_precision_setter("cuda", "all", _fp32_precision) + return orig_flags + + +@contextmanager +def flags( + enabled=False, + benchmark=False, + benchmark_limit=10, + deterministic=False, + allow_tf32=True, + fp32_precision="none", +): + with __allow_nonbracketed_mutation(): + orig_flags = set_flags( + enabled, + benchmark, + benchmark_limit, + deterministic, + allow_tf32, + fp32_precision, + ) + try: + yield + finally: + # recover the previous values + with __allow_nonbracketed_mutation(): + set_flags(*orig_flags) + + +# The magic here is to allow us to intercept code like this: +# +# torch.backends..enabled = True + + +class CudnnModule(PropModule): + enabled = ContextProp(torch._C._get_cudnn_enabled, torch._C._set_cudnn_enabled) + deterministic = ContextProp( + torch._C._get_cudnn_deterministic, torch._C._set_cudnn_deterministic + ) + benchmark = ContextProp( + torch._C._get_cudnn_benchmark, torch._C._set_cudnn_benchmark + ) + benchmark_limit = None + if is_available(): + benchmark_limit = ContextProp( + torch._C._cuda_get_cudnn_benchmark_limit, + torch._C._cuda_set_cudnn_benchmark_limit, + ) + allow_tf32 = ContextProp( + torch._C._get_cudnn_allow_tf32, torch._C._set_cudnn_allow_tf32 + ) + conv = _FP32Precision("cuda", "conv") + rnn = _FP32Precision("cuda", "rnn") + fp32_precision = ContextProp( + _get_fp32_precision_getter("cuda", "all"), + _set_fp32_precision_setter("cuda", "all"), + ) + + +# This is the sys.modules replacement trick, see +# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273 +sys.modules[__name__] = CudnnModule(sys.modules[__name__], __name__) + +# Add type annotation for the replaced module +enabled: bool +deterministic: bool +benchmark: bool +allow_tf32: bool +benchmark_limit: int diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cudnn/rnn.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cudnn/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..0dc9ca80aa6fd10efc41910d38ba33d00852729c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cudnn/rnn.py @@ -0,0 +1,69 @@ +# mypy: allow-untyped-defs +import torch.cuda + + +try: + from torch._C import _cudnn +except ImportError: + # Uses of all the functions below should be guarded by torch.backends.cudnn.is_available(), + # so it's safe to not emit any checks here. + _cudnn = None # type: ignore[assignment] + + +def get_cudnn_mode(mode): + if mode == "RNN_RELU": + # pyrefly: ignore [missing-attribute] + return int(_cudnn.RNNMode.rnn_relu) + elif mode == "RNN_TANH": + # pyrefly: ignore [missing-attribute] + return int(_cudnn.RNNMode.rnn_tanh) + elif mode == "LSTM": + # pyrefly: ignore [missing-attribute] + return int(_cudnn.RNNMode.lstm) + elif mode == "GRU": + # pyrefly: ignore [missing-attribute] + return int(_cudnn.RNNMode.gru) + else: + raise Exception(f"Unknown mode: {mode}") # noqa: TRY002 + + +# NB: We don't actually need this class anymore (in fact, we could serialize the +# dropout state for even better reproducibility), but it is kept for backwards +# compatibility for old models. +class Unserializable: + def __init__(self, inner): + self.inner = inner + + def get(self): + return self.inner + + def __getstate__(self): + # Note: can't return {}, because python2 won't call __setstate__ + # if the value evaluates to False + return "" + + def __setstate__(self, state): + self.inner = None + + +def init_dropout_state(dropout, train, dropout_seed, dropout_state): + dropout_desc_name = "desc_" + str(torch.cuda.current_device()) + dropout_p = dropout if train else 0 + if (dropout_desc_name not in dropout_state) or ( + dropout_state[dropout_desc_name].get() is None + ): + if dropout_p == 0: + dropout_state[dropout_desc_name] = Unserializable(None) + else: + dropout_state[dropout_desc_name] = Unserializable( + torch._cudnn_init_dropout_state( # type: ignore[call-arg] + dropout_p, + train, + dropout_seed, + # pyrefly: ignore [unexpected-keyword] + self_ty=torch.uint8, + device=torch.device("cuda"), + ) + ) + dropout_ts = dropout_state[dropout_desc_name].get() + return dropout_ts diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cusparselt/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cusparselt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3e9b9df2acf144e00a193ee312f728ec30327f8a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cusparselt/__init__.py @@ -0,0 +1,57 @@ +from typing import Optional + +import torch + + +__all__ = [ + "version", + "is_available", + "get_max_alg_id", +] + +try: + from torch._C import _cusparselt +except ImportError: + _cusparselt = None # type: ignore[assignment] + +__cusparselt_version: Optional[int] = None +__MAX_ALG_ID: Optional[int] = None + +if _cusparselt is not None: + + def _init() -> bool: + global __cusparselt_version + global __MAX_ALG_ID + if __cusparselt_version is None: + # pyrefly: ignore [missing-attribute] + __cusparselt_version = _cusparselt.getVersionInt() + if __cusparselt_version == 400: + __MAX_ALG_ID = 4 + elif __cusparselt_version == 502: + __MAX_ALG_ID = 5 + elif __cusparselt_version == 602: + __MAX_ALG_ID = 37 + return True + +else: + + def _init() -> bool: + return False + + +def version() -> Optional[int]: + """Return the version of cuSPARSELt""" + if not _init(): + return None + return __cusparselt_version + + +def is_available() -> bool: + r"""Return a bool indicating if cuSPARSELt is currently available.""" + return torch._C._has_cusparselt + + +def get_max_alg_id() -> Optional[int]: + if not _init(): + return None + return __MAX_ALG_ID diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cusparselt/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cusparselt/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdff84be79e14964ea9c61b1b21ba992aa81a268 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/cusparselt/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/kleidiai/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/kleidiai/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1a681b77ef58ce1f390232b82c4a9843d5559ca3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/kleidiai/__init__.py @@ -0,0 +1,7 @@ +# mypy: allow-untyped-defs +import torch + + +def is_available(): + r"""Return whether PyTorch is built with KleidiAI support.""" + return torch._C._has_kleidiai diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/mha/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/mha/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e1dd2ebd688805bdf3359cb56b64d0854cf258c4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/mha/__init__.py @@ -0,0 +1,25 @@ +# Config options to enable/disable C++ kernel for nn.functional.MHA +# and nn.TransformerEncoder +import torch + + +_is_fastpath_enabled: bool = True + + +def get_fastpath_enabled() -> bool: + """Returns whether fast path for TransformerEncoder and MultiHeadAttention + is enabled, or ``True`` if jit is scripting. + + .. note:: + The fastpath might not be run even if ``get_fastpath_enabled`` returns + ``True`` unless all conditions on inputs are met. + """ + if not torch.jit.is_scripting(): + return _is_fastpath_enabled + return True + + +def set_fastpath_enabled(value: bool) -> None: + """Sets whether fast path is enabled""" + global _is_fastpath_enabled + _is_fastpath_enabled = value diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/miopen/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/miopen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1b270b658e31a91dfb37380abec383009dfc5bfa --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/miopen/__init__.py @@ -0,0 +1,50 @@ +# mypy: allow-untyped-defs +import sys +from contextlib import contextmanager + +import torch +from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule + + +def set_flags( + _immediate=None, +): + orig_flags = (torch._C._get_miopen_immediate(),) + if _immediate is not None: + torch._C._set_miopen_immediate(_immediate) + return orig_flags + + +@contextmanager +def flags( + immediate=False, +): + with __allow_nonbracketed_mutation(): + orig_flags = set_flags( + immediate, + ) + try: + yield + finally: + # recover the previous values + with __allow_nonbracketed_mutation(): + set_flags(*orig_flags) + + +# The magic here is to allow us to intercept code like this: +# +# torch.backends..immediate = True + + +class MiopenModule(PropModule): + immediate = ContextProp( + torch._C._get_miopen_immediate, torch._C._set_miopen_immediate + ) + + +# This is the sys.modules replacement trick, see +# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273 +sys.modules[__name__] = MiopenModule(sys.modules[__name__], __name__) + +# Add type annotation for the replaced module +immediate: bool diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/mkl/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/mkl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae16922761afeafa53766757641bcc532b4d5ef4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/mkl/__init__.py @@ -0,0 +1,58 @@ +# mypy: allow-untyped-defs +import torch + + +def is_available(): + r"""Return whether PyTorch is built with MKL support.""" + return torch._C.has_mkl + + +VERBOSE_OFF = 0 +VERBOSE_ON = 1 + + +class verbose: + """ + On-demand oneMKL verbosing functionality. + + To make it easier to debug performance issues, oneMKL can dump verbose + messages containing execution information like duration while executing + the kernel. The verbosing functionality can be invoked via an environment + variable named `MKL_VERBOSE`. However, this methodology dumps messages in + all steps. Those are a large amount of verbose messages. Moreover, for + investigating the performance issues, generally taking verbose messages + for one single iteration is enough. This on-demand verbosing functionality + makes it possible to control scope for verbose message dumping. In the + following example, verbose messages will be dumped out for the second + inference only. + + .. highlight:: python + .. code-block:: python + + import torch + + model(data) + with torch.backends.mkl.verbose(torch.backends.mkl.VERBOSE_ON): + model(data) + + Args: + level: Verbose level + - ``VERBOSE_OFF``: Disable verbosing + - ``VERBOSE_ON``: Enable verbosing + """ + + def __init__(self, enable): + self.enable = enable + + def __enter__(self): + if self.enable == VERBOSE_OFF: + return + st = torch._C._verbose.mkl_set_verbose(self.enable) + assert st, ( + "Failed to set MKL into verbose mode. Please consider to disable this verbose scope." + ) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + torch._C._verbose.mkl_set_verbose(VERBOSE_OFF) + return False diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..58e6b2c595e9853942b0a3a58a6e5ab2627d3608 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py @@ -0,0 +1,137 @@ +# mypy: allow-untyped-defs +import sys +from contextlib import contextmanager +from typing import TYPE_CHECKING + +import torch +from torch.backends import ( + __allow_nonbracketed_mutation, + _FP32Precision, + _get_fp32_precision_getter, + _set_fp32_precision_setter, + ContextProp, + PropModule, +) + + +def is_available(): + r"""Return whether PyTorch is built with MKL-DNN support.""" + return torch._C._has_mkldnn + + +def is_acl_available(): + r"""Return whether PyTorch is built with MKL-DNN + ACL support.""" + # pyrefly: ignore [missing-attribute] + return torch._C._has_mkldnn_acl + + +VERBOSE_OFF = 0 +VERBOSE_ON = 1 +VERBOSE_ON_CREATION = 2 + + +class verbose: + """ + On-demand oneDNN (former MKL-DNN) verbosing functionality. + + To make it easier to debug performance issues, oneDNN can dump verbose + messages containing information like kernel size, input data size and + execution duration while executing the kernel. The verbosing functionality + can be invoked via an environment variable named `DNNL_VERBOSE`. However, + this methodology dumps messages in all steps. Those are a large amount of + verbose messages. Moreover, for investigating the performance issues, + generally taking verbose messages for one single iteration is enough. + This on-demand verbosing functionality makes it possible to control scope + for verbose message dumping. In the following example, verbose messages + will be dumped out for the second inference only. + + .. highlight:: python + .. code-block:: python + + import torch + + model(data) + with torch.backends.mkldnn.verbose(torch.backends.mkldnn.VERBOSE_ON): + model(data) + + Args: + level: Verbose level + - ``VERBOSE_OFF``: Disable verbosing + - ``VERBOSE_ON``: Enable verbosing + - ``VERBOSE_ON_CREATION``: Enable verbosing, including oneDNN kernel creation + """ + + def __init__(self, level): + self.level = level + + def __enter__(self): + if self.level == VERBOSE_OFF: + return + st = torch._C._verbose.mkldnn_set_verbose(self.level) + assert st, ( + "Failed to set MKLDNN into verbose mode. Please consider to disable this verbose scope." + ) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + torch._C._verbose.mkldnn_set_verbose(VERBOSE_OFF) + return False + + +def set_flags( + _enabled=None, _deterministic=None, _allow_tf32=None, _fp32_precision="none" +): + orig_flags = ( + torch._C._get_mkldnn_enabled(), + torch._C._get_mkldnn_deterministic(), + torch._C._get_onednn_allow_tf32(), + torch._C._get_fp32_precision_getter("mkldnn", "all"), + ) + if _enabled is not None: + torch._C._set_mkldnn_enabled(_enabled) + if _deterministic is not None: + torch._C._set_mkldnn_deterministic(_deterministic) + if _allow_tf32 is not None: + torch._C._set_onednn_allow_tf32(_allow_tf32) + if _fp32_precision is not None: + torch._C._set_fp32_precision_setter("mkldnn", "all", _fp32_precision) + return orig_flags + + +@contextmanager +def flags(enabled=False, deterministic=False, allow_tf32=True, fp32_precision="none"): + with __allow_nonbracketed_mutation(): + orig_flags = set_flags(enabled, deterministic, allow_tf32, fp32_precision) + try: + yield + finally: + with __allow_nonbracketed_mutation(): + set_flags(*orig_flags) + + +class MkldnnModule(PropModule): + def is_available(self): + return is_available() + + enabled = ContextProp(torch._C._get_mkldnn_enabled, torch._C._set_mkldnn_enabled) + deterministic = ContextProp( + torch._C._get_mkldnn_deterministic, torch._C._set_mkldnn_deterministic + ) + allow_tf32 = ContextProp( + torch._C._get_onednn_allow_tf32, torch._C._set_onednn_allow_tf32 + ) + matmul = _FP32Precision("mkldnn", "matmul") + conv = _FP32Precision("mkldnn", "conv") + rnn = _FP32Precision("mkldnn", "rnn") + fp32_precision = ContextProp( + _get_fp32_precision_getter("mkldnn", "all"), + _set_fp32_precision_setter("generic", "all"), + ) + + +if TYPE_CHECKING: + enabled: ContextProp + deterministic: ContextProp + allow_tf32: ContextProp + +sys.modules[__name__] = MkldnnModule(sys.modules[__name__], __name__) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/mps/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/mps/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5c3c507428cfff85a02e1d9939b4951d7e8b84bf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/mps/__init__.py @@ -0,0 +1,78 @@ +from functools import lru_cache as _lru_cache +from typing import Optional, TYPE_CHECKING + +import torch +from torch.library import Library as _Library + + +__all__ = [ + "get_core_count", + "get_name", + "is_built", + "is_available", + "is_macos13_or_newer", + "is_macos_or_newer", +] + + +def is_built() -> bool: + r"""Return whether PyTorch is built with MPS support. + + Note that this doesn't necessarily mean MPS is available; just that + if this PyTorch binary were run a machine with working MPS drivers + and devices, we would be able to use it. + """ + return torch._C._has_mps + + +@_lru_cache +def is_available() -> bool: + r"""Return a bool indicating if MPS is currently available.""" + return torch._C._mps_is_available() + + +@_lru_cache +def is_macos_or_newer(major: int, minor: int) -> bool: + r"""Return a bool indicating whether MPS is running on given MacOS or newer.""" + return torch._C._mps_is_on_macos_or_newer(major, minor) + + +@_lru_cache +def is_macos13_or_newer(minor: int = 0) -> bool: + r"""Return a bool indicating whether MPS is running on MacOS 13 or newer.""" + return torch._C._mps_is_on_macos_or_newer(13, minor) + + +@_lru_cache +def get_name() -> str: + r"""Return Metal device name""" + return torch._C._mps_get_name() + + +@_lru_cache +def get_core_count() -> int: + r"""Return GPU core count. + + According to the documentation, one core is comprised of 16 Execution Units. + One execution Unit has 8 ALUs. + And one ALU can run 24 threads, i.e. one core is capable of executing 3072 threads concurrently. + """ + return torch._C._mps_get_core_count() + + +_lib: Optional[_Library] = None + + +def _init() -> None: + r"""Register prims as implementation of var_mean and group_norm.""" + global _lib + + if _lib is not None or not is_built(): + return + + from torch._decomp.decompositions import native_group_norm_backward + from torch._refs import native_group_norm + + _lib = _Library("aten", "IMPL") # noqa: TOR901 + _lib.impl("native_group_norm", native_group_norm, "MPS") + _lib.impl("native_group_norm_backward", native_group_norm_backward, "MPS") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/nnpack/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/nnpack/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8d8a72f3cda9b0da16702c0d7c6fe92ae8f3f153 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/nnpack/__init__.py @@ -0,0 +1,32 @@ +# mypy: allow-untyped-defs +from contextlib import contextmanager + +import torch +from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule + + +__all__ = ["is_available", "flags", "set_flags"] + + +def is_available(): + r"""Return whether PyTorch is built with NNPACK support.""" + return torch._nnpack_available() + + +def set_flags(_enabled): + r"""Set if nnpack is enabled globally""" + orig_flags = (torch._C._get_nnpack_enabled(),) + torch._C._set_nnpack_enabled(_enabled) + return orig_flags + + +@contextmanager +def flags(enabled=False): + r"""Context manager for setting if nnpack is enabled globally""" + with __allow_nonbracketed_mutation(): + orig_flags = set_flags(enabled) + try: + yield + finally: + with __allow_nonbracketed_mutation(): + set_flags(orig_flags[0]) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/openmp/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/openmp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aff8d46cd4ac2d9ff49942542d99ac2afbb85896 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/openmp/__init__.py @@ -0,0 +1,7 @@ +# mypy: allow-untyped-defs +import torch + + +def is_available(): + r"""Return whether PyTorch is built with OpenMP support.""" + return torch._C.has_openmp diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/opt_einsum/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/opt_einsum/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..264be78aa9a1c24a4624a87782e2b2c5afd29c05 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/opt_einsum/__init__.py @@ -0,0 +1,117 @@ +# mypy: allow-untyped-defs +import sys +import warnings +from contextlib import contextmanager +from functools import lru_cache as _lru_cache +from typing import Any + +from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule + + +try: + import opt_einsum as _opt_einsum # type: ignore[import] +except ImportError: + _opt_einsum = None + + +@_lru_cache +def is_available() -> bool: + r"""Return a bool indicating if opt_einsum is currently available. + + You must install opt-einsum in order for torch to automatically optimize einsum. To + make opt-einsum available, you can install it along with torch: ``pip install torch[opt-einsum]`` + or by itself: ``pip install opt-einsum``. If the package is installed, torch will import + it automatically and use it accordingly. Use this function to check whether opt-einsum + was installed and properly imported by torch. + """ + return _opt_einsum is not None + + +def get_opt_einsum() -> Any: + r"""Return the opt_einsum package if opt_einsum is currently available, else None.""" + return _opt_einsum + + +def _set_enabled(_enabled: bool) -> None: + if not is_available() and _enabled: + raise ValueError( + f"opt_einsum is not available, so setting `enabled` to {_enabled} will not reap " + "the benefits of calculating an optimal path for einsum. torch.einsum will " + "fall back to contracting from left to right. To enable this optimal path " + "calculation, please install opt-einsum." + ) + global enabled + enabled = _enabled + + +def _get_enabled() -> bool: + return enabled + + +def _set_strategy(_strategy: str) -> None: + if not is_available(): + raise ValueError( + f"opt_einsum is not available, so setting `strategy` to {_strategy} will not be meaningful. " + "torch.einsum will bypass path calculation and simply contract from left to right. " + "Please install opt_einsum or unset `strategy`." + ) + if not enabled: + raise ValueError( + f"opt_einsum is not enabled, so setting a `strategy` to {_strategy} will not be meaningful. " + "torch.einsum will bypass path calculation and simply contract from left to right. " + "Please set `enabled` to `True` as well or unset `strategy`." + ) + if _strategy not in ["auto", "greedy", "optimal"]: + raise ValueError( + f"`strategy` must be one of the following: [auto, greedy, optimal] but is {_strategy}" + ) + global strategy + strategy = _strategy + + +def _get_strategy() -> str: + # pyrefly: ignore [bad-return] + return strategy + + +def set_flags(_enabled=None, _strategy=None): + orig_flags = (enabled, None if not is_available() else strategy) + if _enabled is not None: + _set_enabled(_enabled) + if _strategy is not None: + _set_strategy(_strategy) + return orig_flags + + +@contextmanager +def flags(enabled=None, strategy=None): + with __allow_nonbracketed_mutation(): + orig_flags = set_flags(enabled, strategy) + try: + yield + finally: + # recover the previous values + with __allow_nonbracketed_mutation(): + set_flags(*orig_flags) + + +# The magic here is to allow us to intercept code like this: +# +# torch.backends.opt_einsum.enabled = True + + +class OptEinsumModule(PropModule): + global enabled + enabled = ContextProp(_get_enabled, _set_enabled) + global strategy + strategy = None + if is_available(): + strategy = ContextProp(_get_strategy, _set_strategy) + + +# This is the sys.modules replacement trick, see +# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273 +sys.modules[__name__] = OptEinsumModule(sys.modules[__name__], __name__) + +enabled = bool(is_available()) +strategy = "auto" if is_available() else None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/quantized/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/quantized/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..caabfdf243783f2161a201c6a6ec9bd6eca83b18 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/quantized/__init__.py @@ -0,0 +1,65 @@ +# mypy: allow-untyped-defs +import sys +import types + +import torch + + +# This function should correspond to the enums present in c10/core/QEngine.h +def _get_qengine_id(qengine: str) -> int: + if qengine == "none" or qengine == "" or qengine is None: + ret = 0 + elif qengine == "fbgemm": + ret = 1 + elif qengine == "qnnpack": + ret = 2 + elif qengine == "onednn": + ret = 3 + elif qengine == "x86": + ret = 4 + else: + ret = -1 + raise RuntimeError(f"{qengine} is not a valid value for quantized engine") + return ret + + +# This function should correspond to the enums present in c10/core/QEngine.h +def _get_qengine_str(qengine: int) -> str: + all_engines = {0: "none", 1: "fbgemm", 2: "qnnpack", 3: "onednn", 4: "x86"} + return all_engines.get(qengine, "*undefined") + + +class _QEngineProp: + def __get__(self, obj, objtype) -> str: + return _get_qengine_str(torch._C._get_qengine()) + + def __set__(self, obj, val: str) -> None: + torch._C._set_qengine(_get_qengine_id(val)) + + +class _SupportedQEnginesProp: + def __get__(self, obj, objtype) -> list[str]: + qengines = torch._C._supported_qengines() + return [_get_qengine_str(qe) for qe in qengines] + + def __set__(self, obj, val) -> None: + raise RuntimeError("Assignment not supported") + + +class QuantizedEngine(types.ModuleType): + def __init__(self, m, name): + super().__init__(name) + self.m = m + + def __getattr__(self, attr): + return self.m.__getattribute__(attr) + + engine = _QEngineProp() + supported_engines = _SupportedQEnginesProp() + + +# This is the sys.modules replacement trick, see +# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273 +sys.modules[__name__] = QuantizedEngine(sys.modules[__name__], __name__) +engine: str +supported_engines: list[str] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/xeon/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/xeon/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/xeon/run_cpu.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/xeon/run_cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..b6b6bdf78991dcc140d9fedf2be2ea3ba6dedf74 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/xeon/run_cpu.py @@ -0,0 +1,947 @@ +# mypy: allow-untyped-defs +""" +This is a script for launching PyTorch inference on Intel(R) Xeon(R) Scalable Processors with optimal configurations. + +Single instance inference, multi-instance inference are enabled. + +Note: term "instance" here doesn't refer to a cloud instance. This script is executed as a single process. It invokes +multiple "instances" which are formed from multiple threads for each. "instance" is kind of group of threads in this +context. + +Illustrated as below: + +:: + + +-----------------------------+----------------------+-------+ + | process | thread | core | + +=============================+======================+=======+ + | torch.backends.xeon.run_cpu | instance 0: thread 0 | 0 | + | | thread 1 | 1 | + | +----------------------+-------+ + | | instance 1: thread 0 | 2 | + | | thread 1 | 3 | + | +----------------------+-------+ + | | ... | ... | + | +----------------------+-------+ + | | instance N: thread 0 | M | + | | thread 1 | M+1 | + +-----------------------------+----------------------+-------+ + +To get the peak performance on Intel(R) Xeon(R) Scalable Processors, the script optimizes the configuration of thread and memory +management. For thread management, the script configures thread affinity and the preload of Intel OMP library. +For memory management, it configures NUMA binding and preload optimized memory allocation library (e.g. tcmalloc, jemalloc). + +Environment variables that will be set by this script: + ++------------------+-------------------------------------------------------------------------------------------------+ +| Environ Variable | Value | ++==================+=================================================================================================+ +| LD_PRELOAD | Depending on knobs you set, /libiomp5.so, /libjemalloc.so, /libtcmalloc.so might | +| | be appended to LD_PRELOAD. | ++------------------+-------------------------------------------------------------------------------------------------+ +| KMP_AFFINITY | If libiomp5.so is preloaded, KMP_AFFINITY could be set to "granularity=fine,compact,1,0". | ++------------------+-------------------------------------------------------------------------------------------------+ +| KMP_BLOCKTIME | If libiomp5.so is preloaded, KMP_BLOCKTIME is set to "1". | ++------------------+-------------------------------------------------------------------------------------------------+ +| OMP_NUM_THREADS | value of ncores_per_instance | ++------------------+-------------------------------------------------------------------------------------------------+ +| MALLOC_CONF | If libjemalloc.so is preloaded, MALLOC_CONF will be set to | +| | "oversize_threshold:1,background_thread:true,metadata_thp:auto". | ++------------------+-------------------------------------------------------------------------------------------------+ + +*Note*: This script respects environment variables set preliminarily. I.e. If you set the environment variables +mentioned above before running the script, the script will not overwrite the values in the script. + +How to use this module: +~~~~~~~~~~~~~~~~~~~~~~~ + +Single instance inference +------------------------- + +1. Run single-instance inference on a single node with all CPU nodes. + +:: + + python -m torch.backends.xeon.run_cpu --throughput-mode script.py args + +2. Run single-instance inference on a single CPU node. + +:: + + python -m torch.backends.xeon.run_cpu --node-id 1 script.py args + +Multi-instance inference +------------------------ + +1. Multi-instance + By default this tool runs one process per node. If you want to set the instance numbers and core per instance, + --ninstances and --ncores-per-instance should be set. + +:: + + python -m torch.backends.xeon.run_cpu -- python_script args + + eg: on an Intel(R) Xeon(R) Scalable Processor with 14 instance, 4 cores per instance + +:: + + python -m torch.backends.xeon.run_cpu --ninstances 14 --ncores-per-instance 4 python_script args + +2. Run single-instance inference among multiple instances. + By default, runs all ninstances. If you want to independently run a single instance among ninstances, specify rank. + + eg: run 0th instance on an Intel(R) Xeon(R) Scalable Processor with 2 instance (i.e., numactl -C 0-27) + +:: + + python -m torch.backends.xeon.run_cpu --ninstances 2 --rank 0 python_script args + + eg: run 1st instance on an Intel(R) Xeon(R) Scalable Processor with 2 instance (i.e., numactl -C 28-55) + +:: + + python -m torch.backends.xeon.run_cpu --ninstances 2 --rank 1 python_script args + + eg: run 0th instance on an Intel(R) Xeon(R) Scalable Processor with 2 instance, 2 cores per instance, + first four cores (i.e., numactl -C 0-1) + +:: + + python -m torch.backends.xeon.run_cpu --core-list "0, 1, 2, 3" --ninstances 2 --ncores-per-instance 2 + --rank 0 python_script args + +3. To look up what optional arguments this module offers: + +:: + + python -m torch.backends.xeon.run_cpu --help + +Memory allocator +---------------- + +"--enable-tcmalloc" and "--enable-jemalloc" can be used to enable different memory allocator. + +""" + +import glob +import logging +import os +import platform +import re +import subprocess +import sys +from argparse import ArgumentParser, RawTextHelpFormatter, REMAINDER +from os.path import expanduser + +from torch.distributed.elastic.multiprocessing import ( + DefaultLogsSpecs as _DefaultLogsSpecs, + start_processes, + Std, +) + + +format_str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" +logging.basicConfig(level=logging.INFO, format=format_str) +logger = logging.getLogger(__name__) + + +class _CPUinfo: + """Get CPU information, such as cores list and NUMA information.""" + + def __init__(self, test_input=""): + self.cpuinfo = [] + if platform.system() in ["Windows", "Darwin"]: + raise RuntimeError(f"{platform.system()} is not supported!!!") + elif platform.system() == "Linux": + # Sample output of: `lscpu --parse=CPU,Core,Socket,Node` + # + # # The following is the parsable format, which can be fed to other + # # programs. Each different item in every column has an unique ID + # # starting from zero. + # # CPU,Core,Socket,Node + # 0,0,0,0 + # 1,1,0,0 + # ... + if test_input == "": + lscpu_cmd = ["lscpu", "--parse=CPU,Core,Socket,Node"] + lscpu_info = subprocess.check_output( + lscpu_cmd, universal_newlines=True + ).split("\n") + else: + lscpu_info = test_input.split("\n") + + # Get information about cpu, core, socket and node + for line in lscpu_info: + pattern = r"^([\d]+,[\d]+,[\d]+,[\d]?)" + regex_out = re.search(pattern, line) + if regex_out: + self.cpuinfo.append(regex_out.group(1).strip().split(",")) + + # physical cores := core column in lscpu output + # logical cores := cPU column in lscpu output + self.node_nums = int(max(line[3] for line in self.cpuinfo)) + 1 + self.node_physical_cores: list[list[int]] = [] # node_id is index + self.node_logical_cores: list[list[int]] = [] # node_id is index + self.physical_core_node_map = {} # physical core to numa node id + self.logical_core_node_map = {} # logical core to numa node id + + for node_id in range(self.node_nums): + cur_node_physical_core = [] + cur_node_logical_core = [] + for cpuinfo in self.cpuinfo: + nid = cpuinfo[3] if cpuinfo[3] != "" else "0" + if node_id == int(nid): + if int(cpuinfo[1]) not in cur_node_physical_core: + cur_node_physical_core.append(int(cpuinfo[1])) + self.physical_core_node_map[int(cpuinfo[1])] = int(node_id) + cur_node_logical_core.append(int(cpuinfo[0])) + self.logical_core_node_map[int(cpuinfo[0])] = int(node_id) + self.node_physical_cores.append(cur_node_physical_core) + self.node_logical_cores.append(cur_node_logical_core) + + def _physical_core_nums(self): + return len(self.node_physical_cores) * len(self.node_physical_cores[0]) + + def _logical_core_nums(self): + return len(self.node_logical_cores) * len(self.node_logical_cores[0]) + + def get_node_physical_cores(self, node_id): + if node_id < 0 or node_id > self.node_nums - 1: + raise ValueError( + f"Invalid node id: {node_id}. Valid node ids: {list(range(len(self.node_physical_cores)))}" + ) + return self.node_physical_cores[node_id] + + def get_node_logical_cores(self, node_id): + if node_id < 0 or node_id > self.node_nums - 1: + raise ValueError( + f"Invalid node id: {node_id}. Valid node ids: {list(range(len(self.node_physical_cores)))}" + ) + return self.node_logical_cores[node_id] + + def get_all_physical_cores(self): + all_cores = [] + for cores in self.node_physical_cores: + all_cores.extend(cores) + return all_cores + + def get_all_logical_cores(self): + all_cores = [] + for cores in self.node_logical_cores: + all_cores.extend(cores) + return all_cores + + def numa_aware_check(self, core_list): + """ + Check whether all cores in core_list are in the same NUMA node. + + Cross NUMA will reduce performance. + We strongly advice to not use cores on different nodes. + """ + cores_numa_map = self.logical_core_node_map + numa_ids = [] + for core in core_list: + numa_id = cores_numa_map[core] + if numa_id not in numa_ids: + numa_ids.append(numa_id) + if len(numa_ids) > 1: + logger.warning( + "Numa Aware: cores:%s on different NUMA nodes:%s. To avoid \ +this behavior, please use --ncores-per-instance knob to make sure number of cores is divisible by --ncores-per-\ +instance. Alternatively, please use --skip-cross-node-cores knob.", + str(core_list), + str(numa_ids), + ) + if len(numa_ids) == 0: + raise RuntimeError( + "invalid number of NUMA nodes; please make sure numa_ids >= 1" + ) + return numa_ids + + +class _Launcher: + r"""Class for launcher.""" + + msg_lib_notfound = ( + f"Unable to find the {{0}} library file lib{{1}}.so in $CONDA_PREFIX/lib or $VIRTUAL_ENV/lib \ +or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib64 or \ +{expanduser('~')}/.local/lib/ so the LD_PRELOAD environment variable will not be set." + ) + + def __init__(self) -> None: + self.cpuinfo = _CPUinfo() + + def add_lib_preload(self, lib_type): + """Enable TCMalloc/JeMalloc/intel OpenMP.""" + library_paths = [] + if "CONDA_PREFIX" in os.environ: + library_paths.append(f"{os.environ['CONDA_PREFIX']}/lib") + if "VIRTUAL_ENV" in os.environ: + library_paths.append(f"{os.environ['VIRTUAL_ENV']}/lib") + + library_paths += [ + f"{expanduser('~')}/.local/lib", + "/usr/local/lib", + "/usr/local/lib64", + "/usr/lib", + "/usr/lib64", + ] + + lib_find = False + lib_set = False + for item in os.getenv("LD_PRELOAD", "").split(":"): + if item.endswith(f"lib{lib_type}.so"): + lib_set = True + break + if not lib_set: + for lib_path in library_paths: + library_file = os.path.join(lib_path, f"lib{lib_type}.so") + matches = glob.glob(library_file) + if len(matches) > 0: + ld_preloads = [f"{matches[0]}", os.getenv("LD_PRELOAD", "")] + os.environ["LD_PRELOAD"] = os.pathsep.join( + [p.strip(os.pathsep) for p in ld_preloads if p] + ) + lib_find = True + break + return lib_set or lib_find + + def is_numactl_available(self): + numactl_available = False + try: + cmd = ["numactl", "-C", "0", "-m", "0", "hostname"] + r = subprocess.run( + cmd, + env=os.environ, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + check=False, + ) + if r.returncode == 0: + numactl_available = True + except Exception: + pass + return numactl_available + + def set_memory_allocator( + self, enable_tcmalloc=True, enable_jemalloc=False, use_default_allocator=False + ): + """ + Enable TCMalloc/JeMalloc with LD_PRELOAD and set configuration for JeMalloc. + + By default, PTMalloc will be used for PyTorch, but TCMalloc and JeMalloc can get better + memory reuse and reduce page fault to improve performance. + """ + if enable_tcmalloc and enable_jemalloc: + raise RuntimeError( + "Unable to enable TCMalloc and JEMalloc at the same time." + ) + + if enable_tcmalloc: + find_tc = self.add_lib_preload(lib_type="tcmalloc") + if not find_tc: + msg = f'{self.msg_lib_notfound} you can use "conda install -c conda-forge gperftools" to install {{0}}' + logger.warning(msg.format("TCmalloc", "tcmalloc")) # noqa: G001 + else: + logger.info("Use TCMalloc memory allocator") + + elif enable_jemalloc: + find_je = self.add_lib_preload(lib_type="jemalloc") + if not find_je: + msg = f'{self.msg_lib_notfound} you can use "conda install -c conda-forge jemalloc" to install {{0}}' + logger.warning(msg.format("Jemalloc", "jemalloc")) # noqa: G001 + else: + logger.info("Use JeMalloc memory allocator") + self.set_env( + "MALLOC_CONF", + "oversize_threshold:1,background_thread:true,metadata_thp:auto", + ) + + elif use_default_allocator: + pass + + else: + find_tc = self.add_lib_preload(lib_type="tcmalloc") + if find_tc: + logger.info("Use TCMalloc memory allocator") + return + find_je = self.add_lib_preload(lib_type="jemalloc") + if find_je: + logger.info("Use JeMalloc memory allocator") + return + logger.warning( + """Neither TCMalloc nor JeMalloc is found in $CONDA_PREFIX/lib or $VIRTUAL_ENV/lib + or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib64 or + %s/.local/lib/ so the LD_PRELOAD environment variable will not be set. + This may drop the performance""", + expanduser("~"), + ) + + def log_env_var(self, env_var_name=""): + if env_var_name in os.environ: + logger.info("%s=%s", env_var_name, os.environ[env_var_name]) + + def set_env(self, env_name, env_value): + if not env_value: + logger.warning("%s is None", env_name) + if env_name not in os.environ: + os.environ[env_name] = env_value + elif os.environ[env_name] != env_value: + logger.warning( + "Overriding value with the one set in environment variable: %s. \ +Value applied: %s. Value ignored: %s", + env_name, + os.environ[env_name], + env_value, + ) + self.log_env_var(env_name) + + # set_kmp_affinity is used to control whether to set KMP_AFFINITY or not. + # In scenario that use all cores on all nodes, including logical cores, setting KMP_AFFINITY disables logical cores. + # In this case, KMP_AFFINITY should not be set. + def set_multi_thread_and_allocator( + self, + ncores_per_instance, + disable_iomp=False, + set_kmp_affinity=True, + enable_tcmalloc=True, + enable_jemalloc=False, + use_default_allocator=False, + ): + """ + Set multi-thread configuration and enable Intel openMP and TCMalloc/JeMalloc. + + By default, GNU openMP and PTMalloc are used in PyTorch. but Intel openMP and TCMalloc/JeMalloc are better alternatives + to get performance benefit. + """ + self.set_memory_allocator( + enable_tcmalloc, enable_jemalloc, use_default_allocator + ) + self.set_env("OMP_NUM_THREADS", str(ncores_per_instance)) + if not disable_iomp: + find_iomp = self.add_lib_preload(lib_type="iomp5") + if not find_iomp: + msg = f'{self.msg_lib_notfound} you can use "conda install mkl" to install {{0}}' + logger.warning(msg.format("iomp", "iomp5")) # noqa: G001 + else: + logger.info("Using Intel OpenMP") + if set_kmp_affinity: + self.set_env("KMP_AFFINITY", "granularity=fine,compact,1,0") + self.set_env("KMP_BLOCKTIME", "1") + self.log_env_var("LD_PRELOAD") + + r""" + Launcher for single instance and multi-instance + """ + + def launch(self, args): + cores = [] + set_kmp_affinity = True + enable_taskset = False + if args.core_list: # user specify what cores will be used by params + cores = [int(x) for x in args.core_list.split(",")] + if args.ncores_per_instance == -1: + raise RuntimeError( + 'please specify the "--ncores-per-instance" if you have pass the --core-list params' + ) + elif ( + args.ninstances > 1 + and args.ncores_per_instance * args.ninstances < len(cores) + ): + logger.warning( + "only first %s cores will be used, \ +but you specify %s cores in core_list", + args.ncores_per_instance * args.ninstances, + len(cores), + ) + else: + args.ninstances = len(cores) // args.ncores_per_instance + + else: + if args.use_logical_core: + if args.node_id != -1: + cores = self.cpuinfo.get_node_logical_cores(args.node_id) + else: + cores = self.cpuinfo.get_all_logical_cores() + # When using all cores on all nodes, including logical cores, + # setting KMP_AFFINITY disables logical cores. Thus, KMP_AFFINITY should not be set. + set_kmp_affinity = False + else: + if args.node_id != -1: + cores = self.cpuinfo.get_node_physical_cores(args.node_id) + else: + cores = self.cpuinfo.get_all_physical_cores() + if ( + not args.multi_instance + and args.ninstances == -1 + and args.ncores_per_instance == -1 + ): + args.ninstances = 1 + args.ncores_per_instance = len(cores) + elif ( + args.multi_instance + and args.ninstances == -1 + and args.ncores_per_instance == -1 + ): + args.throughput_mode = True + elif args.ncores_per_instance == -1 and args.ninstances != -1: + if args.ninstances > len(cores): + raise RuntimeError( + f"there are {len(cores)} total cores but you specify {args.ninstances} ninstances; \ +please make sure ninstances <= total_cores)" + ) + else: + args.ncores_per_instance = len(cores) // args.ninstances + elif args.ncores_per_instance != -1 and args.ninstances == -1: + if not args.skip_cross_node_cores: + args.ninstances = len(cores) // args.ncores_per_instance + else: + ncore_per_node = len(self.cpuinfo.node_physical_cores[0]) + num_leftover_cores = ncore_per_node % args.ncores_per_instance + if args.ncores_per_instance > ncore_per_node: + # too many ncores_per_instance to skip cross-node cores + logger.warning( + "there are %s core(s) per socket, but you specify %s ncores_per_instance and \ +skip_cross_node_cores. Please make sure --ncores-per-instance < core(s) per \ +socket", + ncore_per_node, + args.ncores_per_instance, + ) + sys.exit(-1) + elif num_leftover_cores == 0: + # aren't any cross-node cores + logger.info( + "--skip-cross-node-cores is set, but there are no cross-node cores." + ) + args.ninstances = len(cores) // args.ncores_per_instance + else: + # skip cross-node cores + if args.ninstances != -1: + logger.warning( + "--skip-cross-node-cores is exclusive to --ninstances. --ninstances \ +won't take effect even if it is set explicitly." + ) + + i = 1 + leftover_cores = set() + while ncore_per_node * i <= len(cores): + leftover_cores.update( + cores[ + ncore_per_node * i + - num_leftover_cores : ncore_per_node * i + ] + ) + i += 1 + cores = list(set(cores) - leftover_cores) + assert len(cores) % args.ncores_per_instance == 0 + args.ninstances = len(cores) // args.ncores_per_instance + else: + if args.ninstances * args.ncores_per_instance > len(cores): + raise RuntimeError( + "Please make sure ninstances * ncores_per_instance <= total_cores" + ) + if args.latency_mode: + logger.warning( + "--latency-mode is exclusive to --ninstances, --ncores-per-instance, --node-id and \ +--use-logical-core. They won't take effect even they are set explicitly." + ) + args.ncores_per_instance = 4 + cores = self.cpuinfo.get_all_physical_cores() + args.ninstances = len(cores) // args.ncores_per_instance + + if args.throughput_mode: + logger.warning( + "--throughput-mode is exclusive to --ninstances, --ncores-per-instance, --node-id and \ +--use-logical-core. They won't take effect even they are set explicitly." + ) + args.ninstances = self.cpuinfo.node_nums + cores = self.cpuinfo.get_all_physical_cores() + args.ncores_per_instance = len(cores) // args.ninstances + + if args.ninstances > 1 and args.rank != -1: + logger.info( + "assigning %s cores for instance %s", + args.ncores_per_instance, + args.rank, + ) + + if not args.disable_numactl: + numactl_available = self.is_numactl_available() + if not numactl_available: + if not args.disable_taskset: + logger.warning( + "Core binding with numactl is not available. Disabling numactl and using taskset instead. \ + This may affect performance in multi-socket system; please use numactl if memory binding is needed." + ) + args.disable_numactl = True + enable_taskset = True + else: + logger.warning( + "Core binding with numactl is not available, and --disable_taskset is set. \ + Please unset --disable_taskset to use taskset instead of numactl." + ) + sys.exit(-1) + + if not args.disable_taskset: + enable_taskset = True + + self.set_multi_thread_and_allocator( + args.ncores_per_instance, + args.disable_iomp, + set_kmp_affinity, + args.enable_tcmalloc, + args.enable_jemalloc, + args.use_default_allocator, + ) + entrypoint = "" + launch_args = {} + launch_envs: dict[int, dict] = {} + launch_tee = {} + # check whether is launched from torchrun with --nproc-per-node + local_size = int(os.environ.get("LOCAL_WORLD_SIZE", 1)) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + for i in range(args.ninstances): + cmd = [] + cur_process_cores = "" + if not args.disable_numactl or enable_taskset: + if not args.disable_numactl: + cmd = ["numactl"] + elif enable_taskset: + cmd = ["taskset"] + cores = sorted(cores) + if ( + args.rank == -1 + ): # sequentially assign ncores_per_instance to ninstances + core_list = cores[ + i * args.ncores_per_instance : (i + 1) + * args.ncores_per_instance + ] + else: # assign ncores_per_instance from rank + core_list = cores[ + args.rank * args.ncores_per_instance : (args.rank + 1) + * args.ncores_per_instance + ] + + core_ranges: list[dict] = [] + if local_size > 1: + total_num_cores = len(core_list) + cores_per_rank = total_num_cores // local_size + assert cores_per_rank >= 1, ( + "At least one core needs to be assigned to each rank" + ) + core_list = core_list[ + cores_per_rank * local_rank : cores_per_rank * (local_rank + 1) + ] + for core in core_list: + if len(core_ranges) == 0: + range_elem = {"start": core, "end": core} + core_ranges.append(range_elem) + else: + if core - core_ranges[-1]["end"] == 1: + core_ranges[-1]["end"] = core + else: + range_elem = {"start": core, "end": core} + core_ranges.append(range_elem) + for r in core_ranges: + cur_process_cores = f"{cur_process_cores}{r['start']}-{r['end']}," + cur_process_cores = cur_process_cores[:-1] + if not args.disable_numactl: + numa_params = f"-C {cur_process_cores} " + numa_ids = ",".join( + [ + str(numa_id) + for numa_id in self.cpuinfo.numa_aware_check(core_list) + ] + ) + numa_params += f"-m {numa_ids}" + cmd.extend(numa_params.split()) + elif enable_taskset: + taskset_params = f"-c {cur_process_cores} " + cmd.extend(taskset_params.split()) + with_python = not args.no_python + if with_python: + cmd.append(sys.executable) + cmd.append("-u") + if args.module: + cmd.append("-m") + cmd.append(args.program) + cmd.extend(args.program_args) + cmd_s = " ".join(cmd) + logger.info(cmd_s) + if entrypoint == "": + entrypoint = cmd[0] + del cmd[0] + launch_args[i] = tuple(cmd) + launch_envs[i] = {} + launch_tee[i] = Std.ALL + + if args.rank != -1: # launches single instance, rank, only + break + + ctx = start_processes( + name=args.log_file_prefix, + entrypoint=entrypoint, + args=launch_args, + envs=launch_envs, + logs_specs=_DefaultLogsSpecs(log_dir=args.log_path, tee=launch_tee), + ) + ctx.wait() + + +def _add_memory_allocator_params(parser): + group = parser.add_argument_group("Memory Allocator Parameters") + # allocator control + group.add_argument( + "--enable-tcmalloc", + "--enable_tcmalloc", + action="store_true", + default=False, + help="Enable tcmalloc allocator", + ) + group.add_argument( + "--enable-jemalloc", + "--enable_jemalloc", + action="store_true", + default=False, + help="Enable jemalloc allocator", + ) + group.add_argument( + "--use-default-allocator", + "--use_default_allocator", + action="store_true", + default=False, + help="Use default memory allocator", + ) + + +def _add_multi_instance_params(parser): + group = parser.add_argument_group("Multi-instance Parameters") + # multi-instance control + group.add_argument( + "--ncores-per-instance", + "--ncores_per_instance", + metavar="\b", + default=-1, + type=int, + help="Cores per instance", + ) + group.add_argument( + "--ninstances", + metavar="\b", + default=-1, + type=int, + help="For multi-instance, you should give the cores number you used for per instance.", + ) + group.add_argument( + "--skip-cross-node-cores", + "--skip_cross_node_cores", + action="store_true", + default=False, + help="If specified --ncores-per-instance, skips cross-node cores.", + ) + group.add_argument( + "--rank", + metavar="\b", + default="-1", + type=int, + help="Specify instance index to assign ncores_per_instance for rank; \ +otherwise ncores_per_instance will be assigned sequentially to ninstances. Please refer to \ +https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md", + ) + group.add_argument( + "--latency-mode", + "--latency_mode", + action="store_true", + default=False, + help="By default 4 core per instance and use all physical cores", + ) + group.add_argument( + "--throughput-mode", + "--throughput_mode", + action="store_true", + default=False, + help="By default one instance per node and use all physical cores", + ) + group.add_argument( + "--node-id", + "--node_id", + metavar="\b", + default=-1, + type=int, + help="node id for multi-instance, by default all nodes will be used", + ) + group.add_argument( + "--use-logical-core", + "--use_logical_core", + action="store_true", + default=False, + help="Whether only use physical cores", + ) + group.add_argument( + "--disable-numactl", + "--disable_numactl", + action="store_true", + default=False, + help="Disable numactl", + ) + group.add_argument( + "--disable-taskset", + "--disable_taskset", + action="store_true", + default=False, + help="Disable taskset", + ) + group.add_argument( + "--core-list", + "--core_list", + metavar="\b", + default=None, + type=str, + help='Specify the core list as "core_id, core_id, ....", otherwise, all the cores will be used.', + ) + group.add_argument( + "--log-path", + "--log_path", + metavar="\b", + default="", + type=str, + help="The log file directory. Default path is " + ", which means disable logging to files.", + ) + group.add_argument( + "--log-file-prefix", + "--log_file_prefix", + metavar="\b", + default="run", + type=str, + help="log file prefix", + ) + + +def _add_kmp_iomp_params(parser): + group = parser.add_argument_group("IOMP Parameters") + group.add_argument( + "--disable-iomp", + "--disable_iomp", + action="store_true", + default=False, + help="By default, we use Intel OpenMP and libiomp5.so will be add to LD_PRELOAD", + ) + + +def create_args(parser=None): + """ + Parse the command line options. + + @retval ArgumentParser + """ + # pyrefly: ignore [missing-attribute] + parser.add_argument( + "--multi-instance", + "--multi_instance", + action="store_true", + default=False, + help="Enable multi-instance, by default one instance per node", + ) + + # pyrefly: ignore [missing-attribute] + parser.add_argument( + "-m", + "--module", + default=False, + action="store_true", + help="Changes each process to interpret the launch script " + "as a python module, executing with the same behavior as" + '"python -m".', + ) + + # pyrefly: ignore [missing-attribute] + parser.add_argument( + "--no-python", + "--no_python", + default=False, + action="store_true", + help='Do not prepend the --program script with "python" - just exec ' + "it directly. Useful when the script is not a Python script.", + ) + + _add_memory_allocator_params(parser) + _add_kmp_iomp_params(parser) + + _add_multi_instance_params(parser) + # positional + # pyrefly: ignore [missing-attribute] + parser.add_argument( + "program", + type=str, + help="The full path to the program/script to be launched. " + "followed by all the arguments for the script", + ) + + # rest from the training program + # pyrefly: ignore [missing-attribute] + parser.add_argument("program_args", nargs=REMAINDER) + + +def main(args): + env_before = set(os.environ.keys()) + if platform.system() in ["Windows", "Darwin"]: + raise RuntimeError(f"{platform.system()} is not supported!!!") + + if args.log_path: + os.makedirs(args.log_path, exist_ok=True) + else: + args.log_path = os.devnull + + if args.latency_mode and args.throughput_mode: + raise RuntimeError( + "Either args.latency_mode or args.throughput_mode should be set" + ) + + if not args.no_python and not args.program.endswith(".py"): + raise RuntimeError( + 'For non Python script, you should use "--no-python" parameter.' + ) + + # Verify LD_PRELOAD + if "LD_PRELOAD" in os.environ: + lst_valid = [] + tmp_ldpreload = os.environ["LD_PRELOAD"] + for item in tmp_ldpreload.split(":"): + matches = glob.glob(item) + if len(matches) > 0: + lst_valid.append(item) + else: + logger.warning("%s doesn't exist. Removing it from LD_PRELOAD.", item) + if len(lst_valid) > 0: + os.environ["LD_PRELOAD"] = ":".join(lst_valid) + else: + os.environ["LD_PRELOAD"] = "" + + launcher = _Launcher() + launcher.launch(args) + for x in sorted(set(os.environ.keys()) - env_before): + logger.debug("%s=%s", x, os.environ[x]) + + +if __name__ == "__main__": + parser = ArgumentParser( + description="This is a script for launching PyTorch inference on Intel(R) Xeon(R) Scalable " + "Processors with optimal configurations. Single instance inference, " + "multi-instance inference are enable. To get the peak performance on Intel(R) " + "Xeon(R) Scalable Processors, the script optimizes the configuration " + "of thread and memory management. For thread management, the script configures thread " + "affinity and the preload of Intel OMP library. For memory management, it configures " + "NUMA binding and preload optimized memory allocation library (e.g. tcmalloc, jemalloc) " + "\n################################# Basic usage ############################# \n" + "\n 1. single instance\n" + "\n >>> python -m torch.backends.xeon.run_cpu python_script args \n" + "\n2. multi-instance \n" + "\n >>> python -m torch.backends.xeon.run_cpu --ninstances xxx " + "--ncores-per-instance xx python_script args\n" + "\n############################################################################# \n", + formatter_class=RawTextHelpFormatter, + ) + create_args(parser) + args = parser.parse_args() + main(args) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/xnnpack/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/xnnpack/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..31e69876927d01878a9d1cb836d72fd14adf95e9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/backends/xnnpack/__init__.py @@ -0,0 +1,29 @@ +# mypy: allow-untyped-defs +import sys +import types + +import torch + + +class _XNNPACKEnabled: + def __get__(self, obj, objtype): + return torch._C._is_xnnpack_enabled() + + def __set__(self, obj, val): + raise RuntimeError("Assignment not supported") + + +class XNNPACKEngine(types.ModuleType): + def __init__(self, m, name): + super().__init__(name) + self.m = m + + def __getattr__(self, attr): + return self.m.__getattribute__(attr) + + enabled = _XNNPACKEnabled() + + +# This is the sys.modules replacement trick, see +# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273 +sys.modules[__name__] = XNNPACKEngine(sys.modules[__name__], __name__) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/FlushDenormal.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/FlushDenormal.h new file mode 100644 index 0000000000000000000000000000000000000000..5e3d0ffbd71a5a4dacd80594e3c7222fe2be0a8e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/FlushDenormal.h @@ -0,0 +1,19 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/// Flush-To-Zero and Denormals-Are-Zero mode +/// +/// Flush-To-Zero (FTZ) and Denormals-Are-Zero (DAZ) are modes that bypass +/// IEEE 754 methods of dealing with denormal floating-point numbers on x86-64 +/// and some x86 CPUs. They result in reduced precision for values near zero, +/// but increased performance. +/// +/// See https://software.intel.com/en-us/articles/x87-and-sse-floating-point-assists-in-ia-32-flush-to-zero-ftz-and-denormals-are-zero-daz + +namespace at::cpu { + +bool set_flush_denormal(bool on); + +} // namespace at::cpu + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/Utils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/Utils.h new file mode 100644 index 0000000000000000000000000000000000000000..b2b9a3e9c1051bcf10b0ac1ea57364771062f2b6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/Utils.h @@ -0,0 +1,38 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +#include + +namespace at::cpu { + +TORCH_API bool is_avx2_supported(); +TORCH_API bool is_avx512_supported(); + +// Detect if CPU support Vector Neural Network Instruction. +TORCH_API bool is_avx512_vnni_supported(); + +// Detect if CPU supports AVX512_BF16 ISA +TORCH_API bool is_avx512_bf16_supported(); + +// Detect if CPU support Advanced Matrix Extension. +TORCH_API bool is_amx_tile_supported(); + +// Detect if CPU support Advanced Matrix Extension for fp16. +TORCH_API bool is_amx_fp16_supported(); + +// Enable the system to use AMX instructions. +TORCH_API bool init_amx(); + +// Get the L1 cache size per core in Byte +TORCH_API uint32_t L1d_cache_size(); + +// Get the L2 cache size per core in Byte +TORCH_API uint32_t L2_cache_size(); + +} // namespace at::cpu + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vml.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vml.h new file mode 100644 index 0000000000000000000000000000000000000000..600c38cfe964817f99c81c8c5c4edbeaabee3fea --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vml.h @@ -0,0 +1,175 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include + +// This header implements various unary operations using a MKL VML style +// interface. + +// It implements various functions with a simple interface +// For example it enables the user to call vsin(float* out, const float* in, +// size) This functions takes a pointer to a continuous output array of floats and +// a constant input array. It will then apply sin to each value in the input +// array and write the result into the output array. out and in may point to the +// same memory, i.e. this fully supports in-place operations. These functions +// also implement their own parallelization, so take precautions when calling +// these from threaded functions. + +// When MKL is available it will call into MKL's VML library similar to NumPy +// If MKL is not available it will use SLEEF. + +// This file might be compiled under AVX or AVX2 when called from e.g. +// UnaryOpsKernel.cpp + +#include +#include +#include +#include +#include + +#if AT_MKL_ENABLED() && !defined(__APPLE__) +#include +#endif + + +namespace at::vml { +inline namespace CPU_CAPABILITY { + +using namespace vec; + +template +inline void vrsqrt(scalar_t* out, scalar_t* in, int64_t size) { + parallel_for(0, size, 2048, [out, in](int64_t begin, int64_t end) { + map( + [](const Vectorized& x) { + return Vectorized((scalar_t)1) / x.sqrt(); + }, + out + begin, + in + begin, + end - begin); + }); +} + +// NB: We ignore numerical errors by convention and leave them to the user + +#define IMPLEMENT_VML(op) \ + template \ + inline void v##op(scalar_t* out, const scalar_t* in, int64_t size) { \ + using vec_t = Vectorized>; \ + vec::map([](vec_t x) { return x.op(); }, out, in, size); \ + } \ + +IMPLEMENT_VML(abs) +IMPLEMENT_VML(acos) +IMPLEMENT_VML(asin) +IMPLEMENT_VML(atan) +IMPLEMENT_VML(atanh) +IMPLEMENT_VML(ceil) +IMPLEMENT_VML(cos) +// IMPLEMENT_VML(cosh) +IMPLEMENT_VML(erf) +IMPLEMENT_VML(erfc) +IMPLEMENT_VML(erfinv) +IMPLEMENT_VML(exp) +IMPLEMENT_VML(expm1) +IMPLEMENT_VML(floor) +IMPLEMENT_VML(i0) +IMPLEMENT_VML(i0e) +IMPLEMENT_VML(digamma) +IMPLEMENT_VML(reciprocal) +IMPLEMENT_VML(log) +IMPLEMENT_VML(log10) +IMPLEMENT_VML(log1p) +IMPLEMENT_VML(log2) +IMPLEMENT_VML(neg) +IMPLEMENT_VML(sin) +// IMPLEMENT_VML(sinh) +IMPLEMENT_VML(sqrt) +IMPLEMENT_VML(round) +IMPLEMENT_VML(rsqrt) +IMPLEMENT_VML(tan) +IMPLEMENT_VML(tanh) +IMPLEMENT_VML(trunc) +IMPLEMENT_VML(lgamma) + + +#if AT_MKL_ENABLED() && !defined(__APPLE__) + +// NB: LP64 MKL is the most commonly used and thus we assume it here. That means +// we need to expect MKL_INT to be of type int, which implies int32_t or int64_t in most +// cases. +static_assert( + std::is_same_v || std::is_same_v, + "MKL_INT is assumed to be int32_t or int64_t"); +#define IMPLEMENT_VML_MKL_STUB(op, mklop, type, mkltype) \ + template <> \ + inline void v##op(type * out, const type * in, int64_t size) { \ + auto constexpr max_mkl_ind = std::numeric_limits::max(); \ + if (size <= static_cast(max_mkl_ind)) { \ + vm##mkltype##mklop( \ + size, in, out, VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \ + } else { \ + int64_t ind = 0; \ + int64_t chunks = size / max_mkl_ind; \ + int64_t rest = size % max_mkl_ind; \ + for (; ind < chunks; ind++) { \ + vm##mkltype##mklop( \ + max_mkl_ind, \ + in + ind * max_mkl_ind, \ + out + ind * max_mkl_ind, \ + VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \ + } \ + vm##mkltype##mklop( \ + rest, \ + in + ind * max_mkl_ind, \ + out + ind * max_mkl_ind, \ + VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \ + } \ + } + +#define IMPLEMENT_VML_MKL(op, mklop) \ + IMPLEMENT_VML_MKL_STUB(op, mklop, float, s) \ + IMPLEMENT_VML_MKL_STUB(op, mklop, double, d) + +// NB: abs, cosh and sinh were temporarily disabled due to issues with Apple +// NB: expm1 is disabled because on some configs it produces expm1(nan)=-1 +IMPLEMENT_VML_MKL(acos, Acos) +IMPLEMENT_VML_MKL(asin, Asin) +IMPLEMENT_VML_MKL(atan, Atan) +IMPLEMENT_VML_MKL(cos, Cos) +// IMPLEMENT_VML_MKL(cosh, Cosh) +IMPLEMENT_VML_MKL(erf, Erf) +IMPLEMENT_VML_MKL(erfc, Erfc) +IMPLEMENT_VML_MKL(erfinv, ErfInv) +IMPLEMENT_VML_MKL(exp, Exp) +// IMPLEMENT_VML_MKL(expm1, Expm1) +IMPLEMENT_VML_MKL(log, Ln) +IMPLEMENT_VML_MKL(log10, Log10) +IMPLEMENT_VML_MKL(sin, Sin) +// IMPLEMENT_VML_MKL(sinh, Sinh) +IMPLEMENT_VML_MKL(sqrt, Sqrt) +IMPLEMENT_VML_MKL(tan, Tan) +IMPLEMENT_VML_MKL(tanh, Tanh) +IMPLEMENT_VML_MKL(trunc, Trunc) + +// Not vectorized in MKL version tested +// IMPLEMENT_VML_MKL(abs, Abs) +// IMPLEMENT_VML_MKL(log1p, Log1p) + +#if INTEL_MKL_VERSION >= 20180406 +IMPLEMENT_VML_MKL(log2, Log2) +#endif + +#endif + +} // namespace +} // namespace at::vml + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/BLASConstants.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/BLASConstants.h new file mode 100644 index 0000000000000000000000000000000000000000..29060f6488fd51cd14b636d4e4bf53e5292723f3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/BLASConstants.h @@ -0,0 +1,16 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace at::cuda::detail { + +float *get_cublas_device_one(); +float *get_cublas_device_zero(); +float *get_user_alpha_ptr(); + +} // namespace at::cuda::detail + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/UnpackRaw.cuh b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/UnpackRaw.cuh new file mode 100644 index 0000000000000000000000000000000000000000..c6f4b3941744ded405f7273f76f4bf182d008c3d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cuda/detail/UnpackRaw.cuh @@ -0,0 +1,39 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// No "#pragma once" because this is a raw definition that can be copied by jit codegen. +// Eager mode clients should not include this file directly, instead, +// they should #include , which has a #pragma once. + +namespace at::cuda::philox { + +// In-kernel call to retrieve philox seed and offset from a PhiloxCudaState instance whether +// that instance was created with graph capture underway or not. +// See Note [CUDA Graph-safe RNG states]. +// +// We can't write a __device__ function in CUDAGeneratorImpl.h, because it's in ATen. +// Also, whatever call unpacks PhiloxCudaState in consumer kernels must be inlineable. +// Easiest thing that comes to mind is, define a __device__ unpack helper here, in ATen/cuda. +// +// The raw definition lives in its own file so jit codegen can easily copy it. +__host__ __device__ __forceinline__ std::tuple +unpack(at::PhiloxCudaState arg) { + if (arg.captured_) { + // static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long". + // *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel. + // For most threads' reads it will hit in cache, so it shouldn't hurt performance. + return std::make_tuple(static_cast(*arg.seed_.ptr), static_cast(*(arg.offset_.ptr) + arg.offset_intragraph_)); + } else { + return std::make_tuple(arg.seed_.val, arg.offset_.val); + } +} + +// Adapted from TE +// extract seed and offset from PhiloxCudaState +__global__ void unpack_cudnn(at::PhiloxCudaState arg, int64_t* seed_ptr, int64_t* offset_ptr); + +void unpack_cudnn_wrapper(at::PhiloxCudaState arg, int64_t* seed_ptr, int64_t* offset_ptr, cudaStream_t stream); + +} // namespace at::cuda::philox + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Descriptors.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Descriptors.h new file mode 100644 index 0000000000000000000000000000000000000000..e0f972da4ea1a69db273756e32b04b9c09309729 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Descriptors.h @@ -0,0 +1,210 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +#include +#include +#include +#include + +namespace at { namespace native { + +std::string miopenTypeToString(miopenDataType_t dtype); + +inline int dataSize(miopenDataType_t dataType) +{ + switch (dataType) { + case miopenHalf: return 2; + case miopenFloat: return 4; + case miopenBFloat16: return 2; + default: return 8; + } +} + +// See NOTE [ cudnn fixSizeOneDimStride ] in aten/src/ATen/cudnn/Descriptors.h +template +static inline void fixSizeOneDimStride(int dim, const T *size, T *stride, bool nhwc) { + int64_t z = 1; + int index = 0; + std::vector permutation(dim); + + if (nhwc) { + permutation[index++] = 1; + } + for (int d = dim-1; d > 1; d--) { + permutation[index++] = d; + } + if (!nhwc) { + permutation[index++] = 1; + } + permutation[index++] = 0; + for (int d : permutation) { + if (size[d] == 1) { + stride[d] = z; + } else { + z *= size[d]; + } + } +} + +template +struct DescriptorDeleter { + void operator()(T* x) { + if (x != nullptr) { + MIOPEN_CHECK(dtor(x)); + } + } +}; + +// A generic class for wrapping MIOpen descriptor types. All you need +// is to give the underlying type the Descriptor_t points to (usually, +// if it's miopenTensorDescriptor_t it points to miopenTensorStruct), +// the constructor and the destructor. Subclasses are responsible +// for defining a set() function to actually set the descriptor. +// +// Descriptors default construct to a nullptr, and have a descriptor +// initialized the first time you call set() or any other initializing +// function. +template +// NOLINTNEXTLINE(bugprone-exception-escape) +class TORCH_HIP_CPP_API Descriptor { + public: + // Use desc() to access the underlying descriptor pointer in + // a read-only fashion. Most client code should use this. + // If the descriptor was never initialized, this will return + // nullptr. + T* desc() const { return desc_.get(); } + T* desc() { return desc_.get(); } + + // Use mut_desc() to access the underlying descriptor pointer + // if you intend to modify what it points to (e.g., using + // miopenSetFooDescriptor). This will ensure that the descriptor + // is initialized. Code in this file will use this function. + T* mut_desc() { init(); return desc_.get(); } +protected: + void init() { + if (desc_ == nullptr) { + T* raw_desc = nullptr; + MIOPEN_CHECK(ctor(&raw_desc)); + desc_.reset(raw_desc); + } + } +private: + std::unique_ptr> desc_; +}; + +class TORCH_HIP_CPP_API TensorDescriptor : public Descriptor< + miopenTensorDescriptor, + &miopenCreateTensorDescriptor, + &miopenDestroyTensorDescriptor> { + public: + TensorDescriptor() = default; + explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) { + set(t, pad); + } + + // See Note [CuDNN broadcast padding] + void set(const at::Tensor &t, size_t pad = 0); + void set(const at::Tensor &t, at::MemoryFormat memory_format, size_t pad = 0); + void set(miopenDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad = 0); + + void print(); + +private: + void set(miopenDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad, bool nhwc); + + void set(miopenDataType_t dataType, int dim, int* size, int* stride, bool nhwc) { + std::vector strides_copy(stride, stride + dim); + fixSizeOneDimStride(dim, size, strides_copy.data(), nhwc); + MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, strides_copy.data())); + } +}; + +std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d); + +class TORCH_HIP_CPP_API FilterDescriptor : public Descriptor< + miopenTensorDescriptor, + &miopenCreateTensorDescriptor, + &miopenDestroyTensorDescriptor> { + public: + void set(const at::Tensor &t, int64_t pad = 0) { + set(t, at::MemoryFormat::Contiguous, pad); + } + + void set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad = 0); + +private: + void set(miopenDataType_t dataType, int dim, int* size, int* stride, bool nhwc) { + std::vector strides_copy(stride, stride + dim); + fixSizeOneDimStride(dim, size, strides_copy.data(), nhwc); + MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, strides_copy.data())); + } +}; + +struct TORCH_HIP_CPP_API ConvolutionDescriptor + : public Descriptor< + miopenConvolutionDescriptor, + &miopenCreateConvolutionDescriptor, + &miopenDestroyConvolutionDescriptor> { + void set(miopenDataType_t dataType, miopenConvolutionMode_t c_mode, int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups, bool benchmark, bool deterministic) { + MIOPEN_CHECK(miopenInitConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale, c_mode)); + MIOPEN_CHECK(miopenSetConvolutionGroupCount(mut_desc(), groups)); + MIOPEN_CHECK(miopenSetConvolutionAttribute(mut_desc(), MIOPEN_CONVOLUTION_ATTRIB_DETERMINISTIC, deterministic ? 1 : 0)); + if (benchmark) { + MIOPEN_CHECK(miopenSetConvolutionFindMode(mut_desc(), miopenConvolutionFindModeNormal)); + } + } +}; + +// NOLINTNEXTLINE(bugprone-exception-escape) +struct TORCH_HIP_CPP_API DropoutDescriptor + : public Descriptor< + miopenDropoutDescriptor, + &miopenCreateDropoutDescriptor, + &miopenDestroyDropoutDescriptor> { + void set(miopenHandle_t handle, float dropout, void* states, size_t stateSizeInBytes, + unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) { + MIOPEN_CHECK(miopenSetDropoutDescriptor(mut_desc(), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode)); + } + + void restore(miopenHandle_t handle, float dropout, void* states, size_t stateSizeInBytes, + unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) { + MIOPEN_CHECK(miopenRestoreDropoutDescriptor(mut_desc(), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode)); + } +}; + +struct TORCH_HIP_CPP_API RNNDescriptor + : public Descriptor +{ + void set(int64_t hidden_size, int64_t num_layers, miopenRNNInputMode_t input_mode, miopenRNNDirectionMode_t direction, miopenRNNMode_t rnn_mode, + miopenRNNBiasMode_t bias_mode, miopenRNNAlgo_t algorithm, miopenDataType_t datatype) { + MIOPEN_CHECK(miopenSetRNNDescriptor(mut_desc(), hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algorithm, datatype)); + } + + void setWithDropout(DropoutDescriptor& dropout_desc, int64_t hidden_size, int64_t num_layers, miopenRNNInputMode_t input_mode, miopenRNNDirectionMode_t direction, + miopenRNNMode_t rnn_mode, miopenRNNBiasMode_t bias_mode, miopenRNNAlgo_t algorithm, miopenDataType_t datatype) { + MIOPEN_CHECK(miopenSetRNNDescriptor_V2(mut_desc(), hidden_size, num_layers, dropout_desc.mut_desc(), input_mode, direction, rnn_mode, bias_mode, algorithm, datatype)); + } +}; + +union Constant +{ + float f; + double d; + Constant(miopenDataType_t dataType, double value) { + if (dataType == miopenHalf || dataType == miopenFloat || dataType == miopenBFloat16) { + f = static_cast(value); + } else { + d = value; + } + } +}; + +}} // namespace + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Exceptions.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Exceptions.h new file mode 100644 index 0000000000000000000000000000000000000000..c7bc662c92c808373571254e524fde9a5c7aaadc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Exceptions.h @@ -0,0 +1,46 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +namespace at { namespace native { + +class miopen_exception : public std::runtime_error { +public: + miopenStatus_t status; + miopen_exception(miopenStatus_t status, const char* msg) + : std::runtime_error(msg) + , status(status) {} + miopen_exception(miopenStatus_t status, const std::string& msg) + : std::runtime_error(msg) + , status(status) {} +}; + +inline void MIOPEN_CHECK(miopenStatus_t status) +{ + if (status != miopenStatusSuccess) { + if (status == miopenStatusNotImplemented) { + throw miopen_exception(status, std::string(miopenGetErrorString(status)) + + ". This error may appear if you passed in a non-contiguous input."); + } + throw miopen_exception(status, miopenGetErrorString(status)); + } +} + +inline void HIP_CHECK(hipError_t error) +{ + if (error != hipSuccess) { + std::string msg("HIP error: "); + msg += hipGetErrorString(error); + throw std::runtime_error(msg); + } +} + +}} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Handle.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Handle.h new file mode 100644 index 0000000000000000000000000000000000000000..f5a3577f06a7a5f8416dadc1188f4882de306ae9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Handle.h @@ -0,0 +1,14 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace at::native { + +TORCH_HIP_CPP_API miopenHandle_t getMiopenHandle(); +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Types.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Types.h new file mode 100644 index 0000000000000000000000000000000000000000..98423302b3479e383e50d05f60c37041cc139b89 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Types.h @@ -0,0 +1,18 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +namespace at::native { + +TORCH_HIP_CPP_API miopenDataType_t getMiopenDataType(const at::Tensor& tensor); + +int64_t miopen_version(); + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Utils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Utils.h new file mode 100644 index 0000000000000000000000000000000000000000..790aaf5b11c0df8c1f17ccb3ca51930f248cdcea --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/Utils.h @@ -0,0 +1,23 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +namespace at { namespace native { + +// This function makes tensors which have zero stride contiguous, by +// setting the strides to 1. +inline Tensor contiguousIfZeroInStrides(const Tensor& t) { + for (auto s : t.strides()) { + if (s == 0) return t.contiguous(); + } + return t; +} + +}} + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/miopen-wrapper.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/miopen-wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..8b76fb0e6fe92b2033c6d8ba8caf006d29f767b1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/miopen/miopen-wrapper.h @@ -0,0 +1,26 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +#if MIOPEN_VERSION_MAJOR > 3 || (MIOPEN_VERSION_MAJOR == 3 && MIOPEN_VERSION_MINOR >= 4) +// miopen 3.4 moved find mode from private header to public header +#else +// from miopen_internal.h +extern "C" { + +typedef enum +{ + miopenConvolutionFindModeNormal = 1, /*!< Normal mode */ +} miopenConvolutionFindMode_t; + +miopenStatus_t miopenSetConvolutionFindMode( + miopenConvolutionDescriptor_t convDesc, + miopenConvolutionFindMode_t findMode); +} +#endif + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/AdaptivePooling.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/AdaptivePooling.h new file mode 100644 index 0000000000000000000000000000000000000000..ce4328d20c33660d7356bd7d96772d46964cf5ca --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/AdaptivePooling.h @@ -0,0 +1,54 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include + +namespace at::native { + +using adaptive_avg_pooling2d_fn = void(*)(Tensor& output, const Tensor& input, IntArrayRef output_size); +using adaptive_avg_pooling2d_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output); +DECLARE_DISPATCH(adaptive_avg_pooling2d_fn, adaptive_avg_pool2d_kernel) +DECLARE_DISPATCH(adaptive_avg_pooling2d_backward_fn, adaptive_avg_pool2d_backward_kernel) + +using adaptive_max_pooling2d_fn = void(*)(const Tensor& output, const Tensor& indices, const Tensor& input, IntArrayRef output_size); +using adaptive_max_pooling2d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices); +DECLARE_DISPATCH(adaptive_max_pooling2d_fn, adaptive_max_pool2d_kernel) +DECLARE_DISPATCH(adaptive_max_pooling2d_backward_fn, adaptive_max_pool2d_backward_kernel) + +using adaptive_avg_pooling3d_fn = void(*)(Tensor& output, const Tensor& input, IntArrayRef output_size); +using adaptive_avg_pooling3d_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output); +DECLARE_DISPATCH(adaptive_avg_pooling3d_fn, adaptive_avg_pool3d_kernel) +DECLARE_DISPATCH(adaptive_avg_pooling3d_backward_fn, adaptive_avg_pool3d_backward_kernel) + +using adaptive_max_pooling3d_fn = void(*)(const Tensor& output, const Tensor& indices, const Tensor& input, IntArrayRef output_size); +using adaptive_max_pooling3d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices); +DECLARE_DISPATCH(adaptive_max_pooling3d_fn, adaptive_max_pool3d_kernel) +DECLARE_DISPATCH(adaptive_max_pooling3d_backward_fn, adaptive_max_pool3d_backward_kernel) + +inline int64_t start_index(int64_t a, int64_t b, int64_t c) { + return (a / b) * c + ((a % b) * c) / b; +} + +inline int64_t end_index(int64_t a, int64_t b, int64_t c) { + return 1 + ((a + 1) * c - 1) / b; +} + +inline void adaptive_pool_empty_output_check(const Tensor& gradOutput_, const char* arg_name) { + int64_t ndim = gradOutput_.ndimension(); + for (const auto i : c10::irange(1, ndim)) { + TORCH_CHECK(gradOutput_.size(i) > 0, + arg_name, "(): Expected grad_output to have non-zero size for non-batch dimensions, " + "but grad_output has sizes ", gradOutput_.sizes(), " with dimension ", i, + " being empty"); + } +} + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/AmpKernels.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/AmpKernels.h new file mode 100644 index 0000000000000000000000000000000000000000..78a936a046d528c5128ff3293826a804fd0ec5c2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/AmpKernels.h @@ -0,0 +1,33 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace at { +class Tensor; + +namespace native { + +using _amp_foreach_non_finite_check_and_unscale_cpu__fn = void (*)( + TensorList, + Tensor&, + const Tensor&); + +using _amp_update_scale_cpu__fn = Tensor& (*)( + Tensor&, + Tensor&, + const Tensor&, + double, + double, + int64_t); + +DECLARE_DISPATCH(_amp_foreach_non_finite_check_and_unscale_cpu__fn, _amp_foreach_non_finite_check_and_unscale_cpu_stub) +DECLARE_DISPATCH(_amp_update_scale_cpu__fn, _amp_update_scale_cpu_stub) + +} // namespace native +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/BatchLinearAlgebra.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/BatchLinearAlgebra.h new file mode 100644 index 0000000000000000000000000000000000000000..e5708a286f81bb23689ef96ad2181156951c915f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/BatchLinearAlgebra.h @@ -0,0 +1,337 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +// Forward declare TI +namespace at { +class Tensor; +struct TensorIterator; + +namespace native { +enum class TransposeType; +} + +} + +namespace at::native { + +enum class LapackLstsqDriverType : int64_t { Gels, Gelsd, Gelsy, Gelss}; + +#if AT_BUILD_WITH_LAPACK() +// Define per-batch functions to be used in the implementation of batched +// linear algebra operations + +template +void lapackCholesky(char uplo, int n, scalar_t *a, int lda, int *info); + +template +void lapackCholeskyInverse(char uplo, int n, scalar_t *a, int lda, int *info); + +template +void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *w, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, value_t *rwork, int *info); + +template +void lapackGeqrf(int m, int n, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info); + +template +void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info); + +template +void lapackOrmqr(char side, char trans, int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *c, int ldc, scalar_t *work, int lwork, int *info); + +template +void lapackSyevd(char jobz, char uplo, int n, scalar_t* a, int lda, value_t* w, scalar_t* work, int lwork, value_t* rwork, int lrwork, int* iwork, int liwork, int* info); + +template +void lapackGels(char trans, int m, int n, int nrhs, + scalar_t *a, int lda, scalar_t *b, int ldb, + scalar_t *work, int lwork, int *info); + +template +void lapackGelsd(int m, int n, int nrhs, + scalar_t *a, int lda, scalar_t *b, int ldb, + value_t *s, value_t rcond, int *rank, + scalar_t* work, int lwork, + value_t *rwork, int* iwork, int *info); + +template +void lapackGelsy(int m, int n, int nrhs, + scalar_t *a, int lda, scalar_t *b, int ldb, + int *jpvt, value_t rcond, int *rank, + scalar_t *work, int lwork, value_t* rwork, int *info); + +template +void lapackGelss(int m, int n, int nrhs, + scalar_t *a, int lda, scalar_t *b, int ldb, + value_t *s, value_t rcond, int *rank, + scalar_t *work, int lwork, + value_t *rwork, int *info); + +template +struct lapackLstsq_impl; + +template +struct lapackLstsq_impl { + static void call( + char trans, int m, int n, int nrhs, + scalar_t *a, int lda, scalar_t *b, int ldb, + scalar_t *work, int lwork, int *info, // Gels flavor + int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor + value_t *s, // Gelss flavor + int *iwork // Gelsd flavor + ) { + lapackGels( + trans, m, n, nrhs, + a, lda, b, ldb, + work, lwork, info); + } +}; + +template +struct lapackLstsq_impl { + static void call( + char trans, int m, int n, int nrhs, + scalar_t *a, int lda, scalar_t *b, int ldb, + scalar_t *work, int lwork, int *info, // Gels flavor + int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor + value_t *s, // Gelss flavor + int *iwork // Gelsd flavor + ) { + lapackGelsy( + m, n, nrhs, + a, lda, b, ldb, + jpvt, rcond, rank, + work, lwork, rwork, info); + } +}; + +template +struct lapackLstsq_impl { + static void call( + char trans, int m, int n, int nrhs, + scalar_t *a, int lda, scalar_t *b, int ldb, + scalar_t *work, int lwork, int *info, // Gels flavor + int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor + value_t *s, // Gelss flavor + int *iwork // Gelsd flavor + ) { + lapackGelsd( + m, n, nrhs, + a, lda, b, ldb, + s, rcond, rank, + work, lwork, + rwork, iwork, info); + } +}; + +template +struct lapackLstsq_impl { + static void call( + char trans, int m, int n, int nrhs, + scalar_t *a, int lda, scalar_t *b, int ldb, + scalar_t *work, int lwork, int *info, // Gels flavor + int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor + value_t *s, // Gelss flavor + int *iwork // Gelsd flavor + ) { + lapackGelss( + m, n, nrhs, + a, lda, b, ldb, + s, rcond, rank, + work, lwork, + rwork, info); + } +}; + +template +void lapackLstsq( + char trans, int m, int n, int nrhs, + scalar_t *a, int lda, scalar_t *b, int ldb, + scalar_t *work, int lwork, int *info, // Gels flavor + int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor + value_t *s, // Gelss flavor + int *iwork // Gelsd flavor + ) { + lapackLstsq_impl::call( + trans, m, n, nrhs, + a, lda, b, ldb, + work, lwork, info, + jpvt, rcond, rank, rwork, + s, + iwork); +} + +template +void lapackLuSolve(char trans, int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info); + +template +void lapackLu(int m, int n, scalar_t *a, int lda, int *ipiv, int *info); + +template +void lapackLdlHermitian( + char uplo, + int n, + scalar_t* a, + int lda, + int* ipiv, + scalar_t* work, + int lwork, + int* info); + +template +void lapackLdlSymmetric( + char uplo, + int n, + scalar_t* a, + int lda, + int* ipiv, + scalar_t* work, + int lwork, + int* info); + +template +void lapackLdlSolveHermitian( + char uplo, + int n, + int nrhs, + scalar_t* a, + int lda, + int* ipiv, + scalar_t* b, + int ldb, + int* info); + +template +void lapackLdlSolveSymmetric( + char uplo, + int n, + int nrhs, + scalar_t* a, + int lda, + int* ipiv, + scalar_t* b, + int ldb, + int* info); + +template +void lapackSvd(char jobz, int m, int n, scalar_t *a, int lda, value_t *s, scalar_t *u, int ldu, scalar_t *vt, int ldvt, scalar_t *work, int lwork, value_t *rwork, int *iwork, int *info); +#endif + +#if AT_BUILD_WITH_BLAS() +template +void blasTriangularSolve(char side, char uplo, char trans, char diag, int n, int nrhs, scalar_t* a, int lda, scalar_t* b, int ldb); +#endif + +using cholesky_fn = void (*)(const Tensor& /*input*/, const Tensor& /*info*/, bool /*upper*/); +DECLARE_DISPATCH(cholesky_fn, cholesky_stub) + +using cholesky_inverse_fn = Tensor& (*)(Tensor& /*result*/, Tensor& /*infos*/, bool /*upper*/); + +DECLARE_DISPATCH(cholesky_inverse_fn, cholesky_inverse_stub) + +using linalg_eig_fn = void (*)(Tensor& /*eigenvalues*/, Tensor& /*eigenvectors*/, Tensor& /*infos*/, const Tensor& /*input*/, bool /*compute_eigenvectors*/); + +DECLARE_DISPATCH(linalg_eig_fn, linalg_eig_stub) + +// Converts LAPACK's real-valued eigenvector encoding to complex eigenvectors +TORCH_API void linalg_eig_make_complex_eigenvectors( + const Tensor& complex_vectors, + const Tensor& complex_values, + const Tensor& real_vectors); + +DECLARE_DISPATCH( + void(*)(const Tensor&, const Tensor&, const Tensor&), + linalg_eig_make_complex_eigenvectors_stub) + + +using geqrf_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/); +DECLARE_DISPATCH(geqrf_fn, geqrf_stub) + +using orgqr_fn = Tensor& (*)(Tensor& /*result*/, const Tensor& /*tau*/); +DECLARE_DISPATCH(orgqr_fn, orgqr_stub) + +using ormqr_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/, const Tensor& /*other*/, bool /*left*/, bool /*transpose*/); +DECLARE_DISPATCH(ormqr_fn, ormqr_stub) + +using linalg_eigh_fn = void (*)( + const Tensor& /*eigenvalues*/, + const Tensor& /*eigenvectors*/, + const Tensor& /*infos*/, + bool /*upper*/, + bool /*compute_eigenvectors*/); +DECLARE_DISPATCH(linalg_eigh_fn, linalg_eigh_stub) + +using lstsq_fn = void (*)( + const Tensor& /*a*/, + Tensor& /*b*/, + Tensor& /*rank*/, + Tensor& /*singular_values*/, + Tensor& /*infos*/, + double /*rcond*/, + std::string /*driver_name*/); +DECLARE_DISPATCH(lstsq_fn, lstsq_stub) + +using triangular_solve_fn = void (*)( + const Tensor& /*A*/, + const Tensor& /*B*/, + bool /*left*/, + bool /*upper*/, + TransposeType /*transpose*/, + bool /*unitriangular*/); +DECLARE_DISPATCH(triangular_solve_fn, triangular_solve_stub) + +using lu_factor_fn = void (*)( + const Tensor& /*input*/, + const Tensor& /*pivots*/, + const Tensor& /*infos*/, + bool /*compute_pivots*/); +DECLARE_DISPATCH(lu_factor_fn, lu_factor_stub) + +using unpack_pivots_fn = void(*)( + TensorIterator& iter, + const int64_t dim_size, + const int64_t max_pivot); +DECLARE_DISPATCH(unpack_pivots_fn, unpack_pivots_stub) + +using lu_solve_fn = void (*)( + const Tensor& /*LU*/, + const Tensor& /*pivots*/, + const Tensor& /*B*/, + TransposeType /*trans*/); +DECLARE_DISPATCH(lu_solve_fn, lu_solve_stub) + +using ldl_factor_fn = void (*)( + const Tensor& /*LD*/, + const Tensor& /*pivots*/, + const Tensor& /*info*/, + bool /*upper*/, + bool /*hermitian*/); +DECLARE_DISPATCH(ldl_factor_fn, ldl_factor_stub) + +using svd_fn = void (*)( + const Tensor& /*A*/, + const bool /*full_matrices*/, + const bool /*compute_uv*/, + const std::optional& /*driver*/, + const Tensor& /*U*/, + const Tensor& /*S*/, + const Tensor& /*Vh*/, + const Tensor& /*info*/); +DECLARE_DISPATCH(svd_fn, svd_stub) + +using ldl_solve_fn = void (*)( + const Tensor& /*LD*/, + const Tensor& /*pivots*/, + const Tensor& /*result*/, + bool /*upper*/, + bool /*hermitian*/); +DECLARE_DISPATCH(ldl_solve_fn, ldl_solve_stub) +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/CPUBlas.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/CPUBlas.h new file mode 100644 index 0000000000000000000000000000000000000000..7cdf10c6d4fdccafa542cbc37da584a0fe17835d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/CPUBlas.h @@ -0,0 +1,319 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include + + +namespace at::native::cpublas { + +namespace internal { +void normalize_last_dims( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + int64_t *lda, int64_t *ldb, int64_t *ldc); +} // namespace internal + +using gemm_fn = void(*)( + at::ScalarType type, + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + const Scalar& alpha, + const void *a, int64_t lda, + const void *b, int64_t ldb, + const Scalar& beta, + void *c, int64_t ldc); + +DECLARE_DISPATCH(gemm_fn, gemm_stub) + +using gemm_no_downcast_fn = void(*)( + at::ScalarType type, + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + const Scalar& alpha, + const void *a, int64_t lda, + const void *b, int64_t ldb, + const Scalar& beta, + void *c, int64_t ldc); + +DECLARE_DISPATCH(gemm_no_downcast_fn, gemm_no_downcast_stub) + +template +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + at::opmath_type alpha, + const scalar_t *a, int64_t lda, + const scalar_t *b, int64_t ldb, + at::opmath_type beta, + scalar_t *c, int64_t ldc) { + internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); + gemm_stub( + kCPU, c10::CppTypeToScalarType::value, + transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + double alpha, + const double *a, int64_t lda, + const double *b, int64_t ldb, + double beta, + double *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + float alpha, + const float *a, int64_t lda, + const float *b, int64_t ldb, + float beta, + float *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + float alpha, + const at::BFloat16 *a, int64_t lda, + const at::BFloat16 *b, int64_t ldb, + float beta, + at::BFloat16 *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + const float alpha, + const at::BFloat16 *a, int64_t lda, + const at::BFloat16 *b, int64_t ldb, + const float beta, + float *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + float alpha, + const at::Half *a, int64_t lda, + const at::Half *b, int64_t ldb, + float beta, + at::Half *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + const float alpha, + const at::Half *a, int64_t lda, + const at::Half *b, int64_t ldb, + const float beta, + float *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + c10::complex alpha, + const c10::complex *a, int64_t lda, + const c10::complex *b, int64_t ldb, + c10::complex beta, + c10::complex *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + c10::complex alpha, + const c10::complex *a, int64_t lda, + const c10::complex *b, int64_t ldb, + c10::complex beta, + c10::complex *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + int64_t alpha, + const int64_t *a, int64_t lda, + const int64_t *b, int64_t ldb, + int64_t beta, + int64_t *c, int64_t ldc); + +template +void gemm_batched( + TransposeType transa, TransposeType transb, + int64_t batch_size, int64_t m, int64_t n, int64_t k, + scalar_t alpha, + const scalar_t * const *a, int64_t lda, + const scalar_t * const *b, int64_t ldb, + const scalar_t beta, + scalar_t * const *c, int64_t ldc); + +template +void gemm_batched_with_stride( + TransposeType transa, TransposeType transb, + int64_t batch_size, int64_t m, int64_t n, int64_t k, + scalar_t alpha, + const scalar_t *a, int64_t lda, int64_t batch_stride_a, + const scalar_t *b, int64_t ldb, int64_t batch_stride_b, + scalar_t beta, + scalar_t *c, int64_t ldc, int64_t batch_stride_c); + +using axpy_fn = void(*)(at::ScalarType type, int64_t n, const Scalar& a, const void *x, int64_t incx, void *y, int64_t incy); + +DECLARE_DISPATCH(axpy_fn, axpy_stub) + +template +void axpy(int64_t n, scalar_t a, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy){ + if(n == 1) + { + incx = 1; + incy = 1; + } + axpy_stub( + kCPU, c10::CppTypeToScalarType::value, + n, a, x, incx, y, incy); +} + +void axpy(int64_t n, double a, const double *x, int64_t incx, double *y, int64_t incy); +void axpy(int64_t n, float a, const float *x, int64_t incx, float *y, int64_t incy); +void axpy(int64_t n, c10::complex a, const c10::complex *x, int64_t incx, c10::complex *y, int64_t incy); +void axpy(int64_t n, c10::complex a, const c10::complex *x, int64_t incx, c10::complex *y, int64_t incy); + +using copy_fn = void(*)(at::ScalarType type, int64_t n, const void *x, int64_t incx, void *y, int64_t incy); + +DECLARE_DISPATCH(copy_fn, copy_stub) + +template +void copy(int64_t n, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy) { + if(n == 1) + { + incx = 1; + incy = 1; + } + copy_stub( + kCPU, c10::CppTypeToScalarType::value, + n, x, incx, y, incy); +} + +void copy(int64_t n, const double *x, int64_t incx, double *y, int64_t incy); +void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy); +void copy(int64_t n, const c10::complex *x, int64_t incx, c10::complex *y, int64_t incy); +void copy(int64_t n, const c10::complex *x, int64_t incx, c10::complex *y, int64_t incy); + +// Batch-reduce GEMM +// Operates by the following formula: +// C = SUM(A[i] x B[i]) + C if add_C is true, i = 0 to batch size +// A Base pointer to a tensor A. +// B Base pointer to a tensor B. +// C Pointer to a tensor C (accumulation buffer). +// Note only batch size 1 is used currently + +// Define macros for available brgemm APIs +// so that callers can determine which APIs are available +#define CPUBLAS_BRGEMM_F16F16F32 // half * half -> float +#define CPUBLAS_BRGEMM_BF16BF16F32 // bfloat16 * bfloat16 -> float +#define CPUBLAS_BRGEMM_F32F32F32 // float * float -> float +#define CPUBLAS_BRGEMM_U8U8I32 // unsigned char * unsigned char -> int32 +#define CPUBLAS_BRGEMM_U8I8I32 // unsigned char * signed char -> int32 +#define CPUBLAS_BRGEMM_I8I8I32 // signed char * signed char -> int32 + +TORCH_API void brgemm( + int64_t M, + int64_t N, + int64_t K, + int64_t ld_a, + int64_t ld_b, + int64_t ld_c, + const bool add_C, + const at::Half* A, + const at::Half* B, + float* C, + bool is_vnni = true); + +TORCH_API void brgemm( + int64_t M, + int64_t N, + int64_t K, + int64_t ld_a, + int64_t ld_b, + int64_t ld_c, + const bool add_C, + const at::BFloat16* A, + const at::BFloat16* B, + float* C, + bool is_vnni = true); + +TORCH_API void brgemm( + int64_t M, + int64_t N, + int64_t K, + int64_t ld_a, + int64_t ld_b, + int64_t ld_c, + const bool add_C, + const float* A, + const float* B, + float* C, + bool is_vnni = false); + +TORCH_API void brgemm( + int64_t M, + int64_t N, + int64_t K, + int64_t ld_a, + int64_t ld_b, + int64_t ld_c, + const bool add_C, + const unsigned char* A, + const unsigned char* B, + int32_t* C, + bool is_vnni = true); + +TORCH_API void brgemm( + int64_t M, + int64_t N, + int64_t K, + int64_t ld_a, + int64_t ld_b, + int64_t ld_c, + const bool add_C, + const unsigned char* A, + const signed char* B, + int32_t* C, + bool is_vnni = true); + +TORCH_API void brgemm( + int64_t M, + int64_t N, + int64_t K, + int64_t ld_a, + int64_t ld_b, + int64_t ld_c, + const bool add_C, + const signed char* A, + const signed char* B, + int32_t* C, + bool is_vnni = true); + +// Release brgemm hardware context +TORCH_API void brgemm_release(bool is_vnni = true); + +// Pack B matrix to get better performance if needed +TORCH_API void pack( + int64_t K, + int64_t N, + int64_t ld_in, + int64_t ld_out, + ScalarType dt_in, + ScalarType dt_out, + const void* in, + void* out); + +// Whether pack is supported in the platform. +TORCH_API bool could_pack(ScalarType dt_in); + +} // namespace at::native::cpublas + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/CanUse32BitIndexMath.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/CanUse32BitIndexMath.h new file mode 100644 index 0000000000000000000000000000000000000000..71cb3ddfcfb5f1634532e29798dd6868bb8b5a3c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/CanUse32BitIndexMath.h @@ -0,0 +1,18 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include + +namespace at { +class TensorBase; +} + +namespace at::native { + +TORCH_API bool canUse32BitIndexMath(const at::TensorBase &t, int64_t max_elem=std::numeric_limits::max()); + +} + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ComplexHelper.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ComplexHelper.h new file mode 100644 index 0000000000000000000000000000000000000000..3ba035b9ea17a17603bc3b445014b4e5f31aa5b3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ComplexHelper.h @@ -0,0 +1,102 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include + +#include +#endif + +// WARNING: this header contains non-inline functions and should be only +// included from ONE cpp file + +namespace at::native { + +// View tensor with new dtype, storage offset, sizes and strides +inline Tensor view_tensor( + const Tensor &tensor, ScalarType dtype, + c10::SymInt offset, SymIntArrayRef sizes, SymIntArrayRef strides) { + Storage storage = tensor.storage(); + auto key_set = tensor.key_set().remove(DispatchKey::Conjugate); + auto new_tensor = detail::make_tensor( + c10::TensorImpl::VIEW, std::move(storage), key_set, scalarTypeToTypeMeta(dtype)); + auto * impl = new_tensor.unsafeGetTensorImpl(); + impl->set_sizes_and_strides(sizes, strides, offset); + return new_tensor; +} + +inline SymDimVector computeStrideForViewAsReal(SymIntArrayRef oldstride) { + SymDimVector res(oldstride.size() + 1); + for (const auto i : c10::irange(oldstride.size())) { + res[i] = oldstride[i] * 2; + } + res.back() = 1; + return res; +} + +inline Tensor _view_as_real_physical(const Tensor& self) { + TORCH_CHECK(self.is_complex(), "view_as_real is only supported for complex tensors"); + auto old_sizes = self.sym_sizes(); + SymDimVector new_sizes(old_sizes.size() + 1); + std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin()); + // last dimension will always have two elements containing the real and imag vals + new_sizes.back() = 2; + auto new_strides = computeStrideForViewAsReal(self.sym_strides()); + auto new_storage_offset = self.sym_storage_offset() * 2; + const auto float_type = c10::toRealValueType(self.scalar_type()); + auto real_tensor = view_tensor(self, float_type, std::move(new_storage_offset), new_sizes, new_strides); + return real_tensor; +} + +// expects as input a complex tensor and returns back a tensor +// with corresponding real dtype containing the complex values +// in the last two dimensions +Tensor view_as_real(const Tensor& self) { + TORCH_CHECK(!self.is_conj(), "view_as_real doesn't work on unresolved conjugated tensors. To resolve the conjugate tensor so you can view it as real, use self.resolve_conj(); however, be warned that the resulting tensor will NOT alias the original."); + return _view_as_real_physical(self); +} + +inline SymDimVector computeStrideForViewAsComplex(SymIntArrayRef oldstride) { + const auto dim = oldstride.size(); + TORCH_CHECK(dim > 0 && oldstride[dim - 1] == 1, "Tensor must have a last dimension with stride 1"); + + SymDimVector res(dim - 1); + for (const auto i : c10::irange(res.size())) { + TORCH_CHECK(oldstride[i] % 2 == 0, "Tensor must have a stride divisible by 2 for all but last dimension"); + res[i] = oldstride[i] / 2; + } + return res; +} + +// expects as input a float or double tensor with last dimension of size 2 +// and returns back a tensor with corresponding complex dtype +Tensor view_as_complex(const Tensor& self) { + TORCH_CHECK( + self.scalar_type() == kFloat || self.scalar_type() == kDouble || self.scalar_type() == kHalf, + "view_as_complex is only supported for half, float and double tensors, but got a tensor of scalar type: ", self.scalar_type()); + + auto old_sizes = self.sym_sizes(); + TORCH_CHECK(!old_sizes.empty(), "Input tensor must have one or more dimensions"); + TORCH_CHECK(old_sizes[old_sizes.size()-1] == 2, "Tensor must have a last dimension of size 2"); + SymDimVector new_sizes(old_sizes.begin(), old_sizes.end() - 1); + + const auto new_strides = computeStrideForViewAsComplex(self.sym_strides()); + const auto complex_type = c10::toComplexType(self.scalar_type()); + + TORCH_CHECK(self.sym_storage_offset() % 2 == 0, "Tensor must have a storage_offset divisible by 2"); + const auto new_storage_offset = self.sym_storage_offset() / 2; + + return view_tensor(self, complex_type, new_storage_offset, new_sizes, new_strides); +} + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/CompositeRandomAccessor.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/CompositeRandomAccessor.h new file mode 100644 index 0000000000000000000000000000000000000000..ab99b0ce5496931d41695af614e67afdfd6af437 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/CompositeRandomAccessor.h @@ -0,0 +1,39 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace at::native { + +struct TupleInfoCPU { + template + using tuple = std::tuple; + + template + static constexpr auto tie(Types&... args) noexcept { + return std::tie(args...); + } +}; + +template +using CompositeRandomAccessorCPU = + CompositeRandomAccessor; + +template +void swap( + references_holder rh1, + references_holder rh2 +) { + return std::swap(rh1.data(), rh2.data()); +} + +template +auto get(references_holder rh) -> decltype(std::get(rh.data())) { + return std::get(rh.data()); +} + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ConvolutionMM3d.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ConvolutionMM3d.h new file mode 100644 index 0000000000000000000000000000000000000000..75342eb5d919c4068ba2d912c138c6f02cd523a9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ConvolutionMM3d.h @@ -0,0 +1,19 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +namespace at::native { + +std::tuple slow_conv3d_backward_cpu( + const Tensor& grad_output, + const Tensor& self, + const Tensor& weight, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + std::array output_mask); + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Copy.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Copy.h new file mode 100644 index 0000000000000000000000000000000000000000..33ed6063a7a2fe721525855481a2b62e98760029 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Copy.h @@ -0,0 +1,25 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace at { + +class Tensor; +struct TensorIterator; +class TensorBase; + +namespace native { + +using copy_fn = void (*)(TensorIterator&, bool non_blocking); + +DECLARE_DISPATCH(copy_fn, copy_stub) + +TORCH_API void copy_ignoring_overlaps(const TensorBase &dst, const TensorBase &src); + +} // namespace native +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/DilatedConvolutionUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/DilatedConvolutionUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..fda88a764833121e58019f9c6142dfa016343ef6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/DilatedConvolutionUtils.h @@ -0,0 +1,234 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +#include +#include +#include + +#define TORCH_CHECK_DIM_SIZE(T, DIM, DIM_SIZE, SIZE) \ + TORCH_CHECK( \ + T.dim() == DIM && T.size(DIM_SIZE) == SIZE, \ + "Need " #T " of dimension ", \ + DIM, \ + " and " #T ".size[", \ + DIM_SIZE, \ + "] == ", \ + SIZE, \ + " but got input to be of shape ", \ + T.sizes()) + +namespace at::native::internal { +namespace { +inline bool all_positive(IntArrayRef& arr) { + return std::all_of( + arr.begin(), arr.end(), [](int64_t item) { return item > 0; }); +} + +inline bool all_nonnegative(std::vector& arr) { + return std::all_of( + arr.begin(), arr.end(), [](int64_t item) { return item >= 0; }); +} + +} // namespace + +// calculate the rear part of output tensor sizes +template +std::vector get_output_size( + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride_size, + IntArrayRef pad_size, + IntArrayRef dilation_size) { + std::vector sizes; + for (const auto index : c10::irange(dim)) { + sizes.push_back( + div_rtn( + input.size(index + input.dim() - dim) + 2 * pad_size[index] - + (dilation_size[index] * (kernel_size[index] - 1) + 1), + stride_size[index]) + + 1); + } + return sizes; +} + +// calculate the sizes of output tensor +template +std::vector get_output_size( + const Tensor& input, + const Tensor& weight, + IntArrayRef kernel_size, + IntArrayRef stride_size, + IntArrayRef pad_size, + IntArrayRef dilation_size) { + auto output_size = get_output_size( + input, kernel_size, stride_size, pad_size, dilation_size); + output_size.insert(output_size.begin(), weight.size(0)); + if (input.dim() == dim + 2) { + output_size.insert(output_size.begin(), input.size(0)); + } + return output_size; +} +/* + slow_conv_dilated_shape_check - check user-input to dilated convolution + forward and backward functions. +*/ +template +void slow_conv_dilated_shape_check( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + const Tensor& grad_output, + IntArrayRef kernel_size, + IntArrayRef stride_size, + IntArrayRef pad_size, + IntArrayRef dilation_size) { + /* + When the following tensors are defined: + + bias, grad_weight, grad_output + + then these are assumed to be contiguous without checking + because of these tensors are made contiguous by calling + .contiguous() method or by resizing of zero-sized tensors in + forward/backward functions. + + When grad_weight is defined then it is assumed without + checking to have the same shape as weight, see backward + functions. + */ + // Check size arguments + TORCH_CHECK( + kernel_size.size() == dim, + "kernel sizes length should be ", + dim, + ", but got ", + kernel_size.size()); + TORCH_CHECK( + stride_size.size() == dim, + "strides length should be ", + dim, + ", but got ", + stride_size.size()); + TORCH_CHECK( + dilation_size.size() == dim, + "dilations length should be ", + dim, + ", but got ", + dilation_size.size()); + TORCH_CHECK( + pad_size.size() == dim, + "pads length should be ", + dim, + ", but got ", + pad_size.size()); + + TORCH_CHECK( + all_positive(kernel_size), + "kernel size should be greater than zero, but got ", + kernel_size); + TORCH_CHECK( + all_positive(stride_size), + "stride should be greater than zero, but got ", + stride_size); + TORCH_CHECK( + all_positive(dilation_size), + "dilation should be greater than zero, but got ", + dilation_size); + + // check input + TORCH_CHECK(input.defined(), "input must be defined"); + bool is_batch = input.dim() == dim + 2; + int64_t n = (is_batch ? 2 : 1); + int64_t ndim = n + dim; + if (!is_batch) { + // input dim has to be dim + 1 if not batched + TORCH_CHECK( + input.dim() == dim + 1, + "input must be 4D or 5D tensor but got ", + input.dim(), + "D tensor"); + } + + // check output sizes + auto output_size = get_output_size( + input, kernel_size, stride_size, pad_size, dilation_size); + + TORCH_CHECK( + all_nonnegative(output_size), + "calculated output size ", + output_size, + " is too small (all sizes must be non-negative)"); + + // check weight + TORCH_CHECK(weight.defined(), "weight must be defined"); + TORCH_CHECK( + weight.dim() == dim + 2, + "weight must be ", + dim + 2, + "D tensor but got ", + weight.dim(), + "D tensor dim=", + dim); + TORCH_CHECK( + weight.sizes().slice(2) == kernel_size, + "weight[2:] shape ", + weight.sizes().slice(2), + " must be equal to kernel_size ", + kernel_size); + + TORCH_CHECK_DIM_SIZE(input, input.dim(), (is_batch ? 1 : 0), weight.size(1)); + + // check bias when present + if (bias.defined()) { + TORCH_CHECK( + bias.dim() == 1, + "bias must be 1D tensor but got ", + bias.dim(), + "D tensor"); + TORCH_CHECK_DIM_SIZE(bias, 1, 0, weight.size(0)); + } + + // check grad_output when present + if (grad_output.defined()) { + TORCH_CHECK( + grad_output.dim() == ndim, + "grad_output must be ", + ndim, + "D tensor but got ", + grad_output.dim(), + "D tensor"); + if (is_batch) { + TORCH_CHECK( + grad_output.size(0) == input.size(0), + "grad_output.size(0)=", + grad_output.size(0), + " must be input.size(0)=", + input.size(0)); + } + TORCH_CHECK( + grad_output.size(n - 1) == weight.size(0), + "grad_output.size(", + n - 1, + ")=", + grad_output.size(n - 1), + " must be weight.size(0)=", + weight.size(0)); + TORCH_CHECK( + grad_output.sizes().slice(n) == output_size, + "grad_output[", + n, + ":] shape", + grad_output.sizes().slice(n), + " must be equal to output size ", + output_size); + } +} + +} // namespace at::native::internal + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/DistributionTemplates.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/DistributionTemplates.h new file mode 100644 index 0000000000000000000000000000000000000000..f97f4fb0443c400011a9af0d8ba896c8a5c1c2f5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/DistributionTemplates.h @@ -0,0 +1,399 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#include +#include +#endif + +namespace at::native::templates { + +// ==================================================== Random ======================================================== + +// The purpose of `update_from` and `update_to` is to find the closest valid int64_t number that can be used as actual `from`. +// The current implementation of `random_` uses uint64_t arithmetic and casts the result to the target dtype(scalar_t). +// This casting can result in generating numbers that happen to be greater or equal to `to` value. For instance: +// +// auto actual = torch::empty({3, 3}, torch::half); +// actual.random_(0, 65504); +// +// If random's uint64_t arithmetic produces 65503 as a random value after casting to torch::half it becomes 65504 +// and violates the requirement that random value must be less than `to`. To resolve this issue `update_from` and `update_to` +// moves `from` to the right and `to` to the left to the next closest value that won't go outside [from, to) after casting to +// the target dtype. For `to` = 65504 it moves left for (1 << (log2(to) - 11 + 1)) = 32 and becomes 65472, which is previous +// available number for torch::half dtype. +template +int64_t update_from(int64_t from) { + static_assert( + std::is_floating_point_v || + std::is_same_v || + std::is_same_v, "scalar_t must be floating-point type"); + const auto from_plus_1 = static_cast(static_cast(from + 1)); + if (from_plus_1 < from) { + int64_t from_ = std::abs(from + 1); + int n = 0; + while (from_ >>= 1) ++n; + // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult) + from = from_plus_1 + (1LL << (n - std::numeric_limits::digits + 1)); + } + return from; +} + +template +int64_t update_to(int64_t to) { + static_assert( + std::is_floating_point_v || + std::is_same_v || + std::is_same_v, "scalar_t must be floating-point type"); + const auto to_minus_1 = static_cast(static_cast(to - 1)); + if (to_minus_1 >= to) { + int64_t to_ = std::abs(to - 1); + int n = 0; + while (to_ >>= 1) ++n; + // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult) + to = to_minus_1 - (1LL << (n - std::numeric_limits::digits + 1)); + } + return to; +} + +// Return earlier for not invoking kernel. +// See https://github.com/pytorch/pytorch/issues/103418 for more details +#define CHECK_EMPTY_AND_RETURN(tensor) \ + if (tensor.numel() == 0) { \ + return tensor; \ + } + +template class random_kernel, typename RNG> +at::Tensor& random_impl(at::Tensor& self, std::optional generator) { + CHECK_EMPTY_AND_RETURN(self); + auto iter = at::TensorIterator::borrowing_nullary_op(self); + random_kernel()(iter, generator); + return self; +} + +#define CHECK_OUT_OF_BOUNDS(var, name, min, max, dtype) \ + TORCH_CHECK(var >= min && var <= max, name , " is out of bounds for ", dtype); \ + +#define WARN_OUT_OF_BOUNDS(var, name, digits, dtype) \ + if (var < -(1LL << digits) || var > (1LL << digits)) { \ + TORCH_WARN(name , " is out of bounds [-(2^", digits, "), 2^", digits, "]. ", \ + "Due to precision limitations ", dtype, " can support discrete uniform distribution only within this range. ", \ + "This warning will become an error in version 1.7 release, please fix the code in advance"); \ + } + +inline void check_from_to_in_range(int64_t from, int64_t to_inc, caffe2::TypeMeta dtype) { + const auto scalar_type = typeMetaToScalarType(dtype); + if (isFloatingType(scalar_type)) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "check_random_fp_bounds", [&] { + const auto min = static_cast(std::numeric_limits::lowest()); + const auto max = static_cast(std::numeric_limits::max()); + CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype); + CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", min, max, dtype); + + constexpr auto digits = std::numeric_limits::digits; + WARN_OUT_OF_BOUNDS(from, "from", digits, dtype); + WARN_OUT_OF_BOUNDS(to_inc, "to - 1", digits, dtype); + }); + } else if (scalar_type == kUInt64) { + // When you do a comparison between int64_t and uint64_t, the usual + // arithmetic conversions say that the int64_t value is promoted to + // unsigned. But this conversion wraps around: if I had -1 as my int64_t, + // then it will promote to 0xFFFFFFFFFFFFFFFF in uint64_t. This is never + // the right thing to do. + CHECK_OUT_OF_BOUNDS(from, "from", 0, INT64_MAX, dtype); + CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", 0, INT64_MAX, dtype); + } else if (isIntegralType(scalar_type, /*includeBool=*/true)) { + AT_DISPATCH_V2(scalar_type, "check_random_integral_bounds", AT_WRAP([&]() { + const auto min = static_cast(std::numeric_limits::lowest()); + const auto max = static_cast(std::numeric_limits::max()); + CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype); + CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", min, max, dtype); + }), AT_EXPAND(AT_INTEGRAL_TYPES), kUInt16, kUInt32, kBool); + } else { + TORCH_CHECK(false, "check_random_bounds handles only integral, floating-point and boolean types"); + } +} + +template class random_from_to_kernel, typename RNG> +at::Tensor& random_from_to_impl(at::Tensor& self, int64_t from, std::optional to_opt, std::optional generator) { + uint64_t range = 0; + auto iter = at::TensorIterator::borrowing_nullary_op(self); + if (to_opt.has_value()) { + // [from, to) + int64_t to = *to_opt; + TORCH_CHECK(from < to, "random_ expects 'from' to be less than 'to', but got from=", from, " >= to=", to); + if (isFloatingType(iter.dtype())) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "random_update_from_to", [&] { + from = update_from(from); + to = update_to(to); + TORCH_CHECK(from < to, "random_ expects 'from' casted to dtype to be less than 'to' casted to dtype, but got from=", from, " >= to=", to); + }); + } + check_from_to_in_range(from, to - 1, self.dtype()); + CHECK_EMPTY_AND_RETURN(self); + range = static_cast(to) - static_cast(from); + random_from_to_kernel()(iter, range, from, generator); + } else if (from != std::numeric_limits::lowest()) { + // [from, std::numeric_limits::max()] + int64_t to_inc = 0; + if (isFloatingType(iter.dtype())) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "random_from_to_range_calc", [&] { + constexpr int64_t scalar_t_max = static_cast(1) << std::numeric_limits::digits; + to_inc = scalar_t_max > std::numeric_limits::max() ? std::numeric_limits::max() : static_cast(scalar_t_max); + from = update_from(from); + TORCH_CHECK(from < to_inc, "random_ expects 'from' casted to dtype to be less than or equal to 'to_inc' casted to dtype, but got from=", from, " > to_inc=", to_inc); + }); + } else if (isIntegralType(iter.dtype(), /*includeBool=*/true)) { + AT_DISPATCH_V2(self.scalar_type(), "random_from_to_range_calc", AT_WRAP([&] { + if constexpr (std::is_same_v) { + to_inc = static_cast(true); + } else { + to_inc = static_cast(std::numeric_limits::max()); + } + }), AT_EXPAND(AT_INTEGRAL_TYPES_V2), kBool); + } else { + TORCH_CHECK(false, "random_from_to_impl handles only integral, floating-point and boolean types"); + } + check_from_to_in_range(from, to_inc, self.dtype()); + CHECK_EMPTY_AND_RETURN(self); + range = static_cast(to_inc) - static_cast(from) + 1; + random_from_to_kernel()(iter, range, from, generator); + } else { + // [std::numeric_limits::lowest(), std::numeric_limits::max()] + // range = 2^64 + CHECK_EMPTY_AND_RETURN(self); + random_from_to_kernel()(iter, generator); + } + return self; +} + +// ==================================================== Normal ======================================================== + +#define CHECK_NORMAL_TENSOR_STD(std) \ + do { \ + TORCH_CHECK( \ + !std.is_complex(), \ + "normal expects standard deviation to be non-complex"); \ + TORCH_CHECK( \ + std.numel() == 0 || std.is_meta() || std.min().ge(0).item(), \ + "normal expects all elements of std >= 0.0"); \ + } while (0) + +#define CHECK_NORMAL_STD(std) \ + TORCH_CHECK(std >= 0.0, "normal expects std >= 0.0, but found std ", std); + +template class normal_kernel, typename RNG> +Tensor& normal_impl_(Tensor& self, double mean, double std, std::optional gen) { + CHECK_NORMAL_STD(std); + CHECK_EMPTY_AND_RETURN(self); + + if (self.is_complex()) { + auto float_tensor = at::view_as_real(self); + // variance for normal distribution of the real and imaginary values + // is half of the input variance + normal_kernel()(float_tensor, mean, std/(std::sqrt(2)), gen); + } else { + normal_kernel()(self, mean, std, gen); + } + return self; +} + +template class normal_kernel, typename RNG> +Tensor& normal_out_impl(Tensor& output, const Tensor& mean, double std, std::optional gen) { + CHECK_NORMAL_STD(std); + auto std_tensor = at::empty_like(output, MemoryFormat::Contiguous); + auto shape = at::infer_size(mean.sizes(), std_tensor.sizes()); + at::native::resize_output(output, shape); + normal_impl_(output, 0, std, gen); + output.add_(mean); + return output; +} + +template class normal_kernel, typename RNG> +Tensor& normal_out_impl(Tensor& output, double mean, const Tensor& std, std::optional gen) { + CHECK_NORMAL_TENSOR_STD(std); + auto mean_tensor = at::full({}, mean, output.options()); + auto shape = at::infer_size(mean_tensor.sizes(), std.sizes()); + at::native::resize_output(output, shape); + normal_impl_(output, 0, 1, gen); + // CUDA NB: addcmul_out copies the tensor to be added into the output. + // The previous function here was addcmul_out(output, mean_tensor, output, std, 1); + // The third argument is not a constant reference and hence the samples in output are overwritten. + // Consequently, the computation performed is mean_tensor + mean_tensor * std instead of mean_tensor + output * std + output.mul_(std).add_(mean_tensor); + return output; +} + +template class normal_kernel, typename RNG> +Tensor& normal_out_impl(Tensor& output, const Tensor& mean, const Tensor& std, std::optional gen) { + CHECK_NORMAL_TENSOR_STD(std); + auto shape = at::infer_size(mean.sizes(), std.sizes()); + at::native::resize_output(output, shape); + normal_impl_(output, 0, 1, gen); + // CUDA NB: addcmul_out copies the tensor to be added into the output. + // The previous function here was addcmul_out(output, mean, output, std, 1); + // The third argument is not a constant reference and hence the samples in output are overwritten. + // Consequently, the computation performed is mean + mean * std instead of mean + output * std + output.mul_(std).add_(mean); + return output; +} + +template class normal_kernel, typename RNG> +Tensor normal_impl(const Tensor& mean, double std, std::optional gen) { + CHECK_NORMAL_STD(std); + Tensor ret = at::empty_like(mean, MemoryFormat::Contiguous); + normal_out_impl(ret, mean, std, gen); + return ret; +} + +template class normal_kernel, typename RNG> +Tensor normal_impl(double mean, const Tensor& std, std::optional gen) { + CHECK_NORMAL_TENSOR_STD(std); + Tensor ret = at::empty_like(std, MemoryFormat::Contiguous); + normal_out_impl(ret, mean, std, gen); + return ret; +} + +template class normal_kernel, typename RNG> +Tensor normal_impl(const Tensor& mean, const Tensor& std, std::optional gen) { + CHECK_NORMAL_TENSOR_STD(std); + auto shape = at::infer_size(mean.sizes(), std.sizes()); + Tensor ret = at::empty(shape, mean.options(), MemoryFormat::Contiguous); + normal_out_impl(ret, mean, std, gen); + return ret; +} + +// ==================================================== Uniform ======================================================= + +template class uniform_kernel, typename RNG> +at::Tensor& uniform_impl_(at::Tensor& self, double from, double to, std::optional generator) { + if (self.is_complex()) { + CHECK_EMPTY_AND_RETURN(self); + auto float_tensor = at::view_as_real(self); + uniform_impl_(float_tensor, from, to, generator); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "check_uniform_bounds", [&] { + [[maybe_unused]] const auto dtype = self.dtype(); + const auto min = static_cast(std::numeric_limits::lowest()); + const auto max = static_cast(std::numeric_limits::max()); + CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype); + CHECK_OUT_OF_BOUNDS(to, "to", min, max, dtype); + TORCH_CHECK(from <= to, "uniform_ expects to return a [from, to) range, but found from=", from, " > to=", to); + TORCH_CHECK((to - from) <= std::numeric_limits::max(), + "uniform_ expects to-from <= std::numeric_limits<", toString(self.scalar_type()), + ">::max(), but found to=", to, " and from=", from, + " which result in to-from to exceed the limit"); + from = std::min(std::max(from, min), max); + to = std::max(std::min(to, max), min); + }); + CHECK_EMPTY_AND_RETURN(self); + auto iter = at::TensorIterator::borrowing_nullary_op(self); + uniform_kernel()(iter, from, to, generator); + } + return self; +} + +// ================================================== LogNormal ======================================================= + +template class log_normal_kernel, typename RNG> +at::Tensor& log_normal_impl_(at::Tensor& self, double mean, double std, std::optional gen) { + TORCH_CHECK(std > 0.0, "log_normal_ expects std > 0.0, but found std=", std); + CHECK_EMPTY_AND_RETURN(self); + auto iter = TensorIterator::borrowing_nullary_op(self); + log_normal_kernel()(iter, mean, std, gen); + return self; +} + +// =================================================== Geometric ====================================================== + +template class geometric_kernel, typename RNG> +Tensor& geometric_impl_(Tensor& self, double p, std::optional gen) { + TORCH_CHECK(0 < p && p < 1, "geometric_ expects p to be in (0, 1), but got p=", p); + CHECK_EMPTY_AND_RETURN(self); + auto iter = TensorIterator::borrowing_nullary_op(self); + geometric_kernel()(iter, p, gen); + return self; +} + +// ================================================== Exponential ===================================================== + +template class exponential_kernel, typename RNG> +Tensor& exponential_impl_(Tensor& self, double lambda, std::optional gen) { + TORCH_CHECK(lambda > 0.0, "exponential_ expects lambda > 0.0, but found lambda=", lambda); + CHECK_EMPTY_AND_RETURN(self); + auto iter = TensorIterator::borrowing_nullary_op(self); + exponential_kernel()(iter, lambda, gen); + return self; +} + +// ==================================================== Cauchy ======================================================== + +template class cauchy_kernel, typename RNG> +Tensor& cauchy_impl_(Tensor& self, double median, double sigma, std::optional gen) { + // TODO: instead of variable name 'sigma', use 'gamma' or 'scale' + // the variance, squared sigma, is undefined for cauchy distribution + TORCH_CHECK(sigma > 0.0, "cauchy_ expects sigma > 0.0, but found sigma=", sigma); + TORCH_CHECK(at::isFloatingType(self.scalar_type()), "Cauchy distribution is a continuous probability distribution. dtype must be a floating point but you specified ", self.dtype()); + CHECK_EMPTY_AND_RETURN(self); + auto iter = TensorIterator::borrowing_nullary_op(self); + cauchy_kernel()(iter, median, sigma, gen); + return self; +} + +// ==================================================== Bernoulli ===================================================== + +template class bernoulli_tensor_kernel, typename RNG> +Tensor& bernoulli_impl_(Tensor& self, const Tensor& p_, std::optional gen) { + CHECK_EMPTY_AND_RETURN(self); + NoNamesGuard guard; + at::assert_no_internal_overlap(self); + bernoulli_tensor_kernel()(self, p_, gen); + return self; +} + +template class bernoulli_scalar_kernel, typename RNG> +Tensor& bernoulli_impl_(Tensor& self, double p, std::optional gen) { + TORCH_CHECK(0 <= p && p <= 1, "bernoulli_ expects p to be in [0, 1], but got p=", p); + CHECK_EMPTY_AND_RETURN(self); + at::assert_no_internal_overlap(self); + bernoulli_scalar_kernel()(self, p, gen); + return self; +} + +template class bernoulli_tensor_kernel, typename RNG> +Tensor& bernoulli_out_impl(Tensor& result, const Tensor& self, std::optional gen) { + // result.resize_as_(self) requires self to have same dtype as result, so we + // use resize_ instead. + // TODO: Fix resize_as_. See pytorch/pytorch#11665. + result.resize_(self.sizes()); + bernoulli_impl_(result, self, gen); + namedinference::propagate_names(result, self); + return result; +} + +#undef CHECK_OUT_OF_BOUNDS +#undef WARN_OUT_OF_BOUNDS + +} // namespace at::native::templates + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Fill.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Fill.h new file mode 100644 index 0000000000000000000000000000000000000000..e4809d59a3e560a6fa4d4f0613bf6bbfffdf7c86 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Fill.h @@ -0,0 +1,26 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Functions that fill Tensors with constants. Implementations are in Fill.cpp. + +#pragma once + +#include + +namespace c10 { +class Scalar; +} + +namespace at { +class Tensor; +struct TensorIterator; + +namespace native { + +DECLARE_DISPATCH(void(*)(TensorIterator&, const c10::Scalar&), fill_stub) + +Tensor& fill_out(Tensor& self, const Scalar& value); + +}} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ForeachUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ForeachUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..eefc95259f662496b08fc0109d0d6c5e007521f6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ForeachUtils.h @@ -0,0 +1,385 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +#include +#include + +namespace at::native { +namespace { +// Check if tensor list has either a boolean tensor or a integer tensor +inline bool has_integral_tensor(TensorList tensors, const bool includeBool) { + return std::any_of( + tensors.begin(), tensors.end(), [includeBool](const auto& t) { + return at::isIntegralType(t.scalar_type(), includeBool); + }); +} +// check if tensor list has bool tensors +inline bool has_bool_tensor(TensorList tensors) { + return std::any_of(tensors.begin(), tensors.end(), [](const auto& t) -> bool { + return t.scalar_type() == ScalarType::Bool; + }); +} + +// Check foreach API restrictions +// - Tensor lists must be non-empty. +// - All TensorLists and ScalarLists must have the same number of elements. +// - Corresponding tensors must have the same size. +inline void check_foreach_api_restrictions(TensorList tensors) { + TORCH_CHECK(!tensors.empty(), "Tensor list must have at least one tensor."); +} + +inline void check_foreach_api_restrictions( + TensorList tensors, + ArrayRef scalars) { + check_foreach_api_restrictions(tensors); + TORCH_CHECK( + tensors.size() == scalars.size(), + "Tensor list must have same number of elements as scalar list."); +} + +inline void check_foreach_api_restrictions( + TensorList tensors1, + TensorList tensors2) { + check_foreach_api_restrictions(tensors1); + check_foreach_api_restrictions(tensors2); + TORCH_CHECK( + tensors1.size() == tensors2.size(), + "Tensor lists must have the same number of tensors, got ", + tensors1.size(), + " and ", + tensors2.size()); +} + +inline void check_foreach_api_restrictions( + TensorList tensors1, + TensorList tensors2, + TensorList tensors3) { + check_foreach_api_restrictions(tensors1, tensors2); + check_foreach_api_restrictions(tensors1, tensors3); +} + +inline void check_foreach_api_restrictions( + TensorList tensors1, + TensorList tensors2, + TensorList tensors3, + ArrayRef scalars) { + check_foreach_api_restrictions(tensors1, tensors2, tensors3); + check_foreach_api_restrictions(tensors1, scalars); +} + +inline void check_foreach_api_restrictions( + TensorList tensors1, + TensorList tensors2, + ArrayRef scalars) { + check_foreach_api_restrictions(tensors1, tensors2); + check_foreach_api_restrictions(tensors1, scalars); +} + +// Helper function called in check_fast_path_restrictions to check whether all +// corresponding tensors (aligning in index across the tensorLists) share the +// same device and dtype. +inline bool _check_tensors_share_device_and_dtype( + ArrayRef tensorLists, + const bool skip_dtype_check = false) { + const auto expected_dtype = tensorLists[0][0].dtype(); + const auto expected_device = tensorLists[0][0].device(); + + auto is_tensor_okay = [&](const Tensor& tensor) { + return (skip_dtype_check || tensor.dtype() == expected_dtype) && + tensor.device() == expected_device && tensor.layout() == at::kStrided && + tensor.is_non_overlapping_and_dense(); + }; + + return std::all_of( + tensorLists.cbegin(), + tensorLists.cend(), + [&](const TensorList& tensorList) { + return std::all_of( + tensorList.cbegin(), tensorList.cend(), is_tensor_okay); + }); +} + +// Helper function called in check_fast_path_restrictions to check if +// corresponding tensors in tensor lists have the same sizes and strides. +inline bool _check_tensors_share_sizes_and_strides( + ArrayRef tensorLists) { + auto is_diff_stride = [](const IntArrayRef& size, + const IntArrayRef& left_stride, + const IntArrayRef& right_stride) -> bool { + const size_t size_size = size.size(); + for (const auto dim : c10::irange(size_size)) { + if (size[dim] == 1) + continue; + if (left_stride[dim] != right_stride[dim]) { + return true; + } + } + return false; + }; + for (const auto i : c10::irange(1, tensorLists.size())) { + for (const auto j : c10::irange(tensorLists[0].size())) { + if (tensorLists[0][j].sizes() != tensorLists[i][j].sizes() || + is_diff_stride( + tensorLists[0][j].sizes(), + tensorLists[0][j].strides(), + tensorLists[i][j].strides())) { + return false; + } + } + } + + return true; +} + +// Helper function called in check_fast_path_restrictions to check whether +// all tensors type promote properly with the scalars in scalarList. This +// function assumes that _check_tensors_share_device_and_dtype has already been +// called so that all corresponding tensors in tensorLists have the same dtype. +// Then, it is sufficient to check the type promotion with just one tensorList. +inline bool _check_tensors_do_type_promotion_with_scalars( + TensorList tensorList, + ArrayRef scalarList = {}, + bool does_op_promote_integer_inputs_to_float = false) { + for (const auto i : c10::irange(tensorList.size())) { + // For division, integer inputs will result in float. + if (does_op_promote_integer_inputs_to_float && + at::isIntegralType(tensorList[i].scalar_type(), /*includeBool*/ true)) { + return false; + } + if (!scalarList.empty()) { + const auto& scalar = + scalarList.size() == 1 ? scalarList[0] : scalarList[i]; + const auto& tensor = tensorList[i]; + // note(mkozuki): This check might be responsible for + // `_foreach_add(bool_tensors, bool_tensors)` being pushed to slow path. + if (tensor.scalar_type() != at::native::result_type(scalar, tensor)) { + return false; + } + } + } + + return true; +} + +// To go via 'fast' path, several conditions must be satisfied +// - All tensors in all lists must have the same dtype. +// - All tensors must be on the same device +// - All tensors must have strided layout +// - All tensors must be non-overlapping and dense +// - Resulting tensor must have the same dtype as the input one + +// [note: what's ``does_op_promote_integer_inputs_to_float=true``?] +// ``does_op_promote_integer_inputs_to_float=true`` means that the result of +// the op will be float even if inputs are integer or boolean, which +// currently fast path does not support. In short, this flag, when +// turned on, gatekeeps the op from going down the fastpath. + +// Please, make sure to call check_foreach_api_restrictions before calling this +// method. There is a set of preconditions that have to be satisfied. +inline bool check_fast_path_restrictions( + ArrayRef tensorLists, + ArrayRef scalarList = {}, + bool does_op_promote_integer_inputs_to_float = false) { + return _check_tensors_share_device_and_dtype(tensorLists) && + _check_tensors_share_sizes_and_strides(tensorLists) && + _check_tensors_do_type_promotion_with_scalars( + tensorLists[0], + scalarList, + does_op_promote_integer_inputs_to_float); +} + +inline std::vector convert_tensor_to_scalar_list( + const Tensor& scalarList_, + int64_t expect_length) { + std::vector scalarList; + TORCH_CHECK( + scalarList_.device() == c10::kCPU, + "Expected scalars to be on CPU, got ", + scalarList_.device(), + " instead."); + TORCH_CHECK( + scalarList_.is_contiguous(), "Expected scalars to be contiguous."); + TORCH_CHECK( + scalarList_.dim() == 1, + "Expected packed scalar Tensor to be of dimension 1. Got ", + scalarList_.dim(), + " instead."); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( + kComplexHalf, + kHalf, + kBool, + kBFloat16, + scalarList_.scalar_type(), + "convert_tensor_to_scalar_list", + [&]() { + const scalar_t* scalar_data = scalarList_.const_data_ptr(); + TORCH_CHECK( + (expect_length == scalarList_.size(0)), + "Expected length of scalars to match input of length ", + expect_length, + " but got ", + scalarList_.size(0), + " instead."); + for (int64_t i = 0; i < scalarList_.size(0); i++) { + scalarList.emplace_back(scalar_data[i]); + } + }); + return scalarList; +} + +// see: [note: what's ``does_op_promote_integer_inputs_to_float=true``?] +inline bool can_use_fast_route( + ArrayRef tensorLists, + ArrayRef scalarList = {}, + bool does_op_promote_integer_inputs_to_float = false) { + return check_fast_path_restrictions( + tensorLists, scalarList, does_op_promote_integer_inputs_to_float); +} + +// see: [note: what's ``does_op_promote_integer_inputs_to_float=true``?] +inline bool can_use_fast_route( + TensorList tensors1, + TensorList tensors2, + bool does_op_promote_integer_inputs_to_float = false) { + return can_use_fast_route( + {tensors1, tensors2}, {}, does_op_promote_integer_inputs_to_float); +} + +using DeviceDtypeKey = std::pair; +using IndicesT = std::vector; +using nested_optional_tensorvec_t = + std::vector>>; +using TensorsAndIndicesT = std::pair; +using FlatMap = std::unordered_map< + DeviceDtypeKey, + TensorsAndIndicesT, + ParamsHash>; + +inline FlatMap _group_tensors_by_first_tensors_device_and_dtype( + const nested_optional_tensorvec_t& nested_tensorlist, + const bool with_indices) { + FlatMap grouped_tensors_with_indices; + + TORCH_CHECK(!nested_tensorlist.empty()); + TORCH_CHECK(!nested_tensorlist[0].empty()); + const auto num_lists = nested_tensorlist.size(); + const auto num_tensors = nested_tensorlist[0].size(); + + TORCH_CHECK(std::all_of( + nested_tensorlist.cbegin(), + nested_tensorlist.cend(), + [&](const auto& tensorlist) -> bool { + // note(crcrpar): Allow empty tensorlists following + // ref: + // https://github.com/pytorch/pytorch/blob/85885301fd3c6adb8b9dc3cf7afadf6945566684/torch/utils/_foreach_utils.py#L21-L24 + return tensorlist.size() == num_tensors || tensorlist.size() == 0; + })); + + for (const auto& tensor_index : c10::irange(num_tensors)) { + const auto key = [&]() -> DeviceDtypeKey { + const auto t = nested_tensorlist[0][tensor_index]; + TORCH_CHECK( + t.has_value(), + "Tensors of the first list of nested Tensor lists are supposed to be defined but ", + "the ", + tensor_index, + "-th Tensor is not."); + return {t->device(), t->scalar_type()}; + }(); + TORCH_CHECK( + std::all_of( + nested_tensorlist.cbegin(), + nested_tensorlist.cend(), + [&](const auto& tensorlist) -> bool { + if (tensorlist.size() == 0) { + return true; + } + const auto& tensor = tensorlist[tensor_index]; + // note(crcrpar): Currently the scope of this function is + // optimizers so there could be `state_steps` and other scalars + // whose elements are float tensors no matter what the parameter's + // dtype is. + if (!tensor.has_value()) { + return true; + } else { + const auto s = tensor->scalar_type(); + const auto d = tensor->device(); + // Note: `step` or `state_step` is float32 by default. + if (key.first == d) { + return key.second == s || s == at::ScalarType::Float || + s == at::ScalarType::Double; + } else if (d.is_cpu()) { + // note(crcrpar): There are some test cases (e.g. + // TestOptim::test_adam) where state_steps are on CPU and the + // others are on CUDA. Currently a state_step Tensor has the + // dtype of float. + return s == at::ScalarType::Float || + s == at::ScalarType::Double; + } else { + return false; + } + } + }), + "Tensors of the same index must be on the same device and the same dtype except `step` tensors that can be CPU and float32/64 notwithstanding"); + grouped_tensors_with_indices.try_emplace( + key, + TensorsAndIndicesT{ + [&]() -> nested_optional_tensorvec_t { + nested_optional_tensorvec_t nested_tensorvec; + nested_tensorvec.reserve(num_lists); + for (const auto& i : c10::irange(num_lists)) { + std::vector> tensors; + if (!nested_tensorlist[i].empty()) { + // NB: num_tensors is the max possible length for any of + // the inner lists of tensor references. Reserving the max + // trades memory for perf. This should not have significant + // impact. + tensors.reserve(num_tensors); + } + nested_tensorvec.emplace_back(std::move(tensors)); + } + return nested_tensorvec; + }(), + [&]() -> IndicesT { + if (!with_indices) { + return {}; + } else { + IndicesT indices; + indices.reserve(num_tensors); + return indices; + } + }()}); + for (const auto& list_index : c10::irange(num_lists)) { + if (!nested_tensorlist[list_index].empty()) { + grouped_tensors_with_indices[key].first[list_index].emplace_back( + nested_tensorlist[list_index][tensor_index]); + } + } + if (with_indices) { + grouped_tensors_with_indices[key].second.emplace_back(tensor_index); + } + } + + return grouped_tensors_with_indices; +} + +} // namespace +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/FractionalMaxPooling.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/FractionalMaxPooling.h new file mode 100644 index 0000000000000000000000000000000000000000..c52c41a314400cf3db633ff7f1f1a0827a724a68 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/FractionalMaxPooling.h @@ -0,0 +1,85 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include +#include + +namespace at::native { + +template +inline std::vector generate_intervals( + scalar_t sample, + int64_t inputSize, + int64_t outputSize, + int64_t poolSize) { + std::vector sequence(outputSize); + if (outputSize > 1) { + scalar_t alpha = static_cast(inputSize - poolSize) / + static_cast(outputSize - 1); + + for (const auto i : c10::irange(outputSize - 1)) { + sequence[i] = + static_cast((i + sample) * alpha) - static_cast(sample * alpha); + } + } + if (outputSize > 0) { + sequence[outputSize - 1] = inputSize - poolSize; + } + return sequence; +} + +template +inline void fractional_max_pool_check_shape( + const Tensor& input, + const Tensor& randomSamples) { + + TORCH_CHECK( + input.scalar_type() == randomSamples.scalar_type(), + "Expect _random_samples to have the same dtype as input"); + + int64_t ndimension = randomSamples.ndimension(); + TORCH_CHECK( + ndimension == 3, + "Expect _random_samples to have 3 dimensions, got ", ndimension); + + int64_t N = randomSamples.size(0); + int64_t C = randomSamples.size(1); + int64_t D = randomSamples.size(2); + + int64_t input_batch = 0, input_channel = 0; + if (ndim == 2) { + // fractional_max_pool2d + if (input.ndimension() == 3) { + input_batch = 1; + input_channel = input.size(0); + } else { + input_batch = input.size(0); + input_channel = input.size(1); + } + } else { + // factional_max_pool3d + if (input.ndimension() == 4) { + input_batch = 1; + input_channel = input.size(0); + } else { + input_batch = input.size(0); + input_channel = input.size(1); + } + } + + TORCH_CHECK( + N >= input_batch, + "Expect _random_samples.size(0) no less then input batch size."); + TORCH_CHECK( + C == input_channel, + "Expect _random_samples.size(1) equals to input channel size."); + TORCH_CHECK( + D == ndim, + "Expect _random_samples.size(2) equals to ", ndim, "; got ", D, "."); +} + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/FusedAdagrad.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/FusedAdagrad.h new file mode 100644 index 0000000000000000000000000000000000000000..a43863b5b0dca62c5c648dc086d698a24b4a91d0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/FusedAdagrad.h @@ -0,0 +1,25 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include +#include + +namespace at::native { + +using fused_adagrad_fn = void (*)( + const at::Tensor& param, + const at::Tensor& grad, + const at::Tensor& state_sum, + const at::Tensor& state_step, + const double lr, + const double lr_decay, + const double weight_decay, + const double eps, + const bool maximize, + const float* grad_scale_ptr); + +DECLARE_DISPATCH(fused_adagrad_fn, fused_adagrad_stub) + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/FusedAdam.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/FusedAdam.h new file mode 100644 index 0000000000000000000000000000000000000000..0ec3c0e854270d9609bc6eae5e471662db724d10 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/FusedAdam.h @@ -0,0 +1,32 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include +#include + +namespace at::native { + +enum class ADAM_MODE : uint8_t { ORIGINAL = 0, ADAMW = 1 }; + +using fused_adam_fn = void (*)( + const at::Tensor& param, + const at::Tensor& grad, + const at::Tensor& exp_avg, + const at::Tensor& exp_avg_sq, + const at::Tensor& max_exp_avg_sq, + const at::Tensor& state_step, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool amsgrad, + const bool maximize, + const float* grad_scale_ptr, + const ADAM_MODE); + +DECLARE_DISPATCH(fused_adam_fn, fused_adam_stub) + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/FusedSGD.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/FusedSGD.h new file mode 100644 index 0000000000000000000000000000000000000000..ed97d3c19054c62b65d7a0cd31e8eb1f3cfa61e1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/FusedSGD.h @@ -0,0 +1,26 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include +#include + +namespace at::native { + +using fused_sgd_fn = void (*)( + const at::Tensor& param, + const at::Tensor& grad, + const at::Tensor& momentum_buffer, + const double weight_decay, + const double momentum, + const double lr, + const double dampening, + const bool nesterov, + const bool maximize, + const bool is_first_step, + const float* grad_scale_ptr); + +DECLARE_DISPATCH(fused_sgd_fn, fused_sgd_stub) + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Gelu.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Gelu.h new file mode 100644 index 0000000000000000000000000000000000000000..e2bf52c98dd06d8764fb9dccd99a8843cb3bcd5f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Gelu.h @@ -0,0 +1,38 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace at::native { +// These constants control the approximation behavior of gelu function. +enum class GeluType { + None, // Baseline Gelu + Tanh, // Tanh Gelu Approximation + END +}; + +inline GeluType get_gelutype_enum(const std::string_view approximate) { + if (approximate == "none") { + return GeluType::None; + } else if (approximate == "tanh") { + return GeluType::Tanh; + } else { + TORCH_CHECK(false, "approximate argument must be either none or tanh."); + } +} + +inline std::string gelutype_to_string(const GeluType type) { + switch(type) { + case GeluType::None: return "none"; + case GeluType::Tanh: return "tanh"; + default: TORCH_CHECK(false, "unknown GELU type: ", static_cast(type)); + } +} + + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/GridSamplerUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/GridSamplerUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..bd68648b3ebd1e1ae48f334659652323984c2dd6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/GridSamplerUtils.h @@ -0,0 +1,116 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// See NOTE: [Tensor vs. TensorBase] +// https://github.com/pytorch/pytorch/pull/66979 +#include +#include +#include + +namespace at::native { + +namespace detail { + +enum class GridSamplerInterpolation {Bilinear, Nearest, Bicubic}; +enum class GridSamplerPadding {Zeros, Border, Reflection}; + +} // namespace detail + +using detail::GridSamplerInterpolation; +using detail::GridSamplerPadding; + +// See NOTE [ grid_sampler Native Functions ]. +inline void check_grid_sampler_common( + const TensorBase& input, + const TensorBase& grid +) { + auto input_opt = input.options(); + auto grid_opt = grid.options(); + + TORCH_CHECK( + input.defined(), + "grid_sampler(): expected input to not be undefined"); + TORCH_CHECK( + grid.defined(), + "grid_sampler(): expected grid to not be undefined"); + TORCH_CHECK( + input_opt.device() == grid_opt.device(), + "grid_sampler(): expected input and grid to be on same device, but input " + "is on ", input_opt.device(), " and grid is on ", grid_opt.device()); + TORCH_CHECK( + input_opt.layout() == kStrided && grid_opt.layout() == kStrided, + "grid_sampler(): expected input and grid to have torch.strided layout, but " + "input has ", input_opt.layout(), " and grid has ", grid_opt.layout()); + TORCH_CHECK( + input.size(0) == grid.size(0), + "grid_sampler(): expected grid and input to have same batch size, but got " + "input with sizes ", input.sizes(), " and grid with sizes ", grid.sizes()); + TORCH_CHECK( + grid.size(-1) == input.dim() - 2, + "grid_sampler(): expected grid to have size ", input.dim() - 2, " in last " + "dimension, but got grid with sizes ", grid.sizes()); + + for (const auto i : c10::irange(2, input.dim())) { + TORCH_CHECK(input.size(i) > 0, + "grid_sampler(): expected input to have non-empty spatial dimensions, " + "but input has sizes ", input.sizes(), " with dimension ", i, " being " + "empty"); + } +} + +// See NOTE [ grid_sampler Native Functions ]. +inline void check_grid_sampler_2d( + const TensorBase& input, + const TensorBase& grid +) { + TORCH_CHECK( + input.dim() == 4 && input.dim() == grid.dim(), + "grid_sampler(): expected 4D input and grid with same number of " + "dimensions, but got input with sizes ", input.sizes(), + " and grid with sizes ", grid.sizes()); +} + +// See NOTE [ grid_sampler Native Functions ]. +inline void check_grid_sampler_3d( + const TensorBase& input, + const TensorBase& grid, + int64_t interpolation_mode +) { + TORCH_CHECK( + input.dim() == 5 && input.dim() == grid.dim(), + "grid_sampler(): expected 5D input and grid with same number of " + "dimensions, but got input with sizes ", input.sizes(), + " and grid with sizes ", grid.sizes()); + TORCH_CHECK( + !(input.dim() == 5 && + static_cast(interpolation_mode) == + GridSamplerInterpolation::Bicubic), + "grid_sampler(): bicubic interpolation only supports 4D input"); +} + +// See NOTE [ grid_sampler Native Functions ]. +// cudnn does not support inputs larger than 1024. +inline bool cond_cudnn_grid_sampler( + const TensorBase& input, + const TensorBase& grid +) { + auto st = input.scalar_type(); + if (!(st == kDouble || st == kFloat || st == kHalf)) + return false; + st = grid.scalar_type(); + if (!(st == kDouble || st == kFloat || st == kHalf)) + return false; + return ( + at::native::cudnn_is_acceptable(input) && + at::native::cudnn_is_acceptable(grid) && + at::native::canUse32BitIndexMath(input) && + at::native::canUse32BitIndexMath(grid) && + input.dim() == 4 && + input.sym_size(1) <= 1024); +} + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/GroupedMMUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/GroupedMMUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..e63c1acf163bddb4c0b1b5fb7c8ec658b3c0dd84 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/GroupedMMUtils.h @@ -0,0 +1,172 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#include +#else +#include +#include +#include +#include +#endif + +namespace at::native { + +inline bool check_valid_strides_and_return_transposed(const Tensor& mat) { + IntArrayRef tensor_strides = mat.strides(); + IntArrayRef tensor_sizes = mat.sizes(); + int end_dim = mat.dim() - 1; + int alignment = 16 / mat.element_size(); + TORCH_CHECK(uint64_t(mat.data_ptr()) % 16 ==0, "expected data_ptr to be aligned to 16 bytes\n"); + if ((tensor_strides[end_dim - 1] == 1) && (tensor_strides[end_dim] >= std::max(1, tensor_sizes[end_dim - 1]))) { + TORCH_CHECK(tensor_strides[end_dim] % alignment == 0, "strides should be multiple of 16 bytes"); + return true; + } else if ((tensor_strides[end_dim] == 1) && (tensor_strides[end_dim - 1] >= std::max(1, tensor_sizes[end_dim]))) { + TORCH_CHECK(tensor_strides[end_dim - 1] % alignment == 0, "strides should be multiple of 16 bytes"); + return false; + } else { + TORCH_CHECK(false, "Invalid strides/sizes, got ", mat.strides(), " for strides and ", mat.sizes(), " for sizes"); + } +} + +inline at::Tensor create_grouped_gemm_output_tensor(const Tensor& mat_a, +const Tensor& mat_b, +const std::optional& offs, +c10::ScalarType out_dtype +) { + c10::SmallVector out_size; + const bool a_is_2d = mat_a.dim() == 2; + const bool b_is_2d = mat_b.dim() == 2; + if (a_is_2d) { + if (b_is_2d) { + out_size = {offs->size(0), mat_a.size(0), mat_b.size(1)}; + } else { + TORCH_CHECK(offs->size(0) == mat_b.size(0), "matrix batch sizes have to match"); + out_size = {mat_a.size(0), mat_b.size(-1)}; + } + } else { + if (b_is_2d) { + // this case is not actually encountered for MoE gemms + TORCH_CHECK(offs->size(0) == mat_a.size(0), "matrix batch sizes have to match"); + out_size = {mat_a.size(1), mat_b.size(1)}; + } else { // regular bmm + TORCH_CHECK(mat_a.size(0) == mat_b.size(0), "batched dimension has to match"); + out_size = {mat_a.size(0), mat_a.size(1), mat_b.size(-1)}; + } + } + + #ifndef USE_ROCM + // For TMA transfers, strides of output tensor have to be either + // 1, or aligned to 16 bytes. + const auto last_dim = out_size.size() - 1; + const auto alignment = 16 / c10::elementSize(out_dtype); + const int64_t size_padded = (out_size[last_dim] + alignment - 1) / alignment * alignment; + std::vector out_stride; + if (a_is_2d != b_is_2d) { + out_stride = {size_padded, 1}; + } else { + out_stride = {out_size[1] * size_padded, size_padded, 1}; + } + return at::empty_strided(out_size, out_stride, mat_a.options().dtype(out_dtype)); + #else + return at::empty(out_size, mat_a.options().dtype(out_dtype)); + #endif +} + +inline void _grouped_mm_validate_inputs(const Tensor& mat_a, const Tensor& mat_b, +const std::optional& offs, +const std::optional& bias, +std::optional out_dtype) { + TORCH_CHECK((mat_a.dtype() == at::kBFloat16) || (mat_a.dtype() == at::kFloat) || (mat_a.dtype() == at::kHalf), "Expected mat_a to be Float32, BFloat16 or Float16 matrix, got ", mat_a.scalar_type()); + TORCH_CHECK((mat_b.dtype() == at::kBFloat16) || (mat_b.dtype() == at::kFloat) || (mat_b.dtype() == at::kHalf), "Expected mat_b to be Float32, BFloat16 or Float16 matrix, got ", mat_b.scalar_type()); + TORCH_CHECK(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d"); + TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d"); + const bool a_is_2d = mat_a.dim() == 2; + const bool b_is_2d = mat_b.dim() == 2; + if (!a_is_2d || !b_is_2d) { + TORCH_CHECK(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match"); + } + + // check that the strides are valid, the fn will throw an error if not + check_valid_strides_and_return_transposed(mat_a); + check_valid_strides_and_return_transposed(mat_b); + TORCH_CHECK(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix, or no offset if both matrices are 3d"); + + if (offs.has_value()) { + TORCH_CHECK(offs->dim() == 1, "offs has to be 1D"); + TORCH_CHECK(offs->dtype() == at::kInt, "Offsets have to be int32"); + } + TORCH_CHECK(!bias.has_value(), "Bias not supported yet"); +} + +inline c10::ScalarType _resolve_grouped_mm_out_dtype(const Tensor& mat_a, const Tensor& mat_b, +std::optional out_dtype) { + const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type()); + // TODO(future PR): enable float32 output dtype for bfloat16 and float16 inputs + TORCH_CHECK(out_dtype_ == mat_a.dtype(), "Grouped gemm output dtype must match `mat_a` dtype"); + return out_dtype_; +} + + +inline void _grouped_mm_fallback(const Tensor& mat_a, const Tensor& mat_b, +const std::optional& offs, +const std::optional& bias, +std::optional out_dtype, +Tensor out) { + LOG(INFO) << "fallback path for `torch._grouped_mm`, performance may not be optimal"; + const bool a_is_2d = mat_a.dim() == 2; + const bool b_is_2d = mat_b.dim() == 2; + if (a_is_2d && !b_is_2d) { + // 2d x 3d with offsets + int group_start_idx = 0; + auto offs_cpu = offs.value().cpu(); + for (int group_idx = 0; group_idx < offs_cpu.size(0); group_idx++) { + int group_end_idx = offs_cpu[group_idx].item(); + auto mat_a_slice = mat_a.slice(0, group_start_idx, group_end_idx); + auto out_slice = out.slice(0, group_start_idx, group_end_idx); + at::mm_out(out_slice, mat_a_slice, mat_b[group_idx]); + group_start_idx = group_end_idx; + } + + } else if (!a_is_2d && b_is_2d) { + // 3d x 2d with offsets + int group_start_idx = 0; + auto offs_cpu = offs.value().cpu(); + for (int group_idx = 0; group_idx < offs_cpu.size(0); group_idx++) { + int group_end_idx = offs_cpu[group_idx].item(); + auto mat_b_slice = mat_b.slice(1, group_start_idx, group_end_idx); + auto out_slice = out.slice(1, group_start_idx, group_end_idx); + at::mm_out(out_slice, mat_a[group_idx], mat_b_slice); + group_start_idx = group_end_idx; + } + + } else if (a_is_2d && b_is_2d) { + // 2d x 2d with offsets + int group_start_idx = 0; + auto offs_cpu = offs.value().cpu(); + for (int group_idx = 0; group_idx < offs_cpu.size(0); group_idx++) { + int group_end_idx = offs_cpu[group_idx].item(); + auto mat_a_slice = mat_a.slice(1, group_start_idx, group_end_idx); + auto mat_b_slice = mat_b.slice(0, group_start_idx, group_end_idx); + auto out_slice = out[group_idx]; + at::mm_out(out_slice, mat_a_slice, mat_b_slice); + group_start_idx = group_end_idx; + } + + } else { + // 3d x 3d without offsets - regular bmm + at::bmm_out(out, mat_a, mat_b); + } +} + + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Histogram.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Histogram.h new file mode 100644 index 0000000000000000000000000000000000000000..7ef0b8c873a0e6d9c5bac1c8e7e495ad854c87ae --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Histogram.h @@ -0,0 +1,21 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace at::native { + +using histogramdd_fn = void(*)(const Tensor&, const std::optional&, bool, Tensor&, const TensorList&); +using histogramdd_linear_fn = void(*)(const Tensor&, const std::optional&, bool, Tensor&, const TensorList&, bool); +using histogram_select_outer_bin_edges_fn = void(*)(const Tensor& input, const int64_t N, std::vector &leftmost_edges, std::vector &rightmost_edges); + +DECLARE_DISPATCH(histogramdd_fn, histogramdd_stub) +DECLARE_DISPATCH(histogramdd_linear_fn, histogramdd_linear_stub) +DECLARE_DISPATCH(histogram_select_outer_bin_edges_fn, histogram_select_outer_bin_edges_stub) + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/IndexKernel.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/IndexKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..00b2a30be157021b5d934108e68282ce5c51e708 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/IndexKernel.h @@ -0,0 +1,46 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include + +namespace at { +class Tensor; +class TensorBase; +struct TensorIterator; +struct TensorIteratorBase; +} + +namespace c10 { +class Scalar; +} + +namespace at::native { + +using index_fn = void(*)(TensorIteratorBase &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides); +using index_fill_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride, const Scalar& source); +using index_copy_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride); +using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate); +using put_fn = void(*)(TensorIterator & iter, const TensorBase& self, const bool accumulate); +using take_fn = void(*)(TensorIterator & iter, const TensorBase& input); +using flip_fn = void(*)(TensorIterator &, const bool); +using masked_fill_fn = void(*)(TensorIterator &, const Scalar& scalar); +using masked_select_fn = void(*)(TensorIterator &, int64_t orig_stride); +using masked_scatter_fn = void(*)(TensorIterator &, const TensorBase &); + +DECLARE_DISPATCH(index_fn, index_stub) +DECLARE_DISPATCH(index_fill_fn, index_fill_stub) +DECLARE_DISPATCH(index_copy_fn, index_copy_stub) +DECLARE_DISPATCH(index_put_fn, index_put_stub) +DECLARE_DISPATCH(put_fn, put_stub) +DECLARE_DISPATCH(take_fn, take_stub) +DECLARE_DISPATCH(flip_fn, flip_stub) +DECLARE_DISPATCH(masked_fill_fn, masked_fill_stub) +DECLARE_DISPATCH(masked_select_fn, masked_select_serial_stub) +DECLARE_DISPATCH(masked_select_fn, masked_select_stub) +DECLARE_DISPATCH(masked_scatter_fn, masked_scatter_stub) + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/LinearAlgebra.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/LinearAlgebra.h new file mode 100644 index 0000000000000000000000000000000000000000..fa0e95971a5231ef8327f12511a25509a6e19f69 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/LinearAlgebra.h @@ -0,0 +1,22 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace c10 { +class Scalar; +} + +namespace at { +struct TensorIterator; +} + +namespace at::native { + +using addr_fn = void (*)(TensorIterator &, const Scalar& beta, const Scalar& alpha); +DECLARE_DISPATCH(addr_fn, addr_stub) +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Math.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Math.h new file mode 100644 index 0000000000000000000000000000000000000000..fed8e3e14cbff16108cce5c98e07ba4e276d8066 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Math.h @@ -0,0 +1,3932 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") +#endif + +/* The next function is taken from https://github.com/antelopeusersgroup/antelope_contrib/blob/master/lib/location/libgenloc/erfinv.c. +Below is the copyright. +Output was modified to be inf or -inf when input is 1 or -1. */ + + +/* + Copyright (c) 2014 Indiana University + All rights reserved. + + Written by Prof. Gary L. Pavlis, Dept. of Geol. Sci., + Indiana University, Bloomington, IN + + This software is licensed under the New BSD license: + + Redistribution and use in source and binary forms, + with or without modification, are permitted provided + that the following conditions are met: + + Redistributions of source code must retain the above + copyright notice, this list of conditions and the + following disclaimer. + + Redistributions in binary form must reproduce the + above copyright notice, this list of conditions and + the following disclaimer in the documentation and/or + other materials provided with the distribution. + + Neither the name of Indiana University nor + the names of its contributors may be used to endorse + or promote products derived from this software without + specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND + CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED + WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A + PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL + THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY + DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF + USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER + IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE + USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + POSSIBILITY OF SUCH DAMAGE. +*/ + +namespace { +/* + * This function is derived from the implementation of the i0e function in the + * Cephes Math Library. See note [3-Clause BSD License for the Cephes Math + * Library]. + * + * Computes an approximation of the exponentially scaled zeroth order modified + * Bessel function of the first kind. The approximation is actually two + * (sub)approximations, both using a Chebyshev polynomial expansion. One + * approximates the function over [0, 8], and the other over (8, infinity). This + * function takes the absolute value of all inputs to convert them into the + * domain of the approximation. + */ +jiterator_also_stringify_as(jiterator_code( + template + JITERATOR_HOST_DEVICE T chbevl(T x, const T array[], const int len) { + T b0, b1, b2; + + b0 = array[0]; + b1 = 0; + + for (int i = 1; i < len; ++i) { + b2 = b1; + b1 = b0; + b0 = x * b1 - b2 + array[i]; + } + + return T{0.5} * (b0 - b2); + } + + template + JITERATOR_HOST_DEVICE T calc_i0e(T _x) { + T x = std::fabs(_x); + + if (x <= T{8.0}) { + static const T coefficients[] = { + -4.41534164647933937950E-18, 3.33079451882223809783E-17, + -2.43127984654795469359E-16, 1.71539128555513303061E-15, + -1.16853328779934516808E-14, 7.67618549860493561688E-14, + -4.85644678311192946090E-13, 2.95505266312963983461E-12, + -1.72682629144155570723E-11, 9.67580903537323691224E-11, + -5.18979560163526290666E-10, 2.65982372468238665035E-9, + -1.30002500998624804212E-8, 6.04699502254191894932E-8, + -2.67079385394061173391E-7, 1.11738753912010371815E-6, + -4.41673835845875056359E-6, 1.64484480707288970893E-5, + -5.75419501008210370398E-5, 1.88502885095841655729E-4, + -5.76375574538582365885E-4, 1.63947561694133579842E-3, + -4.32430999505057594430E-3, 1.05464603945949983183E-2, + -2.37374148058994688156E-2, 4.93052842396707084878E-2, + -9.49010970480476444210E-2, 1.71620901522208775349E-1, + -3.04682672343198398683E-1, 6.76795274409476084995E-1}; + + T y = (x / T{2.0}) - T{2.0}; + return chbevl(y, coefficients, int{30}); + } + + // x > 8 + static const T coefficients[] = { + -7.23318048787475395456E-18, -4.83050448594418207126E-18, + 4.46562142029675999901E-17, 3.46122286769746109310E-17, + -2.82762398051658348494E-16, -3.42548561967721913462E-16, + 1.77256013305652638360E-15, 3.81168066935262242075E-15, + -9.55484669882830764870E-15, -4.15056934728722208663E-14, + 1.54008621752140982691E-14, 3.85277838274214270114E-13, + 7.18012445138366623367E-13, -1.79417853150680611778E-12, + -1.32158118404477131188E-11, -3.14991652796324136454E-11, + 1.18891471078464383424E-11, 4.94060238822496958910E-10, + 3.39623202570838634515E-9, 2.26666899049817806459E-8, + 2.04891858946906374183E-7, 2.89137052083475648297E-6, + 6.88975834691682398426E-5, 3.36911647825569408990E-3, + 8.04490411014108831608E-1}; + + return chbevl(T{32.0} / x - T{2.0}, coefficients, int{25}) / std::sqrt(x); + }), + i0e_string) // i0e_string +} + +#define CENTRAL_RANGE 0.7 + +template +inline typename std::enable_if_t, T> +calc_erfinv(T y) { +/* Function to calculate inverse error function. Rational approximation +is used to generate an initial approximation, which is then improved to +full accuracy by two steps of Newton's method. Code is a direct +translation of the erfinv m file in matlab version 2.0. +Author: Gary L. Pavlis, Indiana University +Date: February 1996 +*/ + T x, z, num, dem; /*working variables */ + /* coefficients in rational expansion */ + T a[4] = { T(0.886226899), T(-1.645349621), T(0.914624893), T(-0.140543331) }; + T b[4] = { T(-2.118377725), T(1.442710462), T(-0.329097515), T(0.012229801) }; + T c[4] = { T(-1.970840454), T(-1.624906493), T(3.429567803), T(1.641345311) }; + T d[2] = { T(3.543889200), T(1.637067800) }; + T y_abs = std::abs(y); + if(y_abs > 1.0) return std::numeric_limits::quiet_NaN(); +#ifdef _WIN32 + // error C2039: '_copysign': is not a member of 'std' + if(y_abs == 1.0) return copysign(std::numeric_limits::infinity(), y); +#else + if(y_abs == 1.0) return std::copysign(std::numeric_limits::infinity(), y); +#endif + if(y_abs <= static_cast(CENTRAL_RANGE)) { + z = y * y; + num = (((a[3]*z + a[2])*z + a[1])*z + a[0]); + dem = ((((b[3]*z + b[2])*z + b[1])*z +b[0]) * z + static_cast(1.0)); + x = y * num / dem; + } + else{ + z = std::sqrt(-std::log((static_cast(1.0)-y_abs)/static_cast(2.0))); + num = ((c[3]*z + c[2])*z + c[1]) * z + c[0]; + dem = (d[1]*z + d[0])*z + static_cast(1.0); +#ifdef _WIN32 + // error C2039: '_copysign': is not a member of 'std' + x = copysign(num, y) / dem; +#else + x = std::copysign(num, y) / dem; +#endif + } + /* Two steps of Newton-Raphson correction */ + x = x - (std::erf(x) - y) / ((static_cast(2.0)/static_cast(std::sqrt(c10::pi)))*std::exp(-x*x)); + x = x - (std::erf(x) - y) / ((static_cast(2.0)/static_cast(std::sqrt(c10::pi)))*std::exp(-x*x)); + + return x; +} + +#undef CENTRAL_RANGE + +/* + * Note [3-Clause BSD License for the Cephes Math Library] + * Code derived from implementations in the Cephes Math Library should mention its derivation and reference + * this note (ex. 'This function is derived from the implementation of X in the Cephes Math Library. See note + * [3-Clause BSD License for the Cephes Math Library]. The license is: + * Copyright (c) 2018, Steven Moshier + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL Steven Moshier BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +/* + * This function is derived from the implementation of the zeta function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + */ +template +C10_HOST_DEVICE inline scalar_t zeta(scalar_t x, scalar_t q) __ubsan_ignore_float_divide_by_zero__ { + using acc_t = at::acc_type; + const acc_t MACHEP = acc_t{1.11022302462515654042E-16}; + constexpr acc_t zero = acc_t{0.0}; + constexpr acc_t half = acc_t{0.5}; + constexpr acc_t one = acc_t{1.0}; + static const acc_t A[] = { + 12.0, + -720.0, + 30240.0, + -1209600.0, + 47900160.0, + -1.8924375803183791606e9, /*1.307674368e12/691*/ + 7.47242496e10, + -2.950130727918164224e12, /*1.067062284288e16/3617*/ + 1.1646782814350067249e14, /*5.109094217170944e18/43867*/ + -4.5979787224074726105e15, /*8.028576626982912e20/174611*/ + 1.8152105401943546773e17, /*1.5511210043330985984e23/854513*/ + -7.1661652561756670113e18 /*1.6938241367317436694528e27/236364091*/ + }; + + acc_t a, b, k, s, t, w; + if (x == one) { + return std::numeric_limits::infinity(); + } + + if (x < one) { + return std::numeric_limits::quiet_NaN(); + } + + if (q <= zero) { + if (q == std::floor(q)) { + return std::numeric_limits::infinity(); + } + if (x != std::floor(x)) { + return std::numeric_limits::quiet_NaN(); + } + } + + s = std::pow(q, -x); + a = q; + int i = 0; + b = zero; + while ((i < 9) || (a <= acc_t{9.0})) { + i += 1; + a += one; + b = ::pow(a, -x); + s += b; + if ((-MACHEP * s < b) && (b < MACHEP * s)) { + return static_cast(s); + } + }; + + w = a; + s += b * w / (x - one); + s -= half * b; + a = one; + k = zero; + for (i = 0; i < 12; i++) { + a *= x + k; + b /= w; + t = a * b / A[i]; + s = s + t; + t = ::fabs(t / s); + if (t < MACHEP) { + return static_cast(s); + } + k += one; + a *= x + k; + b /= w; + k += one; + } + return static_cast(s); +} + +/* + * This function is derived from the implementation of the digamma function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + * + * Evaluates polynomial of degree N: + * + * 2 N + * y = C + C x + C x +...+ C x + * 0 1 2 N + * + * Coefficients are stored in reverse order: + * + * coef[0] = C , ..., coef[N] = C . + * N 0 + */ +template +C10_HOST_DEVICE inline T polevl(const T x, const T A[], size_t len) { + T result = 0; + for (size_t i = 0; i <= len; i++) { + result = result * x + A[i]; + } + return result; +} + +inline double trigamma(double x) __ubsan_ignore_float_divide_by_zero__ { + double sign = +1; + double result = 0; + if (x < 0.5) { + sign = -1; + const double sin_pi_x = sin(c10::pi * x); + result -= (c10::pi * c10::pi) / (sin_pi_x * sin_pi_x); + x = 1 - x; + } + for (int i = 0; i < 6; ++i) { + result += 1 / (x * x); + x += 1; + } + const double ixx = 1 / (x*x); + result += (1 + 1 / (2*x) + ixx * (1./6 - ixx * (1./30 - ixx * (1./42)))) / x; + return sign * result; +} + +inline float trigamma(float x) __ubsan_ignore_float_divide_by_zero__ { + float sign = +1; + float result = 0; + if (x < 0.5f) { + sign = -1; + const float sin_pi_x = sinf(c10::pi * x); + result -= (c10::pi * c10::pi) / (sin_pi_x * sin_pi_x); + x = 1 - x; + } + for (int i = 0; i < 6; ++i) { + result += 1 / (x * x); + x += 1; + } + const float ixx = 1 / (x*x); + result += (1 + 1 / (2*x) + ixx * (1.f/6 - ixx * (1.f/30 - ixx * (1.f/42)))) / x; + return sign * result; +} + +/* + * This function is derived from the implementation of the digamma function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + */ +inline double calc_digamma(double x) { + // [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma + static double PSI_10 = 2.25175258906672110764; + if (x == 0) { + // As per C++ standard for gamma related functions and SciPy, + // If the argument is ±0, ±∞ is returned + return std::copysign(INFINITY, -x); + } + + bool x_is_integer = x == trunc(x); + if (x < 0) { + if (x_is_integer) { + // As per C++ standard for gamma related functions and SciPy, + // If the argument is a negative integer, NaN is returned + return std::numeric_limits::quiet_NaN(); + } + // Extracts the fractional part of x as r, since tan(pi * r) is more numerically + // accurate than tan(pi * x). While these operations are mathematically equivalent + // since both x and r are in radians and tan() has a periodicity of pi, in practice + // the computation of pi * x is a source of error (when |x| > 1). + double q, r; + r = std::modf(x, &q); + return calc_digamma(1 - x) - c10::pi / tan(c10::pi * r); + } + + // Push x to be >= 10 + double result = 0; + while (x < 10) { + result -= 1 / x; + x += 1; + } + if (x == 10) { + return result + PSI_10; + } + + // Compute asymptotic digamma + static const double A[] = { + 8.33333333333333333333E-2, + -2.10927960927960927961E-2, + 7.57575757575757575758E-3, + -4.16666666666666666667E-3, + 3.96825396825396825397E-3, + -8.33333333333333333333E-3, + 8.33333333333333333333E-2, + }; + + double y = 0; + if (x < 1.0e17) { + double z = 1.0 / (x * x); + y = z * polevl(z, A, 6); + } + return result + log(x) - (0.5 / x) - y; +} + +/* + * This function is derived from the implementation of the digamma function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + */ +inline float calc_digamma(float x) { + // See [C++ Standard Reference: Gamma Function] + static float PSI_10 = 2.25175258906672110764f; + if (x == 0) { + // As per C++ standard for gamma related functions and SciPy, + // If the argument is ±0, ±∞ is returned + return std::copysign(INFINITY, -x); + } + + bool x_is_integer = x == truncf(x); + if (x < 0) { + if (x_is_integer) { + // As per C++ standard for gamma related functions and SciPy, + // If the argument is a negative integer, NaN is returned + return std::numeric_limits::quiet_NaN(); + } + // Extracts the fractional part of x as r, since tan(pi * r) is more numerically + // accurate than tan(pi * x). While these operations are mathematically equivalent + // since both x and r are in radians and tan() has a periodicity of pi, in practice + // the computation of pi * x is a source of error (when |x| > 1). + double q, r; + r = std::modf(x, &q); + float pi_over_tan_pi_x = (float)(c10::pi / tan(c10::pi * r)); + return calc_digamma(1 - x) - pi_over_tan_pi_x; + } + + // Push x to be >= 10 + float result = 0; + while (x < 10) { + result -= 1 / x; + x += 1; + } + if (x == 10) { + return result + PSI_10; + } + + // Compute asymptotic digamma + static const float A[] = { + 8.33333333333333333333E-2f, + -2.10927960927960927961E-2f, + 7.57575757575757575758E-3f, + -4.16666666666666666667E-3f, + 3.96825396825396825397E-3f, + -8.33333333333333333333E-3f, + 8.33333333333333333333E-2f, + }; + + float y = 0; + if (x < 1.0e17f) { + float z = 1 / (x * x); + y = z * polevl(z, A, 6); + } + return result + logf(x) - (0.5f / x) - y; +} + +inline c10::BFloat16 calc_digamma(c10::BFloat16 a) { + return calc_digamma(static_cast(a)); +} + +inline c10::Half calc_digamma(c10::Half a) { + return calc_digamma(static_cast(a)); +} + +template +inline C10_HOST_DEVICE scalar_t calc_polygamma(scalar_t x, int n) { + // already blocked if n <= 1 + const auto one = scalar_t{1}; + return ((n % 2) ? one : -one) * + std::exp(std::lgamma(static_cast(n) + one)) * + zeta(static_cast(n + 1), x); +} + +// regularized lower incomplete gamma +// the regularized lower, upper incomplete gamma, as well as their +// helper functions follow SciPy's implementation + +/* References + * [igam1] "The Digital Library of Mathematical Functions", dlmf.nist.gov + * [igam2] Maddock et al., "Incomplete Gamma Functions", + * https://www.boost.org/doc/libs/1_61_0/libs/math/doc/html/math_toolkit/sf_gamma/igamma.html + */ + +/* + * This implementation of the regularized incomplete gamma functions and + * their helper functions are derived from the implementation of SciPy's + * gammainc, Cephes's igam and igamc, and Boost's Lanczos approximations. + * See NOTICE for the licenses. + */ +template +scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M, + const scalar_t denom[], int64_t N) { + // evaluating rational function, i.e., the ratio of two polynomials + // the coefficients for numerator are given by `num` while coeffs for + // denumerator are given by `denom` + + int64_t i, dir; + scalar_t y, num_ans, denom_ans; + scalar_t absx = std::fabs(x); + const scalar_t *p; + + if (absx > 1) { + /* Evaluate as a polynomial in 1/x. */ + dir = -1; + p = num + M; + y = 1 / x; + } + else { + dir = 1; + p = num; + y = x; + } + + /* Evaluate the numerator */ + num_ans = *p; + p += dir; + for (i = 1; i <= M; i++) { + num_ans = num_ans * y + *p; + p += dir; + } + /* Evaluate the denominator */ + if (absx > 1) { + p = denom + N; + } + else { + p = denom; + } + + denom_ans = *p; + p += dir; + for (i = 1; i <= N; i++) { + denom_ans = denom_ans * y + *p; + p += dir; + } + if (absx > 1) { + i = N - M; + return std::pow(x, i) * num_ans / denom_ans; + } + else { + return num_ans / denom_ans; + } +} + +// SciPy's lanczos implementation is taken from Boost +/* (C) Copyright John Maddock 2006. + * Use, modification and distribution are subject to the + * Boost Software License, Version 1.0. See + * https://www.boost.org/LICENSE_1_0.txt or see NOTICE. + */ +template +static scalar_t lanczos_sum_expg_scaled(scalar_t x) { + // lanczos approximation + static constexpr scalar_t lanczos_sum_expg_scaled_num[13] = { + 0.006061842346248906525783753964555936883222, + 0.5098416655656676188125178644804694509993, + 19.51992788247617482847860966235652136208, + 449.9445569063168119446858607650988409623, + 6955.999602515376140356310115515198987526, + 75999.29304014542649875303443598909137092, + 601859.6171681098786670226533699352302507, + 3481712.15498064590882071018964774556468, + 14605578.08768506808414169982791359218571, + 43338889.32467613834773723740590533316085, + 86363131.28813859145546927288977868422342, + 103794043.1163445451906271053616070238554, + 56906521.91347156388090791033559122686859 + }; + static constexpr scalar_t lanczos_sum_expg_scaled_denom[13] = { + 1., + 66., + 1925., + 32670., + 357423., + 2637558., + 13339535., + 45995730., + 105258076., + 150917976., + 120543840., + 39916800., + 0. + }; + return ratevl(x, lanczos_sum_expg_scaled_num, + sizeof(lanczos_sum_expg_scaled_num) / sizeof(lanczos_sum_expg_scaled_num[0]) - 1, + lanczos_sum_expg_scaled_denom, + sizeof(lanczos_sum_expg_scaled_denom) / sizeof(lanczos_sum_expg_scaled_denom[0]) - 1); +} + +template +static scalar_t _igam_helper_fac(scalar_t a, scalar_t x) { + // compute x^a * exp(-a) / gamma(a) + // corrected from (15) and (16) in [igam2] by replacing exp(x - a) with + // exp(a - x). + + scalar_t ax, fac, res, num, numfac; + static scalar_t MAXLOG = std::is_same_v ? + 7.09782712893383996843E2 : 88.72283905206835; + static scalar_t EXP1 = 2.718281828459045; + static scalar_t lanczos_g = 6.024680040776729583740234375; + + if (std::fabs(a - x) > 0.4 * std::fabs(a)) { + ax = a * std::log(x) - x - std::lgamma(a); + if (ax < -MAXLOG) { + return 0.0; + } + return std::exp(ax); + } + + fac = a + lanczos_g - 0.5; + res = std::sqrt(fac / EXP1) / lanczos_sum_expg_scaled(a); + + if ((a < 200) && (x < 200)) { + res *= std::exp(a - x) * std::pow(x / fac, a); + } + else { + num = x - a - lanczos_g + 0.5; + numfac = num / fac; + res *= std::exp(a * (std::log1p(numfac) - numfac) + x * (0.5 - lanczos_g) / fac); + } + return res; +} + +template +static scalar_t _igam_helper_series(scalar_t a, scalar_t x) { + // Compute igam using DLMF 8.11.4. [igam1] + static scalar_t MACHEP = std::is_same_v ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + static int MAXITER = 2000; + + int i; + scalar_t ans, ax, c, r; + + ax = _igam_helper_fac(a, x); + if (ax == 0.0) { + return 0.0; + } + + /* power series */ + r = a; + c = 1.0; + ans = 1.0; + + for (i = 0; i < MAXITER; i++) { + r += 1.0; + c *= x / r; + ans += c; + if (c <= MACHEP * ans) { + break; + } + } + return (ans * ax / a); +} + +template +static scalar_t _igamc_helper_series(scalar_t a, scalar_t x) { + // Compute igamc using DLMF 8.7.3 [igam1]. This is related to the series in + // _igam_helper_series but extra care is taken to avoid cancellation. + + int n; + scalar_t fac = 1; + scalar_t sum = 0; + scalar_t term, logx; + static scalar_t MAXITER = 2000; + static scalar_t MACHEP = std::is_same_v ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + + for (n = 1; n < MAXITER; n++) { + fac *= -x / n; + term = fac / (a + n); + sum += term; + if (std::fabs(term) <= MACHEP * std::fabs(sum)) { + break; + } + } + + logx = std::log(x); + term = -std::expm1(a * logx - std::lgamma(1+a)); + return term - std::exp(a * logx - std::lgamma(a)) * sum; +} + +template +static scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) { + // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1] + static constexpr scalar_t d[25][25] = + {{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, + 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, + 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, + 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9, + 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10, + -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11, + -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13, + -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16, + -1.9752288294349443e-15}, + {-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, + -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, + -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, + 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, + 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, + 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, + 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, + -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, + -4.13125571381061e-15}, + {4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, + 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, + -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, + -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, + -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, + 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, + 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, + 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, + 8.8592218725911273e-15}, + {6.4943415637860082e-4, 2.2947209362139918e-4, -4.6918949439525571e-4, + 2.6772063206283885e-4, -7.5618016718839764e-5, -2.3965051138672967e-7, + 1.1082654115347302e-5, -5.6749528269915966e-6, 1.4230900732435884e-6, + -2.7861080291528142e-11, -1.6958404091930277e-7, 8.0994649053880824e-8, + -1.9111168485973654e-8, 2.3928620439808118e-12, 2.0620131815488798e-9, + -9.4604966618551322e-10, 2.1541049775774908e-10, -1.388823336813903e-14, + -2.1894761681963939e-11, 9.7909989511716851e-12, -2.1782191880180962e-12, + 6.2088195734079014e-17, 2.126978363279737e-13, -9.3446887915174333e-14, + 2.0453671226782849e-14}, + {-8.618882909167117e-4, 7.8403922172006663e-4, -2.9907248030319018e-4, + -1.4638452578843418e-6, 6.6414982154651222e-5, -3.9683650471794347e-5, + 1.1375726970678419e-5, 2.5074972262375328e-10, -1.6954149536558306e-6, + 8.9075075322053097e-7, -2.2929348340008049e-7, 2.956794137544049e-11, + 2.8865829742708784e-8, -1.4189739437803219e-8, 3.4463580499464897e-9, + -2.3024517174528067e-13, -3.9409233028046405e-10, 1.8602338968504502e-10, + -4.356323005056618e-11, 1.2786001016296231e-15, 4.6792750266579195e-12, + -2.1492464706134829e-12, 4.9088156148096522e-13, -6.3385914848915603e-18, + -5.0453320690800944e-14}, + {-3.3679855336635815e-4, -6.9728137583658578e-5, 2.7727532449593921e-4, + -1.9932570516188848e-4, 6.7977804779372078e-5, 1.419062920643967e-7, + -1.3594048189768693e-5, 8.0184702563342015e-6, -2.2914811765080952e-6, + -3.252473551298454e-10, 3.4652846491085265e-7, -1.8447187191171343e-7, + 4.8240967037894181e-8, -1.7989466721743515e-14, -6.3061945000135234e-9, + 3.1624176287745679e-9, -7.8409242536974293e-10, 5.1926791652540407e-15, + 9.3589442423067836e-11, -4.5134262161632782e-11, 1.0799129993116827e-11, + -3.661886712685252e-17, -1.210902069055155e-12, 5.6807435849905643e-13, + -1.3249659916340829e-13}, + {5.3130793646399222e-4, -5.9216643735369388e-4, 2.7087820967180448e-4, + 7.9023532326603279e-7, -8.1539693675619688e-5, 5.6116827531062497e-5, + -1.8329116582843376e-5, -3.0796134506033048e-9, 3.4651553688036091e-6, + -2.0291327396058604e-6, 5.7887928631490037e-7, 2.338630673826657e-13, + -8.8286007463304835e-8, 4.7435958880408128e-8, -1.2545415020710382e-8, + 8.6496488580102925e-14, 1.6846058979264063e-9, -8.5754928235775947e-10, + 2.1598224929232125e-10, -7.6132305204761539e-16, -2.6639822008536144e-11, + 1.3065700536611057e-11, -3.1799163902367977e-12, 4.7109761213674315e-18, + 3.6902800842763467e-13}, + {3.4436760689237767e-4, 5.1717909082605922e-5, -3.3493161081142236e-4, + 2.812695154763237e-4, -1.0976582244684731e-4, -1.2741009095484485e-7, + 2.7744451511563644e-5, -1.8263488805711333e-5, 5.7876949497350524e-6, + 4.9387589339362704e-10, -1.0595367014026043e-6, 6.1667143761104075e-7, + -1.7562973359060462e-7, -1.2974473287015439e-12, 2.695423606288966e-8, + -1.4578352908731271e-8, 3.887645959386175e-9, -3.8810022510194121e-17, + -5.3279941738772867e-10, 2.7437977643314845e-10, -6.9957960920705679e-11, + 2.5899863874868481e-17, 8.8566890996696381e-12, -4.403168815871311e-12, + 1.0865561947091654e-12}, + {-6.5262391859530942e-4, 8.3949872067208728e-4, -4.3829709854172101e-4, + -6.969091458420552e-7, 1.6644846642067548e-4, -1.2783517679769219e-4, + 4.6299532636913043e-5, 4.5579098679227077e-9, -1.0595271125805195e-5, + 6.7833429048651666e-6, -2.1075476666258804e-6, -1.7213731432817145e-11, + 3.7735877416110979e-7, -2.1867506700122867e-7, 6.2202288040189269e-8, + 6.5977038267330006e-16, -9.5903864974256858e-9, 5.2132144922808078e-9, + -1.3991589583935709e-9, 5.382058999060575e-16, 1.9484714275467745e-10, + -1.0127287556389682e-10, 2.6077347197254926e-11, -5.0904186999932993e-18, + -3.3721464474854592e-12}, + {-5.9676129019274625e-4, -7.2048954160200106e-5, 6.7823088376673284e-4, + -6.4014752602627585e-4, 2.7750107634328704e-4, 1.8197008380465151e-7, + -8.4795071170685032e-5, 6.105192082501531e-5, -2.1073920183404862e-5, + -8.8585890141255994e-10, 4.5284535953805377e-6, -2.8427815022504408e-6, + 8.7082341778646412e-7, 3.6886101871706965e-12, -1.5344695190702061e-7, + 8.862466778790695e-8, -2.5184812301826817e-8, -1.0225912098215092e-14, + 3.8969470758154777e-9, -2.1267304792235635e-9, 5.7370135528051385e-10, + -1.887749850169741e-19, -8.0931538694657866e-11, 4.2382723283449199e-11, + -1.1002224534207726e-11}, + {1.3324454494800656e-3, -1.9144384985654775e-3, 1.1089369134596637e-3, + 9.932404122642299e-7, -5.0874501293093199e-4, 4.2735056665392884e-4, + -1.6858853767910799e-4, -8.1301893922784998e-9, 4.5284402370562147e-5, + -3.127053674781734e-5, 1.044986828530338e-5, 4.8435226265680926e-11, + -2.1482565873456258e-6, 1.329369701097492e-6, -4.0295693092101029e-7, + -1.7567877666323291e-13, 7.0145043163668257e-8, -4.040787734999483e-8, + 1.1474026743371963e-8, 3.9642746853563325e-18, -1.7804938269892714e-9, + 9.7480262548731646e-10, -2.6405338676507616e-10, 5.794875163403742e-18, + 3.7647749553543836e-11}, + {1.579727660730835e-3, 1.6251626278391582e-4, -2.0633421035543276e-3, + 2.1389686185689098e-3, -1.0108559391263003e-3, -3.9912705529919201e-7, + 3.6235025084764691e-4, -2.8143901463712154e-4, 1.0449513336495887e-4, + 2.1211418491830297e-9, -2.5779417251947842e-5, 1.7281818956040463e-5, + -5.6413773872904282e-6, -1.1024320105776174e-11, 1.1223224418895175e-6, + -6.8693396379526735e-7, 2.0653236975414887e-7, 4.6714772409838506e-14, + -3.5609886164949055e-8, 2.0470855345905963e-8, -5.8091738633283358e-9, + -1.332821287582869e-16, 9.0354604391335133e-10, -4.9598782517330834e-10, + 1.3481607129399749e-10}, + {-4.0725121195140166e-3, 6.4033628338080698e-3, -4.0410161081676618e-3, + -2.183732802866233e-6, 2.1740441801254639e-3, -1.9700440518418892e-3, + 8.3595469747962458e-4, 1.9445447567109655e-8, -2.5779387120421696e-4, + 1.9009987368139304e-4, -6.7696499937438965e-5, -1.4440629666426572e-10, + 1.5712512518742269e-5, -1.0304008744776893e-5, 3.304517767401387e-6, + 7.9829760242325709e-13, -6.4097794149313004e-7, 3.8894624761300056e-7, + -1.1618347644948869e-7, -2.816808630596451e-15, 1.9878012911297093e-8, + -1.1407719956357511e-8, 3.2355857064185555e-9, 4.1759468293455945e-20, + -5.0423112718105824e-10}, + {-5.9475779383993003e-3, -5.4016476789260452e-4, 8.7910413550767898e-3, + -9.8576315587856125e-3, 5.0134695031021538e-3, 1.2807521786221875e-6, + -2.0626019342754683e-3, 1.7109128573523058e-3, -6.7695312714133799e-4, + -6.9011545676562133e-9, 1.8855128143995902e-4, -1.3395215663491969e-4, + 4.6263183033528039e-5, 4.0034230613321351e-11, -1.0255652921494033e-5, + 6.612086372797651e-6, -2.0913022027253008e-6, -2.0951775649603837e-13, + 3.9756029041993247e-7, -2.3956211978815887e-7, 7.1182883382145864e-8, + 8.925574873053455e-16, -1.2101547235064676e-8, 6.9350618248334386e-9, + -1.9661464453856102e-9}, + {1.7402027787522711e-2, -2.9527880945699121e-2, 2.0045875571402799e-2, + 7.0289515966903407e-6, -1.2375421071343148e-2, 1.1976293444235254e-2, + -5.4156038466518525e-3, -6.3290893396418616e-8, 1.8855118129005065e-3, + -1.473473274825001e-3, 5.5515810097708387e-4, 5.2406834412550662e-10, + -1.4357913535784836e-4, 9.9181293224943297e-5, -3.3460834749478311e-5, + -3.5755837291098993e-12, 7.1560851960630076e-6, -4.5516802628155526e-6, + 1.4236576649271475e-6, 1.8803149082089664e-14, -2.6623403898929211e-7, + 1.5950642189595716e-7, -4.7187514673841102e-8, -6.5107872958755177e-17, + 7.9795091026746235e-9}, + {3.0249124160905891e-2, 2.4817436002649977e-3, -4.9939134373457022e-2, + 5.9915643009307869e-2, -3.2483207601623391e-2, -5.7212968652103441e-6, + 1.5085251778569354e-2, -1.3261324005088445e-2, 5.5515262632426148e-3, + 3.0263182257030016e-8, -1.7229548406756723e-3, 1.2893570099929637e-3, + -4.6845138348319876e-4, -1.830259937893045e-10, 1.1449739014822654e-4, + -7.7378565221244477e-5, 2.5625836246985201e-5, 1.0766165333192814e-12, + -5.3246809282422621e-6, 3.349634863064464e-6, -1.0381253128684018e-6, + -5.608909920621128e-15, 1.9150821930676591e-7, -1.1418365800203486e-7, + 3.3654425209171788e-8}, + {-9.9051020880159045e-2, 1.7954011706123486e-1, -1.2989606383463778e-1, + -3.1478872752284357e-5, 9.0510635276848131e-2, -9.2828824411184397e-2, + 4.4412112839877808e-2, 2.7779236316835888e-7, -1.7229543805449697e-2, + 1.4182925050891573e-2, -5.6214161633747336e-3, -2.39598509186381e-9, + 1.6029634366079908e-3, -1.1606784674435773e-3, 4.1001337768153873e-4, + 1.8365800754090661e-11, -9.5844256563655903e-5, 6.3643062337764708e-5, + -2.076250624489065e-5, -1.1806020912804483e-13, 4.2131808239120649e-6, + -2.6262241337012467e-6, 8.0770620494930662e-7, 6.0125912123632725e-16, + -1.4729737374018841e-7}, + {-1.9994542198219728e-1, -1.5056113040026424e-2, 3.6470239469348489e-1, + -4.6435192311733545e-1, 2.6640934719197893e-1, 3.4038266027147191e-5, + -1.3784338709329624e-1, 1.276467178337056e-1, -5.6213828755200985e-2, + -1.753150885483011e-7, 1.9235592956768113e-2, -1.5088821281095315e-2, + 5.7401854451350123e-3, 1.0622382710310225e-9, -1.5335082692563998e-3, + 1.0819320643228214e-3, -3.7372510193945659e-4, -6.6170909729031985e-12, + 8.4263617380909628e-5, -5.5150706827483479e-5, 1.7769536448348069e-5, + 3.8827923210205533e-14, -3.53513697488768e-6, 2.1865832130045269e-6, + -6.6812849447625594e-7}, + {7.2438608504029431e-1, -1.3918010932653375, 1.0654143352413968, + 1.876173868950258e-4, -8.2705501176152696e-1, 8.9352433347828414e-1, + -4.4971003995291339e-1, -1.6107401567546652e-6, 1.9235590165271091e-1, + -1.6597702160042609e-1, 6.8882222681814333e-2, 1.3910091724608687e-8, + -2.146911561508663e-2, 1.6228980898865892e-2, -5.9796016172584256e-3, + -1.1287469112826745e-10, 1.5167451119784857e-3, -1.0478634293553899e-3, + 3.5539072889126421e-4, 8.1704322111801517e-13, -7.7773013442452395e-5, + 5.0291413897007722e-5, -1.6035083867000518e-5, 1.2469354315487605e-14, + 3.1369106244517615e-6}, + {1.6668949727276811, 1.165462765994632e-1, -3.3288393225018906, + 4.4692325482864037, -2.6977693045875807, -2.600667859891061e-4, + 1.5389017615694539, -1.4937962361134612, 6.8881964633233148e-1, + 1.3077482004552385e-6, -2.5762963325596288e-1, 2.1097676102125449e-1, + -8.3714408359219882e-2, -7.7920428881354753e-9, 2.4267923064833599e-2, + -1.7813678334552311e-2, 6.3970330388900056e-3, 4.9430807090480523e-11, + -1.5554602758465635e-3, 1.0561196919903214e-3, -3.5277184460472902e-4, + 9.3002334645022459e-14, 7.5285855026557172e-5, -4.8186515569156351e-5, + 1.5227271505597605e-5}, + {-6.6188298861372935, 1.3397985455142589e+1, -1.0789350606845146e+1, + -1.4352254537875018e-3, 9.2333694596189809, -1.0456552819547769e+1, + 5.5105526029033471, 1.2024439690716742e-5, -2.5762961164755816, + 2.3207442745387179, -1.0045728797216284, -1.0207833290021914e-7, + 3.3975092171169466e-1, -2.6720517450757468e-1, 1.0235252851562706e-1, + 8.4329730484871625e-10, -2.7998284958442595e-2, 2.0066274144976813e-2, + -7.0554368915086242e-3, 1.9402238183698188e-12, 1.6562888105449611e-3, + -1.1082898580743683e-3, 3.654545161310169e-4, -5.1290032026971794e-11, + -7.6340103696869031e-5}, + {-1.7112706061976095e+1, -1.1208044642899116, 3.7131966511885444e+1, + -5.2298271025348962e+1, 3.3058589696624618e+1, 2.4791298976200222e-3, + -2.061089403411526e+1, 2.088672775145582e+1, -1.0045703956517752e+1, + -1.2238783449063012e-5, 4.0770134274221141, -3.473667358470195, + 1.4329352617312006, 7.1359914411879712e-8, -4.4797257159115612e-1, + 3.4112666080644461e-1, -1.2699786326594923e-1, -2.8953677269081528e-10, + 3.3125776278259863e-2, -2.3274087021036101e-2, 8.0399993503648882e-3, + -1.177805216235265e-9, -1.8321624891071668e-3, 1.2108282933588665e-3, + -3.9479941246822517e-4}, + {7.389033153567425e+1, -1.5680141270402273e+2, 1.322177542759164e+2, + 1.3692876877324546e-2, -1.2366496885920151e+2, 1.4620689391062729e+2, + -8.0365587724865346e+1, -1.1259851148881298e-4, 4.0770132196179938e+1, + -3.8210340013273034e+1, 1.719522294277362e+1, 9.3519707955168356e-7, + -6.2716159907747034, 5.1168999071852637, -2.0319658112299095, + -4.9507215582761543e-9, 5.9626397294332597e-1, -4.4220765337238094e-1, + 1.6079998700166273e-1, -2.4733786203223402e-8, -4.0307574759979762e-2, + 2.7849050747097869e-2, -9.4751858992054221e-3, 6.419922235909132e-6, + 2.1250180774699461e-3}, + {2.1216837098382522e+2, 1.3107863022633868e+1, -4.9698285932871748e+2, + 7.3121595266969204e+2, -4.8213821720890847e+2, -2.8817248692894889e-2, + 3.2616720302947102e+2, -3.4389340280087117e+2, 1.7195193870816232e+2, + 1.4038077378096158e-4, -7.52594195897599e+1, 6.651969984520934e+1, + -2.8447519748152462e+1, -7.613702615875391e-7, 9.5402237105304373, + -7.5175301113311376, 2.8943997568871961, -4.6612194999538201e-7, + -8.0615149598794088e-1, 5.8483006570631029e-1, -2.0845408972964956e-1, + 1.4765818959305817e-4, 5.1000433863753019e-2, -3.3066252141883665e-2, + 1.5109265210467774e-2}, + {-9.8959643098322368e+2, 2.1925555360905233e+3, -1.9283586782723356e+3, + -1.5925738122215253e-1, 1.9569985945919857e+3, -2.4072514765081556e+3, + 1.3756149959336496e+3, 1.2920735237496668e-3, -7.525941715948055e+2, + 7.3171668742208716e+2, -3.4137023466220065e+2, -9.9857390260608043e-6, + 1.3356313181291573e+2, -1.1276295161252794e+2, 4.6310396098204458e+1, + -7.9237387133614756e-6, -1.4510726927018646e+1, 1.1111771248100563e+1, + -4.1690817945270892, 3.1008219800117808e-3, 1.1220095449981468, + -7.6052379926149916e-1, 3.6262236505085254e-1, 2.216867741940747e-1, + 4.8683443692930507e-1}}; + + int k, n, sgn; + int maxpow = 0; + static scalar_t MACHEP = std::is_same_v ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + scalar_t lambda = x / a; + scalar_t sigma = (x - a) / a; + scalar_t eta, res, ck, ckterm, term, absterm; + scalar_t absoldterm = INFINITY; + scalar_t etapow[25] = {1}; + scalar_t sum = 0; + scalar_t afac = 1; + + if (igam) { + sgn = -1; + } + else { + sgn = 1; + } + + if (lambda > 1) { + eta = std::sqrt(-2 * (std::log1p(sigma) - sigma)); + } + else if (lambda < 1) { + eta = -std::sqrt(-2 * (std::log1p(sigma) - sigma)); + } + else { + eta = 0; + } + res = 0.5 * std::erfc(sgn * eta * std::sqrt(a / 2)); + + for (k = 0; k < 25; k++) { + ck = d[k][0]; + for (n = 1; n < 25; n++) { + if (n > maxpow) { + etapow[n] = eta * etapow[n-1]; + maxpow += 1; + } + ckterm = d[k][n]*etapow[n]; + ck += ckterm; + if (std::fabs(ckterm) < MACHEP * std::fabs(ck)) { + break; + } + } + term = ck * afac; + absterm = std::fabs(term); + if (absterm > absoldterm) { + break; + } + sum += term; + if (absterm < MACHEP * std::fabs(sum)) { + break; + } + absoldterm = absterm; + afac /= a; + } + res += sgn * std::exp(-0.5 * a * eta * eta) * sum / std::sqrt(2 * c10::pi * a); + + return res; +} + +template +static scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar_t x) { + // Compute igamc using DLMF 8.9.2. [igam1] + int i; + scalar_t ans, ax, c, yc, r, t, y, z; + scalar_t pk, pkm1, pkm2, qk, qkm1, qkm2; + int MAXITER = 2000; + static scalar_t MACHEP = std::is_same_v ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + static scalar_t BIG = std::is_same_v ? + 4.503599627370496e15 : 16777216.; + static scalar_t BIGINV = std::is_same_v ? + 2.22044604925031308085e-16 : 5.9604644775390625E-8; + + ax = _igam_helper_fac(a, x); + if (ax == 0.0) { + return 0.0; + } + + /* continued fraction */ + y = 1.0 - a; + z = x + y + 1.0; + c = 0.0; + pkm2 = 1.0; + qkm2 = x; + pkm1 = x + 1.0; + qkm1 = z * x; + ans = pkm1 / qkm1; + + for (i = 0; i < MAXITER; i++) { + c += 1.0; + y += 1.0; + z += 2.0; + yc = y * c; + pk = pkm1 * z - pkm2 * yc; + qk = qkm1 * z - qkm2 * yc; + if (qk != 0) { + r = pk / qk; + t = std::fabs((ans - r) / r); + ans = r; + } + else { + t = 1.0; + } + pkm2 = pkm1; + pkm1 = pk; + qkm2 = qkm1; + qkm1 = qk; + if (std::fabs(pk) > BIG) { + pkm2 *= BIGINV; + pkm1 *= BIGINV; + qkm2 *= BIGINV; + qkm1 *= BIGINV; + } + if (t <= MACHEP) { + break; + } + } + return ans * ax; +} + +template +inline scalar_t calc_igammac(scalar_t a, scalar_t x) { + /* the calculation of the regularized upper incomplete gamma function + * is done differently based on the values of a and x: + * - if x and/or a is at the boundary of defined region, then assign the + * result at the boundary + * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for + * Large Parameter (see DLMF 8.12.4 [igam1]) + * - if x > 1.1 and x < a, using the subtraction from the regularized lower + * incomplete gamma + * - otherwise, calculate the series from [igam2] eq (5) + */ + scalar_t absxma_a; + + static scalar_t SMALL = 20.0; + static scalar_t LARGE = 200.0; + static scalar_t SMALLRATIO = 0.3; + static scalar_t LARGERATIO = 4.5; + + // note that in SciPy, a and x are non-negative, with exclusive 0s (i.e., + // at most 1 of them can be 0), where igammac(0, x) = 0.0 iff x > 0. + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return std::numeric_limits::quiet_NaN(); + } + else if (a == 0) { + if (x > 0) { + return 0.0; + } + else { + return std::numeric_limits::quiet_NaN(); + } + } + else if (x == 0) { + return 1.0; + } + else if (std::isinf(a)) { + if (std::isinf(x)) { + return std::numeric_limits::quiet_NaN(); + } + return 1.0; + } + else if (std::isinf(x)) { + return 0.0; + } + + absxma_a = std::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 0); + } + else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 0); + } + + if (x > 1.1) { + if (x < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_continued_fraction(a, x); + } + } + else if (x <= 0.5) { + if (-0.4 / std::log(x) < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_series(a, x); + } + } + else { + if (x * 1.1 < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_series(a, x); + } + } +} + +template +scalar_t calc_igamma(scalar_t a, scalar_t x) { + /* the calculation of the regularized lower incomplete gamma function + * is done differently based on the values of a and x: + * - if x and/or a is at the boundary of defined region, then assign the + * result at the boundary + * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for + * Large Parameter (see DLMF 8.12.3 [igam1]) + * - if x > 1 and x > a, using the subtraction from the regularized upper + * incomplete gamma + * - otherwise, calculate the series from [igam2] eq (4) + */ + scalar_t absxma_a; + static scalar_t SMALL = 20.0; + static scalar_t LARGE = 200.0; + static scalar_t SMALLRATIO = 0.3; + static scalar_t LARGERATIO = 4.5; + + // boundary values following SciPy + // note that in SciPy, a and x are non-negative, with exclusive 0s (i.e., + // at most 1 of them can be 0), where igamma(0, x) = 1.0 iff x > 0. + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return std::numeric_limits::quiet_NaN(); + } + else if (a == 0) { + if (x > 0) { + return 1.0; + } + else { + return std::numeric_limits::quiet_NaN(); + } + } + else if (x == 0) { + return 0.0; // zero integration limit + } + else if (std::isinf(a)) { + if (std::isinf(x)) { + return std::numeric_limits::quiet_NaN(); + } + return 0.0; + } + else if (std::isinf(x)) { + return 1.0; + } + + /* Asymptotic regime where a ~ x. See [igam2] */ + absxma_a = std::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 1); + } + else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 1); + } + + if ((x > 1.0) && (x > a)) { + return 1.0 - calc_igammac(a, x); + } + + return _igam_helper_series(a, x); +} + +template <> +[[maybe_unused]] inline c10::BFloat16 calc_igamma( + c10::BFloat16 a, + c10::BFloat16 x) { + return calc_igamma(float(a), float(x)); +} + +template <> +[[maybe_unused]] inline c10::Half calc_igamma( + c10::Half a, + c10::Half x) { + return calc_igamma(float(a), float(x)); +} + +template <> +[[maybe_unused]] inline c10::BFloat16 calc_igammac( + c10::BFloat16 a, + c10::BFloat16 x) { + return calc_igammac(float(a), float(x)); +} + +template <> +[[maybe_unused]] inline c10::Half calc_igammac( + c10::Half a, + c10::Half x) { + return calc_igammac(float(a), float(x)); +} + +inline c10::BFloat16 calc_erfinv(c10::BFloat16 a) { return calc_erfinv(float(a)); } + +template +inline T abs_impl(T v) { + return std::abs(v); +} + +template <> +[[maybe_unused]] inline uint8_t abs_impl(uint8_t v) { + return v; +} + +template +inline typename std::enable_if_t, T> +calc_gcd(T a, T b) { + a = abs_impl(a); + b = abs_impl(b); + while (a != 0) { + T c = a; + a = b % a; + b = c; + } + return b; +} + +template +C10_HOST_DEVICE T exp2_impl(T x) { + return std::exp2(x); +} + +template +C10_HOST_DEVICE c10::complex exp2_impl(c10::complex x) { + // There is no std::exp2 overload for complex, so instead + // use the identity 2^x = e^(ln(2) * x) + constexpr auto ln2 = c10::ln_2; + return std::exp(ln2 * x); +} + +/* + * This function is derived from the implementation of the chbevl function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + * + * Evaluates the series + * + * len-1 + * - ' + * y = > array[i] T (x/2) + * - i + * i=0 + * + * of Chebyshev polynomials Ti at argument x/2. + * + * Coefficients are stored in reverse order, i.e. the zero order term is last in the array. Note len is the number of + * coefficients, not the order. + * + * If coefficients are for the interval a to b, x must have been transformed to x -> 2(2x - b - a)/(b-a) before + * entering the routine. This maps x from (a, b) to (-1, 1), over which the Chebyshev polynomials are defined. + * + * If the coefficients are for the inverted interval, in which (a, b) is mapped to (1/b, 1/a), the transformation + * required is x -> 2(2ab/x - b - a)/(b-a). If b is infinity, this becomes x -> 4a/x - 1. + */ +template +inline typename std::enable_if_t, T> +chbevl(const T x, const T array[], size_t len) { + T b0, b1, b2; + + b0 = array[0]; + b1 = static_cast(0.0); + + for (size_t i = 1; i < len; ++i) { + b2 = b1; + b1 = b0; + b0 = x * b1 - b2 + array[i]; + } + + return (static_cast(0.5) * (b0 - b2)); +} + +/* + * This function is derived from the implementation of the i0 function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + * + * Computes an approximation of the zeroth order modified Bessel function of the first kind. + * The approximation is actually two (sub)approximations, both using a Chebyshev polynomial expansion. + * One approximates the function over [0, 8], and the other over (8, infinity). This function takes the absolute value + * of all inputs to convert them into the domain of the approximation. + */ +template +inline std::tuple chebyshev_coefficients_i0e_A() { + /* Chebyshev coefficients for exp(-x) I0(x) + * in the interval [0,8]. + * + * lim(x->0){ exp(-x) I0(x) } = 1. + */ + static const T coeff[] = { + -4.41534164647933937950E-18, 3.33079451882223809783E-17, + -2.43127984654795469359E-16, 1.71539128555513303061E-15, + -1.16853328779934516808E-14, 7.67618549860493561688E-14, + -4.85644678311192946090E-13, 2.95505266312963983461E-12, + -1.72682629144155570723E-11, 9.67580903537323691224E-11, + -5.18979560163526290666E-10, 2.65982372468238665035E-9, + -1.30002500998624804212E-8, 6.04699502254191894932E-8, + -2.67079385394061173391E-7, 1.11738753912010371815E-6, + -4.41673835845875056359E-6, 1.64484480707288970893E-5, + -5.75419501008210370398E-5, 1.88502885095841655729E-4, + -5.76375574538582365885E-4, 1.63947561694133579842E-3, + -4.32430999505057594430E-3, 1.05464603945949983183E-2, + -2.37374148058994688156E-2, 4.93052842396707084878E-2, + -9.49010970480476444210E-2, 1.71620901522208775349E-1, + -3.04682672343198398683E-1, 6.76795274409476084995E-1}; + return std::make_tuple(coeff, 30); +} + +template +inline std::tuple chebyshev_coefficients_i0e_B() { + /* Chebyshev coefficients for exp(-x) sqrt(x) I0(x) + * in the inverted interval [8,infinity]. + * + * lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi). + */ + static const T coeff[] = { + -7.23318048787475395456E-18, -4.83050448594418207126E-18, + 4.46562142029675999901E-17, 3.46122286769746109310E-17, + -2.82762398051658348494E-16, -3.42548561967721913462E-16, + 1.77256013305652638360E-15, 3.81168066935262242075E-15, + -9.55484669882830764870E-15, -4.15056934728722208663E-14, + 1.54008621752140982691E-14, 3.85277838274214270114E-13, + 7.18012445138366623367E-13, -1.79417853150680611778E-12, + -1.32158118404477131188E-11, -3.14991652796324136454E-11, + 1.18891471078464383424E-11, 4.94060238822496958910E-10, + 3.39623202570838634515E-9, 2.26666899049817806459E-8, + 2.04891858946906374183E-7, 2.89137052083475648297E-6, + 6.88975834691682398426E-5, 3.36911647825569408990E-3, + 8.04490411014108831608E-1}; + + return std::make_tuple(coeff, 25); +} + +template +inline typename std::enable_if_t, std::tuple> +chebyshev_coefficients_i1e_A() { + /* Chebyshev coefficients for exp(-x) I1(x) + * in the interval [0,8]. + * + * lim(x->0){ exp(-x) I1(x) / x } = 1/2. + */ + static const T coeff[] = { + 2.77791411276104639959E-18, -2.11142121435816608115E-17, + 1.55363195773620046921E-16, -1.10559694773538630805E-15, + 7.60068429473540693410E-15, -5.04218550472791168711E-14, + 3.22379336594557470981E-13, -1.98397439776494371520E-12, + 1.17361862988909016308E-11, -6.66348972350202774223E-11, + 3.62559028155211703701E-10, -1.88724975172282928790E-9, + 9.38153738649577178388E-9, -4.44505912879632808065E-8, + 2.00329475355213526229E-7, -8.56872026469545474066E-7, + 3.47025130813767847674E-6, -1.32731636560394358279E-5, + 4.78156510755005422638E-5, -1.61760815825896745588E-4, + 5.12285956168575772895E-4, -1.51357245063125314899E-3, + 4.15642294431288815669E-3, -1.05640848946261981558E-2, + 2.47264490306265168283E-2, -5.29459812080949914269E-2, + 1.02643658689847095384E-1, -1.76416518357834055153E-1, + 2.52587186443633654823E-1}; + return std::make_tuple(coeff, 29); +} + +template +inline typename std::enable_if_t, std::tuple> +chebyshev_coefficients_i1e_A() { + /* Chebyshev coefficients for exp(-x) I1(x) + * in the interval [0,8]. + * + * lim(x->0){ exp(-x) I1(x) / x } = 1/2. + */ + static const T coeff[] = { + 9.38153738649577178388E-9f, + -4.44505912879632808065E-8f, + 2.00329475355213526229E-7f, + -8.56872026469545474066E-7f, + 3.47025130813767847674E-6f, + -1.32731636560394358279E-5f, + 4.78156510755005422638E-5f, + -1.61760815825896745588E-4f, + 5.12285956168575772895E-4f, + -1.51357245063125314899E-3f, + 4.15642294431288815669E-3f, + -1.05640848946261981558E-2f, + 2.47264490306265168283E-2f, + -5.29459812080949914269E-2f, + 1.02643658689847095384E-1f, + -1.76416518357834055153E-1f, + 2.52587186443633654823E-1f}; + return std::make_tuple(coeff, 17); +} + +template +inline typename std::enable_if_t, std::tuple> +chebyshev_coefficients_i1e_B() { + /* Chebyshev coefficients for exp(-x) sqrt(x) I1(x) + * in the inverted interval [8,infinity]. + * + * lim(x->inf){ exp(-x) sqrt(x) I1(x) } = 1/sqrt(2pi). + */ + static const T coeff[] = { + 7.51729631084210481353E-18, 4.41434832307170791151E-18, + -4.65030536848935832153E-17, -3.20952592199342395980E-17, + 2.96262899764595013876E-16, 3.30820231092092828324E-16, + -1.88035477551078244854E-15, -3.81440307243700780478E-15, + 1.04202769841288027642E-14, 4.27244001671195135429E-14, + -2.10154184277266431302E-14, -4.08355111109219731823E-13, + -7.19855177624590851209E-13, 2.03562854414708950722E-12, + 1.41258074366137813316E-11, 3.25260358301548823856E-11, + -1.89749581235054123450E-11, -5.58974346219658380687E-10, + -3.83538038596423702205E-9, -2.63146884688951950684E-8, + -2.51223623787020892529E-7, -3.88256480887769039346E-6, + -1.10588938762623716291E-4, -9.76109749136146840777E-3, + 7.78576235018280120474E-1}; + + return std::make_tuple(coeff, 25); +} + +template +inline typename std::enable_if_t, std::tuple> +chebyshev_coefficients_i1e_B() { + /* Chebyshev coefficients for exp(-x) sqrt(x) I1(x) + * in the inverted interval [8,infinity]. + * + * lim(x->inf){ exp(-x) sqrt(x) I1(x) } = 1/sqrt(2pi). + */ + static const T coeff[] = { + -3.83538038596423702205E-9f, + -2.63146884688951950684E-8f, + -2.51223623787020892529E-7f, + -3.88256480887769039346E-6f, + -1.10588938762623716291E-4f, + -9.76109749136146840777E-3f, + 7.78576235018280120474E-1f}; + + return std::make_tuple(coeff, 7); +} + +template +inline typename std::enable_if_t, T> +calc_i0(T _x) { + T x = std::abs(_x); + + if (x <= T{8.0}) { + auto [A, len] = chebyshev_coefficients_i0e_A(); + T y = (x / T{2.0}) - T{2.0}; + return static_cast(std::exp(x) * chbevl(y, A, len)); + } + auto [B, len] = chebyshev_coefficients_i0e_B(); + return std::exp(x) * chbevl(T{32.0} / x - T{2.0}, B, len) / std::sqrt(x); +} + +// Upcast bfloat16/half input to float for numerical accuracy purposes +inline c10::BFloat16 calc_i0(c10::BFloat16 a) { return calc_i0(static_cast(a)); } +inline c10::Half calc_i0(c10::Half a) { return calc_i0(static_cast(a)); } + +/* + * This function is derived from the implementation of the i1 function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + * + * Computes an approximation of the first order modified Bessel function of the first kind. + * The approximation is actually two (sub)approximations, both using a Chebyshev polynomial expansion. + * One approximates the function over [0, 8], and the other over (8, infinity). This function takes the absolute value + * of all inputs to convert them into the domain of the approximation. + */ +template +inline typename std::enable_if_t, T> +calc_i1(T _x) { + T x = std::abs(_x); + + if (x <= T{8.0}) { + auto [A, len] = chebyshev_coefficients_i1e_A(); + T y = (x / T{2.0}) - T{2.0}; + const T out = std::exp(x) * x * chbevl(y, A, len); + return (_x < T{0.0}) ? -out : out; + } + auto [B, len] = chebyshev_coefficients_i1e_B(); + const T out = (std::exp(x) * chbevl(T{32.0} / x - T{2.0}, B, len)) / std::sqrt(x); + return (_x < T{0.0}) ? -out : out; +} + +// Upcast bfloat16/half input to float for numerical accuracy purposes +inline c10::BFloat16 calc_i1(c10::BFloat16 a) { return calc_i1(static_cast(a)); } +inline c10::Half calc_i1(c10::Half a) { return calc_i1(static_cast(a)); } + + +/* + * This function is derived from the implementation of the i1e function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + * + * Computes an approximation of the exponentially scaled first order modified Bessel function of the first kind. + * The approximation is actually two (sub)approximations, both using a Chebyshev polynomial expansion. + * One approximates the function over [0, 8], and the other over (8, infinity). This function takes the absolute value + * of all inputs to convert them into the domain of the approximation. + */ +template +inline typename std::enable_if_t, T> +calc_i1e(T _x) { + T x = std::abs(_x); + + if (x <= T{8.0}) { + auto [A, len] = chebyshev_coefficients_i1e_A(); + T y = (x / T{2.0}) - T{2.0}; + const T out = chbevl(y, A, len) * x; + return (_x < T{0.0}) ? -out : out; + } + auto [B, len] = chebyshev_coefficients_i1e_B(); + const auto out = chbevl(T{32.0} / x - T{2.0}, B, len) / std::sqrt(x); + return (_x < T{0.0}) ? -out : out; +} + +// Upcast bfloat16/half input to float for numerical accuracy purposes +inline c10::BFloat16 calc_i1e(c10::BFloat16 a) { return calc_i1e(static_cast(a)); } +inline c10::Half calc_i1e(c10::Half a) { return calc_i1e(static_cast(a)); } + + +/* + * This function is derived from the implementation of the i1e function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + * + * Computes the argument, x, for which the area under the Gaussian probability density function + * (integrated from minus infinity to x) is equal to y. + */ +template +inline C10_HOST_DEVICE T calc_ndtri(T y0) { + + /* sqrt(2pi) */ + constexpr T s2pi = 2.50662827463100050242E0; + constexpr T one = 1; + constexpr T zero = 0; + + /* approximation for 0 <= |y - 0.5| <= 3/8 */ + static const T P0[5] = { + -5.99633501014107895267E1, + 9.80010754185999661536E1, + -5.66762857469070293439E1, + 1.39312609387279679503E1, + -1.23916583867381258016E0, + }; + + static const T Q0[9] = { + 1.00000000000000000000E0, + 1.95448858338141759834E0, + 4.67627912898881538453E0, + 8.63602421390890590575E1, + -2.25462687854119370527E2, + 2.00260212380060660359E2, + -8.20372256168333339912E1, + 1.59056225126211695515E1, + -1.18331621121330003142E0, + }; + + /* Approximation for interval z = sqrt(-2 log y ) between 2 and 8 + * i.e., y between exp(-2) = .135 and exp(-32) = 1.27e-14. + */ + static const T P1[9] = { + 4.05544892305962419923E0, + 3.15251094599893866154E1, + 5.71628192246421288162E1, + 4.40805073893200834700E1, + 1.46849561928858024014E1, + 2.18663306850790267539E0, + -1.40256079171354495875E-1, + -3.50424626827848203418E-2, + -8.57456785154685413611E-4, + }; + + static const T Q1[9] = { + 1.00000000000000000000E0, + 1.57799883256466749731E1, + 4.53907635128879210584E1, + 4.13172038254672030440E1, + 1.50425385692907503408E1, + 2.50464946208309415979E0, + -1.42182922854787788574E-1, + -3.80806407691578277194E-2, + -9.33259480895457427372E-4, + }; + + /* Approximation for interval z = sqrt(-2 log y ) between 8 and 64 + * i.e., y between exp(-32) = 1.27e-14 and exp(-2048) = 3.67e-890. + */ + + static const T P2[9] = { + 3.23774891776946035970E0, + 6.91522889068984211695E0, + 3.93881025292474443415E0, + 1.33303460815807542389E0, + 2.01485389549179081538E-1, + 1.23716634817820021358E-2, + 3.01581553508235416007E-4, + 2.65806974686737550832E-6, + 6.23974539184983293730E-9, + }; + + static const T Q2[9] = { + 1.00000000000000000000E0, + 6.02427039364742014255E0, + 3.67983563856160859403E0, + 1.37702099489081330271E0, + 2.16236993594496635890E-1, + 1.34204006088543189037E-2, + 3.28014464682127739104E-4, + 2.89247864745380683936E-6, + 6.79019408009981274425E-9, + }; + + if (y0 == zero) { + return -std::numeric_limits::infinity(); + } + if (y0 == one) { + return std::numeric_limits::infinity(); + } + if (y0 < zero || y0 > one) { + return std::numeric_limits::quiet_NaN(); + } + bool code = true; + T y = y0; + if (y > one - T{0.13533528323661269189}) { /* 0.135... = exp(-2) */ + y = one - y; + code = false; + } + + if (y > T{0.13533528323661269189}) { + y = y - T{0.5}; + const T y2 = y * y; + T x = y + y * (y2 * polevl(y2, P0, 4) / polevl(y2, Q0, 8)); + return (x * s2pi); + } + + T x = ::sqrt(T{-2.0} * ::log(y)); + const T x0 = x - ::log(x) / x; + + const T z = one / x; + T x1; + if (x < T{8.0}) /* y > exp(-32) = 1.2664165549e-14 */ + { + x1 = z * polevl(z, P1, 8) / polevl(z, Q1, 8); + } else { + x1 = z * polevl(z, P2, 8) / polevl(z, Q2, 8); + } + x = x0 - x1; + if (code) { + x = -x; + } + return x; +} + +/* The next function is taken from http://ab-initio.mit.edu/faddeeva */ + +/* Copyright (c) 2012 Massachusetts Institute of Technology + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE + * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION + * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION + * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +/* erfcx(x) = exp(x^2) erfc(x) function, for real x, written by + Steven G. Johnson, October 2012. + + This function combines a few different ideas. + + First, for x > 50, it uses a continued-fraction expansion (same as + for the Faddeeva function, but with algebraic simplifications for z=i*x). + + Second, for 0 <= x <= 50, it uses Chebyshev polynomial approximations, + but with two twists: + + a) It maps x to y = 4 / (4+x) in [0,1]. This simple transformation, + inspired by a similar transformation in the octave-forge/specfun + erfcx by Soren Hauberg, results in much faster Chebyshev convergence + than other simple transformations I have examined. + + b) Instead of using a single Chebyshev polynomial for the entire + [0,1] y interval, we break the interval up into 100 equal + subintervals, with a switch/lookup table, and use much lower + degree Chebyshev polynomials in each subinterval. This greatly + improves performance in my tests. + + For x < 0, we use the relationship erfcx(-x) = 2 exp(x^2) - erfc(x), + with the usual checks for overflow etcetera. + + Performance-wise, it seems to be substantially faster than either + the SLATEC DERFC function [or an erfcx function derived there from] + or Cody's CALERF function (from netlib.org/specfun), while + retaining near machine precision in accuracy. */ + +/* Given y100=100*y, where y = 4/(4+x) for x >= 0, compute erfc(x). + + Uses a look-up table of 100 different Chebyshev polynomials + for y intervals [0,0.01], [0.01,0.02], ...., [0.99,1], generated + with the help of Maple and a little shell script. This allows + the Chebyshev polynomials to be of significantly lower degree (about 1/4) + compared to fitting the whole [0,1] interval with a single polynomial. */ + + +template +C10_HOST_DEVICE inline typename std::enable_if_t, T> +erfcx_y100(T y100) +{ + switch (static_cast(y100)) { +case 0: { +T t = 2*y100 - 1; +return 0.70878032454106438663e-3 + (0.71234091047026302958e-3 + (0.35779077297597742384e-5 + (0.17403143962587937815e-7 + (0.81710660047307788845e-10 + (0.36885022360434957634e-12 + 0.15917038551111111111e-14 * t) * t) * t) * t) * t) * t; +} +case 1: { +T t = 2*y100 - 3; +return 0.21479143208285144230e-2 + (0.72686402367379996033e-3 + (0.36843175430938995552e-5 + (0.18071841272149201685e-7 + (0.85496449296040325555e-10 + (0.38852037518534291510e-12 + 0.16868473576888888889e-14 * t) * t) * t) * t) * t) * t; +} +case 2: { +T t = 2*y100 - 5; +return 0.36165255935630175090e-2 + (0.74182092323555510862e-3 + (0.37948319957528242260e-5 + (0.18771627021793087350e-7 + (0.89484715122415089123e-10 + (0.40935858517772440862e-12 + 0.17872061464888888889e-14 * t) * t) * t) * t) * t) * t; +} +case 3: { +T t = 2*y100 - 7; +return 0.51154983860031979264e-2 + (0.75722840734791660540e-3 + (0.39096425726735703941e-5 + (0.19504168704300468210e-7 + (0.93687503063178993915e-10 + (0.43143925959079664747e-12 + 0.18939926435555555556e-14 * t) * t) * t) * t) * t) * t; +} +case 4: { +T t = 2*y100 - 9; +return 0.66457513172673049824e-2 + (0.77310406054447454920e-3 + (0.40289510589399439385e-5 + (0.20271233238288381092e-7 + (0.98117631321709100264e-10 + (0.45484207406017752971e-12 + 0.20076352213333333333e-14 * t) * t) * t) * t) * t) * t; +} +case 5: { +T t = 2*y100 - 11; +return 0.82082389970241207883e-2 + (0.78946629611881710721e-3 + (0.41529701552622656574e-5 + (0.21074693344544655714e-7 + (0.10278874108587317989e-9 + (0.47965201390613339638e-12 + 0.21285907413333333333e-14 * t) * t) * t) * t) * t) * t; +} +case 6: { +T t = 2*y100 - 13; +return 0.98039537275352193165e-2 + (0.80633440108342840956e-3 + (0.42819241329736982942e-5 + (0.21916534346907168612e-7 + (0.10771535136565470914e-9 + (0.50595972623692822410e-12 + 0.22573462684444444444e-14 * t) * t) * t) * t) * t) * t; +} +case 7: { +T t = 2*y100 - 15; +return 0.11433927298290302370e-1 + (0.82372858383196561209e-3 + (0.44160495311765438816e-5 + (0.22798861426211986056e-7 + (0.11291291745879239736e-9 + (0.53386189365816880454e-12 + 0.23944209546666666667e-14 * t) * t) * t) * t) * t) * t; +} +case 8: { +T t = 2*y100 - 17; +return 0.13099232878814653979e-1 + (0.84167002467906968214e-3 + (0.45555958988457506002e-5 + (0.23723907357214175198e-7 + (0.11839789326602695603e-9 + (0.56346163067550237877e-12 + 0.25403679644444444444e-14 * t) * t) * t) * t) * t) * t; +} +case 9: { +T t = 2*y100 - 19; +return 0.14800987015587535621e-1 + (0.86018092946345943214e-3 + (0.47008265848816866105e-5 + (0.24694040760197315333e-7 + (0.12418779768752299093e-9 + (0.59486890370320261949e-12 + 0.26957764568888888889e-14 * t) * t) * t) * t) * t) * t; +} +case 10: { +T t = 2*y100 - 21; +return 0.16540351739394069380e-1 + (0.87928458641241463952e-3 + (0.48520195793001753903e-5 + (0.25711774900881709176e-7 + (0.13030128534230822419e-9 + (0.62820097586874779402e-12 + 0.28612737351111111111e-14 * t) * t) * t) * t) * t) * t; +} +case 11: { +T t = 2*y100 - 23; +return 0.18318536789842392647e-1 + (0.89900542647891721692e-3 + (0.50094684089553365810e-5 + (0.26779777074218070482e-7 + (0.13675822186304615566e-9 + (0.66358287745352705725e-12 + 0.30375273884444444444e-14 * t) * t) * t) * t) * t) * t; +} +case 12: { +T t = 2*y100 - 25; +return 0.20136801964214276775e-1 + (0.91936908737673676012e-3 + (0.51734830914104276820e-5 + (0.27900878609710432673e-7 + (0.14357976402809042257e-9 + (0.70114790311043728387e-12 + 0.32252476000000000000e-14 * t) * t) * t) * t) * t) * t; +} +case 13: { +T t = 2*y100 - 27; +return 0.21996459598282740954e-1 + (0.94040248155366777784e-3 + (0.53443911508041164739e-5 + (0.29078085538049374673e-7 + (0.15078844500329731137e-9 + (0.74103813647499204269e-12 + 0.34251892320000000000e-14 * t) * t) * t) * t) * t) * t; +} +case 14: { +T t = 2*y100 - 29; +return 0.23898877187226319502e-1 + (0.96213386835900177540e-3 + (0.55225386998049012752e-5 + (0.30314589961047687059e-7 + (0.15840826497296335264e-9 + (0.78340500472414454395e-12 + 0.36381553564444444445e-14 * t) * t) * t) * t) * t) * t; +} +case 15: { +T t = 2*y100 - 31; +return 0.25845480155298518485e-1 + (0.98459293067820123389e-3 + (0.57082915920051843672e-5 + (0.31613782169164830118e-7 + (0.16646478745529630813e-9 + (0.82840985928785407942e-12 + 0.38649975768888888890e-14 * t) * t) * t) * t) * t) * t; +} +case 16: { +T t = 2*y100 - 33; +return 0.27837754783474696598e-1 + (0.10078108563256892757e-2 + (0.59020366493792212221e-5 + (0.32979263553246520417e-7 + (0.17498524159268458073e-9 + (0.87622459124842525110e-12 + 0.41066206488888888890e-14 * t) * t) * t) * t) * t) * t; +} +case 17: { +T t = 2*y100 - 35; +return 0.29877251304899307550e-1 + (0.10318204245057349310e-2 + (0.61041829697162055093e-5 + (0.34414860359542720579e-7 + (0.18399863072934089607e-9 + (0.92703227366365046533e-12 + 0.43639844053333333334e-14 * t) * t) * t) * t) * t) * t; +} +case 18: { +T t = 2*y100 - 37; +return 0.31965587178596443475e-1 + (0.10566560976716574401e-2 + (0.63151633192414586770e-5 + (0.35924638339521924242e-7 + (0.19353584758781174038e-9 + (0.98102783859889264382e-12 + 0.46381060817777777779e-14 * t) * t) * t) * t) * t) * t; +} +case 19: { +T t = 2*y100 - 39; +return 0.34104450552588334840e-1 + (0.10823541191350532574e-2 + (0.65354356159553934436e-5 + (0.37512918348533521149e-7 + (0.20362979635817883229e-9 + (0.10384187833037282363e-11 + 0.49300625262222222221e-14 * t) * t) * t) * t) * t) * t; +} +case 20: { +T t = 2*y100 - 41; +return 0.36295603928292425716e-1 + (0.11089526167995268200e-2 + (0.67654845095518363577e-5 + (0.39184292949913591646e-7 + (0.21431552202133775150e-9 + (0.10994259106646731797e-11 + 0.52409949102222222221e-14 * t) * t) * t) * t) * t) * t; +} +case 21: { +T t = 2*y100 - 43; +return 0.38540888038840509795e-1 + (0.11364917134175420009e-2 + (0.70058230641246312003e-5 + (0.40943644083718586939e-7 + (0.22563034723692881631e-9 + (0.11642841011361992885e-11 + 0.55721092871111111110e-14 * t) * t) * t) * t) * t) * t; +} +case 22: { +T t = 2*y100 - 45; +return 0.40842225954785960651e-1 + (0.11650136437945673891e-2 + (0.72569945502343006619e-5 + (0.42796161861855042273e-7 + (0.23761401711005024162e-9 + (0.12332431172381557035e-11 + 0.59246802364444444445e-14 * t) * t) * t) * t) * t) * t; +} +case 23: { +T t = 2*y100 - 47; +return 0.43201627431540222422e-1 + (0.11945628793917272199e-2 + (0.75195743532849206263e-5 + (0.44747364553960993492e-7 + (0.25030885216472953674e-9 + (0.13065684400300476484e-11 + 0.63000532853333333334e-14 * t) * t) * t) * t) * t) * t; +} +case 24: { +T t = 2*y100 - 49; +return 0.45621193513810471438e-1 + (0.12251862608067529503e-2 + (0.77941720055551920319e-5 + (0.46803119830954460212e-7 + (0.26375990983978426273e-9 + (0.13845421370977119765e-11 + 0.66996477404444444445e-14 * t) * t) * t) * t) * t) * t; +} +case 25: { +T t = 2*y100 - 51; +return 0.48103121413299865517e-1 + (0.12569331386432195113e-2 + (0.80814333496367673980e-5 + (0.48969667335682018324e-7 + (0.27801515481905748484e-9 + (0.14674637611609884208e-11 + 0.71249589351111111110e-14 * t) * t) * t) * t) * t) * t; +} +case 26: { +T t = 2*y100 - 53; +return 0.50649709676983338501e-1 + (0.12898555233099055810e-2 + (0.83820428414568799654e-5 + (0.51253642652551838659e-7 + (0.29312563849675507232e-9 + (0.15556512782814827846e-11 + 0.75775607822222222221e-14 * t) * t) * t) * t) * t) * t; +} +case 27: { +T t = 2*y100 - 55; +return 0.53263363664388864181e-1 + (0.13240082443256975769e-2 + (0.86967260015007658418e-5 + (0.53662102750396795566e-7 + (0.30914568786634796807e-9 + (0.16494420240828493176e-11 + 0.80591079644444444445e-14 * t) * t) * t) * t) * t) * t; +} +case 28: { +T t = 2*y100 - 57; +return 0.55946601353500013794e-1 + (0.13594491197408190706e-2 + (0.90262520233016380987e-5 + (0.56202552975056695376e-7 + (0.32613310410503135996e-9 + (0.17491936862246367398e-11 + 0.85713381688888888890e-14 * t) * t) * t) * t) * t) * t; +} +case 29: { +T t = 2*y100 - 59; +return 0.58702059496154081813e-1 + (0.13962391363223647892e-2 + (0.93714365487312784270e-5 + (0.58882975670265286526e-7 + (0.34414937110591753387e-9 + (0.18552853109751857859e-11 + 0.91160736711111111110e-14 * t) * t) * t) * t) * t) * t; +} +case 30: { +T t = 2*y100 - 61; +return 0.61532500145144778048e-1 + (0.14344426411912015247e-2 + (0.97331446201016809696e-5 + (0.61711860507347175097e-7 + (0.36325987418295300221e-9 + (0.19681183310134518232e-11 + 0.96952238400000000000e-14 * t) * t) * t) * t) * t) * t; +} +case 31: { +T t = 2*y100 - 63; +return 0.64440817576653297993e-1 + (0.14741275456383131151e-2 + (0.10112293819576437838e-4 + (0.64698236605933246196e-7 + (0.38353412915303665586e-9 + (0.20881176114385120186e-11 + 0.10310784480000000000e-13 * t) * t) * t) * t) * t) * t; +} +case 32: { +T t = 2*y100 - 65; +return 0.67430045633130393282e-1 + (0.15153655418916540370e-2 + (0.10509857606888328667e-4 + (0.67851706529363332855e-7 + (0.40504602194811140006e-9 + (0.22157325110542534469e-11 + 0.10964842115555555556e-13 * t) * t) * t) * t) * t) * t; +} +case 33: { +T t = 2*y100 - 67; +return 0.70503365513338850709e-1 + (0.15582323336495709827e-2 + (0.10926868866865231089e-4 + (0.71182482239613507542e-7 + (0.42787405890153386710e-9 + (0.23514379522274416437e-11 + 0.11659571751111111111e-13 * t) * t) * t) * t) * t) * t; +} +case 34: { +T t = 2*y100 - 69; +return 0.73664114037944596353e-1 + (0.16028078812438820413e-2 + (0.11364423678778207991e-4 + (0.74701423097423182009e-7 + (0.45210162777476488324e-9 + (0.24957355004088569134e-11 + 0.12397238257777777778e-13 * t) * t) * t) * t) * t) * t; +} +case 35: { +T t = 2*y100 - 71; +return 0.76915792420819562379e-1 + (0.16491766623447889354e-2 + (0.11823685320041302169e-4 + (0.78420075993781544386e-7 + (0.47781726956916478925e-9 + (0.26491544403815724749e-11 + 0.13180196462222222222e-13 * t) * t) * t) * t) * t) * t; +} +case 36: { +T t = 2*y100 - 73; +return 0.80262075578094612819e-1 + (0.16974279491709504117e-2 + (0.12305888517309891674e-4 + (0.82350717698979042290e-7 + (0.50511496109857113929e-9 + (0.28122528497626897696e-11 + 0.14010889635555555556e-13 * t) * t) * t) * t) * t) * t; +} +case 37: { +T t = 2*y100 - 75; +return 0.83706822008980357446e-1 + (0.17476561032212656962e-2 + (0.12812343958540763368e-4 + (0.86506399515036435592e-7 + (0.53409440823869467453e-9 + (0.29856186620887555043e-11 + 0.14891851591111111111e-13 * t) * t) * t) * t) * t) * t; +} +case 38: { +T t = 2*y100 - 77; +return 0.87254084284461718231e-1 + (0.17999608886001962327e-2 + (0.13344443080089492218e-4 + (0.90900994316429008631e-7 + (0.56486134972616465316e-9 + (0.31698707080033956934e-11 + 0.15825697795555555556e-13 * t) * t) * t) * t) * t) * t; +} +case 39: { +T t = 2*y100 - 79; +return 0.90908120182172748487e-1 + (0.18544478050657699758e-2 + (0.13903663143426120077e-4 + (0.95549246062549906177e-7 + (0.59752787125242054315e-9 + (0.33656597366099099413e-11 + 0.16815130613333333333e-13 * t) * t) * t) * t) * t) * t; +} +case 40: { +T t = 2*y100 - 81; +return 0.94673404508075481121e-1 + (0.19112284419887303347e-2 + (0.14491572616545004930e-4 + (0.10046682186333613697e-6 + (0.63221272959791000515e-9 + (0.35736693975589130818e-11 + 0.17862931591111111111e-13 * t) * t) * t) * t) * t) * t; +} +case 41: { +T t = 2*y100 - 83; +return 0.98554641648004456555e-1 + (0.19704208544725622126e-2 + (0.15109836875625443935e-4 + (0.10567036667675984067e-6 + (0.66904168640019354565e-9 + (0.37946171850824333014e-11 + 0.18971959040000000000e-13 * t) * t) * t) * t) * t) * t; +} +case 42: { +T t = 2*y100 - 85; +return 0.10255677889470089531e0 + (0.20321499629472857418e-2 + (0.15760224242962179564e-4 + (0.11117756071353507391e-6 + (0.70814785110097658502e-9 + (0.40292553276632563925e-11 + 0.20145143075555555556e-13 * t) * t) * t) * t) * t) * t; +} +case 43: { +T t = 2*y100 - 87; +return 0.10668502059865093318e0 + (0.20965479776148731610e-2 + (0.16444612377624983565e-4 + (0.11700717962026152749e-6 + (0.74967203250938418991e-9 + (0.42783716186085922176e-11 + 0.21385479360000000000e-13 * t) * t) * t) * t) * t) * t; +} +case 44: { +T t = 2*y100 - 89; +return 0.11094484319386444474e0 + (0.21637548491908170841e-2 + (0.17164995035719657111e-4 + (0.12317915750735938089e-6 + (0.79376309831499633734e-9 + (0.45427901763106353914e-11 + 0.22696025653333333333e-13 * t) * t) * t) * t) * t) * t; +} +case 45: { +T t = 2*y100 - 91; +return 0.11534201115268804714e0 + (0.22339187474546420375e-2 + (0.17923489217504226813e-4 + (0.12971465288245997681e-6 + (0.84057834180389073587e-9 + (0.48233721206418027227e-11 + 0.24079890062222222222e-13 * t) * t) * t) * t) * t) * t; +} +case 46: { +T t = 2*y100 - 93; +return 0.11988259392684094740e0 + (0.23071965691918689601e-2 + (0.18722342718958935446e-4 + (0.13663611754337957520e-6 + (0.89028385488493287005e-9 + (0.51210161569225846701e-11 + 0.25540227111111111111e-13 * t) * t) * t) * t) * t) * t; +} +case 47: { +T t = 2*y100 - 95; +return 0.12457298393509812907e0 + (0.23837544771809575380e-2 + (0.19563942105711612475e-4 + (0.14396736847739470782e-6 + (0.94305490646459247016e-9 + (0.54366590583134218096e-11 + 0.27080225920000000000e-13 * t) * t) * t) * t) * t) * t; +} +case 48: { +T t = 2*y100 - 97; +return 0.12941991566142438816e0 + (0.24637684719508859484e-2 + (0.20450821127475879816e-4 + (0.15173366280523906622e-6 + (0.99907632506389027739e-9 + (0.57712760311351625221e-11 + 0.28703099555555555556e-13 * t) * t) * t) * t) * t) * t; +} +case 49: { +T t = 2*y100 - 99; +return 0.13443048593088696613e0 + (0.25474249981080823877e-2 + (0.21385669591362915223e-4 + (0.15996177579900443030e-6 + (0.10585428844575134013e-8 + (0.61258809536787882989e-11 + 0.30412080142222222222e-13 * t) * t) * t) * t) * t) * t; +} +case 50: { +T t = 2*y100 - 101; +return 0.13961217543434561353e0 + (0.26349215871051761416e-2 + (0.22371342712572567744e-4 + (0.16868008199296822247e-6 + (0.11216596910444996246e-8 + (0.65015264753090890662e-11 + 0.32210394506666666666e-13 * t) * t) * t) * t) * t) * t; +} +case 51: { +T t = 2*y100 - 103; +return 0.14497287157673800690e0 + (0.27264675383982439814e-2 + (0.23410870961050950197e-4 + (0.17791863939526376477e-6 + (0.11886425714330958106e-8 + (0.68993039665054288034e-11 + 0.34101266222222222221e-13 * t) * t) * t) * t) * t) * t; +} +case 52: { +T t = 2*y100 - 105; +return 0.15052089272774618151e0 + (0.28222846410136238008e-2 + (0.24507470422713397006e-4 + (0.18770927679626136909e-6 + (0.12597184587583370712e-8 + (0.73203433049229821618e-11 + 0.36087889048888888890e-13 * t) * t) * t) * t) * t) * t; +} +case 53: { +T t = 2*y100 - 107; +return 0.15626501395774612325e0 + (0.29226079376196624949e-2 + (0.25664553693768450545e-4 + (0.19808568415654461964e-6 + (0.13351257759815557897e-8 + (0.77658124891046760667e-11 + 0.38173420035555555555e-13 * t) * t) * t) * t) * t) * t; +} +case 54: { +T t = 2*y100 - 109; +return 0.16221449434620737567e0 + (0.30276865332726475672e-2 + (0.26885741326534564336e-4 + (0.20908350604346384143e-6 + (0.14151148144240728728e-8 + (0.82369170665974313027e-11 + 0.40360957457777777779e-13 * t) * t) * t) * t) * t) * t; +} +case 55: { +T t = 2*y100 - 111; +return 0.16837910595412130659e0 + (0.31377844510793082301e-2 + (0.28174873844911175026e-4 + (0.22074043807045782387e-6 + (0.14999481055996090039e-8 + (0.87348993661930809254e-11 + 0.42653528977777777779e-13 * t) * t) * t) * t) * t) * t; +} +case 56: { +T t = 2*y100 - 113; +return 0.17476916455659369953e0 + (0.32531815370903068316e-2 + (0.29536024347344364074e-4 + (0.23309632627767074202e-6 + (0.15899007843582444846e-8 + (0.92610375235427359475e-11 + 0.45054073102222222221e-13 * t) * t) * t) * t) * t) * t; +} +case 57: { +T t = 2*y100 - 115; +return 0.18139556223643701364e0 + (0.33741744168096996041e-2 + (0.30973511714709500836e-4 + (0.24619326937592290996e-6 + (0.16852609412267750744e-8 + (0.98166442942854895573e-11 + 0.47565418097777777779e-13 * t) * t) * t) * t) * t) * t; +} +case 58: { +T t = 2*y100 - 117; +return 0.18826980194443664549e0 + (0.35010775057740317997e-2 + (0.32491914440014267480e-4 + (0.26007572375886319028e-6 + (0.17863299617388376116e-8 + (0.10403065638343878679e-10 + 0.50190265831111111110e-13 * t) * t) * t) * t) * t) * t; +} +case 59: { +T t = 2*y100 - 119; +return 0.19540403413693967350e0 + (0.36342240767211326315e-2 + (0.34096085096200907289e-4 + (0.27479061117017637474e-6 + (0.18934228504790032826e-8 + (0.11021679075323598664e-10 + 0.52931171733333333334e-13 * t) * t) * t) * t) * t) * t; +} +case 60: { +T t = 2*y100 - 121; +return 0.20281109560651886959e0 + (0.37739673859323597060e-2 + (0.35791165457592409054e-4 + (0.29038742889416172404e-6 + (0.20068685374849001770e-8 + (0.11673891799578381999e-10 + 0.55790523093333333334e-13 * t) * t) * t) * t) * t) * t; +} +case 61: { +T t = 2*y100 - 123; +return 0.21050455062669334978e0 + (0.39206818613925652425e-2 + (0.37582602289680101704e-4 + (0.30691836231886877385e-6 + (0.21270101645763677824e-8 + (0.12361138551062899455e-10 + 0.58770520160000000000e-13 * t) * t) * t) * t) * t) * t; +} +case 62: { +T t = 2*y100 - 125; +return 0.21849873453703332479e0 + (0.40747643554689586041e-2 + (0.39476163820986711501e-4 + (0.32443839970139918836e-6 + (0.22542053491518680200e-8 + (0.13084879235290858490e-10 + 0.61873153262222222221e-13 * t) * t) * t) * t) * t) * t; +} +case 63: { +T t = 2*y100 - 127; +return 0.22680879990043229327e0 + (0.42366354648628516935e-2 + (0.41477956909656896779e-4 + (0.34300544894502810002e-6 + (0.23888264229264067658e-8 + (0.13846596292818514601e-10 + 0.65100183751111111110e-13 * t) * t) * t) * t) * t) * t; +} +case 64: { +T t = 2*y100 - 129; +return 0.23545076536988703937e0 + (0.44067409206365170888e-2 + (0.43594444916224700881e-4 + (0.36268045617760415178e-6 + (0.25312606430853202748e-8 + (0.14647791812837903061e-10 + 0.68453122631111111110e-13 * t) * t) * t) * t) * t) * t; +} +case 65: { +T t = 2*y100 - 131; +return 0.24444156740777432838e0 + (0.45855530511605787178e-2 + (0.45832466292683085475e-4 + (0.38352752590033030472e-6 + (0.26819103733055603460e-8 + (0.15489984390884756993e-10 + 0.71933206364444444445e-13 * t) * t) * t) * t) * t) * t; +} +case 66: { +T t = 2*y100 - 133; +return 0.25379911500634264643e0 + (0.47735723208650032167e-2 + (0.48199253896534185372e-4 + (0.40561404245564732314e-6 + (0.28411932320871165585e-8 + (0.16374705736458320149e-10 + 0.75541379822222222221e-13 * t) * t) * t) * t) * t) * t; +} +case 67: { +T t = 2*y100 - 135; +return 0.26354234756393613032e0 + (0.49713289477083781266e-2 + (0.50702455036930367504e-4 + (0.42901079254268185722e-6 + (0.30095422058900481753e-8 + (0.17303497025347342498e-10 + 0.79278273368888888890e-13 * t) * t) * t) * t) * t) * t; +} +case 68: { +T t = 2*y100 - 137; +return 0.27369129607732343398e0 + (0.51793846023052643767e-2 + (0.53350152258326602629e-4 + (0.45379208848865015485e-6 + (0.31874057245814381257e-8 + (0.18277905010245111046e-10 + 0.83144182364444444445e-13 * t) * t) * t) * t) * t) * t; +} +case 69: { +T t = 2*y100 - 139; +return 0.28426714781640316172e0 + (0.53983341916695141966e-2 + (0.56150884865255810638e-4 + (0.48003589196494734238e-6 + (0.33752476967570796349e-8 + (0.19299477888083469086e-10 + 0.87139049137777777779e-13 * t) * t) * t) * t) * t) * t; +} +case 70: { +T t = 2*y100 - 141; +return 0.29529231465348519920e0 + (0.56288077305420795663e-2 + (0.59113671189913307427e-4 + (0.50782393781744840482e-6 + (0.35735475025851713168e-8 + (0.20369760937017070382e-10 + 0.91262442613333333334e-13 * t) * t) * t) * t) * t) * t; +} +case 71: { +T t = 2*y100 - 143; +return 0.30679050522528838613e0 + (0.58714723032745403331e-2 + (0.62248031602197686791e-4 + (0.53724185766200945789e-6 + (0.37827999418960232678e-8 + (0.21490291930444538307e-10 + 0.95513539182222222221e-13 * t) * t) * t) * t) * t) * t; +} +case 72: { +T t = 2*y100 - 145; +return 0.31878680111173319425e0 + (0.61270341192339103514e-2 + (0.65564012259707640976e-4 + (0.56837930287837738996e-6 + (0.40035151353392378882e-8 + (0.22662596341239294792e-10 + 0.99891109760000000000e-13 * t) * t) * t) * t) * t) * t; +} +case 73: { +T t = 2*y100 - 147; +return 0.33130773722152622027e0 + (0.63962406646798080903e-2 + (0.69072209592942396666e-4 + (0.60133006661885941812e-6 + (0.42362183765883466691e-8 + (0.23888182347073698382e-10 + 0.10439349811555555556e-12 * t) * t) * t) * t) * t) * t; +} +case 74: { +T t = 2*y100 - 149; +return 0.34438138658041336523e0 + (0.66798829540414007258e-2 + (0.72783795518603561144e-4 + (0.63619220443228800680e-6 + (0.44814499336514453364e-8 + (0.25168535651285475274e-10 + 0.10901861383111111111e-12 * t) * t) * t) * t) * t) * t; +} +case 75: { +T t = 2*y100 - 151; +return 0.35803744972380175583e0 + (0.69787978834882685031e-2 + (0.76710543371454822497e-4 + (0.67306815308917386747e-6 + (0.47397647975845228205e-8 + (0.26505114141143050509e-10 + 0.11376390933333333333e-12 * t) * t) * t) * t) * t) * t; +} +case 76: { +T t = 2*y100 - 153; +return 0.37230734890119724188e0 + (0.72938706896461381003e-2 + (0.80864854542670714092e-4 + (0.71206484718062688779e-6 + (0.50117323769745883805e-8 + (0.27899342394100074165e-10 + 0.11862637614222222222e-12 * t) * t) * t) * t) * t) * t; +} +case 77: { +T t = 2*y100 - 155; +return 0.38722432730555448223e0 + (0.76260375162549802745e-2 + (0.85259785810004603848e-4 + (0.75329383305171327677e-6 + (0.52979361368388119355e-8 + (0.29352606054164086709e-10 + 0.12360253370666666667e-12 * t) * t) * t) * t) * t) * t; +} +case 78: { +T t = 2*y100 - 157; +return 0.40282355354616940667e0 + (0.79762880915029728079e-2 + (0.89909077342438246452e-4 + (0.79687137961956194579e-6 + (0.55989731807360403195e-8 + (0.30866246101464869050e-10 + 0.12868841946666666667e-12 * t) * t) * t) * t) * t) * t; +} +case 79: { +T t = 2*y100 - 159; +return 0.41914223158913787649e0 + (0.83456685186950463538e-2 + (0.94827181359250161335e-4 + (0.84291858561783141014e-6 + (0.59154537751083485684e-8 + (0.32441553034347469291e-10 + 0.13387957943111111111e-12 * t) * t) * t) * t) * t) * t; +} +case 80: { +T t = 2*y100 - 161; +return 0.43621971639463786896e0 + (0.87352841828289495773e-2 + (0.10002929142066799966e-3 + (0.89156148280219880024e-6 + (0.62480008150788597147e-8 + (0.34079760983458878910e-10 + 0.13917107176888888889e-12 * t) * t) * t) * t) * t) * t; +} +case 81: { +T t = 2*y100 - 163; +return 0.45409763548534330981e0 + (0.91463027755548240654e-2 + (0.10553137232446167258e-3 + (0.94293113464638623798e-6 + (0.65972492312219959885e-8 + (0.35782041795476563662e-10 + 0.14455745872000000000e-12 * t) * t) * t) * t) * t) * t; +} +case 82: { +T t = 2*y100 - 165; +return 0.47282001668512331468e0 + (0.95799574408860463394e-2 + (0.11135019058000067469e-3 + (0.99716373005509038080e-6 + (0.69638453369956970347e-8 + (0.37549499088161345850e-10 + 0.15003280712888888889e-12 * t) * t) * t) * t) * t) * t; +} +case 83: { +T t = 2*y100 - 167; +return 0.49243342227179841649e0 + (0.10037550043909497071e-1 + (0.11750334542845234952e-3 + (0.10544006716188967172e-5 + (0.73484461168242224872e-8 + (0.39383162326435752965e-10 + 0.15559069118222222222e-12 * t) * t) * t) * t) * t) * t; +} +case 84: { +T t = 2*y100 - 169; +return 0.51298708979209258326e0 + (0.10520454564612427224e-1 + (0.12400930037494996655e-3 + (0.11147886579371265246e-5 + (0.77517184550568711454e-8 + (0.41283980931872622611e-10 + 0.16122419680000000000e-12 * t) * t) * t) * t) * t) * t; +} +case 85: { +T t = 2*y100 - 171; +return 0.53453307979101369843e0 + (0.11030120618800726938e-1 + (0.13088741519572269581e-3 + (0.11784797595374515432e-5 + (0.81743383063044825400e-8 + (0.43252818449517081051e-10 + 0.16692592640000000000e-12 * t) * t) * t) * t) * t) * t; +} +case 86: { +T t = 2*y100 - 173; +return 0.55712643071169299478e0 + (0.11568077107929735233e-1 + (0.13815797838036651289e-3 + (0.12456314879260904558e-5 + (0.86169898078969313597e-8 + (0.45290446811539652525e-10 + 0.17268801084444444444e-12 * t) * t) * t) * t) * t) * t; +} +case 87: { +T t = 2*y100 - 175; +return 0.58082532122519320968e0 + (0.12135935999503877077e-1 + (0.14584223996665838559e-3 + (0.13164068573095710742e-5 + (0.90803643355106020163e-8 + (0.47397540713124619155e-10 + 0.17850211608888888889e-12 * t) * t) * t) * t) * t) * t; +} +case 88: { +T t = 2*y100 - 177; +return 0.60569124025293375554e0 + (0.12735396239525550361e-1 + (0.15396244472258863344e-3 + (0.13909744385382818253e-5 + (0.95651595032306228245e-8 + (0.49574672127669041550e-10 + 0.18435945564444444444e-12 * t) * t) * t) * t) * t) * t; +} +case 89: { +T t = 2*y100 - 179; +return 0.63178916494715716894e0 + (0.13368247798287030927e-1 + (0.16254186562762076141e-3 + (0.14695084048334056083e-5 + (0.10072078109604152350e-7 + (0.51822304995680707483e-10 + 0.19025081422222222222e-12 * t) * t) * t) * t) * t) * t; +} +case 90: { +T t = 2*y100 - 181; +return 0.65918774689725319200e0 + (0.14036375850601992063e-1 + (0.17160483760259706354e-3 + (0.15521885688723188371e-5 + (0.10601827031535280590e-7 + (0.54140790105837520499e-10 + 0.19616655146666666667e-12 * t) * t) * t) * t) * t) * t; +} +case 91: { +T t = 2*y100 - 183; +return 0.68795950683174433822e0 + (0.14741765091365869084e-1 + (0.18117679143520433835e-3 + (0.16392004108230585213e-5 + (0.11155116068018043001e-7 + (0.56530360194925690374e-10 + 0.20209663662222222222e-12 * t) * t) * t) * t) * t) * t; +} +case 92: { +T t = 2*y100 - 185; +return 0.71818103808729967036e0 + (0.15486504187117112279e-1 + (0.19128428784550923217e-3 + (0.17307350969359975848e-5 + (0.11732656736113607751e-7 + (0.58991125287563833603e-10 + 0.20803065333333333333e-12 * t) * t) * t) * t) * t) * t; +} +case 93: { +T t = 2*y100 - 187; +return 0.74993321911726254661e0 + (0.16272790364044783382e-1 + (0.20195505163377912645e-3 + (0.18269894883203346953e-5 + (0.12335161021630225535e-7 + (0.61523068312169087227e-10 + 0.21395783431111111111e-12 * t) * t) * t) * t) * t) * t; +} +case 94: { +T t = 2*y100 - 189; +return 0.78330143531283492729e0 + (0.17102934132652429240e-1 + (0.21321800585063327041e-3 + (0.19281661395543913713e-5 + (0.12963340087354341574e-7 + (0.64126040998066348872e-10 + 0.21986708942222222222e-12 * t) * t) * t) * t) * t) * t; +} +case 95: { +T t = 2*y100 - 191; +return 0.81837581041023811832e0 + (0.17979364149044223802e-1 + (0.22510330592753129006e-3 + (0.20344732868018175389e-5 + (0.13617902941839949718e-7 + (0.66799760083972474642e-10 + 0.22574701262222222222e-12 * t) * t) * t) * t) * t) * t; +} +case 96: { +T t = 2*y100 - 193; +return 0.85525144775685126237e0 + (0.18904632212547561026e-1 + (0.23764237370371255638e-3 + (0.21461248251306387979e-5 + (0.14299555071870523786e-7 + (0.69543803864694171934e-10 + 0.23158593688888888889e-12 * t) * t) * t) * t) * t) * t; +} +case 97: { +T t = 2*y100 - 195; +return 0.89402868170849933734e0 + (0.19881418399127202569e-1 + (0.25086793128395995798e-3 + (0.22633402747585233180e-5 + (0.15008997042116532283e-7 + (0.72357609075043941261e-10 + 0.23737194737777777778e-12 * t) * t) * t) * t) * t) * t; +} +case 98: { +T t = 2*y100 - 197; +return 0.93481333942870796363e0 + (0.20912536329780368893e-1 + (0.26481403465998477969e-3 + (0.23863447359754921676e-5 + (0.15746923065472184451e-7 + (0.75240468141720143653e-10 + 0.24309291271111111111e-12 * t) * t) * t) * t) * t) * t; +} +case 99: { +T t = 2*y100 - 199; +return 0.97771701335885035464e0 + (0.22000938572830479551e-1 + (0.27951610702682383001e-3 + (0.25153688325245314530e-5 + (0.16514019547822821453e-7 + (0.78191526829368231251e-10 + 0.24873652355555555556e-12 * t) * t) * t) * t) * t) * t; +} + } + // we only get here if y = 1, i.e. |x| < 4*eps, in which case + // erfcx is within 1e-15 of 1.. + return 1.0; +} + +template +C10_HOST_DEVICE inline typename std::enable_if_t, T> +calc_erfcx(T x) +{ + if (at::_isnan(x)) { + return x; + } + + if (x >= 0) { + if (x > 50) { // continued-fraction expansion is faster + const T ispi = 0.56418958354775628694807945156; // 1 / sqrt(pi) + if (x > 5e7) { // 1-term expansion, important to avoid overflow + return ispi / x; + } + /* 5-term expansion (rely on compiler for CSE), simplified from: + ispi / (x+0.5/(x+1/(x+1.5/(x+2/x)))) */ + return ispi*((x*x) * (x*x+4.5) + 2) / (x * ((x*x) * (x*x+5) + 3.75)); + } + return erfcx_y100(400/(4+x)); + } + else { + if (x < -26.7) { + return std::numeric_limits::infinity(); + } + else if (x < -6.1) { + return 2*exp(x*x); + } + else { + return 2*exp(x*x) - erfcx_y100(400/(4-x)); + } + } +} + +/* + * Logarithm of Gaussian cumulative distribution function. + + * This implementation of log_ndtr and its helper functions + * follow SciPy's implementation + * See NOTICE for the licenses. + */ +template +inline C10_HOST_DEVICE T calc_log_ndtr(T x) { + T t = x * c10::frac_sqrt_2; + if (x < T{-1.0}) { + return std::log(calc_erfcx(-t) / 2) - t * t; + } else { + return std::log1p(-std::erfc(t) / 2); + } +} + +template +inline C10_HOST_DEVICE T airy_ai_forward(T x) { + static const T AN[] = { + +3.46538101525629032477e-01, + +1.20075952739645805542e+01, + +7.62796053615234516538e+01, + +1.68089224934630576269e+02, + +1.59756391350164413639e+02, + +7.05360906840444183113e+01, + +1.40264691163389668864e+01, + +9.99999999999999995305e-01, + }; + + static const T AD[] = { + +5.67594532638770212846e-01, + +1.47562562584847203173e+01, + +8.45138970141474626562e+01, + +1.77318088145400459522e+02, + +1.64234692871529701831e+02, + +7.14778400825575695274e+01, + +1.40959135607834029598e+01, + +1.00000000000000000470e+00, + }; + + static const T AFN[] = { + -1.31696323418331795333e-01, + -6.26456544431912369773e-01, + -6.93158036036933542233e-01, + -2.79779981545119124951e-01, + -4.91900132609500318020e-02, + -4.06265923594885404393e-03, + -1.59276496239262096340e-04, + -2.77649108155232920844e-06, + -1.67787698489114633780e-08, + }; + + static const T AFD[] = { + +1.33560420706553243746e+01, + +3.26825032795224613948e+01, + +2.67367040941499554804e+01, + +9.18707402907259625840e+00, + +1.47529146771666414581e+00, + +1.15687173795188044134e-01, + +4.40291641615211203805e-03, + +7.54720348287414296618e-05, + +4.51850092970580378464e-07, + }; + + static const T AGN[] = { + +1.97339932091685679179e-02, + +3.91103029615688277255e-01, + +1.06579897599595591108e+00, + +9.39169229816650230044e-01, + +3.51465656105547619242e-01, + +6.33888919628925490927e-02, + +5.85804113048388458567e-03, + +2.82851600836737019778e-04, + +6.98793669997260967291e-06, + +8.11789239554389293311e-08, + +3.41551784765923618484e-10, + }; + + static const T AGD[] = { + +9.30892908077441974853e+00, + +1.98352928718312140417e+01, + +1.55646628932864612953e+01, + +5.47686069422975497931e+00, + +9.54293611618961883998e-01, + +8.64580826352392193095e-02, + +4.12656523824222607191e-03, + +1.01259085116509135510e-04, + +1.17166733214413521882e-06, + +4.91834570062930015649e-09, + }; + + int domain_flag = 0; + + T ai; + + if (std::isinf(x)) { + return std::numeric_limits::quiet_NaN(); + } + + if (x > T(103.892)) { + return T(0.0); + } + + T f; + T g; + T k; + + if (x < T(-2.09)) { + T z = T(1.0) / (T(-2.0) * x * std::sqrt(-x) / T(3.0)); + + T afn = 0.0; + + for (uint8_t index = 0; index <= 8; index++) { + afn = afn * (z * z) + AFN[index]; + } + + T afd = 0.0; + + for (uint8_t index = 0; index <= 8; index++) { + afd = afd * (z * z) + AFD[index]; + } + + T agn = 0.0; + + for (uint8_t index = 0; index <= 10 + 0; index++) { + agn = agn * (z * z) + AGN[index]; + } + + T agd = 0.0; + + for (uint8_t index = 0; index <= 10 - 1; index++) { + agd = agd * (z * z) + AGD[index]; + } + + T t = T(-2.0) * x * std::sqrt(-x) / T(3.0) + T(0.25) * c10::pi; + + return T(5.64189583547756286948e-01) / std::sqrt(std::sqrt(-x)) * (std::sin(t) * (T(1.0) + z * z * afn / afd) - std::cos(t) * (z * agn / agd)); + } + + if (x >= T(2.09)) { + domain_flag = 5; + + T zeta = T(2.0) * x * std::sqrt(x) / T(3.0); + + T an = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + an = an * (T(1.0) / zeta) + AN[index]; + } + + T ad = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + ad = ad * (T(1.0) / zeta) + AD[index]; + } + + ai = T(5.64189583547756286948e-01) * (an / ad) / (T(2.0) * std::sqrt(std::sqrt(x)) * std::exp(zeta)); + + if (x > T(8.3203353)) { + return ai; + } + } + + f = 1.0; + g = x; + k = 1.0; + + T m = 1.0; + T n = x; + T t = 1.0; + T z = x * x * x; + + while (t > std::numeric_limits::epsilon()) { + m *= z; + k += T(1.0); + m /= k; + n *= z; + k += T(1.0); + n /= k; + m /= k; + f += m; + k += T(1.0); + n /= k; + g += n; + + t = std::abs(m / f); + } + + if ((domain_flag & 1) == 0) { + return T(0.355028053887817239260) * f - T(0.258819403792806798405) * g; + } + + return ai; +} // T airy_ai(T x) + +template +inline C10_HOST_DEVICE T bessel_j0_forward(T x) { + static const T PP[] = { + +7.96936729297347051624e-04, + +8.28352392107440799803e-02, + +1.23953371646414299388e+00, + +5.44725003058768775090e+00, + +8.74716500199817011941e+00, + +5.30324038235394892183e+00, + +9.99999999999999997821e-01, + }; + + static const T PQ[] = { + +9.24408810558863637013e-04, + +8.56288474354474431428e-02, + +1.25352743901058953537e+00, + +5.47097740330417105182e+00, + +8.76190883237069594232e+00, + +5.30605288235394617618e+00, + +1.00000000000000000218e+00, + }; + + static const T QP[] = { + -1.13663838898469149931e-02, + -1.28252718670509318512e+00, + -1.95539544257735972385e+01, + -9.32060152123768231369e+01, + -1.77681167980488050595e+02, + -1.47077505154951170175e+02, + -5.14105326766599330220e+01, + -6.05014350600728481186e+00, + }; + + static const T QQ[] = { + +6.43178256118178023184e+01, + +8.56430025976980587198e+02, + +3.88240183605401609683e+03, + +7.24046774195652478189e+03, + +5.93072701187316984827e+03, + +2.06209331660327847417e+03, + +2.42005740240291393179e+02, + }; + + static const T RP[] = { + -4.79443220978201773821e+09, + +1.95617491946556577543e+12, + -2.49248344360967716204e+14, + +9.70862251047306323952e+15, + }; + + static const T RQ[] = { + +4.99563147152651017219e+02, + +1.73785401676374683123e+05, + +4.84409658339962045305e+07, + +1.11855537045356834862e+10, + +2.11277520115489217587e+12, + +3.10518229857422583814e+14, + +3.18121955943204943306e+16, + +1.71086294081043136091e+18, + }; + + if (x < T(0)) { + x = -x; + } + + if (x <= T(5.0)) { + if (x < T(0.00001)) { + return T(1.0) - x * x / T(4.0); + } + + T rp = 0.0; + + for (uint8_t index = 0; index <= 3; index++) { + rp = rp * (x * x) + RP[index]; + } + + T rq = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + rq = rq * (x * x) + RQ[index]; + } + + return (x * x - T(5.78318596294678452118e+00)) * (x * x - T(3.04712623436620863991e+01)) * rp / rq; + } + + T pp = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pp = pp * (T(25.0) / (x * x)) + PP[index]; + } + + T pq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pq = pq * (T(25.0) / (x * x)) + PQ[index]; + } + + T qp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + qp = qp * (T(25.0) / (x * x)) + QP[index]; + } + + T qq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + qq = qq * (T(25.0) / (x * x)) + QQ[index]; + } + + return (pp / pq * std::cos(x - T(0.785398163397448309615660845819875721)) - T(5.0) / x * (qp / qq) * std::sin(x - T(0.785398163397448309615660845819875721))) * T(0.797884560802865355879892119868763737) / std::sqrt(x); +} // bessel_j0_forward(T x) + +template +inline C10_HOST_DEVICE T bessel_j1_forward(T x) { + static const T PP[] = { + +7.62125616208173112003e-04, + +7.31397056940917570436e-02, + +1.12719608129684925192e+00, + +5.11207951146807644818e+00, + +8.42404590141772420927e+00, + +5.21451598682361504063e+00, + +1.00000000000000000254e+00, + }; + + static const T PQ[] = { + +5.71323128072548699714e-04, + +6.88455908754495404082e-02, + +1.10514232634061696926e+00, + +5.07386386128601488557e+00, + +8.39985554327604159757e+00, + +5.20982848682361821619e+00, + +9.99999999999999997461e-01, + }; + + static const T QP[] = { + +5.10862594750176621635e-02, + +4.98213872951233449420e+00, + +7.58238284132545283818e+01, + +3.66779609360150777800e+02, + +7.10856304998926107277e+02, + +5.97489612400613639965e+02, + +2.11688757100572135698e+02, + +2.52070205858023719784e+01, + }; + + static const T QQ[] = { + +7.42373277035675149943e+01, + +1.05644886038262816351e+03, + +4.98641058337653607651e+03, + +9.56231892404756170795e+03, + +7.99704160447350683650e+03, + +2.82619278517639096600e+03, + +3.36093607810698293419e+02, + }; + + static const T RP[] = { + -8.99971225705559398224e+08, + +4.52228297998194034323e+11, + -7.27494245221818276015e+13, + +3.68295732863852883286e+15, + }; + + static const T RQ[] = { + +6.20836478118054335476e+02, + +2.56987256757748830383e+05, + +8.35146791431949253037e+07, + +2.21511595479792499675e+10, + +4.74914122079991414898e+12, + +7.84369607876235854894e+14, + +8.95222336184627338078e+16, + +5.32278620332680085395e+18, + }; + + if (x < T(0.0)) { + return -bessel_j1_forward(-x); + } + + if (x <= T(5.0)) { + T rp = 0.0; + + for (uint8_t index = 0; index <= 3; index++) { + rp = rp * (x * x) + RP[index]; + } + + T rq = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + rq = rq * (x * x) + RQ[index]; + } + + return rp / rq * x * (x * x - T(1.46819706421238932572e+01)) * (x * x - T(4.92184563216946036703e+01)); + } + + T pp = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pp = pp * (T(5.0) / x * (T(5.0) / x)) + PP[index]; + } + + T pq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pq = pq * (T(5.0) / x * (T(5.0) / x)) + PQ[index]; + } + + T qp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + qp = qp * (T(5.0) / x * (T(5.0) / x)) + QP[index]; + } + + T qq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + qq = qq * (T(5.0) / x * (T(5.0) / x)) + QQ[index]; + } + + return (pp / pq * std::cos(x - T(2.356194490192344928846982537459627163)) - T(5.0) / x * (qp / qq) * std::sin(x - T(2.356194490192344928846982537459627163))) * T(0.797884560802865355879892119868763737) / std::sqrt(x); +} // bessel_j1_forward(T x) + +template +inline C10_HOST_DEVICE T bessel_y0_forward(T x) { + static const T PP[] = { + +7.96936729297347051624e-04, + +8.28352392107440799803e-02, + +1.23953371646414299388e+00, + +5.44725003058768775090e+00, + +8.74716500199817011941e+00, + +5.30324038235394892183e+00, + +9.99999999999999997821e-01, + }; + + static const T PQ[] = { + +9.24408810558863637013e-04, + +8.56288474354474431428e-02, + +1.25352743901058953537e+00, + +5.47097740330417105182e+00, + +8.76190883237069594232e+00, + +5.30605288235394617618e+00, + +1.00000000000000000218e+00, + }; + + static const T QP[] = { + -1.13663838898469149931e-02, + -1.28252718670509318512e+00, + -1.95539544257735972385e+01, + -9.32060152123768231369e+01, + -1.77681167980488050595e+02, + -1.47077505154951170175e+02, + -5.14105326766599330220e+01, + -6.05014350600728481186e+00, + }; + + static const T QQ[] = { + +6.43178256118178023184e+01, + +8.56430025976980587198e+02, + +3.88240183605401609683e+03, + +7.24046774195652478189e+03, + +5.93072701187316984827e+03, + +2.06209331660327847417e+03, + +2.42005740240291393179e+02, + }; + + static const T YP[] = { + +1.55924367855235737965e+04, + -1.46639295903971606143e+07, + +5.43526477051876500413e+09, + -9.82136065717911466409e+11, + +8.75906394395366999549e+13, + -3.46628303384729719441e+15, + +4.42733268572569800351e+16, + -1.84950800436986690637e+16, + }; + + static const T YQ[] = { + +1.04128353664259848412e+03, + +6.26107330137134956842e+05, + +2.68919633393814121987e+08, + +8.64002487103935000337e+10, + +2.02979612750105546709e+13, + +3.17157752842975028269e+15, + +2.50596256172653059228e+17, + }; + + if (x <= T(5.0)) { + if (x == T(0.0)) { + return -std::numeric_limits::infinity(); + } + + if (x < T(0.0)) { + return std::numeric_limits::quiet_NaN(); + } + + T yp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + yp = yp * (x * x) + YP[index]; + } + + T yq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + yq = yq * (x * x) + YQ[index]; + } + + return yp / yq + (T(0.636619772367581343075535053490057448) * std::log(x) * bessel_j0_forward(x)); + } + + T pp = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pp = pp * (T(25.0) / (x * x)) + PP[index]; + } + + T pq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pq = pq * (T(25.0) / (x * x)) + PQ[index]; + } + + T qp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + qp = qp * (T(25.0) / (x * x)) + QP[index]; + } + + T qq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + qq = qq * (T(25.0) / (x * x)) + QQ[index]; + } + + return (pp / pq * std::sin(x - T(0.785398163397448309615660845819875721)) + T(5.0) / x * (qp / qq) * std::cos(x - T(0.785398163397448309615660845819875721))) * T(0.797884560802865355879892119868763737) / std::sqrt(x); +} // bessel_y0_forward(T x) + +template +inline C10_HOST_DEVICE T bessel_y1_forward(T x) { + static const T PP[] = { + +7.62125616208173112003e-04, + +7.31397056940917570436e-02, + +1.12719608129684925192e+00, + +5.11207951146807644818e+00, + +8.42404590141772420927e+00, + +5.21451598682361504063e+00, + +1.00000000000000000254e+00, + }; + + static const T PQ[] = { + +5.71323128072548699714e-04, + +6.88455908754495404082e-02, + +1.10514232634061696926e+00, + +5.07386386128601488557e+00, + +8.39985554327604159757e+00, + +5.20982848682361821619e+00, + +9.99999999999999997461e-01, + }; + + static const T QP[] = { + +5.10862594750176621635e-02, + +4.98213872951233449420e+00, + +7.58238284132545283818e+01, + +3.66779609360150777800e+02, + +7.10856304998926107277e+02, + +5.97489612400613639965e+02, + +2.11688757100572135698e+02, + +2.52070205858023719784e+01, + }; + + static const T QQ[] = { + +7.42373277035675149943e+01, + +1.05644886038262816351e+03, + +4.98641058337653607651e+03, + +9.56231892404756170795e+03, + +7.99704160447350683650e+03, + +2.82619278517639096600e+03, + +3.36093607810698293419e+02, + }; + + static const T YP[] = { + +1.26320474790178026440e+09, + -6.47355876379160291031e+11, + +1.14509511541823727583e+14, + -8.12770255501325109621e+15, + +2.02439475713594898196e+17, + -7.78877196265950026825e+17, + }; + + static const T YQ[] = { + +5.94301592346128195359e+02, + +2.35564092943068577943e+05, + +7.34811944459721705660e+07, + +1.87601316108706159478e+10, + +3.88231277496238566008e+12, + +6.20557727146953693363e+14, + +6.87141087355300489866e+16, + +3.97270608116560655612e+18, + }; + + if (x <= T(5.0)) { + if (x == T(0.0)) { + return -std::numeric_limits::infinity(); + } + + if (x <= T(0.0)) { + return std::numeric_limits::quiet_NaN(); + } + + T yp = 0.0; + + for (uint8_t index = 0; index <= 5; index++) { + yp = yp * (x * x) + YP[index]; + } + + T yq = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + yq = yq * (x * x) + YQ[index]; + } + + return x * (yp / yq) + (T(0.636619772367581343075535053490057448) * (bessel_j1_forward(x) * std::log(x) - T(1.0) / x)); + } + + T pp = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pp = pp * (T(5.0) / x * (T(5.0) / x)) + PP[index]; + } + + T pq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pq = pq * (T(5.0) / x * (T(5.0) / x)) + PQ[index]; + } + + T qp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + qp = qp * (T(5.0) / x * (T(5.0) / x)) + QP[index]; + } + + T qq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + qq = qq * (T(5.0) / x * (T(5.0) / x)) + QQ[index]; + } + + return (pp / pq * std::sin(x - T(2.356194490192344928846982537459627163)) + T(5.0) / x * (qp / qq) * std::cos(x - T(2.356194490192344928846982537459627163))) * T(0.797884560802865355879892119868763737) / std::sqrt(x); +} // bessel_y1_forward(T x) + +template +inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (std::abs(x) == T(1.0)) { + if (x > T(0.0) || n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if ((n > 6) && (std::abs(x) < T(1.0))) { + return std::cos(n * std::acos(x)); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x; + } + + T p = T(1.0); + T q = x; + T r; + + for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) { + r = (x + x) * q - p; + p = q; + q = r; + } + + return r; +} // chebyshev_polynomial_t_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, T n) { + return chebyshev_polynomial_t_forward(x, static_cast(n)); +} // chebyshev_polynomial_t_forward(T x, T n) + +template +inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (std::abs(x) == T(1.0)) { + if (x > T(0.0) || n % 2 == 0) { + return n + 1; + } + + return -(n + 1); + } + + if ((n > 8) && (std::abs(x) < T(1.0))) { + if (std::sin(std::acos(x)) != T(0.0)) { + return std::sin((n + 1) * std::acos(x)) / std::sin(std::acos(x)); + } + + return (n + 1) * std::cos((n + 1) * std::acos(x)) / x; + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x; + } + + T p = T(1.0); + T q = x + x; + T r; + + for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) { + r = (x + x) * q - p; + p = q; + q = r; + } + + return r; +} // chebyshev_polynomial_u_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, T n) { + return chebyshev_polynomial_u_forward(x, static_cast(n)); +} // chebyshev_polynomial_u_forward(T x, T n) + +template +inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (std::abs(x) == T(1.0)) { + if (x > T(0.0)) { + return T(1.0); + } + + if (n % 2 == 0) { + return n + n + 1; + } + + return -(n + n + 1); + } + + if ((n > 8) && (std::abs(x) < T(1.0))) { + if (std::sin(std::acos(x) / T(2.0)) != T(1.0)) { + return std::cos((n + T(0.5)) * std::acos(x)) / std::cos(std::acos(x) / T(2.0)); + } + + if (n % 2 == 0) { + return n + n + 1; + } + + return -(n + n + 1); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0); + } + + T p = T(1.0); + T q = x + x - T(1.0); + T r; + + for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) { + r = (x + x) * q - p; + p = q; + q = r; + } + + return r; +} // chebyshev_polynomial_v_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, T n) { + return chebyshev_polynomial_v_forward(x, static_cast(n)); +} // chebyshev_polynomial_v_forward(T x, T n) + +template +inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (std::abs(x) == T(1.0)) { + if (x > T(0.0)) { + return n + n + 1; + } + + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if ((n > 8) && (std::abs(x) < T(1.0))) { + if (std::cos(std::acos(x) / T(2.0)) != T(1.0)) { + return std::sin((n + T(0.5)) * std::acos(x)) / std::sin(std::acos(x) / T(2.0)); + } + + if (x > T(0.0)) { + return n + n + 1; + } + + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x + T(1.0); + } + + T p = T(1.0); + T q = x + x + T(1.0); + T r; + + for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) { + r = (x + x) * q - p; + p = q; + q = r; + } + + return r; +} // chebyshev_polynomial_w_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, T n) { + return chebyshev_polynomial_w_forward(x, static_cast(n)); +} // chebyshev_polynomial_w_forward(T x, T n) + +template +constexpr auto getHermitianLimit() { + if constexpr (std::is_same_v) { + return 128; + } else if constexpr (std::is_same_v) { + return 512; + } else { + return 1024; + } +} + +template +inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x; + } + + if (n > getHermitianLimit()) { + return std::numeric_limits::quiet_NaN(); + } + + T p = T(1.0); + T q = x + x; + T r = T(0.0); + + for (int64_t k = 2; k < n + n; k += 2) { + r = (x + x) * q - k * p; + p = q; + q = r; + } + + return r; +} // hermite_polynomial_h_forward(T x, int64_t n) + +template, int> = 0> +inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) { + return hermite_polynomial_h_forward(x, static_cast(n)); +} // hermite_polynomial_h_forward(T x, T n) + +template, int> = 0> +__ubsan_ignore_float_cast_overflow__ inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) { + return hermite_polynomial_h_forward(x, (!std::isinf(n) && !std::isnan(n)) ? static_cast(n) : static_cast(-1)); +} // hermite_polynomial_h_forward(T x, T n) + +template +inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x; + } + + if (n > getHermitianLimit()) { + return std::numeric_limits::quiet_NaN(); + } + + T p = T(1.0); + T q = x; + T r; + + for (int64_t k = 1; k < n; k++) { + r = x * q - k * p; + p = q; + q = r; + } + + return r; +} // hermite_polynomial_he_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, T n) { + return hermite_polynomial_he_forward(x, static_cast(n)); +} // hermite_polynomial_he_forward(T x, T n) + +template +inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (std::abs(x) == T(0.0)) { + return T(1.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return T(1.0) - x; + } + + T p = T(1.0); + T q = T(1.0) - x; + T r; + + for (int64_t k = 1; (k < n) && !std::isnan(q); k++) { + r = (((k + k) + (T(1.0) - x)) * q - k * p) / (k + 1); + p = q; + q = r; + } + + return r; +} // laguerre_polynomial_l_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, T n) { + return laguerre_polynomial_l_forward(x, static_cast(n)); +} // laguerre_polynomial_l_forward(T x, T n) + +template +inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (std::abs(x) == T(1.0)) { + if (x > T(0.0) || n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x; + } + + T p = T(1.0); + T q = x; + T r; + + for (int64_t k = 1; (k < n) && !std::isnan(q); k++) { + r = ((k + k + 1) * x * q - k * p) / (k + 1); + p = q; + q = r; + } + + return r; +} // legendre_polynomial_p_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, T n) { + return legendre_polynomial_p_forward(x, static_cast(n)); +} // legendre_polynomial_p_forward(T x, T n) + +template +inline C10_HOST_DEVICE T modified_bessel_i0_forward(T x) { + static const T A[] = { + -4.41534164647933937950e-18, + +3.33079451882223809783e-17, + -2.43127984654795469359e-16, + +1.71539128555513303061e-15, + -1.16853328779934516808e-14, + +7.67618549860493561688e-14, + -4.85644678311192946090e-13, + +2.95505266312963983461e-12, + -1.72682629144155570723e-11, + +9.67580903537323691224e-11, + -5.18979560163526290666e-10, + +2.65982372468238665035e-09, + -1.30002500998624804212e-08, + +6.04699502254191894932e-08, + -2.67079385394061173391e-07, + +1.11738753912010371815e-06, + -4.41673835845875056359e-06, + +1.64484480707288970893e-05, + -5.75419501008210370398e-05, + +1.88502885095841655729e-04, + -5.76375574538582365885e-04, + +1.63947561694133579842e-03, + -4.32430999505057594430e-03, + +1.05464603945949983183e-02, + -2.37374148058994688156e-02, + +4.93052842396707084878e-02, + -9.49010970480476444210e-02, + +1.71620901522208775349e-01, + -3.04682672343198398683e-01, + +6.76795274409476084995e-01, + }; + + static const T B[] = { + -7.23318048787475395456e-18, + -4.83050448594418207126e-18, + +4.46562142029675999901e-17, + +3.46122286769746109310e-17, + -2.82762398051658348494e-16, + -3.42548561967721913462e-16, + +1.77256013305652638360e-15, + +3.81168066935262242075e-15, + -9.55484669882830764870e-15, + -4.15056934728722208663e-14, + +1.54008621752140982691e-14, + +3.85277838274214270114e-13, + +7.18012445138366623367e-13, + -1.79417853150680611778e-12, + -1.32158118404477131188e-11, + -3.14991652796324136454e-11, + +1.18891471078464383424e-11, + +4.94060238822496958910e-10, + +3.39623202570838634515e-09, + +2.26666899049817806459e-08, + +2.04891858946906374183e-07, + +2.89137052083475648297e-06, + +6.88975834691682398426e-05, + +3.36911647825569408990e-03, + +8.04490411014108831608e-01, + }; + + T p; + T q = 0.0; + + if (std::abs(x) <= T(8.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 30; index++) { + p = q; + q = a; + a = ((std::abs(x) / T(2.0)) - T(2.0)) * q - p + A[index]; + } + + return std::exp(std::abs(x)) * (T(0.5) * (a - p)); + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(32.0) / std::abs(x) - T(2.0)) * q - p + B[index]; + } + + return std::exp(std::abs(x)) * (T(0.5) * (b - p)) / std::sqrt(std::abs(x)); +} // modified_bessel_i0_forward(T x) + +template +inline C10_HOST_DEVICE T modified_bessel_i1_forward(T x) { + static const T A[] = { + +2.77791411276104639959e-18, + -2.11142121435816608115e-17, + +1.55363195773620046921e-16, + -1.10559694773538630805e-15, + +7.60068429473540693410e-15, + -5.04218550472791168711e-14, + +3.22379336594557470981e-13, + -1.98397439776494371520e-12, + +1.17361862988909016308e-11, + -6.66348972350202774223e-11, + +3.62559028155211703701e-10, + -1.88724975172282928790e-09, + +9.38153738649577178388e-09, + -4.44505912879632808065e-08, + +2.00329475355213526229e-07, + -8.56872026469545474066e-07, + +3.47025130813767847674e-06, + -1.32731636560394358279e-05, + +4.78156510755005422638e-05, + -1.61760815825896745588e-04, + +5.12285956168575772895e-04, + -1.51357245063125314899e-03, + +4.15642294431288815669e-03, + -1.05640848946261981558e-02, + +2.47264490306265168283e-02, + -5.29459812080949914269e-02, + +1.02643658689847095384e-01, + -1.76416518357834055153e-01, + +2.52587186443633654823e-01, + }; + + static const T B[] = { + +7.51729631084210481353e-18, + +4.41434832307170791151e-18, + -4.65030536848935832153e-17, + -3.20952592199342395980e-17, + +2.96262899764595013876e-16, + +3.30820231092092828324e-16, + -1.88035477551078244854e-15, + -3.81440307243700780478e-15, + +1.04202769841288027642e-14, + +4.27244001671195135429e-14, + -2.10154184277266431302e-14, + -4.08355111109219731823e-13, + -7.19855177624590851209e-13, + +2.03562854414708950722e-12, + +1.41258074366137813316e-11, + +3.25260358301548823856e-11, + -1.89749581235054123450e-11, + -5.58974346219658380687e-10, + -3.83538038596423702205e-09, + -2.63146884688951950684e-08, + -2.51223623787020892529e-07, + -3.88256480887769039346e-06, + -1.10588938762623716291e-04, + -9.76109749136146840777e-03, + +7.78576235018280120474e-01, + }; + + T p; + T q = 0.0; + + if (std::abs(x) <= T(8.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 29; index++) { + p = q; + q = a; + a = ((std::abs(x) / T(2.0)) - T(2.0)) * q - p + A[index]; + } + + if (x < T(0.0)) { + return -(T(0.5) * (a - p) * std::abs(x) * std::exp(std::abs(x))); + } + + return T(0.5) * (a - p) * std::abs(x) * std::exp(std::abs(x)); + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(32.0) / std::abs(x) - T(2.0)) * q - p + B[index]; + } + + if (x < T(0.0)) { + return -(std::exp(std::abs(x)) * (T(0.5) * (b - p)) / std::sqrt(std::abs(x))); + } + + return std::exp(std::abs(x)) * (T(0.5) * (b - p)) / std::sqrt(std::abs(x)); +} // modified_bessel_i1_forward(T x) + +template +inline C10_HOST_DEVICE T modified_bessel_k0_forward(T x) { + static const T A[] = { + +1.37446543561352307156e-16, + +4.25981614279661018399e-14, + +1.03496952576338420167e-11, + +1.90451637722020886025e-09, + +2.53479107902614945675e-07, + +2.28621210311945178607e-05, + +1.26461541144692592338e-03, + +3.59799365153615016266e-02, + +3.44289899924628486886e-01, + -5.35327393233902768720e-01, + }; + + static const T B[] = { + +5.30043377268626276149e-18, + -1.64758043015242134646e-17, + +5.21039150503902756861e-17, + -1.67823109680541210385e-16, + +5.51205597852431940784e-16, + -1.84859337734377901440e-15, + +6.34007647740507060557e-15, + -2.22751332699166985548e-14, + +8.03289077536357521100e-14, + -2.98009692317273043925e-13, + +1.14034058820847496303e-12, + -4.51459788337394416547e-12, + +1.85594911495471785253e-11, + -7.95748924447710747776e-11, + +3.57739728140030116597e-10, + -1.69753450938905987466e-09, + +8.57403401741422608519e-09, + -4.66048989768794782956e-08, + +2.76681363944501510342e-07, + -1.83175552271911948767e-06, + +1.39498137188764993662e-05, + -1.28495495816278026384e-04, + +1.56988388573005337491e-03, + -3.14481013119645005427e-02, + +2.44030308206595545468e+00, + }; + + if (x == T(0.0)) { + return std::numeric_limits::infinity(); + } + + if (x < T(0.0)) { + return std::numeric_limits::quiet_NaN(); + } + + T p; + T q = 0.0; + + if (x <= T(2.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 10; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return T(0.5) * (a - p) - std::log(0.5 * x) * modified_bessel_i0_forward(x); + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(8.0) / x - T(2.0)) * q - p + B[index]; + } + + return std::exp(-x) * (T(0.5) * (b - p)) / std::sqrt(x); +} // modified_bessel_k0_forward(T x) + +template +inline C10_HOST_DEVICE T modified_bessel_k1_forward(T x) { + static const T A[] = { + -7.02386347938628759343e-18, + -2.42744985051936593393e-15, + -6.66690169419932900609e-13, + -1.41148839263352776110e-10, + -2.21338763073472585583e-08, + -2.43340614156596823496e-06, + -1.73028895751305206302e-04, + -6.97572385963986435018e-03, + -1.22611180822657148235e-01, + -3.53155960776544875667e-01, + +1.52530022733894777053e+00, + }; + + static const T B[] = { + -5.75674448366501715755e-18, + +1.79405087314755922667e-17, + -5.68946255844285935196e-17, + +1.83809354436663880070e-16, + -6.05704724837331885336e-16, + +2.03870316562433424052e-15, + -7.01983709041831346144e-15, + +2.47715442448130437068e-14, + -8.97670518232499435011e-14, + +3.34841966607842919884e-13, + -1.28917396095102890680e-12, + +5.13963967348173025100e-12, + -2.12996783842756842877e-11, + +9.21831518760500529508e-11, + -4.19035475934189648750e-10, + +2.01504975519703286596e-09, + -1.03457624656780970260e-08, + +5.74108412545004946722e-08, + -3.50196060308781257119e-07, + +2.40648494783721712015e-06, + -1.93619797416608296024e-05, + +1.95215518471351631108e-04, + -2.85781685962277938680e-03, + +1.03923736576817238437e-01, + +2.72062619048444266945e+00, + }; + + if (x == T(0.0)) { + return std::numeric_limits::infinity(); + } + + if (x < T(0.0)) { + return std::numeric_limits::quiet_NaN(); + } + + T p; + T q = 0.0; + + if (x <= T(2.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 11; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return std::log(T(0.5) * x) * modified_bessel_i1_forward(x) + T(0.5) * (a - p) / x; + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(8.0) / x - T(2.0)) * q - p + B[index]; + } + + return std::exp(-x) * (T(0.5) * (b - p)) / std::sqrt(x); +} // modified_bessel_k1_forward(T x) + +template +inline C10_HOST_DEVICE T scaled_modified_bessel_k0_forward(T x) { + static const T A[] = { + +1.37446543561352307156e-16, + +4.25981614279661018399e-14, + +1.03496952576338420167e-11, + +1.90451637722020886025e-09, + +2.53479107902614945675e-07, + +2.28621210311945178607e-05, + +1.26461541144692592338e-03, + +3.59799365153615016266e-02, + +3.44289899924628486886e-01, + -5.35327393233902768720e-01, + }; + + static const T B[] = { + +5.30043377268626276149e-18, + -1.64758043015242134646e-17, + +5.21039150503902756861e-17, + -1.67823109680541210385e-16, + +5.51205597852431940784e-16, + -1.84859337734377901440e-15, + +6.34007647740507060557e-15, + -2.22751332699166985548e-14, + +8.03289077536357521100e-14, + -2.98009692317273043925e-13, + +1.14034058820847496303e-12, + -4.51459788337394416547e-12, + +1.85594911495471785253e-11, + -7.95748924447710747776e-11, + +3.57739728140030116597e-10, + -1.69753450938905987466e-09, + +8.57403401741422608519e-09, + -4.66048989768794782956e-08, + +2.76681363944501510342e-07, + -1.83175552271911948767e-06, + +1.39498137188764993662e-05, + -1.28495495816278026384e-04, + +1.56988388573005337491e-03, + -3.14481013119645005427e-02, + +2.44030308206595545468e+00, + }; + + if (x == T(0.0)) { + return std::numeric_limits::infinity(); + } + + if (x < T(0.0)) { + return std::numeric_limits::quiet_NaN(); + } + + T p; + T q = 0.0; + + if (x <= T(2.0)) { + T a = A[0]; + + for (uint64_t index = 1; index < 10; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return (T(0.5) * (a - p) - std::log(T(0.5) * x) * modified_bessel_i0_forward(x)) * std::exp(x); + } + + T b = B[0]; + + for (uint64_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(8.0) / x - T(2.0)) * q - p + B[index]; + } + + return T(0.5) * (b - p) / std::sqrt(x); +} // T scaled_modified_bessel_k0_forward(T x) + +template +inline C10_HOST_DEVICE T scaled_modified_bessel_k1_forward(T x) { + static const T A[] = { + -7.02386347938628759343e-18, + -2.42744985051936593393e-15, + -6.66690169419932900609e-13, + -1.41148839263352776110e-10, + -2.21338763073472585583e-08, + -2.43340614156596823496e-06, + -1.73028895751305206302e-04, + -6.97572385963986435018e-03, + -1.22611180822657148235e-01, + -3.53155960776544875667e-01, + +1.52530022733894777053e+00, + }; + + static const T B[] = { + -5.75674448366501715755e-18, + +1.79405087314755922667e-17, + -5.68946255844285935196e-17, + +1.83809354436663880070e-16, + -6.05704724837331885336e-16, + +2.03870316562433424052e-15, + -7.01983709041831346144e-15, + +2.47715442448130437068e-14, + -8.97670518232499435011e-14, + +3.34841966607842919884e-13, + -1.28917396095102890680e-12, + +5.13963967348173025100e-12, + -2.12996783842756842877e-11, + +9.21831518760500529508e-11, + -4.19035475934189648750e-10, + +2.01504975519703286596e-09, + -1.03457624656780970260e-08, + +5.74108412545004946722e-08, + -3.50196060308781257119e-07, + +2.40648494783721712015e-06, + -1.93619797416608296024e-05, + +1.95215518471351631108e-04, + -2.85781685962277938680e-03, + +1.03923736576817238437e-01, + +2.72062619048444266945e+00, + }; + + if (x == T(0.0)) { + return std::numeric_limits::infinity(); + } + + if (x < T(0.0)) { + return std::numeric_limits::quiet_NaN(); + } + + T p; + T q = 0.0; + + if (x <= T(2.0)) { + T a = A[0]; + + for (uint64_t index = 1; index < 11; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return (std::log(T(0.5) * x) * modified_bessel_i1_forward(x) + T(0.5) * (a - p) / x) * std::exp(x); + } + + T b = B[0]; + + for (uint64_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(8.0) / x - T(2.0)) * q - p + B[index]; + } + + return (T(0.5) * (b - p) / std::sqrt(x)); +} // T scaled_modified_bessel_k1_forward(T x) + +template +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (x == T(1.0)) { + return T(1.0); + } + + if (x == T(0.0)) { + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if ((n > 6) && (std::abs(x + x - T(1.0)) < T(1.0))) { + return std::cos(n * std::acos(x + x - T(1.0))); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0); + } + + T p = T(1.0); + T q = x + x - T(1.0); + T r; + + for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) { + r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; + p = q; + q = r; + } + + return r; +} // shifted_chebyshev_polynomial_t_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, T n) { + return shifted_chebyshev_polynomial_t_forward(x, static_cast(n)); +} // shifted_chebyshev_polynomial_t_forward(T x, T n) + +template +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (x == T(1.0)) { + return n + 1; + } + + if (x == T(0.0)) { + if (n % 2 == 0) { + return n + 1; + } + + return -(n + 1); + } + + if ((n > 6) && (std::abs(x + x - T(1.0)) < T(1.0))) { + if (std::sin(std::acos(x + x - T(1.0))) != T(0.0)) { + return std::sin((n + 1) * std::acos(x + x - T(1.0))) / std::sin(std::acos(x + x - T(1.0))); + } + + return (n + 1) * std::cos((n + 1) * std::acos(x + x - T(1.0))) / (x + x - T(1.0)); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0) + (x + x - T(1.0)); + } + + T p = T(1.0); + T q = x + x - T(1.0) + (x + x - T(1.0)); + T r; + + for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) { + r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; + p = q; + q = r; + } + + return r; +} // shifted_chebyshev_polynomial_u_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, T n) { + return shifted_chebyshev_polynomial_u_forward(x, static_cast(n)); +} // shifted_chebyshev_polynomial_u_forward(T x, T n) + +template +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (x == T(1.0)) { + return T(1.0); + } + + if (x == T(0.0)) { + if (n % 2 == 0) { + return (n + n + 1); + } + + return -(n + n + 1); + } + + if ((n > 6) && (std::abs(x + x - T(1.0)) < T(1.0))) { + if (std::sin(std::acos(x + x - T(1.0)) / T(2.0)) != T(1.0)) { + return std::cos((n + T(0.5)) * std::acos(x + x - T(1.0))) / std::cos(std::acos(x + x - T(1.0)) / T(2.0)); + } + + if (n % 2 == 0) { + return n + n + 1; + } + + return -(n + n + 1); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0) + (x + x - T(1.0)) - T(1.0); + } + + T p = T(1.0); + T q = x + x - T(1.0) + (x + x - T(1.0)) - T(1.0); + T r; + + for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) { + r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; + p = q; + q = r; + } + + return r; +} // shifted_chebyshev_polynomial_v_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, T n) { + return shifted_chebyshev_polynomial_v_forward(x, static_cast(n)); +} // shifted_chebyshev_polynomial_v_forward(T x, T n) + +template +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (x == T(1.0)) { + return n + n + 1; + } + + if (x == T(0.0)) { + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if ((n > 4) && (std::abs(x + x - T(1.0)) < T(1.0))) { + if (std::cos(std::acos(x + x - T(1.0)) / T(2.0)) != T(1.0)) { + return std::sin((n + T(0.5)) * std::acos(x + x - T(1.0))) / std::sin(std::acos(x + x - T(1.0)) / T(2.0)); + } + + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0) + (x + x - T(1.0)) + T(1.0); + } + + T p = T(1.0); + T q = x + x - T(1.0) + (x + x - T(1.0)) + T(1.0); + T r; + + for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) { + r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; + p = q; + q = r; + } + + return r; +} // shifted_chebyshev_polynomial_w_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, T n) { + return shifted_chebyshev_polynomial_w_forward(x, static_cast(n)); +} // shifted_chebyshev_polynomial_w_forward(T x, T n) + +template +inline C10_HOST_DEVICE T spherical_bessel_j0_forward(T x) { + if (std::isinf(x)) { + return T(0.0); + } + + if (std::abs(x) < T(0.5)) { + return T(1.0) + x * x * (T(-1.0) / T(6.0) + x * x * (T(1.0) / T(120.0) + x * x * (T(-1.0) / T(5040.0) + x * x * (T(1.0) / T(362880.0) + x * x * (T(-1.0) / T(39916800.0) + x * x * (T(1.0) / T(6227020800.0))))))); + } + + return std::sin(x) / x; +} // T spherical_bessel_j0_forward(T x) + +C10_CLANG_DIAGNOSTIC_POP() + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/MathBitFallThroughLists.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/MathBitFallThroughLists.h new file mode 100644 index 0000000000000000000000000000000000000000..3cdb4465af19f8c07cfbea06e2c17f69f64fe21a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/MathBitFallThroughLists.h @@ -0,0 +1,76 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +namespace at { +// views and their in-place version ops +#define TORCH_VIEW_FNS(m) \ + m.impl("as_strided_", torch::CppFunction::makeFallthrough()); \ + m.impl("detach", torch::CppFunction::makeFallthrough()); \ + m.impl("detach_", torch::CppFunction::makeFallthrough()); \ + m.impl("diagonal", torch::CppFunction::makeFallthrough()); \ + m.impl("expand", torch::CppFunction::makeFallthrough()); \ + m.impl("expand_as", torch::CppFunction::makeFallthrough()); \ + m.impl("movedim.int", torch::CppFunction::makeFallthrough()); \ + m.impl("movedim.intlist", torch::CppFunction::makeFallthrough()); \ + m.impl("narrow", torch::CppFunction::makeFallthrough()); \ + m.impl("permute", torch::CppFunction::makeFallthrough()); \ + m.impl("select.Dimname", torch::CppFunction::makeFallthrough()); \ + m.impl("select.int", torch::CppFunction::makeFallthrough()); \ + m.impl("squeeze", torch::CppFunction::makeFallthrough()); \ + m.impl("squeeze_", torch::CppFunction::makeFallthrough()); \ + m.impl("transpose.int", torch::CppFunction::makeFallthrough()); \ + m.impl("transpose.Dimname", torch::CppFunction::makeFallthrough()); \ + m.impl("transpose_", torch::CppFunction::makeFallthrough()); \ + m.impl("t", torch::CppFunction::makeFallthrough()); \ + m.impl("t_", torch::CppFunction::makeFallthrough()); \ + m.impl("real", torch::CppFunction::makeFallthrough()); \ + m.impl("imag", torch::CppFunction::makeFallthrough()); \ + m.impl("view_as_real", torch::CppFunction::makeFallthrough()); \ + m.impl("unflatten.int", torch::CppFunction::makeFallthrough()); \ + m.impl("unflatten.Dimname", torch::CppFunction::makeFallthrough()); \ + m.impl("unfold", torch::CppFunction::makeFallthrough()); \ + m.impl("unsqueeze", torch::CppFunction::makeFallthrough()); \ + m.impl("unsqueeze_", torch::CppFunction::makeFallthrough()); \ + m.impl("view_as", torch::CppFunction::makeFallthrough()); \ + m.impl("unbind.int", torch::CppFunction::makeFallthrough()); \ + m.impl("unbind.Dimname", torch::CppFunction::makeFallthrough()); \ + m.impl("split.Tensor", torch::CppFunction::makeFallthrough()); \ + m.impl("split_with_sizes", torch::CppFunction::makeFallthrough()); \ + m.impl("swapaxes", torch::CppFunction::makeFallthrough()); \ + m.impl("swapdims", torch::CppFunction::makeFallthrough()); \ + m.impl("chunk", torch::CppFunction::makeFallthrough()); \ + m.impl("reshape", torch::CppFunction::makeFallthrough()); \ + m.impl("alias", torch::CppFunction::makeFallthrough()); \ + m.impl("hsplit.int", torch::CppFunction::makeFallthrough()); \ + m.impl("hsplit.array", torch::CppFunction::makeFallthrough()); \ + m.impl("dsplit.int", torch::CppFunction::makeFallthrough()); \ + m.impl("dsplit.array", torch::CppFunction::makeFallthrough()); \ + m.impl("vsplit.int", torch::CppFunction::makeFallthrough()); \ + m.impl("vsplit.array", torch::CppFunction::makeFallthrough()); \ + m.impl("conj", torch::CppFunction::makeFallthrough()); \ + m.impl("_conj", torch::CppFunction::makeFallthrough()); \ + m.impl("_unsafe_view", torch::CppFunction::makeFallthrough()); \ + m.impl("resize_", torch::CppFunction::makeFallthrough()); + +#define TENSOR_UTILITIES_AND_CONSTRUCTORS(m) \ + m.impl("empty_like", torch::CppFunction::makeFallthrough()); \ + m.impl("empty.memory_format", torch::CppFunction::makeFallthrough()); \ + m.impl("empty.out", torch::CppFunction::makeFallthrough()); \ + m.impl("empty_strided", torch::CppFunction::makeFallthrough()); \ + m.impl("full_like", torch::CppFunction::makeFallthrough()); \ + m.impl("stride.int", torch::CppFunction::makeFallthrough()); \ + m.impl("stride.Dimname", torch::CppFunction::makeFallthrough()); \ + m.impl("size.int", torch::CppFunction::makeFallthrough()); \ + m.impl("size.Dimname", torch::CppFunction::makeFallthrough()); \ + m.impl("is_complex", torch::CppFunction::makeFallthrough()); \ + m.impl("is_floating_point", torch::CppFunction::makeFallthrough()); \ + m.impl("requires_grad_", torch::CppFunction::makeFallthrough()); +} + +#define TORCH_VIEW_FNS_NATIVE_FN_REGISTRATION(m) \ + m.impl("as_strided", torch::CppFunction::makeFallthrough()); \ + m.impl("view", torch::CppFunction::makeFallthrough()); + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/NonSymbolicBC.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/NonSymbolicBC.h new file mode 100644 index 0000000000000000000000000000000000000000..50a20a19a801b7b5b3d5b5f53a87a84a623f74f9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/NonSymbolicBC.h @@ -0,0 +1,31 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include +#include + +namespace at::native { +// This file contains non-symbolic signatures for ops that we have sym-intified the signature of. +// However, in certain cases (such as static runtime), we call the native versions of the ops directly. +// In those cases, we will duplicate the signature here with non-symbolic ints, and also duplicate the C++ implementation. +TORCH_API at::Tensor reshape(const at::Tensor& self, at::IntArrayRef proposed_shape); +TORCH_API at::Tensor narrow(const at::Tensor& self, int64_t dim, int64_t start, int64_t length); +TORCH_API at::Tensor _sparse_coo_tensor_unsafe(const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, std::optional dtype=std::nullopt, std::optional layout=std::nullopt, std::optional device=std::nullopt, std::optional pin_memory=std::nullopt, std::optional is_coalesced=std::nullopt); +TORCH_API at::Tensor nll_loss(const at::Tensor & self, const at::Tensor & target, const std::optional& weight_opt, int64_t reduction, int64_t ignore_index); +TORCH_API at::Tensor nll_loss2d(const at::Tensor & self, const at::Tensor & target, const std::optional& weight_opt, int64_t reduction, int64_t ignore_index); +// The below ops don't get a duplicated C++ implementation. +// They are backward ops, which make them very unlikely to be called directly +// by external code (at::native::trace_backward). +// They get their own declaration for BC purposes however. +TORCH_API at::Tensor _embedding_bag_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, const std::optional & per_sample_weights, int64_t padding_idx=-1); +TORCH_API at::Tensor _embedding_bag_sparse_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, const std::optional & per_sample_weights, int64_t padding_idx=-1); +TORCH_API at::Tensor value_selecting_reduction_backward(const at::Tensor & grad, int64_t dim, const at::Tensor & indices, at::IntArrayRef sizes, bool keepdim); +TORCH_API at::Tensor trace_backward(const at::Tensor & grad, at::IntArrayRef sizes); +TORCH_API at::Tensor index_select_backward(const at::Tensor & grad, at::IntArrayRef self_sizes, int64_t dim, const at::Tensor & index); +TORCH_API at::Tensor select(const at::Tensor& self, int64_t dim, int64_t index); +TORCH_API std::vector tensor_split(const Tensor& self, IntArrayRef indices, int64_t dim); +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/PointwiseOps.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/PointwiseOps.h new file mode 100644 index 0000000000000000000000000000000000000000..6a1dbfd7365ba7d264faacc07dde79920e8dec98 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/PointwiseOps.h @@ -0,0 +1,33 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Ternary and higher-order pointwise operations +#pragma once + +#include + +namespace c10 { +class Scalar; +} + +namespace at { + +struct TensorIterator; +struct TensorIteratorBase; + +namespace native { + +using pointwise_fn = void (*)(TensorIterator&, const Scalar& scalar); +using structured_pointwise_fn = void (*)(TensorIteratorBase&, const Scalar& scalar); +using pointwise_fn_double = void (*)(TensorIterator&, const Scalar&, double); + +DECLARE_DISPATCH(structured_pointwise_fn, addcmul_stub) +DECLARE_DISPATCH(structured_pointwise_fn, addcdiv_stub) +DECLARE_DISPATCH(pointwise_fn_double, smooth_l1_backward_stub) +DECLARE_DISPATCH(pointwise_fn_double, huber_backward_stub) +DECLARE_DISPATCH(pointwise_fn, mse_backward_stub) + +} // namespace native +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Pow.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Pow.h new file mode 100644 index 0000000000000000000000000000000000000000..0d34f5638fe8f3a56f3e8e6df246ce6f691ed00b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Pow.h @@ -0,0 +1,74 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace c10 { +class Scalar; +} + +namespace at { + +struct TensorIterator; +struct TensorIteratorBase; + +namespace native { + +#if defined(__CUDACC__) || defined(__HIPCC__) +#define HOST_DEVICE __host__ __device__ +#else +#define HOST_DEVICE +#endif + +// integral power in pytorch allows for negative exponents, giving truncated integral results. +// e.g. since 2**-1==0.5, the truncated integral result is zero. 1**negative_exponent is the +// only non-zero result. +template , T>* = nullptr> +inline HOST_DEVICE __ubsan_ignore_signed_int_overflow__ T powi_impl(T a, T b) { + T result = 1; + while (b) { + if (b & 1) { + result *= a; + } + b /= 2; + a *= a; + } + return result; +} + +template && !std::is_signed_v, T>* = nullptr> +inline HOST_DEVICE T powi(T a, T b) { + return powi_impl(a, b); +} + +template && std::is_signed_v, T>* = nullptr> +inline HOST_DEVICE T powi(T a, T b) { + if ( b < 0 ) { + if ( a == 1 ) { + return 1; + } else if ( a == -1 ) { + auto negative = (-b) % static_cast(2); + return negative ? -1 : 1; + } else { + return 0; + } + } + return powi_impl(a, b); +} + +using pow_tensor_tensor_fn = void (*)(TensorIteratorBase&); +using pow_tensor_scalar_fn = void (*)(TensorIteratorBase&, const c10::Scalar&); + +DECLARE_DISPATCH(pow_tensor_tensor_fn, pow_tensor_tensor_stub) +DECLARE_DISPATCH(pow_tensor_scalar_fn, pow_tensor_scalar_stub) + +} // namespace native + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/RangeFactories.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/RangeFactories.h new file mode 100644 index 0000000000000000000000000000000000000000..cf12e97e4f8933320a058b7e0d6b31cbba867c2a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/RangeFactories.h @@ -0,0 +1,17 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include +#include + +namespace at { +struct TensorIterator; + +namespace native { + +DECLARE_DISPATCH(void(*)(TensorIterator&, const Scalar&, const Scalar&, const Scalar&), arange_stub) +DECLARE_DISPATCH(void(*)(TensorIterator&, const Scalar&, const Scalar&, int64_t), linspace_stub) + +}} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/RangeUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/RangeUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..5b5e724e6b34b42fe3ab10cf685c93d64a31a768 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/RangeUtils.h @@ -0,0 +1,65 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include +#include +#include + + + +namespace at::native { + +inline void arange_check_bounds( + const c10::Scalar& start, + const c10::Scalar& end, + const c10::Scalar& step) { + // use double precision for validation to avoid precision issues + double dstart = start.to(); + double dend = end.to(); + double dstep = step.to(); + + TORCH_CHECK(dstep > 0 || dstep < 0, "step must be nonzero"); + TORCH_CHECK( + std::isfinite(dstart) && std::isfinite(dend), + "unsupported range: ", + dstart, + " -> ", + dend); + TORCH_CHECK( + ((dstep > 0) && (dend >= dstart)) || ((dstep < 0) && (dend <= dstart)), + "upper bound and lower bound inconsistent with step sign"); +} + +template +int64_t compute_arange_size(const Scalar& start, const Scalar& end, const Scalar& step) { + arange_check_bounds(start, end, step); + + // we use double precision for (start - end) / step + // to compute size_d for consistency across devices. + // The problem with using accscalar_t is that accscalar_t might be float32 on gpu for a float32 scalar_t, + // but double on cpu for the same, + // and the effective output size starts differing on CPU vs GPU because of precision issues, which + // we dont want. + // the corner-case we do want to take into account is int64_t, which has higher precision than double + double size_d; + if constexpr (std::is_same_v) { + using accscalar_t = at::acc_type; + auto xstart = start.to(); + auto xend = end.to(); + auto xstep = step.to(); + int64_t sgn = (xstep > 0) - (xstep < 0); + size_d = std::ceil((xend - xstart + xstep - sgn) / xstep); + } else { + size_d = std::ceil((end.to() - start.to()) + / step.to()); + } + + TORCH_CHECK(size_d >= 0 && size_d <= static_cast(std::numeric_limits::max()), + "invalid size, possible overflow?"); + + return static_cast(size_d); +} + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ReduceAllOps.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ReduceAllOps.h new file mode 100644 index 0000000000000000000000000000000000000000..9c07a586c621fa9a1840ff25705d8f69fb7cf44f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ReduceAllOps.h @@ -0,0 +1,21 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace at { +class Tensor; +} + +namespace at::native { + +using reduce_all_fn = void (*)(Tensor & result, const Tensor & self); +using reduce_min_max_fn = void (*)(Tensor & max_result, Tensor & min_result, const Tensor & self); +DECLARE_DISPATCH(reduce_all_fn, min_all_stub) +DECLARE_DISPATCH(reduce_all_fn, max_all_stub) + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ReduceOps.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ReduceOps.h new file mode 100644 index 0000000000000000000000000000000000000000..0b5fe2800e2da9b761de750e4aaab8855903566a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ReduceOps.h @@ -0,0 +1,62 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +namespace c10 { +class Scalar; +} + +namespace at { +struct TensorIterator; +class Tensor; +} + +namespace at::native { + +using reduce_fn = void(*)(TensorIterator &); + +DECLARE_DISPATCH(reduce_fn, sum_stub) +DECLARE_DISPATCH(reduce_fn, nansum_stub) +DECLARE_DISPATCH(reduce_fn, prod_stub) +DECLARE_DISPATCH(reduce_fn, mean_stub) +DECLARE_DISPATCH(reduce_fn, and_stub) +DECLARE_DISPATCH(reduce_fn, or_stub) +DECLARE_DISPATCH(reduce_fn, min_values_stub) +DECLARE_DISPATCH(reduce_fn, max_values_stub) +DECLARE_DISPATCH(reduce_fn, argmax_stub) +DECLARE_DISPATCH(reduce_fn, argmin_stub) +DECLARE_DISPATCH(reduce_fn, xor_sum_stub) + +using reduce_std_var_function = + void (*)(TensorIterator&, double correction, bool take_sqrt); +DECLARE_DISPATCH(reduce_std_var_function, std_var_stub) + +using reduce_norm_fn = + void (*)(Tensor&, const Tensor&, const c10::Scalar&, std::optional); +DECLARE_DISPATCH(reduce_norm_fn, norm_kernel) + +using reduce_fn_flag = void(*)(TensorIterator &, const c10::Scalar&); +DECLARE_DISPATCH(reduce_fn_flag, norm_stub) + +using structured_cum_fn = void (*)(const Tensor&, const Tensor&, int64_t); +using cum_fn = void (*)(Tensor&, const Tensor&, int64_t); +DECLARE_DISPATCH(structured_cum_fn, cumsum_stub) +DECLARE_DISPATCH(structured_cum_fn, cumprod_stub) +DECLARE_DISPATCH(cum_fn, logcumsumexp_stub) + +DECLARE_DISPATCH(void (*)(const Tensor&, int64_t, bool, Tensor&, Tensor&), aminmax_stub) +DECLARE_DISPATCH(void (*)(const Tensor&, Tensor&, Tensor&), aminmax_allreduce_stub) + +// Used in cuda/Normalization.cu +TORCH_API std::tuple var_mean_out( + Tensor &result1, Tensor &result2, const Tensor &self, IntArrayRef dim, + int64_t correction, bool keepdim); + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ReductionType.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ReductionType.h new file mode 100644 index 0000000000000000000000000000000000000000..31f6dffa592c60dc8bd07ecc85dc5660c29e0bc6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ReductionType.h @@ -0,0 +1,45 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace at::native { + +enum class ReductionType {MAX, MEAN, MIN, SUM, PROD}; + +inline ReductionType get_reduction_enum(const std::string_view& reduce) { + if (reduce == "max" || reduce == "amax") { + return ReductionType::MAX; + } else if (reduce == "mean") { + return ReductionType::MEAN; + } else if (reduce == "min" || reduce == "amin") { + return ReductionType::MIN; + } else if (reduce == "sum") { + return ReductionType::SUM; + } else if (reduce == "prod") { + return ReductionType::PROD; + } else { + TORCH_CHECK(false, "reduce argument must be either sum, prod, mean, amax or amin, got ", reduce); + } +} + +// used for `scatter_reduce`, old options for BC. +inline ReductionType get_operator_enum(const std::string_view reduce, bool use_new_options) { + if (use_new_options) { + return get_reduction_enum(reduce); + } else { + if (reduce == "add") { + return ReductionType::SUM; + } else if (reduce == "multiply") { + return ReductionType::PROD; + } else { + TORCH_CHECK(false, "reduce argument must be either add or multiply.") + } + } +} + +} // at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ResizeCommon.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ResizeCommon.h new file mode 100644 index 0000000000000000000000000000000000000000..ccedbed59e28b2bb2774b8a06ec490e903b36ce7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ResizeCommon.h @@ -0,0 +1,80 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +namespace at::native { + +template +inline T storage_size_for(ArrayRef size, ArrayRef stride) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(size.size() == stride.size(), + "storage_size_for(size, stride) requires that size and stride ", + "have the same size as a precondition."); + T storage_size = 1; + for (const auto dim : c10::irange(size.size())) { + if (size[dim] == 0) { + storage_size = 0; + break; + } + storage_size += (size[dim] - 1) * stride[dim]; + } + return storage_size; +} + +inline const Tensor& resize_named_tensor_( + const Tensor& self, + IntArrayRef size, + std::optional optional_memory_format) { + TORCH_INTERNAL_ASSERT(self.has_names()); + TORCH_CHECK( + self.sizes() == size, + "Cannot resize named tensor with resize_ or resize_as_ (tried to resize " + "Tensor", + self.names(), + " with size ", + self.sizes(), + " to ", + size, + "). This may be caused by passing a named tensor ", + "as an `out=` argument; please ensure that the sizes are the same. "); + TORCH_CHECK( + !optional_memory_format.has_value(), + "Unsupported memory format for named tensor resize ", + optional_memory_format.value()); + return self; +} + +// For deterministic output, fill new elements that were added after a storage +// resize with NaN or MAX_INT. `old_storage_nbytes` is the size of the storage +// before the resize happened. +inline const Tensor& fill_resize_deterministic_(const Tensor& tensor, int64_t old_storage_nbytes) { + const at::Storage& storage = tensor.unsafeGetTensorImpl()->unsafe_storage(); + int64_t new_storage_nbytes = storage.nbytes(); + int64_t old_storage_numel = old_storage_nbytes / tensor.itemsize(); + int64_t new_storage_numel = new_storage_nbytes / tensor.itemsize(); + if (new_storage_numel > old_storage_numel) { + at::Tensor tensor_view = at::empty({}, at::TensorOptions().dtype(tensor.scalar_type()).device(tensor.device())); + tensor_view.set_( + storage, + /*storage_offset=*/old_storage_numel, + /*size=*/{new_storage_numel - old_storage_numel}, + /*stride=*/{1}); + at::native::fill_empty_deterministic_(tensor_view); + } + return tensor; +} + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/SegmentReduce.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/SegmentReduce.h new file mode 100644 index 0000000000000000000000000000000000000000..add144d19e2a0c4872304c935a320d4a11b26f27 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/SegmentReduce.h @@ -0,0 +1,55 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +namespace at { +class Tensor; + +namespace native { + +using segment_reduce_lengths_fn = Tensor (*)( + ReductionType, + const Tensor&, + const Tensor&, + int64_t, + const std::optional&); +DECLARE_DISPATCH(segment_reduce_lengths_fn, _segment_reduce_lengths_stub) + +using segment_reduce_offsets_fn = Tensor (*)( + ReductionType, + const Tensor&, + const Tensor&, + int64_t, + const std::optional&); +DECLARE_DISPATCH(segment_reduce_offsets_fn, _segment_reduce_offsets_stub) + +using segment_reduce_lengths_backward_fn = Tensor (*)( + const Tensor&, + const Tensor&, + const Tensor&, + ReductionType, + const Tensor&, + int64_t, + const std::optional&); +DECLARE_DISPATCH(segment_reduce_lengths_backward_fn, _segment_reduce_lengths_backward_stub) + +using segment_reduce_offsets_backward_fn = Tensor (*)( + const Tensor&, + const Tensor&, + const Tensor&, + ReductionType, + const Tensor&, + int64_t, + const std::optional&); +DECLARE_DISPATCH(segment_reduce_offsets_backward_fn, _segment_reduce_offsets_backward_stub) + +} // namespace native +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/SharedReduceOps.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/SharedReduceOps.h new file mode 100644 index 0000000000000000000000000000000000000000..56e4db1423183f5049a21d2008dd9c8c5f147fa0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/SharedReduceOps.h @@ -0,0 +1,550 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +// Please note that this file is +// used across both CPU and GPU. + +#include +#include +#include +#include +#include +#include +#if defined(__CUDACC__) +#include +#include +#elif defined(__HIPCC__) +#include +#include +#endif +#if defined(__CUDACC__) || defined(__HIPCC__) +#include +#else +#include +#define device_sqrt std::sqrt +#endif +#if defined(__CUDACC__) || defined(__HIPCC__) +template +inline C10_DEVICE scalar_t max_propagate_nan(scalar_t a, scalar_t b) { +#if defined(__HIPCC__) + // TODO: remove this special case for HIP when issue is fixed: + // https://github.com/ROCm/hip/issues/2209 + scalar_t max = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::max(a, b)); +#else + scalar_t max = at::_isnan(b) ? b : std::max(a, b); +#endif + return max; +} +template +inline C10_DEVICE scalar_t min_propagate_nan(scalar_t a, scalar_t b) { +#if defined(__HIPCC__) + // TODO: remove this special case for HIP when issue is fixed: + // https://github.com/ROCm/hip/issues/2209 + scalar_t min = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::min(a, b)); +#else + scalar_t min = at::_isnan(b) ? b : std::min(a, b); +#endif + return min; +} +#define MAX(X, Y) max_propagate_nan(X,Y) +#define MIN(X, Y) min_propagate_nan(X,Y) +#else +#include +#define MAX(X, Y) max_impl(X,Y) +#define MIN(X, Y) min_impl(X,Y) +#endif + +// ROCM hcc doesn't work well with using std:: in kernel functions +#if defined(__CUDA_ARCH__) +#include +#define compat_pow c10::cuda::compat::pow +#elif defined(__HIPCC__) +#include +#define compat_pow c10::hip::compat::pow +#else +#define compat_pow std::pow +#endif + +namespace at::native { + +namespace detail { + +#if defined(__CUDACC__) || defined(__HIPCC__) +template using pair = thrust::pair; +#else +template using pair = std::pair; +#endif + +} // namespace detail + +template +struct WelfordData { + scalar_t mean; + scalar_t m2; + index_t n; + scalar_t nf; + + C10_HOST_DEVICE WelfordData() : mean(0), m2(0), n(0), nf(0) {} + + C10_HOST_DEVICE WelfordData( + scalar_t mean, + scalar_t m2, + index_t n, + scalar_t nf) + : mean(mean), m2(m2), n(n), nf(nf) {} +}; + + +template +struct WelfordOps { + acc_scalar_t correction; + bool take_sqrt; + public: + using acc_t = WelfordData; + inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, index_t /*idx*/) const { + // We accumulate n in index_t to avoid cumulative rounding error, but still + // need nf for use in combine where int32 may overflow. + index_t new_n = acc.n + 1; + acc_scalar_t new_nf = static_cast(new_n); + acc_scalar_t delta = data - acc.mean; + acc_scalar_t new_mean = acc.mean + delta / new_nf; + acc_scalar_t new_delta = data - new_mean; + return { + new_mean, + acc.m2 + delta * new_delta, + new_n, + new_nf, + }; + } + inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { + if (a.nf == 0) { + return b; + } + if (b.nf == 0) { + return a; + } + acc_scalar_t delta = b.mean - a.mean; + acc_scalar_t new_count = a.nf + b.nf; + acc_scalar_t nb_over_n = b.nf / new_count; + return { + a.mean + delta * nb_over_n, + a.m2 + b.m2 + delta * delta * a.nf * nb_over_n, + // setting acc.n as -1 since acc.n might not be able to represent the count + // correctly within its range, setting it to -1 to avoid confusion + -1, + new_count + }; + } + inline C10_DEVICE res_t project(acc_t acc) const __ubsan_ignore_float_divide_by_zero__ { + const auto mean = static_cast(acc.mean); + const auto divisor = acc.nf > correction ? acc.nf - correction : 0; + const auto var = acc.m2 / divisor; + res_t results(take_sqrt ? device_sqrt(var) : var, mean); + return results; + } + + static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline __device__ acc_t warp_shfl_down(acc_t acc, int offset) const { + return { + WARP_SHFL_DOWN(acc.mean, offset) + , WARP_SHFL_DOWN(acc.m2, offset) + , WARP_SHFL_DOWN(acc.n, offset) + , WARP_SHFL_DOWN(acc.nf, offset) + }; + } +#endif + C10_HOST_DEVICE WelfordOps(acc_scalar_t correction, bool take_sqrt) + : correction(correction), take_sqrt(take_sqrt) {} +}; + +template +struct MeanOps { + factor_t factor; + + inline C10_DEVICE acc_t reduce(acc_t a, scalar_t b, int64_t /*idx*/) const { + return combine(a, static_cast(b)); + } + + inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { + return a + b; + } + + inline C10_DEVICE out_t project(acc_t a) const { + return a * factor; + } + + static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const { + return WARP_SHFL_DOWN(data, offset); + } +#endif + + MeanOps(factor_t factor): factor(factor) { + } +}; + +// This accumulator template is used to calculate the minimum absolute value of +// a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated +// value. These types differ for complex number input support. +template +struct AbsMinOps { + + inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + return MIN(acc, static_cast(std::abs(at::opmath_type(data)))); + } + + inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { + return MIN(a, b); + } + + inline C10_DEVICE out_t project(acc_t a) const { + return a; + } + + static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const { + return WARP_SHFL_DOWN(acc, offset); + } +#endif +}; + +// This accumulator template is used to calculate the maximum absolute value of +// a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated +// value. These types differ for complex number input support. +template +struct AbsMaxOps { + inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + return MAX(acc, static_cast(std::abs(at::opmath_type(data)))); + } + + inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { + return MAX(a, b); + } + + inline C10_DEVICE out_t project(acc_t a) const { + return a; + } + + static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const { + return WARP_SHFL_DOWN(acc, offset); + } +#endif +}; + +// This accumulator template is used to calculate the norm of the absolute value +// of a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated +// value. These types differ for complex number input support. +template +struct NormOps { + acc_t norm_; + + inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + return acc + compat_pow(static_cast(std::abs(at::opmath_type(data))), norm_); + } + + inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { + return a + b; + } + + inline C10_DEVICE out_t project(acc_t a) const { + return compat_pow(a, static_cast(1.0) / norm_); + } + + static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const { + return WARP_SHFL_DOWN(acc, offset); + } +#endif + + NormOps(acc_t norm_): norm_(norm_) { + } +}; + +// This accumulator template is used to calculate the order zero norm of the +// absolute value of a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated +// value. These types differ for complex number input support. +template +struct NormZeroOps { + inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + return acc + (data == static_cast(0) ? static_cast(0) : static_cast(1)); + } + + inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { + return a + b; + } + + inline C10_DEVICE out_t project(acc_t a) const { + return a; + } + + static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } + + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const { + return WARP_SHFL_DOWN(acc, offset); + } +#endif +}; + +// This accumulator template is used to calculate the order one norm of the +// absolute value of a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated +// value. These types differ for complex number input support. +template +struct NormOneOps { + inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + return acc + static_cast(std::abs(at::opmath_type(data))); + } + + inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { + return a + b; + } + + inline C10_DEVICE out_t project(acc_t a) const { + return a; + } + + static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const { + return WARP_SHFL_DOWN(acc, offset); + } +#endif +}; + + +template +struct AbsSwitch {}; + +template +inline C10_DEVICE acc_t abs_if_complex(scalar_t data, AbsSwitch /*unused*/) { + return static_cast(data); +} + +template +inline C10_DEVICE acc_t abs_if_complex(std::complex data, AbsSwitch /*unused*/) { + return static_cast(std::abs(data)); +} + +template +inline C10_DEVICE acc_t abs_if_complex(c10::complex data, AbsSwitch /*unused*/) { + return static_cast(std::abs(at::opmath_type>(data))); +} + +// This accumulator template is used to calculate the order two norm of the +// absolute value of a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the accumulated +// value. These types differ for complex number input support. +template +struct NormTwoOps { + inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + acc_t data_ = abs_if_complex(data, AbsSwitch()); + return acc + data_ * data_; + } + + inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { + return a + b; + } + + inline C10_DEVICE out_t project(acc_t a) const { + return device_sqrt(a); + } + + static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const { + return WARP_SHFL_DOWN(acc, offset); + } +#endif +}; + +template +struct NanSumOps { + inline C10_DEVICE acc_t reduce(acc_t a, data_t b, int64_t /*idx*/) const { + return a + (at::_isnan(b) ? acc_t{0.} : acc_t{b}); + } + + inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { + return a + b; + } + + inline C10_DEVICE data_t project(acc_t a) const { + return data_t{a}; + } + + static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const { + return WARP_SHFL_DOWN(data, offset); + } +#endif +}; + +namespace detail { + +template +struct LessOrNan { + C10_DEVICE bool operator () (scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const { + // If (a == b), then choose the one with lower idx, else min(a, b) + if (at::_isnan(a)) { + if (at::_isnan(b)) { + return idx_a < idx_b; + } + return true; + } + return (a == b) ? idx_a < idx_b : (a < b); + } +}; + +template +struct GreaterOrNan { + C10_DEVICE bool operator () (scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const { + // If (a == b), then choose the one with lower idx, else max(a, b) + if (at::_isnan(a)) { + if (at::_isnan(b)) { + return idx_a < idx_b; + } + return true; + } + return (a == b) ? idx_a < idx_b : (a > b); + } +}; + +template +struct MinMaxReductionOps { + using scalar_t = typename binary_function_traits::arg1_t; + using index_t = int64_t; + using arg_t = detail::pair; + + static C10_DEVICE arg_t project(arg_t arg) { + return arg; + } + + static C10_DEVICE arg_t reduce(arg_t arg, scalar_t val, int64_t idx) { + return comp_t{}(arg.first, val, arg.second, idx) ? arg : arg_t(val, idx); + } + + static C10_DEVICE arg_t combine(arg_t a, arg_t b) { + return comp_t{}(a.first, b.first, a.second, b.second) ? a : b; + } + + static C10_DEVICE arg_t translate_idx(arg_t a, int64_t base_idx) { + return {a.first, a.second + base_idx}; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + static C10_DEVICE arg_t warp_shfl_down(arg_t arg, int offset) { + return arg_t(WARP_SHFL_DOWN(arg.first, offset), + WARP_SHFL_DOWN(arg.second, offset)); + } +#endif +}; + +template +struct ArgReductionOps : public MinMaxReductionOps { + using typename MinMaxReductionOps::scalar_t; + using typename MinMaxReductionOps::index_t; + using typename MinMaxReductionOps::arg_t; + + static C10_DEVICE index_t project(arg_t arg) { + return arg.second; + } +}; + +} // namespace detail + +template +struct ArgMaxOps : + public detail::ArgReductionOps> { +}; + +template +struct ArgMinOps : + public detail::ArgReductionOps> { +}; + +template +struct MinOps : + public detail::MinMaxReductionOps> { +}; + +template +struct MaxOps : + public detail::MinMaxReductionOps> { +}; + +template +struct MinMaxOps { + using acc_t = detail::pair; + inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, index_t /*idx*/) const { + return combine(acc, {data, data}); + } + + inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const { + auto min_val = (at::_isnan(a.first) || a.first < b.first) ? a.first : b.first; + auto max_val = (at::_isnan(a.second) || a.second > b.second) ? a.second : b.second; + + return {min_val, max_val}; + } + + inline C10_DEVICE acc_t project(acc_t acc) const { + return acc; + } + + static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const { + return { + WARP_SHFL_DOWN(acc.first, offset), WARP_SHFL_DOWN(acc.second, offset) + }; + } +#endif +}; + +} // namespace at::native + +#undef MAX +#undef MIN + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/SobolEngineOpsUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/SobolEngineOpsUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..b13470713490a02311357b44b5094b7517674585 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/SobolEngineOpsUtils.h @@ -0,0 +1,60 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/// This file contains some tensor-agnostic operations to be used in the +/// core functions of the `SobolEngine` +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#include +#endif + +namespace at::native::sobol_utils { + +/// Function to return the minimum of number of bits to represent the integer `n` +inline int64_t bit_length(const int64_t n) { + int64_t nbits, nloc; + for (nloc = n, nbits = 0; nloc > 0; nloc /= 2, nbits++); + return nbits; +} + +/// Function to get the position of the rightmost zero in the bit representation of an integer +/// This value is the zero-indexed position +inline int64_t rightmost_zero(const int64_t n) { + int64_t z, i; + for (z = n, i = 0; z % 2 == 1; z /= 2, i++); + return i; +} + +/// Function to get a subsequence of bits in the representation of an integer starting from +/// `pos` and of length `length` +inline int64_t bitsubseq(const int64_t n, const int64_t pos, const int64_t length) { + return (n >> pos) & ((1 << length) - 1); +} + +/// Function to perform the inner product between a batched square matrix and a power of 2 vector +inline at::Tensor cdot_pow2(const at::Tensor& bmat) { + at::Tensor inter = at::arange(bmat.size(-1) - 1, -1, -1, bmat.options()); + inter = at::pow(2, inter).expand_as(bmat); + return at::mul(inter, bmat).sum(-1); +} + +/// All definitions below this point are data. These are constant, and should not be modified +/// without notice + +constexpr int64_t MAXDIM = 21201; +constexpr int64_t MAXDEG = 18; +constexpr int64_t MAXBIT = 30; +constexpr int64_t LARGEST_NUMBER = 1 << MAXBIT; +constexpr float RECIPD = 1.0 / LARGEST_NUMBER; + +extern const int64_t poly[MAXDIM]; +extern const int64_t initsobolstate[MAXDIM][MAXDEG]; + +} // namespace at::native::sobol_utils + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/SparseTensorUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/SparseTensorUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..3916d49bf3f65de4e3c47f1d223eab7b809fdfe0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/SparseTensorUtils.h @@ -0,0 +1,195 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#endif + +namespace at::sparse { + +// Just for documentary purposes +using SparseTensor = Tensor; +using SparseType = Type; + +// This is an internal utility function for getting at the SparseTensorImpl, +// so that we can write sparse tensor specific accessors for special fields +// in SparseTensor. You should only use this for writing low level +// setters/getters for SparseTensorImpl fields; otherwise, you should use +// the low level setters/getters that were implemented using this. +// +// This may be called repeatedly, so make sure it's pretty cheap. +inline SparseTensorImpl* get_sparse_impl(const SparseTensor& self) { + TORCH_INTERNAL_ASSERT( + self.is_sparse(), "_internal_get_SparseTensorImpl: not a sparse tensor"); + return static_cast(self.unsafeGetTensorImpl()); +} + +// Takes indices and values and directly puts them into the sparse tensor, no +// copy. This used to be called THSTensor_(_move) +inline void alias_into_sparse( + const SparseTensor& self, + const Tensor& indices, + const Tensor& values) { + get_sparse_impl(self)->set_indices_and_values_unsafe(indices, values); +} + +// Take indices and values and makes a (data) copy of them to put into the +// sparse indices/values. This used to be called THSTensor_(_set) +inline void copy_into_sparse( + const SparseTensor& self, + const Tensor& indices, + const Tensor& values, + bool non_blocking) { + alias_into_sparse( + self, + indices.to(self._indices().options(), non_blocking, /*copy=*/true), + values.to(self._values().options(), non_blocking, /*copy=*/true)); +} + +// TODO: put this into the public API +inline bool is_same_tensor(const Tensor& lhs, const Tensor& rhs) { + return lhs.unsafeGetTensorImpl() == rhs.unsafeGetTensorImpl(); +} + +inline bool is_same_density(const SparseTensor& self, const SparseTensor& src) { + return self.sparse_dim() == src.sparse_dim() && + self.dense_dim() == src.dense_dim(); +} + +// Give us a new values tensor, with the same dimensionality +// as 'values' but with a new number of non-zero elements. +// TODO: Expose this for real in ATen, some day? +// NB: Doesn't preserve data. +inline Tensor new_values_with_size_of(const Tensor& values, int64_t nnz) { + std::vector size = values.sizes().vec(); + size[0] = nnz; + return at::empty(size, values.options()); +} + +// NOTE [ Flatten Sparse Indices ] +// This helper function flattens a sparse indices tensor (a Tensor) into a 1D +// indices tensor. E.g., +// input = [[2, 4, 0], +// [3, 1, 10]] +// full_size = [2, 12] +// output = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 10 ] = [27, 49, 10] +// +// In other words, assuming that each `indices[i, :]` is a valid index to a +// tensor `t` of shape `full_size`. This returns the corresponding indices to +// the flattened tensor `t.reshape( prod(full_size[:indices.size(0)]), -1 )`. +// if forceClone is true, the result will forced to be a clone of self. +// if force_clone is true, the result will forced to be a clone of self. +TORCH_API Tensor flatten_indices( + const Tensor& indices, + IntArrayRef full_size, + bool force_clone = false); + +// Flatten sparse tensor's indices from nD to 1D, similar to NOTE [ Flatten +// Sparse Indices ], except this one allows partial flatten: only flatten on +// specified dims. Note that the flatten indices might be uncoalesced if +// dims_to_flatten.size() < sparse_dim. Also if input indices is already +// coalesced, the flattened indices will also be sorted. +// +// args: +// indices: sparse tensor indices +// sizes: sparse tensor sizes +// dims_to_flatten: a list of dim index to flatten +// +// Ex1: +// indices = [[2, 4, 0], +// [3, 1, 3]] +// sizes = [2, 12] +// dims_to_flatten = [0, 1] +// new_indices = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 3 ] = [27, 49, 3] +// +// Ex2: +// dims_to_flatten = [1] +// new_indices = [ 3, 1, 3 ] # uncoalesced +TORCH_API Tensor flatten_indices_by_dims( + const Tensor& indices, + const IntArrayRef& sizes, + const IntArrayRef& dims_to_flatten); + +// Find the CSR representation for a row `indices` from the COO format +TORCH_API Tensor coo_to_csr(const int64_t* indices, int64_t dim, int64_t nnz); + +TORCH_API Tensor zeros_like_with_indices(const Tensor& t); + +template +class TensorGeometryHolder { + using geometry_holder_t = std::array; + + public: + explicit TensorGeometryHolder( + IntArrayRef sizes, + IntArrayRef strides, + TensorOptions options = {}) { + std::copy(sizes.begin(), sizes.end(), t_sizes.begin()); + std::copy(strides.begin(), strides.end(), t_strides.begin()); + } + + explicit TensorGeometryHolder(const Tensor& t) + : TensorGeometryHolder(t.sizes(), t.strides()) {} + + auto operator*() const { + return std::make_tuple(t_sizes, t_strides); + } + + private: + geometry_holder_t t_sizes; + geometry_holder_t t_strides; +}; + +template <> +class TensorGeometryHolder<0> { + using geometry_holder_t = Tensor; + + public: + explicit TensorGeometryHolder( + IntArrayRef sizes, + IntArrayRef strides, + TensorOptions options) { + const int64_t t_ndims = sizes.size(); + const auto cpu_options = TensorOptions(options).dtype(kLong).device(kCPU); + Tensor t_sizes_and_strides_cpu = at::empty({2, t_ndims}, cpu_options); + t_sizes_and_strides_cpu.select(0, 0).copy_(at::tensor(sizes, cpu_options)); + t_sizes_and_strides_cpu.select(0, 1).copy_( + at::tensor(strides, cpu_options)); + const Tensor t_sizes_and_strides = + t_sizes_and_strides_cpu.to(options.device()); + t_sizes = t_sizes_and_strides.select(0, 0); + t_strides = t_sizes_and_strides.select(0, 1); + } + + explicit TensorGeometryHolder(const Tensor& t) + : TensorGeometryHolder(t.sizes(), t.strides(), t.options()) {} + + auto operator*() const { + return std::make_tuple( + t_sizes.template data_ptr(), + t_strides.template data_ptr()); + } + + private: + geometry_holder_t t_sizes; + geometry_holder_t t_strides; +}; + +// Return all indices of a tensor with the given shape. +// +// full_coo_indices(shape) is equivalent to +// torch.ones(shape).nonzero().transpose(-2, -1) but much faster. +TORCH_API Tensor full_coo_indices(IntArrayRef sizes, TensorOptions options); + +} // namespace at::sparse + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorFactories.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorFactories.h new file mode 100644 index 0000000000000000000000000000000000000000..9983d87903ee97bf9a3f5ad120fa134d82ba3d3c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorFactories.h @@ -0,0 +1,174 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +namespace at::native { +// Different combinations of row, col, and offset can lead to two cases: +// +// Case 1 - Trapezoid (Triangle as a special case): row + offset <= col +// Example A: offset > 0 +// 1 1 0 0 0 +// 1 1 1 0 0 +// 1 1 1 1 0 +// Example B: offset <= 0 +// 0 0 0 +// 1 0 0 +// 1 1 0 +// In this case, we calculate the number of elements in the first row and +// last row of the tril respectively, and then compute the tril size. +// +// Case 2 - Trapezoid + Rectangle: row + offset > col +// Example: +// 1 1 0 +// 1 1 1 +// 1 1 1 +// In this case, we first calculate the size of top trapezoid, and then +// calculate the size of the bottom rectangle. +inline int64_t get_tril_size(int64_t row, int64_t col, int64_t offset) { + // If either dimension is 0 then the there is no tril + if (row == 0 || col == 0) { + return 0; + } + // number of elements in the first row of the tril + auto m_first_row = offset > 0 ? std::min(col, 1 + offset) + : // upper bounded by col + row + offset > 0; // either 0 or 1 + // number of elements in the last row of the tril, bounded by [0, col] + auto m_last_row = std::max(0, std::min(col, row + offset)); + // number of rows, bounded by [0, row] + auto n_row_all = std::max(0, std::min(row, row + offset)); + auto n_row_trapezoid = (m_last_row - m_first_row + 1); + + // calculate # of elements in the top trapezoid + auto tril_size = (m_first_row + m_last_row) * n_row_trapezoid >> 1; + + // calculate # of elements in the bottom rectangle if there is any + auto diff_row = n_row_all - n_row_trapezoid; + if (diff_row > 0) { + tril_size += diff_row * col; + } + + return tril_size; +} + +inline void check_args( + int64_t row, + int64_t col, + std::optional layout_opt) { + TORCH_CHECK(row >= 0, "row must be non-negative, got", row); + TORCH_CHECK(col >= 0, "col must be non-negative, got", col); + if (layout_opt.has_value()) { + TORCH_CHECK( + *layout_opt == at::kStrided, + "only support layout=torch.strided, got", + *layout_opt) + } +} + +using at::check_size_nonnegative; + +// assumes maximum value in created tensor is n-1 (e.g., torch.randperm(n)) +inline void check_supported_max_int_with_precision( + int64_t n, + const Tensor& tensor) { + // match defined() to behavior of checks below + TORCH_CHECK( + at::scalar_tensor(n > 0 ? n - 1 : n, tensor.options()).defined(), + "n is too large for result tensor type: '", + tensor.toString(), + "'"); + + // Ensure sufficient precision for floating point representation. + switch (tensor.scalar_type()) { + case at::ScalarType::Half: + TORCH_CHECK( + n <= (int64_t(1) << 11) + 1, + "n cannot be greater than 2049 for Half type."); + break; + case at::ScalarType::Float: + TORCH_CHECK( + n <= (int64_t(1) << 24) + 1, + "n cannot be greater than 2^24+1 for Float type."); + break; + case at::ScalarType::Double: // Unlikely to happen, but doesn't hurt to + // check + TORCH_CHECK( + n <= (int64_t(1) << 53) + 1, + "n cannot be greater than 2^53+1 for Double type."); + break; + default: + break; + } +} + +// Called by `empty*` functions when deterministic algorithms are enabled to +// fill the tensor with NaN if it is floating point or complex type, or fill +// with max value if it is integer type +inline Tensor& fill_empty_deterministic_(Tensor& tensor) { + if (tensor.is_floating_point() || tensor.is_complex()) { + AT_DISPATCH_V2( + tensor.scalar_type(), + "fill_empty_deterministic_", + AT_WRAP([&]() { + tensor.fill_(std::numeric_limits::quiet_NaN()); + }), + AT_EXPAND(AT_FLOATING_TYPES), + AT_EXPAND(AT_COMPLEX_TYPES), + AT_EXPAND(AT_FLOAT8_TYPES), + kBFloat16, + kHalf, + kComplexHalf); + } else { + AT_DISPATCH_V2( + tensor.scalar_type(), + "fill_empty_deterministic_", + AT_WRAP([&]() { tensor.fill_(std::numeric_limits::max()); }), + kBool, + AT_EXPAND(AT_INTEGRAL_TYPES_V2)); + } + return tensor; +} + +// The ZeroTensor allocator ignores whatever allocation is requested and always +// gives you nullptr +struct ZeroTensorAllocator final : public at::Allocator { + ZeroTensorAllocator(at::Device device) : device_(device) {} + ~ZeroTensorAllocator() override = default; + static void deleter(void* const pointer) { + TORCH_INTERNAL_ASSERT(!pointer); + } + DataPtr allocate(const size_t /*nbytes*/) override { + return {nullptr, nullptr, &deleter, device_}; + } + DeleterFnPtr raw_deleter() const override { + return deleter; + } + void copy_data( + void* dest [[maybe_unused]], + const void* src [[maybe_unused]], + std::size_t count [[maybe_unused]]) const final {} + at::Device device_; +}; + +using binary_fn = void (*)(TensorIterator&); + +DECLARE_DISPATCH(binary_fn, complex_stub) +DECLARE_DISPATCH(binary_fn, polar_stub) + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorIterator.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorIterator.h new file mode 100644 index 0000000000000000000000000000000000000000..149e4dd914a0ee816278dd14a1b730980a2b7fa0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorIterator.h @@ -0,0 +1,7 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorShape.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorShape.h new file mode 100644 index 0000000000000000000000000000000000000000..55295ea0a0a7e42962d40dca91161e7d48523676 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorShape.h @@ -0,0 +1,150 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include +#include + +namespace at::native { + +TORCH_API at::Tensor clone_preserve_strides(const at::Tensor& self); + +inline bool cat_should_skip_tensor(const Tensor& t) { + return t.sym_numel() == 0 && t.dim() == 1; +} + +// Check to see if the shape of tensors is compatible +// for being concatenated along a given dimension. +inline void check_cat_shape_except_dim( + const Tensor& first, + const Tensor& second, + int64_t dimension, + int64_t index) { + int64_t first_dims = first.dim(); + int64_t second_dims = second.dim(); + TORCH_CHECK( + first_dims == second_dims, + "Tensors must have same number of dimensions: got ", + first_dims, + " and ", + second_dims); + for (const auto dim : c10::irange(first_dims)) { + if (dim == dimension) { + continue; + } + int64_t first_dim_size = first.sizes()[dim]; + int64_t second_dim_size = second.sizes()[dim]; + TORCH_CHECK( + first_dim_size == second_dim_size, + "Sizes of tensors must match except in dimension ", + dimension, + ". Expected size ", + static_cast(first_dim_size), + " but got size ", + static_cast(second_dim_size), + " for tensor number ", + index, + " in the list."); + } +} + +inline void check_cat_no_zero_dim(const MaterializedITensorListRef& tensors) { + [[maybe_unused]] int64_t i = 0; + for (const Tensor& t : tensors) { + TORCH_CHECK( + t.dim() > 0, + "zero-dimensional tensor (at position ", + i, + ") cannot be concatenated"); + i++; + } +} + +inline int64_t get_num_splits( + const Tensor& self, + int64_t split_size, + int64_t dim) { + TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor"); + TORCH_CHECK( + split_size >= 0, + "split expects split_size be non-negative, but got split_size=", + split_size); + int64_t dim_size = self.size(dim); + TORCH_CHECK( + split_size > 0 || dim_size == 0, + "split_size can only be 0 if dimension size is 0, " + "but got dimension size of ", + dim_size); + // if split_size is 0 and dimension size is 0, there is 1 split. + int64_t num_splits = 1; + if (split_size != 0) { + // ensuring num_splits is at least 1 makes consistent the case where + // split_size > dim_size (returns a single split). We might want to error + // here, but keep it for BC. + num_splits = std::max((dim_size + split_size - 1) / split_size, 1); + } + return num_splits; +} + +inline bool have_same_ndims(TensorList tensors) { + auto ndim = tensors[0].dim(); + for (const auto tensor_idx : c10::irange(tensors.size())) { + if (tensors[tensor_idx].dim() != ndim) { + return false; + } + } + return true; +} + +inline void leading_dimension_matches(TensorList tensors, int64_t dim) { + auto tensor_zero_size = tensors[0].sizes(); + std::vector leading_dim_sizes( + tensor_zero_size.begin(), tensor_zero_size.begin() + dim); + for (const auto i : c10::irange(tensors.size())) { + at::Tensor tensor = tensors[i]; + for (const auto j : c10::irange(dim)) { + TORCH_CHECK( + tensor.size(j) == leading_dim_sizes[j], + "_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors"); + } + } +} + +inline int64_t preprocess_chunk_cat_inputs( + TensorList tensors, + int64_t dim, + int64_t num_chunks) { + TORCH_CHECK(num_chunks >= 1, "_chunk_cat expects positive num_chunks"); + TORCH_CHECK( + !tensors.empty(), "_chunk_cat expects a non-empty input tensor list"); + auto expected_dtype = tensors[0].dtype(); + auto expected_device = tensors[0].device(); + for (const auto i : c10::irange(tensors.size())) { + TORCH_CHECK(tensors[i].numel() > 0, "_chunk_cat expects non-empty tensor"); + TORCH_CHECK( + tensors[i].dtype() == expected_dtype, + "_chunk_cat expects all input tensors with the same dtype"); + TORCH_CHECK( + tensors[i].device() == expected_device, + "_chunk_cat expects all inputs tensors on the same device"); + } + if (have_same_ndims(tensors)) { + dim = maybe_wrap_dim(dim, tensors[0].dim()); + } else { + TORCH_CHECK( + dim >= 0, + "_chunk_cat expects non-negative dim when input tensors have different ndims") + for (const auto i : c10::irange(tensors.size())) { + TORCH_CHECK( + dim < tensors[i].ndimension(), + "_chunk_cat expects dim < ndim for all input tensors"); + } + } + leading_dimension_matches(tensors, dim); + return dim; +} + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorTransformations.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorTransformations.h new file mode 100644 index 0000000000000000000000000000000000000000..b9f22183a1275308da9dc2b68a74d3048fd88770 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TensorTransformations.h @@ -0,0 +1,40 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +#include + +namespace at::native { + +static inline Tensor roll_common( + const Tensor& self, + IntArrayRef shifts, + IntArrayRef dims) { + TORCH_CHECK(!shifts.empty(), "`shifts` required"); + if (dims.empty() && shifts.size() == 1) { + auto flattened = self.contiguous().view(self.numel()); + return roll(flattened, shifts[0], 0).view(self.sizes()); + } + TORCH_CHECK( + shifts.size() == dims.size(), + "shifts and dimensions must align. shifts: ", + shifts.size(), + ", dims:", + dims.size()); + AT_ASSERT(dims.size() > 1); + auto tail_shifts = shifts.slice(1); + auto tail_dims = dims.slice(1); + auto first_dim_rolled = roll(self, shifts[0], dims[0]); + return at::roll(first_dim_rolled, tail_shifts, tail_dims); +} + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TopKImpl.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TopKImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..dfef5ffb36b2ee1ccb84f1ba2e4882da4ef7d83a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TopKImpl.h @@ -0,0 +1,103 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include + +namespace at::native { + +#ifdef CPU_CAPABILITY +inline namespace CPU_CAPABILITY { +#else +inline namespace DEFAULT { +#endif + +// Core topk loop, shared between CPU and QuantizedCPU +template +void topk_impl_loop( + const int64_t mode_values_stride, + const int64_t mode_indices_stride, + const int64_t tmp_values_stride, + const int64_t k, + const int64_t dim_size, + const bool largest, + const bool sorted, + char** data, const int64_t* strides, const int64_t n) { + + // If k is zero, then output values and indices are empty tensors + // So iterating over other dims is pointless + if (k == 0) { + return; + } + using elem_t = std::pair; + std::vector queue(dim_size); + for (const auto i : c10::irange(n)) { + TensorAccessor mode_values( + reinterpret_cast(data[0] + i * strides[0]), + &k, &mode_values_stride); + TensorAccessor mode_indices( + reinterpret_cast(data[1] + i * strides[1]), + &k, &mode_indices_stride); + TensorAccessor tmp_values( + reinterpret_cast(data[2] + i * strides[2]), + &dim_size, &tmp_values_stride); + + auto n_2 = dim_size; + auto use_partial_sort = k * 64 <= n_2; + + for (const auto j : c10::irange(n_2)) { + queue[j].first = tmp_values[j]; + queue[j].second = j; + } + + // we want nan to be sorted as top for numpy compatibility + if (use_partial_sort) { + if (largest) { + std::partial_sort(queue.begin(), queue.begin() + k, queue.end(), + [](const elem_t& x, const elem_t& y) -> bool { + return ((_isnan(x.first) && !_isnan(y.first)) || (x.first > y.first)); + }); + } else { + std::partial_sort(queue.begin(), queue.begin() + k, queue.end(), + [](const elem_t& x, const elem_t& y) -> bool { + return ((!_isnan(x.first) && _isnan(y.first)) || (x.first < y.first)); + }); + } + } else { + if (largest) { + std::nth_element(queue.begin(), queue.begin() + k - 1, queue.end(), + [](const elem_t& x, const elem_t& y) -> bool { + return ((_isnan(x.first) && !_isnan(y.first)) || (x.first > y.first)); + }); + if (sorted) { + std::sort(queue.begin(), queue.begin() + k - 1, + [](const elem_t& x, const elem_t& y) -> bool { + return ((_isnan(x.first) && !_isnan(y.first)) || (x.first > y.first)); + }); + } + } else { + std::nth_element(queue.begin(), queue.begin() + k -1, queue.end(), + [](const elem_t& x, const elem_t& y) -> bool { + return ((!_isnan(x.first) && _isnan(y.first)) || (x.first < y.first)); + }); + if (sorted) { + std::sort(queue.begin(), queue.begin() + k -1, + [](const elem_t& x, const elem_t& y) -> bool { + return ((!_isnan(x.first) && _isnan(y.first)) || (x.first < y.first)); + }); + } + } + } + + for (const auto j : c10::irange(k)) { + mode_values[j] = queue[j].first; + mode_indices[j] = queue[j].second; + } + } +} + +} // namespace CPU_CAPABILITY +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TypeProperties.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TypeProperties.h new file mode 100644 index 0000000000000000000000000000000000000000..d276b7e8b8c9683ee776012fd6575781e31ef270 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/TypeProperties.h @@ -0,0 +1,25 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace at::native { + +struct ResultTypeState { + c10::ScalarType dimResult = ScalarType::Undefined; + c10::ScalarType wrappedResult = ScalarType::Undefined; + c10::ScalarType zeroResult = ScalarType::Undefined; +}; + +TORCH_API ResultTypeState update_result_type_state(const Tensor& tensor, const ResultTypeState& in_state); +TORCH_API ResultTypeState update_result_type_state(const Scalar& scalar, const ResultTypeState& in_state); +TORCH_API ScalarType result_type(const ResultTypeState& state); + +TORCH_API ScalarType result_type(ITensorListRef tensors); + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/UnaryOps.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/UnaryOps.h new file mode 100644 index 0000000000000000000000000000000000000000..5e8a1ff0748d64bd47eb4f99709dd768885d308b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/UnaryOps.h @@ -0,0 +1,133 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +namespace at { +class Tensor; +class TensorBase; +struct TensorIteratorBase; +} + +namespace at::native { + +using unary_fn = void(*)(TensorIteratorBase&); +using unary_fn_with_scalar = void(*)(TensorIteratorBase&, const Scalar& a); + +inline namespace CPU_CAPABILITY { +void conj_kernel(TensorIteratorBase &iter); +void neg_kernel(TensorIteratorBase &iter); +void reciprocal_kernel(TensorIteratorBase &iter); +void rsqrt_kernel(TensorIteratorBase& iter); +void sqrt_kernel(TensorIteratorBase& iter); +} // namespace CPU_CAPABILITY + +DECLARE_DISPATCH(unary_fn, abs_stub) +DECLARE_DISPATCH(unary_fn, angle_stub) +DECLARE_DISPATCH(unary_fn, conj_physical_stub) +DECLARE_DISPATCH(unary_fn, acos_stub) +DECLARE_DISPATCH(unary_fn, acosh_stub) +DECLARE_DISPATCH(unary_fn, asinh_stub) +DECLARE_DISPATCH(unary_fn, atanh_stub) +DECLARE_DISPATCH(unary_fn, asin_stub) +DECLARE_DISPATCH(unary_fn, atan_stub) +DECLARE_DISPATCH(unary_fn, bitwise_not_stub) +DECLARE_DISPATCH(unary_fn, logical_not_stub) +DECLARE_DISPATCH(unary_fn, ceil_stub) +DECLARE_DISPATCH(unary_fn, cos_stub) +DECLARE_DISPATCH(unary_fn, cosh_stub) +DECLARE_DISPATCH(unary_fn, digamma_stub) +DECLARE_DISPATCH(unary_fn, special_entr_stub) +DECLARE_DISPATCH(unary_fn, special_erfcx_stub) +DECLARE_DISPATCH(unary_fn, erf_stub) +DECLARE_DISPATCH(unary_fn, erfc_stub) +DECLARE_DISPATCH(unary_fn, erfinv_stub) +DECLARE_DISPATCH(unary_fn, exp_stub) +DECLARE_DISPATCH(unary_fn, exp2_stub) +DECLARE_DISPATCH(unary_fn, expm1_stub) +DECLARE_DISPATCH(unary_fn, floor_stub) +DECLARE_DISPATCH(unary_fn, frac_stub) +DECLARE_DISPATCH(unary_fn, frexp_stub) +DECLARE_DISPATCH(unary_fn, i0_stub) +DECLARE_DISPATCH(unary_fn, special_i0e_stub) +DECLARE_DISPATCH(unary_fn, special_i1_stub) +DECLARE_DISPATCH(unary_fn, special_i1e_stub) +DECLARE_DISPATCH(unary_fn, log_stub) +DECLARE_DISPATCH(unary_fn, log10_stub) +DECLARE_DISPATCH(unary_fn, log1p_stub) +DECLARE_DISPATCH(unary_fn, log2_stub) +DECLARE_DISPATCH(unary_fn, special_ndtri_stub) +DECLARE_DISPATCH(unary_fn, special_log_ndtr_stub) +DECLARE_DISPATCH(unary_fn, neg_stub) + +DECLARE_DISPATCH(unary_fn, reciprocal_stub) +DECLARE_DISPATCH(unary_fn, round_stub) +DECLARE_DISPATCH(unary_fn, rsqrt_stub) +DECLARE_DISPATCH(unary_fn, sigmoid_stub) +DECLARE_DISPATCH(unary_fn_with_scalar, logit_stub) +DECLARE_DISPATCH(unary_fn, sign_stub) +DECLARE_DISPATCH(unary_fn, signbit_stub) +DECLARE_DISPATCH(unary_fn, sgn_stub) +DECLARE_DISPATCH(unary_fn, sin_stub) +DECLARE_DISPATCH(unary_fn, sinc_stub) +DECLARE_DISPATCH(unary_fn, sinh_stub) +DECLARE_DISPATCH(unary_fn, sqrt_stub) +DECLARE_DISPATCH(unary_fn, tan_stub) +DECLARE_DISPATCH(unary_fn, tanh_stub) +DECLARE_DISPATCH(unary_fn, trigamma_stub) +DECLARE_DISPATCH(unary_fn, trunc_stub) +DECLARE_DISPATCH(unary_fn, lgamma_stub) +DECLARE_DISPATCH(unary_fn, special_airy_ai_stub) +DECLARE_DISPATCH(unary_fn, special_bessel_j0_stub) +DECLARE_DISPATCH(unary_fn, special_bessel_j1_stub) +DECLARE_DISPATCH(unary_fn, special_bessel_y0_stub) +DECLARE_DISPATCH(unary_fn, special_bessel_y1_stub) +DECLARE_DISPATCH(unary_fn, special_modified_bessel_i0_stub) +DECLARE_DISPATCH(unary_fn, special_modified_bessel_i1_stub) +DECLARE_DISPATCH(unary_fn, special_modified_bessel_k0_stub) +DECLARE_DISPATCH(unary_fn, special_modified_bessel_k1_stub) +DECLARE_DISPATCH(unary_fn, special_scaled_modified_bessel_k0_stub) +DECLARE_DISPATCH(unary_fn, special_scaled_modified_bessel_k1_stub) +DECLARE_DISPATCH(unary_fn, special_spherical_bessel_j0_stub) + +// NB: these are actually defined in Distribution +DECLARE_DISPATCH(void(*)(const TensorBase&, const TensorBase&, std::optional), bernoulli_tensor_stub) +DECLARE_DISPATCH(void(*)(const TensorBase&, const double, std::optional), bernoulli_scalar_stub) +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, std::optional), cauchy_stub) +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, std::optional), exponential_stub) +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, std::optional), geometric_stub) +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, std::optional), log_normal_stub) +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, std::optional), uniform_stub) +DECLARE_DISPATCH(void(*)(const TensorBase&, const double, const double, std::optional), normal_stub) +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const uint64_t, const int64_t, std::optional), random_from_to_stub) +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, std::optional), random_full_64_bits_range_stub) +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, std::optional), random_stub) + +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const int64_t, const double), kaiser_window_stub) +DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const int64_t), polygamma_stub) +DECLARE_DISPATCH( + void (*)(Tensor&, const Tensor&, int64_t, std::optional), + multinomial_with_replacement_stub) +DECLARE_DISPATCH( + void (*)( + TensorIteratorBase&, + std::optional, + std::optional, + std::optional), + nan_to_num_stub) +DECLARE_DISPATCH(void (*)(TensorIteratorBase&, int64_t), round_decimals_stub) + +// Missing unary functions +// digamma +// lgamma +// erfinv +// clone +// contiguous +// zero +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Unfold3d.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Unfold3d.h new file mode 100644 index 0000000000000000000000000000000000000000..b8a2fa55234c0cd8181f8f286c438578df3368d2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/Unfold3d.h @@ -0,0 +1,54 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace at::native { + +void Unfold3dCopyCPU( + ScalarType dtype, + const void *src, + int64_t C, + int64_t X_D, + int64_t X_H, + int64_t X_W, + int64_t Y_D, + int64_t Y_H, + int64_t Y_W, + int64_t kernel_d, + int64_t kernel_h, + int64_t kernel_w, + int64_t stride_d, + int64_t stride_h, + int64_t stride_w, + int64_t pad_d, + int64_t pad_h, + int64_t pad_w, + void* dst); + +void Unfold3dAccCPU( + ScalarType dtype, + const void *src, + int64_t C, + int64_t X_D, + int64_t X_H, + int64_t X_W, + int64_t Y_D, + int64_t Y_H, + int64_t Y_W, + int64_t kernel_d, + int64_t kernel_h, + int64_t kernel_w, + int64_t stride_d, + int64_t stride_h, + int64_t stride_w, + int64_t pad_d, + int64_t pad_h, + int64_t pad_w, + void *dst); + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/UpSample.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/UpSample.h new file mode 100644 index 0000000000000000000000000000000000000000..e4e210901a0a2960a0191bc8d760201af9e764bf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/UpSample.h @@ -0,0 +1,514 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include + +/** + * Note [compute_scales_value] + * Note [area_pixel_compute_scale] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * Interpolate with scale_factor can have different behaviors + * depending on the value of recompute_scale_factor: + * + * - With recompute_scale_factor = True (current default behavior): + * the scale_factor, when provided by the user, are used to calculate + * the output size. The input size and the computed output_size + * are then used to infer new values for the scales which are + * used in the interpolation. Because floating-point math is not exact, + * this may be a different value from the user-supplied scales. + * + * - With recompute_scale_factor = False (which will be the default + * behavior starting 1.5.0): + * the behavior follows opencv logic, and the scales provided by + * the user are the ones used in the interpolation calculations. + * + * If the scales are not provided or if they are provided but + * recompute_scale_factor is set to True (default behavior), the scales + * are computed from the input and the output size; + * + * + * When the scales are inferred from the input and output sizes, + * we view each pixel as an area, idx + 0.5 as its center index. + * Here is an example formula in 1D case. + * if align_corners: center of two corner pixel areas are preserved, + * (0.5, 0.5) -> (0.5, 0.5), + * (input_size - 0.5, 0.5) -> (output_size - 0.5) + * scale = (input_size - 0.5 - 0.5) / (output_size - 0.5 - 0.5) + * src_index + 0.5 - 0.5 = scale * (dst_index + 0.5 - 0.5) + * if not align_corners: the whole range is scaled accordingly + * scale = input_size / output_size + * src_idx + 0.5 = scale * (dst_index + 0.5) + */ + +namespace at::native { + +namespace upsample { + +TORCH_API c10::SmallVector compute_output_size( + c10::IntArrayRef input_size, // Full input tensor size. + at::OptionalIntArrayRef output_size, + std::optional> scale_factors); + +inline std::optional get_scale_value(std::optional> scales, int idx) { + if (!scales) { + return std::nullopt; + } + return scales->at(idx); +} + +} // namespace upsample + +using scale_t = std::optional; +using upsampling_nearest1d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_w); +using _upsampling_nearest_exact1d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_w); +using upsampling_nearest2d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_h, scale_t scales_w); +using _upsampling_nearest_exact2d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_h, scale_t scales_w); +using upsampling_nearest3d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_d, scale_t scales_h, scale_t scales_w); +using _upsampling_nearest_exact3d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_d, scale_t scales_h, scale_t scales_w); +using upsampling_linear1d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_w); +using upsampling_bilinear2d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w); +using _upsampling_bilinear2d_aa = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w); +using upsampling_trilinear3d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_d, scale_t scales_h, scale_t scales_w); +using upsampling_bicubic2d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w); +using _upsampling_bicubic2d_aa = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w); +DECLARE_DISPATCH(upsampling_nearest1d, upsample_nearest1d_kernel) +DECLARE_DISPATCH(_upsampling_nearest_exact1d, _upsample_nearest_exact1d_kernel) +DECLARE_DISPATCH(upsampling_nearest2d, upsample_nearest2d_kernel) +DECLARE_DISPATCH(_upsampling_nearest_exact2d, _upsample_nearest_exact2d_kernel) +DECLARE_DISPATCH(upsampling_nearest3d, upsample_nearest3d_kernel) +DECLARE_DISPATCH(_upsampling_nearest_exact3d, _upsample_nearest_exact3d_kernel) +DECLARE_DISPATCH(upsampling_nearest1d, upsample_nearest1d_backward_kernel) +DECLARE_DISPATCH(_upsampling_nearest_exact1d, _upsample_nearest_exact1d_backward_kernel) +DECLARE_DISPATCH(upsampling_nearest2d, upsample_nearest2d_backward_kernel) +DECLARE_DISPATCH(_upsampling_nearest_exact2d, _upsample_nearest_exact2d_backward_kernel) +DECLARE_DISPATCH(upsampling_nearest3d, upsample_nearest3d_backward_kernel) +DECLARE_DISPATCH(_upsampling_nearest_exact3d, _upsample_nearest_exact3d_backward_kernel) +DECLARE_DISPATCH(upsampling_linear1d, upsample_linear1d_kernel) +DECLARE_DISPATCH(upsampling_bilinear2d, upsample_bilinear2d_kernel) +DECLARE_DISPATCH(_upsampling_bilinear2d_aa, _upsample_bilinear2d_aa_kernel) +DECLARE_DISPATCH(upsampling_trilinear3d, upsample_trilinear3d_kernel) +DECLARE_DISPATCH(upsampling_linear1d, upsample_linear1d_backward_kernel) +DECLARE_DISPATCH(upsampling_bilinear2d, upsample_bilinear2d_backward_kernel) +DECLARE_DISPATCH(_upsampling_bilinear2d_aa, _upsample_bilinear2d_aa_backward_kernel) +DECLARE_DISPATCH(upsampling_trilinear3d, upsample_trilinear3d_backward_kernel) +DECLARE_DISPATCH(upsampling_bicubic2d, upsample_bicubic2d_kernel) +DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_kernel) +DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_backward_kernel) + +[[maybe_unused]] inline std::array upsample_1d_common_check( + IntArrayRef input_size, + IntArrayRef output_size) { + TORCH_CHECK( + output_size.size() == 1, + "It is expected output_size equals to 1, but got size ", + output_size.size()); + + TORCH_CHECK( + input_size.size() == 3, + "It is expected input_size equals to 3, but got size ", + input_size.size()); + + int64_t output_width = output_size[0]; + + int64_t nbatch = input_size[0]; + int64_t channels = input_size[1]; + int64_t input_width = input_size[2]; + + TORCH_CHECK( + input_width > 0 && output_width > 0, + "Input and output sizes should be greater than 0, but got input (W: ", + input_width, + ") and output (W: ", + output_width, + ")"); + + return {nbatch, channels, output_width}; +} + +[[maybe_unused]] inline std::array upsample_2d_common_check( + IntArrayRef input_size, + IntArrayRef output_size) { + TORCH_CHECK( + output_size.size() == 2, + "It is expected output_size equals to 2, but got size ", + output_size.size()); + + TORCH_CHECK( + input_size.size() == 4, + "It is expected input_size equals to 4, but got size ", + input_size.size()); + + int64_t output_height = output_size[0]; + int64_t output_width = output_size[1]; + + int64_t nbatch = input_size[0]; + int64_t channels = input_size[1]; + int64_t input_height = input_size[2]; + int64_t input_width = input_size[3]; + + TORCH_CHECK( + input_height > 0 && input_width > 0 && output_height > 0 && + output_width > 0, + "Input and output sizes should be greater than 0," + " but got input (H: ", + input_height, + ", W: ", + input_width, + ") output (H: ", + output_height, + ", W: ", + output_width, + ")"); + + return {nbatch, channels, output_height, output_width}; +} + +[[maybe_unused]] inline std::array upsample_3d_common_check( + IntArrayRef input_size, + IntArrayRef output_size) { + TORCH_CHECK( + output_size.size() == 3, + "It is expected output_size equals to 3, but got size ", + output_size.size()); + + TORCH_CHECK( + input_size.size() == 5, + "It is expected input_size equals to 5, but got size ", + input_size.size()); + + int64_t output_depth = output_size[0]; + int64_t output_height = output_size[1]; + int64_t output_width = output_size[2]; + + int64_t nbatch = input_size[0]; + int64_t channels = input_size[1]; + int64_t input_depth = input_size[2]; + int64_t input_height = input_size[3]; + int64_t input_width = input_size[4]; + + TORCH_CHECK( + input_depth > 0 && input_height > 0 && input_width > 0 && + output_depth > 0 && output_height > 0 && output_width > 0, + "Input and output sizes should be greater than 0, but got input (D: ", + input_depth, + ", H: ", + input_height, + ", W: ", + input_width, + ") output (D: ", + output_depth, + ", H: ", + output_height, + ", W: ", + output_width, + ")"); + + + return {nbatch, channels, output_depth, output_height, output_width}; +} + +inline void upsample_2d_shape_check( + const Tensor& input, + const Tensor& grad_output, + int64_t nbatch, + int64_t nchannels, + int64_t input_height, + int64_t input_width, + int64_t output_height, + int64_t output_width) { + TORCH_CHECK( + input_height > 0 && input_width > 0 && output_height > 0 && + output_width > 0, + "Input and output sizes should be greater than 0," + " but got input (H: ", + input_height, + ", W: ", + input_width, + ") output (H: ", + output_height, + ", W: ", + output_width, + ")"); + + if (input.defined()) { + // Allow for empty batch size but not other dimensions + TORCH_CHECK( + (input.numel() != 0 || + (input.size(1) != 0 && input.size(2) != 0 && input.size(3) != 0) + ) && + input.dim() == 4, + "Non-empty 4D data tensor expected but got a tensor with sizes ", + input.sizes()); + } else if (grad_output.defined()) { + check_dim_size(grad_output, 4, 0, nbatch); + check_dim_size(grad_output, 4, 1, nchannels); + check_dim_size(grad_output, 4, 2, output_height); + check_dim_size(grad_output, 4, 3, output_width); + } +} + +template +inline scalar_t compute_scales_value( + const std::optional scale, + int64_t input_size, + int64_t output_size) { + // see Note [compute_scales_value] + // FIXME: remove magic > 0 after we ensure no models were serialized with -1 defaults. + return (scale.has_value() && scale.value() > 0.) + ? static_cast(1.0 / scale.value()) + : (static_cast(input_size) / output_size); +} + +template +inline scalar_t area_pixel_compute_scale( + int64_t input_size, + int64_t output_size, + bool align_corners, + const std::optional scale) { + // see Note [area_pixel_compute_scale] + if(align_corners) { + if(output_size > 1) { + return static_cast(input_size - 1) / (output_size - 1); + } else { + return static_cast(0); + } + } else { + return compute_scales_value(scale, input_size, output_size); + } +} + +template +inline scalar_t area_pixel_compute_source_index( + scalar_t scale, + int64_t dst_index, + bool align_corners, + bool cubic) { + if (align_corners) { + return scale * dst_index; + } else { + scalar_t src_idx = scale * (dst_index + static_cast(0.5)) - + static_cast(0.5); + // [Note] Follow Opencv resize logic: + // We allow negative src_idx here and later will use + // dx = src_idx - floorf(src_idx) + // to compute the "distance"(which affects weights). + // For linear modes, weight distribution doesn't matter + // for negative indices as they use 2 pixels to interpolate. + // For example, [-1, 0], they both use pixel 0 value so it + // doesn't affect if we bound the src_idx to 0 or not. + // TODO: Our current linear mode impls use unbound indices + // where we should and then remove this cubic flag. + // This matters in cubic mode, as we might need [-1, 0, 1, 2] + // to interpolate and the weights can be affected. + return (!cubic && src_idx < static_cast(0)) ? scalar_t(0) + : src_idx; + } +} + +inline int64_t nearest_neighbor_compute_source_index( + const float scale, + int64_t dst_index, + int64_t input_size) { + // Index computation matching OpenCV INTER_NEAREST + // which is buggy and kept for BC + const int64_t src_index = + std::min(static_cast(floorf(dst_index * scale)), input_size - 1); + return src_index; +} + +inline int64_t nearest_neighbor_exact_compute_source_index( + const float scale, + int64_t dst_index, + int64_t input_size) { + // index_f32 = (output_index + 0.5) * scale - 0.5 + // input_index = round(index_f32) + // Same as Pillow and Scikit-Image/Scipy ndi.zoom + const int64_t src_index = + std::min(static_cast(floorf((dst_index + 0.5) * scale)), input_size - 1); + return src_index; +} + +inline int64_t nearest_idx( + int64_t output_index, + int64_t input_size, + int64_t output_size, + std::optional scales) { + // This method specifically treats cases: output_size == input_size or + // output_size == 2 * input_size, that we would like to get rid of + // We keep this method for BC and consider as deprecated. + // See nearest_exact_idx as replacement + if (output_size == input_size) { + // scale_factor = 1, simply copy + return output_index; + } else if (output_size == 2 * input_size) { + // scale_factor = 2, shift input index + return output_index >> 1; + } else { + float scale = compute_scales_value(scales, input_size, output_size); + return nearest_neighbor_compute_source_index(scale, output_index, input_size); + } +} + +inline int64_t nearest_exact_idx( + int64_t output_index, + int64_t input_size, + int64_t output_size, + std::optional scales) { + float scale = compute_scales_value(scales, input_size, output_size); + return nearest_neighbor_exact_compute_source_index(scale, output_index, input_size); +} + +// Define a typedef to dispatch to nearest_idx or nearest_exact_idx +typedef int64_t (*nearest_idx_fn_t)(int64_t, int64_t, int64_t, std::optional); + +template +scalar_t upsample_get_value_bounded( + scalar_t* data, + int64_t width, + int64_t height, + int64_t x, + int64_t y) { + int64_t access_x = std::max(std::min(x, width - 1), static_cast(0)); + int64_t access_y = std::max(std::min(y, height - 1), static_cast(0)); + return data[access_y * width + access_x]; +} + +template +void upsample_increment_value_bounded( + scalar_t* data, + int64_t width, + int64_t height, + int64_t x, + int64_t y, + scalar_t value) { + int64_t access_x = std::max(std::min(x, width - 1), static_cast(0)); + int64_t access_y = std::max(std::min(y, height - 1), static_cast(0)); + data[access_y * width + access_x] += value; +} + +// Based on +// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm +template +scalar_t cubic_convolution1(scalar_t x, scalar_t A) { + return ((A + 2) * x - (A + 3)) * x * x + 1; +} + +template +scalar_t cubic_convolution2(scalar_t x, scalar_t A) { + return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; +} + +template +static inline void get_cubic_upsample_coefficients( + scalar_t coeffs[4], + scalar_t t) { + scalar_t A = -0.75; + + scalar_t x1 = t; + coeffs[0] = cubic_convolution2(x1 + 1.0, A); + coeffs[1] = cubic_convolution1(x1, A); + + // opposite coefficients + scalar_t x2 = 1.0 - t; + coeffs[2] = cubic_convolution1(x2, A); + coeffs[3] = cubic_convolution2(x2 + 1.0, A); +} + +template +inline scalar_t cubic_interp1d( + scalar_t x0, + scalar_t x1, + scalar_t x2, + scalar_t x3, + scalar_t t) { + scalar_t coeffs[4]; + get_cubic_upsample_coefficients(coeffs, t); + + return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]; +} + +// when `real_input_index` becomes larger than the range the floating point +// type can accurately represent, the type casting to `int64_t` might exceed +// `input_size`, causing overflow. So we guard it with `std::min` below. +template +inline void guard_index_and_lambda(const opmath_t& real_input_index, const int64_t& input_size, int64_t& input_index, scalar_t& lambda) { + input_index = std::min(static_cast(floorf(real_input_index)), input_size - 1); + lambda = std::min( + std::max(real_input_index - input_index, static_cast(0)), + static_cast(1) + ); +} + +template +inline void compute_source_index_and_lambda( + int64_t& input_index0, + int64_t& input_index1, + scalar_t& lambda0, + scalar_t& lambda1, + opmath_t ratio, + int64_t output_index, + int64_t input_size, + int64_t output_size, + bool align_corners) { + if (output_size == input_size) { + // scale_factor = 1, simply copy + input_index0 = output_index; + input_index1 = output_index; + lambda0 = static_cast(1); + lambda1 = static_cast(0); + } else { + const auto real_input_index = + area_pixel_compute_source_index( + ratio, output_index, align_corners, /*cubic=*/false); + guard_index_and_lambda(real_input_index, input_size, input_index0, lambda1); + int64_t offset = (input_index0 < input_size - 1) ? 1 : 0; + input_index1 = input_index0 + offset; + lambda0 = static_cast(1.) - lambda1; + } +} + +// It will not be used by data types other than BFloat16 and Half. +template || !std::is_same_v, int> = 0> +void inline apply_grad_input(scalar_in* buffer_ptr, scalar_out* gin, int64_t size) { + TORCH_CHECK((is_reduced_floating_point_v), + "Upsample backward only support BFloat16 and Half in the lower precision data types on CPU.") + TORCH_CHECK((std::is_same_v), + "Upsample backward should use float as acc buffer for BFloat16 and Half grad input on CPU.") + return; +} + +template && std::is_same_v, int> = 0> +void inline apply_grad_input(scalar_in* buffer_ptr, scalar_out* gin, int64_t size) { + using bVec = Vectorized; + using fVec = Vectorized; + int64_t d = 0; + for (; d < size - (size % bVec::size()); d += bVec::size()) { + bVec gin_bvec = bVec::loadu(gin + d); + auto [gin_fvec0, gin_fvec1] = convert_to_float(gin_bvec); + gin_fvec0 += fVec::loadu(buffer_ptr + d); + gin_fvec1 += fVec::loadu(buffer_ptr + d + fVec::size()); + fVec(0).store(buffer_ptr + d); + fVec(0).store(buffer_ptr + d + fVec::size()); + convert_from_float(gin_fvec0, gin_fvec1).store(gin + d); + } + for (; d < size; d++) { + gin[d] += buffer_ptr[d]; + buffer_ptr[d] = 0; + } +} + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/im2col_shape_check.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/im2col_shape_check.h new file mode 100644 index 0000000000000000000000000000000000000000..8b2b946c01cf15fdaf723298a4b484071aae918e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/im2col_shape_check.h @@ -0,0 +1,246 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include +#include +#include + +namespace at::native { + +inline void col2im_shape_check( + const Tensor& input, + const Tensor& grad_output, + int64_t output_height, + int64_t output_width, + int64_t kernel_height, + int64_t kernel_width, + int64_t dilation_height, + int64_t dilation_width, + int64_t pad_height, + int64_t pad_width, + int64_t stride_height, + int64_t stride_width) { + TORCH_CHECK( + kernel_width > 0 && kernel_height > 0, + "kernel size should be greater than zero, but got kernel_height: ", + kernel_height, + " kernel_width: ", + kernel_width); + TORCH_CHECK( + stride_width > 0 && stride_height > 0, + "stride should be greater than zero, but got stride_height: ", + stride_height, + " stride_width: ", + stride_width); + TORCH_CHECK( + dilation_width > 0 && dilation_height > 0, + "dilation should be greater than zero, but got dilation_height: ", + dilation_height, + " dilation_width: ", + dilation_width); + TORCH_CHECK( + pad_width >= 0 && pad_height >= 0, + "padding should be non-negative, but got pad_height: ", + pad_height, + " pad_width: ", + pad_width); + + + int64_t ndim = input.ndimension(); + // allow dim=0 only the batch dimension. + TORCH_CHECK( + (ndim == 2 && input.size(0) != 0 && input.size(1) != 0) || + (ndim == 3 && input.size(1) != 0 && input.size(2) != 0), + "Expected 2D or 3D (batch mode) tensor for input with possibly 0 batch size and non-zero dimensions for input, but got: ", + input.sizes()); + + int64_t batch_dim = (ndim == 3) ? 0 : -1; + int64_t n_input_plane = input.size(batch_dim + 1); + uint64_t prod_kernel_size = 1; + + TORCH_CHECK(!c10::mul_overflows(static_cast(kernel_width), static_cast(kernel_height), &prod_kernel_size), + "Given kernel_width = ", + kernel_width, + " and kernel_height = ", + kernel_height, + " the product of kernel_width and kernel_height overflowed."); + + if (n_input_plane % (kernel_width * kernel_height) != 0) { + TORCH_CHECK(false, + "Expected size of input's dimension 1 to be divisible by the " + "product of kernel_size, but got input.size(1)=", + n_input_plane, + " and kernel_size=(", + kernel_height, + ", ", + kernel_width, + ")."); + } + + int64_t input_length = input.size(batch_dim + 2); + int64_t n_blocks_height = + div_rtn( + output_height + 2 * pad_height - + dilation_height * (kernel_height - 1) - 1, + stride_height) + + 1; + int64_t n_blocks_width = div_rtn( + output_width + 2 * pad_width - + dilation_width * (kernel_width - 1) - 1, + stride_width) + + 1; + + if (input_length != (n_blocks_height * n_blocks_width)) { + TORCH_CHECK(false, + "Given output_size=(", + output_height, + ", ", + output_width, + "), kernel_size=(", + kernel_height, + ", ", + kernel_width, + "), dilation=(", + dilation_height, + ", ", + dilation_width, + "), padding=(", + pad_height, + ", ", + pad_width, + "), stride=(", + stride_height, + ", ", + stride_width, + "), expected size of input's dimension 2 to match the calculated number of ", + "sliding blocks ", + n_blocks_height, + " * ", + n_blocks_width, + " = ", + (n_blocks_height * n_blocks_width), + ", but got input.size(2)=", + input_length, + "."); + } + + TORCH_CHECK( + n_blocks_height >= 1 && n_blocks_width >= 1, + "Given output_size=(", output_height, ", ", output_width, "), ", + "kernel_size=(", kernel_height, ", ", kernel_width, "), ", + "dilation=(", dilation_height, ", ", dilation_width, "), ", + "padding=(", pad_height, ", ", pad_width, "), ", + "stride=(", stride_height, ", ", stride_width, "), ", + "calculated shape of the array of sliding blocks as ", + "(", n_blocks_height, ", ", n_blocks_width, "), ", + "which is too small (non-positive)"); + + if (output_width < 1 || output_height < 1) { + TORCH_CHECK(false, + "Expected output spatial size to be positive, but got: output_size=(", + output_height, + ", ", + output_width, + ")."); + } +} + +inline void im2col_shape_check( + const Tensor& input, + const Tensor& grad_output, + int64_t kernel_height, + int64_t kernel_width, + int64_t dilation_height, + int64_t dilation_width, + int64_t pad_height, + int64_t pad_width, + int64_t stride_height, + int64_t stride_width) { + TORCH_CHECK( + kernel_width > 0 && kernel_height > 0, + "kernel size should be greater than zero, but got kernel_height: ", + kernel_height, + " kernel_width: ", + kernel_width); + + TORCH_CHECK( + dilation_width > 0 && dilation_height > 0, + "dilation should be greater than zero, but got dilation_height: ", + dilation_height, + " dilation_width: ", + dilation_width); + + TORCH_CHECK( + pad_width >= 0 && pad_height >= 0, + "padding should be non-negative, but got pad_height: ", + pad_height, + " pad_width: ", + pad_width); + + TORCH_CHECK( + stride_width > 0 && stride_height > 0, + "stride should be greater than zero, but got stride_height: ", + stride_height, + " stride_width: ", + stride_width); + + int64_t ndim = input.ndimension(); + + // allow dim=0 only the batch dimension. + bool valid_dims = input.size(1) != 0 && input.size(2) != 0; + TORCH_CHECK( + (ndim == 3 && input.size(0) && valid_dims) || + (ndim == 4 && valid_dims && input.size(3) != 0), + "Expected 3D or 4D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ", + input.sizes()); + + int64_t dim_batch = 0; + + if (ndim == 3) { + dim_batch = -1; + } + + int64_t input_height = input.size(dim_batch + 2); + int64_t input_width = input.size(dim_batch + 3); + int64_t output_height = div_rtn( + input_height + 2 * pad_height - + (dilation_height * (kernel_height - 1) + 1), + stride_height) + + 1; + int64_t output_width = div_rtn( + input_width + 2 * pad_width - + (dilation_width * (kernel_width - 1) + 1), + stride_width) + + 1; + + if (output_height < 1 || output_width < 1) { + TORCH_CHECK(false, + "Given input with spatial size (", + input_height, + ", ", + input_height, + "), kernel_size=(", + kernel_height, + ", ", + kernel_width, + "), dilation=(", + dilation_height, + ", ", + dilation_width, + "), padding=(", + pad_height, + ", ", + pad_width, + "), calculated shape of the array of sliding blocks as (", + output_height, + ", ", + output_width, + "), but its components must be at least one."); + } +} + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/verbose_wrapper.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/verbose_wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..bc144b9f92cc6441074485888f34cb4c44a683de --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/verbose_wrapper.h @@ -0,0 +1,13 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace torch::verbose { +TORCH_API int _mkl_set_verbose(int enable); +TORCH_API int _mkldnn_set_verbose(int level); +} // namespace torch::verbose + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/CachingHostAllocator.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/CachingHostAllocator.h new file mode 100644 index 0000000000000000000000000000000000000000..c153824e0607ef92b0828c1a670ad5d7644d2c9c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/CachingHostAllocator.h @@ -0,0 +1,43 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include + +namespace at::xpu { + +C10_DEPRECATED_MESSAGE( + "at::xpu::getCachingHostAllocator() is deprecated. Please use at::getHostAllocator(at::kXPU) instead.") +inline TORCH_XPU_API at::HostAllocator* getCachingHostAllocator() { + return at::getHostAllocator(at::kXPU); +} + +C10_DEPRECATED_MESSAGE( + "at::xpu::CachingHostAllocator_recordEvent(...) is deprecated. Please use at::getHostAllocator(at::kXPU)->record_event(...) instead.") +inline TORCH_XPU_API bool CachingHostAllocator_recordEvent( + void* ptr, + void* ctx, + c10::xpu::XPUStream stream) { + return getHostAllocator(at::kXPU)->record_event(ptr, ctx, stream.unwrap()); +} + +C10_DEPRECATED_MESSAGE( + "at::xpu::CachingHostAllocator_emptyCache() is deprecated. Please use at::getHostAllocator(at::kXPU)->empty_cache() instead.") +inline TORCH_XPU_API void CachingHostAllocator_emptyCache() { + getHostAllocator(at::kXPU)->empty_cache(); +} + +C10_DEPRECATED_MESSAGE( + "at::xpu::HostAlloc(...) is deprecated. Please use at::getHostAllocator(at::kXPU)->allocate(...) instead.") +inline TORCH_XPU_API at::DataPtr HostAlloc(size_t size) { + return getHostAllocator(at::kXPU)->allocate(size); +} + +} // namespace at::xpu + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/PeerToPeerAccess.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/PeerToPeerAccess.h new file mode 100644 index 0000000000000000000000000000000000000000..a807cd1ffc18d6f2d4248e9a525457b71fcb70ee --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/PeerToPeerAccess.h @@ -0,0 +1,20 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace at::xpu { +namespace detail { +void init_p2p_access_cache(c10::DeviceIndex num_devices); +} // namespace detail + +TORCH_XPU_API bool get_p2p_access( + c10::DeviceIndex dev, + c10::DeviceIndex dev_to_access); + +} // namespace at::xpu + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/PhiloxXpuState.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/PhiloxXpuState.h new file mode 100644 index 0000000000000000000000000000000000000000..f3f602fe3dd5a6c3ebb9d20f8e51a4f246679977 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/PhiloxXpuState.h @@ -0,0 +1,50 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +namespace at { + +struct PhiloxXpuState { + PhiloxXpuState() = default; + PhiloxXpuState(uint64_t seed, uint64_t offset) { + seed_.val = seed; + offset_.val = offset; + } + // for graph capture + PhiloxXpuState( + int64_t* seed, + int64_t* offset_extragraph, + uint32_t offset_intragraph) { + seed_.ptr = seed; + offset_.ptr = offset_extragraph; + offset_intragraph_ = offset_intragraph; + captured_ = true; + } + + union Payload { + uint64_t val; + int64_t* ptr; + }; + + Payload seed_{}; + Payload offset_{}; + uint32_t offset_intragraph_ = 0; + bool captured_ = false; +}; + +namespace xpu::philox { +inline std::tuple unpack(at::PhiloxXpuState arg) { + if (arg.captured_) { + return std::make_tuple( + static_cast(*arg.seed_.ptr), + static_cast(*(arg.offset_.ptr) + arg.offset_intragraph_)); + } else { + return std::make_tuple(arg.seed_.val, arg.offset_.val); + } +} + +} // namespace xpu::philox +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/PinnedMemoryAllocator.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/PinnedMemoryAllocator.h new file mode 100644 index 0000000000000000000000000000000000000000..0ef4066089c40bbc28001edf111209932b995864 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/PinnedMemoryAllocator.h @@ -0,0 +1,16 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace at::xpu { + +inline TORCH_XPU_API at::HostAllocator* getPinnedMemoryAllocator() { + return at::getHostAllocator(at::kXPU); +} +} // namespace at::xpu + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUContext.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUContext.h new file mode 100644 index 0000000000000000000000000000000000000000..049b4f68267552a98ce8e07210dd7d3aa8e80e91 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUContext.h @@ -0,0 +1,27 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +namespace at::xpu { + +// XPU is available if we compiled with XPU. +inline bool is_available() { + return c10::xpu::device_count() > 0; +} + +TORCH_XPU_API DeviceProp* getCurrentDeviceProperties(); + +TORCH_XPU_API DeviceProp* getDeviceProperties(DeviceIndex device); + +TORCH_XPU_API int32_t getGlobalIdxFromDevice(DeviceIndex device); + +TORCH_XPU_API bool canDeviceAccessPeer(DeviceIndex device, DeviceIndex peer); + +} // namespace at::xpu + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUDevice.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUDevice.h new file mode 100644 index 0000000000000000000000000000000000000000..63b56c86c6ed26d2877eb4534dd831007830dbb4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUDevice.h @@ -0,0 +1,18 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace at::xpu { + +inline Device getDeviceFromPtr(void* ptr) { + auto device = c10::xpu::get_device_idx_from_pointer(ptr); + return {c10::DeviceType::XPU, device}; +} + +} // namespace at::xpu + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUEvent.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUEvent.h new file mode 100644 index 0000000000000000000000000000000000000000..be5c5b83169f0a632d913e08b161ab19bafb6421 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUEvent.h @@ -0,0 +1,8 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUGeneratorImpl.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUGeneratorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..15567d178f5848214055d9e9df6411ced16a3a5e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUGeneratorImpl.h @@ -0,0 +1,78 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +namespace at { + +namespace xpu { +struct XPUGraph; +} + +struct XPUGeneratorState : public c10::intrusive_ptr_target { + uint64_t seed_; + uint64_t philox_offset_per_thread_; + uint32_t offset_intragraph_; + bool capturing_{}; + at::TensorBase seed_extragraph_{}; + at::TensorBase offset_extragraph_{}; + + XPUGeneratorState( + uint64_t seed = default_rng_seed_val, + uint64_t philox_offset_per_thread = 0, + uint32_t offset_intragraph = 0) + : seed_(seed), + philox_offset_per_thread_(philox_offset_per_thread), + offset_intragraph_(offset_intragraph) {} + + void increase(uint64_t increment); + + c10::intrusive_ptr clone(); +}; + +struct TORCH_XPU_API XPUGeneratorImpl : public GeneratorImpl { + // Constructors + XPUGeneratorImpl(DeviceIndex device_index = -1); + XPUGeneratorImpl( + DeviceIndex device_index, + c10::intrusive_ptr state_); + ~XPUGeneratorImpl() override = default; + + // XPUGeneratorImpl methods + std::shared_ptr clone() const; + void set_current_seed(uint64_t seed) override; + void set_offset(uint64_t offset) override; + uint64_t get_offset() const override; + uint64_t current_seed() const override; + uint64_t seed() override; + void set_state(const c10::TensorImpl& new_state) override; + c10::intrusive_ptr get_state() const override; + + void set_philox_offset_per_thread(uint64_t offset); + uint64_t philox_offset_per_thread() const; + + PhiloxXpuState philox_xpu_state(uint64_t increment); + // will remove once all ops are refactored to use philox_xpu_state. + std::pair philox_engine_inputs(uint64_t increment); + static c10::DeviceType device_type(); + + private: + XPUGeneratorImpl* clone_impl() const override; + c10::intrusive_ptr state_; +}; + +namespace xpu::detail { + +TORCH_XPU_API const Generator& getDefaultXPUGenerator(DeviceIndex device = -1); + +TORCH_XPU_API Generator createXPUGenerator(DeviceIndex device = -1); + +} // namespace xpu::detail +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUGraphsUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUGraphsUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..8b61e894d54d97ee140049b356477a82d38fd6b7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUGraphsUtils.h @@ -0,0 +1,27 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace at::xpu { + +inline CaptureStatus currentStreamCaptureStatus() { + return c10::xpu::currentStreamCaptureStatusMayInitCtx(); +} + +inline void assertNotCapturing(const std::string& attempt) { + auto status = currentStreamCaptureStatus(); + TORCH_CHECK( + status == CaptureStatus::Executing, + attempt, + " during XPU graph capture. If you need this call to be captured, " + "please file an issue. " + "Current xpuStreamCaptureStatus: ", + status); +} + +} // namespace at::xpu + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUScaledBlas.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUScaledBlas.h new file mode 100644 index 0000000000000000000000000000000000000000..883e7642d968186b4a006c74e9ac558c2d3557ce --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/xpu/XPUScaledBlas.h @@ -0,0 +1,100 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include +#include +#include +#include +#include +#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#ifdef USE_FBGEMM_GENAI +#include +#endif + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + +using at::blas::ScalingType; + +namespace at::native::onednn::scaled { + +/** + * Track concrete implementations available + */ +enum class ScaledGemmImplementation { + NONE = 0, + TENSORWISE_TENSORWISE = 1, + ROWWISE_ROWWISE = 2, +}; + +/** + * Convert passed int (enum) from python back into a + * strictly-typed enum + */ +template +std::vector convert_int_to_enum(ArrayType& v) { + std::vector converted; + converted.reserve(v.size()); + + for (auto vi : v) { + converted.push_back(static_cast(vi)); + } + return converted; +} + +bool check_tensorwise_recipe( + c10::ScalarType, + std::vector&, + ArrayRef&, + c10::ScalarType, + std::vector&, + ArrayRef&); + +bool check_rowwise_recipe( + c10::ScalarType, + std::vector&, + ArrayRef&, + c10::ScalarType, + std::vector&, + ArrayRef&); + +} // namespace at::native::onednn::scaled + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Array.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Array.h new file mode 100644 index 0000000000000000000000000000000000000000..5cb2d8dff74253bf9c54d53b3aa532d91bee89a8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Array.h @@ -0,0 +1,23 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace c10 { + +// This helper function creates a constexpr std::array +// From a compile time list of values, without requiring you to explicitly +// write out the length. +// +// See also https://stackoverflow.com/a/26351760/23845 +template +inline constexpr auto array_of(T&&... t) -> std::array { + return {{std::forward(t)...}}; +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/BFloat16-math.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/BFloat16-math.h new file mode 100644 index 0000000000000000000000000000000000000000..6865f84fa6af5dbd8e2fb60ff46f1bbabdead1fd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/BFloat16-math.h @@ -0,0 +1,304 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") +#endif + +namespace c10 { +template +struct is_reduced_floating_point + : std::integral_constant< + bool, + std::is_same_v || std::is_same_v> {}; + +template +constexpr bool is_reduced_floating_point_v = + is_reduced_floating_point::value; +} // namespace c10 + +namespace std { + +#if !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED) +using c10::is_reduced_floating_point; +using c10::is_reduced_floating_point_v; +#endif // !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED) + +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T acos(T a) { + return std::acos(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T asin(T a) { + return std::asin(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T atan(T a) { + return std::atan(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T atanh(T a) { + return std::atanh(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T erf(T a) { + return std::erf(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T erfc(T a) { + return std::erfc(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T exp(T a) { + return std::exp(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T expm1(T a) { + return std::expm1(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline bool isfinite(T a) { + return std::isfinite(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T log(T a) { + return std::log(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T log10(T a) { + return std::log10(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T log1p(T a) { + return std::log1p(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T log2(T a) { + return std::log2(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T ceil(T a) { + return std::ceil(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T cos(T a) { + return std::cos(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T floor(T a) { + return std::floor(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T nearbyint(T a) { + return std::nearbyint(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T sin(T a) { + return std::sin(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T tan(T a) { + return std::tan(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T sinh(T a) { + return std::sinh(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T cosh(T a) { + return std::cosh(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T tanh(T a) { + return std::tanh(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T trunc(T a) { + return std::trunc(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T lgamma(T a) { + return std::lgamma(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T sqrt(T a) { + return std::sqrt(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T rsqrt(T a) { + return 1.0 / std::sqrt(float(a)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T abs(T a) { + return std::abs(float(a)); +} +#if defined(_MSC_VER) && defined(__CUDACC__) +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T pow(T a, double b) { + return std::pow(float(a), float(b)); +} +#else +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T pow(T a, double b) { + return std::pow(float(a), b); +} +#endif +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T pow(T a, T b) { + return std::pow(float(a), float(b)); +} +template < + typename T, + typename std::enable_if_t, int> = 0> +inline T fmod(T a, T b) { + return std::fmod(float(a), float(b)); +} + +/* + The following function is inspired from the implementation in `musl` + Link to License: https://git.musl-libc.org/cgit/musl/tree/COPYRIGHT + ---------------------------------------------------------------------- + Copyright © 2005-2020 Rich Felker, et al. + + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files (the + "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to + the following conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + ---------------------------------------------------------------------- + */ +template < + typename T, + typename std::enable_if_t, int> = 0> +C10_HOST_DEVICE inline T nextafter(T from, T to) { + // Reference: + // https://git.musl-libc.org/cgit/musl/tree/src/math/nextafter.c + using int_repr_t = uint16_t; + constexpr uint8_t bits = 16; + union { + T f; + int_repr_t i; + } ufrom = {from}, uto = {to}; + + // get a mask to get the sign bit i.e. MSB + int_repr_t sign_mask = int_repr_t{1} << (bits - 1); + + // short-circuit: if either is NaN, return NaN + if (from != from || to != to) { + return from + to; + } + + // short-circuit: if they are exactly the same. + if (ufrom.i == uto.i) { + return from; + } + + // mask the sign-bit to zero i.e. positive + // equivalent to abs(x) + int_repr_t abs_from = ufrom.i & ~sign_mask; + int_repr_t abs_to = uto.i & ~sign_mask; + if (abs_from == 0) { + // if both are zero but with different sign, + // preserve the sign of `to`. + if (abs_to == 0) { + return to; + } + // smallest subnormal with sign of `to`. + ufrom.i = (uto.i & sign_mask) | int_repr_t{1}; + return ufrom.f; + } + + // if abs(from) > abs(to) or sign(from) != sign(to) + if (abs_from > abs_to || ((ufrom.i ^ uto.i) & sign_mask)) { + ufrom.i--; + } else { + ufrom.i++; + } + + return ufrom.f; +} + +} // namespace std + +C10_CLANG_DIAGNOSTIC_POP() + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/BFloat16.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/BFloat16.h new file mode 100644 index 0000000000000000000000000000000000000000..90ca6b677ab3740550f4700479497fd58c35536b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/BFloat16.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Backtrace.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Backtrace.h new file mode 100644 index 0000000000000000000000000000000000000000..0a9e8d2c27ff43ab571d3883567ef5535c3287db --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Backtrace.h @@ -0,0 +1,36 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef C10_UTIL_BACKTRACE_H_ +#define C10_UTIL_BACKTRACE_H_ + +#include +#include +#include +#include + +#include +#include + +namespace c10 { + +// Symbolizing the backtrace can be expensive; pass it around as a lazy string +// so it is symbolized only if actually needed. +using Backtrace = std::shared_ptr>; + +// DEPRECATED: Prefer get_lazy_backtrace(). +C10_API std::string get_backtrace( + size_t frames_to_skip = 0, + size_t maximum_number_of_frames = 64, + bool skip_python_frames = true); + +C10_API Backtrace get_lazy_backtrace( + size_t frames_to_skip = 0, + size_t maximum_number_of_frames = 64, + bool skip_python_frames = true); + +} // namespace c10 + +#endif // C10_UTIL_BACKTRACE_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/DeadlockDetection.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/DeadlockDetection.h new file mode 100644 index 0000000000000000000000000000000000000000..5fd611a2add7563d8c6ca6fba28e704765c4ec79 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/DeadlockDetection.h @@ -0,0 +1,57 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +/// This file provides some simple utilities for detecting common deadlocks in +/// PyTorch. For now, we focus exclusively on detecting Python GIL deadlocks, +/// as the GIL is a wide ranging lock that is taken out in many situations. +/// The basic strategy is before performing an operation that may block, you +/// can use TORCH_ASSERT_NO_GIL_WITHOUT_PYTHON_DEP() to assert that the GIL is +/// not held. This macro is to be used in contexts where no static dependency +/// on Python is available (we will handle indirecting a virtual call for you). +/// +/// If the GIL is held by a torchdeploy interpreter, we always report false. +/// If you are in a context where Python bindings are available, it's better +/// to directly assert on PyGILState_Check (as it avoids a vcall and also +/// works correctly with torchdeploy.) + +#define TORCH_ASSERT_NO_GIL_WITHOUT_PYTHON_DEP() \ + TORCH_INTERNAL_ASSERT( \ + !c10::impl::check_python_gil(), \ + "Holding GIL before a blocking operation! Please release the GIL before blocking, or see https://github.com/pytorch/pytorch/issues/56297 for how to release the GIL for destructors of objects") + +namespace c10::impl { + +C10_API bool check_python_gil(); + +struct C10_API PythonGILHooks { + virtual ~PythonGILHooks() = default; + // Returns true if we hold the GIL. If not linked against Python we + // always return false. + virtual bool check_python_gil() const = 0; +}; + +C10_API void SetPythonGILHooks(PythonGILHooks* factory); + +// DO NOT call this registerer from a torch deploy instance! You will clobber +// other registrations +struct C10_API PythonGILHooksRegisterer { + explicit PythonGILHooksRegisterer(PythonGILHooks* factory) { + SetPythonGILHooks(factory); + } + PythonGILHooksRegisterer(const PythonGILHooksRegisterer&) = delete; + PythonGILHooksRegisterer(PythonGILHooksRegisterer&&) = delete; + PythonGILHooksRegisterer& operator=(const PythonGILHooksRegisterer&) = delete; + PythonGILHooksRegisterer& operator=(PythonGILHooksRegisterer&&) = delete; + ~PythonGILHooksRegisterer() { + SetPythonGILHooks(nullptr); + } +}; + +} // namespace c10::impl + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Enumerate.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Enumerate.h new file mode 100644 index 0000000000000000000000000000000000000000..441e158ccc4ab86cd7c19a25963d1da7005c82e9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Enumerate.h @@ -0,0 +1,164 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Ported from folly/container/Enumerate.h + */ + +#pragma once + +#include +#include + +#ifdef _WIN32 +#include // @manual +using ssize_t = SSIZE_T; +#endif + +#include + +/** + * Similar to Python's enumerate(), enumerate() can be used to + * iterate a range with a for-range loop, and it also allows to + * retrieve the count of iterations so far. Can be used in constexpr + * context. + * + * For example: + * + * for (auto&& [index, element] : enumerate(vec)) { + * // index is a const reference to a size_t containing the iteration count. + * // element is a reference to the type contained within vec, mutable + * // unless vec is const. + * } + * + * If the binding is const, the element reference is too. + * + * for (const auto&& [index, element] : enumerate(vec)) { + * // element is always a const reference. + * } + * + * It can also be used as follows: + * + * for (auto&& it : enumerate(vec)) { + * // *it is a reference to the current element. Mutable unless vec is const. + * // it->member can be used as well. + * // it.index contains the iteration count. + * } + * + * As before, const auto&& it can also be used. + */ + +namespace c10 { + +namespace detail { + +template +struct MakeConst { + using type = const T; +}; +template +struct MakeConst { + using type = const T&; +}; +template +struct MakeConst { + using type = const T*; +}; + +template +class Enumerator { + public: + constexpr explicit Enumerator(Iterator it) : it_(std::move(it)) {} + + class Proxy { + public: + using difference_type = ssize_t; + using value_type = typename std::iterator_traits::value_type; + using reference = typename std::iterator_traits::reference; + using pointer = typename std::iterator_traits::pointer; + using iterator_category = std::input_iterator_tag; + + C10_ALWAYS_INLINE constexpr explicit Proxy(const Enumerator& e) + : index(e.idx_), element(*e.it_) {} + + // Non-const Proxy: Forward constness from Iterator. + C10_ALWAYS_INLINE constexpr reference operator*() { + return element; + } + C10_ALWAYS_INLINE constexpr pointer operator->() { + return std::addressof(element); + } + + // Const Proxy: Force const references. + C10_ALWAYS_INLINE constexpr typename MakeConst::type operator*() + const { + return element; + } + C10_ALWAYS_INLINE constexpr typename MakeConst::type operator->() + const { + return std::addressof(element); + } + + public: + size_t index; + reference element; + }; + + C10_ALWAYS_INLINE constexpr Proxy operator*() const { + return Proxy(*this); + } + + C10_ALWAYS_INLINE constexpr Enumerator& operator++() { + ++it_; + ++idx_; + return *this; + } + + template + C10_ALWAYS_INLINE constexpr bool operator==( + const Enumerator& rhs) const { + return it_ == rhs.it_; + } + + template + C10_ALWAYS_INLINE constexpr bool operator!=( + const Enumerator& rhs) const { + return !(it_ == rhs.it_); + } + + private: + template + friend class Enumerator; + + Iterator it_; + size_t idx_ = 0; +}; + +template +class RangeEnumerator { + Range r_; + using BeginIteratorType = decltype(std::declval().begin()); + using EndIteratorType = decltype(std::declval().end()); + + public: + // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) + constexpr explicit RangeEnumerator(Range&& r) : r_(std::forward(r)) {} + + constexpr Enumerator begin() { + return Enumerator(r_.begin()); + } + constexpr Enumerator end() { + return Enumerator(r_.end()); + } +}; + +} // namespace detail + +template +constexpr detail::RangeEnumerator enumerate(Range&& r) { + return detail::RangeEnumerator(std::forward(r)); +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/ExclusivelyOwnedTensorTraits.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/ExclusivelyOwnedTensorTraits.h new file mode 100644 index 0000000000000000000000000000000000000000..5b3a76fe9fc94776a70538d212e657435189b350 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/ExclusivelyOwnedTensorTraits.h @@ -0,0 +1,80 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +#include + +namespace c10 { +// Shared ExclusivelyOwnedTraits implementation between caffe2::Tensor and +// at::TensorBase. +template +struct ExclusivelyOwnedTensorTraits { + using repr_type = TensorType; + using pointer_type = TensorType*; + using const_pointer_type = const TensorType*; + + static repr_type nullRepr() { + return TensorType(); + } + + template + static repr_type createInPlace(Args&&... args) { + return TensorType(std::forward(args)...); + } + + static repr_type moveToRepr(TensorType&& x) { + return std::move(x); + } + + static void destroyOwned(TensorType& x) { + TensorImpl* const toDestroy = x.unsafeReleaseTensorImpl(); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + toDestroy != nullptr, "Tensor somehow got null TensorImpl?"); + // May be 0 because UndefinedTensorImpl doesn't get its refcount + // incremented. + const bool isUndefined = toDestroy == UndefinedTensorImpl::singleton(); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + toDestroy->refcount() == 1 || + (toDestroy->refcount() == 0 && isUndefined), + "ExclusivelyOwned destroyed with isUndefined ", + isUndefined, + " and refcount ", + toDestroy->refcount(), + ", expected 1 or, if isUndefined, 0!"); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + toDestroy->weakcount() == 1 || + (toDestroy->weakcount() == 0 && + toDestroy == UndefinedTensorImpl::singleton()), + "ExclusivelyOwned destroyed with isUndefined ", + isUndefined, + " and weakcount ", + toDestroy->weakcount(), + ", expected 1 or, if isUndefined, 0!"); + if (!isUndefined) { +#ifndef NDEBUG + // Needed to pass the debug assertions in ~intrusive_ptr_target. + toDestroy->combined_refcount_.store(0, std::memory_order_relaxed); +#endif + delete toDestroy; + } + } + + static TensorType take(TensorType& x) { + return std::move(x); + } + + static pointer_type getImpl(repr_type& x) { + return &x; + } + + static const_pointer_type getImpl(const repr_type& x) { + return &x; + } +}; +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/FbcodeMaps.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/FbcodeMaps.h new file mode 100644 index 0000000000000000000000000000000000000000..8ce3648d928f50cf474d26cab63c16df16dda728 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/FbcodeMaps.h @@ -0,0 +1,34 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef C10_UTIL_FBCODEMAPS_H_ +#define C10_UTIL_FBCODEMAPS_H_ + +// Map typedefs so that we can use folly's F14 maps in fbcode without +// taking a folly dependency. + +#ifdef FBCODE_CAFFE2 +#include +#include +#else +#include +#include +#endif + +namespace c10 { +#ifdef FBCODE_CAFFE2 +template +using FastMap = folly::F14FastMap; +template +using FastSet = folly::F14FastSet; +#else +template +using FastMap = std::unordered_map; +template +using FastSet = std::unordered_set; +#endif +} // namespace c10 + +#endif // C10_UTIL_FBCODEMAPS_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Float8_e4m3fn-inl.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Float8_e4m3fn-inl.h new file mode 100644 index 0000000000000000000000000000000000000000..ed07b955168f7ab08b4a20657d8f36ea7cd4123c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Float8_e4m3fn-inl.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Float8_e4m3fnuz-inl.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Float8_e4m3fnuz-inl.h new file mode 100644 index 0000000000000000000000000000000000000000..30481a62430fdf08f2107bc1ab50e811314767f3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Float8_e4m3fnuz-inl.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Float8_e5m2fnuz-inl.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Float8_e5m2fnuz-inl.h new file mode 100644 index 0000000000000000000000000000000000000000..f3e8c25099a630204f3c4ee345fd2a3653c14116 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Float8_e5m2fnuz-inl.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Float8_e8m0fnu-inl.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Float8_e8m0fnu-inl.h new file mode 100644 index 0000000000000000000000000000000000000000..030b23d64750b7378c8fc281c96d2fe662e38d88 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Float8_e8m0fnu-inl.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/FunctionRef.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/FunctionRef.h new file mode 100644 index 0000000000000000000000000000000000000000..342824b5b9095219b123ab4bfb19fbb3cd1a7819 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/FunctionRef.h @@ -0,0 +1,80 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +//===- llvm/ADT/STLExtras.h - Useful STL related functions ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains some templates that are useful if you are working with the +// STL at all. +// +// No library is required when using these functions. +// +//===----------------------------------------------------------------------===// + +// c10: modified from llvm::function_ref +// c10: added more SFINAE to enable use in overloaded functions + +#pragma once + +#include +#include +#include + +namespace c10 { + +/// An efficient, type-erasing, non-owning reference to a callable. This is +/// intended for use as the type of a function parameter that is not used +/// after the function in question returns. +/// +/// This class does not own the callable, so it is not in general safe to store +/// a function_ref. +template +class function_ref; + +template +class function_ref { + Ret (*callback)(intptr_t callable, Params... params) = nullptr; + intptr_t callable{}; + + template + static Ret callback_fn(intptr_t callable, Params... params) { + return (*reinterpret_cast(callable))( + std::forward(params)...); + } + + public: + function_ref() = default; + function_ref(std::nullptr_t) {} + + template + function_ref( + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) + Callable&& callable, + std::enable_if_t, + function_ref>>* /*unused*/ + = nullptr, + std::enable_if_t, + Ret>>* /*unused*/ + = nullptr) + : callback(callback_fn>), + callable(reinterpret_cast(&callable)) {} + + Ret operator()(Params... params) const { + return callback(callable, std::forward(params)...); + } + + operator bool() const { + return callback; + } +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Gauge.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Gauge.h new file mode 100644 index 0000000000000000000000000000000000000000..b10ed7f5c9b33b99adbd031069af7c4e2fd3d0e3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Gauge.h @@ -0,0 +1,55 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +#include +#include + +namespace c10::monitor { +namespace detail { + +class GaugeImpl; + +class GaugeBackendIf { + public: + virtual ~GaugeBackendIf() = default; + virtual void record(int64_t value) noexcept = 0; +}; + +class GaugeBackendFactoryIf { + public: + virtual ~GaugeBackendFactoryIf() = default; + + // May return nullptr if the gauge will be ignored by the given backend. + virtual std::unique_ptr create( + std::string_view key) noexcept = 0; +}; + +void C10_API + registerGaugeBackend(std::unique_ptr /*backend*/); +} // namespace detail + +// A handle to a Gauge. +class C10_API GaugeHandle { + public: + explicit GaugeHandle(std::string_view key); + void record(int64_t value); + + private: + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + detail::GaugeImpl& impl_; +}; + +} // namespace c10::monitor + +#define STATIC_GAUGE(_key) \ + []() -> ::c10::monitor::GaugeHandle& { \ + static ::c10::monitor::GaugeHandle handle(#_key); \ + return handle; \ + }() + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Half.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Half.h new file mode 100644 index 0000000000000000000000000000000000000000..0a3d4462657c7aa4d4e3827a2de811132911632b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Half.h @@ -0,0 +1,13 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +// need to keep the following for BC because the APIs in here were exposed +// before migrating Half to torch/headeronly +#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \ + !defined(__APPLE__) +#include +#endif + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Lazy.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Lazy.h new file mode 100644 index 0000000000000000000000000000000000000000..204fc205ef9940c397de23915c4fee7dba8673ec --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Lazy.h @@ -0,0 +1,125 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace c10 { + +/** + * Thread-safe lazy value with opportunistic concurrency: on concurrent first + * access, the factory may be called by multiple threads, but only one result is + * stored and its reference returned to all the callers. + * + * Value is heap-allocated; this optimizes for the case in which the value is + * never actually computed. + */ +template +class OptimisticLazy { + public: + OptimisticLazy() = default; + OptimisticLazy(const OptimisticLazy& other) { + if (T* value = other.value_.load(std::memory_order_acquire)) { + value_ = new T(*value); + } + } + OptimisticLazy(OptimisticLazy&& other) noexcept + : value_(other.value_.exchange(nullptr, std::memory_order_acq_rel)) {} + ~OptimisticLazy() { + reset(); + } + + template + T& ensure(const Factory& factory) { + if (T* value = value_.load(std::memory_order_acquire)) { + return *value; + } + T* value = new T(factory()); + T* old = nullptr; + if (!value_.compare_exchange_strong( + old, value, std::memory_order_release, std::memory_order_acquire)) { + delete value; + value = old; + } + return *value; + } + + // The following methods are not thread-safe: they should not be called + // concurrently with any other method. + + OptimisticLazy& operator=(const OptimisticLazy& other) { + *this = OptimisticLazy{other}; + return *this; + } + + OptimisticLazy& operator=(OptimisticLazy&& other) noexcept { + if (this != &other) { + reset(); + value_.store( + other.value_.exchange(nullptr, std::memory_order_acquire), + std::memory_order_release); + } + return *this; + } + + void reset() { + if (T* old = value_.load(std::memory_order_relaxed)) { + value_.store(nullptr, std::memory_order_relaxed); + delete old; + } + } + + private: + std::atomic value_{nullptr}; +}; + +/** + * Interface for a value that is computed on first access. + */ +template +class LazyValue { + public: + virtual ~LazyValue() = default; + + virtual const T& get() const = 0; +}; + +/** + * Convenience thread-safe LazyValue implementation with opportunistic + * concurrency. + */ +template +class OptimisticLazyValue : public LazyValue { + public: + const T& get() const override { + return value_.ensure([this] { return compute(); }); + } + + private: + virtual T compute() const = 0; + + mutable OptimisticLazy value_; +}; + +/** + * Convenience immutable (thus thread-safe) LazyValue implementation for cases + * in which the value is not actually lazy. + */ +template +class PrecomputedLazyValue : public LazyValue { + public: + PrecomputedLazyValue(T value) : value_(std::move(value)) {} + + const T& get() const override { + return value_; + } + + private: + T value_; +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/LeftRight.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/LeftRight.h new file mode 100644 index 0000000000000000000000000000000000000000..0435fffb73fdd7a8e6ef0cedc7d6feac6b818651 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/LeftRight.h @@ -0,0 +1,234 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace c10 { + +namespace detail { + +struct IncrementRAII final { + public: + explicit IncrementRAII(std::atomic* counter) : _counter(counter) { + _counter->fetch_add(1); + } + + ~IncrementRAII() { + _counter->fetch_sub(1); + } + IncrementRAII(IncrementRAII&&) = delete; + IncrementRAII& operator=(IncrementRAII&&) = delete; + + private: + std::atomic* _counter; + + C10_DISABLE_COPY_AND_ASSIGN(IncrementRAII); +}; + +} // namespace detail + +// LeftRight wait-free readers synchronization primitive +// https://hal.archives-ouvertes.fr/hal-01207881/document +// +// LeftRight is quite easy to use (it can make an arbitrary +// data structure permit wait-free reads), but it has some +// particular performance characteristics you should be aware +// of if you're deciding to use it: +// +// - Reads still incur an atomic write (this is how LeftRight +// keeps track of how long it needs to keep around the old +// data structure) +// +// - Writes get executed twice, to keep both the left and right +// versions up to date. So if your write is expensive or +// nondeterministic, this is also an inappropriate structure +// +// LeftRight is used fairly rarely in PyTorch's codebase. If you +// are still not sure if you need it or not, consult your local +// C++ expert. +// +template +class LeftRight final { + public: + template + explicit LeftRight(const Args&... args) + : _counters{{{0}, {0}}}, + _foregroundCounterIndex(0), + _foregroundDataIndex(0), + _data{{T{args...}, T{args...}}} {} + + // Copying and moving would not be threadsafe. + // Needs more thought and careful design to make that work. + LeftRight(const LeftRight&) = delete; + LeftRight(LeftRight&&) noexcept = delete; + LeftRight& operator=(const LeftRight&) = delete; + LeftRight& operator=(LeftRight&&) noexcept = delete; + + ~LeftRight() { + // wait until any potentially running writers are finished + { + std::unique_lock lock(_writeMutex); + } + + // wait until any potentially running readers are finished + while (_counters[0].load() != 0 || _counters[1].load() != 0) { + std::this_thread::yield(); + } + } + + template + auto read(F&& readFunc) const { + detail::IncrementRAII _increment_counter( + &_counters[_foregroundCounterIndex.load()]); + + return std::forward(readFunc)(_data[_foregroundDataIndex.load()]); + } + + // Throwing an exception in writeFunc is ok but causes the state to be either + // the old or the new state, depending on if the first or the second call to + // writeFunc threw. + template + auto write(F&& writeFunc) { + std::unique_lock lock(_writeMutex); + + return _write(std::forward(writeFunc)); + } + + private: + template + auto _write(const F& writeFunc) { + /* + * Assume, A is in background and B in foreground. In simplified terms, we + * want to do the following: + * 1. Write to A (old background) + * 2. Switch A/B + * 3. Write to B (new background) + * + * More detailed algorithm (explanations on why this is important are below + * in code): + * 1. Write to A + * 2. Switch A/B data pointers + * 3. Wait until A counter is zero + * 4. Switch A/B counters + * 5. Wait until B counter is zero + * 6. Write to B + */ + + auto localDataIndex = _foregroundDataIndex.load(); + + // 1. Write to A + _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex); + + // 2. Switch A/B data pointers + localDataIndex = localDataIndex ^ 1; + _foregroundDataIndex = localDataIndex; + + /* + * 3. Wait until A counter is zero + * + * In the previous write run, A was foreground and B was background. + * There was a time after switching _foregroundDataIndex (B to foreground) + * and before switching _foregroundCounterIndex, in which new readers could + * have read B but incremented A's counter. + * + * In this current run, we just switched _foregroundDataIndex (A back to + * foreground), but before writing to the new background B, we have to make + * sure A's counter was zero briefly, so all these old readers are gone. + */ + auto localCounterIndex = _foregroundCounterIndex.load(); + _waitForBackgroundCounterToBeZero(localCounterIndex); + + /* + * 4. Switch A/B counters + * + * Now that we know all readers on B are really gone, we can switch the + * counters and have new readers increment A's counter again, which is the + * correct counter since they're reading A. + */ + localCounterIndex = localCounterIndex ^ 1; + _foregroundCounterIndex = localCounterIndex; + + /* + * 5. Wait until B counter is zero + * + * This waits for all the readers on B that came in while both data and + * counter for B was in foreground, i.e. normal readers that happened + * outside of that brief gap between switching data and counter. + */ + _waitForBackgroundCounterToBeZero(localCounterIndex); + + // 6. Write to B + return _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex); + } + + template + auto _callWriteFuncOnBackgroundInstance( + const F& writeFunc, + uint8_t localDataIndex) { + try { + return writeFunc(_data[localDataIndex ^ 1]); + } catch (...) { + // recover invariant by copying from the foreground instance + _data[localDataIndex ^ 1] = _data[localDataIndex]; + // rethrow + throw; + } + } + + void _waitForBackgroundCounterToBeZero(uint8_t counterIndex) { + while (_counters[counterIndex ^ 1].load() != 0) { + std::this_thread::yield(); + } + } + + mutable std::array, 2> _counters; + std::atomic _foregroundCounterIndex; + std::atomic _foregroundDataIndex; + std::array _data; + std::mutex _writeMutex; +}; + +// RWSafeLeftRightWrapper is API compatible with LeftRight and uses a +// read-write lock to protect T (data). +template +class RWSafeLeftRightWrapper final { + public: + template + explicit RWSafeLeftRightWrapper(const Args&... args) : data_{args...} {} + + // RWSafeLeftRightWrapper is not copyable or moveable since LeftRight + // is not copyable or moveable. + RWSafeLeftRightWrapper(const RWSafeLeftRightWrapper&) = delete; + RWSafeLeftRightWrapper(RWSafeLeftRightWrapper&&) noexcept = delete; + RWSafeLeftRightWrapper& operator=(const RWSafeLeftRightWrapper&) = delete; + RWSafeLeftRightWrapper& operator=(RWSafeLeftRightWrapper&&) noexcept = delete; + ~RWSafeLeftRightWrapper() = default; + + template + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) + auto read(F&& readFunc) const { + return data_.withLock( + [&readFunc](T const& data) { return std::forward(readFunc)(data); }); + } + + template + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) + auto write(F&& writeFunc) { + return data_.withLock( + [&writeFunc](T& data) { return std::forward(writeFunc)(data); }); + } + + private: + c10::Synchronized data_; +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Load.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Load.h new file mode 100644 index 0000000000000000000000000000000000000000..38aef4c1ea38d790799e49f3f594ff8c3c7a0d78 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Load.h @@ -0,0 +1,43 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include + +namespace c10 { +namespace detail { + +template +struct LoadImpl { + C10_HOST_DEVICE static T apply(const void* src) { + return *reinterpret_cast(src); + } +}; + +template <> +struct LoadImpl { + C10_HOST_DEVICE static bool apply(const void* src) { + static_assert(sizeof(bool) == sizeof(char)); + // NOTE: [Loading boolean values] + // Protect against invalid boolean values by loading as a byte + // first, then converting to bool (see gh-54789). + return *reinterpret_cast(src); + } +}; + +} // namespace detail + +template +C10_HOST_DEVICE constexpr T load(const void* src) { + return c10::detail::LoadImpl::apply(src); +} + +template +C10_HOST_DEVICE constexpr scalar_t load(const scalar_t* src) { + return c10::detail::LoadImpl::apply(src); +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Logging.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Logging.h new file mode 100644 index 0000000000000000000000000000000000000000..49420110eb333a07e7b15cc6f21a8c77af52e84d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Logging.h @@ -0,0 +1,378 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef C10_UTIL_LOGGING_H_ +#define C10_UTIL_LOGGING_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +// CAFFE2_LOG_THRESHOLD is a compile time flag that would allow us to turn off +// logging at compile time so no logging message below that level is produced +// at all. The value should be between INT_MIN and CAFFE_FATAL. +#ifndef CAFFE2_LOG_THRESHOLD +// If we have not defined the compile time log threshold, we keep all the +// log cases. +#define CAFFE2_LOG_THRESHOLD INT_MIN +#endif // CAFFE2_LOG_THRESHOLD + +// Below are different implementations for glog and non-glog cases. +#ifdef C10_USE_GLOG +#include +#else // !C10_USE_GLOG +#include +#endif // C10_USE_GLOG + +C10_DECLARE_int(caffe2_log_level); +C10_DECLARE_bool(caffe2_use_fatal_for_enforce); + +// Some versions of GLOG support less-spammy version of LOG_EVERY_MS. If it's +// not available - just short-circuit to the always working one one. +// We define the C10_ name to avoid confusing other files +#ifdef LOG_EVERY_MS +#define C10_LOG_EVERY_MS(severity, ms) LOG_EVERY_MS(severity, ms) +#else +#define C10_LOG_EVERY_MS(severity, ms) LOG(severity) +#endif + +// Same for LOG_FIRST_N +#ifdef LOG_FIRST_N +#define C10_LOG_FIRST_N(severity, n) LOG_FIRST_N(severity, n) +#else +#define C10_LOG_FIRST_N(severity, n) LOG(severity) +#endif + +// Same for LOG_EVERY_N +#ifdef LOG_EVERY_N +#define C10_LOG_EVERY_N(severity, n) LOG_EVERY_N(severity, n) +#else +#define C10_LOG_EVERY_N(severity, n) LOG(severity) +#endif + +namespace c10 { + +#if !defined(C10_NODEPRECATED) +using std::string; +#endif + +// Functions that we use for initialization. +C10_API bool InitCaffeLogging(int* argc, char** argv); +C10_API void UpdateLoggingLevelsFromFlags(); + +[[noreturn]] C10_API void ThrowEnforceNotMet( + const char* file, + const int line, + const char* condition, + const std::string& msg, + const void* caller = nullptr); + +[[noreturn]] C10_API void ThrowEnforceNotMet( + const char* file, + const int line, + const char* condition, + const char* msg, + const void* caller = nullptr); + +[[noreturn]] inline void ThrowEnforceNotMet( + const char* file, + const int line, + const char* condition, + detail::CompileTimeEmptyString /*msg*/, + const void* caller = nullptr) { + ThrowEnforceNotMet(file, line, condition, "", caller); +} + +[[noreturn]] C10_API void ThrowEnforceFiniteNotMet( + const char* file, + const int line, + const char* condition, + const std::string& msg, + const void* caller = nullptr); + +[[noreturn]] C10_API void ThrowEnforceFiniteNotMet( + const char* file, + const int line, + const char* condition, + const char* msg, + const void* caller = nullptr); + +[[noreturn]] inline void ThrowEnforceFiniteNotMet( + const char* file, + const int line, + const char* condition, + detail::CompileTimeEmptyString /*msg*/, + const void* caller = nullptr) { + ThrowEnforceFiniteNotMet(file, line, condition, "", caller); +} + +constexpr bool IsUsingGoogleLogging() { +#ifdef C10_USE_GLOG + return true; +#else + return false; +#endif +} + +/** + * A utility to allow one to show log info to stderr after the program starts. + * + * This is similar to calling GLOG's --logtostderr, or setting caffe2_log_level + * to smaller than INFO. You are recommended to only use this in a few sparse + * cases, such as when you want to write a tutorial or something. Normally, use + * the commandline flags to set the log level. + */ +C10_API void ShowLogInfoToStderr(); + +C10_API void SetStackTraceFetcher(std::function<::c10::Backtrace()> fetcher); + +/** + * Convenience function for non-lazy stack trace fetchers. The Backtrace + * overload should be preferred when stringifying the backtrace is expensive. + */ +C10_API void SetStackTraceFetcher(std::function fetcher); + +using EnforceNotMet = ::c10::Error; + +#define CAFFE_ENFORCE(condition, ...) \ + do { \ + if (C10_UNLIKELY(!(condition))) { \ + ::c10::ThrowEnforceNotMet( \ + __FILE__, __LINE__, #condition, ::c10::str(__VA_ARGS__)); \ + } \ + } while (false) + +#define CAFFE_ENFORCE_FINITE(condition, ...) \ + do { \ + if (C10_UNLIKELY(!(condition))) { \ + ::c10::ThrowEnforceFiniteNotMet( \ + __FILE__, __LINE__, #condition, ::c10::str(__VA_ARGS__)); \ + } \ + } while (false) + +#define CAFFE_ENFORCE_WITH_CALLER(condition, ...) \ + do { \ + if (C10_UNLIKELY(!(condition))) { \ + ::c10::ThrowEnforceNotMet( \ + __FILE__, __LINE__, #condition, ::c10::str(__VA_ARGS__), this); \ + } \ + } while (false) + +#define CAFFE_THROW(...) \ + ::c10::ThrowEnforceNotMet(__FILE__, __LINE__, "", ::c10::str(__VA_ARGS__)) + +/** + * Rich logging messages + * + * CAFFE_ENFORCE_THAT can be used with one of the "checker functions" that + * capture input argument values and add it to the exception message. E.g. + * `CAFFE_ENFORCE_THAT(Equals(foo(x), bar(y)), "Optional additional message")` + * would evaluate both foo and bar only once and if the results are not equal - + * include them in the exception message. + * + * Some of the basic checker functions like Equals or Greater are already + * defined below. Other header might define customized checkers by adding + * functions to caffe2::enforce_detail namespace. For example: + * + * namespace caffe2 { namespace enforce_detail { + * inline EnforceFailMessage IsVector(const vector& shape) { + * if (shape.size() == 1) { return EnforceOK(); } + * return c10::str("Shape ", shape, " is not a vector"); + * } + * }} + * + * With further usages like `CAFFE_ENFORCE_THAT(IsVector(Input(0).dims()))` + * + * Convenient wrappers for binary operations like CAFFE_ENFORCE_EQ are provided + * too. Please use them instead of TORCH_CHECK_EQ and friends for failures in + * user-provided input. + */ + +namespace enforce_detail { + +template +std::string enforceFailMsgImpl(const T1& x, const T2& y) { + return c10::str(x, " vs ", y); +} + +template +std::string enforceFailMsgImpl(const T1& x, const T2& y, const Args&... args) { + return c10::str(x, " vs ", y, ". ", args...); +} + +template +void enforceThatImpl( + Pred p, + const T1& lhs, + const T2& rhs, + const char* file, + int line, + const char* expr, + const void* caller, + GetFailMsgFunc getFailMsg) { + if (C10_UNLIKELY(!(p(lhs, rhs)))) { + ::c10::ThrowEnforceNotMet(file, line, expr, getFailMsg(lhs, rhs), caller); + } +} + +#define CAFFE_ENFORCE_THAT_IMPL(op, lhs, rhs, expr, ...) \ + ::c10::enforce_detail::enforceThatImpl( \ + op, \ + (lhs), \ + (rhs), \ + __FILE__, \ + __LINE__, \ + expr, \ + nullptr, \ + [&](const auto& arg1, const auto& arg2) { \ + return ::c10::enforce_detail::enforceFailMsgImpl( \ + arg1, arg2, ##__VA_ARGS__); \ + }) + +#define CAFFE_ENFORCE_THAT_IMPL_WITH_CALLER(op, lhs, rhs, expr, ...) \ + ::c10::enforce_detail::enforceThatImpl( \ + op, \ + (lhs), \ + (rhs), \ + __FILE__, \ + __LINE__, \ + expr, \ + this, \ + [&](const auto& arg1, const auto& arg2) { \ + return ::c10::enforce_detail::enforceFailMsgImpl( \ + arg1, arg2, ##__VA_ARGS__); \ + }) + +} // namespace enforce_detail + +#define CAFFE_ENFORCE_THAT(cmp, op, lhs, rhs, ...) \ + CAFFE_ENFORCE_THAT_IMPL(cmp, lhs, rhs, #lhs " " #op " " #rhs, ##__VA_ARGS__) + +#define CAFFE_ENFORCE_BINARY_OP(cmp, op, x, y, ...) \ + CAFFE_ENFORCE_THAT_IMPL(cmp, x, y, #x " " #op " " #y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_EQ(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP(std::equal_to(), ==, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_NE(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP(std::not_equal_to(), !=, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_LE(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP(std::less_equal(), <=, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_LT(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP(std::less(), <, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_GE(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP(std::greater_equal(), >=, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_GT(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP(std::greater(), >, x, y, ##__VA_ARGS__) + +#define CAFFE_ENFORCE_BINARY_OP_WITH_CALLER(cmp, op, x, y, ...) \ + CAFFE_ENFORCE_THAT_IMPL_WITH_CALLER( \ + cmp, x, y, #x " " #op " " #y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_EQ_WITH_CALLER(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP_WITH_CALLER( \ + std::equal_to(), ==, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_NE_WITH_CALLER(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP_WITH_CALLER( \ + std::not_equal_to(), !=, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_LE_WITH_CALLER(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP_WITH_CALLER( \ + std::less_equal(), <=, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_LT_WITH_CALLER(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP_WITH_CALLER(std::less(), <, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_GE_WITH_CALLER(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP_WITH_CALLER( \ + std::greater_equal(), >=, x, y, ##__VA_ARGS__) +#define CAFFE_ENFORCE_GT_WITH_CALLER(x, y, ...) \ + CAFFE_ENFORCE_BINARY_OP_WITH_CALLER( \ + std::greater(), >, x, y, ##__VA_ARGS__) + +struct IValue; +class C10_API EventSampledHandler { + public: + virtual void log( + std::string_view model_id, + const std::vector& args) = 0; + virtual ~EventSampledHandler() = default; +}; + +#define C10_LOG_EVENT_SAMPLED(event, ...) \ + static const std::unique_ptr<::c10::EventSampledHandler>& \ + _##event##EventSampledHandler = ::c10::GetEventSampledHandler(#event); \ + if (_##event##EventSampledHandler) { \ + _##event##EventSampledHandler->log(__VA_ARGS__); \ + } + +// Must be called in the main thread before any other threads are spawned. +C10_API void InitEventSampledHandlers( + std::vector>> /*handlers*/); +C10_API const std::unique_ptr& GetEventSampledHandler( + std::string_view /*event*/); + +/** + * Very lightweight logging for the first time API usage. It's beneficial for + * tracking of individual functionality usage in larger applications. + * + * In order to ensure light-weightedness of logging, we utilize static variable + * trick - LogAPIUsage will be invoked only once and further invocations will + * just do an atomic check. + * + * Example: + * // Logs caller info with an arbitrary text event, if there is a usage. + * C10_LOG_API_USAGE_ONCE("my_api"); + */ +#define C10_LOG_API_USAGE_ONCE(...) \ + [[maybe_unused]] static bool C10_ANONYMOUS_VARIABLE(logFlag) = \ + ::c10::detail::LogAPIUsageFakeReturn(__VA_ARGS__); + +// API usage logging capabilities +C10_API void SetAPIUsageLogger(std::function logger); +C10_API void LogAPIUsage(const std::string& context); + +C10_API void SetAPIUsageMetadataLogger( + std::function& metadata_map)> logger); +C10_API void LogAPIUsageMetadata( + const std::string& context, + const std::map& metadata_map); + +// PyTorch ddp usage logging capabilities +// DDPLoggingData holds data that can be logged in applications +// for analysis and debugging. Data structure is defined in +// c10 directory so that it can be easily imported by both c10 +// and torch files. +struct DDPLoggingData { + // logging fields that are string types. + std::map strs_map; + // logging fields that are int64_t types. + std::map ints_map; +}; + +C10_API void SetPyTorchDDPUsageLogger( + std::function logger); +C10_API void LogPyTorchDDPUsage(const DDPLoggingData& ddpData); + +namespace detail { +// Return value is needed to do the static variable initialization trick +C10_API bool LogAPIUsageFakeReturn(const std::string& context); +} // namespace detail + +// Initializes the c10 logger. +C10_API void initLogging(); + +// Sets the rank, which will be included in log messages +C10_API void SetGlobalRank(int64_t rank); + +} // namespace c10 + +#endif // C10_UTIL_LOGGING_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/MaybeOwned.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/MaybeOwned.h new file mode 100644 index 0000000000000000000000000000000000000000..61e6ed82f27a4a2b91300f0987612f5a03c3bea1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/MaybeOwned.h @@ -0,0 +1,242 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +#include +#include +#include + +namespace c10 { + +/// MaybeOwnedTraits describes how to borrow from T. Here is how we +/// can implement borrowing from an arbitrary type T using a raw +/// pointer to const: +template +struct MaybeOwnedTraitsGenericImpl { + using owned_type = T; + using borrow_type = const T*; + + static borrow_type createBorrow(const owned_type& from) { + return &from; + } + + static void assignBorrow(borrow_type& lhs, borrow_type rhs) { + lhs = rhs; + } + + static void destroyBorrow(borrow_type& /*toDestroy*/) {} + + static const owned_type& referenceFromBorrow(const borrow_type& borrow) { + return *borrow; + } + + static const owned_type* pointerFromBorrow(const borrow_type& borrow) { + return borrow; + } + + static bool debugBorrowIsValid(const borrow_type& borrow) { + return borrow != nullptr; + } +}; + +/// It is possible to eliminate the extra layer of indirection for +/// borrows for some types that we control. For examples, see +/// intrusive_ptr.h and TensorBody.h. + +template +struct MaybeOwnedTraits; + +// Explicitly enable MaybeOwned>, rather than allowing +// MaybeOwned to be used for any type right away. +template +struct MaybeOwnedTraits> + : public MaybeOwnedTraitsGenericImpl> {}; + +/// A smart pointer around either a borrowed or owned T. When +/// constructed with borrowed(), the caller MUST ensure that the +/// borrowed-from argument outlives this MaybeOwned. Compare to +/// Rust's std::borrow::Cow +/// (https://doc.rust-lang.org/std/borrow/enum.Cow.html), but note +/// that it is probably not suitable for general use because C++ has +/// no borrow checking. Included here to support +/// Tensor::expect_contiguous. +template +class MaybeOwned final { + using borrow_type = typename MaybeOwnedTraits::borrow_type; + using owned_type = typename MaybeOwnedTraits::owned_type; + + bool isBorrowed_; + union { + borrow_type borrow_; + owned_type own_; + }; + + /// Don't use this; use borrowed() instead. + explicit MaybeOwned(const owned_type& t) + : isBorrowed_(true), borrow_(MaybeOwnedTraits::createBorrow(t)) {} + + /// Don't use this; use owned() instead. + explicit MaybeOwned(T&& t) noexcept(std::is_nothrow_move_constructible_v) + : isBorrowed_(false), own_(std::move(t)) {} + + /// Don't use this; use owned() instead. + template + explicit MaybeOwned(std::in_place_t /*unused*/, Args&&... args) + : isBorrowed_(false), own_(std::forward(args)...) {} + + public: + explicit MaybeOwned() : isBorrowed_(true), borrow_() {} + + // Copying a borrow yields another borrow of the original, as with a + // T*. Copying an owned T yields another owned T for safety: no + // chains of borrowing by default! (Note you could get that behavior + // with MaybeOwned::borrowed(*rhs) if you wanted it.) + MaybeOwned(const MaybeOwned& rhs) : isBorrowed_(rhs.isBorrowed_) { + if (C10_LIKELY(rhs.isBorrowed_)) { + MaybeOwnedTraits::assignBorrow(borrow_, rhs.borrow_); + } else { + new (&own_) T(rhs.own_); + } + } + + MaybeOwned& operator=(const MaybeOwned& rhs) { + if (this == &rhs) { + return *this; + } + if (C10_UNLIKELY(!isBorrowed_)) { + if (rhs.isBorrowed_) { + own_.~T(); + MaybeOwnedTraits::assignBorrow(borrow_, rhs.borrow_); + isBorrowed_ = true; + } else { + own_ = rhs.own_; + } + } else { + if (C10_LIKELY(rhs.isBorrowed_)) { + MaybeOwnedTraits::assignBorrow(borrow_, rhs.borrow_); + } else { + MaybeOwnedTraits::destroyBorrow(borrow_); + new (&own_) T(rhs.own_); + isBorrowed_ = false; + } + } + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isBorrowed_ == rhs.isBorrowed_); + return *this; + } + + MaybeOwned(MaybeOwned&& rhs) noexcept( + // NOLINTNEXTLINE(*-noexcept-move-*) + std::is_nothrow_move_constructible_v && + std::is_nothrow_move_assignable_v) + : isBorrowed_(rhs.isBorrowed_) { + if (C10_LIKELY(rhs.isBorrowed_)) { + MaybeOwnedTraits::assignBorrow(borrow_, rhs.borrow_); + } else { + new (&own_) T(std::move(rhs.own_)); + } + } + + MaybeOwned& operator=(MaybeOwned&& rhs) noexcept( + std::is_nothrow_move_assignable_v && + std::is_nothrow_move_assignable_v && + std::is_nothrow_move_constructible_v && + // NOLINTNEXTLINE(*-noexcept-move-*) + std::is_nothrow_destructible_v && + std::is_nothrow_destructible_v) { + if (this == &rhs) { + return *this; + } + if (C10_UNLIKELY(!isBorrowed_)) { + if (rhs.isBorrowed_) { + own_.~T(); + MaybeOwnedTraits::assignBorrow(borrow_, rhs.borrow_); + isBorrowed_ = true; + } else { + own_ = std::move(rhs.own_); + } + } else { + if (C10_LIKELY(rhs.isBorrowed_)) { + MaybeOwnedTraits::assignBorrow(borrow_, rhs.borrow_); + } else { + MaybeOwnedTraits::destroyBorrow(borrow_); + new (&own_) T(std::move(rhs.own_)); + isBorrowed_ = false; + } + } + return *this; + } + + static MaybeOwned borrowed(const T& t) { + return MaybeOwned(t); + } + + static MaybeOwned owned(T&& t) noexcept( + std::is_nothrow_move_constructible_v) { + return MaybeOwned(std::move(t)); + } + + template + static MaybeOwned owned(std::in_place_t /*unused*/, Args&&... args) { + return MaybeOwned(std::in_place, std::forward(args)...); + } + + ~MaybeOwned() noexcept( + // NOLINTNEXTLINE(*-noexcept-destructor) + std::is_nothrow_destructible_v && + std::is_nothrow_destructible_v) { + if (C10_UNLIKELY(!isBorrowed_)) { + own_.~T(); + } else { + MaybeOwnedTraits::destroyBorrow(borrow_); + } + } + + // This is an implementation detail! You should know what you're doing + // if you are testing this. If you just want to guarantee ownership move + // this into a T + bool unsafeIsBorrowed() const { + return isBorrowed_; + } + + const T& operator*() const& { + if (isBorrowed_) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + MaybeOwnedTraits::debugBorrowIsValid(borrow_)); + } + return C10_LIKELY(isBorrowed_) + ? MaybeOwnedTraits::referenceFromBorrow(borrow_) + : own_; + } + + const T* operator->() const { + if (isBorrowed_) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + MaybeOwnedTraits::debugBorrowIsValid(borrow_)); + } + return C10_LIKELY(isBorrowed_) + ? MaybeOwnedTraits::pointerFromBorrow(borrow_) + : &own_; + } + + // If borrowed, copy the underlying T. If owned, move from + // it. borrowed/owned state remains the same, and either we + // reference the same borrow as before or we are an owned moved-from + // T. + T operator*() && { + if (isBorrowed_) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + MaybeOwnedTraits::debugBorrowIsValid(borrow_)); + return MaybeOwnedTraits::referenceFromBorrow(borrow_); + } else { + return std::move(own_); + } + } +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Metaprogramming.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Metaprogramming.h new file mode 100644 index 0000000000000000000000000000000000000000..55c3fb2ba6db0dbc8bf8d00e616aefc3acab7c85 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Metaprogramming.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/NetworkFlow.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/NetworkFlow.h new file mode 100644 index 0000000000000000000000000000000000000000..e029ae65773be41aa7d05402fc3e1c3d50dbb8a8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/NetworkFlow.h @@ -0,0 +1,59 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +#include +#include + +/** + * This file provides a network flow implementation. + * https://en.wikipedia.org/wiki/Flow_network + * + * It aims to mirror some of the behavior of networkx, which is/was used by + * functorch partitioners for splitting the graph into a forward and backward + * graph. + */ + +namespace c10 { + +enum class C10_API_ENUM MinCutStatus { + SUCCESS = 0, + UNBOUNDED = 1, + OVERFLOW_INF = 2, + INVALID = 3, +}; + +struct MinCutResult { + MinCutStatus status; + int64_t max_flow; + std::vector reachable; + std::vector unreachable; +}; + +// Modeled after networkx implementation +class C10_API NetworkFlowGraph { + public: + // selected such that INF + INF is < INT64_MAX + constexpr static int64_t INF = (1LL << 62) - 1; + + struct Edge { + std::string source, dest; + int64_t capacity; + }; + + MinCutStatus add_edge( + const std::string& source, + const std::string& dest, + int64_t capacity = 1); + + MinCutResult minimum_cut(const std::string& s, const std::string& t) const; + + std::vector edges; +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/OptionalArrayRef.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/OptionalArrayRef.h new file mode 100644 index 0000000000000000000000000000000000000000..cd15a5f19d1db7673c8f3485a136d1730e34f433 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/OptionalArrayRef.h @@ -0,0 +1,242 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// This file defines OptionalArrayRef, a class that has almost the same +// exact functionality as std::optional>, except that its +// converting constructor fixes a dangling pointer issue. +// +// The implicit converting constructor of both std::optional> and +// std::optional> can cause the underlying ArrayRef to store +// a dangling pointer. OptionalArrayRef prevents this by wrapping +// a std::optional> and fixing the constructor implementation. +// +// See https://github.com/pytorch/pytorch/issues/63645 for more on this. + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace c10 { + +template +class OptionalArrayRef final { + public: + // Constructors + + constexpr OptionalArrayRef() noexcept = default; + + constexpr OptionalArrayRef(std::nullopt_t /*unused*/) noexcept {} + + OptionalArrayRef(const OptionalArrayRef& other) = default; + + OptionalArrayRef(OptionalArrayRef&& other) noexcept = default; + + constexpr OptionalArrayRef(const std::optional>& other) noexcept + : wrapped_opt_array_ref(other) {} + + constexpr OptionalArrayRef(std::optional>&& other) noexcept + : wrapped_opt_array_ref(std::move(other)) {} + + constexpr OptionalArrayRef(const T& value) noexcept + : wrapped_opt_array_ref(value) {} + + template < + typename U = ArrayRef, + std::enable_if_t< + !std::is_same_v, OptionalArrayRef> && + !std::is_same_v, std::in_place_t> && + std::is_constructible_v, U&&> && + std::is_convertible_v> && + !std::is_convertible_v, + bool> = false> + constexpr OptionalArrayRef(U&& value) noexcept( + std::is_nothrow_constructible_v, U&&>) + : wrapped_opt_array_ref(std::forward(value)) {} + + template < + typename U = ArrayRef, + std::enable_if_t< + !std::is_same_v, OptionalArrayRef> && + !std::is_same_v, std::in_place_t> && + std::is_constructible_v, U&&> && + !std::is_convertible_v>, + bool> = false> + constexpr explicit OptionalArrayRef(U&& value) noexcept( + std::is_nothrow_constructible_v, U&&>) + : wrapped_opt_array_ref(std::forward(value)) {} + + template + constexpr explicit OptionalArrayRef( + std::in_place_t ip, + Args&&... args) noexcept + : wrapped_opt_array_ref(ip, std::forward(args)...) {} + + template + constexpr explicit OptionalArrayRef( + std::in_place_t ip, + std::initializer_list il, + Args&&... args) + : wrapped_opt_array_ref(ip, il, std::forward(args)...) {} + + constexpr OptionalArrayRef(const std::initializer_list& Vec) + : wrapped_opt_array_ref(ArrayRef(Vec)) {} + + // Destructor + + ~OptionalArrayRef() = default; + + // Assignment + + constexpr OptionalArrayRef& operator=(std::nullopt_t /*unused*/) noexcept { + wrapped_opt_array_ref = std::nullopt; + return *this; + } + + OptionalArrayRef& operator=(const OptionalArrayRef& other) = default; + + OptionalArrayRef& operator=(OptionalArrayRef&& other) noexcept = default; + + constexpr OptionalArrayRef& operator=( + const std::optional>& other) noexcept { + wrapped_opt_array_ref = other; + return *this; + } + + constexpr OptionalArrayRef& operator=( + std::optional>&& other) noexcept { + wrapped_opt_array_ref = std::move(other); + return *this; + } + + template < + typename U = ArrayRef, + typename = std::enable_if_t< + !std::is_same_v, OptionalArrayRef> && + std::is_constructible_v, U&&> && + std::is_assignable_v&, U&&>>> + constexpr OptionalArrayRef& operator=(U&& value) noexcept( + std::is_nothrow_constructible_v, U&&> && + std::is_nothrow_assignable_v&, U&&>) { + wrapped_opt_array_ref = std::forward(value); + return *this; + } + + // Observers + + constexpr ArrayRef* operator->() noexcept { + return &wrapped_opt_array_ref.value(); + } + + constexpr const ArrayRef* operator->() const noexcept { + return &wrapped_opt_array_ref.value(); + } + + constexpr ArrayRef& operator*() & noexcept { + return wrapped_opt_array_ref.value(); + } + + constexpr const ArrayRef& operator*() const& noexcept { + return wrapped_opt_array_ref.value(); + } + + constexpr ArrayRef&& operator*() && noexcept { + return std::move(wrapped_opt_array_ref.value()); + } + + constexpr const ArrayRef&& operator*() const&& noexcept { + return std::move(wrapped_opt_array_ref.value()); + } + + constexpr explicit operator bool() const noexcept { + return wrapped_opt_array_ref.has_value(); + } + + constexpr bool has_value() const noexcept { + return wrapped_opt_array_ref.has_value(); + } + + constexpr ArrayRef& value() & { + return wrapped_opt_array_ref.value(); + } + + constexpr const ArrayRef& value() const& { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + return wrapped_opt_array_ref.value(); + } + + constexpr ArrayRef&& value() && { + return std::move(wrapped_opt_array_ref.value()); + } + + constexpr const ArrayRef&& value() const&& { + return std::move(wrapped_opt_array_ref.value()); + } + + template + constexpr std:: + enable_if_t>, ArrayRef> + value_or(U&& default_value) const& { + return wrapped_opt_array_ref.value_or(std::forward(default_value)); + } + + template + constexpr std:: + enable_if_t>, ArrayRef> + value_or(U&& default_value) && { + return wrapped_opt_array_ref.value_or(std::forward(default_value)); + } + + // Modifiers + + constexpr void swap(OptionalArrayRef& other) noexcept { + std::swap(wrapped_opt_array_ref, other.wrapped_opt_array_ref); + } + + constexpr void reset() noexcept { + wrapped_opt_array_ref.reset(); + } + + template + constexpr std:: + enable_if_t, Args&&...>, ArrayRef&> + emplace(Args&&... args) noexcept( + std::is_nothrow_constructible_v, Args&&...>) { + return wrapped_opt_array_ref.emplace(std::forward(args)...); + } + + template + constexpr ArrayRef& emplace( + std::initializer_list il, + Args&&... args) noexcept { + return wrapped_opt_array_ref.emplace(il, std::forward(args)...); + } + + private: + std::optional> wrapped_opt_array_ref; +}; + +using OptionalIntArrayRef = OptionalArrayRef; + +inline bool operator==( + const OptionalIntArrayRef& a1, + const IntArrayRef& other) { + if (!a1.has_value()) { + return false; + } + return a1.value() == other; +} + +inline bool operator==( + const c10::IntArrayRef& a1, + const c10::OptionalIntArrayRef& a2) { + return a2 == a1; +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/ParallelGuard.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/ParallelGuard.h new file mode 100644 index 0000000000000000000000000000000000000000..e577497980fbf93d2e928b9c879f085cc1852a4d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/ParallelGuard.h @@ -0,0 +1,25 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace c10 { + +// RAII thread local guard that tracks whether code is being executed in +// `at::parallel_for` or `at::parallel_reduce` loop function. +class C10_API ParallelGuard { + public: + static bool is_enabled(); + + ParallelGuard(bool state); + ~ParallelGuard(); + + private: + bool previous_state_; +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Registry.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Registry.h new file mode 100644 index 0000000000000000000000000000000000000000..92d1809d8c3094d19c927d9594afab15eba475ad --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Registry.h @@ -0,0 +1,334 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef C10_UTIL_REGISTRY_H_ +#define C10_UTIL_REGISTRY_H_ + +/** + * Simple registry implementation that uses static variables to + * register object creators during program initialization time. + */ + +// NB: This Registry works poorly when you have other namespaces. +// Make all macro invocations from inside the at namespace. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace c10 { + +template +inline std::string KeyStrRepr(const KeyType& /*key*/) { + return "[key type printing not supported]"; +} + +template <> +inline std::string KeyStrRepr(const std::string& key) { + return key; +} + +enum RegistryPriority { + REGISTRY_FALLBACK = 1, + REGISTRY_DEFAULT = 2, + REGISTRY_PREFERRED = 3, +}; + +/** + * @brief A template class that allows one to register classes by keys. + * + * The keys are usually a std::string specifying the name, but can be anything + * that can be used in a std::map. + * + * You should most likely not use the Registry class explicitly, but use the + * helper macros below to declare specific registries as well as registering + * objects. + */ +template +class Registry { + public: + typedef std::function Creator; + + Registry(bool warning = true) : registry_(), priority_(), warning_(warning) {} + ~Registry() = default; + + void Register( + const SrcType& key, + Creator creator, + const RegistryPriority priority = REGISTRY_DEFAULT) { + std::lock_guard lock(register_mutex_); + // The if statement below is essentially the same as the following line: + // TORCH_CHECK_EQ(registry_.count(key), 0) << "Key " << key + // << " registered twice."; + // However, TORCH_CHECK_EQ depends on google logging, and since registration + // is carried out at static initialization time, we do not want to have an + // explicit dependency on glog's initialization function. + if (registry_.count(key) != 0) { + auto cur_priority = priority_[key]; + if (priority > cur_priority) { +#ifdef DEBUG + std::string warn_msg = + "Overwriting already registered item for key " + KeyStrRepr(key); + fprintf(stderr, "%s\n", warn_msg.c_str()); +#endif + registry_[key] = creator; + priority_[key] = priority; + } else if (priority == cur_priority) { + std::string err_msg = + "Key already registered with the same priority: " + KeyStrRepr(key); + fprintf(stderr, "%s\n", err_msg.c_str()); + if (terminate_) { + std::exit(1); + } else { + throw std::runtime_error(err_msg); + } + } else if (warning_) { + std::string warn_msg = + "Higher priority item already registered, skipping registration of " + + KeyStrRepr(key); + fprintf(stderr, "%s\n", warn_msg.c_str()); + } + } else { + registry_[key] = creator; + priority_[key] = priority; + } + } + + void Register( + const SrcType& key, + Creator creator, + const std::string& help_msg, + const RegistryPriority priority = REGISTRY_DEFAULT) { + Register(key, creator, priority); + help_message_[key] = help_msg; + } + + inline bool Has(const SrcType& key) { + return (registry_.count(key) != 0); + } + + ObjectPtrType Create(const SrcType& key, Args... args) { + auto it = registry_.find(key); + if (it == registry_.end()) { + // Returns nullptr if the key is not registered. + return nullptr; + } + return it->second(args...); + } + + /** + * Returns the keys currently registered as a std::vector. + */ + std::vector Keys() const { + std::vector keys; + keys.reserve(registry_.size()); + for (const auto& it : registry_) { + keys.push_back(it.first); + } + return keys; + } + + inline const std::unordered_map& HelpMessage() const { + return help_message_; + } + + const char* HelpMessage(const SrcType& key) const { + auto it = help_message_.find(key); + if (it == help_message_.end()) { + return nullptr; + } + return it->second.c_str(); + } + + // Used for testing, if terminate is unset, Registry throws instead of + // calling std::exit + void SetTerminate(bool terminate) { + terminate_ = terminate; + } + + C10_DISABLE_COPY_AND_ASSIGN(Registry); + Registry(Registry&&) = delete; + Registry& operator=(Registry&&) = delete; + + private: + std::unordered_map registry_; + std::unordered_map priority_; + bool terminate_{true}; + const bool warning_; + std::unordered_map help_message_; + std::mutex register_mutex_; +}; + +template +class Registerer { + public: + explicit Registerer( + const SrcType& key, + Registry* registry, + typename Registry::Creator creator, + const std::string& help_msg = "") { + registry->Register(key, creator, help_msg); + } + + explicit Registerer( + const SrcType& key, + const RegistryPriority priority, + Registry* registry, + typename Registry::Creator creator, + const std::string& help_msg = "") { + registry->Register(key, creator, help_msg, priority); + } + + template + static ObjectPtrType DefaultCreator(Args... args) { + return ObjectPtrType(new DerivedType(args...)); + } +}; + +/** + * C10_DECLARE_TYPED_REGISTRY is a macro that expands to a function + * declaration, as well as creating a convenient typename for its corresponding + * registerer. + */ +// Note on C10_IMPORT and C10_EXPORT below: we need to explicitly mark DECLARE +// as import and DEFINE as export, because these registry macros will be used +// in downstream shared libraries as well, and one cannot use *_API - the API +// macro will be defined on a per-shared-library basis. Semantically, when one +// declares a typed registry it is always going to be IMPORT, and when one +// defines a registry (which should happen ONLY ONCE and ONLY IN SOURCE FILE), +// the instantiation unit is always going to be exported. +// +// The only unique condition is when in the same file one does DECLARE and +// DEFINE - in Windows compilers, this generates a warning that dllimport and +// dllexport are mixed, but the warning is fine and linker will be properly +// exporting the symbol. Same thing happens in the gflags flag declaration and +// definition caes. +#define C10_DECLARE_TYPED_REGISTRY( \ + RegistryName, SrcType, ObjectType, PtrType, ...) \ + C10_API ::c10::Registry, ##__VA_ARGS__>* \ + RegistryName(); \ + typedef ::c10::Registerer, ##__VA_ARGS__> \ + Registerer##RegistryName + +#define TORCH_DECLARE_TYPED_REGISTRY( \ + RegistryName, SrcType, ObjectType, PtrType, ...) \ + TORCH_API ::c10::Registry, ##__VA_ARGS__>* \ + RegistryName(); \ + typedef ::c10::Registerer, ##__VA_ARGS__> \ + Registerer##RegistryName + +#define C10_DEFINE_TYPED_REGISTRY( \ + RegistryName, SrcType, ObjectType, PtrType, ...) \ + C10_EXPORT ::c10::Registry, ##__VA_ARGS__>* \ + RegistryName() { \ + static ::c10::Registry, ##__VA_ARGS__>* \ + registry = new ::c10:: \ + Registry, ##__VA_ARGS__>(); \ + return registry; \ + } + +#define C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \ + RegistryName, SrcType, ObjectType, PtrType, ...) \ + C10_EXPORT ::c10::Registry, ##__VA_ARGS__>* \ + RegistryName() { \ + static ::c10::Registry, ##__VA_ARGS__>* \ + registry = \ + new ::c10::Registry, ##__VA_ARGS__>( \ + false); \ + return registry; \ + } + +// Note(Yangqing): The __VA_ARGS__ below allows one to specify a templated +// creator with comma in its templated arguments. +#define C10_REGISTER_TYPED_CREATOR(RegistryName, key, ...) \ + static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \ + key, RegistryName(), ##__VA_ARGS__); + +#define C10_REGISTER_TYPED_CREATOR_WITH_PRIORITY( \ + RegistryName, key, priority, ...) \ + static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \ + key, priority, RegistryName(), ##__VA_ARGS__); + +#define C10_REGISTER_TYPED_CLASS(RegistryName, key, ...) \ + static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \ + key, \ + RegistryName(), \ + Registerer##RegistryName::DefaultCreator<__VA_ARGS__>, \ + ::c10::demangle_type<__VA_ARGS__>()); + +#define C10_REGISTER_TYPED_CLASS_WITH_PRIORITY( \ + RegistryName, key, priority, ...) \ + static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \ + key, \ + priority, \ + RegistryName(), \ + Registerer##RegistryName::DefaultCreator<__VA_ARGS__>, \ + ::c10::demangle_type<__VA_ARGS__>()); + +// C10_DECLARE_REGISTRY and C10_DEFINE_REGISTRY are hard-wired to use +// std::string as the key type, because that is the most commonly used cases. +#define C10_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \ + C10_DECLARE_TYPED_REGISTRY( \ + RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__) + +#define TORCH_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \ + TORCH_DECLARE_TYPED_REGISTRY( \ + RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__) + +#define C10_DEFINE_REGISTRY(RegistryName, ObjectType, ...) \ + C10_DEFINE_TYPED_REGISTRY( \ + RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__) + +#define C10_DEFINE_REGISTRY_WITHOUT_WARNING(RegistryName, ObjectType, ...) \ + C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \ + RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__) + +#define C10_DECLARE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \ + C10_DECLARE_TYPED_REGISTRY( \ + RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__) + +#define TORCH_DECLARE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \ + TORCH_DECLARE_TYPED_REGISTRY( \ + RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__) + +#define C10_DEFINE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \ + C10_DEFINE_TYPED_REGISTRY( \ + RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__) + +#define C10_DEFINE_SHARED_REGISTRY_WITHOUT_WARNING( \ + RegistryName, ObjectType, ...) \ + C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \ + RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__) + +// C10_REGISTER_CREATOR and C10_REGISTER_CLASS are hard-wired to use std::string +// as the key +// type, because that is the most commonly used cases. +#define C10_REGISTER_CREATOR(RegistryName, key, ...) \ + C10_REGISTER_TYPED_CREATOR(RegistryName, #key, __VA_ARGS__) + +#define C10_REGISTER_CREATOR_WITH_PRIORITY(RegistryName, key, priority, ...) \ + C10_REGISTER_TYPED_CREATOR_WITH_PRIORITY( \ + RegistryName, #key, priority, __VA_ARGS__) + +#define C10_REGISTER_CLASS(RegistryName, key, ...) \ + C10_REGISTER_TYPED_CLASS(RegistryName, #key, __VA_ARGS__) + +#define C10_REGISTER_CLASS_WITH_PRIORITY(RegistryName, key, priority, ...) \ + C10_REGISTER_TYPED_CLASS_WITH_PRIORITY( \ + RegistryName, #key, priority, __VA_ARGS__) + +} // namespace c10 + +#endif // C10_UTIL_REGISTRY_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Semaphore.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Semaphore.h new file mode 100644 index 0000000000000000000000000000000000000000..1a0e63680bee7b6aa9107d6f0a20fa50388d6acb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Semaphore.h @@ -0,0 +1,76 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +/* + a simple semaphore interface. +*/ + +// note: __cpp_lib_semaphore will not be defined in some apple platforms +// even if >= C++20. +#if __has_include() && defined(__cpp_lib_semaphore) && __cpp_lib_semaphore >= 201907L +#define C10_SEMAPHORE_USE_STL +#endif + +#ifdef C10_SEMAPHORE_USE_STL +#include +#else +// To use moodycamel semaphore, we need to include the header file +// for concurrentqueue first. Hiding implementation detail here. +#ifdef BLOCK_SIZE +#pragma push_macro("BLOCK_SIZE") +#undef BLOCK_SIZE +#include // @manual +#pragma pop_macro("BLOCK_SIZE") +#else +#include // @manual +#endif + +#include // @manual +#endif + +namespace c10 { + +class Semaphore { + public: + Semaphore(int32_t initial_count = 0) : impl_(initial_count) {} + + void release(int32_t n = 1) { +#ifdef C10_SEMAPHORE_USE_STL + impl_.release(n); +#else + impl_.signal(n); +#endif + } + + void acquire() { +#ifdef C10_SEMAPHORE_USE_STL + impl_.acquire(); +#else + impl_.wait(); +#endif + } + + bool tryAcquire() { +#ifdef C10_SEMAPHORE_USE_STL + return impl_.try_acquire(); +#else + return impl_.tryWait(); +#endif + } + + private: +#ifdef C10_SEMAPHORE_USE_STL + std::counting_semaphore<> impl_; +#else + moodycamel::LightweightSemaphore impl_; +#endif +}; +} // namespace c10 + +#undef C10_SEMAPHORE_USE_STL + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/ThreadLocal.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/ThreadLocal.h new file mode 100644 index 0000000000000000000000000000000000000000..e5b92117a67fed4731ba92dc0b116b6c5aa80bcd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/ThreadLocal.h @@ -0,0 +1,161 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +/** + * Android versions with libgnustl incorrectly handle thread_local C++ + * qualifier with composite types. NDK up to r17 version is affected. + * + * (A fix landed on Jun 4 2018: + * https://android-review.googlesource.com/c/toolchain/gcc/+/683601) + * + * In such cases, use c10::ThreadLocal wrapper + * which is `pthread_*` based with smart pointer semantics. + * + * In addition, convenient macro C10_DEFINE_TLS_static is available. + * To define static TLS variable of type std::string, do the following + * ``` + * C10_DEFINE_TLS_static(std::string, str_tls_); + * /////// + * { + * *str_tls_ = "abc"; + * assert(str_tls_->length(), 3); + * } + * ``` + * + * (see c10/test/util/ThreadLocal_test.cpp for more examples) + */ +#if !defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE) + +#if defined(C10_ANDROID) && defined(__GLIBCXX__) && __GLIBCXX__ < 20180604 +#define C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE +#endif // defined(C10_ANDROID) && defined(__GLIBCXX__) && __GLIBCXX__ < 20180604 + +#endif // !defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE) + +#if defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE) +#include +#include +#include +#include +namespace c10 { + +/** + * @brief Temporary thread_local C++ qualifier replacement for Android + * based on `pthread_*`. + * To be used with composite types that provide default ctor. + */ +template +class ThreadLocal { + public: + ThreadLocal() { + pthread_key_create( + &key_, [](void* buf) { delete static_cast(buf); }); + } + + ~ThreadLocal() { + if (void* current = pthread_getspecific(key_)) { + delete static_cast(current); + } + + pthread_key_delete(key_); + } + + ThreadLocal(const ThreadLocal&) = delete; + ThreadLocal& operator=(const ThreadLocal&) = delete; + + Type& get() { + if (void* current = pthread_getspecific(key_)) { + return *static_cast(current); + } + + std::unique_ptr ptr = std::make_unique(); + if (0 == pthread_setspecific(key_, ptr.get())) { + return *ptr.release(); + } + + int err = errno; + TORCH_INTERNAL_ASSERT(false, "pthread_setspecific() failed, errno = ", err); + } + + Type& operator*() { + return get(); + } + + Type* operator->() { + return &get(); + } + + private: + pthread_key_t key_; +}; + +} // namespace c10 + +#define C10_DEFINE_TLS_static(Type, Name) static ::c10::ThreadLocal Name + +#define C10_DECLARE_TLS_class_static(Class, Type, Name) \ + static ::c10::ThreadLocal Name + +#define C10_DEFINE_TLS_class_static(Class, Type, Name) \ + ::c10::ThreadLocal Class::Name + +#else // defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE) + +namespace c10 { + +/** + * @brief Default thread_local implementation for non-Android cases. + * To be used with composite types that provide default ctor. + */ +template +class ThreadLocal { + public: + using Accessor = Type* (*)(); + explicit ThreadLocal(Accessor accessor) : accessor_(accessor) {} + + ThreadLocal(const ThreadLocal&) = delete; + ThreadLocal(ThreadLocal&&) noexcept = default; + ThreadLocal& operator=(const ThreadLocal&) = delete; + ThreadLocal& operator=(ThreadLocal&&) noexcept = default; + ~ThreadLocal() = default; + + Type& get() { + return *accessor_(); + } + + Type& operator*() { + return get(); + } + + Type* operator->() { + return &get(); + } + + private: + Accessor accessor_; +}; + +} // namespace c10 + +#define C10_DEFINE_TLS_static(Type, Name) \ + static ::c10::ThreadLocal Name([]() { \ + static thread_local Type var; \ + return &var; \ + }) + +#define C10_DECLARE_TLS_class_static(Class, Type, Name) \ + static ::c10::ThreadLocal Name + +#define C10_DEFINE_TLS_class_static(Class, Type, Name) \ + ::c10::ThreadLocal Class::Name([]() { \ + static thread_local Type var; \ + return &var; \ + }) + +#endif // defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE) + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Type.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Type.h new file mode 100644 index 0000000000000000000000000000000000000000..9f460d4bde11da8629abece4d994a800e7918fc4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Type.h @@ -0,0 +1,35 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef C10_UTIL_TYPE_H_ +#define C10_UTIL_TYPE_H_ + +#include +#include +#ifdef __GXX_RTTI +#include +#endif // __GXX_RTTI + +#include + +namespace c10 { + +/// Utility to demangle a C++ symbol name. +C10_API std::string demangle(const char* name); + +/// Returns the printable name of the type. +template +inline const char* demangle_type() { +#ifdef __GXX_RTTI + static const auto& name = *(new std::string(demangle(typeid(T).name()))); + return name.c_str(); +#else // __GXX_RTTI + return "(RTTI disabled, cannot show name)"; +#endif // __GXX_RTTI +} + +} // namespace c10 + +#endif // C10_UTIL_TYPE_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/TypeCast.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/TypeCast.h new file mode 100644 index 0000000000000000000000000000000000000000..1d95fd90929796735962e4fb4fe1855cda857ac5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/TypeCast.h @@ -0,0 +1,215 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") +#endif +#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace c10 { + +template +struct needs_real { + constexpr static bool value = + (is_complex::value && !is_complex::value); +}; + +template +struct maybe_real { + C10_HOST_DEVICE static inline src_t apply(src_t src) { + return src; + } +}; + +template +struct maybe_real { + C10_HOST_DEVICE static inline decltype(auto) apply(src_t src) { + return src.real(); + } +}; + +template +struct maybe_bool { + C10_HOST_DEVICE static inline src_t apply(src_t src) { + return src; + } +}; + +template +struct maybe_bool { + C10_HOST_DEVICE static inline decltype(auto) apply(src_t src) { + // Don't use bool operator so as to also compile for ComplexHalf. + return src.real() || src.imag(); + } +}; + +// Note: deliberately ignores undefined behavior, consistent with NumPy. +// PyTorch's type conversions can cause a variety of undefined behavior, +// including float to integral overflow and signed to unsigned integer overflow. +// Some of this undefined behavior is addressed below. +template +struct static_cast_with_inter_type { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline dest_t apply( + src_t src) { + constexpr bool real = needs_real::value; + auto r = maybe_real::apply(src); + return static_cast(r); + } +}; + +// Partial template specialization for casting to bool. +// Need to handle complex types separately, as we don't +// simply want to cast the real part to bool. +template +struct static_cast_with_inter_type { + C10_HOST_DEVICE static inline bool apply(src_t src) { + constexpr bool complex = needs_real::value; + return static_cast(maybe_bool::apply(src)); + } +}; + +// Partial template instantiation for casting to uint8. +// Note: Converting from negative float values to unsigned integer types is +// undefined behavior in C++, and current CPU and GPU compilers exhibit +// divergent behavior. Casting from negative float values to signed +// integer types and then to unsigned integer types is not undefined, +// however, so this cast improves the consistency of type conversions +// to uint8 across compilers. +// Further note: Type conversions across compilers still have other undefined +// and divergent behavior. +template +struct static_cast_with_inter_type { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline uint8_t apply( + src_t src) { + constexpr bool real = needs_real::value; + return static_cast( + static_cast(maybe_real::apply(src))); + } +}; + +template <> +struct static_cast_with_inter_type, c10::BFloat16> { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< + c10::Half> + apply(c10::BFloat16 src) { + return static_cast>(c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type, c10::Float8_e5m2> { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< + c10::Half> + apply(c10::Float8_e5m2 src) { + return static_cast>(c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type< + c10::complex, + c10::Float8_e5m2fnuz> { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< + c10::Half> + apply(c10::Float8_e5m2fnuz src) { + return static_cast>(c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type< + c10::complex, + c10::Float8_e4m3fn> { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< + c10::Half> + apply(c10::Float8_e4m3fn src) { + return static_cast>(c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type< + c10::complex, + c10::Float8_e4m3fnuz> { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< + c10::Half> + apply(c10::Float8_e4m3fnuz src) { + return static_cast>(c10::complex{src}); + } +}; + +// TODO(#146647): Can we make all these template specialization happen +// based off our apply macros? +template <> +struct static_cast_with_inter_type< + c10::complex, + c10::Float8_e8m0fnu> { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< + c10::Half> + apply(c10::Float8_e8m0fnu src) { + return static_cast>(c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type, c10::Half> { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< + c10::Half> + apply(c10::Half src) { + return static_cast>(c10::complex{src}); + } +}; + +template <> +struct static_cast_with_inter_type< + c10::complex, + c10::complex> { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< + c10::Half> + apply(c10::complex src) { + return static_cast>( + static_cast>(src)); + } +}; + +template +C10_HOST_DEVICE To convert(From f) { + return static_cast_with_inter_type::apply(f); +} + +// Define separately to avoid being inlined and prevent code-size bloat +[[noreturn]] C10_API void report_overflow(const char* name); + +template +To checked_convert(From f, const char* name) { + // Converting to bool can't overflow so we exclude this case from checking. + if (!std::is_same_v && overflows(f)) { + report_overflow(name); + } + return convert(f); +} + +} // namespace c10 + +C10_CLANG_DIAGNOSTIC_POP() + +// Trigger tests for D25440771. TODO: Remove this line any time you want. + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/TypeList.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/TypeList.h new file mode 100644 index 0000000000000000000000000000000000000000..7386baccad1420dd13c2530c31b52b0344fe5b9e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/TypeList.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/UniqueVoidPtr.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/UniqueVoidPtr.h new file mode 100644 index 0000000000000000000000000000000000000000..dc2ba274cb76d7d7b7c810c8cc318abbe412106a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/UniqueVoidPtr.h @@ -0,0 +1,145 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include +#include + +#include +#include + +namespace c10 { + +using DeleterFnPtr = void (*)(void*); + +namespace detail { + +// Does not delete anything +C10_API void deleteNothing(void* /*unused*/); + +// A detail::UniqueVoidPtr is an owning smart pointer like unique_ptr, but +// with three major differences: +// +// 1) It is specialized to void +// +// 2) It is specialized for a function pointer deleter +// void(void* ctx); i.e., the deleter doesn't take a +// reference to the data, just to a context pointer +// (erased as void*). In fact, internally, this pointer +// is implemented as having an owning reference to +// context, and a non-owning reference to data; this is why +// you release_context(), not release() (the conventional +// API for release() wouldn't give you enough information +// to properly dispose of the object later.) +// +// 3) The deleter is guaranteed to be called when the unique +// pointer is destructed and the context is non-null; this is different +// from std::unique_ptr where the deleter is not called if the +// data pointer is null. +// +// Some of the methods have slightly different types than std::unique_ptr +// to reflect this. +// +class UniqueVoidPtr { + private: + // Lifetime tied to ctx_ + void* data_; + std::unique_ptr ctx_; + + public: + UniqueVoidPtr() : data_(nullptr), ctx_(nullptr, &deleteNothing) {} + explicit UniqueVoidPtr(void* data) + : data_(data), ctx_(nullptr, &deleteNothing) {} + UniqueVoidPtr(void* data, void* ctx, DeleterFnPtr ctx_deleter) + : data_(data), ctx_(ctx, ctx_deleter ? ctx_deleter : &deleteNothing) {} + void* operator->() const { + return data_; + } + void clear() { + ctx_ = nullptr; + data_ = nullptr; + } + void* get() const { + return data_; + } + + bool /* success */ unsafe_reset_data_and_ctx(void* new_data_and_ctx) { + if (C10_UNLIKELY(ctx_.get_deleter() != &deleteNothing)) { + return false; + } + // seems quicker than calling the no-op deleter when we reset + // NOLINTNEXTLINE(bugprone-unused-return-value) + ctx_.release(); + ctx_.reset(new_data_and_ctx); + data_ = new_data_and_ctx; + return true; + } + + void* get_context() const { + return ctx_.get(); + } + void* release_context() { + return ctx_.release(); + } + std::unique_ptr&& move_context() { + return std::move(ctx_); + } + [[nodiscard]] bool compare_exchange_deleter( + DeleterFnPtr expected_deleter, + DeleterFnPtr new_deleter) { + if (get_deleter() != expected_deleter) + return false; + ctx_ = std::unique_ptr(ctx_.release(), new_deleter); + return true; + } + + template + T* cast_context(DeleterFnPtr expected_deleter) const { + if (get_deleter() != expected_deleter) + return nullptr; + return static_cast(get_context()); + } + operator bool() const { + return data_ || ctx_; + } + DeleterFnPtr get_deleter() const { + return ctx_.get_deleter(); + } +}; + +// Note [How UniqueVoidPtr is implemented] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// UniqueVoidPtr solves a common problem for allocators of tensor data, which +// is that the data pointer (e.g., float*) which you are interested in, is not +// the same as the context pointer (e.g., DLManagedTensor) which you need +// to actually deallocate the data. Under a conventional deleter design, you +// have to store extra context in the deleter itself so that you can actually +// delete the right thing. Implementing this with standard C++ is somewhat +// error-prone: if you use a std::unique_ptr to manage tensors, the deleter will +// not be called if the data pointer is nullptr, which can cause a leak if the +// context pointer is non-null (and the deleter is responsible for freeing both +// the data pointer and the context pointer). +// +// So, in our reimplementation of unique_ptr, which just store the context +// directly in the unique pointer, and attach the deleter to the context +// pointer itself. In simple cases, the context pointer is just the pointer +// itself. + +inline bool operator==(const UniqueVoidPtr& sp, std::nullptr_t) noexcept { + return !sp; +} +inline bool operator==(std::nullptr_t, const UniqueVoidPtr& sp) noexcept { + return !sp; +} +inline bool operator!=(const UniqueVoidPtr& sp, std::nullptr_t) noexcept { + return sp; +} +inline bool operator!=(std::nullptr_t, const UniqueVoidPtr& sp) noexcept { + return sp; +} + +} // namespace detail +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Unroll.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Unroll.h new file mode 100644 index 0000000000000000000000000000000000000000..c1470391c8c4ac75f5055848de538b66beea00b7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/Unroll.h @@ -0,0 +1,35 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include + +// Utility to guarantee complete unrolling of a loop where the bounds are known +// at compile time. Various pragmas achieve similar effects, but are not as +// portable across compilers. + +// Example: c10::ForcedUnroll<4>{}(f); is equivalent to f(0); f(1); f(2); f(3); + +namespace c10 { + +template +struct ForcedUnroll { + template + C10_ALWAYS_INLINE void operator()(const Func& f, Args... args) const { + ForcedUnroll{}(f, args...); + f(std::integral_constant{}, args...); + } +}; + +template <> +struct ForcedUnroll<1> { + template + C10_ALWAYS_INLINE void operator()(const Func& f, Args... args) const { + f(std::integral_constant{}, args...); + } +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/WaitCounterDynamicBackend.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/WaitCounterDynamicBackend.h new file mode 100644 index 0000000000000000000000000000000000000000..141d5431adcc1f51286b864d02cc30c2035e3371 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/WaitCounterDynamicBackend.h @@ -0,0 +1,26 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace c10::monitor::detail { + +struct WaitCounterDynamicBackend { + void* self{nullptr}; + intptr_t (*start)(void* self, int64_t nowUs){nullptr}; + void (*stop)(void* self, int64_t nowUs, intptr_t ctx){nullptr}; + void (*destroy)(void* self){nullptr}; +}; + +using WaitCounterDynamicBackendInit = + void (*)(WaitCounterDynamicBackend*, const char* key, std::size_t keyLen); + +// This name needs to be updated if anything in the API above is changed. +constexpr std::string_view kWaitCounterDynamicBackendInitFn = + "c10_monitor_wait_counter_dynamic_backend_init_v1"; +} // namespace c10::monitor::detail + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/accumulate.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/accumulate.h new file mode 100644 index 0000000000000000000000000000000000000000..df0899a2ce0697b9ff2d8c395dc81fbb9c2d0f84 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/accumulate.h @@ -0,0 +1,129 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Copyright 2004-present Facebook. All Rights Reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +/// Sum of a list of integers; accumulates into the int64_t datatype +template < + typename C, + std::enable_if_t, int> = 0> +inline int64_t sum_integers(const C& container) { + // std::accumulate infers return type from `init` type, so if the `init` type + // is not large enough to hold the result, computation can overflow. We use + // `int64_t` here to avoid this. + return std::accumulate( + container.begin(), container.end(), static_cast(0)); +} + +/// Sum of integer elements referred to by iterators; accumulates into the +/// int64_t datatype +template < + typename Iter, + std::enable_if_t< + std::is_integral_v::value_type>, + int> = 0> +inline int64_t sum_integers(Iter begin, Iter end) { + // std::accumulate infers return type from `init` type, so if the `init` type + // is not large enough to hold the result, computation can overflow. We use + // `int64_t` here to avoid this. + return std::accumulate(begin, end, static_cast(0)); +} + +/// Product of a list of integers; accumulates into the int64_t datatype +template < + typename C, + std::enable_if_t, int> = 0> +inline int64_t multiply_integers(const C& container) { + // std::accumulate infers return type from `init` type, so if the `init` type + // is not large enough to hold the result, computation can overflow. We use + // `int64_t` here to avoid this. + return std::accumulate( + container.begin(), + container.end(), + static_cast(1), + std::multiplies<>()); +} + +/// Product of integer elements referred to by iterators; accumulates into the +/// int64_t datatype +template < + typename Iter, + std::enable_if_t< + std::is_integral_v::value_type>, + int> = 0> +inline int64_t multiply_integers(Iter begin, Iter end) { + // std::accumulate infers return type from `init` type, so if the `init` type + // is not large enough to hold the result, computation can overflow. We use + // `int64_t` here to avoid this. + return std::accumulate( + begin, end, static_cast(1), std::multiplies<>()); +} + +/// Return product of all dimensions starting from k +/// Returns 1 if k>=dims.size() +template < + typename C, + std::enable_if_t, int> = 0> +inline int64_t numelements_from_dim(const int k, const C& dims) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(k >= 0); + + if (k > static_cast(dims.size())) { + return 1; + } else { + auto cbegin = dims.cbegin(); + std::advance(cbegin, k); + return multiply_integers(cbegin, dims.cend()); + } +} + +/// Product of all dims up to k (not including dims[k]) +/// Throws an error if k>dims.size() +template < + typename C, + std::enable_if_t, int> = 0> +inline int64_t numelements_to_dim(const int k, const C& dims) { + TORCH_INTERNAL_ASSERT(0 <= k); + TORCH_INTERNAL_ASSERT((unsigned)k <= dims.size()); + + auto cend = dims.cbegin(); + std::advance(cend, k); + return multiply_integers(dims.cbegin(), cend); +} + +/// Product of all dims between k and l (including dims[k] and excluding +/// dims[l]) k and l may be supplied in either order +template < + typename C, + std::enable_if_t, int> = 0> +inline int64_t numelements_between_dim(int k, int l, const C& dims) { + TORCH_INTERNAL_ASSERT(0 <= k); + TORCH_INTERNAL_ASSERT(0 <= l); + + if (k > l) { + std::swap(k, l); + } + + TORCH_INTERNAL_ASSERT((unsigned)l < dims.size()); + + auto cbegin = dims.cbegin(); + auto cend = dims.cbegin(); + std::advance(cbegin, k); + std::advance(cend, l); + return multiply_integers(cbegin, cend); +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/complex.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/complex.h new file mode 100644 index 0000000000000000000000000000000000000000..ff5ea55c508872c075b181518ff6e1cf537bbc3a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/complex.h @@ -0,0 +1,83 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +#include +#include +#include + +// std functions +// +// The implementation of these functions also follow the design of C++20 + +namespace std { + +template +constexpr T real(const c10::complex& z) { + return z.real(); +} + +template +constexpr T imag(const c10::complex& z) { + return z.imag(); +} + +template +C10_HOST_DEVICE T abs(const c10::complex& z) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return thrust::abs(static_cast>(z)); +#else + return std::abs(static_cast>(z)); +#endif +} + +#if defined(USE_ROCM) +#define ROCm_Bug(x) +#else +#define ROCm_Bug(x) x +#endif + +template +C10_HOST_DEVICE T arg(const c10::complex& z) { + return ROCm_Bug(std)::atan2(std::imag(z), std::real(z)); +} + +#undef ROCm_Bug + +template +constexpr T norm(const c10::complex& z) { + return z.real() * z.real() + z.imag() * z.imag(); +} + +// For std::conj, there are other versions of it: +// constexpr std::complex conj( float z ); +// template< class DoubleOrInteger > +// constexpr std::complex conj( DoubleOrInteger z ); +// constexpr std::complex conj( long double z ); +// These are not implemented +// TODO(@zasdfgbnm): implement them as c10::conj +template +constexpr c10::complex conj(const c10::complex& z) { + return c10::complex(z.real(), -z.imag()); +} + +// Thrust does not have complex --> complex version of thrust::proj, +// so this function is not implemented at c10 right now. +// TODO(@zasdfgbnm): implement it by ourselves + +// There is no c10 version of std::polar, because std::polar always +// returns std::complex. Use c10::polar instead; + +} // namespace std + +#define C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H +// math functions are included in a separate file +#include // IWYU pragma: keep +// utilities for complex types +#include // IWYU pragma: keep +#undef C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/error.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/error.h new file mode 100644 index 0000000000000000000000000000000000000000..4afd8a9ab673ff71cb1d0a58e209262096e86347 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/error.h @@ -0,0 +1,16 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace c10::utils { + +// Get an error string in the thread-safe way. +C10_API std::string str_error(int errnum); + +} // namespace c10::utils + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/flat_hash_map.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/flat_hash_map.h new file mode 100644 index 0000000000000000000000000000000000000000..653401395d4098ea77752e4bafdb64682ac8c242 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/flat_hash_map.h @@ -0,0 +1,2107 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Taken from +// https://github.com/skarupke/flat_hash_map/blob/2c4687431f978f02a3780e24b8b701d22aa32d9c/flat_hash_map.hpp +// with fixes applied: +// - https://github.com/skarupke/flat_hash_map/pull/25 +// - https://github.com/skarupke/flat_hash_map/pull/26 +// - replace size_t with uint64_t to fix it for 32bit +// - add "GCC diagnostic" pragma to ignore -Wshadow +// - make sherwood_v3_table::convertible_to_iterator public because GCC5 seems +// to have issues with it otherwise +// - fix compiler warnings in operator templated_iterator +// - make use of 'if constexpr' and eliminate AssignIfTrue template + +// Copyright Malte Skarupke 2017. +// Distributed under the Boost Software License, Version 1.0. +// (See http://www.boost.org/LICENSE_1_0.txt) + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(push) +#pragma warning(disable : 4624) // destructor was implicitly defined as deleted +#endif + +#ifdef _MSC_VER +#define SKA_NOINLINE(...) __declspec(noinline) __VA_ARGS__ +#else +#define SKA_NOINLINE(...) __VA_ARGS__ __attribute__((noinline)) +#endif + +namespace ska { +struct prime_number_hash_policy; +struct power_of_two_hash_policy; +struct fibonacci_hash_policy; + +namespace detailv3 { +template +struct functor_storage : Functor { + functor_storage() = default; + functor_storage(const Functor& functor) : Functor(functor) {} + template + Result operator()(Args&&... args) { + return static_cast(*this)(std::forward(args)...); + } + template + Result operator()(Args&&... args) const { + return static_cast(*this)(std::forward(args)...); + } +}; +template +struct functor_storage { + typedef Result (*function_ptr)(Args...); + function_ptr function; + functor_storage(function_ptr function) : function(function) {} + Result operator()(Args... args) const { + return function(std::forward(args)...); + } + operator function_ptr&() { + return function; + } + operator const function_ptr&() { + return function; + } +}; +template +struct KeyOrValueHasher : functor_storage { + typedef functor_storage hasher_storage; + KeyOrValueHasher() = default; + KeyOrValueHasher(const hasher& hash) : hasher_storage(hash) {} + uint64_t operator()(const key_type& key) { + return static_cast(*this)(key); + } + uint64_t operator()(const key_type& key) const { + return static_cast(*this)(key); + } + uint64_t operator()(const value_type& value) { + return static_cast(*this)(value.first); + } + uint64_t operator()(const value_type& value) const { + return static_cast(*this)(value.first); + } + template + uint64_t operator()(const std::pair& value) { + return static_cast(*this)(value.first); + } + template + uint64_t operator()(const std::pair& value) const { + return static_cast(*this)(value.first); + } +}; +template +struct KeyOrValueEquality : functor_storage { + typedef functor_storage equality_storage; + KeyOrValueEquality() = default; + KeyOrValueEquality(const key_equal& equality) : equality_storage(equality) {} + bool operator()(const key_type& lhs, const key_type& rhs) { + return static_cast(*this)(lhs, rhs); + } + bool operator()(const key_type& lhs, const value_type& rhs) { + return static_cast(*this)(lhs, rhs.first); + } + bool operator()(const value_type& lhs, const key_type& rhs) { + return static_cast(*this)(lhs.first, rhs); + } + bool operator()(const value_type& lhs, const value_type& rhs) { + return static_cast(*this)(lhs.first, rhs.first); + } + template + bool operator()(const key_type& lhs, const std::pair& rhs) { + return static_cast(*this)(lhs, rhs.first); + } + template + bool operator()(const std::pair& lhs, const key_type& rhs) { + return static_cast(*this)(lhs.first, rhs); + } + template + bool operator()(const value_type& lhs, const std::pair& rhs) { + return static_cast(*this)(lhs.first, rhs.first); + } + template + bool operator()(const std::pair& lhs, const value_type& rhs) { + return static_cast(*this)(lhs.first, rhs.first); + } + template + bool operator()(const std::pair& lhs, const std::pair& rhs) { + return static_cast(*this)(lhs.first, rhs.first); + } +}; +static constexpr int8_t min_lookups = 4; +template +struct sherwood_v3_entry { + sherwood_v3_entry() = default; + sherwood_v3_entry(int8_t distance_from_desired) + : distance_from_desired(distance_from_desired) {} + ~sherwood_v3_entry() = default; + + bool has_value() const { + return distance_from_desired >= 0; + } + bool is_empty() const { + return distance_from_desired < 0; + } + bool is_at_desired_position() const { + return distance_from_desired <= 0; + } + template + void emplace(int8_t distance, Args&&... args) { + new (std::addressof(value)) T(std::forward(args)...); + distance_from_desired = distance; + } + + void destroy_value() { + value.~T(); + distance_from_desired = -1; + } + + int8_t distance_from_desired = -1; + static constexpr int8_t special_end_value = 0; + union { + T value; + }; +}; + +inline int8_t log2(uint64_t value) { + // NOLINTNEXTLINE(*c-arrays*) + static constexpr int8_t table[64] = { + 63, 0, 58, 1, 59, 47, 53, 2, 60, 39, 48, 27, 54, 33, 42, 3, + 61, 51, 37, 40, 49, 18, 28, 20, 55, 30, 34, 11, 43, 14, 22, 4, + 62, 57, 46, 52, 38, 26, 32, 41, 50, 36, 17, 19, 29, 10, 13, 21, + 56, 45, 25, 31, 35, 16, 9, 12, 44, 24, 15, 8, 23, 7, 6, 5}; + value |= value >> 1; + value |= value >> 2; + value |= value >> 4; + value |= value >> 8; + value |= value >> 16; + value |= value >> 32; + return table[((value - (value >> 1)) * 0x07EDD5E59A4E28C2) >> 58]; +} + +inline uint64_t next_power_of_two(uint64_t i) { + --i; + i |= i >> 1; + i |= i >> 2; + i |= i >> 4; + i |= i >> 8; + i |= i >> 16; + i |= i >> 32; + ++i; + return i; +} + +// Implementation taken from http://en.cppreference.com/w/cpp/types/void_t +// (it takes CWG1558 into account and also works for older compilers) +template +struct make_void { + typedef void type; +}; +template +using void_t = typename make_void::type; + +template +struct HashPolicySelector { + typedef fibonacci_hash_policy type; +}; +template +struct HashPolicySelector> { + typedef typename T::hash_policy type; +}; + +template < + typename T, + typename FindKey, + typename ArgumentHash, + typename DetailHasher, + typename ArgumentEqual, + typename Equal, + typename ArgumentAlloc, + typename EntryAlloc> +class sherwood_v3_table : private EntryAlloc, + private DetailHasher, + private Equal { + using Entry = detailv3::sherwood_v3_entry; + using AllocatorTraits = std::allocator_traits; + using EntryPointer = typename AllocatorTraits::pointer; + + public: + struct convertible_to_iterator; + + using value_type = T; + using size_type = uint64_t; + using difference_type = std::ptrdiff_t; + using hasher = ArgumentHash; + using key_equal = ArgumentEqual; + using allocator_type = EntryAlloc; + using reference = value_type&; + using const_reference = const value_type&; + using pointer = value_type*; + using const_pointer = const value_type*; + + sherwood_v3_table() = default; + explicit sherwood_v3_table( + size_type bucket_count, + const ArgumentHash& hash = ArgumentHash(), + const ArgumentEqual& equal = ArgumentEqual(), + const ArgumentAlloc& alloc = ArgumentAlloc()) + : EntryAlloc(alloc), DetailHasher(hash), Equal(equal) { + rehash(bucket_count); + } + sherwood_v3_table(size_type bucket_count, const ArgumentAlloc& alloc) + : sherwood_v3_table( + bucket_count, + ArgumentHash(), + ArgumentEqual(), + alloc) {} + sherwood_v3_table( + size_type bucket_count, + const ArgumentHash& hash, + const ArgumentAlloc& alloc) + : sherwood_v3_table(bucket_count, hash, ArgumentEqual(), alloc) {} + explicit sherwood_v3_table(const ArgumentAlloc& alloc) : EntryAlloc(alloc) {} + template + sherwood_v3_table( + It first, + It last, + size_type bucket_count = 0, + const ArgumentHash& hash = ArgumentHash(), + const ArgumentEqual& equal = ArgumentEqual(), + const ArgumentAlloc& alloc = ArgumentAlloc()) + : sherwood_v3_table(bucket_count, hash, equal, alloc) { + insert(first, last); + } + template + sherwood_v3_table( + It first, + It last, + size_type bucket_count, + const ArgumentAlloc& alloc) + : sherwood_v3_table( + first, + last, + bucket_count, + ArgumentHash(), + ArgumentEqual(), + alloc) {} + template + sherwood_v3_table( + It first, + It last, + size_type bucket_count, + const ArgumentHash& hash, + const ArgumentAlloc& alloc) + : sherwood_v3_table( + first, + last, + bucket_count, + hash, + ArgumentEqual(), + alloc) {} + sherwood_v3_table( + std::initializer_list il, + size_type bucket_count = 0, + const ArgumentHash& hash = ArgumentHash(), + const ArgumentEqual& equal = ArgumentEqual(), + const ArgumentAlloc& alloc = ArgumentAlloc()) + : sherwood_v3_table(bucket_count, hash, equal, alloc) { + if (bucket_count == 0) + rehash(il.size()); + insert(il.begin(), il.end()); + } + sherwood_v3_table( + std::initializer_list il, + size_type bucket_count, + const ArgumentAlloc& alloc) + : sherwood_v3_table( + il, + bucket_count, + ArgumentHash(), + ArgumentEqual(), + alloc) {} + sherwood_v3_table( + std::initializer_list il, + size_type bucket_count, + const ArgumentHash& hash, + const ArgumentAlloc& alloc) + : sherwood_v3_table(il, bucket_count, hash, ArgumentEqual(), alloc) {} + sherwood_v3_table(const sherwood_v3_table& other) + : sherwood_v3_table( + other, + AllocatorTraits::select_on_container_copy_construction( + other.get_allocator())) {} + sherwood_v3_table(const sherwood_v3_table& other, const ArgumentAlloc& alloc) + : EntryAlloc(alloc), + DetailHasher(other), + Equal(other), + _max_load_factor(other._max_load_factor) { + rehash_for_other_container(other); + try { + insert(other.begin(), other.end()); + } catch (...) { + clear(); + deallocate_data(entries, num_slots_minus_one, max_lookups); + throw; + } + } + sherwood_v3_table(sherwood_v3_table&& other) noexcept + : EntryAlloc(std::move(other)), + DetailHasher(std::move(other)), + Equal(std::move(other)) { + swap_pointers(other); + } + sherwood_v3_table( + sherwood_v3_table&& other, + const ArgumentAlloc& alloc) noexcept + : EntryAlloc(alloc), + DetailHasher(std::move(other)), + Equal(std::move(other)) { + swap_pointers(other); + } + sherwood_v3_table& operator=(const sherwood_v3_table& other) { + if (this == std::addressof(other)) + return *this; + + clear(); + if constexpr (AllocatorTraits::propagate_on_container_copy_assignment:: + value) { + if (static_cast(*this) != + static_cast(other)) { + reset_to_empty_state(); + } + static_cast(*this) = other; + } + _max_load_factor = other._max_load_factor; + static_cast(*this) = other; + static_cast(*this) = other; + rehash_for_other_container(other); + insert(other.begin(), other.end()); + return *this; + } + sherwood_v3_table& operator=(sherwood_v3_table&& other) noexcept { + if (this == std::addressof(other)) + return *this; + else if constexpr (AllocatorTraits::propagate_on_container_move_assignment:: + value) { + clear(); + reset_to_empty_state(); + static_cast(*this) = std::move(other); + swap_pointers(other); + } else if ( + static_cast(*this) == static_cast(other)) { + swap_pointers(other); + } else { + clear(); + _max_load_factor = other._max_load_factor; + rehash_for_other_container(other); + for (T& elem : other) + emplace(std::move(elem)); + other.clear(); + } + static_cast(*this) = std::move(other); + static_cast(*this) = std::move(other); + return *this; + } + ~sherwood_v3_table() { + clear(); + deallocate_data(entries, num_slots_minus_one, max_lookups); + } + + const allocator_type& get_allocator() const { + return static_cast(*this); + } + const ArgumentEqual& key_eq() const { + return static_cast(*this); + } + const ArgumentHash& hash_function() const { + return static_cast(*this); + } + + template + struct templated_iterator { + templated_iterator() = default; + templated_iterator(EntryPointer current) : current(current) {} + EntryPointer current = EntryPointer(); + + using iterator_category = std::forward_iterator_tag; + using value_type = ValueType; + using difference_type = ptrdiff_t; + using pointer = ValueType*; + using reference = ValueType&; + + friend bool operator==( + const templated_iterator& lhs, + const templated_iterator& rhs) { + return lhs.current == rhs.current; + } + friend bool operator!=( + const templated_iterator& lhs, + const templated_iterator& rhs) { + return !(lhs == rhs); + } + + templated_iterator& operator++() { + do { + ++current; + } while (current->is_empty()); + return *this; + } + templated_iterator operator++(int) { + templated_iterator copy(*this); + ++*this; + return copy; + } + + ValueType& operator*() const { + return current->value; + } + ValueType* operator->() const { + return std::addressof(current->value); + } + + // the template automatically disables the operator when value_type is + // already const, because that would cause a lot of compiler warnings + // otherwise. + template < + class target_type = const value_type, + class = std::enable_if_t< + std::is_same_v && + !std::is_same_v>> + operator templated_iterator() const { + return {current}; + } + }; + using iterator = templated_iterator; + using const_iterator = templated_iterator; + + iterator begin() { + for (EntryPointer it = entries;; ++it) { + if (it->has_value()) + return {it}; + } + } + const_iterator begin() const { + for (EntryPointer it = entries;; ++it) { + if (it->has_value()) + return {it}; + } + } + const_iterator cbegin() const { + return begin(); + } + iterator end() { + return { + entries + static_cast(num_slots_minus_one + max_lookups)}; + } + const_iterator end() const { + return { + entries + static_cast(num_slots_minus_one + max_lookups)}; + } + const_iterator cend() const { + return end(); + } + + iterator find(const FindKey& key) { + uint64_t index = + hash_policy.index_for_hash(hash_object(key), num_slots_minus_one); + EntryPointer it = entries + ptrdiff_t(index); + for (int8_t distance = 0; it->distance_from_desired >= distance; + ++distance, ++it) { + if (compares_equal(key, it->value)) + return {it}; + } + return end(); + } + const_iterator find(const FindKey& key) const { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + return const_cast(this)->find(key); + } + uint64_t count(const FindKey& key) const { + return find(key) == end() ? 0 : 1; + } + std::pair equal_range(const FindKey& key) { + iterator found = find(key); + if (found == end()) + return {found, found}; + else + return {found, std::next(found)}; + } + std::pair equal_range( + const FindKey& key) const { + const_iterator found = find(key); + if (found == end()) + return {found, found}; + else + return {found, std::next(found)}; + } + + template + std::pair emplace(Key&& key, Args&&... args) { + uint64_t index = + hash_policy.index_for_hash(hash_object(key), num_slots_minus_one); + EntryPointer current_entry = entries + ptrdiff_t(index); + int8_t distance_from_desired = 0; + for (; current_entry->distance_from_desired >= distance_from_desired; + ++current_entry, ++distance_from_desired) { + if (compares_equal(key, current_entry->value)) + return {{current_entry}, false}; + } + return emplace_new_key( + distance_from_desired, + current_entry, + std::forward(key), + std::forward(args)...); + } + + std::pair insert(const value_type& value) { + return emplace(value); + } + std::pair insert(value_type&& value) { + return emplace(std::move(value)); + } + template + iterator emplace_hint(const_iterator /*unused*/, Args&&... args) { + return emplace(std::forward(args)...).first; + } + iterator insert(const_iterator /*unused*/, const value_type& value) { + return emplace(value).first; + } + iterator insert(const_iterator /*unused*/, value_type&& value) { + return emplace(std::move(value)).first; + } + + template + void insert(It begin, It end) { + for (; begin != end; ++begin) { + emplace(*begin); + } + } + void insert(std::initializer_list il) { + insert(il.begin(), il.end()); + } + + void rehash(uint64_t num_buckets) { + num_buckets = std::max( + num_buckets, + static_cast( + std::ceil(num_elements / static_cast(_max_load_factor)))); + if (num_buckets == 0) { + reset_to_empty_state(); + return; + } + auto new_prime_index = hash_policy.next_size_over(num_buckets); + if (num_buckets == bucket_count()) + return; + int8_t new_max_lookups = compute_max_lookups(num_buckets); + EntryPointer new_buckets( + AllocatorTraits::allocate(*this, num_buckets + new_max_lookups)); + EntryPointer special_end_item = + new_buckets + static_cast(num_buckets + new_max_lookups - 1); + for (EntryPointer it = new_buckets; it != special_end_item; ++it) + it->distance_from_desired = -1; + special_end_item->distance_from_desired = Entry::special_end_value; + std::swap(entries, new_buckets); + std::swap(num_slots_minus_one, num_buckets); + --num_slots_minus_one; + hash_policy.commit(new_prime_index); + int8_t old_max_lookups = max_lookups; + max_lookups = new_max_lookups; + num_elements = 0; + for (EntryPointer + it = new_buckets, + end = it + static_cast(num_buckets + old_max_lookups); + it != end; + ++it) { + if (it->has_value()) { + emplace(std::move(it->value)); + it->destroy_value(); + } + } + deallocate_data(new_buckets, num_buckets, old_max_lookups); + } + + void reserve(uint64_t num_elements_) { + uint64_t required_buckets = num_buckets_for_reserve(num_elements_); + if (required_buckets > bucket_count()) + rehash(required_buckets); + } + + // the return value is a type that can be converted to an iterator + // the reason for doing this is that it's not free to find the + // iterator pointing at the next element. if you care about the + // next iterator, turn the return value into an iterator + convertible_to_iterator erase(const_iterator to_erase) { + EntryPointer current = to_erase.current; + current->destroy_value(); + --num_elements; + for (EntryPointer next = current + ptrdiff_t(1); + !next->is_at_desired_position(); + ++current, ++next) { + current->emplace(next->distance_from_desired - 1, std::move(next->value)); + next->destroy_value(); + } + return {to_erase.current}; + } + + iterator erase(const_iterator begin_it, const_iterator end_it) { + if (begin_it == end_it) + return {begin_it.current}; + for (EntryPointer it = begin_it.current, end = end_it.current; it != end; + ++it) { + if (it->has_value()) { + it->destroy_value(); + --num_elements; + } + } + if (end_it == this->end()) + return this->end(); + ptrdiff_t num_to_move = std::min( + static_cast(end_it.current->distance_from_desired), + end_it.current - begin_it.current); + EntryPointer to_return = end_it.current - num_to_move; + for (EntryPointer it = end_it.current; !it->is_at_desired_position();) { + EntryPointer target = it - num_to_move; + target->emplace( + it->distance_from_desired - num_to_move, std::move(it->value)); + it->destroy_value(); + ++it; + num_to_move = std::min( + static_cast(it->distance_from_desired), num_to_move); + } + return {to_return}; + } + + uint64_t erase(const FindKey& key) { + auto found = find(key); + if (found == end()) + return 0; + else { + erase(found); + return 1; + } + } + + void clear() { + for (EntryPointer it = entries, + end = it + + static_cast(num_slots_minus_one + max_lookups); + it != end; + ++it) { + if (it->has_value()) + it->destroy_value(); + } + num_elements = 0; + } + + void shrink_to_fit() { + rehash_for_other_container(*this); + } + + void swap(sherwood_v3_table& other) noexcept { + using std::swap; + swap_pointers(other); + swap(static_cast(*this), static_cast(other)); + swap( + static_cast(*this), static_cast(other)); + if (AllocatorTraits::propagate_on_container_swap::value) + swap(static_cast(*this), static_cast(other)); + } + + uint64_t size() const { + return num_elements; + } + uint64_t max_size() const { + return (AllocatorTraits::max_size(*this)) / sizeof(Entry); + } + uint64_t bucket_count() const { + return num_slots_minus_one ? num_slots_minus_one + 1 : 0; + } + size_type max_bucket_count() const { + return (AllocatorTraits::max_size(*this) - min_lookups) / sizeof(Entry); + } + uint64_t bucket(const FindKey& key) const { + return hash_policy.index_for_hash(hash_object(key), num_slots_minus_one); + } + float load_factor() const { + uint64_t buckets = bucket_count(); + if (buckets) + return static_cast(num_elements) / bucket_count(); + else + return 0; + } + void max_load_factor(float value) { + _max_load_factor = value; + } + float max_load_factor() const { + return _max_load_factor; + } + + bool empty() const { + return num_elements == 0; + } + + private: + EntryPointer entries = empty_default_table(); + uint64_t num_slots_minus_one = 0; + typename HashPolicySelector::type hash_policy; + int8_t max_lookups = detailv3::min_lookups - 1; + float _max_load_factor = 0.5f; + uint64_t num_elements = 0; + + EntryPointer empty_default_table() { + EntryPointer result = + AllocatorTraits::allocate(*this, detailv3::min_lookups); + EntryPointer special_end_item = + result + static_cast(detailv3::min_lookups - 1); + for (EntryPointer it = result; it != special_end_item; ++it) + it->distance_from_desired = -1; + special_end_item->distance_from_desired = Entry::special_end_value; + return result; + } + + static int8_t compute_max_lookups(uint64_t num_buckets) { + int8_t desired = detailv3::log2(num_buckets); + return std::max(detailv3::min_lookups, desired); + } + + uint64_t num_buckets_for_reserve(uint64_t num_elements_) const { + return static_cast(std::ceil( + static_cast(num_elements_) / + std::min(0.5, static_cast(_max_load_factor)))); + } + void rehash_for_other_container(const sherwood_v3_table& other) { + rehash( + std::min(num_buckets_for_reserve(other.size()), other.bucket_count())); + } + + void swap_pointers(sherwood_v3_table& other) { + using std::swap; + swap(hash_policy, other.hash_policy); + swap(entries, other.entries); + swap(num_slots_minus_one, other.num_slots_minus_one); + swap(num_elements, other.num_elements); + swap(max_lookups, other.max_lookups); + swap(_max_load_factor, other._max_load_factor); + } + + template + SKA_NOINLINE(std::pair) + emplace_new_key( + int8_t distance_from_desired, + EntryPointer current_entry, + Key&& key, + Args&&... args) { + using std::swap; + if (num_slots_minus_one == 0 || distance_from_desired == max_lookups || + num_elements + 1 > + (num_slots_minus_one + 1) * static_cast(_max_load_factor)) { + grow(); + return emplace(std::forward(key), std::forward(args)...); + } else if (current_entry->is_empty()) { + current_entry->emplace( + distance_from_desired, + std::forward(key), + std::forward(args)...); + ++num_elements; + return {{current_entry}, true}; + } + value_type to_insert(std::forward(key), std::forward(args)...); + swap(distance_from_desired, current_entry->distance_from_desired); + swap(to_insert, current_entry->value); + iterator result = {current_entry}; + for (++distance_from_desired, ++current_entry;; ++current_entry) { + if (current_entry->is_empty()) { + current_entry->emplace(distance_from_desired, std::move(to_insert)); + ++num_elements; + return {result, true}; + } else if (current_entry->distance_from_desired < distance_from_desired) { + swap(distance_from_desired, current_entry->distance_from_desired); + swap(to_insert, current_entry->value); + ++distance_from_desired; + } else { + ++distance_from_desired; + if (distance_from_desired == max_lookups) { + swap(to_insert, result.current->value); + grow(); + return emplace(std::move(to_insert)); + } + } + } + } + + void grow() { + rehash(std::max(uint64_t(4), 2 * bucket_count())); + } + + void deallocate_data( + EntryPointer begin, + uint64_t num_slots_minus_one_, + int8_t max_lookups_) { + AllocatorTraits::deallocate( + *this, begin, num_slots_minus_one_ + max_lookups_ + 1); + } + + void reset_to_empty_state() { + deallocate_data(entries, num_slots_minus_one, max_lookups); + entries = empty_default_table(); + num_slots_minus_one = 0; + hash_policy.reset(); + max_lookups = detailv3::min_lookups - 1; + } + + template + uint64_t hash_object(const U& key) { + return static_cast(*this)(key); + } + template + uint64_t hash_object(const U& key) const { + return static_cast(*this)(key); + } + template + bool compares_equal(const L& lhs, const R& rhs) { + return static_cast(*this)(lhs, rhs); + } + + public: + struct convertible_to_iterator { + EntryPointer it; + + operator iterator() { + if (it->has_value()) + return {it}; + else + return ++iterator{it}; + } + operator const_iterator() { + if (it->has_value()) + return {it}; + else + return ++const_iterator{it}; + } + }; +}; +} // namespace detailv3 + +struct prime_number_hash_policy { + static uint64_t mod0(uint64_t /*unused*/) { + return 0llu; + } + static uint64_t mod2(uint64_t hash) { + return hash % 2llu; + } + static uint64_t mod3(uint64_t hash) { + return hash % 3llu; + } + static uint64_t mod5(uint64_t hash) { + return hash % 5llu; + } + static uint64_t mod7(uint64_t hash) { + return hash % 7llu; + } + static uint64_t mod11(uint64_t hash) { + return hash % 11llu; + } + static uint64_t mod13(uint64_t hash) { + return hash % 13llu; + } + static uint64_t mod17(uint64_t hash) { + return hash % 17llu; + } + static uint64_t mod23(uint64_t hash) { + return hash % 23llu; + } + static uint64_t mod29(uint64_t hash) { + return hash % 29llu; + } + static uint64_t mod37(uint64_t hash) { + return hash % 37llu; + } + static uint64_t mod47(uint64_t hash) { + return hash % 47llu; + } + static uint64_t mod59(uint64_t hash) { + return hash % 59llu; + } + static uint64_t mod73(uint64_t hash) { + return hash % 73llu; + } + static uint64_t mod97(uint64_t hash) { + return hash % 97llu; + } + static uint64_t mod127(uint64_t hash) { + return hash % 127llu; + } + static uint64_t mod151(uint64_t hash) { + return hash % 151llu; + } + static uint64_t mod197(uint64_t hash) { + return hash % 197llu; + } + static uint64_t mod251(uint64_t hash) { + return hash % 251llu; + } + static uint64_t mod313(uint64_t hash) { + return hash % 313llu; + } + static uint64_t mod397(uint64_t hash) { + return hash % 397llu; + } + static uint64_t mod499(uint64_t hash) { + return hash % 499llu; + } + static uint64_t mod631(uint64_t hash) { + return hash % 631llu; + } + static uint64_t mod797(uint64_t hash) { + return hash % 797llu; + } + static uint64_t mod1009(uint64_t hash) { + return hash % 1009llu; + } + static uint64_t mod1259(uint64_t hash) { + return hash % 1259llu; + } + static uint64_t mod1597(uint64_t hash) { + return hash % 1597llu; + } + static uint64_t mod2011(uint64_t hash) { + return hash % 2011llu; + } + static uint64_t mod2539(uint64_t hash) { + return hash % 2539llu; + } + static uint64_t mod3203(uint64_t hash) { + return hash % 3203llu; + } + static uint64_t mod4027(uint64_t hash) { + return hash % 4027llu; + } + static uint64_t mod5087(uint64_t hash) { + return hash % 5087llu; + } + static uint64_t mod6421(uint64_t hash) { + return hash % 6421llu; + } + static uint64_t mod8089(uint64_t hash) { + return hash % 8089llu; + } + static uint64_t mod10193(uint64_t hash) { + return hash % 10193llu; + } + static uint64_t mod12853(uint64_t hash) { + return hash % 12853llu; + } + static uint64_t mod16193(uint64_t hash) { + return hash % 16193llu; + } + static uint64_t mod20399(uint64_t hash) { + return hash % 20399llu; + } + static uint64_t mod25717(uint64_t hash) { + return hash % 25717llu; + } + static uint64_t mod32401(uint64_t hash) { + return hash % 32401llu; + } + static uint64_t mod40823(uint64_t hash) { + return hash % 40823llu; + } + static uint64_t mod51437(uint64_t hash) { + return hash % 51437llu; + } + static uint64_t mod64811(uint64_t hash) { + return hash % 64811llu; + } + static uint64_t mod81649(uint64_t hash) { + return hash % 81649llu; + } + static uint64_t mod102877(uint64_t hash) { + return hash % 102877llu; + } + static uint64_t mod129607(uint64_t hash) { + return hash % 129607llu; + } + static uint64_t mod163307(uint64_t hash) { + return hash % 163307llu; + } + static uint64_t mod205759(uint64_t hash) { + return hash % 205759llu; + } + static uint64_t mod259229(uint64_t hash) { + return hash % 259229llu; + } + static uint64_t mod326617(uint64_t hash) { + return hash % 326617llu; + } + static uint64_t mod411527(uint64_t hash) { + return hash % 411527llu; + } + static uint64_t mod518509(uint64_t hash) { + return hash % 518509llu; + } + static uint64_t mod653267(uint64_t hash) { + return hash % 653267llu; + } + static uint64_t mod823117(uint64_t hash) { + return hash % 823117llu; + } + static uint64_t mod1037059(uint64_t hash) { + return hash % 1037059llu; + } + static uint64_t mod1306601(uint64_t hash) { + return hash % 1306601llu; + } + static uint64_t mod1646237(uint64_t hash) { + return hash % 1646237llu; + } + static uint64_t mod2074129(uint64_t hash) { + return hash % 2074129llu; + } + static uint64_t mod2613229(uint64_t hash) { + return hash % 2613229llu; + } + static uint64_t mod3292489(uint64_t hash) { + return hash % 3292489llu; + } + static uint64_t mod4148279(uint64_t hash) { + return hash % 4148279llu; + } + static uint64_t mod5226491(uint64_t hash) { + return hash % 5226491llu; + } + static uint64_t mod6584983(uint64_t hash) { + return hash % 6584983llu; + } + static uint64_t mod8296553(uint64_t hash) { + return hash % 8296553llu; + } + static uint64_t mod10453007(uint64_t hash) { + return hash % 10453007llu; + } + static uint64_t mod13169977(uint64_t hash) { + return hash % 13169977llu; + } + static uint64_t mod16593127(uint64_t hash) { + return hash % 16593127llu; + } + static uint64_t mod20906033(uint64_t hash) { + return hash % 20906033llu; + } + static uint64_t mod26339969(uint64_t hash) { + return hash % 26339969llu; + } + static uint64_t mod33186281(uint64_t hash) { + return hash % 33186281llu; + } + static uint64_t mod41812097(uint64_t hash) { + return hash % 41812097llu; + } + static uint64_t mod52679969(uint64_t hash) { + return hash % 52679969llu; + } + static uint64_t mod66372617(uint64_t hash) { + return hash % 66372617llu; + } + static uint64_t mod83624237(uint64_t hash) { + return hash % 83624237llu; + } + static uint64_t mod105359939(uint64_t hash) { + return hash % 105359939llu; + } + static uint64_t mod132745199(uint64_t hash) { + return hash % 132745199llu; + } + static uint64_t mod167248483(uint64_t hash) { + return hash % 167248483llu; + } + static uint64_t mod210719881(uint64_t hash) { + return hash % 210719881llu; + } + static uint64_t mod265490441(uint64_t hash) { + return hash % 265490441llu; + } + static uint64_t mod334496971(uint64_t hash) { + return hash % 334496971llu; + } + static uint64_t mod421439783(uint64_t hash) { + return hash % 421439783llu; + } + static uint64_t mod530980861(uint64_t hash) { + return hash % 530980861llu; + } + static uint64_t mod668993977(uint64_t hash) { + return hash % 668993977llu; + } + static uint64_t mod842879579(uint64_t hash) { + return hash % 842879579llu; + } + static uint64_t mod1061961721(uint64_t hash) { + return hash % 1061961721llu; + } + static uint64_t mod1337987929(uint64_t hash) { + return hash % 1337987929llu; + } + static uint64_t mod1685759167(uint64_t hash) { + return hash % 1685759167llu; + } + static uint64_t mod2123923447(uint64_t hash) { + return hash % 2123923447llu; + } + static uint64_t mod2675975881(uint64_t hash) { + return hash % 2675975881llu; + } + static uint64_t mod3371518343(uint64_t hash) { + return hash % 3371518343llu; + } + static uint64_t mod4247846927(uint64_t hash) { + return hash % 4247846927llu; + } + static uint64_t mod5351951779(uint64_t hash) { + return hash % 5351951779llu; + } + static uint64_t mod6743036717(uint64_t hash) { + return hash % 6743036717llu; + } + static uint64_t mod8495693897(uint64_t hash) { + return hash % 8495693897llu; + } + static uint64_t mod10703903591(uint64_t hash) { + return hash % 10703903591llu; + } + static uint64_t mod13486073473(uint64_t hash) { + return hash % 13486073473llu; + } + static uint64_t mod16991387857(uint64_t hash) { + return hash % 16991387857llu; + } + static uint64_t mod21407807219(uint64_t hash) { + return hash % 21407807219llu; + } + static uint64_t mod26972146961(uint64_t hash) { + return hash % 26972146961llu; + } + static uint64_t mod33982775741(uint64_t hash) { + return hash % 33982775741llu; + } + static uint64_t mod42815614441(uint64_t hash) { + return hash % 42815614441llu; + } + static uint64_t mod53944293929(uint64_t hash) { + return hash % 53944293929llu; + } + static uint64_t mod67965551447(uint64_t hash) { + return hash % 67965551447llu; + } + static uint64_t mod85631228929(uint64_t hash) { + return hash % 85631228929llu; + } + static uint64_t mod107888587883(uint64_t hash) { + return hash % 107888587883llu; + } + static uint64_t mod135931102921(uint64_t hash) { + return hash % 135931102921llu; + } + static uint64_t mod171262457903(uint64_t hash) { + return hash % 171262457903llu; + } + static uint64_t mod215777175787(uint64_t hash) { + return hash % 215777175787llu; + } + static uint64_t mod271862205833(uint64_t hash) { + return hash % 271862205833llu; + } + static uint64_t mod342524915839(uint64_t hash) { + return hash % 342524915839llu; + } + static uint64_t mod431554351609(uint64_t hash) { + return hash % 431554351609llu; + } + static uint64_t mod543724411781(uint64_t hash) { + return hash % 543724411781llu; + } + static uint64_t mod685049831731(uint64_t hash) { + return hash % 685049831731llu; + } + static uint64_t mod863108703229(uint64_t hash) { + return hash % 863108703229llu; + } + static uint64_t mod1087448823553(uint64_t hash) { + return hash % 1087448823553llu; + } + static uint64_t mod1370099663459(uint64_t hash) { + return hash % 1370099663459llu; + } + static uint64_t mod1726217406467(uint64_t hash) { + return hash % 1726217406467llu; + } + static uint64_t mod2174897647073(uint64_t hash) { + return hash % 2174897647073llu; + } + static uint64_t mod2740199326961(uint64_t hash) { + return hash % 2740199326961llu; + } + static uint64_t mod3452434812973(uint64_t hash) { + return hash % 3452434812973llu; + } + static uint64_t mod4349795294267(uint64_t hash) { + return hash % 4349795294267llu; + } + static uint64_t mod5480398654009(uint64_t hash) { + return hash % 5480398654009llu; + } + static uint64_t mod6904869625999(uint64_t hash) { + return hash % 6904869625999llu; + } + static uint64_t mod8699590588571(uint64_t hash) { + return hash % 8699590588571llu; + } + static uint64_t mod10960797308051(uint64_t hash) { + return hash % 10960797308051llu; + } + static uint64_t mod13809739252051(uint64_t hash) { + return hash % 13809739252051llu; + } + static uint64_t mod17399181177241(uint64_t hash) { + return hash % 17399181177241llu; + } + static uint64_t mod21921594616111(uint64_t hash) { + return hash % 21921594616111llu; + } + static uint64_t mod27619478504183(uint64_t hash) { + return hash % 27619478504183llu; + } + static uint64_t mod34798362354533(uint64_t hash) { + return hash % 34798362354533llu; + } + static uint64_t mod43843189232363(uint64_t hash) { + return hash % 43843189232363llu; + } + static uint64_t mod55238957008387(uint64_t hash) { + return hash % 55238957008387llu; + } + static uint64_t mod69596724709081(uint64_t hash) { + return hash % 69596724709081llu; + } + static uint64_t mod87686378464759(uint64_t hash) { + return hash % 87686378464759llu; + } + static uint64_t mod110477914016779(uint64_t hash) { + return hash % 110477914016779llu; + } + static uint64_t mod139193449418173(uint64_t hash) { + return hash % 139193449418173llu; + } + static uint64_t mod175372756929481(uint64_t hash) { + return hash % 175372756929481llu; + } + static uint64_t mod220955828033581(uint64_t hash) { + return hash % 220955828033581llu; + } + static uint64_t mod278386898836457(uint64_t hash) { + return hash % 278386898836457llu; + } + static uint64_t mod350745513859007(uint64_t hash) { + return hash % 350745513859007llu; + } + static uint64_t mod441911656067171(uint64_t hash) { + return hash % 441911656067171llu; + } + static uint64_t mod556773797672909(uint64_t hash) { + return hash % 556773797672909llu; + } + static uint64_t mod701491027718027(uint64_t hash) { + return hash % 701491027718027llu; + } + static uint64_t mod883823312134381(uint64_t hash) { + return hash % 883823312134381llu; + } + static uint64_t mod1113547595345903(uint64_t hash) { + return hash % 1113547595345903llu; + } + static uint64_t mod1402982055436147(uint64_t hash) { + return hash % 1402982055436147llu; + } + static uint64_t mod1767646624268779(uint64_t hash) { + return hash % 1767646624268779llu; + } + static uint64_t mod2227095190691797(uint64_t hash) { + return hash % 2227095190691797llu; + } + static uint64_t mod2805964110872297(uint64_t hash) { + return hash % 2805964110872297llu; + } + static uint64_t mod3535293248537579(uint64_t hash) { + return hash % 3535293248537579llu; + } + static uint64_t mod4454190381383713(uint64_t hash) { + return hash % 4454190381383713llu; + } + static uint64_t mod5611928221744609(uint64_t hash) { + return hash % 5611928221744609llu; + } + static uint64_t mod7070586497075177(uint64_t hash) { + return hash % 7070586497075177llu; + } + static uint64_t mod8908380762767489(uint64_t hash) { + return hash % 8908380762767489llu; + } + static uint64_t mod11223856443489329(uint64_t hash) { + return hash % 11223856443489329llu; + } + static uint64_t mod14141172994150357(uint64_t hash) { + return hash % 14141172994150357llu; + } + static uint64_t mod17816761525534927(uint64_t hash) { + return hash % 17816761525534927llu; + } + static uint64_t mod22447712886978529(uint64_t hash) { + return hash % 22447712886978529llu; + } + static uint64_t mod28282345988300791(uint64_t hash) { + return hash % 28282345988300791llu; + } + static uint64_t mod35633523051069991(uint64_t hash) { + return hash % 35633523051069991llu; + } + static uint64_t mod44895425773957261(uint64_t hash) { + return hash % 44895425773957261llu; + } + static uint64_t mod56564691976601587(uint64_t hash) { + return hash % 56564691976601587llu; + } + static uint64_t mod71267046102139967(uint64_t hash) { + return hash % 71267046102139967llu; + } + static uint64_t mod89790851547914507(uint64_t hash) { + return hash % 89790851547914507llu; + } + static uint64_t mod113129383953203213(uint64_t hash) { + return hash % 113129383953203213llu; + } + static uint64_t mod142534092204280003(uint64_t hash) { + return hash % 142534092204280003llu; + } + static uint64_t mod179581703095829107(uint64_t hash) { + return hash % 179581703095829107llu; + } + static uint64_t mod226258767906406483(uint64_t hash) { + return hash % 226258767906406483llu; + } + static uint64_t mod285068184408560057(uint64_t hash) { + return hash % 285068184408560057llu; + } + static uint64_t mod359163406191658253(uint64_t hash) { + return hash % 359163406191658253llu; + } + static uint64_t mod452517535812813007(uint64_t hash) { + return hash % 452517535812813007llu; + } + static uint64_t mod570136368817120201(uint64_t hash) { + return hash % 570136368817120201llu; + } + static uint64_t mod718326812383316683(uint64_t hash) { + return hash % 718326812383316683llu; + } + static uint64_t mod905035071625626043(uint64_t hash) { + return hash % 905035071625626043llu; + } + static uint64_t mod1140272737634240411(uint64_t hash) { + return hash % 1140272737634240411llu; + } + static uint64_t mod1436653624766633509(uint64_t hash) { + return hash % 1436653624766633509llu; + } + static uint64_t mod1810070143251252131(uint64_t hash) { + return hash % 1810070143251252131llu; + } + static uint64_t mod2280545475268481167(uint64_t hash) { + return hash % 2280545475268481167llu; + } + static uint64_t mod2873307249533267101(uint64_t hash) { + return hash % 2873307249533267101llu; + } + static uint64_t mod3620140286502504283(uint64_t hash) { + return hash % 3620140286502504283llu; + } + static uint64_t mod4561090950536962147(uint64_t hash) { + return hash % 4561090950536962147llu; + } + static uint64_t mod5746614499066534157(uint64_t hash) { + return hash % 5746614499066534157llu; + } + static uint64_t mod7240280573005008577(uint64_t hash) { + return hash % 7240280573005008577llu; + } + static uint64_t mod9122181901073924329(uint64_t hash) { + return hash % 9122181901073924329llu; + } + static uint64_t mod11493228998133068689(uint64_t hash) { + return hash % 11493228998133068689llu; + } + static uint64_t mod14480561146010017169(uint64_t hash) { + return hash % 14480561146010017169llu; + } + static uint64_t mod18446744073709551557(uint64_t hash) { + return hash % 18446744073709551557llu; + } + + using mod_function = uint64_t (*)(uint64_t); + + mod_function next_size_over(uint64_t& size) const { + // prime numbers generated by the following method: + // 1. start with a prime p = 2 + // 2. go to wolfram alpha and get p = NextPrime(2 * p) + // 3. repeat 2. until you overflow 64 bits + // you now have large gaps which you would hit if somebody called reserve() + // with an unlucky number. + // 4. to fill the gaps for every prime p go to wolfram alpha and get + // ClosestPrime(p * 2^(1/3)) and ClosestPrime(p * 2^(2/3)) and put those in + // the gaps + // 5. get PrevPrime(2^64) and put it at the end + // NOLINTNEXTLINE(*c-arrays*) + static constexpr const uint64_t prime_list[] = { + 2llu, + 3llu, + 5llu, + 7llu, + 11llu, + 13llu, + 17llu, + 23llu, + 29llu, + 37llu, + 47llu, + 59llu, + 73llu, + 97llu, + 127llu, + 151llu, + 197llu, + 251llu, + 313llu, + 397llu, + 499llu, + 631llu, + 797llu, + 1009llu, + 1259llu, + 1597llu, + 2011llu, + 2539llu, + 3203llu, + 4027llu, + 5087llu, + 6421llu, + 8089llu, + 10193llu, + 12853llu, + 16193llu, + 20399llu, + 25717llu, + 32401llu, + 40823llu, + 51437llu, + 64811llu, + 81649llu, + 102877llu, + 129607llu, + 163307llu, + 205759llu, + 259229llu, + 326617llu, + 411527llu, + 518509llu, + 653267llu, + 823117llu, + 1037059llu, + 1306601llu, + 1646237llu, + 2074129llu, + 2613229llu, + 3292489llu, + 4148279llu, + 5226491llu, + 6584983llu, + 8296553llu, + 10453007llu, + 13169977llu, + 16593127llu, + 20906033llu, + 26339969llu, + 33186281llu, + 41812097llu, + 52679969llu, + 66372617llu, + 83624237llu, + 105359939llu, + 132745199llu, + 167248483llu, + 210719881llu, + 265490441llu, + 334496971llu, + 421439783llu, + 530980861llu, + 668993977llu, + 842879579llu, + 1061961721llu, + 1337987929llu, + 1685759167llu, + 2123923447llu, + 2675975881llu, + 3371518343llu, + 4247846927llu, + 5351951779llu, + 6743036717llu, + 8495693897llu, + 10703903591llu, + 13486073473llu, + 16991387857llu, + 21407807219llu, + 26972146961llu, + 33982775741llu, + 42815614441llu, + 53944293929llu, + 67965551447llu, + 85631228929llu, + 107888587883llu, + 135931102921llu, + 171262457903llu, + 215777175787llu, + 271862205833llu, + 342524915839llu, + 431554351609llu, + 543724411781llu, + 685049831731llu, + 863108703229llu, + 1087448823553llu, + 1370099663459llu, + 1726217406467llu, + 2174897647073llu, + 2740199326961llu, + 3452434812973llu, + 4349795294267llu, + 5480398654009llu, + 6904869625999llu, + 8699590588571llu, + 10960797308051llu, + 13809739252051llu, + 17399181177241llu, + 21921594616111llu, + 27619478504183llu, + 34798362354533llu, + 43843189232363llu, + 55238957008387llu, + 69596724709081llu, + 87686378464759llu, + 110477914016779llu, + 139193449418173llu, + 175372756929481llu, + 220955828033581llu, + 278386898836457llu, + 350745513859007llu, + 441911656067171llu, + 556773797672909llu, + 701491027718027llu, + 883823312134381llu, + 1113547595345903llu, + 1402982055436147llu, + 1767646624268779llu, + 2227095190691797llu, + 2805964110872297llu, + 3535293248537579llu, + 4454190381383713llu, + 5611928221744609llu, + 7070586497075177llu, + 8908380762767489llu, + 11223856443489329llu, + 14141172994150357llu, + 17816761525534927llu, + 22447712886978529llu, + 28282345988300791llu, + 35633523051069991llu, + 44895425773957261llu, + 56564691976601587llu, + 71267046102139967llu, + 89790851547914507llu, + 113129383953203213llu, + 142534092204280003llu, + 179581703095829107llu, + 226258767906406483llu, + 285068184408560057llu, + 359163406191658253llu, + 452517535812813007llu, + 570136368817120201llu, + 718326812383316683llu, + 905035071625626043llu, + 1140272737634240411llu, + 1436653624766633509llu, + 1810070143251252131llu, + 2280545475268481167llu, + 2873307249533267101llu, + 3620140286502504283llu, + 4561090950536962147llu, + 5746614499066534157llu, + 7240280573005008577llu, + 9122181901073924329llu, + 11493228998133068689llu, + 14480561146010017169llu, + 18446744073709551557llu}; + // NOLINTNEXTLINE(*c-arrays*) + static constexpr uint64_t (*const mod_functions[])(uint64_t) = { + &mod0, + &mod2, + &mod3, + &mod5, + &mod7, + &mod11, + &mod13, + &mod17, + &mod23, + &mod29, + &mod37, + &mod47, + &mod59, + &mod73, + &mod97, + &mod127, + &mod151, + &mod197, + &mod251, + &mod313, + &mod397, + &mod499, + &mod631, + &mod797, + &mod1009, + &mod1259, + &mod1597, + &mod2011, + &mod2539, + &mod3203, + &mod4027, + &mod5087, + &mod6421, + &mod8089, + &mod10193, + &mod12853, + &mod16193, + &mod20399, + &mod25717, + &mod32401, + &mod40823, + &mod51437, + &mod64811, + &mod81649, + &mod102877, + &mod129607, + &mod163307, + &mod205759, + &mod259229, + &mod326617, + &mod411527, + &mod518509, + &mod653267, + &mod823117, + &mod1037059, + &mod1306601, + &mod1646237, + &mod2074129, + &mod2613229, + &mod3292489, + &mod4148279, + &mod5226491, + &mod6584983, + &mod8296553, + &mod10453007, + &mod13169977, + &mod16593127, + &mod20906033, + &mod26339969, + &mod33186281, + &mod41812097, + &mod52679969, + &mod66372617, + &mod83624237, + &mod105359939, + &mod132745199, + &mod167248483, + &mod210719881, + &mod265490441, + &mod334496971, + &mod421439783, + &mod530980861, + &mod668993977, + &mod842879579, + &mod1061961721, + &mod1337987929, + &mod1685759167, + &mod2123923447, + &mod2675975881, + &mod3371518343, + &mod4247846927, + &mod5351951779, + &mod6743036717, + &mod8495693897, + &mod10703903591, + &mod13486073473, + &mod16991387857, + &mod21407807219, + &mod26972146961, + &mod33982775741, + &mod42815614441, + &mod53944293929, + &mod67965551447, + &mod85631228929, + &mod107888587883, + &mod135931102921, + &mod171262457903, + &mod215777175787, + &mod271862205833, + &mod342524915839, + &mod431554351609, + &mod543724411781, + &mod685049831731, + &mod863108703229, + &mod1087448823553, + &mod1370099663459, + &mod1726217406467, + &mod2174897647073, + &mod2740199326961, + &mod3452434812973, + &mod4349795294267, + &mod5480398654009, + &mod6904869625999, + &mod8699590588571, + &mod10960797308051, + &mod13809739252051, + &mod17399181177241, + &mod21921594616111, + &mod27619478504183, + &mod34798362354533, + &mod43843189232363, + &mod55238957008387, + &mod69596724709081, + &mod87686378464759, + &mod110477914016779, + &mod139193449418173, + &mod175372756929481, + &mod220955828033581, + &mod278386898836457, + &mod350745513859007, + &mod441911656067171, + &mod556773797672909, + &mod701491027718027, + &mod883823312134381, + &mod1113547595345903, + &mod1402982055436147, + &mod1767646624268779, + &mod2227095190691797, + &mod2805964110872297, + &mod3535293248537579, + &mod4454190381383713, + &mod5611928221744609, + &mod7070586497075177, + &mod8908380762767489, + &mod11223856443489329, + &mod14141172994150357, + &mod17816761525534927, + &mod22447712886978529, + &mod28282345988300791, + &mod35633523051069991, + &mod44895425773957261, + &mod56564691976601587, + &mod71267046102139967, + &mod89790851547914507, + &mod113129383953203213, + &mod142534092204280003, + &mod179581703095829107, + &mod226258767906406483, + &mod285068184408560057, + &mod359163406191658253, + &mod452517535812813007, + &mod570136368817120201, + &mod718326812383316683, + &mod905035071625626043, + &mod1140272737634240411, + &mod1436653624766633509, + &mod1810070143251252131, + &mod2280545475268481167, + &mod2873307249533267101, + &mod3620140286502504283, + &mod4561090950536962147, + &mod5746614499066534157, + &mod7240280573005008577, + &mod9122181901073924329, + &mod11493228998133068689, + &mod14480561146010017169, + &mod18446744073709551557}; + const uint64_t* found = std::lower_bound( + std::begin(prime_list), std::end(prime_list) - 1, size); + size = *found; + return mod_functions[1 + found - prime_list]; + } + void commit(mod_function new_mod_function) { + current_mod_function = new_mod_function; + } + void reset() { + current_mod_function = &mod0; + } + + uint64_t index_for_hash(uint64_t hash, uint64_t /*num_slots_minus_one*/) + const { + return current_mod_function(hash); + } + uint64_t keep_in_range(uint64_t index, uint64_t num_slots_minus_one) const { + return index > num_slots_minus_one ? current_mod_function(index) : index; + } + + private: + mod_function current_mod_function = &mod0; +}; + +struct power_of_two_hash_policy { + uint64_t index_for_hash(uint64_t hash, uint64_t num_slots_minus_one) const { + return hash & num_slots_minus_one; + } + uint64_t keep_in_range(uint64_t index, uint64_t num_slots_minus_one) const { + return index_for_hash(index, num_slots_minus_one); + } + int8_t next_size_over(uint64_t& size) const { + size = detailv3::next_power_of_two(size); + return 0; + } + void commit(int8_t /*unused*/) {} + void reset() {} +}; + +struct fibonacci_hash_policy { + uint64_t index_for_hash(uint64_t hash, uint64_t /*num_slots_minus_one*/) + const { + return (11400714819323198485ull * hash) >> shift; + } + uint64_t keep_in_range(uint64_t index, uint64_t num_slots_minus_one) const { + return index & num_slots_minus_one; + } + + int8_t next_size_over(uint64_t& size) const { + size = std::max(uint64_t(2), detailv3::next_power_of_two(size)); + return static_cast(64 - detailv3::log2(size)); + } + void commit(int8_t shift_) { + shift = shift_; + } + void reset() { + shift = 63; + } + + private: + int8_t shift = 63; +}; + +template < + typename K, + typename V, + typename H = std::hash, + typename E = std::equal_to, + typename A = std::allocator>> +class flat_hash_map + : public detailv3::sherwood_v3_table< + std::pair, + K, + H, + detailv3::KeyOrValueHasher, H>, + E, + detailv3::KeyOrValueEquality, E>, + A, + typename std::allocator_traits::template rebind_alloc< + detailv3::sherwood_v3_entry>>> { + using Table = detailv3::sherwood_v3_table< + std::pair, + K, + H, + detailv3::KeyOrValueHasher, H>, + E, + detailv3::KeyOrValueEquality, E>, + A, + typename std::allocator_traits::template rebind_alloc< + detailv3::sherwood_v3_entry>>>; + + public: + using key_type = K; + using mapped_type = V; + + using Table::Table; + flat_hash_map() = default; + + inline V& operator[](const K& key) { + return emplace(key, convertible_to_value()).first->second; + } + inline V& operator[](K&& key) { + return emplace(std::move(key), convertible_to_value()).first->second; + } + V& at(const K& key) { + auto found = this->find(key); + if (found == this->end()) + throw std::out_of_range("Argument passed to at() was not in the map."); + return found->second; + } + const V& at(const K& key) const { + auto found = this->find(key); + if (found == this->end()) + throw std::out_of_range("Argument passed to at() was not in the map."); + return found->second; + } + + using Table::emplace; + std::pair emplace() { + return emplace(key_type(), convertible_to_value()); + } + template + std::pair insert_or_assign( + const key_type& key, + M&& m) { + auto emplace_result = emplace(key, std::forward(m)); + if (!emplace_result.second) + emplace_result.first->second = std::forward(m); + return emplace_result; + } + template + std::pair insert_or_assign( + key_type&& key, + M&& m) { + auto emplace_result = emplace(std::move(key), std::forward(m)); + if (!emplace_result.second) + emplace_result.first->second = std::forward(m); + return emplace_result; + } + template + typename Table::iterator insert_or_assign( + typename Table::const_iterator /*unused*/, + const key_type& key, + M&& m) { + return insert_or_assign(key, std::forward(m)).first; + } + template + typename Table::iterator insert_or_assign( + typename Table::const_iterator /*unused*/, + key_type&& key, + M&& m) { + return insert_or_assign(std::move(key), std::forward(m)).first; + } + + friend bool operator==(const flat_hash_map& lhs, const flat_hash_map& rhs) { + if (lhs.size() != rhs.size()) + return false; + for (const typename Table::value_type& value : lhs) { + auto found = rhs.find(value.first); + if (found == rhs.end() || value.second != found->second) + return false; + } + return true; + } + friend bool operator!=(const flat_hash_map& lhs, const flat_hash_map& rhs) { + return !(lhs == rhs); + } + + private: + struct convertible_to_value { + operator V() const { + return V(); + } + }; +}; + +template < + typename T, + typename H = std::hash, + typename E = std::equal_to, + typename A = std::allocator> +class flat_hash_set + : public detailv3::sherwood_v3_table< + T, + T, + H, + detailv3::functor_storage, + E, + detailv3::functor_storage, + A, + typename std::allocator_traits::template rebind_alloc< + detailv3::sherwood_v3_entry>> { + using Table = detailv3::sherwood_v3_table< + T, + T, + H, + detailv3::functor_storage, + E, + detailv3::functor_storage, + A, + typename std::allocator_traits::template rebind_alloc< + detailv3::sherwood_v3_entry>>; + + public: + using key_type = T; + + using Table::Table; + flat_hash_set() = default; + + template + std::pair emplace(Args&&... args) { + return Table::emplace(T(std::forward(args)...)); + } + std::pair emplace(const key_type& arg) { + return Table::emplace(arg); + } + std::pair emplace(key_type& arg) { + return Table::emplace(arg); + } + std::pair emplace(const key_type&& arg) { + return Table::emplace(std::move(arg)); + } + std::pair emplace(key_type&& arg) { + return Table::emplace(std::move(arg)); + } + + friend bool operator==(const flat_hash_set& lhs, const flat_hash_set& rhs) { + if (lhs.size() != rhs.size()) + return false; + for (const T& value : lhs) { + if (rhs.find(value) == rhs.end()) + return false; + } + return true; + } + friend bool operator!=(const flat_hash_set& lhs, const flat_hash_set& rhs) { + return !(lhs == rhs); + } +}; + +template +struct power_of_two_std_hash : std::hash { + typedef ska::power_of_two_hash_policy hash_policy; +}; + +} // end namespace ska + +C10_CLANG_DIAGNOSTIC_POP() + +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(pop) +#endif + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/generic_math.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/generic_math.h new file mode 100644 index 0000000000000000000000000000000000000000..969e095ef59a8ad07bf80089a054d14e84c682d6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/generic_math.h @@ -0,0 +1,113 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +#if defined(__CUDA_ARCH__) +#include +#define C10_COMPAT_COPYSIGN c10::cuda::compat::copysign +#elif defined(__HIPCC__) +#include +#define C10_COMPAT_COPYSIGN c10::hip::compat::copysign +#else +#include +#define C10_COMPAT_COPYSIGN c10::copysign +#endif + +// The functions in this file should be header-only as it is used under +// ABI-compatibility mode. + +namespace c10 { + +// NOTE: [Floor Division in Python] +// Python's __floordiv__ operator is more complicated than just floor(a / b). +// It aims to maintain the property: a == (a // b) * b + remainder(a, b) +// which can otherwise fail due to rounding errors in the remainder. +// So, instead it is calculated as: a // b = (a - remainder(a, b)) / b +// With some additional fix-ups added to the result. +// +// For reference, see CPython's implementation: +// https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636 + +template +inline C10_HOST_DEVICE scalar_t div_floor_floating(scalar_t a, scalar_t b) + __ubsan_ignore_float_divide_by_zero__ { + if (C10_UNLIKELY(b == 0)) { + // Divide by zero: return standard IEEE result + return a / b; + } + + auto mod = std::fmod(a, b); + auto div = (a - mod) / b; + if ((mod != 0) && (b < 0) != (mod < 0)) { + div -= scalar_t(1); + } + + scalar_t floordiv; + if (div != 0) { + floordiv = std::floor(div); + if (div - floordiv > scalar_t(0.5)) { + floordiv += scalar_t(1.0); + } + } else { + floordiv = C10_COMPAT_COPYSIGN(scalar_t(0), a / b); + } + return floordiv; +} + +template +inline C10_HOST_DEVICE scalar_t div_floor_integer(scalar_t a, scalar_t b) { + if (C10_UNLIKELY( + std::is_signed::value && + a == std::numeric_limits::min() && b == scalar_t(-1))) { + return a; + } + + if (c10::signs_differ(a, b)) { + // Subtracts one from the results of truncation division if the + // divisor and dividend have different sign(bit)s and the remainder of + // the division is nonzero + const auto quot = a / b; + const auto rem = a % b; + return rem ? quot - 1 : quot; + } + return a / b; +} + +template < + typename scalar_t, + std::enable_if_t, int> = 0> +inline C10_HOST_DEVICE scalar_t div_mod(scalar_t a, scalar_t b) + __ubsan_ignore_float_divide_by_zero__ { + if (C10_UNLIKELY(b == 0)) { + // Divide by zero: return standard IEEE result + return std::fmod(a, b); + } + + auto mod = std::fmod(a, b); + if (mod == 0) { + mod = C10_COMPAT_COPYSIGN(scalar_t(0), b); + } else if ((b < 0) != (mod < 0)) { + mod += b; + } + return mod; +} + +template < + typename scalar_t, + std::enable_if_t, int> = 0> +inline C10_HOST_DEVICE scalar_t div_mod(scalar_t a, scalar_t b) { + auto mod = a % b; + if (mod != 0 && (b < 0) != (mod < 0)) { + mod += b; + } + return mod; +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/hash.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/hash.h new file mode 100644 index 0000000000000000000000000000000000000000..c3fff128439efb6d4ddf143493fd9a3d46b04435 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/hash.h @@ -0,0 +1,384 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace c10 { + +// NOTE: hash_combine and SHA1 hashing is based on implementation from Boost +// +// Boost Software License - Version 1.0 - August 17th, 2003 +// +// Permission is hereby granted, free of charge, to any person or organization +// obtaining a copy of the software and accompanying documentation covered by +// this license (the "Software") to use, reproduce, display, distribute, +// execute, and transmit the Software, and to prepare derivative works of the +// Software, and to permit third-parties to whom the Software is furnished to +// do so, all subject to the following: +// +// The copyright notices in the Software and this entire statement, including +// the above license grant, this restriction and the following disclaimer, +// must be included in all copies of the Software, in whole or in part, and +// all derivative works of the Software, unless such copies or derivative +// works are solely in the form of machine-executable object code generated by +// a source language processor. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT +// SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE +// FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +inline size_t hash_combine(size_t seed, size_t value) { + return seed ^ (value + 0x9e3779b9 + (seed << 6u) + (seed >> 2u)); +} + +// Creates the SHA1 hash of a string. A 160-bit hash. +// Based on the implementation in Boost (see notice above). +// Note that SHA1 hashes are no longer considered cryptographically +// secure, but are the standard hash for generating unique ids. +// Usage: +// // Let 'code' be a std::string +// c10::sha1 sha1_hash{code}; +// const auto hash_code = sha1_hash.str(); +// TODO: Compare vs OpenSSL and/or CryptoPP implementations +struct sha1 { + typedef unsigned int(digest_type)[5]; + + sha1(const std::string& s = "") { + if (!s.empty()) { + reset(); + process_bytes(s.c_str(), s.size()); + } + } + + void reset() { + h_[0] = 0x67452301; + h_[1] = 0xEFCDAB89; + h_[2] = 0x98BADCFE; + h_[3] = 0x10325476; + h_[4] = 0xC3D2E1F0; + + block_byte_index_ = 0; + bit_count_low = 0; + bit_count_high = 0; + } + + std::string str() { + unsigned int digest[5]; + get_digest(digest); + + std::ostringstream buf; + for (unsigned int i : digest) { + buf << std::hex << std::setfill('0') << std::setw(8) << i; + } + + return buf.str(); + } + + private: + unsigned int left_rotate(unsigned int x, std::size_t n) { + return (x << n) ^ (x >> (32 - n)); + } + + void process_block_impl() { + unsigned int w[80]; + + for (std::size_t i = 0; i < 16; ++i) { + w[i] = (block_[i * 4 + 0] << 24); + w[i] |= (block_[i * 4 + 1] << 16); + w[i] |= (block_[i * 4 + 2] << 8); + w[i] |= (block_[i * 4 + 3]); + } + + for (std::size_t i = 16; i < 80; ++i) { + w[i] = left_rotate((w[i - 3] ^ w[i - 8] ^ w[i - 14] ^ w[i - 16]), 1); + } + + unsigned int a = h_[0]; + unsigned int b = h_[1]; + unsigned int c = h_[2]; + unsigned int d = h_[3]; + unsigned int e = h_[4]; + + for (std::size_t i = 0; i < 80; ++i) { + unsigned int f = 0; + unsigned int k = 0; + + if (i < 20) { + f = (b & c) | (~b & d); + k = 0x5A827999; + } else if (i < 40) { + f = b ^ c ^ d; + k = 0x6ED9EBA1; + } else if (i < 60) { + f = (b & c) | (b & d) | (c & d); + k = 0x8F1BBCDC; + } else { + f = b ^ c ^ d; + k = 0xCA62C1D6; + } + + unsigned temp = left_rotate(a, 5) + f + e + k + w[i]; + e = d; + d = c; + c = left_rotate(b, 30); + b = a; + a = temp; + } + + h_[0] += a; + h_[1] += b; + h_[2] += c; + h_[3] += d; + h_[4] += e; + } + + void process_byte_impl(unsigned char byte) { + block_[block_byte_index_++] = byte; + + if (block_byte_index_ == 64) { + block_byte_index_ = 0; + process_block_impl(); + } + } + + void process_byte(unsigned char byte) { + process_byte_impl(byte); + + // size_t max value = 0xFFFFFFFF + // if (bit_count_low + 8 >= 0x100000000) { // would overflow + // if (bit_count_low >= 0x100000000-8) { + if (bit_count_low < 0xFFFFFFF8) { + bit_count_low += 8; + } else { + bit_count_low = 0; + + if (bit_count_high <= 0xFFFFFFFE) { + ++bit_count_high; + } else { + TORCH_CHECK(false, "sha1 too many bytes"); + } + } + } + + void process_block(void const* bytes_begin, void const* bytes_end) { + unsigned char const* begin = static_cast(bytes_begin); + unsigned char const* end = static_cast(bytes_end); + for (; begin != end; ++begin) { + process_byte(*begin); + } + } + + void process_bytes(void const* buffer, std::size_t byte_count) { + unsigned char const* b = static_cast(buffer); + process_block(b, b + byte_count); + } + + void get_digest(digest_type& digest) { + // append the bit '1' to the message + process_byte_impl(0x80); + + // append k bits '0', where k is the minimum number >= 0 + // such that the resulting message length is congruent to 56 (mod 64) + // check if there is enough space for padding and bit_count + if (block_byte_index_ > 56) { + // finish this block + while (block_byte_index_ != 0) { + process_byte_impl(0); + } + + // one more block + while (block_byte_index_ < 56) { + process_byte_impl(0); + } + } else { + while (block_byte_index_ < 56) { + process_byte_impl(0); + } + } + + // append length of message (before pre-processing) + // as a 64-bit big-endian integer + process_byte_impl( + static_cast((bit_count_high >> 24) & 0xFF)); + process_byte_impl( + static_cast((bit_count_high >> 16) & 0xFF)); + process_byte_impl(static_cast((bit_count_high >> 8) & 0xFF)); + process_byte_impl(static_cast((bit_count_high) & 0xFF)); + process_byte_impl(static_cast((bit_count_low >> 24) & 0xFF)); + process_byte_impl(static_cast((bit_count_low >> 16) & 0xFF)); + process_byte_impl(static_cast((bit_count_low >> 8) & 0xFF)); + process_byte_impl(static_cast((bit_count_low) & 0xFF)); + + // get final digest + digest[0] = h_[0]; + digest[1] = h_[1]; + digest[2] = h_[2]; + digest[3] = h_[3]; + digest[4] = h_[4]; + } + + unsigned int h_[5]{}; + unsigned char block_[64]{}; + std::size_t block_byte_index_{}; + std::size_t bit_count_low{}; + std::size_t bit_count_high{}; +}; + +constexpr uint64_t twang_mix64(uint64_t key) noexcept { + key = (~key) + (key << 21); // key *= (1 << 21) - 1; key -= 1; + key = key ^ (key >> 24); + key = key + (key << 3) + (key << 8); // key *= 1 + (1 << 3) + (1 << 8) + key = key ^ (key >> 14); + key = key + (key << 2) + (key << 4); // key *= 1 + (1 << 2) + (1 << 4) + key = key ^ (key >> 28); + key = key + (key << 31); // key *= 1 + (1 << 31) + return key; +} + +//////////////////////////////////////////////////////////////////////////////// +// c10::hash implementation +//////////////////////////////////////////////////////////////////////////////// + +namespace _hash_detail { + +// Use template argument deduction to shorten calls to c10::hash +template +size_t simple_get_hash(const T& o); + +template +using type_if_not_enum = std::enable_if_t, V>; + +// Use SFINAE to dispatch to std::hash if possible, cast enum types to int +// automatically, and fall back to T::hash otherwise. NOTE: C++14 added support +// for hashing enum types to the standard, and some compilers implement it even +// when C++14 flags aren't specified. This is why we have to disable this +// overload if T is an enum type (and use the one below in this case). +template +auto dispatch_hash(const T& o) + -> decltype(std::hash()(o), type_if_not_enum()) { + return std::hash()(o); +} + +template +std::enable_if_t, size_t> dispatch_hash(const T& o) { + using R = std::underlying_type_t; + return std::hash()(static_cast(o)); +} + +template +auto dispatch_hash(const T& o) -> decltype(T::hash(o), size_t()) { + return T::hash(o); +} + +} // namespace _hash_detail + +// Hasher struct +template +struct hash { + size_t operator()(const T& o) const { + return _hash_detail::dispatch_hash(o); + } +}; + +// Specialization for std::tuple +template +struct hash> { + template + struct tuple_hash { + size_t operator()(const std::tuple& t) const { + return hash_combine( + _hash_detail::simple_get_hash(std::get(t)), + tuple_hash()(t)); + } + }; + + template + struct tuple_hash<0, Ts...> { + size_t operator()(const std::tuple& t) const { + return _hash_detail::simple_get_hash(std::get<0>(t)); + } + }; + + size_t operator()(const std::tuple& t) const { + return tuple_hash()(t); + } +}; + +template +struct hash> { + size_t operator()(const std::pair& pair) const { + std::tuple tuple = std::make_tuple(pair.first, pair.second); + return _hash_detail::simple_get_hash(tuple); + } +}; + +template +struct hash> { + size_t operator()(c10::ArrayRef v) const { + size_t seed = 0; + for (const auto& elem : v) { + seed = hash_combine(seed, _hash_detail::simple_get_hash(elem)); + } + return seed; + } +}; + +// Specialization for std::vector +template +struct hash> { + size_t operator()(const std::vector& v) const { + return hash>()(v); + } +}; + +namespace _hash_detail { + +template +size_t simple_get_hash(const T& o) { + return c10::hash()(o); +} + +} // namespace _hash_detail + +// Use this function to actually hash multiple things in one line. +// Dispatches to c10::hash, so it can hash containers. +// Example: +// +// static size_t hash(const MyStruct& s) { +// return get_hash(s.member1, s.member2, s.member3); +// } +template +size_t get_hash(const Types&... args) { + return c10::hash()(std::tie(args...)); +} + +// Specialization for c10::complex +template +struct hash> { + size_t operator()(const c10::complex& c) const { + return get_hash(c.real(), c.imag()); + } +}; + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/logging_common.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/logging_common.h new file mode 100644 index 0000000000000000000000000000000000000000..8d881f4de245b1fe650b322cffa1dc294bd019dd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/logging_common.h @@ -0,0 +1,79 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef C10_UTIL_LOGGING_COMMON_H_ +#define C10_UTIL_LOGGING_COMMON_H_ + +#include +#include + +namespace c10 { + +// MessageLogger that throws exceptions instead of aborting (glog version) +// or logs and may abort (non-glog version). +class C10_API MessageLogger { + public: + MessageLogger( + const char* file, + int line, + int severity, + bool exit_on_fatal = true); + ~MessageLogger() noexcept(false); + + // Return the stream associated with the logger object. + std::stringstream& stream(); + + private: + // When there is a fatal log, and fatal == true, we abort + // otherwise, we throw. + void DealWithFatal(); + +#if defined(ANDROID) && !defined(C10_USE_GLOG) + const char* tag_{"native"}; +#endif + std::stringstream stream_; + int severity_; + bool exit_on_fatal_; +}; + +// This class is used to explicitly ignore values in the conditional +// logging macros. This avoids compiler warnings like "value computed +// is not used" and "statement has no effect". +class C10_API LoggerVoidify { + public: + LoggerVoidify() = default; + // This has to be an operator with a precedence lower than << but + // higher than ?: + void operator&(const std::ostream& s [[maybe_unused]]) {} +}; + +// Forward declarations for CheckNotNull functions +template +T& CheckNotNullCommon( + const char* file, + int line, + const char* names, + T& t, + bool fatal = true); + +template +T* CheckNotNull( + const char* file, + int line, + const char* names, + T* t, + bool fatal = true); + +template +T& CheckNotNull( + const char* file, + int line, + const char* names, + T& t, + bool fatal = true); + +} // namespace c10 + +#endif // C10_UTIL_LOGGING_COMMON_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/logging_is_google_glog.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/logging_is_google_glog.h new file mode 100644 index 0000000000000000000000000000000000000000..082e0b86484f7b62d7b0d383d6c717ef5a9d9340 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/logging_is_google_glog.h @@ -0,0 +1,110 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef C10_UTIL_LOGGING_IS_GOOGLE_GLOG_H_ +#define C10_UTIL_LOGGING_IS_GOOGLE_GLOG_H_ + +#include +#include +#include + +#include // because some of the caffe2 code uses e.g. std::setw +// Using google glog. For glog 0.3.2 versions, stl_logging.h needs to be before +// logging.h to actually use stl_logging. Because template magic. +// In addition, we do not do stl logging in .cu files because nvcc does not like +// it. Some mobile platforms do not like stl_logging, so we add an +// overload in that case as well. + +#ifdef __CUDACC__ +#include +#endif + +#if !defined(__CUDACC__) && !defined(C10_USE_MINIMAL_GLOG) +#include + +// Old versions of glog don't declare this using declaration, so help +// them out. Fortunately, C++ won't complain if you declare the same +// using declaration multiple times. +namespace std { +using ::operator<<; +} + +#else // !defined(__CUDACC__) && !defined(C10_USE_MINIMAL_GLOG) + +// In the cudacc compiler scenario, we will simply ignore the container +// printout feature. Basically we need to register a fake overload for +// vector/string - here, we just ignore the entries in the logs. + +namespace std { +#define INSTANTIATE_FOR_CONTAINER(container) \ + template \ + ostream& operator<<(ostream& out, const container&) { \ + return out; \ + } + +INSTANTIATE_FOR_CONTAINER(vector) +INSTANTIATE_FOR_CONTAINER(map) +INSTANTIATE_FOR_CONTAINER(set) +#undef INSTANTIATE_FOR_CONTAINER +} // namespace std + +#endif + +#include +#include + +namespace c10 { + +[[noreturn]] void ThrowEnforceNotMet( + const char* file, + const int line, + const char* condition, + const std::string& msg, + const void* caller); + +template +T& CheckNotNullCommon( + const char* file, + int line, + const char* names, + T& t, + bool fatal) { + if (t == nullptr) { + MessageLogger(file, line, ::google::GLOG_FATAL, fatal).stream() + << "Check failed: '" << names << "' must be non NULL. "; + } + return t; +} + +template +T* CheckNotNull( + const char* file, + int line, + const char* names, + T* t, + bool fatal) { + return CheckNotNullCommon(file, line, names, t, fatal); +} + +template +T& CheckNotNull( + const char* file, + int line, + const char* names, + T& t, + bool fatal) { + return CheckNotNullCommon(file, line, names, t, fatal); +} + +} // namespace c10 + +// Log with source location information override (to be used in generic +// warning/error handlers implemented as functions, not macros) +// +// Note, we don't respect GOOGLE_STRIP_LOG here for simplicity +#define LOG_AT_FILE_LINE(n, file, line) \ + ::google::LogMessage(file, line, ::google::GLOG_##n).stream() + +#endif // C10_UTIL_LOGGING_IS_GOOGLE_GLOG_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/logging_is_not_google_glog.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/logging_is_not_google_glog.h new file mode 100644 index 0000000000000000000000000000000000000000..efeffb93afc3e05f0780dc554534fb856c49de3b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/logging_is_not_google_glog.h @@ -0,0 +1,186 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef C10_UTIL_LOGGING_IS_NOT_GOOGLE_GLOG_H_ +#define C10_UTIL_LOGGING_IS_NOT_GOOGLE_GLOG_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +const char CAFFE2_SEVERITY_PREFIX[] = "FEWIV"; + +namespace c10 { + +// Log severity level constants. +const int GLOG_FATAL = 3; +const int GLOG_ERROR = 2; +const int GLOG_WARNING = 1; +const int GLOG_INFO = 0; + +// Helpers for TORCH_CHECK_NOTNULL(). Two are necessary to support both raw +// pointers and smart pointers. +template +T& CheckNotNullCommon( + const char* file, + int line, + const char* names, + T& t, + bool fatal) { + if (t == nullptr) { + MessageLogger(file, line, GLOG_FATAL, fatal).stream() + << "Check failed: '" << names << "' must be non NULL. "; + } + return t; +} + +template +T* CheckNotNull( + const char* file, + int line, + const char* names, + T* t, + bool fatal) { + return CheckNotNullCommon(file, line, names, t, fatal); +} + +template +T& CheckNotNull( + const char* file, + int line, + const char* names, + T& t, + bool fatal) { + return CheckNotNullCommon(file, line, names, t, fatal); +} +} // namespace c10 + +// ---------------------- Logging Macro definitions -------------------------- + +static_assert( + CAFFE2_LOG_THRESHOLD <= ::c10::GLOG_FATAL, + "CAFFE2_LOG_THRESHOLD should at most be GLOG_FATAL."); +// If n is under the compile time caffe log threshold, The _CAFFE_LOG(n) +// should not generate anything in optimized code. +#define LOG(n) \ + if (::c10::GLOG_##n >= CAFFE2_LOG_THRESHOLD) \ + ::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_##n).stream() +#define VLOG(n) \ + if (-n >= CAFFE2_LOG_THRESHOLD) \ + ::c10::MessageLogger(__FILE__, __LINE__, -n).stream() + +#define LOG_IF(n, condition) \ + if (::c10::GLOG_##n >= CAFFE2_LOG_THRESHOLD && (condition)) \ + ::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_##n).stream() +#define VLOG_IF(n, condition) \ + if (-n >= CAFFE2_LOG_THRESHOLD && (condition)) \ + ::c10::MessageLogger(__FILE__, __LINE__, -n).stream() + +#define VLOG_IS_ON(verboselevel) (CAFFE2_LOG_THRESHOLD <= -(verboselevel)) + +// Log with source location information override (to be used in generic +// warning/error handlers implemented as functions, not macros) +#define LOG_AT_FILE_LINE(n, file, line) \ + if (::c10::GLOG_##n >= CAFFE2_LOG_THRESHOLD) \ + ::c10::MessageLogger(file, line, ::c10::GLOG_##n).stream() + +// Log only if condition is met. Otherwise evaluates to void. +#define FATAL_IF(condition) \ + condition ? (void)0 \ + : ::c10::LoggerVoidify() & \ + ::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_FATAL).stream() + +// Check for a given boolean condition. +#define CHECK(condition) FATAL_IF(condition) << "Check failed: " #condition " " + +#ifndef NDEBUG +// Debug only version of CHECK +#define DCHECK(condition) FATAL_IF(condition) << "Check failed: " #condition " " +#define DLOG(severity) LOG(severity) +#else // NDEBUG +// Optimized version - generates no code. +#define DCHECK(condition) \ + while (false) \ + CHECK(condition) + +#define DLOG(n) \ + true ? (void)0 \ + : ::c10::LoggerVoidify() & \ + ::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_##n).stream() +#endif // NDEBUG + +// ---------------------- Support for std objects -------------------------- +// These are adapted from glog to support a limited set of logging capability +// for STL objects. + +namespace std { +// Forward declare these two, and define them after all the container streams +// operators so that we can recurse from pair -> container -> container -> pair +// properly. +template +std::ostream& operator<<(std::ostream& out, const std::pair& p); +} // namespace std + +namespace c10 { +template +void PrintSequence(std::ostream& ss, Iter begin, Iter end); +} // namespace c10 + +namespace std { +#define INSTANTIATE_FOR_CONTAINER(container) \ + template \ + std::ostream& operator<<( \ + std::ostream& out, const container& seq) { \ + c10::PrintSequence(out, seq.begin(), seq.end()); \ + return out; \ + } + +INSTANTIATE_FOR_CONTAINER(std::vector) +INSTANTIATE_FOR_CONTAINER(std::map) +INSTANTIATE_FOR_CONTAINER(std::set) +#undef INSTANTIATE_FOR_CONTAINER + +template +inline std::ostream& operator<<( + std::ostream& out, + const std::pair& p) { + out << '(' << p.first << ", " << p.second << ')'; + return out; +} + +inline std::ostream& operator<<( + std::ostream& out, + const std::nullptr_t& /*unused*/) { + out << "(null)"; + return out; +} +} // namespace std + +namespace c10 { +template +inline void PrintSequence(std::ostream& out, Iter begin, Iter end) { + // Output at most 100 elements -- appropriate if used for logging. + for (int i = 0; begin != end && i < 100; ++i, ++begin) { + if (i > 0) + out << ' '; + out << *begin; + } + if (begin != end) { + out << " ..."; + } +} +} // namespace c10 + +#endif // C10_UTIL_LOGGING_IS_NOT_GOOGLE_GLOG_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/quint2x4.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/quint2x4.h new file mode 100644 index 0000000000000000000000000000000000000000..b7781bc5772828da4ec97e1db4bbab2b7f54dd42 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/quint2x4.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/quint8.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/quint8.h new file mode 100644 index 0000000000000000000000000000000000000000..5445be70945ff028d6ad98cff1732b678c7245da --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/quint8.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/safe_numerics.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/safe_numerics.h new file mode 100644 index 0000000000000000000000000000000000000000..f376f9dfd8a529851dd45a9319da482eae1c7c60 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/safe_numerics.h @@ -0,0 +1,119 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +#include +#include +#include + +// GCC has __builtin_mul_overflow from before it supported __has_builtin +#ifdef _MSC_VER +#define C10_HAS_BUILTIN_OVERFLOW() (0) +#include +#include +#else +#define C10_HAS_BUILTIN_OVERFLOW() (1) +#endif + +namespace c10 { + +template , int> = 0> +C10_ALWAYS_INLINE bool add_overflows(T a, T b, T* out) { +#if C10_HAS_BUILTIN_OVERFLOW() + return __builtin_add_overflow(a, b, out); +#else + if constexpr (std::is_signed_v) { + // For signed types, detect overflow by checking sign changes + volatile T tmp = a + b; + *out = tmp; + + // If both operands have the same sign, check if result changed sign + // unexpectedly. + if ((a > 0) == (b > 0)) { + if ((a > 0) && (tmp <= 0)) { + return true; // Positive overflow + } + if ((a < 0) && (tmp >= 0)) { + return true; // Negative overflow + } + } + return false; + } else { + // For unsigned types, overflow causes wrap-around + volatile T tmp = a + b; + *out = tmp; + return (tmp < a || tmp < b); + } +#endif +} + +C10_ALWAYS_INLINE bool add_overflows(uint64_t a, uint64_t b, uint64_t* out) { + return add_overflows(a, b, out); +} + +template , int> = 0> +C10_ALWAYS_INLINE bool mul_overflows(T a, T b, T* out) { +#if C10_HAS_BUILTIN_OVERFLOW() + return __builtin_mul_overflow(a, b, out); +#else + if constexpr (std::is_signed_v) { + // For signed types, use the division-based check + volatile T tmp = a * b; + *out = tmp; + if (a == 0 || b == 0) { + return false; + } + return !(a == tmp / b); + } else { + // For unsigned types, use leading zeros approach + // This test isn't exact, but avoids doing integer division + *out = a * b; + constexpr int bits = sizeof(T) * 8; + return ( + (c10::llvm::countLeadingZeros(a) + c10::llvm::countLeadingZeros(b)) < + bits); + } +#endif +} + +C10_ALWAYS_INLINE bool mul_overflows(uint64_t a, uint64_t b, uint64_t* out) { + return mul_overflows(a, b, out); +} + +template +bool safe_multiplies_u64(It first, It last, uint64_t* out) { +#if C10_HAS_BUILTIN_OVERFLOW() + uint64_t prod = 1; + bool overflow = false; + for (; first != last; ++first) { + overflow |= c10::mul_overflows(prod, *first, &prod); + } + *out = prod; + return overflow; +#else + uint64_t prod = 1; + uint64_t prod_log2 = 0; + bool is_zero = false; + for (; first != last; ++first) { + auto x = static_cast(*first); + prod *= x; + // log2(0) isn't valid, so need to track it specially + is_zero |= (x == 0); + prod_log2 += c10::llvm::Log2_64_Ceil(x); + } + *out = prod; + // This test isn't exact, but avoids doing integer division + return !is_zero && (prod_log2 >= 64); +#endif +} + +template +bool safe_multiplies_u64(const Container& c, uint64_t* out) { + return safe_multiplies_u64(c.begin(), c.end(), out); +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/sparse_bitset.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/sparse_bitset.h new file mode 100644 index 0000000000000000000000000000000000000000..877b4fb52f0ed04a6bb555201f0cc58163bfe552 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/sparse_bitset.h @@ -0,0 +1,898 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +//===- llvm/ADT/SparseBitVector.h - Efficient Sparse BitVector --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the SparseBitVector class. See the doxygen comment for +// SparseBitVector for more details on the algorithm used. +// +//===----------------------------------------------------------------------===// + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +/// SparseBitVector is an implementation of a bitvector that is sparse by only +/// storing the elements that have non-zero bits set. In order to make this +/// fast for the most common cases, SparseBitVector is implemented as a linked +/// list of SparseBitVectorElements. We maintain a pointer to the last +/// SparseBitVectorElement accessed (in the form of a list iterator), in order +/// to make multiple in-order test/set constant time after the first one is +/// executed. Note that using vectors to store SparseBitVectorElement's does +/// not work out very well because it causes insertion in the middle to take +/// enormous amounts of time with a large amount of bits. Other structures that +/// have better worst cases for insertion in the middle (various balanced trees, +/// etc) do not perform as well in practice as a linked list with this iterator +/// kept up to date. They are also significantly more memory intensive. + +template +struct SparseBitVectorElement { + public: + using BitWord = unsigned long; + using size_type = unsigned; + enum { + BITWORD_SIZE = sizeof(BitWord) * CHAR_BIT, + BITWORDS_PER_ELEMENT = (ElementSize + BITWORD_SIZE - 1) / BITWORD_SIZE, + BITS_PER_ELEMENT = ElementSize + }; + + private: + // Index of Element in terms of where first bit starts. + unsigned ElementIndex; + std::array Bits{}; + + SparseBitVectorElement() : ElementIndex(~0U) {} + + public: + explicit SparseBitVectorElement(unsigned Idx) : ElementIndex(Idx) {} + + // Comparison. + bool operator==(const SparseBitVectorElement& RHS) const { + if (ElementIndex != RHS.ElementIndex) + return false; + for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) + if (Bits[i] != RHS.Bits[i]) + return false; + return true; + } + + bool operator!=(const SparseBitVectorElement& RHS) const { + return !(*this == RHS); + } + + // Return the bits that make up word Idx in our element. + BitWord word(unsigned Idx) const { + assert(Idx < BITWORDS_PER_ELEMENT); + return Bits[Idx]; + } + + unsigned index() const { + return ElementIndex; + } + + bool empty() const { + for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) + if (Bits[i]) + return false; + return true; + } + + void set(unsigned Idx) { + Bits[Idx / BITWORD_SIZE] |= 1L << (Idx % BITWORD_SIZE); + } + + bool test_and_set(unsigned Idx) { + bool old = test(Idx); + if (!old) { + set(Idx); + return true; + } + return false; + } + + void reset(unsigned Idx) { + Bits[Idx / BITWORD_SIZE] &= ~(1L << (Idx % BITWORD_SIZE)); + } + + bool test(unsigned Idx) const { + return Bits[Idx / BITWORD_SIZE] & (1L << (Idx % BITWORD_SIZE)); + } + + size_type count() const { + unsigned NumBits = 0; + for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) + NumBits += llvm::countPopulation(Bits[i]); + return NumBits; + } + + /// find_first - Returns the index of the first set bit. + int find_first() const { + for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) + if (Bits[i] != 0) + return i * BITWORD_SIZE + llvm::countTrailingZeros(Bits[i]); + throw std::runtime_error("Illegal empty element"); + } + + /// find_last - Returns the index of the last set bit. + int find_last() const { + for (unsigned I = 0; I < BITWORDS_PER_ELEMENT; ++I) { + unsigned Idx = BITWORDS_PER_ELEMENT - I - 1; + if (Bits[Idx] != 0) + return Idx * BITWORD_SIZE + BITWORD_SIZE - + llvm::countLeadingZeros(Bits[Idx]); + } + throw std::runtime_error("Illegal empty element"); + } + + /// find_next - Returns the index of the next set bit starting from the + /// "Curr" bit. Returns -1 if the next set bit is not found. + int find_next(unsigned Curr) const { + if (Curr >= BITS_PER_ELEMENT) + return -1; + + unsigned WordPos = Curr / BITWORD_SIZE; + unsigned BitPos = Curr % BITWORD_SIZE; + BitWord Copy = Bits[WordPos]; + assert( + WordPos <= BITWORDS_PER_ELEMENT && "Word Position outside of element"); + + // Mask off previous bits. + Copy &= ~0UL << BitPos; + + if (Copy != 0) + return WordPos * BITWORD_SIZE + llvm::countTrailingZeros(Copy); + + // Check subsequent words. + for (unsigned i = WordPos + 1; i < BITWORDS_PER_ELEMENT; ++i) + if (Bits[i] != 0) + return i * BITWORD_SIZE + llvm::countTrailingZeros(Bits[i]); + return -1; + } + + // Union this element with RHS and return true if this one changed. + bool unionWith(const SparseBitVectorElement& RHS) { + bool changed = false; + for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) { + BitWord old = changed ? 0 : Bits[i]; + + Bits[i] |= RHS.Bits[i]; + if (!changed && old != Bits[i]) + changed = true; + } + return changed; + } + + // Return true if we have any bits in common with RHS + bool intersects(const SparseBitVectorElement& RHS) const { + for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) { + if (RHS.Bits[i] & Bits[i]) + return true; + } + return false; + } + + // Intersect this Element with RHS and return true if this one changed. + // BecameZero is set to true if this element became all-zero bits. + bool intersectWith(const SparseBitVectorElement& RHS, bool& BecameZero) { + bool changed = false; + bool allzero = true; + + for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) { + BitWord old = changed ? 0 : Bits[i]; + + Bits[i] &= RHS.Bits[i]; + if (Bits[i] != 0) + allzero = false; + + if (!changed && old != Bits[i]) + changed = true; + } + BecameZero = allzero; + return changed; + } + + // Intersect this Element with the complement of RHS and return true if this + // one changed. BecameZero is set to true if this element became all-zero + // bits. + bool intersectWithComplement( + const SparseBitVectorElement& RHS, + bool& BecameZero) { + bool changed = false; + bool allzero = true; + + for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) { + BitWord old = changed ? 0 : Bits[i]; + + Bits[i] &= ~RHS.Bits[i]; + if (Bits[i] != 0) + allzero = false; + + if (!changed && old != Bits[i]) + changed = true; + } + BecameZero = allzero; + return changed; + } + + // Three argument version of intersectWithComplement that intersects + // RHS1 & ~RHS2 into this element + void intersectWithComplement( + const SparseBitVectorElement& RHS1, + const SparseBitVectorElement& RHS2, + bool& BecameZero) { + bool allzero = true; + + for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) { + Bits[i] = RHS1.Bits[i] & ~RHS2.Bits[i]; + if (Bits[i] != 0) + allzero = false; + } + BecameZero = allzero; + } +}; + +template +class SparseBitVector { + using ElementList = std::list>; + using ElementListIter = typename ElementList::iterator; + using ElementListConstIter = typename ElementList::const_iterator; + enum { BITWORD_SIZE = SparseBitVectorElement::BITWORD_SIZE }; + + ElementList Elements; + // Pointer to our current Element. This has no visible effect on the external + // state of a SparseBitVector, it's just used to improve performance in the + // common case of testing/modifying bits with similar indices. + mutable ElementListIter CurrElementIter; + + // This is like std::lower_bound, except we do linear searching from the + // current position. + ElementListIter FindLowerBoundImpl(unsigned ElementIndex) const { + // We cache a non-const iterator so we're forced to resort to const_cast to + // get the begin/end in the case where 'this' is const. To avoid duplication + // of code with the only difference being whether the const cast is present + // 'this' is always const in this particular function and we sort out the + // difference in FindLowerBound and FindLowerBoundConst. + ElementListIter Begin = + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + const_cast*>(this)->Elements.begin(); + ElementListIter End = + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + const_cast*>(this)->Elements.end(); + + if (Elements.empty()) { + CurrElementIter = Begin; + return CurrElementIter; + } + + // Make sure our current iterator is valid. + if (CurrElementIter == End) + --CurrElementIter; + + // Search from our current iterator, either backwards or forwards, + // depending on what element we are looking for. + ElementListIter ElementIter = CurrElementIter; + if (CurrElementIter->index() == ElementIndex) { + return ElementIter; + } else if (CurrElementIter->index() > ElementIndex) { + while (ElementIter != Begin && ElementIter->index() > ElementIndex) + --ElementIter; + } else { + while (ElementIter != End && ElementIter->index() < ElementIndex) + ++ElementIter; + } + CurrElementIter = ElementIter; + return ElementIter; + } + ElementListConstIter FindLowerBoundConst(unsigned ElementIndex) const { + return FindLowerBoundImpl(ElementIndex); + } + ElementListIter FindLowerBound(unsigned ElementIndex) { + return FindLowerBoundImpl(ElementIndex); + } + + // Iterator to walk set bits in the bitmap. This iterator is a lot uglier + // than it would be, in order to be efficient. + class SparseBitVectorIterator { + private: + bool AtEnd{false}; + + const SparseBitVector* BitVector = nullptr; + + // Current element inside of bitmap. + ElementListConstIter Iter; + + // Current bit number inside of our bitmap. + unsigned BitNumber{0}; + + // Current word number inside of our element. + unsigned WordNumber{0}; + + // Current bits from the element. + typename SparseBitVectorElement::BitWord Bits{0}; + + // Move our iterator to the first non-zero bit in the bitmap. + void AdvanceToFirstNonZero() { + if (AtEnd) + return; + if (BitVector->Elements.empty()) { + AtEnd = true; + return; + } + Iter = BitVector->Elements.begin(); + BitNumber = Iter->index() * ElementSize; + unsigned BitPos = Iter->find_first(); + BitNumber += BitPos; + WordNumber = (BitNumber % ElementSize) / BITWORD_SIZE; + Bits = Iter->word(WordNumber); + Bits >>= BitPos % BITWORD_SIZE; + } + + // Move our iterator to the next non-zero bit. + void AdvanceToNextNonZero() { + if (AtEnd) + return; + + while (Bits && !(Bits & 1)) { + Bits >>= 1; + BitNumber += 1; + } + + // See if we ran out of Bits in this word. + if (!Bits) { + int NextSetBitNumber = Iter->find_next(BitNumber % ElementSize); + // If we ran out of set bits in this element, move to next element. + if (NextSetBitNumber == -1 || (BitNumber % ElementSize == 0)) { + ++Iter; + WordNumber = 0; + + // We may run out of elements in the bitmap. + if (Iter == BitVector->Elements.end()) { + AtEnd = true; + return; + } + // Set up for next non-zero word in bitmap. + BitNumber = Iter->index() * ElementSize; + NextSetBitNumber = Iter->find_first(); + BitNumber += NextSetBitNumber; + WordNumber = (BitNumber % ElementSize) / BITWORD_SIZE; + Bits = Iter->word(WordNumber); + Bits >>= NextSetBitNumber % BITWORD_SIZE; + } else { + WordNumber = (NextSetBitNumber % ElementSize) / BITWORD_SIZE; + Bits = Iter->word(WordNumber); + Bits >>= NextSetBitNumber % BITWORD_SIZE; + BitNumber = Iter->index() * ElementSize; + BitNumber += NextSetBitNumber; + } + } + } + + public: + SparseBitVectorIterator() = default; + + SparseBitVectorIterator( + const SparseBitVector* RHS, + bool end = false) + : AtEnd(end), + BitVector(RHS), + Iter(BitVector->Elements.begin()), + WordNumber(~0) { + AdvanceToFirstNonZero(); + } + + // Preincrement. + inline SparseBitVectorIterator& operator++() { + ++BitNumber; + Bits >>= 1; + AdvanceToNextNonZero(); + return *this; + } + + // Postincrement. + inline SparseBitVectorIterator operator++(int) { + SparseBitVectorIterator tmp = *this; + ++*this; + return tmp; + } + + // Return the current set bit number. + unsigned operator*() const { + return BitNumber; + } + + bool operator==(const SparseBitVectorIterator& RHS) const { + // If they are both at the end, ignore the rest of the fields. + if (AtEnd && RHS.AtEnd) + return true; + // Otherwise they are the same if they have the same bit number and + // bitmap. + return AtEnd == RHS.AtEnd && RHS.BitNumber == BitNumber; + } + + bool operator!=(const SparseBitVectorIterator& RHS) const { + return !(*this == RHS); + } + }; + + public: + using iterator = SparseBitVectorIterator; + + SparseBitVector() : Elements(), CurrElementIter(Elements.begin()) {} + + SparseBitVector(const SparseBitVector& RHS) + : Elements(RHS.Elements), CurrElementIter(Elements.begin()) {} + SparseBitVector(SparseBitVector&& RHS) noexcept + : Elements(std::move(RHS.Elements)), CurrElementIter(Elements.begin()) {} + ~SparseBitVector() = default; + + // Clear. + void clear() { + Elements.clear(); + } + + // Assignment + SparseBitVector& operator=(const SparseBitVector& RHS) { + if (this == &RHS) + return *this; + + Elements = RHS.Elements; + CurrElementIter = Elements.begin(); + return *this; + } + SparseBitVector& operator=(SparseBitVector&& RHS) noexcept { + Elements = std::move(RHS.Elements); + CurrElementIter = Elements.begin(); + return *this; + } + + // Test, Reset, and Set a bit in the bitmap. + bool test(unsigned Idx) const { + if (Elements.empty()) + return false; + + unsigned ElementIndex = Idx / ElementSize; + ElementListConstIter ElementIter = FindLowerBoundConst(ElementIndex); + + // If we can't find an element that is supposed to contain this bit, there + // is nothing more to do. + if (ElementIter == Elements.end() || ElementIter->index() != ElementIndex) + return false; + return ElementIter->test(Idx % ElementSize); + } + + void reset(unsigned Idx) { + if (Elements.empty()) + return; + + unsigned ElementIndex = Idx / ElementSize; + ElementListIter ElementIter = FindLowerBound(ElementIndex); + + // If we can't find an element that is supposed to contain this bit, there + // is nothing more to do. + if (ElementIter == Elements.end() || ElementIter->index() != ElementIndex) + return; + ElementIter->reset(Idx % ElementSize); + + // When the element is zeroed out, delete it. + if (ElementIter->empty()) { + ++CurrElementIter; + Elements.erase(ElementIter); + } + } + + void set(unsigned Idx) { + unsigned ElementIndex = Idx / ElementSize; + ElementListIter ElementIter; + if (Elements.empty()) { + ElementIter = Elements.emplace(Elements.end(), ElementIndex); + } else { + ElementIter = FindLowerBound(ElementIndex); + + if (ElementIter == Elements.end() || + ElementIter->index() != ElementIndex) { + // We may have hit the beginning of our SparseBitVector, in which case, + // we may need to insert right after this element, which requires moving + // the current iterator forward one, because insert does insert before. + if (ElementIter != Elements.end() && + ElementIter->index() < ElementIndex) + ++ElementIter; + ElementIter = Elements.emplace(ElementIter, ElementIndex); + } + } + CurrElementIter = ElementIter; + + ElementIter->set(Idx % ElementSize); + } + + bool test_and_set(unsigned Idx) { + bool old = test(Idx); + if (!old) { + set(Idx); + return true; + } + return false; + } + + bool operator!=(const SparseBitVector& RHS) const { + return !(*this == RHS); + } + + bool operator==(const SparseBitVector& RHS) const { + ElementListConstIter Iter1 = Elements.begin(); + ElementListConstIter Iter2 = RHS.Elements.begin(); + + for (; Iter1 != Elements.end() && Iter2 != RHS.Elements.end(); + ++Iter1, ++Iter2) { + if (*Iter1 != *Iter2) + return false; + } + return Iter1 == Elements.end() && Iter2 == RHS.Elements.end(); + } + + // Union our bitmap with the RHS and return true if we changed. + bool operator|=(const SparseBitVector& RHS) { + if (this == &RHS) + return false; + + if (empty()) { + *this = RHS; + return true; + } + + bool changed = false; + ElementListIter Iter1 = Elements.begin(); + ElementListConstIter Iter2 = RHS.Elements.begin(); + + // If RHS is empty, we are done + if (RHS.Elements.empty()) + return false; + + while (Iter2 != RHS.Elements.end()) { + if (Iter1 == Elements.end() || Iter1->index() > Iter2->index()) { + Elements.insert(Iter1, *Iter2); + ++Iter2; + changed = true; + } else if (Iter1->index() == Iter2->index()) { + changed |= Iter1->unionWith(*Iter2); + ++Iter1; + ++Iter2; + } else { + ++Iter1; + } + } + CurrElementIter = Elements.begin(); + return changed; + } + + // Intersect our bitmap with the RHS and return true if ours changed. + bool operator-=(const SparseBitVector& RHS) { + return intersectWithComplement(RHS); + } + + // Intersect our bitmap with the RHS and return true if ours changed. + bool operator&=(const SparseBitVector& RHS) { + if (this == &RHS) + return false; + + bool changed = false; + ElementListIter Iter1 = Elements.begin(); + ElementListConstIter Iter2 = RHS.Elements.begin(); + + // Check if both bitmaps are empty. + if (Elements.empty() && RHS.Elements.empty()) + return false; + + // Loop through, intersecting as we go, erasing elements when necessary. + while (Iter2 != RHS.Elements.end()) { + if (Iter1 == Elements.end()) { + CurrElementIter = Elements.begin(); + return changed; + } + + if (Iter1->index() > Iter2->index()) { + ++Iter2; + } else if (Iter1->index() == Iter2->index()) { + bool BecameZero = false; + changed |= Iter1->intersectWith(*Iter2, BecameZero); + if (BecameZero) { + ElementListIter IterTmp = Iter1; + ++Iter1; + Elements.erase(IterTmp); + } else { + ++Iter1; + } + ++Iter2; + } else { + ElementListIter IterTmp = Iter1; + ++Iter1; + Elements.erase(IterTmp); + changed = true; + } + } + if (Iter1 != Elements.end()) { + Elements.erase(Iter1, Elements.end()); + changed = true; + } + CurrElementIter = Elements.begin(); + return changed; + } + + // Intersect our bitmap with the complement of the RHS and return true + // if ours changed. + bool intersectWithComplement(const SparseBitVector& RHS) { + if (this == &RHS) { + if (!empty()) { + clear(); + return true; + } + return false; + } + + bool changed = false; + ElementListIter Iter1 = Elements.begin(); + ElementListConstIter Iter2 = RHS.Elements.begin(); + + // If either our bitmap or RHS is empty, we are done + if (Elements.empty() || RHS.Elements.empty()) + return false; + + // Loop through, intersecting as we go, erasing elements when necessary. + while (Iter2 != RHS.Elements.end()) { + if (Iter1 == Elements.end()) { + CurrElementIter = Elements.begin(); + return changed; + } + + if (Iter1->index() > Iter2->index()) { + ++Iter2; + } else if (Iter1->index() == Iter2->index()) { + bool BecameZero = false; + changed |= Iter1->intersectWithComplement(*Iter2, BecameZero); + if (BecameZero) { + ElementListIter IterTmp = Iter1; + ++Iter1; + Elements.erase(IterTmp); + } else { + ++Iter1; + } + ++Iter2; + } else { + ++Iter1; + } + } + CurrElementIter = Elements.begin(); + return changed; + } + + bool intersectWithComplement(const SparseBitVector* RHS) const { + return intersectWithComplement(*RHS); + } + + // Three argument version of intersectWithComplement. + // Result of RHS1 & ~RHS2 is stored into this bitmap. + void intersectWithComplement( + const SparseBitVector& RHS1, + const SparseBitVector& RHS2) { + if (this == &RHS1) { + intersectWithComplement(RHS2); + return; + } else if (this == &RHS2) { + SparseBitVector RHS2Copy(RHS2); + intersectWithComplement(RHS1, RHS2Copy); + return; + } + + Elements.clear(); + CurrElementIter = Elements.begin(); + ElementListConstIter Iter1 = RHS1.Elements.begin(); + ElementListConstIter Iter2 = RHS2.Elements.begin(); + + // If RHS1 is empty, we are done + // If RHS2 is empty, we still have to copy RHS1 + if (RHS1.Elements.empty()) + return; + + // Loop through, intersecting as we go, erasing elements when necessary. + while (Iter2 != RHS2.Elements.end()) { + if (Iter1 == RHS1.Elements.end()) + return; + + if (Iter1->index() > Iter2->index()) { + ++Iter2; + } else if (Iter1->index() == Iter2->index()) { + bool BecameZero = false; + Elements.emplace_back(Iter1->index()); + Elements.back().intersectWithComplement(*Iter1, *Iter2, BecameZero); + if (BecameZero) + Elements.pop_back(); + ++Iter1; + ++Iter2; + } else { + Elements.push_back(*Iter1++); + } + } + + // copy the remaining elements + std::copy(Iter1, RHS1.Elements.end(), std::back_inserter(Elements)); + } + + void intersectWithComplement( + const SparseBitVector* RHS1, + const SparseBitVector* RHS2) { + intersectWithComplement(*RHS1, *RHS2); + } + + bool intersects(const SparseBitVector* RHS) const { + return intersects(*RHS); + } + + // Return true if we share any bits in common with RHS + bool intersects(const SparseBitVector& RHS) const { + ElementListConstIter Iter1 = Elements.begin(); + ElementListConstIter Iter2 = RHS.Elements.begin(); + + // Check if both bitmaps are empty. + if (Elements.empty() && RHS.Elements.empty()) + return false; + + // Loop through, intersecting stopping when we hit bits in common. + while (Iter2 != RHS.Elements.end()) { + if (Iter1 == Elements.end()) + return false; + + if (Iter1->index() > Iter2->index()) { + ++Iter2; + } else if (Iter1->index() == Iter2->index()) { + if (Iter1->intersects(*Iter2)) + return true; + ++Iter1; + ++Iter2; + } else { + ++Iter1; + } + } + return false; + } + + // Return true iff all bits set in this SparseBitVector are + // also set in RHS. + bool contains(const SparseBitVector& RHS) const { + SparseBitVector Result(*this); + Result &= RHS; + return (Result == RHS); + } + + // Return the first set bit in the bitmap. Return -1 if no bits are set. + int find_first() const { + if (Elements.empty()) + return -1; + const SparseBitVectorElement& First = *(Elements.begin()); + return (First.index() * ElementSize) + First.find_first(); + } + + // Return the last set bit in the bitmap. Return -1 if no bits are set. + int find_last() const { + if (Elements.empty()) + return -1; + const SparseBitVectorElement& Last = *(Elements.rbegin()); + return (Last.index() * ElementSize) + Last.find_last(); + } + + // Return true if the SparseBitVector is empty + bool empty() const { + return Elements.empty(); + } + + unsigned count() const { + unsigned BitCount = 0; + for (ElementListConstIter Iter = Elements.begin(); Iter != Elements.end(); + ++Iter) + BitCount += Iter->count(); + + return BitCount; + } + + iterator begin() const { + return iterator(this); + } + + iterator end() const { + return iterator(this, true); + } +}; + +// Convenience functions to allow Or and And without dereferencing in the user +// code. + +template +inline bool operator|=( + SparseBitVector& LHS, + const SparseBitVector* RHS) { + return LHS |= *RHS; +} + +template +inline bool operator|=( + SparseBitVector* LHS, + const SparseBitVector& RHS) { + return LHS->operator|=(RHS); +} + +template +inline bool operator&=( + SparseBitVector* LHS, + const SparseBitVector& RHS) { + return LHS->operator&=(RHS); +} + +template +inline bool operator&=( + SparseBitVector& LHS, + const SparseBitVector* RHS) { + return LHS &= *RHS; +} + +// Convenience functions for infix union, intersection, difference operators. + +template +inline SparseBitVector operator|( + const SparseBitVector& LHS, + const SparseBitVector& RHS) { + SparseBitVector Result(LHS); + Result |= RHS; + return Result; +} + +template +inline SparseBitVector operator&( + const SparseBitVector& LHS, + const SparseBitVector& RHS) { + SparseBitVector Result(LHS); + Result &= RHS; + return Result; +} + +template +inline SparseBitVector operator-( + const SparseBitVector& LHS, + const SparseBitVector& RHS) { + SparseBitVector Result; + Result.intersectWithComplement(LHS, RHS); + return Result; +} + +template +std::ostream& operator<<( + std::ostream& stream, + const SparseBitVector& vec) { + bool first = true; + stream << '{'; + for (auto el : vec) { + if (first) { + first = false; + } else { + stream << ", "; + } + stream << el; + } + stream << '}'; + return stream; +} + +} // end namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/ssize.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/ssize.h new file mode 100644 index 0000000000000000000000000000000000000000..395bf8a2eb7c5ef35f0de9530f4f9ccb9fe18e42 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/ssize.h @@ -0,0 +1,51 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +#include +#include + +namespace c10 { + +// Implementations of std::ssize() from C++ 20. +// +// This is useful in particular for avoiding -Werror=sign-compare +// issues. +// +// Use this with argument-dependent lookup, e.g.: +// use c10::ssize; +// auto size = ssize(container); +// +// As with the standard library version, containers are permitted to +// specialize this with a free function defined in the same namespace. +// +// See https://en.cppreference.com/w/cpp/iterator/size for more +// information as well as the source of our implementations. +// +// We augment the implementation by adding an assert() if an overflow +// would occur. + +template +constexpr auto ssize(const C& c) -> std:: + common_type_t> { + using R = std:: + common_type_t>; + // We expect this to be exceedingly rare to fire and don't wish to + // pay a performance hit in release mode. + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!greater_than_max(c.size())); + return static_cast(c.size()); +} + +template +// NOLINTNEXTLINE(*-c-arrays) +constexpr auto ssize(const T (&array)[N]) noexcept -> std::ptrdiff_t { + return N; +} + +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/tempfile.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/tempfile.h new file mode 100644 index 0000000000000000000000000000000000000000..afcf4504c87a49112a6f4f21c6fe08f153c49a2a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/tempfile.h @@ -0,0 +1,94 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include + +namespace c10 { +struct C10_API TempFile { + TempFile(std::string_view name, int fd = -1) noexcept : fd(fd), name(name) {} + TempFile(const TempFile&) = delete; + TempFile(TempFile&& other) noexcept + : fd(other.fd), name(std::move(other.name)) { + other.fd = -1; + } + + TempFile& operator=(const TempFile&) = delete; + TempFile& operator=(TempFile&& other) noexcept { + fd = other.fd; + name = std::move(other.name); + other.fd = -1; + return *this; + } +#if defined(_WIN32) + bool open(); +#endif + + ~TempFile(); + + int fd; + + std::string name; +}; + +struct C10_API TempDir { + TempDir() = delete; + explicit TempDir(std::string_view name) noexcept : name(name) {} + TempDir(const TempDir&) = delete; + TempDir(TempDir&& other) noexcept : name(std::move(other.name)) { + other.name.clear(); + } + + TempDir& operator=(const TempDir&) = delete; + TempDir& operator=(TempDir&& other) noexcept { + name = std::move(other.name); + return *this; + } + + ~TempDir(); + + std::string name; +}; + +/// Attempts to return a temporary file or returns `nullopt` if an error +/// occurred. +/// +/// The file returned follows the pattern +/// `/`, where `` is the value of +/// the `"TMPDIR"`, `"TMP"`, `"TEMP"` or +/// `"TEMPDIR"` environment variable if any is set, or otherwise `/tmp`; +/// `` is the value supplied to this function, and +/// `` is a random sequence of numbers. +/// On Windows, `name_prefix` is ignored and `tmpnam_s` is used, +/// and no temporary file is opened. +C10_API std::optional try_make_tempfile( + std::string_view name_prefix = "torch-file-"); + +/// Like `try_make_tempfile`, but throws an exception if a temporary file could +/// not be returned. +C10_API TempFile make_tempfile(std::string_view name_prefix = "torch-file-"); + +/// Attempts to return a temporary directory or returns `nullopt` if an error +/// occurred. +/// +/// The directory returned follows the pattern +/// `//`, where `` is the value +/// of the `"TMPDIR"`, `"TMP"`, `"TEMP"` or +/// `"TEMPDIR"` environment variable if any is set, or otherwise `/tmp`; +/// `` is the value supplied to this function, and +/// `` is a random sequence of numbers. +/// On Windows, `name_prefix` is ignored. +C10_API std::optional try_make_tempdir( + std::string_view name_prefix = "torch-dir-"); + +/// Like `try_make_tempdir`, but throws an exception if a temporary directory +/// could not be returned. +C10_API TempDir make_tempdir(std::string_view name_prefix = "torch-dir-"); +} // namespace c10 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/typeid.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/typeid.h new file mode 100644 index 0000000000000000000000000000000000000000..3f7da4264ad5339af2535aeb10863ef07c75515a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/typeid.h @@ -0,0 +1,720 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +/* + * TypeIdentifier is a small type containing an id. + * Types must be registered using CAFFE_DECLARE_KNOWN_TYPE() (in their header) + * and CAFFE_DEFINE_KNOWN_TYPE() (in their .cpp file) for them to have a type + * id. If a type is registered, you can also create an object containing meta + * data like constructor, destructor, stringified name, ... about the type by + * calling TypeMeta::Make. This returns a TypeMeta() object, which is + * basically just a pointer to the type information, so it's cheap to pass + * around. + */ + +// TODO: This file is still in the caffe2 namespace, despite living +// in the ATen directory. This is because the macro +// CAFFE_KNOWN_TYPE (and CAFFE_DECLARE_KNOWN_TYPE) defines a template +// specialization, which relies +// on the namespace of TypeMeta matching the namespace where the macro is +// called. This requires us to fix all of the call-sites, which I want to do +// later. So the namespace is not fixed at the moment. + +// Make at::Half a fundamental type. + +namespace c10::guts { +template <> +struct is_fundamental : std::true_type {}; +} // namespace c10::guts + +namespace caffe2 { + +/** + * A type id is a unique id for a given C++ type. + * You need to register your types using CAFFE_KNOWN_TYPE(MyType) to be able to + * use TypeIdentifier with custom types. This is for example used to store the + * dtype of tensors. + */ +class C10_API TypeIdentifier final + : public at::IdWrapper { + public: + friend std::ostream& operator<<(std::ostream& stream, TypeIdentifier typeId); + friend constexpr bool operator<(TypeIdentifier lhs, TypeIdentifier rhs); + + /** + * Returns the unique id for the given type T. The id is unique for the type T + * in the sense that for any two different types, their ids are different; for + * the same type T, the id remains the same over different calls of the + * function. However, this is not guaranteed over different runs, as the id + * is generated during run-time. Do NOT serialize the id for storage. + */ + template + static constexpr TypeIdentifier Get() noexcept { + return TypeIdentifier(c10::util::get_type_index()); + } + + static constexpr TypeIdentifier uninitialized() { + return TypeIdentifier(c10::util::type_index{0}); + } + + private: + constexpr explicit TypeIdentifier(c10::util::type_index id) : IdWrapper(id) {} +}; + +// Allow usage in std::map / std::set +// TODO Disallow this and rather use std::unordered_map/set everywhere +inline constexpr bool operator<(TypeIdentifier lhs, TypeIdentifier rhs) { + return lhs.underlyingId() < rhs.underlyingId(); +} + +inline std::ostream& operator<<( + std::ostream& stream, + caffe2::TypeIdentifier typeId) { + return stream << typeId.underlyingId(); +} + +} // namespace caffe2 + +namespace at { +using DataType = caffe2::TypeIdentifier; +} + +C10_DEFINE_HASH_FOR_IDWRAPPER(caffe2::TypeIdentifier) + +namespace caffe2 { + +namespace detail { + +// This struct holds the actual type information. There will be +// one allocated per type. TypeMeta objects will then point to the struct +// instance for the type they're configured for. +struct TypeMetaData final { + using New = void*(); + using PlacementNew = void(void*, size_t); + using Copy = void(const void*, void*, size_t); + using PlacementDelete = void(void*, size_t); + using Delete = void(void*); + + constexpr TypeMetaData() noexcept + : itemsize_(0), + new_(nullptr), + placementNew_(nullptr), + copy_(nullptr), + placementDelete_(nullptr), + delete_(nullptr), + id_(TypeIdentifier::uninitialized()), + name_("nullptr (uninitialized)") {} + + constexpr TypeMetaData( + size_t itemsize, + New* newFn, + PlacementNew* placementNew, + Copy* copy, + PlacementDelete* placementDelete, + Delete* deleteFn, + TypeIdentifier id, + std::string_view name) noexcept + : itemsize_(itemsize), + new_(newFn), + placementNew_(placementNew), + copy_(copy), + placementDelete_(placementDelete), + delete_(deleteFn), + id_(id), + name_(name) {} + + size_t itemsize_; + New* new_; + PlacementNew* placementNew_; + Copy* copy_; + PlacementDelete* placementDelete_; + Delete* delete_; + TypeIdentifier id_; + std::string_view name_; +}; + +// Mechanism for throwing errors which can't be prevented at compile time +// due to type erasure. E.g. somebody calling TypeMeta::copy() for +// non-copyable type. Right now just throws exception but is implemented +// in .cpp to manage dependencies +[[noreturn]] C10_API void _ThrowRuntimeTypeLogicError(const std::string& msg); + +/** + * Placement new function for the type. + */ +template +inline void _PlacementNew(void* ptr, size_t n) { + T* typed_ptr = static_cast(ptr); + for (const auto i : c10::irange(n)) { + new (typed_ptr + i) T; + } +} + +template +inline void _PlacementNewNotDefault(void* /*ptr*/, size_t /*n*/) { + _ThrowRuntimeTypeLogicError( + "Type " + std::string(c10::util::get_fully_qualified_type_name()) + + " is not default-constructible."); +} + +template < + typename T, + std::enable_if_t>* = nullptr> +inline constexpr TypeMetaData::PlacementNew* _PickPlacementNew() { + return (c10::guts::is_fundamental::value || std::is_pointer_v) + ? nullptr + : &_PlacementNew; +} + +template < + typename T, + std::enable_if_t>* = nullptr> +inline constexpr TypeMetaData::PlacementNew* _PickPlacementNew() { + static_assert( + !c10::guts::is_fundamental::value && !std::is_pointer_v, + "this should have picked the other SFINAE case"); + return &_PlacementNewNotDefault; +} + +template +inline void* _New() { + return new T; +} + +template +inline void* _NewNotDefault() { + _ThrowRuntimeTypeLogicError( + "Type " + std::string(c10::util::get_fully_qualified_type_name()) + + " is not default-constructible."); +} + +template < + typename T, + std::enable_if_t>* = nullptr> +inline constexpr TypeMetaData::New* _PickNew() { + return &_New; +} + +template < + typename T, + std::enable_if_t>* = nullptr> +inline constexpr TypeMetaData::New* _PickNew() { + return &_NewNotDefault; +} + +/** + * Typed copy function for classes. + */ +template +inline void _Copy(const void* src, void* dst, size_t n) { + const T* typed_src = static_cast(src); + T* typed_dst = static_cast(dst); + for (const auto i : c10::irange(n)) { + typed_dst[i] = typed_src[i]; + } +} + +/** + * A placeholder function for types that do not allow assignment. + */ +template +inline void _CopyNotAllowed(const void* /*src*/, void* /*dst*/, size_t /*n*/) { + _ThrowRuntimeTypeLogicError( + "Type " + std::string(c10::util::get_fully_qualified_type_name()) + + " does not allow assignment."); +} + +template >* = nullptr> +inline constexpr TypeMetaData::Copy* _PickCopy() { + return (c10::guts::is_fundamental::value || std::is_pointer_v) + ? nullptr + : &_Copy; +} + +template < + typename T, + std::enable_if_t>* = nullptr> +inline constexpr TypeMetaData::Copy* _PickCopy() { + static_assert( + !c10::guts::is_fundamental::value && !std::is_pointer_v, + "this should have picked the other SFINAE case"); + return &_CopyNotAllowed; +} + +/** + * Destructor for non-fundamental types. + */ +template +inline void _PlacementDelete(void* ptr, size_t n) { + T* typed_ptr = static_cast(ptr); + for (const auto i : c10::irange(n)) { + typed_ptr[i].~T(); + } +} + +template +inline constexpr TypeMetaData::PlacementDelete* _PickPlacementDelete() { + return (c10::guts::is_fundamental::value || std::is_pointer_v) + ? nullptr + : &_PlacementDelete; +} + +template +inline void _Delete(void* ptr) { + T* typed_ptr = static_cast(ptr); + delete typed_ptr; +} + +template +inline constexpr TypeMetaData::Delete* _PickDelete() noexcept { + return &_Delete; +} + +class _Uninitialized final {}; + +} // namespace detail + +// +// note: this is outside TypeMeta bc gcc seems to have trouble +// with scalarTypeItemSizes as a constexpr static member used by +// a public inline instance method +// + +// item sizes for TypeMeta::itemsize() fast path +static constexpr std::array scalarTypeItemSizes = { +#define SCALAR_TYPE_SIZE(T, name) sizeof(T), + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SCALAR_TYPE_SIZE) +#undef SCALAR_TYPE_SIZE + 0, // Undefined +}; + +/** + * TypeMeta is a thin class that allows us to store the type of a container such + * as a blob, or the data type of a tensor, with a unique run-time id. It also + * stores some additional data such as the item size and the name of the type + * for run-time inspection. + */ +class C10_API TypeMeta final { + public: + using New = detail::TypeMetaData::New; + using PlacementNew = detail::TypeMetaData::PlacementNew; + using Copy = detail::TypeMetaData::Copy; + using PlacementDelete = detail::TypeMetaData::PlacementDelete; + using Delete = detail::TypeMetaData::Delete; + + /** Create a dummy TypeMeta object. To create a TypeMeta object for a specific + * type, use TypeMeta::Make(). + */ + TypeMeta() noexcept; + ~TypeMeta() = default; + + /** + * Copy constructor. + */ + TypeMeta(const TypeMeta& src) noexcept = default; + + /** + * Assignment operators. + */ + TypeMeta& operator=(const TypeMeta& src) noexcept = default; + + TypeMeta& operator=(TypeMeta&& src) noexcept = default; + TypeMeta(TypeMeta&& rhs) noexcept = default; + + inline TypeMeta& operator=(ScalarType scalar_type) noexcept { + index_ = static_cast(scalar_type); + return *this; + } + + private: + // TypeMeta can only be created by Make, making sure that we do not + // create incorrectly mixed up TypeMeta objects. + explicit TypeMeta(const uint16_t index) noexcept : index_(index) {} + + public: + /** + * Returns the type id. + */ + TypeIdentifier id() const noexcept { + return data().id_; + } + /** + * true if we represent some ScalarType type + */ + inline bool isScalarType() const noexcept { + return index_ < NumScalarTypes; + } + /** + * true if we represent ScalarType scalar_type + */ + inline bool isScalarType(ScalarType scalar_type) const noexcept { + return index_ == static_cast(scalar_type); + } + /** + * Returns the size of the item. + */ + inline size_t itemsize() const noexcept { + if (C10_LIKELY(isScalarType())) { + return scalarTypeItemSizes[index_]; + } + return data().itemsize_; + } + /** + * Returns the new function pointer for individual items. + */ + New* newFn() const noexcept { + return data().new_; + } + /** + * Returns the placement new function pointer for individual items. + */ + PlacementNew* placementNew() const noexcept { + return data().placementNew_; + } + /** + * Returns the typed copy function pointer for individual items. + */ + Copy* copy() const noexcept { + return data().copy_; + } + /** + * Returns the destructor function pointer for individual items. + */ + PlacementDelete* placementDelete() const noexcept { + return data().placementDelete_; + } + Delete* deleteFn() const noexcept { + return data().delete_; + } + /** + * Returns a printable name for the type. + */ + std::string_view name() const noexcept { + return data().name_; + } + + friend bool operator==(const TypeMeta& lhs, const TypeMeta& rhs) noexcept; + + template + bool Match() const noexcept { + return (*this == Make()); + } + + // Below are static functions that can be called by passing a specific type. + + template + static constexpr TypeIdentifier Id() noexcept { + return TypeIdentifier::Get(); + } + + template + static std::string_view TypeName() noexcept { + return c10::util::get_fully_qualified_type_name(); + } + + template + static constexpr size_t ItemSize() noexcept { + return sizeof(T); + } + + /** + * Returns a TypeMeta object that corresponds to the typename T. + */ + template + static TypeMeta Make() { + // The instance pointed to is declared here, but defined in a .cpp file. + // We need to silence the compiler warning about using an undefined + // variable template. '-Wpragmas' and '-Wunknown-warning-option' has to be + // disabled for compilers that don't know '-Wundefined-var-template' and + // would error at our attempt to disable it. +#ifndef _MSC_VER +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wpragmas" +#pragma GCC diagnostic ignored "-Wunknown-warning-option" +#pragma GCC diagnostic ignored "-Wundefined-var-template" +#endif + return TypeMeta(_typeMetaData()); +#ifndef _MSC_VER +#pragma GCC diagnostic pop +#endif + } + + /** + * convert ScalarType enum values to TypeMeta handles + */ + static inline caffe2::TypeMeta fromScalarType(ScalarType scalar_type) { + const auto index = static_cast(scalar_type); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + index < NumScalarTypes, + "Unrecognized Scalartype ", + scalar_type, + " (please report this error)"); + return TypeMeta(index); + } + + /** + * convert TypeMeta handles to ScalarType enum values + */ + inline ScalarType toScalarType() const { + if (C10_LIKELY(isScalarType())) { + return static_cast(index_); + } + error_unsupported_typemeta(*this); + } + + private: + [[noreturn]] static void error_unsupported_typemeta(caffe2::TypeMeta dtype); + + // hard limit number of registered types + // note: constexpr provokes Windows compilation error "member may not be + // initialized" static constexpr size_t MaxTypeIndex = 32; + // +#if defined C10_MOBILE +// The reason for this not to be UINT8_MAX is that the array +// initialization takes space which is proportional to the size of the array. +// The compiler seems to add code (or data padding) to initialize the array with +// empty elements. Please see +// https://github.com/pytorch/pytorch/pull/51881 for details. +// +#define MaxTypeIndex \ + (NumScalarTypes + 15 /* number of CAFFE_DEFINE_KNOWN_TYPE in typeid.cpp */ + \ + 1 /* 1 more for caffe2 tensor */) +#else +#define MaxTypeIndex UINT8_MAX +#endif + + // Protects type metadata allocation. + // NOLINTNEXTLINE(facebook-hte-NonPodStaticDeclaration) + static std::mutex& getTypeMetaDatasLock(); + static uint16_t nextTypeIndex; + + static detail::TypeMetaData* typeMetaDatas(); + + static uint16_t existingMetaDataIndexForType(TypeIdentifier identifier); + + public: +#ifdef __CUDACC__ + // NOTE [ TypeIdentifier::Get nvcc/clang discrepancy] + // nvcc and clang do not produce identical results for + // TypeIdentifier::Get, because TypeIdentifier::Get relies on + // __PRETTY_FUNCTION__ and they don't agree on the canonical names + // of types (e.g., nvcc normalizes to `short unsigned int`, but clang + // calls it `unsigned short`). Hide the implementation of this function + // from nvcc so that we always use clang (or whatever host C++ compiler) + // for TypeIdentifier::Get. + template + C10_EXPORT static uint16_t addTypeMetaData(); +#else + template + C10_EXPORT static uint16_t addTypeMetaData() { + const auto identifier = TypeIdentifier::Get(); + // Need to hold this for the rest of the function, protecting: + // 1) existingMetaDataIndexForType() + // 2) nextTypeIndex++ + // 3) the write into typeMetaDatas() + std::lock_guard lock(getTypeMetaDatasLock()); + // It may exist already if added in a different dynamic shared library. + const uint16_t existing_index = existingMetaDataIndexForType(identifier); + if (existing_index != MaxTypeIndex) { + return existing_index; + } + const uint16_t index = nextTypeIndex++; + TORCH_CHECK( + index <= MaxTypeIndex, + "Maximum number of CAFFE_KNOWN_TYPE declarations has been exceeded. ", + "Please report this issue."); + typeMetaDatas()[index] = detail::TypeMetaData{ + sizeof(T), + detail::_PickNew(), + detail::_PickPlacementNew(), + detail::_PickCopy(), + detail::_PickPlacementDelete(), + detail::_PickDelete(), + identifier, + c10::util::get_fully_qualified_type_name()}; + return index; + } +#endif + + private: + // specializations return indexes into typeMetaDataInstances() + template + C10_API static uint16_t _typeMetaData() noexcept; + + // + // TypeMeta just wraps this index + // + + uint16_t index_; + + inline const detail::TypeMetaData& data() const { + return typeMetaDatas()[index_]; + } +}; + +// specializations of TypeMeta::_typeMetaData for ScalarType types + +#define DEFINE_SCALAR_METADATA_INSTANCE(T, name) \ + template <> \ + constexpr uint16_t TypeMeta::_typeMetaData() noexcept { \ + return static_cast(ScalarType::name); \ + } +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_METADATA_INSTANCE) +#undef DEFINE_SCALAR_METADATA_INSTANCE + +template <> +C10_EXPORT constexpr uint16_t TypeMeta::_typeMetaData< + detail::_Uninitialized>() noexcept { + return static_cast(ScalarType::Undefined); +} + +inline TypeMeta::TypeMeta() noexcept + : index_(_typeMetaData()) {} + +inline bool operator==(const TypeMeta& lhs, const TypeMeta& rhs) noexcept { + return (lhs.index_ == rhs.index_); +} +inline bool operator!=(const TypeMeta& lhs, const TypeMeta& rhs) noexcept { + return !operator==(lhs, rhs); +} + +inline std::ostream& operator<<( + std::ostream& stream, + caffe2::TypeMeta typeMeta) { + return stream << typeMeta.name(); +} + +/** + * Register unique id for a type so it can be used in TypeMeta context, e.g. be + * used as a type for Blob or for Tensor elements. + * + * CAFFE_KNOWN_TYPE is deprecated; prefer CAFFE_DECLARE_KNOWN_TYPE and + * CAFFE_DEFINE_KNOWN_TYPE. + * + * CAFFE_KNOWN_TYPE does explicit instantiation of TypeIdentifier::Get + * template function and thus needs to be put in a single translation unit (.cpp + * file) for a given type T. Other translation units that use type T as a type + * of the caffe2::Blob or element type of caffe2::Tensor need to depend on the + * translation unit that contains CAFFE_KNOWN_TYPE declaration via regular + * linkage dependencies. + * + * NOTE: the macro needs to be invoked in ::caffe2 namespace + */ +// Implementation note: in MSVC, we will need to prepend the C10_API +// keyword in order to get things compiled properly. in Linux, gcc seems to +// create attribute ignored error for explicit template instantiations, see +// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2017/p0537r0.html +// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=51930 +// and as a result, we define these two macros slightly differently. +#if defined(_MSC_VER) || defined(__clang__) +#define EXPORT_IF_NOT_GCC C10_EXPORT +#else +#define EXPORT_IF_NOT_GCC +#endif + +// CAFFE_KNOWN_TYPE is deprecated! Use CAFFE_DECLARE_KNOWN_TYPE and +// CAFFE_DEFINE_KNOWN_TYPE instead. +#define CAFFE_KNOWN_TYPE(T) \ + template uint16_t TypeMeta::addTypeMetaData(); \ + template <> \ + EXPORT_IF_NOT_GCC uint16_t TypeMeta::_typeMetaData() noexcept { \ + static const uint16_t index = addTypeMetaData(); \ + return index; \ + } + +#define CAFFE_DEFINE_KNOWN_TYPE(T, ident) \ + template uint16_t TypeMeta::addTypeMetaData(); \ + namespace detail { \ + EXPORT_IF_NOT_GCC const uint16_t ident##_metadata_index = \ + TypeMeta::addTypeMetaData(); \ + } // namespace detail + +// Unlike CAFFE_KNOWN_TYPE, CAFFE_DECLARE_KNOWN_TYPE avoids a function +// call to access _typeMetaData in the common case. +#define CAFFE_DECLARE_KNOWN_TYPE(T, ident) \ + extern template uint16_t TypeMeta::addTypeMetaData(); \ + namespace detail { \ + extern C10_API const uint16_t ident##_metadata_index; \ + } /* namespace detail */ \ + template <> \ + EXPORT_IF_NOT_GCC C10_ALWAYS_INLINE uint16_t \ + TypeMeta::_typeMetaData() noexcept { \ + return detail::ident##_metadata_index; \ + } + +#define CAFFE_KNOWN_TYPE_NOEXPORT(T) \ + template <> \ + uint16_t TypeMeta::_typeMetaData() noexcept { \ + static const uint16_t index = addTypeMetaData(); \ + return index; \ + } + +CAFFE_DECLARE_KNOWN_TYPE(std::string, std_string) +CAFFE_DECLARE_KNOWN_TYPE(char, char) +CAFFE_DECLARE_KNOWN_TYPE(std::unique_ptr, std_unique_ptr_std_mutex) +CAFFE_DECLARE_KNOWN_TYPE( + std::unique_ptr>, + std_unique_ptr_std_atomic_bool) +CAFFE_DECLARE_KNOWN_TYPE(std::vector, std_vector_int32_t) +CAFFE_DECLARE_KNOWN_TYPE(std::vector, std_vector_int64_t) +CAFFE_DECLARE_KNOWN_TYPE(std::vector, std_vector_unsigned_long) +CAFFE_DECLARE_KNOWN_TYPE(bool*, bool_ptr) +CAFFE_DECLARE_KNOWN_TYPE(char*, char_ptr) +CAFFE_DECLARE_KNOWN_TYPE(int*, int_ptr) + +// For some of the compilers, long is defined separately from int32_t and +// int64_t. As a result we will need to actually define them separately. +// It is recommended that one does NOT use long - use int32_t and int64_t +// explicitly. Explicit long type annotation may go away in the future. +// details: This hack works by defining a _guard_long_unique type, which is +// long iff the compiler has a separate long type and is a dummy type otherwise. +// we then allocate a type id to that _guard_long_unique. If the compiler has a +// separate long type, this allocates a type id for long. Otherwise, it +// allocates a type id for the dummy type, which doesn't matter. +namespace detail { +template +class _guard_long_unique_dummy final {}; +template +using _guard_long_unique = std::conditional_t< + std::is_same_v || std::is_same_v, + _guard_long_unique_dummy, + T>; +} // namespace detail + +CAFFE_DECLARE_KNOWN_TYPE( + detail::_guard_long_unique, + detail_guard_long_unique_long) +CAFFE_DECLARE_KNOWN_TYPE( + detail::_guard_long_unique>, + detail_guard_long_unique_std_vector_long) + +CAFFE_DECLARE_KNOWN_TYPE(float*, float_ptr) +CAFFE_DECLARE_KNOWN_TYPE(at::Half*, at_Half) + +} // namespace caffe2 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/win32-headers.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/win32-headers.h new file mode 100644 index 0000000000000000000000000000000000000000..f9eb55948a858c8551a52c81ece8eab8c862c324 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/c10/util/win32-headers.h @@ -0,0 +1,65 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#ifndef NOMINMAX +#define NOMINMAX +#endif +#ifndef NOKERNEL +#define NOKERNEL +#endif +#ifndef NOUSER +#define NOUSER +#endif +#ifndef NOSERVICE +#define NOSERVICE +#endif +#ifndef NOSOUND +#define NOSOUND +#endif +#ifndef NOMCX +#define NOMCX +#endif +#ifndef NOGDI +#define NOGDI +#endif +#ifndef NOMSG +#define NOMSG +#endif +#ifndef NOMB +#define NOMB +#endif +#ifndef NOCLIPBOARD +#define NOCLIPBOARD +#endif + +// dbghelp seems to require windows.h. +// clang-format off +#include +#include +// clang-format on + +#undef VOID +#undef DELETE +#undef IN +#undef THIS +#undef CONST +#undef NAN +#undef UNKNOWN +#undef NONE +#undef ANY +#undef IGNORE +#undef STRICT +#undef GetObject +#undef CreateSemaphore +#undef Yield +#undef RotateRight32 +#undef RotateLeft32 +#undef RotateRight64 +#undef RotateLeft64 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/code_generator.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/code_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..c38c35da9761396b4fcb630d97c388422b7dd4c4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/code_generator.h @@ -0,0 +1,197 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// Defines the abstract interface implemented by each of the language-specific +// code generators. + +#ifndef GOOGLE_PROTOBUF_COMPILER_CODE_GENERATOR_H__ +#define GOOGLE_PROTOBUF_COMPILER_CODE_GENERATOR_H__ + +#include +#include +#include +#include + +#include + +namespace google { +namespace protobuf { + +namespace io { +class ZeroCopyOutputStream; +} +class FileDescriptor; + +namespace compiler { +class AccessInfoMap; + +class Version; + +// Defined in this file. +class CodeGenerator; +class GeneratorContext; + +// The abstract interface to a class which generates code implementing a +// particular proto file in a particular language. A number of these may +// be registered with CommandLineInterface to support various languages. +class PROTOC_EXPORT CodeGenerator { + public: + inline CodeGenerator() {} + virtual ~CodeGenerator(); + + // Generates code for the given proto file, generating one or more files in + // the given output directory. + // + // A parameter to be passed to the generator can be specified on the command + // line. This is intended to be used to pass generator specific parameters. + // It is empty if no parameter was given. ParseGeneratorParameter (below), + // can be used to accept multiple parameters within the single parameter + // command line flag. + // + // Returns true if successful. Otherwise, sets *error to a description of + // the problem (e.g. "invalid parameter") and returns false. + virtual bool Generate(const FileDescriptor* file, + const std::string& parameter, + GeneratorContext* generator_context, + std::string* error) const = 0; + + // Generates code for all given proto files. + // + // WARNING: The canonical code generator design produces one or two output + // files per input .proto file, and we do not wish to encourage alternate + // designs. + // + // A parameter is given as passed on the command line, as in |Generate()| + // above. + // + // Returns true if successful. Otherwise, sets *error to a description of + // the problem (e.g. "invalid parameter") and returns false. + virtual bool GenerateAll(const std::vector& files, + const std::string& parameter, + GeneratorContext* generator_context, + std::string* error) const; + + // Sync with plugin.proto. + enum Feature { + FEATURE_PROTO3_OPTIONAL = 1, + }; + + // Implement this to indicate what features this code generator supports. + // This should be a bitwise OR of features from the Features enum in + // plugin.proto. + virtual uint64_t GetSupportedFeatures() const { return 0; } + + // This is no longer used, but this class is part of the opensource protobuf + // library, so it has to remain to keep vtables the same for the current + // version of the library. When protobufs does a api breaking change, the + // method can be removed. + virtual bool HasGenerateAll() const { return true; } + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CodeGenerator); +}; + +// CodeGenerators generate one or more files in a given directory. This +// abstract interface represents the directory to which the CodeGenerator is +// to write and other information about the context in which the Generator +// runs. +class PROTOC_EXPORT GeneratorContext { + public: + inline GeneratorContext() { + } + virtual ~GeneratorContext(); + + // Opens the given file, truncating it if it exists, and returns a + // ZeroCopyOutputStream that writes to the file. The caller takes ownership + // of the returned object. This method never fails (a dummy stream will be + // returned instead). + // + // The filename given should be relative to the root of the source tree. + // E.g. the C++ generator, when generating code for "foo/bar.proto", will + // generate the files "foo/bar.pb.h" and "foo/bar.pb.cc"; note that + // "foo/" is included in these filenames. The filename is not allowed to + // contain "." or ".." components. + virtual io::ZeroCopyOutputStream* Open(const std::string& filename) = 0; + + // Similar to Open() but the output will be appended to the file if exists + virtual io::ZeroCopyOutputStream* OpenForAppend(const std::string& filename); + + // Creates a ZeroCopyOutputStream which will insert code into the given file + // at the given insertion point. See plugin.proto (plugin.pb.h) for more + // information on insertion points. The default implementation + // assert-fails -- it exists only for backwards-compatibility. + // + // WARNING: This feature is currently EXPERIMENTAL and is subject to change. + virtual io::ZeroCopyOutputStream* OpenForInsert( + const std::string& filename, const std::string& insertion_point); + + // Returns a vector of FileDescriptors for all the files being compiled + // in this run. Useful for languages, such as Go, that treat files + // differently when compiled as a set rather than individually. + virtual void ListParsedFiles(std::vector* output); + + // Retrieves the version number of the protocol compiler associated with + // this GeneratorContext. + virtual void GetCompilerVersion(Version* version) const; + + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(GeneratorContext); +}; + +// The type GeneratorContext was once called OutputDirectory. This typedef +// provides backward compatibility. +typedef GeneratorContext OutputDirectory; + +// Several code generators treat the parameter argument as holding a +// list of options separated by commas. This helper function parses +// a set of comma-delimited name/value pairs: e.g., +// "foo=bar,baz,qux=corge" +// parses to the pairs: +// ("foo", "bar"), ("baz", ""), ("qux", "corge") +PROTOC_EXPORT void ParseGeneratorParameter( + const std::string&, std::vector >*); + +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_CODE_GENERATOR_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/command_line_interface.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/command_line_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..8f71eb7ed759e197781e1fb47773435eb88faf72 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/command_line_interface.h @@ -0,0 +1,468 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// Implements the Protocol Compiler front-end such that it may be reused by +// custom compilers written to support other languages. + +#ifndef GOOGLE_PROTOBUF_COMPILER_COMMAND_LINE_INTERFACE_H__ +#define GOOGLE_PROTOBUF_COMPILER_COMMAND_LINE_INTERFACE_H__ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace google { +namespace protobuf { + +class Descriptor; // descriptor.h +class DescriptorDatabase; // descriptor_database.h +class DescriptorPool; // descriptor.h +class FileDescriptor; // descriptor.h +class FileDescriptorSet; // descriptor.h +class FileDescriptorProto; // descriptor.pb.h +template +class RepeatedPtrField; // repeated_field.h +class SimpleDescriptorDatabase; // descriptor_database.h + +namespace compiler { + +class CodeGenerator; // code_generator.h +class GeneratorContext; // code_generator.h +class DiskSourceTree; // importer.h + +// This class implements the command-line interface to the protocol compiler. +// It is designed to make it very easy to create a custom protocol compiler +// supporting the languages of your choice. For example, if you wanted to +// create a custom protocol compiler binary which includes both the regular +// C++ support plus support for your own custom output "Foo", you would +// write a class "FooGenerator" which implements the CodeGenerator interface, +// then write a main() procedure like this: +// +// int main(int argc, char* argv[]) { +// google::protobuf::compiler::CommandLineInterface cli; +// +// // Support generation of C++ source and headers. +// google::protobuf::compiler::cpp::CppGenerator cpp_generator; +// cli.RegisterGenerator("--cpp_out", &cpp_generator, +// "Generate C++ source and header."); +// +// // Support generation of Foo code. +// FooGenerator foo_generator; +// cli.RegisterGenerator("--foo_out", &foo_generator, +// "Generate Foo file."); +// +// return cli.Run(argc, argv); +// } +// +// The compiler is invoked with syntax like: +// protoc --cpp_out=outdir --foo_out=outdir --proto_path=src src/foo.proto +// +// The .proto file to compile can be specified on the command line using either +// its physical file path, or a virtual path relative to a directory specified +// in --proto_path. For example, for src/foo.proto, the following two protoc +// invocations work the same way: +// 1. protoc --proto_path=src src/foo.proto (physical file path) +// 2. protoc --proto_path=src foo.proto (virtual path relative to src) +// +// If a file path can be interpreted both as a physical file path and as a +// relative virtual path, the physical file path takes precendence. +// +// For a full description of the command-line syntax, invoke it with --help. +class PROTOC_EXPORT CommandLineInterface { + public: + static const char* const kPathSeparator; + + CommandLineInterface(); + ~CommandLineInterface(); + + // Register a code generator for a language. + // + // Parameters: + // * flag_name: The command-line flag used to specify an output file of + // this type. The name must start with a '-'. If the name is longer + // than one letter, it must start with two '-'s. + // * generator: The CodeGenerator which will be called to generate files + // of this type. + // * help_text: Text describing this flag in the --help output. + // + // Some generators accept extra parameters. You can specify this parameter + // on the command-line by placing it before the output directory, separated + // by a colon: + // protoc --foo_out=enable_bar:outdir + // The text before the colon is passed to CodeGenerator::Generate() as the + // "parameter". + void RegisterGenerator(const std::string& flag_name, CodeGenerator* generator, + const std::string& help_text); + + // Register a code generator for a language. + // Besides flag_name you can specify another option_flag_name that could be + // used to pass extra parameters to the registered code generator. + // Suppose you have registered a generator by calling: + // command_line_interface.RegisterGenerator("--foo_out", "--foo_opt", ...) + // Then you could invoke the compiler with a command like: + // protoc --foo_out=enable_bar:outdir --foo_opt=enable_baz + // This will pass "enable_bar,enable_baz" as the parameter to the generator. + void RegisterGenerator(const std::string& flag_name, + const std::string& option_flag_name, + CodeGenerator* generator, + const std::string& help_text); + + // Enables "plugins". In this mode, if a command-line flag ends with "_out" + // but does not match any registered generator, the compiler will attempt to + // find a "plugin" to implement the generator. Plugins are just executables. + // They should live somewhere in the PATH. + // + // The compiler determines the executable name to search for by concatenating + // exe_name_prefix with the unrecognized flag name, removing "_out". So, for + // example, if exe_name_prefix is "protoc-" and you pass the flag --foo_out, + // the compiler will try to run the program "protoc-gen-foo". + // + // The plugin program should implement the following usage: + // plugin [--out=OUTDIR] [--parameter=PARAMETER] PROTO_FILES < DESCRIPTORS + // --out indicates the output directory (as passed to the --foo_out + // parameter); if omitted, the current directory should be used. --parameter + // gives the generator parameter, if any was provided (see below). The + // PROTO_FILES list the .proto files which were given on the compiler + // command-line; these are the files for which the plugin is expected to + // generate output code. Finally, DESCRIPTORS is an encoded FileDescriptorSet + // (as defined in descriptor.proto). This is piped to the plugin's stdin. + // The set will include descriptors for all the files listed in PROTO_FILES as + // well as all files that they import. The plugin MUST NOT attempt to read + // the PROTO_FILES directly -- it must use the FileDescriptorSet. + // + // The plugin should generate whatever files are necessary, as code generators + // normally do. It should write the names of all files it generates to + // stdout. The names should be relative to the output directory, NOT absolute + // names or relative to the current directory. If any errors occur, error + // messages should be written to stderr. If an error is fatal, the plugin + // should exit with a non-zero exit code. + // + // Plugins can have generator parameters similar to normal built-in + // generators. Extra generator parameters can be passed in via a matching + // "_opt" parameter. For example: + // protoc --plug_out=enable_bar:outdir --plug_opt=enable_baz + // This will pass "enable_bar,enable_baz" as the parameter to the plugin. + // + void AllowPlugins(const std::string& exe_name_prefix); + + // Run the Protocol Compiler with the given command-line parameters. + // Returns the error code which should be returned by main(). + // + // It may not be safe to call Run() in a multi-threaded environment because + // it calls strerror(). I'm not sure why you'd want to do this anyway. + int Run(int argc, const char* const argv[]); + + // DEPRECATED. Calling this method has no effect. Protocol compiler now + // always try to find the .proto file relative to the current directory + // first and if the file is not found, it will then treat the input path + // as a virtual path. + void SetInputsAreProtoPathRelative(bool /* enable */) {} + + // Provides some text which will be printed when the --version flag is + // used. The version of libprotoc will also be printed on the next line + // after this text. + void SetVersionInfo(const std::string& text) { version_info_ = text; } + + + private: + // ----------------------------------------------------------------- + + class ErrorPrinter; + class GeneratorContextImpl; + class MemoryOutputStream; + typedef std::unordered_map> + GeneratorContextMap; + + // Clear state from previous Run(). + void Clear(); + + // Remaps the proto file so that it is relative to one of the directories + // in proto_path_. Returns false if an error occurred. + bool MakeProtoProtoPathRelative(DiskSourceTree* source_tree, + std::string* proto, + DescriptorDatabase* fallback_database); + + // Remaps each file in input_files_ so that it is relative to one of the + // directories in proto_path_. Returns false if an error occurred. + bool MakeInputsBeProtoPathRelative(DiskSourceTree* source_tree, + DescriptorDatabase* fallback_database); + + // Is this .proto file whitelisted, or do we have a command-line flag allowing + // us to use proto3 optional? This is a temporary control to avoid people from + // using proto3 optional until code generators have implemented it. + bool AllowProto3Optional(const FileDescriptor& file) const; + + // Fails if these files use proto3 optional and the code generator doesn't + // support it. This is a permanent check. + bool EnforceProto3OptionalSupport( + const std::string& codegen_name, uint64 supported_features, + const std::vector& parsed_files) const; + + + // Return status for ParseArguments() and InterpretArgument(). + enum ParseArgumentStatus { + PARSE_ARGUMENT_DONE_AND_CONTINUE, + PARSE_ARGUMENT_DONE_AND_EXIT, + PARSE_ARGUMENT_FAIL + }; + + // Parse all command-line arguments. + ParseArgumentStatus ParseArguments(int argc, const char* const argv[]); + + // Read an argument file and append the file's content to the list of + // arguments. Return false if the file cannot be read. + bool ExpandArgumentFile(const std::string& file, + std::vector* arguments); + + // Parses a command-line argument into a name/value pair. Returns + // true if the next argument in the argv should be used as the value, + // false otherwise. + // + // Examples: + // "-Isrc/protos" -> + // name = "-I", value = "src/protos" + // "--cpp_out=src/foo.pb2.cc" -> + // name = "--cpp_out", value = "src/foo.pb2.cc" + // "foo.proto" -> + // name = "", value = "foo.proto" + bool ParseArgument(const char* arg, std::string* name, std::string* value); + + // Interprets arguments parsed with ParseArgument. + ParseArgumentStatus InterpretArgument(const std::string& name, + const std::string& value); + + // Print the --help text to stderr. + void PrintHelpText(); + + // Loads proto_path_ into the provided source_tree. + bool InitializeDiskSourceTree(DiskSourceTree* source_tree, + DescriptorDatabase* fallback_database); + + // Verify that all the input files exist in the given database. + bool VerifyInputFilesInDescriptors(DescriptorDatabase* fallback_database); + + // Parses input_files_ into parsed_files + bool ParseInputFiles(DescriptorPool* descriptor_pool, + DiskSourceTree* source_tree, + std::vector* parsed_files); + + // Generate the given output file from the given input. + struct OutputDirective; // see below + bool GenerateOutput(const std::vector& parsed_files, + const OutputDirective& output_directive, + GeneratorContext* generator_context); + bool GeneratePluginOutput( + const std::vector& parsed_files, + const std::string& plugin_name, const std::string& parameter, + GeneratorContext* generator_context, std::string* error); + + // Implements --encode and --decode. + bool EncodeOrDecode(const DescriptorPool* pool); + + // Implements the --descriptor_set_out option. + bool WriteDescriptorSet( + const std::vector& parsed_files); + + // Implements the --dependency_out option + bool GenerateDependencyManifestFile( + const std::vector& parsed_files, + const GeneratorContextMap& output_directories, + DiskSourceTree* source_tree); + + // Get all transitive dependencies of the given file (including the file + // itself), adding them to the given list of FileDescriptorProtos. The + // protos will be ordered such that every file is listed before any file that + // depends on it, so that you can call DescriptorPool::BuildFile() on them + // in order. Any files in *already_seen will not be added, and each file + // added will be inserted into *already_seen. If include_source_code_info is + // true then include the source code information in the FileDescriptorProtos. + // If include_json_name is true, populate the json_name field of + // FieldDescriptorProto for all fields. + static void GetTransitiveDependencies( + const FileDescriptor* file, bool include_json_name, + bool include_source_code_info, + std::set* already_seen, + RepeatedPtrField* output); + + // Implements the --print_free_field_numbers. This function prints free field + // numbers into stdout for the message and it's nested message types in + // post-order, i.e. nested types first. Printed range are left-right + // inclusive, i.e. [a, b]. + // + // Groups: + // For historical reasons, groups are considered to share the same + // field number space with the parent message, thus it will not print free + // field numbers for groups. The field numbers used in the groups are + // excluded in the free field numbers of the parent message. + // + // Extension Ranges: + // Extension ranges are considered ocuppied field numbers and they will not be + // listed as free numbers in the output. + void PrintFreeFieldNumbers(const Descriptor* descriptor); + + // ----------------------------------------------------------------- + + // The name of the executable as invoked (i.e. argv[0]). + std::string executable_name_; + + // Version info set with SetVersionInfo(). + std::string version_info_; + + // Registered generators. + struct GeneratorInfo { + std::string flag_name; + std::string option_flag_name; + CodeGenerator* generator; + std::string help_text; + }; + typedef std::map GeneratorMap; + GeneratorMap generators_by_flag_name_; + GeneratorMap generators_by_option_name_; + // A map from generator names to the parameters specified using the option + // flag. For example, if the user invokes the compiler with: + // protoc --foo_out=outputdir --foo_opt=enable_bar ... + // Then there will be an entry ("--foo_out", "enable_bar") in this map. + std::map generator_parameters_; + // Similar to generator_parameters_, but stores the parameters for plugins. + std::map plugin_parameters_; + + // See AllowPlugins(). If this is empty, plugins aren't allowed. + std::string plugin_prefix_; + + // Maps specific plugin names to files. When executing a plugin, this map + // is searched first to find the plugin executable. If not found here, the + // PATH (or other OS-specific search strategy) is searched. + std::map plugins_; + + // Stuff parsed from command line. + enum Mode { + MODE_COMPILE, // Normal mode: parse .proto files and compile them. + MODE_ENCODE, // --encode: read text from stdin, write binary to stdout. + MODE_DECODE, // --decode: read binary from stdin, write text to stdout. + MODE_PRINT, // Print mode: print info of the given .proto files and exit. + }; + + Mode mode_ = MODE_COMPILE; + + enum PrintMode { + PRINT_NONE, // Not in MODE_PRINT + PRINT_FREE_FIELDS, // --print_free_fields + }; + + PrintMode print_mode_ = PRINT_NONE; + + enum ErrorFormat { + ERROR_FORMAT_GCC, // GCC error output format (default). + ERROR_FORMAT_MSVS // Visual Studio output (--error_format=msvs). + }; + + ErrorFormat error_format_ = ERROR_FORMAT_GCC; + + std::vector > + proto_path_; // Search path for proto files. + std::vector input_files_; // Names of the input proto files. + + // Names of proto files which are allowed to be imported. Used by build + // systems to enforce depend-on-what-you-import. + std::set direct_dependencies_; + bool direct_dependencies_explicitly_set_ = false; + + // If there's a violation of depend-on-what-you-import, this string will be + // presented to the user. "%s" will be replaced with the violating import. + std::string direct_dependencies_violation_msg_; + + // output_directives_ lists all the files we are supposed to output and what + // generator to use for each. + struct OutputDirective { + std::string name; // E.g. "--foo_out" + CodeGenerator* generator; // NULL for plugins + std::string parameter; + std::string output_location; + }; + std::vector output_directives_; + + // When using --encode or --decode, this names the type we are encoding or + // decoding. (Empty string indicates --decode_raw.) + std::string codec_type_; + + // If --descriptor_set_in was given, these are filenames containing + // parsed FileDescriptorSets to be used for loading protos. Otherwise, empty. + std::vector descriptor_set_in_names_; + + // If --descriptor_set_out was given, this is the filename to which the + // FileDescriptorSet should be written. Otherwise, empty. + std::string descriptor_set_out_name_; + + // If --dependency_out was given, this is the path to the file where the + // dependency file will be written. Otherwise, empty. + std::string dependency_out_name_; + + // True if --include_imports was given, meaning that we should + // write all transitive dependencies to the DescriptorSet. Otherwise, only + // the .proto files listed on the command-line are added. + bool imports_in_descriptor_set_; + + // True if --include_source_info was given, meaning that we should not strip + // SourceCodeInfo from the DescriptorSet. + bool source_info_in_descriptor_set_ = false; + + // Was the --disallow_services flag used? + bool disallow_services_ = false; + + // Was the --experimental_allow_proto3_optional flag used? + bool allow_proto3_optional_ = false; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CommandLineInterface); +}; + +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_COMMAND_LINE_INTERFACE_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/cpp/cpp_generator.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/cpp/cpp_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..365b8e12f8f7897b0bb881b891ae525a2ca72ca9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/cpp/cpp_generator.h @@ -0,0 +1,111 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// Generates C++ code for a given .proto file. + +#ifndef GOOGLE_PROTOBUF_COMPILER_CPP_GENERATOR_H__ +#define GOOGLE_PROTOBUF_COMPILER_CPP_GENERATOR_H__ + +#include +#include + +#include + +namespace google { +namespace protobuf { +namespace compiler { +namespace cpp { + +// CodeGenerator implementation which generates a C++ source file and +// header. If you create your own protocol compiler binary and you want +// it to support C++ output, you can do so by registering an instance of this +// CodeGenerator with the CommandLineInterface in your main() function. +class PROTOC_EXPORT CppGenerator : public CodeGenerator { + public: + CppGenerator(); + ~CppGenerator(); + + enum class Runtime { + kGoogle3, // Use the internal google3 runtime. + kOpensource, // Use the open-source runtime. + + // Use the open-source runtime with google3 #include paths. We make these + // absolute to avoid ambiguity, so the runtime will be #included like: + // #include "third_party/protobuf/.../google/protobuf/message.h" + kOpensourceGoogle3 + }; + + void set_opensource_runtime(bool opensource) { + opensource_runtime_ = opensource; + } + + // If set to a non-empty string, generated code will do: + // #include "/google/protobuf/message.h" + // instead of: + // #include + // This has no effect if opensource_runtime = false. + void set_runtime_include_base(const std::string& base) { + runtime_include_base_ = base; + } + + // implements CodeGenerator ---------------------------------------- + bool Generate(const FileDescriptor* file, const std::string& parameter, + GeneratorContext* generator_context, + std::string* error) const override; + + uint64_t GetSupportedFeatures() const override { + // We don't fully support this yet, but this is needed to unblock the tests, + // and we will have full support before the experimental flag is removed. + return FEATURE_PROTO3_OPTIONAL; + } + + private: + bool opensource_runtime_ = true; + std::string runtime_include_base_; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CppGenerator); +}; + +} // namespace cpp +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_CPP_GENERATOR_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/csharp/csharp_generator.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/csharp/csharp_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..b85832eb7688472c8547cadd4c4e83ffba6d1fec --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/csharp/csharp_generator.h @@ -0,0 +1,75 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Generates C# code for a given .proto file. + +#ifndef GOOGLE_PROTOBUF_COMPILER_CSHARP_GENERATOR_H__ +#define GOOGLE_PROTOBUF_COMPILER_CSHARP_GENERATOR_H__ + +#include + +#include + +#include + +namespace google { +namespace protobuf { +namespace compiler { +namespace csharp { + +// CodeGenerator implementation which generates a C# source file and +// header. If you create your own protocol compiler binary and you want +// it to support C# output, you can do so by registering an instance of this +// CodeGenerator with the CommandLineInterface in your main() function. +class PROTOC_EXPORT Generator : public CodeGenerator { + public: + Generator(); + ~Generator(); + bool Generate( + const FileDescriptor* file, + const string& parameter, + GeneratorContext* generator_context, + string* error) const override; + uint64_t GetSupportedFeatures() const override; +}; + +} // namespace csharp +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_CSHARP_GENERATOR_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/csharp/csharp_names.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/csharp/csharp_names.h new file mode 100644 index 0000000000000000000000000000000000000000..972b097817a14ef490829ad2fd725a58122afbcf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/csharp/csharp_names.h @@ -0,0 +1,112 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// Provides a mechanism for mapping a descriptor to the +// fully-qualified name of the corresponding C# class. + +#ifndef GOOGLE_PROTOBUF_COMPILER_CSHARP_NAMES_H__ +#define GOOGLE_PROTOBUF_COMPILER_CSHARP_NAMES_H__ + +#include +#include +#include + +#include + +namespace google { +namespace protobuf { + +class Descriptor; +class EnumDescriptor; +class FileDescriptor; +class ServiceDescriptor; + +namespace compiler { +namespace csharp { + +// Requires: +// descriptor != NULL +// +// Returns: +// The namespace to use for given file descriptor. +string PROTOC_EXPORT GetFileNamespace(const FileDescriptor* descriptor); + +// Requires: +// descriptor != NULL +// +// Returns: +// The fully-qualified C# class name. +string PROTOC_EXPORT GetClassName(const Descriptor* descriptor); + +// Requires: +// descriptor != NULL +// +// Returns: +// The fully-qualified name of the C# class that provides +// access to the file descriptor. Proto compiler generates +// such class for each .proto file processed. +string PROTOC_EXPORT GetReflectionClassName(const FileDescriptor* descriptor); + +// Generates output file name for given file descriptor. If generate_directories +// is true, the output file will be put under directory corresponding to file's +// namespace. base_namespace can be used to strip some of the top level +// directories. E.g. for file with namespace "Bar.Foo" and base_namespace="Bar", +// the resulting file will be put under directory "Foo" (and not "Bar/Foo"). +// +// Requires: +// descriptor != NULL +// error != NULL +// +// Returns: +// The file name to use as output file for given file descriptor. In case +// of failure, this function will return empty string and error parameter +// will contain the error message. +string PROTOC_EXPORT GetOutputFile(const FileDescriptor* descriptor, + const string file_extension, + const bool generate_directories, + const string base_namespace, string* error); + +} // namespace csharp +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_CSHARP_NAMES_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/importer.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/importer.h new file mode 100644 index 0000000000000000000000000000000000000000..1f2c6df0962a9141c3b567524cc6062e00f4ea11 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/importer.h @@ -0,0 +1,341 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// This file is the public interface to the .proto file parser. + +#ifndef GOOGLE_PROTOBUF_COMPILER_IMPORTER_H__ +#define GOOGLE_PROTOBUF_COMPILER_IMPORTER_H__ + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace google { +namespace protobuf { + +namespace io { +class ZeroCopyInputStream; +} + +namespace compiler { + +// Defined in this file. +class Importer; +class MultiFileErrorCollector; +class SourceTree; +class DiskSourceTree; + +// TODO(kenton): Move all SourceTree stuff to a separate file? + +// An implementation of DescriptorDatabase which loads files from a SourceTree +// and parses them. +// +// Note: This class is not thread-safe since it maintains a table of source +// code locations for error reporting. However, when a DescriptorPool wraps +// a DescriptorDatabase, it uses mutex locking to make sure only one method +// of the database is called at a time, even if the DescriptorPool is used +// from multiple threads. Therefore, there is only a problem if you create +// multiple DescriptorPools wrapping the same SourceTreeDescriptorDatabase +// and use them from multiple threads. +// +// Note: This class does not implement FindFileContainingSymbol() or +// FindFileContainingExtension(); these will always return false. +class PROTOBUF_EXPORT SourceTreeDescriptorDatabase : public DescriptorDatabase { + public: + SourceTreeDescriptorDatabase(SourceTree* source_tree); + + // If non-NULL, fallback_database will be checked if a file doesn't exist in + // the specified source_tree. + SourceTreeDescriptorDatabase(SourceTree* source_tree, + DescriptorDatabase* fallback_database); + ~SourceTreeDescriptorDatabase(); + + // Instructs the SourceTreeDescriptorDatabase to report any parse errors + // to the given MultiFileErrorCollector. This should be called before + // parsing. error_collector must remain valid until either this method + // is called again or the SourceTreeDescriptorDatabase is destroyed. + void RecordErrorsTo(MultiFileErrorCollector* error_collector) { + error_collector_ = error_collector; + } + + // Gets a DescriptorPool::ErrorCollector which records errors to the + // MultiFileErrorCollector specified with RecordErrorsTo(). This collector + // has the ability to determine exact line and column numbers of errors + // from the information given to it by the DescriptorPool. + DescriptorPool::ErrorCollector* GetValidationErrorCollector() { + using_validation_error_collector_ = true; + return &validation_error_collector_; + } + + // implements DescriptorDatabase ----------------------------------- + bool FindFileByName(const std::string& filename, + FileDescriptorProto* output) override; + bool FindFileContainingSymbol(const std::string& symbol_name, + FileDescriptorProto* output) override; + bool FindFileContainingExtension(const std::string& containing_type, + int field_number, + FileDescriptorProto* output) override; + + private: + class SingleFileErrorCollector; + + SourceTree* source_tree_; + DescriptorDatabase* fallback_database_; + MultiFileErrorCollector* error_collector_; + + class PROTOBUF_EXPORT ValidationErrorCollector + : public DescriptorPool::ErrorCollector { + public: + ValidationErrorCollector(SourceTreeDescriptorDatabase* owner); + ~ValidationErrorCollector(); + + // implements ErrorCollector --------------------------------------- + void AddError(const std::string& filename, const std::string& element_name, + const Message* descriptor, ErrorLocation location, + const std::string& message) override; + + void AddWarning(const std::string& filename, + const std::string& element_name, const Message* descriptor, + ErrorLocation location, + const std::string& message) override; + + private: + SourceTreeDescriptorDatabase* owner_; + }; + friend class ValidationErrorCollector; + + bool using_validation_error_collector_; + SourceLocationTable source_locations_; + ValidationErrorCollector validation_error_collector_; +}; + +// Simple interface for parsing .proto files. This wraps the process +// of opening the file, parsing it with a Parser, recursively parsing all its +// imports, and then cross-linking the results to produce a FileDescriptor. +// +// This is really just a thin wrapper around SourceTreeDescriptorDatabase. +// You may find that SourceTreeDescriptorDatabase is more flexible. +// +// TODO(kenton): I feel like this class is not well-named. +class PROTOBUF_EXPORT Importer { + public: + Importer(SourceTree* source_tree, MultiFileErrorCollector* error_collector); + ~Importer(); + + // Import the given file and build a FileDescriptor representing it. If + // the file is already in the DescriptorPool, the existing FileDescriptor + // will be returned. The FileDescriptor is property of the DescriptorPool, + // and will remain valid until it is destroyed. If any errors occur, they + // will be reported using the error collector and Import() will return NULL. + // + // A particular Importer object will only report errors for a particular + // file once. All future attempts to import the same file will return NULL + // without reporting any errors. The idea is that you might want to import + // a lot of files without seeing the same errors over and over again. If + // you want to see errors for the same files repeatedly, you can use a + // separate Importer object to import each one (but use the same + // DescriptorPool so that they can be cross-linked). + const FileDescriptor* Import(const std::string& filename); + + // The DescriptorPool in which all imported FileDescriptors and their + // contents are stored. + inline const DescriptorPool* pool() const { return &pool_; } + + void AddUnusedImportTrackFile(const std::string& file_name, + bool is_error = false); + void ClearUnusedImportTrackFiles(); + + + private: + SourceTreeDescriptorDatabase database_; + DescriptorPool pool_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(Importer); +}; + +// If the importer encounters problems while trying to import the proto files, +// it reports them to a MultiFileErrorCollector. +class PROTOBUF_EXPORT MultiFileErrorCollector { + public: + inline MultiFileErrorCollector() {} + virtual ~MultiFileErrorCollector(); + + // Line and column numbers are zero-based. A line number of -1 indicates + // an error with the entire file (e.g. "not found"). + virtual void AddError(const std::string& filename, int line, int column, + const std::string& message) = 0; + + virtual void AddWarning(const std::string& filename, int line, int column, + const std::string& message) {} + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(MultiFileErrorCollector); +}; + +// Abstract interface which represents a directory tree containing proto files. +// Used by the default implementation of Importer to resolve import statements +// Most users will probably want to use the DiskSourceTree implementation, +// below. +class PROTOBUF_EXPORT SourceTree { + public: + inline SourceTree() {} + virtual ~SourceTree(); + + // Open the given file and return a stream that reads it, or NULL if not + // found. The caller takes ownership of the returned object. The filename + // must be a path relative to the root of the source tree and must not + // contain "." or ".." components. + virtual io::ZeroCopyInputStream* Open(const std::string& filename) = 0; + + // If Open() returns NULL, calling this method immediately will return an + // description of the error. + // Subclasses should implement this method and return a meaningful value for + // better error reporting. + // TODO(xiaofeng): change this to a pure virtual function. + virtual std::string GetLastErrorMessage(); + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(SourceTree); +}; + +// An implementation of SourceTree which loads files from locations on disk. +// Multiple mappings can be set up to map locations in the DiskSourceTree to +// locations in the physical filesystem. +class PROTOBUF_EXPORT DiskSourceTree : public SourceTree { + public: + DiskSourceTree(); + ~DiskSourceTree(); + + // Map a path on disk to a location in the SourceTree. The path may be + // either a file or a directory. If it is a directory, the entire tree + // under it will be mapped to the given virtual location. To map a directory + // to the root of the source tree, pass an empty string for virtual_path. + // + // If multiple mapped paths apply when opening a file, they will be searched + // in order. For example, if you do: + // MapPath("bar", "foo/bar"); + // MapPath("", "baz"); + // and then you do: + // Open("bar/qux"); + // the DiskSourceTree will first try to open foo/bar/qux, then baz/bar/qux, + // returning the first one that opens successfully. + // + // disk_path may be an absolute path or relative to the current directory, + // just like a path you'd pass to open(). + void MapPath(const std::string& virtual_path, const std::string& disk_path); + + // Return type for DiskFileToVirtualFile(). + enum DiskFileToVirtualFileResult { + SUCCESS, + SHADOWED, + CANNOT_OPEN, + NO_MAPPING + }; + + // Given a path to a file on disk, find a virtual path mapping to that + // file. The first mapping created with MapPath() whose disk_path contains + // the filename is used. However, that virtual path may not actually be + // usable to open the given file. Possible return values are: + // * SUCCESS: The mapping was found. *virtual_file is filled in so that + // calling Open(*virtual_file) will open the file named by disk_file. + // * SHADOWED: A mapping was found, but using Open() to open this virtual + // path will end up returning some different file. This is because some + // other mapping with a higher precedence also matches this virtual path + // and maps it to a different file that exists on disk. *virtual_file + // is filled in as it would be in the SUCCESS case. *shadowing_disk_file + // is filled in with the disk path of the file which would be opened if + // you were to call Open(*virtual_file). + // * CANNOT_OPEN: The mapping was found and was not shadowed, but the + // file specified cannot be opened. When this value is returned, + // errno will indicate the reason the file cannot be opened. *virtual_file + // will be set to the virtual path as in the SUCCESS case, even though + // it is not useful. + // * NO_MAPPING: Indicates that no mapping was found which contains this + // file. + DiskFileToVirtualFileResult DiskFileToVirtualFile( + const std::string& disk_file, std::string* virtual_file, + std::string* shadowing_disk_file); + + // Given a virtual path, find the path to the file on disk. + // Return true and update disk_file with the on-disk path if the file exists. + // Return false and leave disk_file untouched if the file doesn't exist. + bool VirtualFileToDiskFile(const std::string& virtual_file, + std::string* disk_file); + + // implements SourceTree ------------------------------------------- + io::ZeroCopyInputStream* Open(const std::string& filename) override; + + std::string GetLastErrorMessage() override; + + private: + struct Mapping { + std::string virtual_path; + std::string disk_path; + + inline Mapping(const std::string& virtual_path_param, + const std::string& disk_path_param) + : virtual_path(virtual_path_param), disk_path(disk_path_param) {} + }; + std::vector mappings_; + std::string last_error_message_; + + // Like Open(), but returns the on-disk path in disk_file if disk_file is + // non-NULL and the file could be successfully opened. + io::ZeroCopyInputStream* OpenVirtualFile(const std::string& virtual_file, + std::string* disk_file); + + // Like Open() but given the actual on-disk path. + io::ZeroCopyInputStream* OpenDiskFile(const std::string& filename); + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(DiskSourceTree); +}; + +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_IMPORTER_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/java/java_generator.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/java/java_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..99014924c68d1b481a64ef1a65cf3d787356c511 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/java/java_generator.h @@ -0,0 +1,81 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// Generates Java code for a given .proto file. + +#ifndef GOOGLE_PROTOBUF_COMPILER_JAVA_GENERATOR_H__ +#define GOOGLE_PROTOBUF_COMPILER_JAVA_GENERATOR_H__ + +#include +#include + +#include + +namespace google { +namespace protobuf { +namespace compiler { +namespace java { + +// CodeGenerator implementation which generates Java code. If you create your +// own protocol compiler binary and you want it to support Java output, you +// can do so by registering an instance of this CodeGenerator with the +// CommandLineInterface in your main() function. +class PROTOC_EXPORT JavaGenerator : public CodeGenerator { + public: + JavaGenerator(); + ~JavaGenerator(); + + // implements CodeGenerator ---------------------------------------- + bool Generate(const FileDescriptor* file, const std::string& parameter, + GeneratorContext* context, std::string* error) const override; + + uint64_t GetSupportedFeatures() const override; + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(JavaGenerator); +}; + +} // namespace java +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_JAVA_GENERATOR_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/java/java_names.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/java/java_names.h new file mode 100644 index 0000000000000000000000000000000000000000..1e82f60fb15b283a8db5e33f3c0ec2b02288ab8b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/java/java_names.h @@ -0,0 +1,117 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// Provides a mechanism for mapping a descriptor to the +// fully-qualified name of the corresponding Java class. + +#ifndef GOOGLE_PROTOBUF_COMPILER_JAVA_NAMES_H__ +#define GOOGLE_PROTOBUF_COMPILER_JAVA_NAMES_H__ + +#include + +namespace google { +namespace protobuf { + +class Descriptor; +class EnumDescriptor; +class FileDescriptor; +class FieldDescriptor; +class ServiceDescriptor; + +namespace compiler { +namespace java { + +// Requires: +// descriptor != NULL +// +// Returns: +// The fully-qualified Java class name. +std::string ClassName(const Descriptor* descriptor); + +// Requires: +// descriptor != NULL +// +// Returns: +// The fully-qualified Java class name. +std::string ClassName(const EnumDescriptor* descriptor); + +// Requires: +// descriptor != NULL +// +// Returns: +// The fully-qualified Java class name. +std::string ClassName(const FileDescriptor* descriptor); + +// Requires: +// descriptor != NULL +// +// Returns: +// The fully-qualified Java class name. +std::string ClassName(const ServiceDescriptor* descriptor); + +// Requires: +// descriptor != NULL +// +// Returns: +// Java package name. +std::string FileJavaPackage(const FileDescriptor* descriptor); + +// Requires: +// descriptor != NULL +// Returns: +// Capitalized camel case name field name. +std::string CapitalizedFieldName(const FieldDescriptor* descriptor); + +// Requires: +// descriptor != NULL +// Returns: +// Primitive Java type name for the field. +const char* PrimitiveTypeName(const FieldDescriptor* descriptor); + +// Requires: +// descriptor != NULL +// Returns: +// Boes primitive Java type name for the field. +const char* BoxedPrimitiveTypeName(const FieldDescriptor* descriptor); + +} // namespace java +} // namespace compiler +} // namespace protobuf +} // namespace google +#endif // GOOGLE_PROTOBUF_COMPILER_JAVA_NAMES_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/js/js_generator.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/js/js_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..87f69bd39e91d3fd5857ca9513e1469a25dc6409 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/js/js_generator.h @@ -0,0 +1,344 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Generates JavaScript code for a given .proto file. +// +#ifndef GOOGLE_PROTOBUF_COMPILER_JS_GENERATOR_H__ +#define GOOGLE_PROTOBUF_COMPILER_JS_GENERATOR_H__ + +#include +#include + +#include +#include +#include +#include + +#include + +namespace google { +namespace protobuf { + +class Descriptor; +class EnumDescriptor; +class FieldDescriptor; +class OneofDescriptor; +class FileDescriptor; + +namespace io { +class Printer; +} + +namespace compiler { +namespace js { + +struct GeneratorOptions { + // Output path. + std::string output_dir; + // Namespace prefix. + std::string namespace_prefix; + // Enable binary-format support? + bool binary; + // What style of imports should be used. + enum ImportStyle { + kImportClosure, // goog.require() + kImportCommonJs, // require() + kImportCommonJsStrict, // require() with no global export + kImportBrowser, // no import statements + kImportEs6, // import { member } from '' + } import_style; + + GeneratorOptions() + : output_dir("."), + namespace_prefix(""), + binary(false), + import_style(kImportClosure), + add_require_for_enums(false), + testonly(false), + library(""), + error_on_name_conflict(false), + extension(".js"), + one_output_file_per_input_file(false), + annotate_code(false) {} + + bool ParseFromOptions( + const std::vector >& options, + std::string* error); + + // Returns the file name extension to use for generated code. + std::string GetFileNameExtension() const { + return import_style == kImportClosure ? extension : "_pb.js"; + } + + enum OutputMode { + // Create an output file for each input .proto file. + kOneOutputFilePerInputFile, + // Create an output file for each type. + kOneOutputFilePerSCC, + // Put everything in a single file named by the library option. + kEverythingInOneFile, + }; + + // Indicates how to output the generated code based on the provided options. + OutputMode output_mode() const; + + // The remaining options are only relevant when we are using kImportClosure. + + // Add a `goog.requires()` call for each enum type used. If not set, a + // forward declaration with `goog.forwardDeclare` is produced instead. + bool add_require_for_enums; + // Set this as a test-only module via `goog.setTestOnly();`. + bool testonly; + // Create a library with name _lib.js rather than a separate .js file + // per type? + std::string library; + // Error if there are two types that would generate the same output file? + bool error_on_name_conflict; + // The extension to use for output file names. + std::string extension; + // Create a separate output file for each input file? + bool one_output_file_per_input_file; + // If true, we should append annotations as commen on the last line for + // generated .js file. Annotations used by tools like https://kythe.io + // to provide cross-references between .js and .proto files. Annotations + // are enced as base64 proto of GeneratedCodeInfo message (see + // descriptor.proto). + bool annotate_code; +}; + +// CodeGenerator implementation which generates a JavaScript source file and +// header. If you create your own protocol compiler binary and you want it to +// support JavaScript output, you can do so by registering an instance of this +// CodeGenerator with the CommandLineInterface in your main() function. +class PROTOC_EXPORT Generator : public CodeGenerator { + public: + Generator() {} + virtual ~Generator() {} + + bool Generate(const FileDescriptor* file, const std::string& parameter, + GeneratorContext* context, std::string* error) const override { + *error = "Unimplemented Generate() method. Call GenerateAll() instead."; + return false; + } + + bool HasGenerateAll() const override { return true; } + + bool GenerateAll(const std::vector& files, + const std::string& parameter, GeneratorContext* context, + std::string* error) const override; + + uint64 GetSupportedFeatures() const override { + return FEATURE_PROTO3_OPTIONAL; + } + + private: + void GenerateHeader(const GeneratorOptions& options, + const FileDescriptor* file, io::Printer* printer) const; + + // Generate goog.provides() calls. + void FindProvides(const GeneratorOptions& options, io::Printer* printer, + const std::vector& file, + std::set* provided) const; + void FindProvidesForFile(const GeneratorOptions& options, + io::Printer* printer, const FileDescriptor* file, + std::set* provided) const; + void FindProvidesForMessage(const GeneratorOptions& options, + io::Printer* printer, const Descriptor* desc, + std::set* provided) const; + void FindProvidesForEnum(const GeneratorOptions& options, + io::Printer* printer, const EnumDescriptor* enumdesc, + std::set* provided) const; + // For extension fields at file scope. + void FindProvidesForFields(const GeneratorOptions& options, + io::Printer* printer, + const std::vector& fields, + std::set* provided) const; + // Print the goog.provides() found by the methods above. + void GenerateProvides(const GeneratorOptions& options, io::Printer* printer, + std::set* provided) const; + + // Generate goog.setTestOnly() if indicated. + void GenerateTestOnly(const GeneratorOptions& options, + io::Printer* printer) const; + + // Generate goog.requires() calls. + void GenerateRequiresForLibrary( + const GeneratorOptions& options, io::Printer* printer, + const std::vector& files, + std::set* provided) const; + void GenerateRequiresForSCC(const GeneratorOptions& options, + io::Printer* printer, const SCC* scc, + std::set* provided) const; + // For extension fields at file scope. + void GenerateRequiresForExtensions( + const GeneratorOptions& options, io::Printer* printer, + const std::vector& fields, + std::set* provided) const; + void GenerateRequiresImpl(const GeneratorOptions& options, + io::Printer* printer, + std::set* required, + std::set* forwards, + std::set* provided, bool require_jspb, + bool require_extension, bool require_map) const; + void FindRequiresForMessage(const GeneratorOptions& options, + const Descriptor* desc, + std::set* required, + std::set* forwards, + bool* have_message) const; + void FindRequiresForField(const GeneratorOptions& options, + const FieldDescriptor* field, + std::set* required, + std::set* forwards) const; + void FindRequiresForExtension(const GeneratorOptions& options, + const FieldDescriptor* field, + std::set* required, + std::set* forwards) const; + // Generate all things in a proto file into one file. + // If use_short_name is true, the generated file's name will only be short + // name that without directory, otherwise filename equals file->name() + bool GenerateFile(const FileDescriptor* file, const GeneratorOptions& options, + GeneratorContext* context, bool use_short_name) const; + void GenerateFile(const GeneratorOptions& options, io::Printer* printer, + const FileDescriptor* file) const; + + // Generate definitions for all message classes and enums in all files, + // processing the files in dependence order. + void GenerateFilesInDepOrder( + const GeneratorOptions& options, io::Printer* printer, + const std::vector& file) const; + // Helper for above. + void GenerateFileAndDeps(const GeneratorOptions& options, + io::Printer* printer, const FileDescriptor* root, + std::set* all_files, + std::set* generated) const; + + // Generate definitions for all message classes and enums. + void GenerateClassesAndEnums(const GeneratorOptions& options, + io::Printer* printer, + const FileDescriptor* file) const; + + void GenerateFieldValueExpression(io::Printer* printer, + const char* obj_reference, + const FieldDescriptor* field, + bool use_default) const; + + // Generate definition for one class. + void GenerateClass(const GeneratorOptions& options, io::Printer* printer, + const Descriptor* desc) const; + void GenerateClassConstructor(const GeneratorOptions& options, + io::Printer* printer, + const Descriptor* desc) const; + void GenerateClassFieldInfo(const GeneratorOptions& options, + io::Printer* printer, + const Descriptor* desc) const; + void GenerateClassConstructorAndDeclareExtensionFieldInfo( + const GeneratorOptions& options, io::Printer* printer, + const Descriptor* desc) const; + void GenerateClassXid(const GeneratorOptions& options, io::Printer* printer, + const Descriptor* desc) const; + void GenerateOneofCaseDefinition(const GeneratorOptions& options, + io::Printer* printer, + const OneofDescriptor* oneof) const; + void GenerateObjectTypedef(const GeneratorOptions& options, + io::Printer* printer, + const Descriptor* desc) const; + void GenerateClassToObject(const GeneratorOptions& options, + io::Printer* printer, + const Descriptor* desc) const; + void GenerateClassFieldToObject(const GeneratorOptions& options, + io::Printer* printer, + const FieldDescriptor* field) const; + void GenerateClassFromObject(const GeneratorOptions& options, + io::Printer* printer, + const Descriptor* desc) const; + void GenerateClassFieldFromObject(const GeneratorOptions& options, + io::Printer* printer, + const FieldDescriptor* field) const; + void GenerateClassRegistration(const GeneratorOptions& options, + io::Printer* printer, + const Descriptor* desc) const; + void GenerateClassFields(const GeneratorOptions& options, + io::Printer* printer, const Descriptor* desc) const; + void GenerateClassField(const GeneratorOptions& options, io::Printer* printer, + const FieldDescriptor* desc) const; + void GenerateClassExtensionFieldInfo(const GeneratorOptions& options, + io::Printer* printer, + const Descriptor* desc) const; + void GenerateClassDeserialize(const GeneratorOptions& options, + io::Printer* printer, + const Descriptor* desc) const; + void GenerateClassDeserializeBinary(const GeneratorOptions& options, + io::Printer* printer, + const Descriptor* desc) const; + void GenerateClassDeserializeBinaryField(const GeneratorOptions& options, + io::Printer* printer, + const FieldDescriptor* field) const; + void GenerateClassSerializeBinary(const GeneratorOptions& options, + io::Printer* printer, + const Descriptor* desc) const; + void GenerateClassSerializeBinaryField(const GeneratorOptions& options, + io::Printer* printer, + const FieldDescriptor* field) const; + + // Generate definition for one enum. + void GenerateEnum(const GeneratorOptions& options, io::Printer* printer, + const EnumDescriptor* enumdesc) const; + + // Generate an extension definition. + void GenerateExtension(const GeneratorOptions& options, io::Printer* printer, + const FieldDescriptor* field) const; + + // Generate addFoo() method for repeated primitive fields. + void GenerateRepeatedPrimitiveHelperMethods(const GeneratorOptions& options, + io::Printer* printer, + const FieldDescriptor* field, + bool untyped) const; + + // Generate addFoo() method for repeated message fields. + void GenerateRepeatedMessageHelperMethods(const GeneratorOptions& options, + io::Printer* printer, + const FieldDescriptor* field) const; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(Generator); +}; + +} // namespace js +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_JS_GENERATOR_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/js/well_known_types_embed.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/js/well_known_types_embed.h new file mode 100644 index 0000000000000000000000000000000000000000..5e3d8361ab42ed19536e2d1e284fc9c0a0526122 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/js/well_known_types_embed.h @@ -0,0 +1,48 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_COMPILER_JS_WELL_KNOWN_TYPES_EMBED_H__ +#define GOOGLE_PROTOBUF_COMPILER_JS_WELL_KNOWN_TYPES_EMBED_H__ + +#include + +struct FileToc { + const char* name; + const char* data; +}; + +extern struct FileToc well_known_types_js[]; + +#endif // GOOGLE_PROTOBUF_COMPILER_JS_WELL_KNOWN_TYPES_EMBED_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/objectivec/objectivec_generator.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/objectivec/objectivec_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..e5673a7bdaa2950a9e9b3a14eebbf5313c9696ff --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/objectivec/objectivec_generator.h @@ -0,0 +1,87 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Generates ObjectiveC code for a given .proto file. + +#ifndef GOOGLE_PROTOBUF_COMPILER_OBJECTIVEC_GENERATOR_H__ +#define GOOGLE_PROTOBUF_COMPILER_OBJECTIVEC_GENERATOR_H__ + +#include +#include +#include + +#include + +namespace google { +namespace protobuf { +namespace compiler { +namespace objectivec { + +// CodeGenerator implementation which generates a ObjectiveC source file and +// header. If you create your own protocol compiler binary and you want it to +// support ObjectiveC output, you can do so by registering an instance of this +// CodeGenerator with the CommandLineInterface in your main() function. +class PROTOC_EXPORT ObjectiveCGenerator : public CodeGenerator { + public: + ObjectiveCGenerator(); + ~ObjectiveCGenerator(); + + ObjectiveCGenerator(const ObjectiveCGenerator&) = delete; + ObjectiveCGenerator& operator=(const ObjectiveCGenerator&) = delete; + + // implements CodeGenerator ---------------------------------------- + bool HasGenerateAll() const override; + bool Generate(const FileDescriptor* file, + const string& parameter, + GeneratorContext* context, + string* error) const override; + bool GenerateAll(const std::vector& files, + const string& parameter, + GeneratorContext* context, + string* error) const override; + + uint64_t GetSupportedFeatures() const override { + return FEATURE_PROTO3_OPTIONAL; + } +}; + +} // namespace objectivec +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_OBJECTIVEC_GENERATOR_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/objectivec/objectivec_helpers.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/objectivec/objectivec_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..f170d077ab39a25c3ca56d9337c249f16cb57445 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/objectivec/objectivec_helpers.h @@ -0,0 +1,333 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Helper functions for generating ObjectiveC code. + +#ifndef GOOGLE_PROTOBUF_COMPILER_OBJECTIVEC_HELPERS_H__ +#define GOOGLE_PROTOBUF_COMPILER_OBJECTIVEC_HELPERS_H__ + +#include +#include + +#include +#include + +#include + +namespace google { +namespace protobuf { +namespace compiler { +namespace objectivec { + +// Generator options (see objectivec_generator.cc for a description of each): +struct Options { + Options(); + string expected_prefixes_path; + std::vector expected_prefixes_suppressions; + string generate_for_named_framework; + string named_framework_to_proto_path_mappings_path; + string runtime_import_prefix; +}; + +// Escape C++ trigraphs by escaping question marks to "\?". +string PROTOC_EXPORT EscapeTrigraphs(const string& to_escape); + +// Strips ".proto" or ".protodevel" from the end of a filename. +string PROTOC_EXPORT StripProto(const string& filename); + +// Remove white space from either end of a StringPiece. +void PROTOC_EXPORT TrimWhitespace(StringPiece* input); + +// Returns true if the name requires a ns_returns_not_retained attribute applied +// to it. +bool PROTOC_EXPORT IsRetainedName(const string& name); + +// Returns true if the name starts with "init" and will need to have special +// handling under ARC. +bool PROTOC_EXPORT IsInitName(const string& name); + +// Gets the objc_class_prefix. +string PROTOC_EXPORT FileClassPrefix(const FileDescriptor* file); + +// Gets the path of the file we're going to generate (sans the .pb.h +// extension). The path will be dependent on the objectivec package +// declared in the proto package. +string PROTOC_EXPORT FilePath(const FileDescriptor* file); + +// Just like FilePath(), but without the directory part. +string PROTOC_EXPORT FilePathBasename(const FileDescriptor* file); + +// Gets the name of the root class we'll generate in the file. This class +// is not meant for external consumption, but instead contains helpers that +// the rest of the classes need +string PROTOC_EXPORT FileClassName(const FileDescriptor* file); + +// These return the fully-qualified class name corresponding to the given +// descriptor. +string PROTOC_EXPORT ClassName(const Descriptor* descriptor); +string PROTOC_EXPORT ClassName(const Descriptor* descriptor, + string* out_suffix_added); +string PROTOC_EXPORT EnumName(const EnumDescriptor* descriptor); + +// Returns the fully-qualified name of the enum value corresponding to the +// the descriptor. +string PROTOC_EXPORT EnumValueName(const EnumValueDescriptor* descriptor); + +// Returns the name of the enum value corresponding to the descriptor. +string PROTOC_EXPORT EnumValueShortName(const EnumValueDescriptor* descriptor); + +// Reverse what an enum does. +string PROTOC_EXPORT UnCamelCaseEnumShortName(const string& name); + +// Returns the name to use for the extension (used as the method off the file's +// Root class). +string PROTOC_EXPORT ExtensionMethodName(const FieldDescriptor* descriptor); + +// Returns the transformed field name. +string PROTOC_EXPORT FieldName(const FieldDescriptor* field); +string PROTOC_EXPORT FieldNameCapitalized(const FieldDescriptor* field); + +// Returns the transformed oneof name. +string PROTOC_EXPORT OneofEnumName(const OneofDescriptor* descriptor); +string PROTOC_EXPORT OneofName(const OneofDescriptor* descriptor); +string PROTOC_EXPORT OneofNameCapitalized(const OneofDescriptor* descriptor); + +// Returns a symbol that can be used in C code to refer to an Objective C +// class without initializing the class. +string PROTOC_EXPORT ObjCClass(const string& class_name); + +// Declares an Objective C class without initializing the class so that it can +// be refrerred to by ObjCClass. +string PROTOC_EXPORT ObjCClassDeclaration(const string& class_name); + +inline bool HasPreservingUnknownEnumSemantics(const FileDescriptor* file) { + return file->syntax() == FileDescriptor::SYNTAX_PROTO3; +} + +inline bool IsMapEntryMessage(const Descriptor* descriptor) { + return descriptor->options().map_entry(); +} + +// Reverse of the above. +string PROTOC_EXPORT UnCamelCaseFieldName(const string& name, + const FieldDescriptor* field); + +enum ObjectiveCType { + OBJECTIVECTYPE_INT32, + OBJECTIVECTYPE_UINT32, + OBJECTIVECTYPE_INT64, + OBJECTIVECTYPE_UINT64, + OBJECTIVECTYPE_FLOAT, + OBJECTIVECTYPE_DOUBLE, + OBJECTIVECTYPE_BOOLEAN, + OBJECTIVECTYPE_STRING, + OBJECTIVECTYPE_DATA, + OBJECTIVECTYPE_ENUM, + OBJECTIVECTYPE_MESSAGE +}; + +enum FlagType { + FLAGTYPE_DESCRIPTOR_INITIALIZATION, + FLAGTYPE_EXTENSION, + FLAGTYPE_FIELD +}; + +template +string GetOptionalDeprecatedAttribute( + const TDescriptor* descriptor, + const FileDescriptor* file = NULL, + bool preSpace = true, bool postNewline = false) { + bool isDeprecated = descriptor->options().deprecated(); + // The file is only passed when checking Messages & Enums, so those types + // get tagged. At the moment, it doesn't seem to make sense to tag every + // field or enum value with when the file is deprecated. + bool isFileLevelDeprecation = false; + if (!isDeprecated && file) { + isFileLevelDeprecation = file->options().deprecated(); + isDeprecated = isFileLevelDeprecation; + } + if (isDeprecated) { + string message; + const FileDescriptor* sourceFile = descriptor->file(); + if (isFileLevelDeprecation) { + message = sourceFile->name() + " is deprecated."; + } else { + message = descriptor->full_name() + " is deprecated (see " + + sourceFile->name() + ")."; + } + + string result = string("GPB_DEPRECATED_MSG(\"") + message + "\")"; + if (preSpace) { + result.insert(0, " "); + } + if (postNewline) { + result.append("\n"); + } + return result; + } else { + return ""; + } +} + +string PROTOC_EXPORT GetCapitalizedType(const FieldDescriptor* field); + +ObjectiveCType PROTOC_EXPORT +GetObjectiveCType(FieldDescriptor::Type field_type); + +inline ObjectiveCType GetObjectiveCType(const FieldDescriptor* field) { + return GetObjectiveCType(field->type()); +} + +bool PROTOC_EXPORT IsPrimitiveType(const FieldDescriptor* field); +bool PROTOC_EXPORT IsReferenceType(const FieldDescriptor* field); + +string PROTOC_EXPORT GPBGenericValueFieldName(const FieldDescriptor* field); +string PROTOC_EXPORT DefaultValue(const FieldDescriptor* field); +bool PROTOC_EXPORT HasNonZeroDefaultValue(const FieldDescriptor* field); + +string PROTOC_EXPORT BuildFlagsString(const FlagType type, + const std::vector& strings); + +// Builds HeaderDoc/appledoc style comments out of the comments in the .proto +// file. +string PROTOC_EXPORT BuildCommentsString(const SourceLocation& location, + bool prefer_single_line); + +// The name the commonly used by the library when built as a framework. +// This lines up to the name used in the CocoaPod. +extern PROTOC_EXPORT const char* const ProtobufLibraryFrameworkName; +// Returns the CPP symbol name to use as the gate for framework style imports +// for the given framework name to use. +string PROTOC_EXPORT +ProtobufFrameworkImportSymbol(const string& framework_name); + +// Checks if the file is one of the proto's bundled with the library. +bool PROTOC_EXPORT +IsProtobufLibraryBundledProtoFile(const FileDescriptor* file); + +// Checks the prefix for the given files and outputs any warnings as needed. If +// there are flat out errors, then out_error is filled in with the first error +// and the result is false. +bool PROTOC_EXPORT +ValidateObjCClassPrefixes(const std::vector& files, + const Options& generation_options, string* out_error); + +// Generate decode data needed for ObjC's GPBDecodeTextFormatName() to transform +// the input into the expected output. +class PROTOC_EXPORT TextFormatDecodeData { + public: + TextFormatDecodeData(); + ~TextFormatDecodeData(); + + TextFormatDecodeData(const TextFormatDecodeData&) = delete; + TextFormatDecodeData& operator=(const TextFormatDecodeData&) = delete; + + void AddString(int32 key, const string& input_for_decode, + const string& desired_output); + size_t num_entries() const { return entries_.size(); } + string Data() const; + + static string DecodeDataForString(const string& input_for_decode, + const string& desired_output); + + private: + typedef std::pair DataEntry; + std::vector entries_; +}; + +// Helper for parsing simple files. +class PROTOC_EXPORT LineConsumer { + public: + LineConsumer(); + virtual ~LineConsumer(); + virtual bool ConsumeLine(const StringPiece& line, string* out_error) = 0; +}; + +bool PROTOC_EXPORT ParseSimpleFile(const string& path, + LineConsumer* line_consumer, + string* out_error); + +// Helper class for parsing framework import mappings and generating +// import statements. +class PROTOC_EXPORT ImportWriter { + public: + ImportWriter(const string& generate_for_named_framework, + const string& named_framework_to_proto_path_mappings_path, + const string& runtime_import_prefix, + bool include_wkt_imports); + ~ImportWriter(); + + void AddFile(const FileDescriptor* file, const string& header_extension); + void Print(io::Printer *printer) const; + + static void PrintRuntimeImports(io::Printer *printer, + const std::vector& header_to_import, + const string& runtime_import_prefix, + bool default_cpp_symbol = false); + + private: + class ProtoFrameworkCollector : public LineConsumer { + public: + ProtoFrameworkCollector(std::map* inout_proto_file_to_framework_name) + : map_(inout_proto_file_to_framework_name) {} + + virtual bool ConsumeLine(const StringPiece& line, string* out_error); + + private: + std::map* map_; + }; + + void ParseFrameworkMappings(); + + const string generate_for_named_framework_; + const string named_framework_to_proto_path_mappings_path_; + const string runtime_import_prefix_; + const bool include_wkt_imports_; + std::map proto_file_to_framework_name_; + bool need_to_parse_mapping_file_; + + std::vector protobuf_imports_; + std::vector other_framework_imports_; + std::vector other_imports_; +}; + +} // namespace objectivec +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_OBJECTIVEC_HELPERS_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/parser.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/parser.h new file mode 100644 index 0000000000000000000000000000000000000000..ea3b64dc72f5316e51261584c153274455c033f5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/parser.h @@ -0,0 +1,605 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// Implements parsing of .proto files to FileDescriptorProtos. + +#ifndef GOOGLE_PROTOBUF_COMPILER_PARSER_H__ +#define GOOGLE_PROTOBUF_COMPILER_PARSER_H__ + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace google { +namespace protobuf { + +class Message; + +namespace compiler { + +// Defined in this file. +class Parser; +class SourceLocationTable; + +// Implements parsing of protocol definitions (such as .proto files). +// +// Note that most users will be more interested in the Importer class. +// Parser is a lower-level class which simply converts a single .proto file +// to a FileDescriptorProto. It does not resolve import directives or perform +// many other kinds of validation needed to construct a complete +// FileDescriptor. +class PROTOBUF_EXPORT Parser { + public: + Parser(); + ~Parser(); + + // Parse the entire input and construct a FileDescriptorProto representing + // it. Returns true if no errors occurred, false otherwise. + bool Parse(io::Tokenizer* input, FileDescriptorProto* file); + + // Optional features: + + // DEPRECATED: New code should use the SourceCodeInfo embedded in the + // FileDescriptorProto. + // + // Requests that locations of certain definitions be recorded to the given + // SourceLocationTable while parsing. This can be used to look up exact line + // and column numbers for errors reported by DescriptorPool during validation. + // Set to NULL (the default) to discard source location information. + void RecordSourceLocationsTo(SourceLocationTable* location_table) { + source_location_table_ = location_table; + } + + // Requests that errors be recorded to the given ErrorCollector while + // parsing. Set to NULL (the default) to discard error messages. + void RecordErrorsTo(io::ErrorCollector* error_collector) { + error_collector_ = error_collector; + } + + // Returns the identifier used in the "syntax = " declaration, if one was + // seen during the last call to Parse(), or the empty string otherwise. + const std::string& GetSyntaxIdentifier() { return syntax_identifier_; } + + // If set true, input files will be required to begin with a syntax + // identifier. Otherwise, files may omit this. If a syntax identifier + // is provided, it must be 'syntax = "proto2";' and must appear at the + // top of this file regardless of whether or not it was required. + void SetRequireSyntaxIdentifier(bool value) { + require_syntax_identifier_ = value; + } + + // Call SetStopAfterSyntaxIdentifier(true) to tell the parser to stop + // parsing as soon as it has seen the syntax identifier, or lack thereof. + // This is useful for quickly identifying the syntax of the file without + // parsing the whole thing. If this is enabled, no error will be recorded + // if the syntax identifier is something other than "proto2" (since + // presumably the caller intends to deal with that), but other kinds of + // errors (e.g. parse errors) will still be reported. When this is enabled, + // you may pass a NULL FileDescriptorProto to Parse(). + void SetStopAfterSyntaxIdentifier(bool value) { + stop_after_syntax_identifier_ = value; + } + + private: + class LocationRecorder; + + // ================================================================= + // Error recovery helpers + + // Consume the rest of the current statement. This consumes tokens + // until it sees one of: + // ';' Consumes the token and returns. + // '{' Consumes the brace then calls SkipRestOfBlock(). + // '}' Returns without consuming. + // EOF Returns (can't consume). + // The Parser often calls SkipStatement() after encountering a syntax + // error. This allows it to go on parsing the following lines, allowing + // it to report more than just one error in the file. + void SkipStatement(); + + // Consume the rest of the current block, including nested blocks, + // ending after the closing '}' is encountered and consumed, or at EOF. + void SkipRestOfBlock(); + + // ----------------------------------------------------------------- + // Single-token consuming helpers + // + // These make parsing code more readable. + + // True if the current token is TYPE_END. + inline bool AtEnd(); + + // True if the next token matches the given text. + inline bool LookingAt(const char* text); + // True if the next token is of the given type. + inline bool LookingAtType(io::Tokenizer::TokenType token_type); + + // If the next token exactly matches the text given, consume it and return + // true. Otherwise, return false without logging an error. + bool TryConsume(const char* text); + + // These attempt to read some kind of token from the input. If successful, + // they return true. Otherwise they return false and add the given error + // to the error list. + + // Consume a token with the exact text given. + bool Consume(const char* text, const char* error); + // Same as above, but automatically generates the error "Expected \"text\".", + // where "text" is the expected token text. + bool Consume(const char* text); + // Consume a token of type IDENTIFIER and store its text in "output". + bool ConsumeIdentifier(std::string* output, const char* error); + // Consume an integer and store its value in "output". + bool ConsumeInteger(int* output, const char* error); + // Consume a signed integer and store its value in "output". + bool ConsumeSignedInteger(int* output, const char* error); + // Consume a 64-bit integer and store its value in "output". If the value + // is greater than max_value, an error will be reported. + bool ConsumeInteger64(uint64 max_value, uint64* output, const char* error); + // Consume a number and store its value in "output". This will accept + // tokens of either INTEGER or FLOAT type. + bool ConsumeNumber(double* output, const char* error); + // Consume a string literal and store its (unescaped) value in "output". + bool ConsumeString(std::string* output, const char* error); + + // Consume a token representing the end of the statement. Comments between + // this token and the next will be harvested for documentation. The given + // LocationRecorder should refer to the declaration that was just parsed; + // it will be populated with these comments. + // + // TODO(kenton): The LocationRecorder is const because historically locations + // have been passed around by const reference, for no particularly good + // reason. We should probably go through and change them all to mutable + // pointer to make this more intuitive. + bool TryConsumeEndOfDeclaration(const char* text, + const LocationRecorder* location); + bool TryConsumeEndOfDeclarationFinishScope(const char* text, + const LocationRecorder* location); + + bool ConsumeEndOfDeclaration(const char* text, + const LocationRecorder* location); + + // ----------------------------------------------------------------- + // Error logging helpers + + // Invokes error_collector_->AddError(), if error_collector_ is not NULL. + void AddError(int line, int column, const std::string& error); + + // Invokes error_collector_->AddError() with the line and column number + // of the current token. + void AddError(const std::string& error); + + // Invokes error_collector_->AddWarning() with the line and column number + // of the current token. + void AddWarning(const std::string& warning); + + // Records a location in the SourceCodeInfo.location table (see + // descriptor.proto). We use RAII to ensure that the start and end locations + // are recorded -- the constructor records the start location and the + // destructor records the end location. Since the parser is + // recursive-descent, this works out beautifully. + class PROTOBUF_EXPORT LocationRecorder { + public: + // Construct the file's "root" location. + LocationRecorder(Parser* parser); + + // Construct a location that represents a declaration nested within the + // given parent. E.g. a field's location is nested within the location + // for a message type. The parent's path will be copied, so you should + // call AddPath() only to add the path components leading from the parent + // to the child (as opposed to leading from the root to the child). + LocationRecorder(const LocationRecorder& parent); + + // Convenience constructors that call AddPath() one or two times. + LocationRecorder(const LocationRecorder& parent, int path1); + LocationRecorder(const LocationRecorder& parent, int path1, int path2); + + // Creates a recorder that generates locations into given source code info. + LocationRecorder(const LocationRecorder& parent, int path1, + SourceCodeInfo* source_code_info); + + ~LocationRecorder(); + + // Add a path component. See SourceCodeInfo.Location.path in + // descriptor.proto. + void AddPath(int path_component); + + // By default the location is considered to start at the current token at + // the time the LocationRecorder is created. StartAt() sets the start + // location to the given token instead. + void StartAt(const io::Tokenizer::Token& token); + + // Start at the same location as some other LocationRecorder. + void StartAt(const LocationRecorder& other); + + // By default the location is considered to end at the previous token at + // the time the LocationRecorder is destroyed. EndAt() sets the end + // location to the given token instead. + void EndAt(const io::Tokenizer::Token& token); + + // Records the start point of this location to the SourceLocationTable that + // was passed to RecordSourceLocationsTo(), if any. SourceLocationTable + // is an older way of keeping track of source locations which is still + // used in some places. + void RecordLegacyLocation( + const Message* descriptor, + DescriptorPool::ErrorCollector::ErrorLocation location); + void RecordLegacyImportLocation(const Message* descriptor, + const std::string& name); + + // Returns the number of path components in the recorder's current location. + int CurrentPathSize() const; + + // Attaches leading and trailing comments to the location. The two strings + // will be swapped into place, so after this is called *leading and + // *trailing will be empty. + // + // TODO(kenton): See comment on TryConsumeEndOfDeclaration(), above, for + // why this is const. + void AttachComments(std::string* leading, std::string* trailing, + std::vector* detached_comments) const; + + private: + // Indexes of parent and current location in the parent + // SourceCodeInfo.location repeated field. For top-level elements, + // parent_index_ is -1. + Parser* parser_; + SourceCodeInfo* source_code_info_; + SourceCodeInfo::Location* location_; + + void Init(const LocationRecorder& parent, SourceCodeInfo* source_code_info); + }; + + // ================================================================= + // Parsers for various language constructs + + // Parses the "syntax = \"proto2\";" line at the top of the file. Returns + // false if it failed to parse or if the syntax identifier was not + // recognized. + bool ParseSyntaxIdentifier(const LocationRecorder& parent); + + // These methods parse various individual bits of code. They return + // false if they completely fail to parse the construct. In this case, + // it is probably necessary to skip the rest of the statement to recover. + // However, if these methods return true, it does NOT mean that there + // were no errors; only that there were no *syntax* errors. For instance, + // if a service method is defined using proper syntax but uses a primitive + // type as its input or output, ParseMethodField() still returns true + // and only reports the error by calling AddError(). In practice, this + // makes logic much simpler for the caller. + + // Parse a top-level message, enum, service, etc. + bool ParseTopLevelStatement(FileDescriptorProto* file, + const LocationRecorder& root_location); + + // Parse various language high-level language construrcts. + bool ParseMessageDefinition(DescriptorProto* message, + const LocationRecorder& message_location, + const FileDescriptorProto* containing_file); + bool ParseEnumDefinition(EnumDescriptorProto* enum_type, + const LocationRecorder& enum_location, + const FileDescriptorProto* containing_file); + bool ParseServiceDefinition(ServiceDescriptorProto* service, + const LocationRecorder& service_location, + const FileDescriptorProto* containing_file); + bool ParsePackage(FileDescriptorProto* file, + const LocationRecorder& root_location, + const FileDescriptorProto* containing_file); + bool ParseImport(RepeatedPtrField* dependency, + RepeatedField* public_dependency, + RepeatedField* weak_dependency, + const LocationRecorder& root_location, + const FileDescriptorProto* containing_file); + + // These methods parse the contents of a message, enum, or service type and + // add them to the given object. They consume the entire block including + // the beginning and ending brace. + bool ParseMessageBlock(DescriptorProto* message, + const LocationRecorder& message_location, + const FileDescriptorProto* containing_file); + bool ParseEnumBlock(EnumDescriptorProto* enum_type, + const LocationRecorder& enum_location, + const FileDescriptorProto* containing_file); + bool ParseServiceBlock(ServiceDescriptorProto* service, + const LocationRecorder& service_location, + const FileDescriptorProto* containing_file); + + // Parse one statement within a message, enum, or service block, including + // final semicolon. + bool ParseMessageStatement(DescriptorProto* message, + const LocationRecorder& message_location, + const FileDescriptorProto* containing_file); + bool ParseEnumStatement(EnumDescriptorProto* message, + const LocationRecorder& enum_location, + const FileDescriptorProto* containing_file); + bool ParseServiceStatement(ServiceDescriptorProto* message, + const LocationRecorder& service_location, + const FileDescriptorProto* containing_file); + + // Parse a field of a message. If the field is a group, its type will be + // added to "messages". + // + // parent_location and location_field_number_for_nested_type are needed when + // parsing groups -- we need to generate a nested message type within the + // parent and record its location accordingly. Since the parent could be + // either a FileDescriptorProto or a DescriptorProto, we must pass in the + // correct field number to use. + bool ParseMessageField(FieldDescriptorProto* field, + RepeatedPtrField* messages, + const LocationRecorder& parent_location, + int location_field_number_for_nested_type, + const LocationRecorder& field_location, + const FileDescriptorProto* containing_file); + + // Like ParseMessageField() but expects the label has already been filled in + // by the caller. + bool ParseMessageFieldNoLabel(FieldDescriptorProto* field, + RepeatedPtrField* messages, + const LocationRecorder& parent_location, + int location_field_number_for_nested_type, + const LocationRecorder& field_location, + const FileDescriptorProto* containing_file); + + // Parse an "extensions" declaration. + bool ParseExtensions(DescriptorProto* message, + const LocationRecorder& extensions_location, + const FileDescriptorProto* containing_file); + + // Parse a "reserved" declaration. + bool ParseReserved(DescriptorProto* message, + const LocationRecorder& message_location); + bool ParseReservedNames(DescriptorProto* message, + const LocationRecorder& parent_location); + bool ParseReservedNumbers(DescriptorProto* message, + const LocationRecorder& parent_location); + bool ParseReserved(EnumDescriptorProto* message, + const LocationRecorder& message_location); + bool ParseReservedNames(EnumDescriptorProto* message, + const LocationRecorder& parent_location); + bool ParseReservedNumbers(EnumDescriptorProto* message, + const LocationRecorder& parent_location); + + // Parse an "extend" declaration. (See also comments for + // ParseMessageField().) + bool ParseExtend(RepeatedPtrField* extensions, + RepeatedPtrField* messages, + const LocationRecorder& parent_location, + int location_field_number_for_nested_type, + const LocationRecorder& extend_location, + const FileDescriptorProto* containing_file); + + // Parse a "oneof" declaration. The caller is responsible for setting + // oneof_decl->label() since it will have had to parse the label before it + // knew it was parsing a oneof. + bool ParseOneof(OneofDescriptorProto* oneof_decl, + DescriptorProto* containing_type, int oneof_index, + const LocationRecorder& oneof_location, + const LocationRecorder& containing_type_location, + const FileDescriptorProto* containing_file); + + // Parse a single enum value within an enum block. + bool ParseEnumConstant(EnumValueDescriptorProto* enum_value, + const LocationRecorder& enum_value_location, + const FileDescriptorProto* containing_file); + + // Parse enum constant options, i.e. the list in square brackets at the end + // of the enum constant value definition. + bool ParseEnumConstantOptions(EnumValueDescriptorProto* value, + const LocationRecorder& enum_value_location, + const FileDescriptorProto* containing_file); + + // Parse a single method within a service definition. + bool ParseServiceMethod(MethodDescriptorProto* method, + const LocationRecorder& method_location, + const FileDescriptorProto* containing_file); + + + // Parse options of a single method or stream. + bool ParseMethodOptions(const LocationRecorder& parent_location, + const FileDescriptorProto* containing_file, + const int optionsFieldNumber, + Message* mutable_options); + + // Parse "required", "optional", or "repeated" and fill in "label" + // with the value. Returns true if such a label is consumed. + bool ParseLabel(FieldDescriptorProto::Label* label, + const LocationRecorder& field_location, + const FileDescriptorProto* containing_file); + + // Parse a type name and fill in "type" (if it is a primitive) or + // "type_name" (if it is not) with the type parsed. + bool ParseType(FieldDescriptorProto::Type* type, std::string* type_name); + // Parse a user-defined type and fill in "type_name" with the name. + // If a primitive type is named, it is treated as an error. + bool ParseUserDefinedType(std::string* type_name); + + // Parses field options, i.e. the stuff in square brackets at the end + // of a field definition. Also parses default value. + bool ParseFieldOptions(FieldDescriptorProto* field, + const LocationRecorder& field_location, + const FileDescriptorProto* containing_file); + + // Parse the "default" option. This needs special handling because its + // type is the field's type. + bool ParseDefaultAssignment(FieldDescriptorProto* field, + const LocationRecorder& field_location, + const FileDescriptorProto* containing_file); + + bool ParseJsonName(FieldDescriptorProto* field, + const LocationRecorder& field_location, + const FileDescriptorProto* containing_file); + + enum OptionStyle { + OPTION_ASSIGNMENT, // just "name = value" + OPTION_STATEMENT // "option name = value;" + }; + + // Parse a single option name/value pair, e.g. "ctype = CORD". The name + // identifies a field of the given Message, and the value of that field + // is set to the parsed value. + bool ParseOption(Message* options, const LocationRecorder& options_location, + const FileDescriptorProto* containing_file, + OptionStyle style); + + // Parses a single part of a multipart option name. A multipart name consists + // of names separated by dots. Each name is either an identifier or a series + // of identifiers separated by dots and enclosed in parentheses. E.g., + // "foo.(bar.baz).qux". + bool ParseOptionNamePart(UninterpretedOption* uninterpreted_option, + const LocationRecorder& part_location, + const FileDescriptorProto* containing_file); + + // Parses a string surrounded by balanced braces. Strips off the outer + // braces and stores the enclosed string in *value. + // E.g., + // { foo } *value gets 'foo' + // { foo { bar: box } } *value gets 'foo { bar: box }' + // {} *value gets '' + // + // REQUIRES: LookingAt("{") + // When finished successfully, we are looking at the first token past + // the ending brace. + bool ParseUninterpretedBlock(std::string* value); + + struct MapField { + // Whether the field is a map field. + bool is_map_field; + // The types of the key and value if they are primitive types. + FieldDescriptorProto::Type key_type; + FieldDescriptorProto::Type value_type; + // Or the type names string if the types are customized types. + std::string key_type_name; + std::string value_type_name; + + MapField() : is_map_field(false) {} + }; + // Desugar the map syntax to generate a nested map entry message. + void GenerateMapEntry(const MapField& map_field, FieldDescriptorProto* field, + RepeatedPtrField* messages); + + // Whether fields without label default to optional fields. + bool DefaultToOptionalFields() const { + return syntax_identifier_ == "proto3"; + } + + + bool ValidateEnum(const EnumDescriptorProto* proto); + + // ================================================================= + + io::Tokenizer* input_; + io::ErrorCollector* error_collector_; + SourceCodeInfo* source_code_info_; + SourceLocationTable* source_location_table_; // legacy + bool had_errors_; + bool require_syntax_identifier_; + bool stop_after_syntax_identifier_; + std::string syntax_identifier_; + + // Leading doc comments for the next declaration. These are not complete + // yet; use ConsumeEndOfDeclaration() to get the complete comments. + std::string upcoming_doc_comments_; + + // Detached comments are not connected to any syntax entities. Elements in + // this vector are paragraphs of comments separated by empty lines. The + // detached comments will be put into the leading_detached_comments field for + // the next element (See SourceCodeInfo.Location in descriptor.proto), when + // ConsumeEndOfDeclaration() is called. + std::vector upcoming_detached_comments_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(Parser); +}; + +// A table mapping (descriptor, ErrorLocation) pairs -- as reported by +// DescriptorPool when validating descriptors -- to line and column numbers +// within the original source code. +// +// This is semi-obsolete: FileDescriptorProto.source_code_info now contains +// far more complete information about source locations. However, as of this +// writing you still need to use SourceLocationTable when integrating with +// DescriptorPool. +class PROTOBUF_EXPORT SourceLocationTable { + public: + SourceLocationTable(); + ~SourceLocationTable(); + + // Finds the precise location of the given error and fills in *line and + // *column with the line and column numbers. If not found, sets *line to + // -1 and *column to 0 (since line = -1 is used to mean "error has no exact + // location" in the ErrorCollector interface). Returns true if found, false + // otherwise. + bool Find(const Message* descriptor, + DescriptorPool::ErrorCollector::ErrorLocation location, int* line, + int* column) const; + bool FindImport(const Message* descriptor, const std::string& name, int* line, + int* column) const; + + // Adds a location to the table. + void Add(const Message* descriptor, + DescriptorPool::ErrorCollector::ErrorLocation location, int line, + int column); + void AddImport(const Message* descriptor, const std::string& name, int line, + int column); + + // Clears the contents of the table. + void Clear(); + + private: + typedef std::map< + std::pair, + std::pair > + LocationMap; + LocationMap location_map_; + std::map, std::pair > + import_location_map_; +}; + +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_PARSER_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/php/php_generator.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/php/php_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..e2610ec4dde2b539e6e021d55dd30e252bdc94ea --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/php/php_generator.h @@ -0,0 +1,97 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_COMPILER_PHP_GENERATOR_H__ +#define GOOGLE_PROTOBUF_COMPILER_PHP_GENERATOR_H__ + +#include +#include + +#include + +#include + +namespace google { +namespace protobuf { +namespace compiler { +namespace php { + +class PROTOC_EXPORT Generator : public CodeGenerator { + public: + virtual bool Generate( + const FileDescriptor* file, + const string& parameter, + GeneratorContext* generator_context, + string* error) const override; + + bool GenerateAll(const std::vector& files, + const std::string& parameter, + GeneratorContext* generator_context, + std::string* error) const override; + + uint64_t GetSupportedFeatures() const override { + return FEATURE_PROTO3_OPTIONAL; + } + + private: + bool Generate( + const FileDescriptor* file, + bool is_descriptor, + bool aggregate_metadata, + const std::set& aggregate_metadata_prefixes, + GeneratorContext* generator_context, + string* error) const; +}; + +// To skip reserved keywords in php, some generated classname are prefixed. +// Other code generators may need following API to figure out the actual +// classname. +PROTOC_EXPORT std::string GeneratedClassName(const Descriptor* desc); +PROTOC_EXPORT std::string GeneratedClassName(const EnumDescriptor* desc); +PROTOC_EXPORT std::string GeneratedClassName(const ServiceDescriptor* desc); + +inline bool IsWrapperType(const FieldDescriptor* descriptor) { + return descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE && + descriptor->message_type()->file()->name() == "google/protobuf/wrappers.proto"; +} + +} // namespace php +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_PHP_GENERATOR_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/plugin.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..a25079235a36c11ad7ca8ef8eeea78c172178eeb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/plugin.h @@ -0,0 +1,99 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// +// Front-end for protoc code generator plugins written in C++. +// +// To implement a protoc plugin in C++, simply write an implementation of +// CodeGenerator, then create a main() function like: +// int main(int argc, char* argv[]) { +// MyCodeGenerator generator; +// return google::protobuf::compiler::PluginMain(argc, argv, &generator); +// } +// You must link your plugin against libprotobuf and libprotoc. +// +// The core part of PluginMain is to invoke the given CodeGenerator on a +// CodeGeneratorRequest to generate a CodeGeneratorResponse. This part is +// abstracted out and made into function GenerateCode so that it can be reused, +// for example, to implement a variant of PluginMain that does some +// preprocessing on the input CodeGeneratorRequest before feeding the request +// to the given code generator. +// +// To get protoc to use the plugin, do one of the following: +// * Place the plugin binary somewhere in the PATH and give it the name +// "protoc-gen-NAME" (replacing "NAME" with the name of your plugin). If you +// then invoke protoc with the parameter --NAME_out=OUT_DIR (again, replace +// "NAME" with your plugin's name), protoc will invoke your plugin to generate +// the output, which will be placed in OUT_DIR. +// * Place the plugin binary anywhere, with any name, and pass the --plugin +// parameter to protoc to direct it to your plugin like so: +// protoc --plugin=protoc-gen-NAME=path/to/mybinary --NAME_out=OUT_DIR +// On Windows, make sure to include the .exe suffix: +// protoc --plugin=protoc-gen-NAME=path/to/mybinary.exe --NAME_out=OUT_DIR + +#ifndef GOOGLE_PROTOBUF_COMPILER_PLUGIN_H__ +#define GOOGLE_PROTOBUF_COMPILER_PLUGIN_H__ + +#include + +#include + +namespace google { +namespace protobuf { +namespace compiler { + +class CodeGenerator; // code_generator.h +class CodeGeneratorRequest; +class CodeGeneratorResponse; + +// Implements main() for a protoc plugin exposing the given code generator. +PROTOC_EXPORT int PluginMain(int argc, char* argv[], + const CodeGenerator* generator); + +// Generates code using the given code generator. Returns true if the code +// generation is successful. If the code generation fails, error_msg may be +// populated to describe the failure cause. +bool GenerateCode(const CodeGeneratorRequest& request, + const CodeGenerator& generator, + CodeGeneratorResponse* response, std::string* error_msg); + +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_PLUGIN_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/plugin.pb.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/plugin.pb.h new file mode 100644 index 0000000000000000000000000000000000000000..81ba11cec535baed50d5b8f892ed5ba82e45a58f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/plugin.pb.h @@ -0,0 +1,1803 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: google/protobuf/compiler/plugin.proto + +#ifndef GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2fcompiler_2fplugin_2eproto +#define GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2fcompiler_2fplugin_2eproto + +#include +#include + +#include +#if PROTOBUF_VERSION < 3013000 +#error This file was generated by a newer version of protoc which is +#error incompatible with your Protocol Buffer headers. Please update +#error your headers. +#endif +#if 3013000 < PROTOBUF_MIN_PROTOC_VERSION +#error This file was generated by an older version of protoc which is +#error incompatible with your Protocol Buffer headers. Please +#error regenerate this file with a newer version of protoc. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // IWYU pragma: export +#include // IWYU pragma: export +#include +#include +#include +// @@protoc_insertion_point(includes) +#include +#define PROTOBUF_INTERNAL_EXPORT_google_2fprotobuf_2fcompiler_2fplugin_2eproto PROTOC_EXPORT +#ifdef major +#undef major +#endif +#ifdef minor +#undef minor +#endif +PROTOBUF_NAMESPACE_OPEN +namespace internal { +class AnyMetadata; +} // namespace internal +PROTOBUF_NAMESPACE_CLOSE + +// Internal implementation detail -- do not use these members. +struct PROTOC_EXPORT TableStruct_google_2fprotobuf_2fcompiler_2fplugin_2eproto { + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTableField entries[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::AuxiliaryParseTableField aux[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTable schema[4] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::FieldMetadata field_metadata[]; + static const ::PROTOBUF_NAMESPACE_ID::internal::SerializationTable serialization_table[]; + static const ::PROTOBUF_NAMESPACE_ID::uint32 offsets[]; +}; +extern PROTOC_EXPORT const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_google_2fprotobuf_2fcompiler_2fplugin_2eproto; +PROTOBUF_NAMESPACE_OPEN +namespace compiler { +class CodeGeneratorRequest; +class CodeGeneratorRequestDefaultTypeInternal; +PROTOC_EXPORT extern CodeGeneratorRequestDefaultTypeInternal _CodeGeneratorRequest_default_instance_; +class CodeGeneratorResponse; +class CodeGeneratorResponseDefaultTypeInternal; +PROTOC_EXPORT extern CodeGeneratorResponseDefaultTypeInternal _CodeGeneratorResponse_default_instance_; +class CodeGeneratorResponse_File; +class CodeGeneratorResponse_FileDefaultTypeInternal; +PROTOC_EXPORT extern CodeGeneratorResponse_FileDefaultTypeInternal _CodeGeneratorResponse_File_default_instance_; +class Version; +class VersionDefaultTypeInternal; +PROTOC_EXPORT extern VersionDefaultTypeInternal _Version_default_instance_; +} // namespace compiler +PROTOBUF_NAMESPACE_CLOSE +PROTOBUF_NAMESPACE_OPEN +template<> PROTOC_EXPORT PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorRequest* Arena::CreateMaybeMessage(Arena*); +template<> PROTOC_EXPORT PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse* Arena::CreateMaybeMessage(Arena*); +template<> PROTOC_EXPORT PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File* Arena::CreateMaybeMessage(Arena*); +template<> PROTOC_EXPORT PROTOBUF_NAMESPACE_ID::compiler::Version* Arena::CreateMaybeMessage(Arena*); +PROTOBUF_NAMESPACE_CLOSE +PROTOBUF_NAMESPACE_OPEN +namespace compiler { + +enum CodeGeneratorResponse_Feature : int { + CodeGeneratorResponse_Feature_FEATURE_NONE = 0, + CodeGeneratorResponse_Feature_FEATURE_PROTO3_OPTIONAL = 1 +}; +PROTOC_EXPORT bool CodeGeneratorResponse_Feature_IsValid(int value); +constexpr CodeGeneratorResponse_Feature CodeGeneratorResponse_Feature_Feature_MIN = CodeGeneratorResponse_Feature_FEATURE_NONE; +constexpr CodeGeneratorResponse_Feature CodeGeneratorResponse_Feature_Feature_MAX = CodeGeneratorResponse_Feature_FEATURE_PROTO3_OPTIONAL; +constexpr int CodeGeneratorResponse_Feature_Feature_ARRAYSIZE = CodeGeneratorResponse_Feature_Feature_MAX + 1; + +PROTOC_EXPORT const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* CodeGeneratorResponse_Feature_descriptor(); +template +inline const std::string& CodeGeneratorResponse_Feature_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function CodeGeneratorResponse_Feature_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + CodeGeneratorResponse_Feature_descriptor(), enum_t_value); +} +inline bool CodeGeneratorResponse_Feature_Parse( + ::PROTOBUF_NAMESPACE_ID::ConstStringParam name, CodeGeneratorResponse_Feature* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + CodeGeneratorResponse_Feature_descriptor(), name, value); +} +// =================================================================== + +class PROTOC_EXPORT Version PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.compiler.Version) */ { + public: + inline Version() : Version(nullptr) {} + virtual ~Version(); + + Version(const Version& from); + Version(Version&& from) noexcept + : Version() { + *this = ::std::move(from); + } + + inline Version& operator=(const Version& from) { + CopyFrom(from); + return *this; + } + inline Version& operator=(Version&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Version& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Version* internal_default_instance() { + return reinterpret_cast( + &_Version_default_instance_); + } + static constexpr int kIndexInFileMessages = + 0; + + friend void swap(Version& a, Version& b) { + a.Swap(&b); + } + inline void Swap(Version* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Version* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Version* New() const final { + return CreateMaybeMessage(nullptr); + } + + Version* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Version& from); + void MergeFrom(const Version& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Version* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.compiler.Version"; + } + protected: + explicit Version(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fcompiler_2fplugin_2eproto); + return ::descriptor_table_google_2fprotobuf_2fcompiler_2fplugin_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kSuffixFieldNumber = 4, + kMajorFieldNumber = 1, + kMinorFieldNumber = 2, + kPatchFieldNumber = 3, + }; + // optional string suffix = 4; + bool has_suffix() const; + private: + bool _internal_has_suffix() const; + public: + void clear_suffix(); + const std::string& suffix() const; + void set_suffix(const std::string& value); + void set_suffix(std::string&& value); + void set_suffix(const char* value); + void set_suffix(const char* value, size_t size); + std::string* mutable_suffix(); + std::string* release_suffix(); + void set_allocated_suffix(std::string* suffix); + private: + const std::string& _internal_suffix() const; + void _internal_set_suffix(const std::string& value); + std::string* _internal_mutable_suffix(); + public: + + // optional int32 major = 1; + bool has_major() const; + private: + bool _internal_has_major() const; + public: + void clear_major(); + ::PROTOBUF_NAMESPACE_ID::int32 major() const; + void set_major(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_major() const; + void _internal_set_major(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // optional int32 minor = 2; + bool has_minor() const; + private: + bool _internal_has_minor() const; + public: + void clear_minor(); + ::PROTOBUF_NAMESPACE_ID::int32 minor() const; + void set_minor(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_minor() const; + void _internal_set_minor(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // optional int32 patch = 3; + bool has_patch() const; + private: + bool _internal_has_patch() const; + public: + void clear_patch(); + ::PROTOBUF_NAMESPACE_ID::int32 patch() const; + void set_patch(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_patch() const; + void _internal_set_patch(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // @@protoc_insertion_point(class_scope:google.protobuf.compiler.Version) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr suffix_; + ::PROTOBUF_NAMESPACE_ID::int32 major_; + ::PROTOBUF_NAMESPACE_ID::int32 minor_; + ::PROTOBUF_NAMESPACE_ID::int32 patch_; + friend struct ::TableStruct_google_2fprotobuf_2fcompiler_2fplugin_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOC_EXPORT CodeGeneratorRequest PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.compiler.CodeGeneratorRequest) */ { + public: + inline CodeGeneratorRequest() : CodeGeneratorRequest(nullptr) {} + virtual ~CodeGeneratorRequest(); + + CodeGeneratorRequest(const CodeGeneratorRequest& from); + CodeGeneratorRequest(CodeGeneratorRequest&& from) noexcept + : CodeGeneratorRequest() { + *this = ::std::move(from); + } + + inline CodeGeneratorRequest& operator=(const CodeGeneratorRequest& from) { + CopyFrom(from); + return *this; + } + inline CodeGeneratorRequest& operator=(CodeGeneratorRequest&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const CodeGeneratorRequest& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const CodeGeneratorRequest* internal_default_instance() { + return reinterpret_cast( + &_CodeGeneratorRequest_default_instance_); + } + static constexpr int kIndexInFileMessages = + 1; + + friend void swap(CodeGeneratorRequest& a, CodeGeneratorRequest& b) { + a.Swap(&b); + } + inline void Swap(CodeGeneratorRequest* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(CodeGeneratorRequest* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline CodeGeneratorRequest* New() const final { + return CreateMaybeMessage(nullptr); + } + + CodeGeneratorRequest* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const CodeGeneratorRequest& from); + void MergeFrom(const CodeGeneratorRequest& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(CodeGeneratorRequest* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.compiler.CodeGeneratorRequest"; + } + protected: + explicit CodeGeneratorRequest(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fcompiler_2fplugin_2eproto); + return ::descriptor_table_google_2fprotobuf_2fcompiler_2fplugin_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kFileToGenerateFieldNumber = 1, + kProtoFileFieldNumber = 15, + kParameterFieldNumber = 2, + kCompilerVersionFieldNumber = 3, + }; + // repeated string file_to_generate = 1; + int file_to_generate_size() const; + private: + int _internal_file_to_generate_size() const; + public: + void clear_file_to_generate(); + const std::string& file_to_generate(int index) const; + std::string* mutable_file_to_generate(int index); + void set_file_to_generate(int index, const std::string& value); + void set_file_to_generate(int index, std::string&& value); + void set_file_to_generate(int index, const char* value); + void set_file_to_generate(int index, const char* value, size_t size); + std::string* add_file_to_generate(); + void add_file_to_generate(const std::string& value); + void add_file_to_generate(std::string&& value); + void add_file_to_generate(const char* value); + void add_file_to_generate(const char* value, size_t size); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& file_to_generate() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* mutable_file_to_generate(); + private: + const std::string& _internal_file_to_generate(int index) const; + std::string* _internal_add_file_to_generate(); + public: + + // repeated .google.protobuf.FileDescriptorProto proto_file = 15; + int proto_file_size() const; + private: + int _internal_proto_file_size() const; + public: + void clear_proto_file(); + PROTOBUF_NAMESPACE_ID::FileDescriptorProto* mutable_proto_file(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FileDescriptorProto >* + mutable_proto_file(); + private: + const PROTOBUF_NAMESPACE_ID::FileDescriptorProto& _internal_proto_file(int index) const; + PROTOBUF_NAMESPACE_ID::FileDescriptorProto* _internal_add_proto_file(); + public: + const PROTOBUF_NAMESPACE_ID::FileDescriptorProto& proto_file(int index) const; + PROTOBUF_NAMESPACE_ID::FileDescriptorProto* add_proto_file(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FileDescriptorProto >& + proto_file() const; + + // optional string parameter = 2; + bool has_parameter() const; + private: + bool _internal_has_parameter() const; + public: + void clear_parameter(); + const std::string& parameter() const; + void set_parameter(const std::string& value); + void set_parameter(std::string&& value); + void set_parameter(const char* value); + void set_parameter(const char* value, size_t size); + std::string* mutable_parameter(); + std::string* release_parameter(); + void set_allocated_parameter(std::string* parameter); + private: + const std::string& _internal_parameter() const; + void _internal_set_parameter(const std::string& value); + std::string* _internal_mutable_parameter(); + public: + + // optional .google.protobuf.compiler.Version compiler_version = 3; + bool has_compiler_version() const; + private: + bool _internal_has_compiler_version() const; + public: + void clear_compiler_version(); + const PROTOBUF_NAMESPACE_ID::compiler::Version& compiler_version() const; + PROTOBUF_NAMESPACE_ID::compiler::Version* release_compiler_version(); + PROTOBUF_NAMESPACE_ID::compiler::Version* mutable_compiler_version(); + void set_allocated_compiler_version(PROTOBUF_NAMESPACE_ID::compiler::Version* compiler_version); + private: + const PROTOBUF_NAMESPACE_ID::compiler::Version& _internal_compiler_version() const; + PROTOBUF_NAMESPACE_ID::compiler::Version* _internal_mutable_compiler_version(); + public: + void unsafe_arena_set_allocated_compiler_version( + PROTOBUF_NAMESPACE_ID::compiler::Version* compiler_version); + PROTOBUF_NAMESPACE_ID::compiler::Version* unsafe_arena_release_compiler_version(); + + // @@protoc_insertion_point(class_scope:google.protobuf.compiler.CodeGeneratorRequest) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField file_to_generate_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FileDescriptorProto > proto_file_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr parameter_; + PROTOBUF_NAMESPACE_ID::compiler::Version* compiler_version_; + friend struct ::TableStruct_google_2fprotobuf_2fcompiler_2fplugin_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOC_EXPORT CodeGeneratorResponse_File PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.compiler.CodeGeneratorResponse.File) */ { + public: + inline CodeGeneratorResponse_File() : CodeGeneratorResponse_File(nullptr) {} + virtual ~CodeGeneratorResponse_File(); + + CodeGeneratorResponse_File(const CodeGeneratorResponse_File& from); + CodeGeneratorResponse_File(CodeGeneratorResponse_File&& from) noexcept + : CodeGeneratorResponse_File() { + *this = ::std::move(from); + } + + inline CodeGeneratorResponse_File& operator=(const CodeGeneratorResponse_File& from) { + CopyFrom(from); + return *this; + } + inline CodeGeneratorResponse_File& operator=(CodeGeneratorResponse_File&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const CodeGeneratorResponse_File& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const CodeGeneratorResponse_File* internal_default_instance() { + return reinterpret_cast( + &_CodeGeneratorResponse_File_default_instance_); + } + static constexpr int kIndexInFileMessages = + 2; + + friend void swap(CodeGeneratorResponse_File& a, CodeGeneratorResponse_File& b) { + a.Swap(&b); + } + inline void Swap(CodeGeneratorResponse_File* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(CodeGeneratorResponse_File* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline CodeGeneratorResponse_File* New() const final { + return CreateMaybeMessage(nullptr); + } + + CodeGeneratorResponse_File* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const CodeGeneratorResponse_File& from); + void MergeFrom(const CodeGeneratorResponse_File& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(CodeGeneratorResponse_File* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.compiler.CodeGeneratorResponse.File"; + } + protected: + explicit CodeGeneratorResponse_File(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fcompiler_2fplugin_2eproto); + return ::descriptor_table_google_2fprotobuf_2fcompiler_2fplugin_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kNameFieldNumber = 1, + kInsertionPointFieldNumber = 2, + kContentFieldNumber = 15, + }; + // optional string name = 1; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // optional string insertion_point = 2; + bool has_insertion_point() const; + private: + bool _internal_has_insertion_point() const; + public: + void clear_insertion_point(); + const std::string& insertion_point() const; + void set_insertion_point(const std::string& value); + void set_insertion_point(std::string&& value); + void set_insertion_point(const char* value); + void set_insertion_point(const char* value, size_t size); + std::string* mutable_insertion_point(); + std::string* release_insertion_point(); + void set_allocated_insertion_point(std::string* insertion_point); + private: + const std::string& _internal_insertion_point() const; + void _internal_set_insertion_point(const std::string& value); + std::string* _internal_mutable_insertion_point(); + public: + + // optional string content = 15; + bool has_content() const; + private: + bool _internal_has_content() const; + public: + void clear_content(); + const std::string& content() const; + void set_content(const std::string& value); + void set_content(std::string&& value); + void set_content(const char* value); + void set_content(const char* value, size_t size); + std::string* mutable_content(); + std::string* release_content(); + void set_allocated_content(std::string* content); + private: + const std::string& _internal_content() const; + void _internal_set_content(const std::string& value); + std::string* _internal_mutable_content(); + public: + + // @@protoc_insertion_point(class_scope:google.protobuf.compiler.CodeGeneratorResponse.File) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr insertion_point_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr content_; + friend struct ::TableStruct_google_2fprotobuf_2fcompiler_2fplugin_2eproto; +}; +// ------------------------------------------------------------------- + +class PROTOC_EXPORT CodeGeneratorResponse PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.compiler.CodeGeneratorResponse) */ { + public: + inline CodeGeneratorResponse() : CodeGeneratorResponse(nullptr) {} + virtual ~CodeGeneratorResponse(); + + CodeGeneratorResponse(const CodeGeneratorResponse& from); + CodeGeneratorResponse(CodeGeneratorResponse&& from) noexcept + : CodeGeneratorResponse() { + *this = ::std::move(from); + } + + inline CodeGeneratorResponse& operator=(const CodeGeneratorResponse& from) { + CopyFrom(from); + return *this; + } + inline CodeGeneratorResponse& operator=(CodeGeneratorResponse&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const CodeGeneratorResponse& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const CodeGeneratorResponse* internal_default_instance() { + return reinterpret_cast( + &_CodeGeneratorResponse_default_instance_); + } + static constexpr int kIndexInFileMessages = + 3; + + friend void swap(CodeGeneratorResponse& a, CodeGeneratorResponse& b) { + a.Swap(&b); + } + inline void Swap(CodeGeneratorResponse* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(CodeGeneratorResponse* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline CodeGeneratorResponse* New() const final { + return CreateMaybeMessage(nullptr); + } + + CodeGeneratorResponse* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const CodeGeneratorResponse& from); + void MergeFrom(const CodeGeneratorResponse& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(CodeGeneratorResponse* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "google.protobuf.compiler.CodeGeneratorResponse"; + } + protected: + explicit CodeGeneratorResponse(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2fcompiler_2fplugin_2eproto); + return ::descriptor_table_google_2fprotobuf_2fcompiler_2fplugin_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef CodeGeneratorResponse_File File; + + typedef CodeGeneratorResponse_Feature Feature; + static constexpr Feature FEATURE_NONE = + CodeGeneratorResponse_Feature_FEATURE_NONE; + static constexpr Feature FEATURE_PROTO3_OPTIONAL = + CodeGeneratorResponse_Feature_FEATURE_PROTO3_OPTIONAL; + static inline bool Feature_IsValid(int value) { + return CodeGeneratorResponse_Feature_IsValid(value); + } + static constexpr Feature Feature_MIN = + CodeGeneratorResponse_Feature_Feature_MIN; + static constexpr Feature Feature_MAX = + CodeGeneratorResponse_Feature_Feature_MAX; + static constexpr int Feature_ARRAYSIZE = + CodeGeneratorResponse_Feature_Feature_ARRAYSIZE; + static inline const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* + Feature_descriptor() { + return CodeGeneratorResponse_Feature_descriptor(); + } + template + static inline const std::string& Feature_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function Feature_Name."); + return CodeGeneratorResponse_Feature_Name(enum_t_value); + } + static inline bool Feature_Parse(::PROTOBUF_NAMESPACE_ID::ConstStringParam name, + Feature* value) { + return CodeGeneratorResponse_Feature_Parse(name, value); + } + + // accessors ------------------------------------------------------- + + enum : int { + kFileFieldNumber = 15, + kErrorFieldNumber = 1, + kSupportedFeaturesFieldNumber = 2, + }; + // repeated .google.protobuf.compiler.CodeGeneratorResponse.File file = 15; + int file_size() const; + private: + int _internal_file_size() const; + public: + void clear_file(); + PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File* mutable_file(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File >* + mutable_file(); + private: + const PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File& _internal_file(int index) const; + PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File* _internal_add_file(); + public: + const PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File& file(int index) const; + PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File* add_file(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File >& + file() const; + + // optional string error = 1; + bool has_error() const; + private: + bool _internal_has_error() const; + public: + void clear_error(); + const std::string& error() const; + void set_error(const std::string& value); + void set_error(std::string&& value); + void set_error(const char* value); + void set_error(const char* value, size_t size); + std::string* mutable_error(); + std::string* release_error(); + void set_allocated_error(std::string* error); + private: + const std::string& _internal_error() const; + void _internal_set_error(const std::string& value); + std::string* _internal_mutable_error(); + public: + + // optional uint64 supported_features = 2; + bool has_supported_features() const; + private: + bool _internal_has_supported_features() const; + public: + void clear_supported_features(); + ::PROTOBUF_NAMESPACE_ID::uint64 supported_features() const; + void set_supported_features(::PROTOBUF_NAMESPACE_ID::uint64 value); + private: + ::PROTOBUF_NAMESPACE_ID::uint64 _internal_supported_features() const; + void _internal_set_supported_features(::PROTOBUF_NAMESPACE_ID::uint64 value); + public: + + // @@protoc_insertion_point(class_scope:google.protobuf.compiler.CodeGeneratorResponse) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File > file_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr error_; + ::PROTOBUF_NAMESPACE_ID::uint64 supported_features_; + friend struct ::TableStruct_google_2fprotobuf_2fcompiler_2fplugin_2eproto; +}; +// =================================================================== + + +// =================================================================== + +#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ +// Version + +// optional int32 major = 1; +inline bool Version::_internal_has_major() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool Version::has_major() const { + return _internal_has_major(); +} +inline void Version::clear_major() { + major_ = 0; + _has_bits_[0] &= ~0x00000002u; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Version::_internal_major() const { + return major_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Version::major() const { + // @@protoc_insertion_point(field_get:google.protobuf.compiler.Version.major) + return _internal_major(); +} +inline void Version::_internal_set_major(::PROTOBUF_NAMESPACE_ID::int32 value) { + _has_bits_[0] |= 0x00000002u; + major_ = value; +} +inline void Version::set_major(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_major(value); + // @@protoc_insertion_point(field_set:google.protobuf.compiler.Version.major) +} + +// optional int32 minor = 2; +inline bool Version::_internal_has_minor() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool Version::has_minor() const { + return _internal_has_minor(); +} +inline void Version::clear_minor() { + minor_ = 0; + _has_bits_[0] &= ~0x00000004u; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Version::_internal_minor() const { + return minor_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Version::minor() const { + // @@protoc_insertion_point(field_get:google.protobuf.compiler.Version.minor) + return _internal_minor(); +} +inline void Version::_internal_set_minor(::PROTOBUF_NAMESPACE_ID::int32 value) { + _has_bits_[0] |= 0x00000004u; + minor_ = value; +} +inline void Version::set_minor(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_minor(value); + // @@protoc_insertion_point(field_set:google.protobuf.compiler.Version.minor) +} + +// optional int32 patch = 3; +inline bool Version::_internal_has_patch() const { + bool value = (_has_bits_[0] & 0x00000008u) != 0; + return value; +} +inline bool Version::has_patch() const { + return _internal_has_patch(); +} +inline void Version::clear_patch() { + patch_ = 0; + _has_bits_[0] &= ~0x00000008u; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Version::_internal_patch() const { + return patch_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Version::patch() const { + // @@protoc_insertion_point(field_get:google.protobuf.compiler.Version.patch) + return _internal_patch(); +} +inline void Version::_internal_set_patch(::PROTOBUF_NAMESPACE_ID::int32 value) { + _has_bits_[0] |= 0x00000008u; + patch_ = value; +} +inline void Version::set_patch(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_patch(value); + // @@protoc_insertion_point(field_set:google.protobuf.compiler.Version.patch) +} + +// optional string suffix = 4; +inline bool Version::_internal_has_suffix() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool Version::has_suffix() const { + return _internal_has_suffix(); +} +inline void Version::clear_suffix() { + suffix_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& Version::suffix() const { + // @@protoc_insertion_point(field_get:google.protobuf.compiler.Version.suffix) + return _internal_suffix(); +} +inline void Version::set_suffix(const std::string& value) { + _internal_set_suffix(value); + // @@protoc_insertion_point(field_set:google.protobuf.compiler.Version.suffix) +} +inline std::string* Version::mutable_suffix() { + // @@protoc_insertion_point(field_mutable:google.protobuf.compiler.Version.suffix) + return _internal_mutable_suffix(); +} +inline const std::string& Version::_internal_suffix() const { + return suffix_.Get(); +} +inline void Version::_internal_set_suffix(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + suffix_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Version::set_suffix(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + suffix_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.compiler.Version.suffix) +} +inline void Version::set_suffix(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + suffix_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.compiler.Version.suffix) +} +inline void Version::set_suffix(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + suffix_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.compiler.Version.suffix) +} +inline std::string* Version::_internal_mutable_suffix() { + _has_bits_[0] |= 0x00000001u; + return suffix_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Version::release_suffix() { + // @@protoc_insertion_point(field_release:google.protobuf.compiler.Version.suffix) + if (!_internal_has_suffix()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return suffix_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void Version::set_allocated_suffix(std::string* suffix) { + if (suffix != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + suffix_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), suffix, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.compiler.Version.suffix) +} + +// ------------------------------------------------------------------- + +// CodeGeneratorRequest + +// repeated string file_to_generate = 1; +inline int CodeGeneratorRequest::_internal_file_to_generate_size() const { + return file_to_generate_.size(); +} +inline int CodeGeneratorRequest::file_to_generate_size() const { + return _internal_file_to_generate_size(); +} +inline void CodeGeneratorRequest::clear_file_to_generate() { + file_to_generate_.Clear(); +} +inline std::string* CodeGeneratorRequest::add_file_to_generate() { + // @@protoc_insertion_point(field_add_mutable:google.protobuf.compiler.CodeGeneratorRequest.file_to_generate) + return _internal_add_file_to_generate(); +} +inline const std::string& CodeGeneratorRequest::_internal_file_to_generate(int index) const { + return file_to_generate_.Get(index); +} +inline const std::string& CodeGeneratorRequest::file_to_generate(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.compiler.CodeGeneratorRequest.file_to_generate) + return _internal_file_to_generate(index); +} +inline std::string* CodeGeneratorRequest::mutable_file_to_generate(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.compiler.CodeGeneratorRequest.file_to_generate) + return file_to_generate_.Mutable(index); +} +inline void CodeGeneratorRequest::set_file_to_generate(int index, const std::string& value) { + // @@protoc_insertion_point(field_set:google.protobuf.compiler.CodeGeneratorRequest.file_to_generate) + file_to_generate_.Mutable(index)->assign(value); +} +inline void CodeGeneratorRequest::set_file_to_generate(int index, std::string&& value) { + // @@protoc_insertion_point(field_set:google.protobuf.compiler.CodeGeneratorRequest.file_to_generate) + file_to_generate_.Mutable(index)->assign(std::move(value)); +} +inline void CodeGeneratorRequest::set_file_to_generate(int index, const char* value) { + GOOGLE_DCHECK(value != nullptr); + file_to_generate_.Mutable(index)->assign(value); + // @@protoc_insertion_point(field_set_char:google.protobuf.compiler.CodeGeneratorRequest.file_to_generate) +} +inline void CodeGeneratorRequest::set_file_to_generate(int index, const char* value, size_t size) { + file_to_generate_.Mutable(index)->assign( + reinterpret_cast(value), size); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.compiler.CodeGeneratorRequest.file_to_generate) +} +inline std::string* CodeGeneratorRequest::_internal_add_file_to_generate() { + return file_to_generate_.Add(); +} +inline void CodeGeneratorRequest::add_file_to_generate(const std::string& value) { + file_to_generate_.Add()->assign(value); + // @@protoc_insertion_point(field_add:google.protobuf.compiler.CodeGeneratorRequest.file_to_generate) +} +inline void CodeGeneratorRequest::add_file_to_generate(std::string&& value) { + file_to_generate_.Add(std::move(value)); + // @@protoc_insertion_point(field_add:google.protobuf.compiler.CodeGeneratorRequest.file_to_generate) +} +inline void CodeGeneratorRequest::add_file_to_generate(const char* value) { + GOOGLE_DCHECK(value != nullptr); + file_to_generate_.Add()->assign(value); + // @@protoc_insertion_point(field_add_char:google.protobuf.compiler.CodeGeneratorRequest.file_to_generate) +} +inline void CodeGeneratorRequest::add_file_to_generate(const char* value, size_t size) { + file_to_generate_.Add()->assign(reinterpret_cast(value), size); + // @@protoc_insertion_point(field_add_pointer:google.protobuf.compiler.CodeGeneratorRequest.file_to_generate) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& +CodeGeneratorRequest::file_to_generate() const { + // @@protoc_insertion_point(field_list:google.protobuf.compiler.CodeGeneratorRequest.file_to_generate) + return file_to_generate_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* +CodeGeneratorRequest::mutable_file_to_generate() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.compiler.CodeGeneratorRequest.file_to_generate) + return &file_to_generate_; +} + +// optional string parameter = 2; +inline bool CodeGeneratorRequest::_internal_has_parameter() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool CodeGeneratorRequest::has_parameter() const { + return _internal_has_parameter(); +} +inline void CodeGeneratorRequest::clear_parameter() { + parameter_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& CodeGeneratorRequest::parameter() const { + // @@protoc_insertion_point(field_get:google.protobuf.compiler.CodeGeneratorRequest.parameter) + return _internal_parameter(); +} +inline void CodeGeneratorRequest::set_parameter(const std::string& value) { + _internal_set_parameter(value); + // @@protoc_insertion_point(field_set:google.protobuf.compiler.CodeGeneratorRequest.parameter) +} +inline std::string* CodeGeneratorRequest::mutable_parameter() { + // @@protoc_insertion_point(field_mutable:google.protobuf.compiler.CodeGeneratorRequest.parameter) + return _internal_mutable_parameter(); +} +inline const std::string& CodeGeneratorRequest::_internal_parameter() const { + return parameter_.Get(); +} +inline void CodeGeneratorRequest::_internal_set_parameter(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + parameter_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void CodeGeneratorRequest::set_parameter(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + parameter_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.compiler.CodeGeneratorRequest.parameter) +} +inline void CodeGeneratorRequest::set_parameter(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + parameter_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.compiler.CodeGeneratorRequest.parameter) +} +inline void CodeGeneratorRequest::set_parameter(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + parameter_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.compiler.CodeGeneratorRequest.parameter) +} +inline std::string* CodeGeneratorRequest::_internal_mutable_parameter() { + _has_bits_[0] |= 0x00000001u; + return parameter_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* CodeGeneratorRequest::release_parameter() { + // @@protoc_insertion_point(field_release:google.protobuf.compiler.CodeGeneratorRequest.parameter) + if (!_internal_has_parameter()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return parameter_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void CodeGeneratorRequest::set_allocated_parameter(std::string* parameter) { + if (parameter != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + parameter_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), parameter, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.compiler.CodeGeneratorRequest.parameter) +} + +// repeated .google.protobuf.FileDescriptorProto proto_file = 15; +inline int CodeGeneratorRequest::_internal_proto_file_size() const { + return proto_file_.size(); +} +inline int CodeGeneratorRequest::proto_file_size() const { + return _internal_proto_file_size(); +} +inline PROTOBUF_NAMESPACE_ID::FileDescriptorProto* CodeGeneratorRequest::mutable_proto_file(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.compiler.CodeGeneratorRequest.proto_file) + return proto_file_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FileDescriptorProto >* +CodeGeneratorRequest::mutable_proto_file() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.compiler.CodeGeneratorRequest.proto_file) + return &proto_file_; +} +inline const PROTOBUF_NAMESPACE_ID::FileDescriptorProto& CodeGeneratorRequest::_internal_proto_file(int index) const { + return proto_file_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::FileDescriptorProto& CodeGeneratorRequest::proto_file(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.compiler.CodeGeneratorRequest.proto_file) + return _internal_proto_file(index); +} +inline PROTOBUF_NAMESPACE_ID::FileDescriptorProto* CodeGeneratorRequest::_internal_add_proto_file() { + return proto_file_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::FileDescriptorProto* CodeGeneratorRequest::add_proto_file() { + // @@protoc_insertion_point(field_add:google.protobuf.compiler.CodeGeneratorRequest.proto_file) + return _internal_add_proto_file(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::FileDescriptorProto >& +CodeGeneratorRequest::proto_file() const { + // @@protoc_insertion_point(field_list:google.protobuf.compiler.CodeGeneratorRequest.proto_file) + return proto_file_; +} + +// optional .google.protobuf.compiler.Version compiler_version = 3; +inline bool CodeGeneratorRequest::_internal_has_compiler_version() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + PROTOBUF_ASSUME(!value || compiler_version_ != nullptr); + return value; +} +inline bool CodeGeneratorRequest::has_compiler_version() const { + return _internal_has_compiler_version(); +} +inline void CodeGeneratorRequest::clear_compiler_version() { + if (compiler_version_ != nullptr) compiler_version_->Clear(); + _has_bits_[0] &= ~0x00000002u; +} +inline const PROTOBUF_NAMESPACE_ID::compiler::Version& CodeGeneratorRequest::_internal_compiler_version() const { + const PROTOBUF_NAMESPACE_ID::compiler::Version* p = compiler_version_; + return p != nullptr ? *p : *reinterpret_cast( + &PROTOBUF_NAMESPACE_ID::compiler::_Version_default_instance_); +} +inline const PROTOBUF_NAMESPACE_ID::compiler::Version& CodeGeneratorRequest::compiler_version() const { + // @@protoc_insertion_point(field_get:google.protobuf.compiler.CodeGeneratorRequest.compiler_version) + return _internal_compiler_version(); +} +inline void CodeGeneratorRequest::unsafe_arena_set_allocated_compiler_version( + PROTOBUF_NAMESPACE_ID::compiler::Version* compiler_version) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(compiler_version_); + } + compiler_version_ = compiler_version; + if (compiler_version) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:google.protobuf.compiler.CodeGeneratorRequest.compiler_version) +} +inline PROTOBUF_NAMESPACE_ID::compiler::Version* CodeGeneratorRequest::release_compiler_version() { + _has_bits_[0] &= ~0x00000002u; + PROTOBUF_NAMESPACE_ID::compiler::Version* temp = compiler_version_; + compiler_version_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline PROTOBUF_NAMESPACE_ID::compiler::Version* CodeGeneratorRequest::unsafe_arena_release_compiler_version() { + // @@protoc_insertion_point(field_release:google.protobuf.compiler.CodeGeneratorRequest.compiler_version) + _has_bits_[0] &= ~0x00000002u; + PROTOBUF_NAMESPACE_ID::compiler::Version* temp = compiler_version_; + compiler_version_ = nullptr; + return temp; +} +inline PROTOBUF_NAMESPACE_ID::compiler::Version* CodeGeneratorRequest::_internal_mutable_compiler_version() { + _has_bits_[0] |= 0x00000002u; + if (compiler_version_ == nullptr) { + auto* p = CreateMaybeMessage(GetArena()); + compiler_version_ = p; + } + return compiler_version_; +} +inline PROTOBUF_NAMESPACE_ID::compiler::Version* CodeGeneratorRequest::mutable_compiler_version() { + // @@protoc_insertion_point(field_mutable:google.protobuf.compiler.CodeGeneratorRequest.compiler_version) + return _internal_mutable_compiler_version(); +} +inline void CodeGeneratorRequest::set_allocated_compiler_version(PROTOBUF_NAMESPACE_ID::compiler::Version* compiler_version) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete compiler_version_; + } + if (compiler_version) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(compiler_version); + if (message_arena != submessage_arena) { + compiler_version = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, compiler_version, submessage_arena); + } + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + compiler_version_ = compiler_version; + // @@protoc_insertion_point(field_set_allocated:google.protobuf.compiler.CodeGeneratorRequest.compiler_version) +} + +// ------------------------------------------------------------------- + +// CodeGeneratorResponse_File + +// optional string name = 1; +inline bool CodeGeneratorResponse_File::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool CodeGeneratorResponse_File::has_name() const { + return _internal_has_name(); +} +inline void CodeGeneratorResponse_File::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& CodeGeneratorResponse_File::name() const { + // @@protoc_insertion_point(field_get:google.protobuf.compiler.CodeGeneratorResponse.File.name) + return _internal_name(); +} +inline void CodeGeneratorResponse_File::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:google.protobuf.compiler.CodeGeneratorResponse.File.name) +} +inline std::string* CodeGeneratorResponse_File::mutable_name() { + // @@protoc_insertion_point(field_mutable:google.protobuf.compiler.CodeGeneratorResponse.File.name) + return _internal_mutable_name(); +} +inline const std::string& CodeGeneratorResponse_File::_internal_name() const { + return name_.Get(); +} +inline void CodeGeneratorResponse_File::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void CodeGeneratorResponse_File::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.compiler.CodeGeneratorResponse.File.name) +} +inline void CodeGeneratorResponse_File::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.compiler.CodeGeneratorResponse.File.name) +} +inline void CodeGeneratorResponse_File::set_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.compiler.CodeGeneratorResponse.File.name) +} +inline std::string* CodeGeneratorResponse_File::_internal_mutable_name() { + _has_bits_[0] |= 0x00000001u; + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* CodeGeneratorResponse_File::release_name() { + // @@protoc_insertion_point(field_release:google.protobuf.compiler.CodeGeneratorResponse.File.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void CodeGeneratorResponse_File::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.compiler.CodeGeneratorResponse.File.name) +} + +// optional string insertion_point = 2; +inline bool CodeGeneratorResponse_File::_internal_has_insertion_point() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool CodeGeneratorResponse_File::has_insertion_point() const { + return _internal_has_insertion_point(); +} +inline void CodeGeneratorResponse_File::clear_insertion_point() { + insertion_point_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& CodeGeneratorResponse_File::insertion_point() const { + // @@protoc_insertion_point(field_get:google.protobuf.compiler.CodeGeneratorResponse.File.insertion_point) + return _internal_insertion_point(); +} +inline void CodeGeneratorResponse_File::set_insertion_point(const std::string& value) { + _internal_set_insertion_point(value); + // @@protoc_insertion_point(field_set:google.protobuf.compiler.CodeGeneratorResponse.File.insertion_point) +} +inline std::string* CodeGeneratorResponse_File::mutable_insertion_point() { + // @@protoc_insertion_point(field_mutable:google.protobuf.compiler.CodeGeneratorResponse.File.insertion_point) + return _internal_mutable_insertion_point(); +} +inline const std::string& CodeGeneratorResponse_File::_internal_insertion_point() const { + return insertion_point_.Get(); +} +inline void CodeGeneratorResponse_File::_internal_set_insertion_point(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + insertion_point_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void CodeGeneratorResponse_File::set_insertion_point(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + insertion_point_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.compiler.CodeGeneratorResponse.File.insertion_point) +} +inline void CodeGeneratorResponse_File::set_insertion_point(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + insertion_point_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.compiler.CodeGeneratorResponse.File.insertion_point) +} +inline void CodeGeneratorResponse_File::set_insertion_point(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000002u; + insertion_point_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.compiler.CodeGeneratorResponse.File.insertion_point) +} +inline std::string* CodeGeneratorResponse_File::_internal_mutable_insertion_point() { + _has_bits_[0] |= 0x00000002u; + return insertion_point_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* CodeGeneratorResponse_File::release_insertion_point() { + // @@protoc_insertion_point(field_release:google.protobuf.compiler.CodeGeneratorResponse.File.insertion_point) + if (!_internal_has_insertion_point()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return insertion_point_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void CodeGeneratorResponse_File::set_allocated_insertion_point(std::string* insertion_point) { + if (insertion_point != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + insertion_point_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), insertion_point, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.compiler.CodeGeneratorResponse.File.insertion_point) +} + +// optional string content = 15; +inline bool CodeGeneratorResponse_File::_internal_has_content() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool CodeGeneratorResponse_File::has_content() const { + return _internal_has_content(); +} +inline void CodeGeneratorResponse_File::clear_content() { + content_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000004u; +} +inline const std::string& CodeGeneratorResponse_File::content() const { + // @@protoc_insertion_point(field_get:google.protobuf.compiler.CodeGeneratorResponse.File.content) + return _internal_content(); +} +inline void CodeGeneratorResponse_File::set_content(const std::string& value) { + _internal_set_content(value); + // @@protoc_insertion_point(field_set:google.protobuf.compiler.CodeGeneratorResponse.File.content) +} +inline std::string* CodeGeneratorResponse_File::mutable_content() { + // @@protoc_insertion_point(field_mutable:google.protobuf.compiler.CodeGeneratorResponse.File.content) + return _internal_mutable_content(); +} +inline const std::string& CodeGeneratorResponse_File::_internal_content() const { + return content_.Get(); +} +inline void CodeGeneratorResponse_File::_internal_set_content(const std::string& value) { + _has_bits_[0] |= 0x00000004u; + content_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void CodeGeneratorResponse_File::set_content(std::string&& value) { + _has_bits_[0] |= 0x00000004u; + content_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.compiler.CodeGeneratorResponse.File.content) +} +inline void CodeGeneratorResponse_File::set_content(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000004u; + content_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.compiler.CodeGeneratorResponse.File.content) +} +inline void CodeGeneratorResponse_File::set_content(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000004u; + content_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.compiler.CodeGeneratorResponse.File.content) +} +inline std::string* CodeGeneratorResponse_File::_internal_mutable_content() { + _has_bits_[0] |= 0x00000004u; + return content_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* CodeGeneratorResponse_File::release_content() { + // @@protoc_insertion_point(field_release:google.protobuf.compiler.CodeGeneratorResponse.File.content) + if (!_internal_has_content()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000004u; + return content_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void CodeGeneratorResponse_File::set_allocated_content(std::string* content) { + if (content != nullptr) { + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + content_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), content, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.compiler.CodeGeneratorResponse.File.content) +} + +// ------------------------------------------------------------------- + +// CodeGeneratorResponse + +// optional string error = 1; +inline bool CodeGeneratorResponse::_internal_has_error() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool CodeGeneratorResponse::has_error() const { + return _internal_has_error(); +} +inline void CodeGeneratorResponse::clear_error() { + error_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& CodeGeneratorResponse::error() const { + // @@protoc_insertion_point(field_get:google.protobuf.compiler.CodeGeneratorResponse.error) + return _internal_error(); +} +inline void CodeGeneratorResponse::set_error(const std::string& value) { + _internal_set_error(value); + // @@protoc_insertion_point(field_set:google.protobuf.compiler.CodeGeneratorResponse.error) +} +inline std::string* CodeGeneratorResponse::mutable_error() { + // @@protoc_insertion_point(field_mutable:google.protobuf.compiler.CodeGeneratorResponse.error) + return _internal_mutable_error(); +} +inline const std::string& CodeGeneratorResponse::_internal_error() const { + return error_.Get(); +} +inline void CodeGeneratorResponse::_internal_set_error(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + error_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void CodeGeneratorResponse::set_error(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + error_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:google.protobuf.compiler.CodeGeneratorResponse.error) +} +inline void CodeGeneratorResponse::set_error(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + error_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:google.protobuf.compiler.CodeGeneratorResponse.error) +} +inline void CodeGeneratorResponse::set_error(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + error_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:google.protobuf.compiler.CodeGeneratorResponse.error) +} +inline std::string* CodeGeneratorResponse::_internal_mutable_error() { + _has_bits_[0] |= 0x00000001u; + return error_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* CodeGeneratorResponse::release_error() { + // @@protoc_insertion_point(field_release:google.protobuf.compiler.CodeGeneratorResponse.error) + if (!_internal_has_error()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return error_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void CodeGeneratorResponse::set_allocated_error(std::string* error) { + if (error != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + error_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), error, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:google.protobuf.compiler.CodeGeneratorResponse.error) +} + +// optional uint64 supported_features = 2; +inline bool CodeGeneratorResponse::_internal_has_supported_features() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool CodeGeneratorResponse::has_supported_features() const { + return _internal_has_supported_features(); +} +inline void CodeGeneratorResponse::clear_supported_features() { + supported_features_ = PROTOBUF_ULONGLONG(0); + _has_bits_[0] &= ~0x00000002u; +} +inline ::PROTOBUF_NAMESPACE_ID::uint64 CodeGeneratorResponse::_internal_supported_features() const { + return supported_features_; +} +inline ::PROTOBUF_NAMESPACE_ID::uint64 CodeGeneratorResponse::supported_features() const { + // @@protoc_insertion_point(field_get:google.protobuf.compiler.CodeGeneratorResponse.supported_features) + return _internal_supported_features(); +} +inline void CodeGeneratorResponse::_internal_set_supported_features(::PROTOBUF_NAMESPACE_ID::uint64 value) { + _has_bits_[0] |= 0x00000002u; + supported_features_ = value; +} +inline void CodeGeneratorResponse::set_supported_features(::PROTOBUF_NAMESPACE_ID::uint64 value) { + _internal_set_supported_features(value); + // @@protoc_insertion_point(field_set:google.protobuf.compiler.CodeGeneratorResponse.supported_features) +} + +// repeated .google.protobuf.compiler.CodeGeneratorResponse.File file = 15; +inline int CodeGeneratorResponse::_internal_file_size() const { + return file_.size(); +} +inline int CodeGeneratorResponse::file_size() const { + return _internal_file_size(); +} +inline void CodeGeneratorResponse::clear_file() { + file_.Clear(); +} +inline PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File* CodeGeneratorResponse::mutable_file(int index) { + // @@protoc_insertion_point(field_mutable:google.protobuf.compiler.CodeGeneratorResponse.file) + return file_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File >* +CodeGeneratorResponse::mutable_file() { + // @@protoc_insertion_point(field_mutable_list:google.protobuf.compiler.CodeGeneratorResponse.file) + return &file_; +} +inline const PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File& CodeGeneratorResponse::_internal_file(int index) const { + return file_.Get(index); +} +inline const PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File& CodeGeneratorResponse::file(int index) const { + // @@protoc_insertion_point(field_get:google.protobuf.compiler.CodeGeneratorResponse.file) + return _internal_file(index); +} +inline PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File* CodeGeneratorResponse::_internal_add_file() { + return file_.Add(); +} +inline PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File* CodeGeneratorResponse::add_file() { + // @@protoc_insertion_point(field_add:google.protobuf.compiler.CodeGeneratorResponse.file) + return _internal_add_file(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_File >& +CodeGeneratorResponse::file() const { + // @@protoc_insertion_point(field_list:google.protobuf.compiler.CodeGeneratorResponse.file) + return file_; +} + +#ifdef __GNUC__ + #pragma GCC diagnostic pop +#endif // __GNUC__ +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + + +// @@protoc_insertion_point(namespace_scope) + +} // namespace compiler +PROTOBUF_NAMESPACE_CLOSE + +PROTOBUF_NAMESPACE_OPEN + +template <> struct is_proto_enum< PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_Feature> : ::std::true_type {}; +template <> +inline const EnumDescriptor* GetEnumDescriptor< PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_Feature>() { + return PROTOBUF_NAMESPACE_ID::compiler::CodeGeneratorResponse_Feature_descriptor(); +} + +PROTOBUF_NAMESPACE_CLOSE + +// @@protoc_insertion_point(global_scope) + +#include +#endif // GOOGLE_PROTOBUF_INCLUDED_GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2fcompiler_2fplugin_2eproto + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/python/python_generator.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/python/python_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..b3d3e7fd6c1dff0a191a3298b6a74d4c54cb3734 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/python/python_generator.h @@ -0,0 +1,187 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: robinson@google.com (Will Robinson) +// +// Generates Python code for a given .proto file. + +#ifndef GOOGLE_PROTOBUF_COMPILER_PYTHON_GENERATOR_H__ +#define GOOGLE_PROTOBUF_COMPILER_PYTHON_GENERATOR_H__ + +#include + +#include +#include + +#include + +namespace google { +namespace protobuf { + +class Descriptor; +class EnumDescriptor; +class EnumValueDescriptor; +class FieldDescriptor; +class OneofDescriptor; +class ServiceDescriptor; + +namespace io { +class Printer; +} + +namespace compiler { +namespace python { + +// CodeGenerator implementation for generated Python protocol buffer classes. +// If you create your own protocol compiler binary and you want it to support +// Python output, you can do so by registering an instance of this +// CodeGenerator with the CommandLineInterface in your main() function. +class PROTOC_EXPORT Generator : public CodeGenerator { + public: + Generator(); + virtual ~Generator(); + + // CodeGenerator methods. + bool Generate(const FileDescriptor* file, const std::string& parameter, + GeneratorContext* generator_context, + std::string* error) const override; + + uint64_t GetSupportedFeatures() const override; + + private: + void PrintImports() const; + void PrintFileDescriptor() const; + void PrintTopLevelEnums() const; + void PrintAllNestedEnumsInFile() const; + void PrintNestedEnums(const Descriptor& descriptor) const; + void PrintEnum(const EnumDescriptor& enum_descriptor) const; + + void PrintTopLevelExtensions() const; + + void PrintFieldDescriptor(const FieldDescriptor& field, + bool is_extension) const; + void PrintFieldDescriptorsInDescriptor( + const Descriptor& message_descriptor, bool is_extension, + const std::string& list_variable_name, int (Descriptor::*CountFn)() const, + const FieldDescriptor* (Descriptor::*GetterFn)(int)const) const; + void PrintFieldsInDescriptor(const Descriptor& message_descriptor) const; + void PrintExtensionsInDescriptor(const Descriptor& message_descriptor) const; + void PrintMessageDescriptors() const; + void PrintDescriptor(const Descriptor& message_descriptor) const; + void PrintNestedDescriptors(const Descriptor& containing_descriptor) const; + + void PrintMessages() const; + void PrintMessage(const Descriptor& message_descriptor, + const std::string& prefix, + std::vector* to_register, + bool is_nested) const; + void PrintNestedMessages(const Descriptor& containing_descriptor, + const std::string& prefix, + std::vector* to_register) const; + + void FixForeignFieldsInDescriptors() const; + void FixForeignFieldsInDescriptor( + const Descriptor& descriptor, + const Descriptor* containing_descriptor) const; + void FixForeignFieldsInField(const Descriptor* containing_type, + const FieldDescriptor& field, + const std::string& python_dict_name) const; + void AddMessageToFileDescriptor(const Descriptor& descriptor) const; + void AddEnumToFileDescriptor(const EnumDescriptor& descriptor) const; + void AddExtensionToFileDescriptor(const FieldDescriptor& descriptor) const; + void AddServiceToFileDescriptor(const ServiceDescriptor& descriptor) const; + std::string FieldReferencingExpression( + const Descriptor* containing_type, const FieldDescriptor& field, + const std::string& python_dict_name) const; + template + void FixContainingTypeInDescriptor( + const DescriptorT& descriptor, + const Descriptor* containing_descriptor) const; + + void FixForeignFieldsInExtensions() const; + void FixForeignFieldsInExtension( + const FieldDescriptor& extension_field) const; + void FixForeignFieldsInNestedExtensions(const Descriptor& descriptor) const; + + void PrintServices() const; + void PrintServiceDescriptors() const; + void PrintServiceDescriptor(const ServiceDescriptor& descriptor) const; + void PrintServiceClass(const ServiceDescriptor& descriptor) const; + void PrintServiceStub(const ServiceDescriptor& descriptor) const; + void PrintDescriptorKeyAndModuleName( + const ServiceDescriptor& descriptor) const; + + void PrintEnumValueDescriptor(const EnumValueDescriptor& descriptor) const; + std::string OptionsValue(const std::string& serialized_options) const; + bool GeneratingDescriptorProto() const; + + template + std::string ModuleLevelDescriptorName(const DescriptorT& descriptor) const; + std::string ModuleLevelMessageName(const Descriptor& descriptor) const; + std::string ModuleLevelServiceDescriptorName( + const ServiceDescriptor& descriptor) const; + + template + void PrintSerializedPbInterval(const DescriptorT& descriptor, + DescriptorProtoT& proto) const; + + void FixAllDescriptorOptions() const; + void FixOptionsForField(const FieldDescriptor& field) const; + void FixOptionsForOneof(const OneofDescriptor& oneof) const; + void FixOptionsForEnum(const EnumDescriptor& descriptor) const; + void FixOptionsForMessage(const Descriptor& descriptor) const; + + void CopyPublicDependenciesAliases(const std::string& copy_from, + const FileDescriptor* file) const; + + // Very coarse-grained lock to ensure that Generate() is reentrant. + // Guards file_, printer_ and file_descriptor_serialized_. + mutable Mutex mutex_; + mutable const FileDescriptor* file_; // Set in Generate(). Under mutex_. + mutable std::string file_descriptor_serialized_; + mutable io::Printer* printer_; // Set in Generate(). Under mutex_. + mutable bool pure_python_workable_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(Generator); +}; + +} // namespace python +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_PYTHON_GENERATOR_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/ruby/ruby_generator.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/ruby/ruby_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..9d297c5f183c8d7a041db0c5c074bf133bf9a222 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/compiler/ruby/ruby_generator.h @@ -0,0 +1,73 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Generates Ruby code for a given .proto file. + +#ifndef GOOGLE_PROTOBUF_COMPILER_RUBY_GENERATOR_H__ +#define GOOGLE_PROTOBUF_COMPILER_RUBY_GENERATOR_H__ + +#include + +#include + +#include + +namespace google { +namespace protobuf { +namespace compiler { +namespace ruby { + +// CodeGenerator implementation for generated Ruby protocol buffer classes. +// If you create your own protocol compiler binary and you want it to support +// Ruby output, you can do so by registering an instance of this +// CodeGenerator with the CommandLineInterface in your main() function. +class PROTOC_EXPORT Generator : public CodeGenerator { + bool Generate(const FileDescriptor* file, const string& parameter, + GeneratorContext* generator_context, + string* error) const override; + uint64_t GetSupportedFeatures() const override { + return FEATURE_PROTO3_OPTIONAL; + } +}; + +} // namespace ruby +} // namespace compiler +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_COMPILER_RUBY_GENERATOR_H__ + + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/io/coded_stream.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/io/coded_stream.h new file mode 100644 index 0000000000000000000000000000000000000000..061d60cd71990af74cabfd23441975367ff636ce --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/io/coded_stream.h @@ -0,0 +1,1719 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// This file contains the CodedInputStream and CodedOutputStream classes, +// which wrap a ZeroCopyInputStream or ZeroCopyOutputStream, respectively, +// and allow you to read or write individual pieces of data in various +// formats. In particular, these implement the varint encoding for +// integers, a simple variable-length encoding in which smaller numbers +// take fewer bytes. +// +// Typically these classes will only be used internally by the protocol +// buffer library in order to encode and decode protocol buffers. Clients +// of the library only need to know about this class if they wish to write +// custom message parsing or serialization procedures. +// +// CodedOutputStream example: +// // Write some data to "myfile". First we write a 4-byte "magic number" +// // to identify the file type, then write a length-delimited string. The +// // string is composed of a varint giving the length followed by the raw +// // bytes. +// int fd = open("myfile", O_CREAT | O_WRONLY); +// ZeroCopyOutputStream* raw_output = new FileOutputStream(fd); +// CodedOutputStream* coded_output = new CodedOutputStream(raw_output); +// +// int magic_number = 1234; +// char text[] = "Hello world!"; +// coded_output->WriteLittleEndian32(magic_number); +// coded_output->WriteVarint32(strlen(text)); +// coded_output->WriteRaw(text, strlen(text)); +// +// delete coded_output; +// delete raw_output; +// close(fd); +// +// CodedInputStream example: +// // Read a file created by the above code. +// int fd = open("myfile", O_RDONLY); +// ZeroCopyInputStream* raw_input = new FileInputStream(fd); +// CodedInputStream* coded_input = new CodedInputStream(raw_input); +// +// coded_input->ReadLittleEndian32(&magic_number); +// if (magic_number != 1234) { +// cerr << "File not in expected format." << endl; +// return; +// } +// +// uint32 size; +// coded_input->ReadVarint32(&size); +// +// char* text = new char[size + 1]; +// coded_input->ReadRaw(buffer, size); +// text[size] = '\0'; +// +// delete coded_input; +// delete raw_input; +// close(fd); +// +// cout << "Text is: " << text << endl; +// delete [] text; +// +// For those who are interested, varint encoding is defined as follows: +// +// The encoding operates on unsigned integers of up to 64 bits in length. +// Each byte of the encoded value has the format: +// * bits 0-6: Seven bits of the number being encoded. +// * bit 7: Zero if this is the last byte in the encoding (in which +// case all remaining bits of the number are zero) or 1 if +// more bytes follow. +// The first byte contains the least-significant 7 bits of the number, the +// second byte (if present) contains the next-least-significant 7 bits, +// and so on. So, the binary number 1011000101011 would be encoded in two +// bytes as "10101011 00101100". +// +// In theory, varint could be used to encode integers of any length. +// However, for practicality we set a limit at 64 bits. The maximum encoded +// length of a number is thus 10 bytes. + +#ifndef GOOGLE_PROTOBUF_IO_CODED_STREAM_H__ +#define GOOGLE_PROTOBUF_IO_CODED_STREAM_H__ + + +#include + +#include +#include +#include +#include +#include +#include +#include + +#ifdef _MSC_VER +// Assuming windows is always little-endian. +#if !defined(PROTOBUF_DISABLE_LITTLE_ENDIAN_OPT_FOR_TEST) +#define PROTOBUF_LITTLE_ENDIAN 1 +#endif +#if _MSC_VER >= 1300 && !defined(__INTEL_COMPILER) +// If MSVC has "/RTCc" set, it will complain about truncating casts at +// runtime. This file contains some intentional truncating casts. +#pragma runtime_checks("c", off) +#endif +#else +#include // __BYTE_ORDER +#if ((defined(__LITTLE_ENDIAN__) && !defined(__BIG_ENDIAN__)) || \ + (defined(__BYTE_ORDER) && __BYTE_ORDER == __LITTLE_ENDIAN)) && \ + !defined(PROTOBUF_DISABLE_LITTLE_ENDIAN_OPT_FOR_TEST) +#define PROTOBUF_LITTLE_ENDIAN 1 +#endif +#endif +#include +#include +#include +#include +#include + + +#include + +namespace google { +namespace protobuf { + +class DescriptorPool; +class MessageFactory; +class ZeroCopyCodedInputStream; + +namespace internal { +void MapTestForceDeterministic(); +class EpsCopyByteStream; +} // namespace internal + +namespace io { + +// Defined in this file. +class CodedInputStream; +class CodedOutputStream; + +// Defined in other files. +class ZeroCopyInputStream; // zero_copy_stream.h +class ZeroCopyOutputStream; // zero_copy_stream.h + +// Class which reads and decodes binary data which is composed of varint- +// encoded integers and fixed-width pieces. Wraps a ZeroCopyInputStream. +// Most users will not need to deal with CodedInputStream. +// +// Most methods of CodedInputStream that return a bool return false if an +// underlying I/O error occurs or if the data is malformed. Once such a +// failure occurs, the CodedInputStream is broken and is no longer useful. +// After a failure, callers also should assume writes to "out" args may have +// occurred, though nothing useful can be determined from those writes. +class PROTOBUF_EXPORT CodedInputStream { + public: + // Create a CodedInputStream that reads from the given ZeroCopyInputStream. + explicit CodedInputStream(ZeroCopyInputStream* input); + + // Create a CodedInputStream that reads from the given flat array. This is + // faster than using an ArrayInputStream. PushLimit(size) is implied by + // this constructor. + explicit CodedInputStream(const uint8* buffer, int size); + + // Destroy the CodedInputStream and position the underlying + // ZeroCopyInputStream at the first unread byte. If an error occurred while + // reading (causing a method to return false), then the exact position of + // the input stream may be anywhere between the last value that was read + // successfully and the stream's byte limit. + ~CodedInputStream(); + + // Return true if this CodedInputStream reads from a flat array instead of + // a ZeroCopyInputStream. + inline bool IsFlat() const; + + // Skips a number of bytes. Returns false if an underlying read error + // occurs. + inline bool Skip(int count); + + // Sets *data to point directly at the unread part of the CodedInputStream's + // underlying buffer, and *size to the size of that buffer, but does not + // advance the stream's current position. This will always either produce + // a non-empty buffer or return false. If the caller consumes any of + // this data, it should then call Skip() to skip over the consumed bytes. + // This may be useful for implementing external fast parsing routines for + // types of data not covered by the CodedInputStream interface. + bool GetDirectBufferPointer(const void** data, int* size); + + // Like GetDirectBufferPointer, but this method is inlined, and does not + // attempt to Refresh() if the buffer is currently empty. + PROTOBUF_ALWAYS_INLINE + void GetDirectBufferPointerInline(const void** data, int* size); + + // Read raw bytes, copying them into the given buffer. + bool ReadRaw(void* buffer, int size); + + // Like ReadRaw, but reads into a string. + bool ReadString(std::string* buffer, int size); + + + // Read a 32-bit little-endian integer. + bool ReadLittleEndian32(uint32* value); + // Read a 64-bit little-endian integer. + bool ReadLittleEndian64(uint64* value); + + // These methods read from an externally provided buffer. The caller is + // responsible for ensuring that the buffer has sufficient space. + // Read a 32-bit little-endian integer. + static const uint8* ReadLittleEndian32FromArray(const uint8* buffer, + uint32* value); + // Read a 64-bit little-endian integer. + static const uint8* ReadLittleEndian64FromArray(const uint8* buffer, + uint64* value); + + // Read an unsigned integer with Varint encoding, truncating to 32 bits. + // Reading a 32-bit value is equivalent to reading a 64-bit one and casting + // it to uint32, but may be more efficient. + bool ReadVarint32(uint32* value); + // Read an unsigned integer with Varint encoding. + bool ReadVarint64(uint64* value); + + // Reads a varint off the wire into an "int". This should be used for reading + // sizes off the wire (sizes of strings, submessages, bytes fields, etc). + // + // The value from the wire is interpreted as unsigned. If its value exceeds + // the representable value of an integer on this platform, instead of + // truncating we return false. Truncating (as performed by ReadVarint32() + // above) is an acceptable approach for fields representing an integer, but + // when we are parsing a size from the wire, truncating the value would result + // in us misparsing the payload. + bool ReadVarintSizeAsInt(int* value); + + // Read a tag. This calls ReadVarint32() and returns the result, or returns + // zero (which is not a valid tag) if ReadVarint32() fails. Also, ReadTag + // (but not ReadTagNoLastTag) updates the last tag value, which can be checked + // with LastTagWas(). + // + // Always inline because this is only called in one place per parse loop + // but it is called for every iteration of said loop, so it should be fast. + // GCC doesn't want to inline this by default. + PROTOBUF_ALWAYS_INLINE uint32 ReadTag() { + return last_tag_ = ReadTagNoLastTag(); + } + + PROTOBUF_ALWAYS_INLINE uint32 ReadTagNoLastTag(); + + // This usually a faster alternative to ReadTag() when cutoff is a manifest + // constant. It does particularly well for cutoff >= 127. The first part + // of the return value is the tag that was read, though it can also be 0 in + // the cases where ReadTag() would return 0. If the second part is true + // then the tag is known to be in [0, cutoff]. If not, the tag either is + // above cutoff or is 0. (There's intentional wiggle room when tag is 0, + // because that can arise in several ways, and for best performance we want + // to avoid an extra "is tag == 0?" check here.) + PROTOBUF_ALWAYS_INLINE + std::pair ReadTagWithCutoff(uint32 cutoff) { + std::pair result = ReadTagWithCutoffNoLastTag(cutoff); + last_tag_ = result.first; + return result; + } + + PROTOBUF_ALWAYS_INLINE + std::pair ReadTagWithCutoffNoLastTag(uint32 cutoff); + + // Usually returns true if calling ReadVarint32() now would produce the given + // value. Will always return false if ReadVarint32() would not return the + // given value. If ExpectTag() returns true, it also advances past + // the varint. For best performance, use a compile-time constant as the + // parameter. + // Always inline because this collapses to a small number of instructions + // when given a constant parameter, but GCC doesn't want to inline by default. + PROTOBUF_ALWAYS_INLINE bool ExpectTag(uint32 expected); + + // Like above, except this reads from the specified buffer. The caller is + // responsible for ensuring that the buffer is large enough to read a varint + // of the expected size. For best performance, use a compile-time constant as + // the expected tag parameter. + // + // Returns a pointer beyond the expected tag if it was found, or NULL if it + // was not. + PROTOBUF_ALWAYS_INLINE + static const uint8* ExpectTagFromArray(const uint8* buffer, uint32 expected); + + // Usually returns true if no more bytes can be read. Always returns false + // if more bytes can be read. If ExpectAtEnd() returns true, a subsequent + // call to LastTagWas() will act as if ReadTag() had been called and returned + // zero, and ConsumedEntireMessage() will return true. + bool ExpectAtEnd(); + + // If the last call to ReadTag() or ReadTagWithCutoff() returned the given + // value, returns true. Otherwise, returns false. + // ReadTagNoLastTag/ReadTagWithCutoffNoLastTag do not preserve the last + // returned value. + // + // This is needed because parsers for some types of embedded messages + // (with field type TYPE_GROUP) don't actually know that they've reached the + // end of a message until they see an ENDGROUP tag, which was actually part + // of the enclosing message. The enclosing message would like to check that + // tag to make sure it had the right number, so it calls LastTagWas() on + // return from the embedded parser to check. + bool LastTagWas(uint32 expected); + void SetLastTag(uint32 tag) { last_tag_ = tag; } + + // When parsing message (but NOT a group), this method must be called + // immediately after MergeFromCodedStream() returns (if it returns true) + // to further verify that the message ended in a legitimate way. For + // example, this verifies that parsing did not end on an end-group tag. + // It also checks for some cases where, due to optimizations, + // MergeFromCodedStream() can incorrectly return true. + bool ConsumedEntireMessage(); + void SetConsumed() { legitimate_message_end_ = true; } + + // Limits ---------------------------------------------------------- + // Limits are used when parsing length-delimited embedded messages. + // After the message's length is read, PushLimit() is used to prevent + // the CodedInputStream from reading beyond that length. Once the + // embedded message has been parsed, PopLimit() is called to undo the + // limit. + + // Opaque type used with PushLimit() and PopLimit(). Do not modify + // values of this type yourself. The only reason that this isn't a + // struct with private internals is for efficiency. + typedef int Limit; + + // Places a limit on the number of bytes that the stream may read, + // starting from the current position. Once the stream hits this limit, + // it will act like the end of the input has been reached until PopLimit() + // is called. + // + // As the names imply, the stream conceptually has a stack of limits. The + // shortest limit on the stack is always enforced, even if it is not the + // top limit. + // + // The value returned by PushLimit() is opaque to the caller, and must + // be passed unchanged to the corresponding call to PopLimit(). + Limit PushLimit(int byte_limit); + + // Pops the last limit pushed by PushLimit(). The input must be the value + // returned by that call to PushLimit(). + void PopLimit(Limit limit); + + // Returns the number of bytes left until the nearest limit on the + // stack is hit, or -1 if no limits are in place. + int BytesUntilLimit() const; + + // Returns current position relative to the beginning of the input stream. + int CurrentPosition() const; + + // Total Bytes Limit ----------------------------------------------- + // To prevent malicious users from sending excessively large messages + // and causing memory exhaustion, CodedInputStream imposes a hard limit on + // the total number of bytes it will read. + + // Sets the maximum number of bytes that this CodedInputStream will read + // before refusing to continue. To prevent servers from allocating enormous + // amounts of memory to hold parsed messages, the maximum message length + // should be limited to the shortest length that will not harm usability. + // The default limit is INT_MAX (~2GB) and apps should set shorter limits + // if possible. An error will always be printed to stderr if the limit is + // reached. + // + // Note: setting a limit less than the current read position is interpreted + // as a limit on the current position. + // + // This is unrelated to PushLimit()/PopLimit(). + void SetTotalBytesLimit(int total_bytes_limit); + + PROTOBUF_DEPRECATED_MSG( + "Please use the single parameter version of SetTotalBytesLimit(). The " + "second parameter is ignored.") + void SetTotalBytesLimit(int total_bytes_limit, int) { + SetTotalBytesLimit(total_bytes_limit); + } + + // The Total Bytes Limit minus the Current Position, or -1 if the total bytes + // limit is INT_MAX. + int BytesUntilTotalBytesLimit() const; + + // Recursion Limit ------------------------------------------------- + // To prevent corrupt or malicious messages from causing stack overflows, + // we must keep track of the depth of recursion when parsing embedded + // messages and groups. CodedInputStream keeps track of this because it + // is the only object that is passed down the stack during parsing. + + // Sets the maximum recursion depth. The default is 100. + void SetRecursionLimit(int limit); + int RecursionBudget() { return recursion_budget_; } + + static int GetDefaultRecursionLimit() { return default_recursion_limit_; } + + // Increments the current recursion depth. Returns true if the depth is + // under the limit, false if it has gone over. + bool IncrementRecursionDepth(); + + // Decrements the recursion depth if possible. + void DecrementRecursionDepth(); + + // Decrements the recursion depth blindly. This is faster than + // DecrementRecursionDepth(). It should be used only if all previous + // increments to recursion depth were successful. + void UnsafeDecrementRecursionDepth(); + + // Shorthand for make_pair(PushLimit(byte_limit), --recursion_budget_). + // Using this can reduce code size and complexity in some cases. The caller + // is expected to check that the second part of the result is non-negative (to + // bail out if the depth of recursion is too high) and, if all is well, to + // later pass the first part of the result to PopLimit() or similar. + std::pair IncrementRecursionDepthAndPushLimit( + int byte_limit); + + // Shorthand for PushLimit(ReadVarint32(&length) ? length : 0). + Limit ReadLengthAndPushLimit(); + + // Helper that is equivalent to: { + // bool result = ConsumedEntireMessage(); + // PopLimit(limit); + // UnsafeDecrementRecursionDepth(); + // return result; } + // Using this can reduce code size and complexity in some cases. + // Do not use unless the current recursion depth is greater than zero. + bool DecrementRecursionDepthAndPopLimit(Limit limit); + + // Helper that is equivalent to: { + // bool result = ConsumedEntireMessage(); + // PopLimit(limit); + // return result; } + // Using this can reduce code size and complexity in some cases. + bool CheckEntireMessageConsumedAndPopLimit(Limit limit); + + // Extension Registry ---------------------------------------------- + // ADVANCED USAGE: 99.9% of people can ignore this section. + // + // By default, when parsing extensions, the parser looks for extension + // definitions in the pool which owns the outer message's Descriptor. + // However, you may call SetExtensionRegistry() to provide an alternative + // pool instead. This makes it possible, for example, to parse a message + // using a generated class, but represent some extensions using + // DynamicMessage. + + // Set the pool used to look up extensions. Most users do not need to call + // this as the correct pool will be chosen automatically. + // + // WARNING: It is very easy to misuse this. Carefully read the requirements + // below. Do not use this unless you are sure you need it. Almost no one + // does. + // + // Let's say you are parsing a message into message object m, and you want + // to take advantage of SetExtensionRegistry(). You must follow these + // requirements: + // + // The given DescriptorPool must contain m->GetDescriptor(). It is not + // sufficient for it to simply contain a descriptor that has the same name + // and content -- it must be the *exact object*. In other words: + // assert(pool->FindMessageTypeByName(m->GetDescriptor()->full_name()) == + // m->GetDescriptor()); + // There are two ways to satisfy this requirement: + // 1) Use m->GetDescriptor()->pool() as the pool. This is generally useless + // because this is the pool that would be used anyway if you didn't call + // SetExtensionRegistry() at all. + // 2) Use a DescriptorPool which has m->GetDescriptor()->pool() as an + // "underlay". Read the documentation for DescriptorPool for more + // information about underlays. + // + // You must also provide a MessageFactory. This factory will be used to + // construct Message objects representing extensions. The factory's + // GetPrototype() MUST return non-NULL for any Descriptor which can be found + // through the provided pool. + // + // If the provided factory might return instances of protocol-compiler- + // generated (i.e. compiled-in) types, or if the outer message object m is + // a generated type, then the given factory MUST have this property: If + // GetPrototype() is given a Descriptor which resides in + // DescriptorPool::generated_pool(), the factory MUST return the same + // prototype which MessageFactory::generated_factory() would return. That + // is, given a descriptor for a generated type, the factory must return an + // instance of the generated class (NOT DynamicMessage). However, when + // given a descriptor for a type that is NOT in generated_pool, the factory + // is free to return any implementation. + // + // The reason for this requirement is that generated sub-objects may be + // accessed via the standard (non-reflection) extension accessor methods, + // and these methods will down-cast the object to the generated class type. + // If the object is not actually of that type, the results would be undefined. + // On the other hand, if an extension is not compiled in, then there is no + // way the code could end up accessing it via the standard accessors -- the + // only way to access the extension is via reflection. When using reflection, + // DynamicMessage and generated messages are indistinguishable, so it's fine + // if these objects are represented using DynamicMessage. + // + // Using DynamicMessageFactory on which you have called + // SetDelegateToGeneratedFactory(true) should be sufficient to satisfy the + // above requirement. + // + // If either pool or factory is NULL, both must be NULL. + // + // Note that this feature is ignored when parsing "lite" messages as they do + // not have descriptors. + void SetExtensionRegistry(const DescriptorPool* pool, + MessageFactory* factory); + + // Get the DescriptorPool set via SetExtensionRegistry(), or NULL if no pool + // has been provided. + const DescriptorPool* GetExtensionPool(); + + // Get the MessageFactory set via SetExtensionRegistry(), or NULL if no + // factory has been provided. + MessageFactory* GetExtensionFactory(); + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CodedInputStream); + + const uint8* buffer_; + const uint8* buffer_end_; // pointer to the end of the buffer. + ZeroCopyInputStream* input_; + int total_bytes_read_; // total bytes read from input_, including + // the current buffer + + // If total_bytes_read_ surpasses INT_MAX, we record the extra bytes here + // so that we can BackUp() on destruction. + int overflow_bytes_; + + // LastTagWas() stuff. + uint32 last_tag_; // result of last ReadTag() or ReadTagWithCutoff(). + + // This is set true by ReadTag{Fallback/Slow}() if it is called when exactly + // at EOF, or by ExpectAtEnd() when it returns true. This happens when we + // reach the end of a message and attempt to read another tag. + bool legitimate_message_end_; + + // See EnableAliasing(). + bool aliasing_enabled_; + + // Limits + Limit current_limit_; // if position = -1, no limit is applied + + // For simplicity, if the current buffer crosses a limit (either a normal + // limit created by PushLimit() or the total bytes limit), buffer_size_ + // only tracks the number of bytes before that limit. This field + // contains the number of bytes after it. Note that this implies that if + // buffer_size_ == 0 and buffer_size_after_limit_ > 0, we know we've + // hit a limit. However, if both are zero, it doesn't necessarily mean + // we aren't at a limit -- the buffer may have ended exactly at the limit. + int buffer_size_after_limit_; + + // Maximum number of bytes to read, period. This is unrelated to + // current_limit_. Set using SetTotalBytesLimit(). + int total_bytes_limit_; + + // Current recursion budget, controlled by IncrementRecursionDepth() and + // similar. Starts at recursion_limit_ and goes down: if this reaches + // -1 we are over budget. + int recursion_budget_; + // Recursion depth limit, set by SetRecursionLimit(). + int recursion_limit_; + + // See SetExtensionRegistry(). + const DescriptorPool* extension_pool_; + MessageFactory* extension_factory_; + + // Private member functions. + + // Fallback when Skip() goes past the end of the current buffer. + bool SkipFallback(int count, int original_buffer_size); + + // Advance the buffer by a given number of bytes. + void Advance(int amount); + + // Back up input_ to the current buffer position. + void BackUpInputToCurrentPosition(); + + // Recomputes the value of buffer_size_after_limit_. Must be called after + // current_limit_ or total_bytes_limit_ changes. + void RecomputeBufferLimits(); + + // Writes an error message saying that we hit total_bytes_limit_. + void PrintTotalBytesLimitError(); + + // Called when the buffer runs out to request more data. Implies an + // Advance(BufferSize()). + bool Refresh(); + + // When parsing varints, we optimize for the common case of small values, and + // then optimize for the case when the varint fits within the current buffer + // piece. The Fallback method is used when we can't use the one-byte + // optimization. The Slow method is yet another fallback when the buffer is + // not large enough. Making the slow path out-of-line speeds up the common + // case by 10-15%. The slow path is fairly uncommon: it only triggers when a + // message crosses multiple buffers. Note: ReadVarint32Fallback() and + // ReadVarint64Fallback() are called frequently and generally not inlined, so + // they have been optimized to avoid "out" parameters. The former returns -1 + // if it fails and the uint32 it read otherwise. The latter has a bool + // indicating success or failure as part of its return type. + int64 ReadVarint32Fallback(uint32 first_byte_or_zero); + int ReadVarintSizeAsIntFallback(); + std::pair ReadVarint64Fallback(); + bool ReadVarint32Slow(uint32* value); + bool ReadVarint64Slow(uint64* value); + int ReadVarintSizeAsIntSlow(); + bool ReadLittleEndian32Fallback(uint32* value); + bool ReadLittleEndian64Fallback(uint64* value); + + // Fallback/slow methods for reading tags. These do not update last_tag_, + // but will set legitimate_message_end_ if we are at the end of the input + // stream. + uint32 ReadTagFallback(uint32 first_byte_or_zero); + uint32 ReadTagSlow(); + bool ReadStringFallback(std::string* buffer, int size); + + // Return the size of the buffer. + int BufferSize() const; + + static const int kDefaultTotalBytesLimit = INT_MAX; + + static int default_recursion_limit_; // 100 by default. + + friend class google::protobuf::ZeroCopyCodedInputStream; + friend class google::protobuf::internal::EpsCopyByteStream; +}; + +// EpsCopyOutputStream wraps a ZeroCopyOutputStream and exposes a new stream, +// which has the property you can write kSlopBytes (16 bytes) from the current +// position without bounds checks. The cursor into the stream is managed by +// the user of the class and is an explicit parameter in the methods. Careful +// use of this class, ie. keep ptr a local variable, eliminates the need to +// for the compiler to sync the ptr value between register and memory. +class PROTOBUF_EXPORT EpsCopyOutputStream { + public: + enum { kSlopBytes = 16 }; + + // Initialize from a stream. + EpsCopyOutputStream(ZeroCopyOutputStream* stream, bool deterministic, + uint8** pp) + : end_(buffer_), + stream_(stream), + is_serialization_deterministic_(deterministic) { + *pp = buffer_; + } + + // Only for array serialization. No overflow protection, end_ will be the + // pointed to the end of the array. When using this the total size is already + // known, so no need to maintain the slop region. + EpsCopyOutputStream(void* data, int size, bool deterministic) + : end_(static_cast(data) + size), + buffer_end_(nullptr), + stream_(nullptr), + is_serialization_deterministic_(deterministic) {} + + // Initialize from stream but with the first buffer already given (eager). + EpsCopyOutputStream(void* data, int size, ZeroCopyOutputStream* stream, + bool deterministic, uint8** pp) + : stream_(stream), is_serialization_deterministic_(deterministic) { + *pp = SetInitialBuffer(data, size); + } + + // Flush everything that's written into the underlying ZeroCopyOutputStream + // and trims the underlying stream to the location of ptr. + uint8* Trim(uint8* ptr); + + // After this it's guaranteed you can safely write kSlopBytes to ptr. This + // will never fail! The underlying stream can produce an error. Use HadError + // to check for errors. + PROTOBUF_MUST_USE_RESULT uint8* EnsureSpace(uint8* ptr) { + if (PROTOBUF_PREDICT_FALSE(ptr >= end_)) { + return EnsureSpaceFallback(ptr); + } + return ptr; + } + + uint8* WriteRaw(const void* data, int size, uint8* ptr) { + if (PROTOBUF_PREDICT_FALSE(end_ - ptr < size)) { + return WriteRawFallback(data, size, ptr); + } + std::memcpy(ptr, data, size); + return ptr + size; + } + // Writes the buffer specified by data, size to the stream. Possibly by + // aliasing the buffer (ie. not copying the data). The caller is responsible + // to make sure the buffer is alive for the duration of the + // ZeroCopyOutputStream. + uint8* WriteRawMaybeAliased(const void* data, int size, uint8* ptr) { + if (aliasing_enabled_) { + return WriteAliasedRaw(data, size, ptr); + } else { + return WriteRaw(data, size, ptr); + } + } + + + uint8* WriteStringMaybeAliased(uint32 num, const std::string& s, uint8* ptr) { + std::ptrdiff_t size = s.size(); + if (PROTOBUF_PREDICT_FALSE( + size >= 128 || end_ - ptr + 16 - TagSize(num << 3) - 1 < size)) { + return WriteStringMaybeAliasedOutline(num, s, ptr); + } + ptr = UnsafeVarint((num << 3) | 2, ptr); + *ptr++ = static_cast(size); + std::memcpy(ptr, s.data(), size); + return ptr + size; + } + uint8* WriteBytesMaybeAliased(uint32 num, const std::string& s, uint8* ptr) { + return WriteStringMaybeAliased(num, s, ptr); + } + + template + PROTOBUF_ALWAYS_INLINE uint8* WriteString(uint32 num, const T& s, + uint8* ptr) { + std::ptrdiff_t size = s.size(); + if (PROTOBUF_PREDICT_FALSE( + size >= 128 || end_ - ptr + 16 - TagSize(num << 3) - 1 < size)) { + return WriteStringOutline(num, s, ptr); + } + ptr = UnsafeVarint((num << 3) | 2, ptr); + *ptr++ = static_cast(size); + std::memcpy(ptr, s.data(), size); + return ptr + size; + } + template + uint8* WriteBytes(uint32 num, const T& s, uint8* ptr) { + return WriteString(num, s, ptr); + } + + template + PROTOBUF_ALWAYS_INLINE uint8* WriteInt32Packed(int num, const T& r, int size, + uint8* ptr) { + return WriteVarintPacked(num, r, size, ptr, Encode64); + } + template + PROTOBUF_ALWAYS_INLINE uint8* WriteUInt32Packed(int num, const T& r, int size, + uint8* ptr) { + return WriteVarintPacked(num, r, size, ptr, Encode32); + } + template + PROTOBUF_ALWAYS_INLINE uint8* WriteSInt32Packed(int num, const T& r, int size, + uint8* ptr) { + return WriteVarintPacked(num, r, size, ptr, ZigZagEncode32); + } + template + PROTOBUF_ALWAYS_INLINE uint8* WriteInt64Packed(int num, const T& r, int size, + uint8* ptr) { + return WriteVarintPacked(num, r, size, ptr, Encode64); + } + template + PROTOBUF_ALWAYS_INLINE uint8* WriteUInt64Packed(int num, const T& r, int size, + uint8* ptr) { + return WriteVarintPacked(num, r, size, ptr, Encode64); + } + template + PROTOBUF_ALWAYS_INLINE uint8* WriteSInt64Packed(int num, const T& r, int size, + uint8* ptr) { + return WriteVarintPacked(num, r, size, ptr, ZigZagEncode64); + } + template + PROTOBUF_ALWAYS_INLINE uint8* WriteEnumPacked(int num, const T& r, int size, + uint8* ptr) { + return WriteVarintPacked(num, r, size, ptr, Encode64); + } + + template + PROTOBUF_ALWAYS_INLINE uint8* WriteFixedPacked(int num, const T& r, + uint8* ptr) { + ptr = EnsureSpace(ptr); + constexpr auto element_size = sizeof(typename T::value_type); + auto size = r.size() * element_size; + ptr = WriteLengthDelim(num, size, ptr); + return WriteRawLittleEndian(r.data(), static_cast(size), + ptr); + } + + // Returns true if there was an underlying I/O error since this object was + // created. + bool HadError() const { return had_error_; } + + // Instructs the EpsCopyOutputStream to allow the underlying + // ZeroCopyOutputStream to hold pointers to the original structure instead of + // copying, if it supports it (i.e. output->AllowsAliasing() is true). If the + // underlying stream does not support aliasing, then enabling it has no + // affect. For now, this only affects the behavior of + // WriteRawMaybeAliased(). + // + // NOTE: It is caller's responsibility to ensure that the chunk of memory + // remains live until all of the data has been consumed from the stream. + void EnableAliasing(bool enabled); + + // See documentation on CodedOutputStream::SetSerializationDeterministic. + void SetSerializationDeterministic(bool value) { + is_serialization_deterministic_ = value; + } + + // See documentation on CodedOutputStream::IsSerializationDeterministic. + bool IsSerializationDeterministic() const { + return is_serialization_deterministic_; + } + + // The number of bytes written to the stream at position ptr, relative to the + // stream's overall position. + int64 ByteCount(uint8* ptr) const; + + + private: + uint8* end_; + uint8* buffer_end_ = buffer_; + uint8 buffer_[2 * kSlopBytes]; + ZeroCopyOutputStream* stream_; + bool had_error_ = false; + bool aliasing_enabled_ = false; // See EnableAliasing(). + bool is_serialization_deterministic_; + + uint8* EnsureSpaceFallback(uint8* ptr); + inline uint8* Next(); + int Flush(uint8* ptr); + std::ptrdiff_t GetSize(uint8* ptr) const { + GOOGLE_DCHECK(ptr <= end_ + kSlopBytes); // NOLINT + return end_ + kSlopBytes - ptr; + } + + uint8* Error() { + had_error_ = true; + // We use the patch buffer to always guarantee space to write to. + end_ = buffer_ + kSlopBytes; + return buffer_; + } + + static constexpr int TagSize(uint32 tag) { + return (tag < (1 << 7)) + ? 1 + : (tag < (1 << 14)) + ? 2 + : (tag < (1 << 21)) ? 3 : (tag < (1 << 28)) ? 4 : 5; + } + + PROTOBUF_ALWAYS_INLINE uint8* WriteTag(uint32 num, uint32 wt, uint8* ptr) { + GOOGLE_DCHECK(ptr < end_); // NOLINT + return UnsafeVarint((num << 3) | wt, ptr); + } + + PROTOBUF_ALWAYS_INLINE uint8* WriteLengthDelim(int num, uint32 size, + uint8* ptr) { + ptr = WriteTag(num, 2, ptr); + return UnsafeWriteSize(size, ptr); + } + + uint8* WriteRawFallback(const void* data, int size, uint8* ptr); + + uint8* WriteAliasedRaw(const void* data, int size, uint8* ptr); + + uint8* WriteStringMaybeAliasedOutline(uint32 num, const std::string& s, + uint8* ptr); + uint8* WriteStringOutline(uint32 num, const std::string& s, uint8* ptr); + + template + PROTOBUF_ALWAYS_INLINE uint8* WriteVarintPacked(int num, const T& r, int size, + uint8* ptr, const E& encode) { + ptr = EnsureSpace(ptr); + ptr = WriteLengthDelim(num, size, ptr); + auto it = r.data(); + auto end = it + r.size(); + do { + ptr = EnsureSpace(ptr); + ptr = UnsafeVarint(encode(*it++), ptr); + } while (it < end); + return ptr; + } + + static uint32 Encode32(uint32 v) { return v; } + static uint64 Encode64(uint64 v) { return v; } + static uint32 ZigZagEncode32(int32 v) { + return (static_cast(v) << 1) ^ static_cast(v >> 31); + } + static uint64 ZigZagEncode64(int64 v) { + return (static_cast(v) << 1) ^ static_cast(v >> 63); + } + + template + PROTOBUF_ALWAYS_INLINE static uint8* UnsafeVarint(T value, uint8* ptr) { + static_assert(std::is_unsigned::value, + "Varint serialization must be unsigned"); + if (value < 0x80) { + ptr[0] = static_cast(value); + return ptr + 1; + } + ptr[0] = static_cast(value | 0x80); + value >>= 7; + if (value < 0x80) { + ptr[1] = static_cast(value); + return ptr + 2; + } + ptr++; + do { + *ptr = static_cast(value | 0x80); + value >>= 7; + ++ptr; + } while (PROTOBUF_PREDICT_FALSE(value >= 0x80)); + *ptr++ = static_cast(value); + return ptr; + } + + PROTOBUF_ALWAYS_INLINE static uint8* UnsafeWriteSize(uint32 value, + uint8* ptr) { + while (PROTOBUF_PREDICT_FALSE(value >= 0x80)) { + *ptr = static_cast(value | 0x80); + value >>= 7; + ++ptr; + } + *ptr++ = static_cast(value); + return ptr; + } + + template + uint8* WriteRawLittleEndian(const void* data, int size, uint8* ptr); +#ifndef PROTOBUF_LITTLE_ENDIAN + uint8* WriteRawLittleEndian32(const void* data, int size, uint8* ptr); + uint8* WriteRawLittleEndian64(const void* data, int size, uint8* ptr); +#endif + + // These methods are for CodedOutputStream. Ideally they should be private + // but to match current behavior of CodedOutputStream as close as possible + // we allow it some functionality. + public: + uint8* SetInitialBuffer(void* data, int size) { + auto ptr = static_cast(data); + if (size > kSlopBytes) { + end_ = ptr + size - kSlopBytes; + buffer_end_ = nullptr; + return ptr; + } else { + end_ = buffer_ + size; + buffer_end_ = ptr; + return buffer_; + } + } + + private: + // Needed by CodedOutputStream HadError. HadError needs to flush the patch + // buffers to ensure there is no error as of yet. + uint8* FlushAndResetBuffer(uint8*); + + // The following functions mimick the old CodedOutputStream behavior as close + // as possible. They flush the current state to the stream, behave as + // the old CodedOutputStream and then return to normal operation. + bool Skip(int count, uint8** pp); + bool GetDirectBufferPointer(void** data, int* size, uint8** pp); + uint8* GetDirectBufferForNBytesAndAdvance(int size, uint8** pp); + + friend class CodedOutputStream; +}; + +template <> +inline uint8* EpsCopyOutputStream::WriteRawLittleEndian<1>(const void* data, + int size, + uint8* ptr) { + return WriteRaw(data, size, ptr); +} +template <> +inline uint8* EpsCopyOutputStream::WriteRawLittleEndian<4>(const void* data, + int size, + uint8* ptr) { +#ifdef PROTOBUF_LITTLE_ENDIAN + return WriteRaw(data, size, ptr); +#else + return WriteRawLittleEndian32(data, size, ptr); +#endif +} +template <> +inline uint8* EpsCopyOutputStream::WriteRawLittleEndian<8>(const void* data, + int size, + uint8* ptr) { +#ifdef PROTOBUF_LITTLE_ENDIAN + return WriteRaw(data, size, ptr); +#else + return WriteRawLittleEndian64(data, size, ptr); +#endif +} + +// Class which encodes and writes binary data which is composed of varint- +// encoded integers and fixed-width pieces. Wraps a ZeroCopyOutputStream. +// Most users will not need to deal with CodedOutputStream. +// +// Most methods of CodedOutputStream which return a bool return false if an +// underlying I/O error occurs. Once such a failure occurs, the +// CodedOutputStream is broken and is no longer useful. The Write* methods do +// not return the stream status, but will invalidate the stream if an error +// occurs. The client can probe HadError() to determine the status. +// +// Note that every method of CodedOutputStream which writes some data has +// a corresponding static "ToArray" version. These versions write directly +// to the provided buffer, returning a pointer past the last written byte. +// They require that the buffer has sufficient capacity for the encoded data. +// This allows an optimization where we check if an output stream has enough +// space for an entire message before we start writing and, if there is, we +// call only the ToArray methods to avoid doing bound checks for each +// individual value. +// i.e., in the example above: +// +// CodedOutputStream* coded_output = new CodedOutputStream(raw_output); +// int magic_number = 1234; +// char text[] = "Hello world!"; +// +// int coded_size = sizeof(magic_number) + +// CodedOutputStream::VarintSize32(strlen(text)) + +// strlen(text); +// +// uint8* buffer = +// coded_output->GetDirectBufferForNBytesAndAdvance(coded_size); +// if (buffer != nullptr) { +// // The output stream has enough space in the buffer: write directly to +// // the array. +// buffer = CodedOutputStream::WriteLittleEndian32ToArray(magic_number, +// buffer); +// buffer = CodedOutputStream::WriteVarint32ToArray(strlen(text), buffer); +// buffer = CodedOutputStream::WriteRawToArray(text, strlen(text), buffer); +// } else { +// // Make bound-checked writes, which will ask the underlying stream for +// // more space as needed. +// coded_output->WriteLittleEndian32(magic_number); +// coded_output->WriteVarint32(strlen(text)); +// coded_output->WriteRaw(text, strlen(text)); +// } +// +// delete coded_output; +class PROTOBUF_EXPORT CodedOutputStream { + public: + // Create an CodedOutputStream that writes to the given ZeroCopyOutputStream. + explicit CodedOutputStream(ZeroCopyOutputStream* stream) + : CodedOutputStream(stream, true) {} + CodedOutputStream(ZeroCopyOutputStream* stream, bool do_eager_refresh); + + // Destroy the CodedOutputStream and position the underlying + // ZeroCopyOutputStream immediately after the last byte written. + ~CodedOutputStream(); + + // Returns true if there was an underlying I/O error since this object was + // created. On should call Trim before this function in order to catch all + // errors. + bool HadError() { + cur_ = impl_.FlushAndResetBuffer(cur_); + GOOGLE_DCHECK(cur_); + return impl_.HadError(); + } + + // Trims any unused space in the underlying buffer so that its size matches + // the number of bytes written by this stream. The underlying buffer will + // automatically be trimmed when this stream is destroyed; this call is only + // necessary if the underlying buffer is accessed *before* the stream is + // destroyed. + void Trim() { cur_ = impl_.Trim(cur_); } + + // Skips a number of bytes, leaving the bytes unmodified in the underlying + // buffer. Returns false if an underlying write error occurs. This is + // mainly useful with GetDirectBufferPointer(). + // Note of caution, the skipped bytes may contain uninitialized data. The + // caller must make sure that the skipped bytes are properly initialized, + // otherwise you might leak bytes from your heap. + bool Skip(int count) { return impl_.Skip(count, &cur_); } + + // Sets *data to point directly at the unwritten part of the + // CodedOutputStream's underlying buffer, and *size to the size of that + // buffer, but does not advance the stream's current position. This will + // always either produce a non-empty buffer or return false. If the caller + // writes any data to this buffer, it should then call Skip() to skip over + // the consumed bytes. This may be useful for implementing external fast + // serialization routines for types of data not covered by the + // CodedOutputStream interface. + bool GetDirectBufferPointer(void** data, int* size) { + return impl_.GetDirectBufferPointer(data, size, &cur_); + } + + // If there are at least "size" bytes available in the current buffer, + // returns a pointer directly into the buffer and advances over these bytes. + // The caller may then write directly into this buffer (e.g. using the + // *ToArray static methods) rather than go through CodedOutputStream. If + // there are not enough bytes available, returns NULL. The return pointer is + // invalidated as soon as any other non-const method of CodedOutputStream + // is called. + inline uint8* GetDirectBufferForNBytesAndAdvance(int size) { + return impl_.GetDirectBufferForNBytesAndAdvance(size, &cur_); + } + + // Write raw bytes, copying them from the given buffer. + void WriteRaw(const void* buffer, int size) { + cur_ = impl_.WriteRaw(buffer, size, cur_); + } + // Like WriteRaw() but will try to write aliased data if aliasing is + // turned on. + void WriteRawMaybeAliased(const void* data, int size); + // Like WriteRaw() but writing directly to the target array. + // This is _not_ inlined, as the compiler often optimizes memcpy into inline + // copy loops. Since this gets called by every field with string or bytes + // type, inlining may lead to a significant amount of code bloat, with only a + // minor performance gain. + static uint8* WriteRawToArray(const void* buffer, int size, uint8* target); + + // Equivalent to WriteRaw(str.data(), str.size()). + void WriteString(const std::string& str); + // Like WriteString() but writing directly to the target array. + static uint8* WriteStringToArray(const std::string& str, uint8* target); + // Write the varint-encoded size of str followed by str. + static uint8* WriteStringWithSizeToArray(const std::string& str, + uint8* target); + + + // Write a 32-bit little-endian integer. + void WriteLittleEndian32(uint32 value) { + cur_ = impl_.EnsureSpace(cur_); + SetCur(WriteLittleEndian32ToArray(value, Cur())); + } + // Like WriteLittleEndian32() but writing directly to the target array. + static uint8* WriteLittleEndian32ToArray(uint32 value, uint8* target); + // Write a 64-bit little-endian integer. + void WriteLittleEndian64(uint64 value) { + cur_ = impl_.EnsureSpace(cur_); + SetCur(WriteLittleEndian64ToArray(value, Cur())); + } + // Like WriteLittleEndian64() but writing directly to the target array. + static uint8* WriteLittleEndian64ToArray(uint64 value, uint8* target); + + // Write an unsigned integer with Varint encoding. Writing a 32-bit value + // is equivalent to casting it to uint64 and writing it as a 64-bit value, + // but may be more efficient. + void WriteVarint32(uint32 value); + // Like WriteVarint32() but writing directly to the target array. + static uint8* WriteVarint32ToArray(uint32 value, uint8* target); + // Write an unsigned integer with Varint encoding. + void WriteVarint64(uint64 value); + // Like WriteVarint64() but writing directly to the target array. + static uint8* WriteVarint64ToArray(uint64 value, uint8* target); + + // Equivalent to WriteVarint32() except when the value is negative, + // in which case it must be sign-extended to a full 10 bytes. + void WriteVarint32SignExtended(int32 value); + // Like WriteVarint32SignExtended() but writing directly to the target array. + static uint8* WriteVarint32SignExtendedToArray(int32 value, uint8* target); + + // This is identical to WriteVarint32(), but optimized for writing tags. + // In particular, if the input is a compile-time constant, this method + // compiles down to a couple instructions. + // Always inline because otherwise the aformentioned optimization can't work, + // but GCC by default doesn't want to inline this. + void WriteTag(uint32 value); + // Like WriteTag() but writing directly to the target array. + PROTOBUF_ALWAYS_INLINE + static uint8* WriteTagToArray(uint32 value, uint8* target); + + // Returns the number of bytes needed to encode the given value as a varint. + static size_t VarintSize32(uint32 value); + // Returns the number of bytes needed to encode the given value as a varint. + static size_t VarintSize64(uint64 value); + + // If negative, 10 bytes. Otherwise, same as VarintSize32(). + static size_t VarintSize32SignExtended(int32 value); + + // Compile-time equivalent of VarintSize32(). + template + struct StaticVarintSize32 { + static const size_t value = + (Value < (1 << 7)) + ? 1 + : (Value < (1 << 14)) + ? 2 + : (Value < (1 << 21)) ? 3 : (Value < (1 << 28)) ? 4 : 5; + }; + + // Returns the total number of bytes written since this object was created. + int ByteCount() const { + return static_cast(impl_.ByteCount(cur_) - start_count_); + } + + // Instructs the CodedOutputStream to allow the underlying + // ZeroCopyOutputStream to hold pointers to the original structure instead of + // copying, if it supports it (i.e. output->AllowsAliasing() is true). If the + // underlying stream does not support aliasing, then enabling it has no + // affect. For now, this only affects the behavior of + // WriteRawMaybeAliased(). + // + // NOTE: It is caller's responsibility to ensure that the chunk of memory + // remains live until all of the data has been consumed from the stream. + void EnableAliasing(bool enabled) { impl_.EnableAliasing(enabled); } + + // Indicate to the serializer whether the user wants derministic + // serialization. The default when this is not called comes from the global + // default, controlled by SetDefaultSerializationDeterministic. + // + // What deterministic serialization means is entirely up to the driver of the + // serialization process (i.e. the caller of methods like WriteVarint32). In + // the case of serializing a proto buffer message using one of the methods of + // MessageLite, this means that for a given binary equal messages will always + // be serialized to the same bytes. This implies: + // + // * Repeated serialization of a message will return the same bytes. + // + // * Different processes running the same binary (including on different + // machines) will serialize equal messages to the same bytes. + // + // Note that this is *not* canonical across languages. It is also unstable + // across different builds with intervening message definition changes, due to + // unknown fields. Users who need canonical serialization (e.g. persistent + // storage in a canonical form, fingerprinting) should define their own + // canonicalization specification and implement the serializer using + // reflection APIs rather than relying on this API. + void SetSerializationDeterministic(bool value) { + impl_.SetSerializationDeterministic(value); + } + + // Return whether the user wants deterministic serialization. See above. + bool IsSerializationDeterministic() const { + return impl_.IsSerializationDeterministic(); + } + + static bool IsDefaultSerializationDeterministic() { + return default_serialization_deterministic_.load( + std::memory_order_relaxed) != 0; + } + + template + void Serialize(const Func& func); + + uint8* Cur() const { return cur_; } + void SetCur(uint8* ptr) { cur_ = ptr; } + EpsCopyOutputStream* EpsCopy() { return &impl_; } + + private: + EpsCopyOutputStream impl_; + uint8* cur_; + int64 start_count_; + static std::atomic default_serialization_deterministic_; + + // See above. Other projects may use "friend" to allow them to call this. + // After SetDefaultSerializationDeterministic() completes, all protocol + // buffer serializations will be deterministic by default. Thread safe. + // However, the meaning of "after" is subtle here: to be safe, each thread + // that wants deterministic serialization by default needs to call + // SetDefaultSerializationDeterministic() or ensure on its own that another + // thread has done so. + friend void internal::MapTestForceDeterministic(); + static void SetDefaultSerializationDeterministic() { + default_serialization_deterministic_.store(true, std::memory_order_relaxed); + } + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CodedOutputStream); +}; + +// inline methods ==================================================== +// The vast majority of varints are only one byte. These inline +// methods optimize for that case. + +inline bool CodedInputStream::ReadVarint32(uint32* value) { + uint32 v = 0; + if (PROTOBUF_PREDICT_TRUE(buffer_ < buffer_end_)) { + v = *buffer_; + if (v < 0x80) { + *value = v; + Advance(1); + return true; + } + } + int64 result = ReadVarint32Fallback(v); + *value = static_cast(result); + return result >= 0; +} + +inline bool CodedInputStream::ReadVarint64(uint64* value) { + if (PROTOBUF_PREDICT_TRUE(buffer_ < buffer_end_) && *buffer_ < 0x80) { + *value = *buffer_; + Advance(1); + return true; + } + std::pair p = ReadVarint64Fallback(); + *value = p.first; + return p.second; +} + +inline bool CodedInputStream::ReadVarintSizeAsInt(int* value) { + if (PROTOBUF_PREDICT_TRUE(buffer_ < buffer_end_)) { + int v = *buffer_; + if (v < 0x80) { + *value = v; + Advance(1); + return true; + } + } + *value = ReadVarintSizeAsIntFallback(); + return *value >= 0; +} + +// static +inline const uint8* CodedInputStream::ReadLittleEndian32FromArray( + const uint8* buffer, uint32* value) { +#if defined(PROTOBUF_LITTLE_ENDIAN) + memcpy(value, buffer, sizeof(*value)); + return buffer + sizeof(*value); +#else + *value = (static_cast(buffer[0])) | + (static_cast(buffer[1]) << 8) | + (static_cast(buffer[2]) << 16) | + (static_cast(buffer[3]) << 24); + return buffer + sizeof(*value); +#endif +} +// static +inline const uint8* CodedInputStream::ReadLittleEndian64FromArray( + const uint8* buffer, uint64* value) { +#if defined(PROTOBUF_LITTLE_ENDIAN) + memcpy(value, buffer, sizeof(*value)); + return buffer + sizeof(*value); +#else + uint32 part0 = (static_cast(buffer[0])) | + (static_cast(buffer[1]) << 8) | + (static_cast(buffer[2]) << 16) | + (static_cast(buffer[3]) << 24); + uint32 part1 = (static_cast(buffer[4])) | + (static_cast(buffer[5]) << 8) | + (static_cast(buffer[6]) << 16) | + (static_cast(buffer[7]) << 24); + *value = static_cast(part0) | (static_cast(part1) << 32); + return buffer + sizeof(*value); +#endif +} + +inline bool CodedInputStream::ReadLittleEndian32(uint32* value) { +#if defined(PROTOBUF_LITTLE_ENDIAN) + if (PROTOBUF_PREDICT_TRUE(BufferSize() >= static_cast(sizeof(*value)))) { + buffer_ = ReadLittleEndian32FromArray(buffer_, value); + return true; + } else { + return ReadLittleEndian32Fallback(value); + } +#else + return ReadLittleEndian32Fallback(value); +#endif +} + +inline bool CodedInputStream::ReadLittleEndian64(uint64* value) { +#if defined(PROTOBUF_LITTLE_ENDIAN) + if (PROTOBUF_PREDICT_TRUE(BufferSize() >= static_cast(sizeof(*value)))) { + buffer_ = ReadLittleEndian64FromArray(buffer_, value); + return true; + } else { + return ReadLittleEndian64Fallback(value); + } +#else + return ReadLittleEndian64Fallback(value); +#endif +} + +inline uint32 CodedInputStream::ReadTagNoLastTag() { + uint32 v = 0; + if (PROTOBUF_PREDICT_TRUE(buffer_ < buffer_end_)) { + v = *buffer_; + if (v < 0x80) { + Advance(1); + return v; + } + } + v = ReadTagFallback(v); + return v; +} + +inline std::pair CodedInputStream::ReadTagWithCutoffNoLastTag( + uint32 cutoff) { + // In performance-sensitive code we can expect cutoff to be a compile-time + // constant, and things like "cutoff >= kMax1ByteVarint" to be evaluated at + // compile time. + uint32 first_byte_or_zero = 0; + if (PROTOBUF_PREDICT_TRUE(buffer_ < buffer_end_)) { + // Hot case: buffer_ non_empty, buffer_[0] in [1, 128). + // TODO(gpike): Is it worth rearranging this? E.g., if the number of fields + // is large enough then is it better to check for the two-byte case first? + first_byte_or_zero = buffer_[0]; + if (static_cast(buffer_[0]) > 0) { + const uint32 kMax1ByteVarint = 0x7f; + uint32 tag = buffer_[0]; + Advance(1); + return std::make_pair(tag, cutoff >= kMax1ByteVarint || tag <= cutoff); + } + // Other hot case: cutoff >= 0x80, buffer_ has at least two bytes available, + // and tag is two bytes. The latter is tested by bitwise-and-not of the + // first byte and the second byte. + if (cutoff >= 0x80 && PROTOBUF_PREDICT_TRUE(buffer_ + 1 < buffer_end_) && + PROTOBUF_PREDICT_TRUE((buffer_[0] & ~buffer_[1]) >= 0x80)) { + const uint32 kMax2ByteVarint = (0x7f << 7) + 0x7f; + uint32 tag = (1u << 7) * buffer_[1] + (buffer_[0] - 0x80); + Advance(2); + // It might make sense to test for tag == 0 now, but it is so rare that + // that we don't bother. A varint-encoded 0 should be one byte unless + // the encoder lost its mind. The second part of the return value of + // this function is allowed to be either true or false if the tag is 0, + // so we don't have to check for tag == 0. We may need to check whether + // it exceeds cutoff. + bool at_or_below_cutoff = cutoff >= kMax2ByteVarint || tag <= cutoff; + return std::make_pair(tag, at_or_below_cutoff); + } + } + // Slow path + const uint32 tag = ReadTagFallback(first_byte_or_zero); + return std::make_pair(tag, static_cast(tag - 1) < cutoff); +} + +inline bool CodedInputStream::LastTagWas(uint32 expected) { + return last_tag_ == expected; +} + +inline bool CodedInputStream::ConsumedEntireMessage() { + return legitimate_message_end_; +} + +inline bool CodedInputStream::ExpectTag(uint32 expected) { + if (expected < (1 << 7)) { + if (PROTOBUF_PREDICT_TRUE(buffer_ < buffer_end_) && + buffer_[0] == expected) { + Advance(1); + return true; + } else { + return false; + } + } else if (expected < (1 << 14)) { + if (PROTOBUF_PREDICT_TRUE(BufferSize() >= 2) && + buffer_[0] == static_cast(expected | 0x80) && + buffer_[1] == static_cast(expected >> 7)) { + Advance(2); + return true; + } else { + return false; + } + } else { + // Don't bother optimizing for larger values. + return false; + } +} + +inline const uint8* CodedInputStream::ExpectTagFromArray(const uint8* buffer, + uint32 expected) { + if (expected < (1 << 7)) { + if (buffer[0] == expected) { + return buffer + 1; + } + } else if (expected < (1 << 14)) { + if (buffer[0] == static_cast(expected | 0x80) && + buffer[1] == static_cast(expected >> 7)) { + return buffer + 2; + } + } + return nullptr; +} + +inline void CodedInputStream::GetDirectBufferPointerInline(const void** data, + int* size) { + *data = buffer_; + *size = static_cast(buffer_end_ - buffer_); +} + +inline bool CodedInputStream::ExpectAtEnd() { + // If we are at a limit we know no more bytes can be read. Otherwise, it's + // hard to say without calling Refresh(), and we'd rather not do that. + + if (buffer_ == buffer_end_ && ((buffer_size_after_limit_ != 0) || + (total_bytes_read_ == current_limit_))) { + last_tag_ = 0; // Pretend we called ReadTag()... + legitimate_message_end_ = true; // ... and it hit EOF. + return true; + } else { + return false; + } +} + +inline int CodedInputStream::CurrentPosition() const { + return total_bytes_read_ - (BufferSize() + buffer_size_after_limit_); +} + +inline void CodedInputStream::Advance(int amount) { buffer_ += amount; } + +inline void CodedInputStream::SetRecursionLimit(int limit) { + recursion_budget_ += limit - recursion_limit_; + recursion_limit_ = limit; +} + +inline bool CodedInputStream::IncrementRecursionDepth() { + --recursion_budget_; + return recursion_budget_ >= 0; +} + +inline void CodedInputStream::DecrementRecursionDepth() { + if (recursion_budget_ < recursion_limit_) ++recursion_budget_; +} + +inline void CodedInputStream::UnsafeDecrementRecursionDepth() { + assert(recursion_budget_ < recursion_limit_); + ++recursion_budget_; +} + +inline void CodedInputStream::SetExtensionRegistry(const DescriptorPool* pool, + MessageFactory* factory) { + extension_pool_ = pool; + extension_factory_ = factory; +} + +inline const DescriptorPool* CodedInputStream::GetExtensionPool() { + return extension_pool_; +} + +inline MessageFactory* CodedInputStream::GetExtensionFactory() { + return extension_factory_; +} + +inline int CodedInputStream::BufferSize() const { + return static_cast(buffer_end_ - buffer_); +} + +inline CodedInputStream::CodedInputStream(ZeroCopyInputStream* input) + : buffer_(nullptr), + buffer_end_(nullptr), + input_(input), + total_bytes_read_(0), + overflow_bytes_(0), + last_tag_(0), + legitimate_message_end_(false), + aliasing_enabled_(false), + current_limit_(kint32max), + buffer_size_after_limit_(0), + total_bytes_limit_(kDefaultTotalBytesLimit), + recursion_budget_(default_recursion_limit_), + recursion_limit_(default_recursion_limit_), + extension_pool_(nullptr), + extension_factory_(nullptr) { + // Eagerly Refresh() so buffer space is immediately available. + Refresh(); +} + +inline CodedInputStream::CodedInputStream(const uint8* buffer, int size) + : buffer_(buffer), + buffer_end_(buffer + size), + input_(nullptr), + total_bytes_read_(size), + overflow_bytes_(0), + last_tag_(0), + legitimate_message_end_(false), + aliasing_enabled_(false), + current_limit_(size), + buffer_size_after_limit_(0), + total_bytes_limit_(kDefaultTotalBytesLimit), + recursion_budget_(default_recursion_limit_), + recursion_limit_(default_recursion_limit_), + extension_pool_(nullptr), + extension_factory_(nullptr) { + // Note that setting current_limit_ == size is important to prevent some + // code paths from trying to access input_ and segfaulting. +} + +inline bool CodedInputStream::IsFlat() const { return input_ == nullptr; } + +inline bool CodedInputStream::Skip(int count) { + if (count < 0) return false; // security: count is often user-supplied + + const int original_buffer_size = BufferSize(); + + if (count <= original_buffer_size) { + // Just skipping within the current buffer. Easy. + Advance(count); + return true; + } + + return SkipFallback(count, original_buffer_size); +} + +inline uint8* CodedOutputStream::WriteVarint32ToArray(uint32 value, + uint8* target) { + return EpsCopyOutputStream::UnsafeVarint(value, target); +} + +inline uint8* CodedOutputStream::WriteVarint64ToArray(uint64 value, + uint8* target) { + return EpsCopyOutputStream::UnsafeVarint(value, target); +} + +inline void CodedOutputStream::WriteVarint32SignExtended(int32 value) { + WriteVarint64(static_cast(value)); +} + +inline uint8* CodedOutputStream::WriteVarint32SignExtendedToArray( + int32 value, uint8* target) { + return WriteVarint64ToArray(static_cast(value), target); +} + +inline uint8* CodedOutputStream::WriteLittleEndian32ToArray(uint32 value, + uint8* target) { +#if defined(PROTOBUF_LITTLE_ENDIAN) + memcpy(target, &value, sizeof(value)); +#else + target[0] = static_cast(value); + target[1] = static_cast(value >> 8); + target[2] = static_cast(value >> 16); + target[3] = static_cast(value >> 24); +#endif + return target + sizeof(value); +} + +inline uint8* CodedOutputStream::WriteLittleEndian64ToArray(uint64 value, + uint8* target) { +#if defined(PROTOBUF_LITTLE_ENDIAN) + memcpy(target, &value, sizeof(value)); +#else + uint32 part0 = static_cast(value); + uint32 part1 = static_cast(value >> 32); + + target[0] = static_cast(part0); + target[1] = static_cast(part0 >> 8); + target[2] = static_cast(part0 >> 16); + target[3] = static_cast(part0 >> 24); + target[4] = static_cast(part1); + target[5] = static_cast(part1 >> 8); + target[6] = static_cast(part1 >> 16); + target[7] = static_cast(part1 >> 24); +#endif + return target + sizeof(value); +} + +inline void CodedOutputStream::WriteVarint32(uint32 value) { + cur_ = impl_.EnsureSpace(cur_); + SetCur(WriteVarint32ToArray(value, Cur())); +} + +inline void CodedOutputStream::WriteVarint64(uint64 value) { + cur_ = impl_.EnsureSpace(cur_); + SetCur(WriteVarint64ToArray(value, Cur())); +} + +inline void CodedOutputStream::WriteTag(uint32 value) { WriteVarint32(value); } + +inline uint8* CodedOutputStream::WriteTagToArray(uint32 value, uint8* target) { + return WriteVarint32ToArray(value, target); +} + +inline size_t CodedOutputStream::VarintSize32(uint32 value) { + // This computes value == 0 ? 1 : floor(log2(value)) / 7 + 1 + // Use an explicit multiplication to implement the divide of + // a number in the 1..31 range. + // Explicit OR 0x1 to avoid calling Bits::Log2FloorNonZero(0), which is + // undefined. + uint32 log2value = Bits::Log2FloorNonZero(value | 0x1); + return static_cast((log2value * 9 + 73) / 64); +} + +inline size_t CodedOutputStream::VarintSize64(uint64 value) { + // This computes value == 0 ? 1 : floor(log2(value)) / 7 + 1 + // Use an explicit multiplication to implement the divide of + // a number in the 1..63 range. + // Explicit OR 0x1 to avoid calling Bits::Log2FloorNonZero(0), which is + // undefined. + uint32 log2value = Bits::Log2FloorNonZero64(value | 0x1); + return static_cast((log2value * 9 + 73) / 64); +} + +inline size_t CodedOutputStream::VarintSize32SignExtended(int32 value) { + if (value < 0) { + return 10; // TODO(kenton): Make this a symbolic constant. + } else { + return VarintSize32(static_cast(value)); + } +} + +inline void CodedOutputStream::WriteString(const std::string& str) { + WriteRaw(str.data(), static_cast(str.size())); +} + +inline void CodedOutputStream::WriteRawMaybeAliased(const void* data, + int size) { + cur_ = impl_.WriteRawMaybeAliased(data, size, cur_); +} + +inline uint8* CodedOutputStream::WriteRawToArray(const void* data, int size, + uint8* target) { + memcpy(target, data, size); + return target + size; +} + +inline uint8* CodedOutputStream::WriteStringToArray(const std::string& str, + uint8* target) { + return WriteRawToArray(str.data(), static_cast(str.size()), target); +} + +} // namespace io +} // namespace protobuf +} // namespace google + +#if defined(_MSC_VER) && _MSC_VER >= 1300 && !defined(__INTEL_COMPILER) +#pragma runtime_checks("c", restore) +#endif // _MSC_VER && !defined(__INTEL_COMPILER) + +#include + +#endif // GOOGLE_PROTOBUF_IO_CODED_STREAM_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/io/gzip_stream.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/io/gzip_stream.h new file mode 100644 index 0000000000000000000000000000000000000000..cb0dac875a0720d0143d23bb2163f3f249cc5594 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/io/gzip_stream.h @@ -0,0 +1,207 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: brianolson@google.com (Brian Olson) +// +// This file contains the definition for classes GzipInputStream and +// GzipOutputStream. +// +// GzipInputStream decompresses data from an underlying +// ZeroCopyInputStream and provides the decompressed data as a +// ZeroCopyInputStream. +// +// GzipOutputStream is an ZeroCopyOutputStream that compresses data to +// an underlying ZeroCopyOutputStream. + +#ifndef GOOGLE_PROTOBUF_IO_GZIP_STREAM_H__ +#define GOOGLE_PROTOBUF_IO_GZIP_STREAM_H__ + + +#include +#include +#include +#include + +#include + +namespace google { +namespace protobuf { +namespace io { + +// A ZeroCopyInputStream that reads compressed data through zlib +class PROTOBUF_EXPORT GzipInputStream : public ZeroCopyInputStream { + public: + // Format key for constructor + enum Format { + // zlib will autodetect gzip header or deflate stream + AUTO = 0, + + // GZIP streams have some extra header data for file attributes. + GZIP = 1, + + // Simpler zlib stream format. + ZLIB = 2, + }; + + // buffer_size and format may be -1 for default of 64kB and GZIP format + explicit GzipInputStream(ZeroCopyInputStream* sub_stream, + Format format = AUTO, int buffer_size = -1); + virtual ~GzipInputStream(); + + // Return last error message or NULL if no error. + inline const char* ZlibErrorMessage() const { return zcontext_.msg; } + inline int ZlibErrorCode() const { return zerror_; } + + // implements ZeroCopyInputStream ---------------------------------- + bool Next(const void** data, int* size); + void BackUp(int count); + bool Skip(int count); + int64_t ByteCount() const; + + private: + Format format_; + + ZeroCopyInputStream* sub_stream_; + + z_stream zcontext_; + int zerror_; + + void* output_buffer_; + void* output_position_; + size_t output_buffer_length_; + int64 byte_count_; + + int Inflate(int flush); + void DoNextOutput(const void** data, int* size); + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(GzipInputStream); +}; + +class PROTOBUF_EXPORT GzipOutputStream : public ZeroCopyOutputStream { + public: + // Format key for constructor + enum Format { + // GZIP streams have some extra header data for file attributes. + GZIP = 1, + + // Simpler zlib stream format. + ZLIB = 2, + }; + + struct PROTOBUF_EXPORT Options { + // Defaults to GZIP. + Format format; + + // What size buffer to use internally. Defaults to 64kB. + int buffer_size; + + // A number between 0 and 9, where 0 is no compression and 9 is best + // compression. Defaults to Z_DEFAULT_COMPRESSION (see zlib.h). + int compression_level; + + // Defaults to Z_DEFAULT_STRATEGY. Can also be set to Z_FILTERED, + // Z_HUFFMAN_ONLY, or Z_RLE. See the documentation for deflateInit2 in + // zlib.h for definitions of these constants. + int compression_strategy; + + Options(); // Initializes with default values. + }; + + // Create a GzipOutputStream with default options. + explicit GzipOutputStream(ZeroCopyOutputStream* sub_stream); + + // Create a GzipOutputStream with the given options. + GzipOutputStream(ZeroCopyOutputStream* sub_stream, const Options& options); + + virtual ~GzipOutputStream(); + + // Return last error message or NULL if no error. + inline const char* ZlibErrorMessage() const { return zcontext_.msg; } + inline int ZlibErrorCode() const { return zerror_; } + + // Flushes data written so far to zipped data in the underlying stream. + // It is the caller's responsibility to flush the underlying stream if + // necessary. + // Compression may be less efficient stopping and starting around flushes. + // Returns true if no error. + // + // Please ensure that block size is > 6. Here is an excerpt from the zlib + // doc that explains why: + // + // In the case of a Z_FULL_FLUSH or Z_SYNC_FLUSH, make sure that avail_out + // is greater than six to avoid repeated flush markers due to + // avail_out == 0 on return. + bool Flush(); + + // Writes out all data and closes the gzip stream. + // It is the caller's responsibility to close the underlying stream if + // necessary. + // Returns true if no error. + bool Close(); + + // implements ZeroCopyOutputStream --------------------------------- + bool Next(void** data, int* size); + void BackUp(int count); + int64_t ByteCount() const; + + private: + ZeroCopyOutputStream* sub_stream_; + // Result from calling Next() on sub_stream_ + void* sub_data_; + int sub_data_size_; + + z_stream zcontext_; + int zerror_; + void* input_buffer_; + size_t input_buffer_length_; + + // Shared constructor code. + void Init(ZeroCopyOutputStream* sub_stream, const Options& options); + + // Do some compression. + // Takes zlib flush mode. + // Returns zlib error code. + int Deflate(int flush); + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(GzipOutputStream); +}; + +} // namespace io +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_IO_GZIP_STREAM_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/io/io_win32.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/io/io_win32.h new file mode 100644 index 0000000000000000000000000000000000000000..bbbae7e4f95770749524b88618561c0759e3aa3b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/io/io_win32.h @@ -0,0 +1,144 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: laszlocsomor@google.com (Laszlo Csomor) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. + +// This file contains the declarations for Windows implementations of +// commonly used POSIX functions such as open(2) and access(2), as well +// as macro definitions for flags of these functions. +// +// By including this file you'll redefine open/access/etc. to +// ::google::protobuf::io::win32::{open/access/etc.}. +// Make sure you don't include a header that attempts to redeclare or +// redefine these functions, that'll lead to confusing compilation +// errors. It's best to #include this file as the last one to ensure that. +// +// This file is only used on Windows, it's empty on other platforms. + +#ifndef GOOGLE_PROTOBUF_IO_IO_WIN32_H__ +#define GOOGLE_PROTOBUF_IO_IO_WIN32_H__ + +#if defined(_WIN32) + +#include +#include + +#include +#include + +// Compilers on Windows other than MSVC (e.g. Cygwin, MinGW32) define the +// following functions already, except for mkdir. +namespace google { +namespace protobuf { +namespace io { +namespace win32 { + +PROTOBUF_EXPORT FILE* fopen(const char* path, const char* mode); +PROTOBUF_EXPORT int access(const char* path, int mode); +PROTOBUF_EXPORT int chdir(const char* path); +PROTOBUF_EXPORT int close(int fd); +PROTOBUF_EXPORT int dup(int fd); +PROTOBUF_EXPORT int dup2(int fd1, int fd2); +PROTOBUF_EXPORT int mkdir(const char* path, int _mode); +PROTOBUF_EXPORT int open(const char* path, int flags, int mode = 0); +PROTOBUF_EXPORT int read(int fd, void* buffer, size_t size); +PROTOBUF_EXPORT int setmode(int fd, int mode); +PROTOBUF_EXPORT int stat(const char* path, struct _stat* buffer); +PROTOBUF_EXPORT int write(int fd, const void* buffer, size_t size); +PROTOBUF_EXPORT std::wstring testonly_utf8_to_winpath(const char* path); + +enum class ExpandWildcardsResult { + kSuccess = 0, + kErrorNoMatchingFile = 1, + kErrorInputPathConversion = 2, + kErrorOutputPathConversion = 3, +}; + +// Expand wildcards in a path pattern, feed the result to a consumer function. +// +// `path` must be a valid, Windows-style path. It may be absolute, or relative +// to the current working directory, and it may contain wildcards ("*" and "?") +// in the last path segment. This function passes all matching file names to +// `consume`. The resulting paths may not be absolute nor normalized. +// +// The function returns a value from `ExpandWildcardsResult`. +PROTOBUF_EXPORT ExpandWildcardsResult ExpandWildcards( + const std::string& path, std::function consume); + +namespace strings { + +// Convert from UTF-16 to Active-Code-Page-encoded or to UTF-8-encoded text. +PROTOBUF_EXPORT bool wcs_to_mbs(const wchar_t* s, std::string* out, + bool outUtf8); + +// Convert from Active-Code-Page-encoded or UTF-8-encoded text to UTF-16. +PROTOBUF_EXPORT bool mbs_to_wcs(const char* s, std::wstring* out, bool inUtf8); + +// Convert from UTF-8-encoded text to UTF-16. +PROTOBUF_EXPORT bool utf8_to_wcs(const char* input, std::wstring* out); + +// Convert from UTF-16-encoded text to UTF-8. +PROTOBUF_EXPORT bool wcs_to_utf8(const wchar_t* input, std::string* out); + +} // namespace strings + +} // namespace win32 +} // namespace io +} // namespace protobuf +} // namespace google + +#ifndef W_OK +#define W_OK 02 // not defined by MSVC for whatever reason +#endif + +#ifndef F_OK +#define F_OK 00 // not defined by MSVC for whatever reason +#endif + +#ifndef STDIN_FILENO +#define STDIN_FILENO 0 +#endif + +#ifndef STDOUT_FILENO +#define STDOUT_FILENO 1 +#endif + +#include + +#endif // defined(_WIN32) + +#endif // GOOGLE_PROTOBUF_IO_IO_WIN32_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/io/printer.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/io/printer.h new file mode 100644 index 0000000000000000000000000000000000000000..ad6985d6f019a7508ae929148a30b5b1b1d2268a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/io/printer.h @@ -0,0 +1,390 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// Utility class for writing text to a ZeroCopyOutputStream. + +#ifndef GOOGLE_PROTOBUF_IO_PRINTER_H__ +#define GOOGLE_PROTOBUF_IO_PRINTER_H__ + + +#include +#include +#include + +#include +#include + +namespace google { +namespace protobuf { +namespace io { + +class ZeroCopyOutputStream; // zero_copy_stream.h + +// Records annotations about a Printer's output. +class PROTOBUF_EXPORT AnnotationCollector { + public: + // Annotation is a offset range and a payload pair. + typedef std::pair, std::string> Annotation; + + // Records that the bytes in file_path beginning with begin_offset and ending + // before end_offset are associated with the SourceCodeInfo-style path. + virtual void AddAnnotation(size_t begin_offset, size_t end_offset, + const std::string& file_path, + const std::vector& path) = 0; + + // TODO(gerbens) I don't see why we need virtuals here. Just a vector of + // range, payload pairs stored in a context should suffice. + virtual void AddAnnotationNew(Annotation& a) {} + + virtual ~AnnotationCollector() {} +}; + +// Records annotations about a Printer's output to the given protocol buffer, +// assuming that the buffer has an ::Annotation message exposing path, +// source_file, begin and end fields. +template +class AnnotationProtoCollector : public AnnotationCollector { + public: + // annotation_proto is the protocol buffer to which new Annotations should be + // added. It is not owned by the AnnotationProtoCollector. + explicit AnnotationProtoCollector(AnnotationProto* annotation_proto) + : annotation_proto_(annotation_proto) {} + + // Override for AnnotationCollector::AddAnnotation. + virtual void AddAnnotation(size_t begin_offset, size_t end_offset, + const std::string& file_path, + const std::vector& path) { + typename AnnotationProto::Annotation* annotation = + annotation_proto_->add_annotation(); + for (int i = 0; i < path.size(); ++i) { + annotation->add_path(path[i]); + } + annotation->set_source_file(file_path); + annotation->set_begin(begin_offset); + annotation->set_end(end_offset); + } + // Override for AnnotationCollector::AddAnnotation. + virtual void AddAnnotationNew(Annotation& a) { + auto* annotation = annotation_proto_->add_annotation(); + annotation->ParseFromString(a.second); + annotation->set_begin(a.first.first); + annotation->set_end(a.first.second); + } + + private: + // The protocol buffer to which new annotations should be added. + AnnotationProto* const annotation_proto_; +}; + +// This simple utility class assists in code generation. It basically +// allows the caller to define a set of variables and then output some +// text with variable substitutions. Example usage: +// +// Printer printer(output, '$'); +// map vars; +// vars["name"] = "Bob"; +// printer.Print(vars, "My name is $name$."); +// +// The above writes "My name is Bob." to the output stream. +// +// Printer aggressively enforces correct usage, crashing (with assert failures) +// in the case of undefined variables in debug builds. This helps greatly in +// debugging code which uses it. +// +// If a Printer is constructed with an AnnotationCollector, it will provide it +// with annotations that connect the Printer's output to paths that can identify +// various descriptors. In the above example, if person_ is a descriptor that +// identifies Bob, we can associate the output string "My name is Bob." with +// a source path pointing to that descriptor with: +// +// printer.Annotate("name", person_); +// +// The AnnotationCollector will be sent an annotation linking the output range +// covering "Bob" to the logical path provided by person_. Tools may use +// this association to (for example) link "Bob" in the output back to the +// source file that defined the person_ descriptor identifying Bob. +// +// Annotate can only examine variables substituted during the last call to +// Print. It is invalid to refer to a variable that was used multiple times +// in a single Print call. +// +// In full generality, one may specify a range of output text using a beginning +// substitution variable and an ending variable. The resulting annotation will +// span from the first character of the substituted value for the beginning +// variable to the last character of the substituted value for the ending +// variable. For example, the Annotate call above is equivalent to this one: +// +// printer.Annotate("name", "name", person_); +// +// This is useful if multiple variables combine to form a single span of output +// that should be annotated with the same source path. For example: +// +// Printer printer(output, '$'); +// map vars; +// vars["first"] = "Alice"; +// vars["last"] = "Smith"; +// printer.Print(vars, "My name is $first$ $last$."); +// printer.Annotate("first", "last", person_); +// +// This code would associate the span covering "Alice Smith" in the output with +// the person_ descriptor. +// +// Note that the beginning variable must come before (or overlap with, in the +// case of zero-sized substitution values) the ending variable. +// +// It is also sometimes useful to use variables with zero-sized values as +// markers. This avoids issues with multiple references to the same variable +// and also allows annotation ranges to span literal text from the Print +// templates: +// +// Printer printer(output, '$'); +// map vars; +// vars["foo"] = "bar"; +// vars["function"] = "call"; +// vars["mark"] = ""; +// printer.Print(vars, "$function$($foo$,$foo$)$mark$"); +// printer.Annotate("function", "mark", call_); +// +// This code associates the span covering "call(bar,bar)" in the output with the +// call_ descriptor. + +class PROTOBUF_EXPORT Printer { + public: + // Create a printer that writes text to the given output stream. Use the + // given character as the delimiter for variables. + Printer(ZeroCopyOutputStream* output, char variable_delimiter); + + // Create a printer that writes text to the given output stream. Use the + // given character as the delimiter for variables. If annotation_collector + // is not null, Printer will provide it with annotations about code written + // to the stream. annotation_collector is not owned by Printer. + Printer(ZeroCopyOutputStream* output, char variable_delimiter, + AnnotationCollector* annotation_collector); + + ~Printer(); + + // Link a substitution variable emitted by the last call to Print to the + // object described by descriptor. + template + void Annotate(const char* varname, const SomeDescriptor* descriptor) { + Annotate(varname, varname, descriptor); + } + + // Link the output range defined by the substitution variables as emitted by + // the last call to Print to the object described by descriptor. The range + // begins at begin_varname's value and ends after the last character of the + // value substituted for end_varname. + template + void Annotate(const char* begin_varname, const char* end_varname, + const SomeDescriptor* descriptor) { + if (annotation_collector_ == NULL) { + // Annotations aren't turned on for this Printer, so don't pay the cost + // of building the location path. + return; + } + std::vector path; + descriptor->GetLocationPath(&path); + Annotate(begin_varname, end_varname, descriptor->file()->name(), path); + } + + // Link a substitution variable emitted by the last call to Print to the file + // with path file_name. + void Annotate(const char* varname, const std::string& file_name) { + Annotate(varname, varname, file_name); + } + + // Link the output range defined by the substitution variables as emitted by + // the last call to Print to the file with path file_name. The range begins + // at begin_varname's value and ends after the last character of the value + // substituted for end_varname. + void Annotate(const char* begin_varname, const char* end_varname, + const std::string& file_name) { + if (annotation_collector_ == NULL) { + // Annotations aren't turned on for this Printer. + return; + } + std::vector empty_path; + Annotate(begin_varname, end_varname, file_name, empty_path); + } + + // Print some text after applying variable substitutions. If a particular + // variable in the text is not defined, this will crash. Variables to be + // substituted are identified by their names surrounded by delimiter + // characters (as given to the constructor). The variable bindings are + // defined by the given map. + void Print(const std::map& variables, + const char* text); + + // Like the first Print(), except the substitutions are given as parameters. + template + void Print(const char* text, const Args&... args) { + std::map vars; + PrintInternal(&vars, text, args...); + } + + // Indent text by two spaces. After calling Indent(), two spaces will be + // inserted at the beginning of each line of text. Indent() may be called + // multiple times to produce deeper indents. + void Indent(); + + // Reduces the current indent level by two spaces, or crashes if the indent + // level is zero. + void Outdent(); + + // Write a string to the output buffer. + // This method does not look for newlines to add indentation. + void PrintRaw(const std::string& data); + + // Write a zero-delimited string to output buffer. + // This method does not look for newlines to add indentation. + void PrintRaw(const char* data); + + // Write some bytes to the output buffer. + // This method does not look for newlines to add indentation. + void WriteRaw(const char* data, int size); + + // FormatInternal is a helper function not meant to use directly, use + // compiler::cpp::Formatter instead. This function is meant to support + // formatting text using named variables (eq. "$foo$) from a lookup map (vars) + // and variables directly supplied by arguments (eq "$1$" meaning first + // argument which is the zero index element of args). + void FormatInternal(const std::vector& args, + const std::map& vars, + const char* format); + + // True if any write to the underlying stream failed. (We don't just + // crash in this case because this is an I/O failure, not a programming + // error.) + bool failed() const { return failed_; } + + private: + // Link the output range defined by the substitution variables as emitted by + // the last call to Print to the object found at the SourceCodeInfo-style path + // in a file with path file_path. The range begins at the start of + // begin_varname's value and ends after the last character of the value + // substituted for end_varname. Note that begin_varname and end_varname + // may refer to the same variable. + void Annotate(const char* begin_varname, const char* end_varname, + const std::string& file_path, const std::vector& path); + + // Base case + void PrintInternal(std::map* vars, + const char* text) { + Print(*vars, text); + } + + template + void PrintInternal(std::map* vars, const char* text, + const char* key, const std::string& value, + const Args&... args) { + (*vars)[key] = value; + PrintInternal(vars, text, args...); + } + + // Copy size worth of bytes from data to buffer_. + void CopyToBuffer(const char* data, int size); + + void push_back(char c) { + if (failed_) return; + if (buffer_size_ == 0) { + if (!Next()) return; + } + *buffer_++ = c; + buffer_size_--; + offset_++; + } + + bool Next(); + + inline void IndentIfAtStart(); + const char* WriteVariable( + const std::vector& args, + const std::map& vars, const char* format, + int* arg_index, + std::vector* annotations); + + const char variable_delimiter_; + + ZeroCopyOutputStream* const output_; + char* buffer_; + int buffer_size_; + // The current position, in bytes, in the output stream. This is equivalent + // to the total number of bytes that have been written so far. This value is + // used to calculate annotation ranges in the substitutions_ map below. + size_t offset_; + + std::string indent_; + bool at_start_of_line_; + bool failed_; + + // A map from variable name to [start, end) offsets in the output buffer. + // These refer to the offsets used for a variable after the last call to + // Print. If a variable was used more than once, the entry used in + // this map is set to a negative-length span. For singly-used variables, the + // start offset is the beginning of the substitution; the end offset is the + // last byte of the substitution plus one (such that (end - start) is the + // length of the substituted string). + std::map > substitutions_; + + // Keeps track of the keys in substitutions_ that need to be updated when + // indents are inserted. These are keys that refer to the beginning of the + // current line. + std::vector line_start_variables_; + + // Returns true and sets range to the substitution range in the output for + // varname if varname was used once in the last call to Print. If varname + // was not used, or if it was used multiple times, returns false (and + // fails a debug assertion). + bool GetSubstitutionRange(const char* varname, + std::pair* range); + + // If non-null, annotation_collector_ is used to store annotations about + // generated code. + AnnotationCollector* const annotation_collector_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(Printer); +}; + +} // namespace io +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_IO_PRINTER_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/io/strtod.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/io/strtod.h new file mode 100644 index 0000000000000000000000000000000000000000..e05ba81b001b4bd4c5c0c59a175b02083b74e501 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/io/strtod.h @@ -0,0 +1,60 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// A locale-independent version of strtod(), used to parse floating +// point default values in .proto files, where the decimal separator +// is always a dot. + +#ifndef GOOGLE_PROTOBUF_IO_STRTOD_H__ +#define GOOGLE_PROTOBUF_IO_STRTOD_H__ + +namespace google { +namespace protobuf { +namespace io { + +// A locale-independent version of the standard strtod(), which always +// uses a dot as the decimal separator. +double NoLocaleStrtod(const char* str, char** endptr); + +// Casts a double value to a float value. If the value is outside of the +// representable range of float, it will be converted to positive or negative +// infinity. +float SafeDoubleToFloat(double value); + +} // namespace io +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_IO_STRTOD_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/io/tokenizer.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/io/tokenizer.h new file mode 100644 index 0000000000000000000000000000000000000000..984a0597e6a916eaa809333d99e138b7c962cb40 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/io/tokenizer.h @@ -0,0 +1,418 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// Class for parsing tokenized text from a ZeroCopyInputStream. + +#ifndef GOOGLE_PROTOBUF_IO_TOKENIZER_H__ +#define GOOGLE_PROTOBUF_IO_TOKENIZER_H__ + + +#include +#include + +#include +#include +#include + +namespace google { +namespace protobuf { +namespace io { + +class ZeroCopyInputStream; // zero_copy_stream.h + +// Defined in this file. +class ErrorCollector; +class Tokenizer; + +// By "column number", the proto compiler refers to a count of the number +// of bytes before a given byte, except that a tab character advances to +// the next multiple of 8 bytes. Note in particular that column numbers +// are zero-based, while many user interfaces use one-based column numbers. +typedef int ColumnNumber; + +// Abstract interface for an object which collects the errors that occur +// during parsing. A typical implementation might simply print the errors +// to stdout. +class PROTOBUF_EXPORT ErrorCollector { + public: + inline ErrorCollector() {} + virtual ~ErrorCollector(); + + // Indicates that there was an error in the input at the given line and + // column numbers. The numbers are zero-based, so you may want to add + // 1 to each before printing them. + virtual void AddError(int line, ColumnNumber column, + const std::string& message) = 0; + + // Indicates that there was a warning in the input at the given line and + // column numbers. The numbers are zero-based, so you may want to add + // 1 to each before printing them. + virtual void AddWarning(int line, ColumnNumber column, + const std::string& message) {} + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ErrorCollector); +}; + +// This class converts a stream of raw text into a stream of tokens for +// the protocol definition parser to parse. The tokens recognized are +// similar to those that make up the C language; see the TokenType enum for +// precise descriptions. Whitespace and comments are skipped. By default, +// C- and C++-style comments are recognized, but other styles can be used by +// calling set_comment_style(). +class PROTOBUF_EXPORT Tokenizer { + public: + // Construct a Tokenizer that reads and tokenizes text from the given + // input stream and writes errors to the given error_collector. + // The caller keeps ownership of input and error_collector. + Tokenizer(ZeroCopyInputStream* input, ErrorCollector* error_collector); + ~Tokenizer(); + + enum TokenType { + TYPE_START, // Next() has not yet been called. + TYPE_END, // End of input reached. "text" is empty. + + TYPE_IDENTIFIER, // A sequence of letters, digits, and underscores, not + // starting with a digit. It is an error for a number + // to be followed by an identifier with no space in + // between. + TYPE_INTEGER, // A sequence of digits representing an integer. Normally + // the digits are decimal, but a prefix of "0x" indicates + // a hex number and a leading zero indicates octal, just + // like with C numeric literals. A leading negative sign + // is NOT included in the token; it's up to the parser to + // interpret the unary minus operator on its own. + TYPE_FLOAT, // A floating point literal, with a fractional part and/or + // an exponent. Always in decimal. Again, never + // negative. + TYPE_STRING, // A quoted sequence of escaped characters. Either single + // or double quotes can be used, but they must match. + // A string literal cannot cross a line break. + TYPE_SYMBOL, // Any other printable character, like '!' or '+'. + // Symbols are always a single character, so "!+$%" is + // four tokens. + }; + + // Structure representing a token read from the token stream. + struct Token { + TokenType type; + std::string text; // The exact text of the token as it appeared in + // the input. e.g. tokens of TYPE_STRING will still + // be escaped and in quotes. + + // "line" and "column" specify the position of the first character of + // the token within the input stream. They are zero-based. + int line; + ColumnNumber column; + ColumnNumber end_column; + }; + + // Get the current token. This is updated when Next() is called. Before + // the first call to Next(), current() has type TYPE_START and no contents. + const Token& current(); + + // Return the previous token -- i.e. what current() returned before the + // previous call to Next(). + const Token& previous(); + + // Advance to the next token. Returns false if the end of the input is + // reached. + bool Next(); + + // Like Next(), but also collects comments which appear between the previous + // and next tokens. + // + // Comments which appear to be attached to the previous token are stored + // in *prev_tailing_comments. Comments which appear to be attached to the + // next token are stored in *next_leading_comments. Comments appearing in + // between which do not appear to be attached to either will be added to + // detached_comments. Any of these parameters can be NULL to simply discard + // the comments. + // + // A series of line comments appearing on consecutive lines, with no other + // tokens appearing on those lines, will be treated as a single comment. + // + // Only the comment content is returned; comment markers (e.g. //) are + // stripped out. For block comments, leading whitespace and an asterisk will + // be stripped from the beginning of each line other than the first. Newlines + // are included in the output. + // + // Examples: + // + // optional int32 foo = 1; // Comment attached to foo. + // // Comment attached to bar. + // optional int32 bar = 2; + // + // optional string baz = 3; + // // Comment attached to baz. + // // Another line attached to baz. + // + // // Comment attached to qux. + // // + // // Another line attached to qux. + // optional double qux = 4; + // + // // Detached comment. This is not attached to qux or corge + // // because there are blank lines separating it from both. + // + // optional string corge = 5; + // /* Block comment attached + // * to corge. Leading asterisks + // * will be removed. */ + // /* Block comment attached to + // * grault. */ + // optional int32 grault = 6; + bool NextWithComments(std::string* prev_trailing_comments, + std::vector* detached_comments, + std::string* next_leading_comments); + + // Parse helpers --------------------------------------------------- + + // Parses a TYPE_FLOAT token. This never fails, so long as the text actually + // comes from a TYPE_FLOAT token parsed by Tokenizer. If it doesn't, the + // result is undefined (possibly an assert failure). + static double ParseFloat(const std::string& text); + + // Parses a TYPE_STRING token. This never fails, so long as the text actually + // comes from a TYPE_STRING token parsed by Tokenizer. If it doesn't, the + // result is undefined (possibly an assert failure). + static void ParseString(const std::string& text, std::string* output); + + // Identical to ParseString, but appends to output. + static void ParseStringAppend(const std::string& text, std::string* output); + + // Parses a TYPE_INTEGER token. Returns false if the result would be + // greater than max_value. Otherwise, returns true and sets *output to the + // result. If the text is not from a Token of type TYPE_INTEGER originally + // parsed by a Tokenizer, the result is undefined (possibly an assert + // failure). + static bool ParseInteger(const std::string& text, uint64 max_value, + uint64* output); + + // Options --------------------------------------------------------- + + // Set true to allow floats to be suffixed with the letter 'f'. Tokens + // which would otherwise be integers but which have the 'f' suffix will be + // forced to be interpreted as floats. For all other purposes, the 'f' is + // ignored. + void set_allow_f_after_float(bool value) { allow_f_after_float_ = value; } + + // Valid values for set_comment_style(). + enum CommentStyle { + // Line comments begin with "//", block comments are delimited by "/*" and + // "*/". + CPP_COMMENT_STYLE, + // Line comments begin with "#". No way to write block comments. + SH_COMMENT_STYLE + }; + + // Sets the comment style. + void set_comment_style(CommentStyle style) { comment_style_ = style; } + + // Whether to require whitespace between a number and a field name. + // Default is true. Do not use this; for Google-internal cleanup only. + void set_require_space_after_number(bool require) { + require_space_after_number_ = require; + } + + // Whether to allow string literals to span multiple lines. Default is false. + // Do not use this; for Google-internal cleanup only. + void set_allow_multiline_strings(bool allow) { + allow_multiline_strings_ = allow; + } + + // External helper: validate an identifier. + static bool IsIdentifier(const std::string& text); + + // ----------------------------------------------------------------- + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(Tokenizer); + + Token current_; // Returned by current(). + Token previous_; // Returned by previous(). + + ZeroCopyInputStream* input_; + ErrorCollector* error_collector_; + + char current_char_; // == buffer_[buffer_pos_], updated by NextChar(). + const char* buffer_; // Current buffer returned from input_. + int buffer_size_; // Size of buffer_. + int buffer_pos_; // Current position within the buffer. + bool read_error_; // Did we previously encounter a read error? + + // Line and column number of current_char_ within the whole input stream. + int line_; + ColumnNumber column_; + + // String to which text should be appended as we advance through it. + // Call RecordTo(&str) to start recording and StopRecording() to stop. + // E.g. StartToken() calls RecordTo(¤t_.text). record_start_ is the + // position within the current buffer where recording started. + std::string* record_target_; + int record_start_; + + // Options. + bool allow_f_after_float_; + CommentStyle comment_style_; + bool require_space_after_number_; + bool allow_multiline_strings_; + + // Since we count columns we need to interpret tabs somehow. We'll take + // the standard 8-character definition for lack of any way to do better. + // This must match the documentation of ColumnNumber. + static const int kTabWidth = 8; + + // ----------------------------------------------------------------- + // Helper methods. + + // Consume this character and advance to the next one. + void NextChar(); + + // Read a new buffer from the input. + void Refresh(); + + inline void RecordTo(std::string* target); + inline void StopRecording(); + + // Called when the current character is the first character of a new + // token (not including whitespace or comments). + inline void StartToken(); + // Called when the current character is the first character after the + // end of the last token. After this returns, current_.text will + // contain all text consumed since StartToken() was called. + inline void EndToken(); + + // Convenience method to add an error at the current line and column. + void AddError(const std::string& message) { + error_collector_->AddError(line_, column_, message); + } + + // ----------------------------------------------------------------- + // The following four methods are used to consume tokens of specific + // types. They are actually used to consume all characters *after* + // the first, since the calling function consumes the first character + // in order to decide what kind of token is being read. + + // Read and consume a string, ending when the given delimiter is + // consumed. + void ConsumeString(char delimiter); + + // Read and consume a number, returning TYPE_FLOAT or TYPE_INTEGER + // depending on what was read. This needs to know if the first + // character was a zero in order to correctly recognize hex and octal + // numbers. + // It also needs to know if the first character was a . to parse floating + // point correctly. + TokenType ConsumeNumber(bool started_with_zero, bool started_with_dot); + + // Consume the rest of a line. + void ConsumeLineComment(std::string* content); + // Consume until "*/". + void ConsumeBlockComment(std::string* content); + + enum NextCommentStatus { + // Started a line comment. + LINE_COMMENT, + + // Started a block comment. + BLOCK_COMMENT, + + // Consumed a slash, then realized it wasn't a comment. current_ has + // been filled in with a slash token. The caller should return it. + SLASH_NOT_COMMENT, + + // We do not appear to be starting a comment here. + NO_COMMENT + }; + + // If we're at the start of a new comment, consume it and return what kind + // of comment it is. + NextCommentStatus TryConsumeCommentStart(); + + // ----------------------------------------------------------------- + // These helper methods make the parsing code more readable. The + // "character classes" referred to are defined at the top of the .cc file. + // Basically it is a C++ class with one method: + // static bool InClass(char c); + // The method returns true if c is a member of this "class", like "Letter" + // or "Digit". + + // Returns true if the current character is of the given character + // class, but does not consume anything. + template + inline bool LookingAt(); + + // If the current character is in the given class, consume it and return + // true. Otherwise return false. + // e.g. TryConsumeOne() + template + inline bool TryConsumeOne(); + + // Like above, but try to consume the specific character indicated. + inline bool TryConsume(char c); + + // Consume zero or more of the given character class. + template + inline void ConsumeZeroOrMore(); + + // Consume one or more of the given character class or log the given + // error message. + // e.g. ConsumeOneOrMore("Expected digits."); + template + inline void ConsumeOneOrMore(const char* error); +}; + +// inline methods ==================================================== +inline const Tokenizer::Token& Tokenizer::current() { return current_; } + +inline const Tokenizer::Token& Tokenizer::previous() { return previous_; } + +inline void Tokenizer::ParseString(const std::string& text, + std::string* output) { + output->clear(); + ParseStringAppend(text, output); +} + +} // namespace io +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_IO_TOKENIZER_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/io/zero_copy_stream.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/io/zero_copy_stream.h new file mode 100644 index 0000000000000000000000000000000000000000..b310d3a56b8949247d9b7cc4fec78fd0a356c12c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/io/zero_copy_stream.h @@ -0,0 +1,258 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// This file contains the ZeroCopyInputStream and ZeroCopyOutputStream +// interfaces, which represent abstract I/O streams to and from which +// protocol buffers can be read and written. For a few simple +// implementations of these interfaces, see zero_copy_stream_impl.h. +// +// These interfaces are different from classic I/O streams in that they +// try to minimize the amount of data copying that needs to be done. +// To accomplish this, responsibility for allocating buffers is moved to +// the stream object, rather than being the responsibility of the caller. +// So, the stream can return a buffer which actually points directly into +// the final data structure where the bytes are to be stored, and the caller +// can interact directly with that buffer, eliminating an intermediate copy +// operation. +// +// As an example, consider the common case in which you are reading bytes +// from an array that is already in memory (or perhaps an mmap()ed file). +// With classic I/O streams, you would do something like: +// char buffer[BUFFER_SIZE]; +// input->Read(buffer, BUFFER_SIZE); +// DoSomething(buffer, BUFFER_SIZE); +// Then, the stream basically just calls memcpy() to copy the data from +// the array into your buffer. With a ZeroCopyInputStream, you would do +// this instead: +// const void* buffer; +// int size; +// input->Next(&buffer, &size); +// DoSomething(buffer, size); +// Here, no copy is performed. The input stream returns a pointer directly +// into the backing array, and the caller ends up reading directly from it. +// +// If you want to be able to read the old-fashion way, you can create +// a CodedInputStream or CodedOutputStream wrapping these objects and use +// their ReadRaw()/WriteRaw() methods. These will, of course, add a copy +// step, but Coded*Stream will handle buffering so at least it will be +// reasonably efficient. +// +// ZeroCopyInputStream example: +// // Read in a file and print its contents to stdout. +// int fd = open("myfile", O_RDONLY); +// ZeroCopyInputStream* input = new FileInputStream(fd); +// +// const void* buffer; +// int size; +// while (input->Next(&buffer, &size)) { +// cout.write(buffer, size); +// } +// +// delete input; +// close(fd); +// +// ZeroCopyOutputStream example: +// // Copy the contents of "infile" to "outfile", using plain read() for +// // "infile" but a ZeroCopyOutputStream for "outfile". +// int infd = open("infile", O_RDONLY); +// int outfd = open("outfile", O_WRONLY); +// ZeroCopyOutputStream* output = new FileOutputStream(outfd); +// +// void* buffer; +// int size; +// while (output->Next(&buffer, &size)) { +// int bytes = read(infd, buffer, size); +// if (bytes < size) { +// // Reached EOF. +// output->BackUp(size - bytes); +// break; +// } +// } +// +// delete output; +// close(infd); +// close(outfd); + +#ifndef GOOGLE_PROTOBUF_IO_ZERO_COPY_STREAM_H__ +#define GOOGLE_PROTOBUF_IO_ZERO_COPY_STREAM_H__ + + +#include + +#include +#include + + +namespace google { +namespace protobuf { +namespace io { + +// Defined in this file. +class ZeroCopyInputStream; +class ZeroCopyOutputStream; + +// Abstract interface similar to an input stream but designed to minimize +// copying. +class PROTOBUF_EXPORT ZeroCopyInputStream { + public: + ZeroCopyInputStream() {} + virtual ~ZeroCopyInputStream() {} + + // Obtains a chunk of data from the stream. + // + // Preconditions: + // * "size" and "data" are not NULL. + // + // Postconditions: + // * If the returned value is false, there is no more data to return or + // an error occurred. All errors are permanent. + // * Otherwise, "size" points to the actual number of bytes read and "data" + // points to a pointer to a buffer containing these bytes. + // * Ownership of this buffer remains with the stream, and the buffer + // remains valid only until some other method of the stream is called + // or the stream is destroyed. + // * It is legal for the returned buffer to have zero size, as long + // as repeatedly calling Next() eventually yields a buffer with non-zero + // size. + virtual bool Next(const void** data, int* size) = 0; + + // Backs up a number of bytes, so that the next call to Next() returns + // data again that was already returned by the last call to Next(). This + // is useful when writing procedures that are only supposed to read up + // to a certain point in the input, then return. If Next() returns a + // buffer that goes beyond what you wanted to read, you can use BackUp() + // to return to the point where you intended to finish. + // + // Preconditions: + // * The last method called must have been Next(). + // * count must be less than or equal to the size of the last buffer + // returned by Next(). + // + // Postconditions: + // * The last "count" bytes of the last buffer returned by Next() will be + // pushed back into the stream. Subsequent calls to Next() will return + // the same data again before producing new data. + virtual void BackUp(int count) = 0; + + // Skips a number of bytes. Returns false if the end of the stream is + // reached or some input error occurred. In the end-of-stream case, the + // stream is advanced to the end of the stream (so ByteCount() will return + // the total size of the stream). + virtual bool Skip(int count) = 0; + + // Returns the total number of bytes read since this object was created. + virtual int64_t ByteCount() const = 0; + + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ZeroCopyInputStream); +}; + +// Abstract interface similar to an output stream but designed to minimize +// copying. +class PROTOBUF_EXPORT ZeroCopyOutputStream { + public: + ZeroCopyOutputStream() {} + virtual ~ZeroCopyOutputStream() {} + + // Obtains a buffer into which data can be written. Any data written + // into this buffer will eventually (maybe instantly, maybe later on) + // be written to the output. + // + // Preconditions: + // * "size" and "data" are not NULL. + // + // Postconditions: + // * If the returned value is false, an error occurred. All errors are + // permanent. + // * Otherwise, "size" points to the actual number of bytes in the buffer + // and "data" points to the buffer. + // * Ownership of this buffer remains with the stream, and the buffer + // remains valid only until some other method of the stream is called + // or the stream is destroyed. + // * Any data which the caller stores in this buffer will eventually be + // written to the output (unless BackUp() is called). + // * It is legal for the returned buffer to have zero size, as long + // as repeatedly calling Next() eventually yields a buffer with non-zero + // size. + virtual bool Next(void** data, int* size) = 0; + + // Backs up a number of bytes, so that the end of the last buffer returned + // by Next() is not actually written. This is needed when you finish + // writing all the data you want to write, but the last buffer was bigger + // than you needed. You don't want to write a bunch of garbage after the + // end of your data, so you use BackUp() to back up. + // + // Preconditions: + // * The last method called must have been Next(). + // * count must be less than or equal to the size of the last buffer + // returned by Next(). + // * The caller must not have written anything to the last "count" bytes + // of that buffer. + // + // Postconditions: + // * The last "count" bytes of the last buffer returned by Next() will be + // ignored. + virtual void BackUp(int count) = 0; + + // Returns the total number of bytes written since this object was created. + virtual int64_t ByteCount() const = 0; + + // Write a given chunk of data to the output. Some output streams may + // implement this in a way that avoids copying. Check AllowsAliasing() before + // calling WriteAliasedRaw(). It will GOOGLE_CHECK fail if WriteAliasedRaw() is + // called on a stream that does not allow aliasing. + // + // NOTE: It is caller's responsibility to ensure that the chunk of memory + // remains live until all of the data has been consumed from the stream. + virtual bool WriteAliasedRaw(const void* data, int size); + virtual bool AllowsAliasing() const { return false; } + + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ZeroCopyOutputStream); +}; + +} // namespace io +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_IO_ZERO_COPY_STREAM_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/io/zero_copy_stream_impl.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/io/zero_copy_stream_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..b18d5451c90c08ac6e80d62534510c52051a1067 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/io/zero_copy_stream_impl.h @@ -0,0 +1,343 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// This file contains common implementations of the interfaces defined in +// zero_copy_stream.h which are only included in the full (non-lite) +// protobuf library. These implementations include Unix file descriptors +// and C++ iostreams. See also: zero_copy_stream_impl_lite.h + +#ifndef GOOGLE_PROTOBUF_IO_ZERO_COPY_STREAM_IMPL_H__ +#define GOOGLE_PROTOBUF_IO_ZERO_COPY_STREAM_IMPL_H__ + + +#include +#include + +#include +#include +#include + + +#include + +namespace google { +namespace protobuf { +namespace io { + +// =================================================================== + +// A ZeroCopyInputStream which reads from a file descriptor. +// +// FileInputStream is preferred over using an ifstream with IstreamInputStream. +// The latter will introduce an extra layer of buffering, harming performance. +// Also, it's conceivable that FileInputStream could someday be enhanced +// to use zero-copy file descriptors on OSs which support them. +class PROTOBUF_EXPORT FileInputStream : public ZeroCopyInputStream { + public: + // Creates a stream that reads from the given Unix file descriptor. + // If a block_size is given, it specifies the number of bytes that + // should be read and returned with each call to Next(). Otherwise, + // a reasonable default is used. + explicit FileInputStream(int file_descriptor, int block_size = -1); + + // Flushes any buffers and closes the underlying file. Returns false if + // an error occurs during the process; use GetErrno() to examine the error. + // Even if an error occurs, the file descriptor is closed when this returns. + bool Close(); + + // By default, the file descriptor is not closed when the stream is + // destroyed. Call SetCloseOnDelete(true) to change that. WARNING: + // This leaves no way for the caller to detect if close() fails. If + // detecting close() errors is important to you, you should arrange + // to close the descriptor yourself. + void SetCloseOnDelete(bool value) { copying_input_.SetCloseOnDelete(value); } + + // If an I/O error has occurred on this file descriptor, this is the + // errno from that error. Otherwise, this is zero. Once an error + // occurs, the stream is broken and all subsequent operations will + // fail. + int GetErrno() const { return copying_input_.GetErrno(); } + + // implements ZeroCopyInputStream ---------------------------------- + bool Next(const void** data, int* size) override; + void BackUp(int count) override; + bool Skip(int count) override; + int64_t ByteCount() const override; + + private: + class PROTOBUF_EXPORT CopyingFileInputStream : public CopyingInputStream { + public: + CopyingFileInputStream(int file_descriptor); + ~CopyingFileInputStream() override; + + bool Close(); + void SetCloseOnDelete(bool value) { close_on_delete_ = value; } + int GetErrno() const { return errno_; } + + // implements CopyingInputStream --------------------------------- + int Read(void* buffer, int size) override; + int Skip(int count) override; + + private: + // The file descriptor. + const int file_; + bool close_on_delete_; + bool is_closed_; + + // The errno of the I/O error, if one has occurred. Otherwise, zero. + int errno_; + + // Did we try to seek once and fail? If so, we assume this file descriptor + // doesn't support seeking and won't try again. + bool previous_seek_failed_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CopyingFileInputStream); + }; + + CopyingFileInputStream copying_input_; + CopyingInputStreamAdaptor impl_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(FileInputStream); +}; + +// =================================================================== + +// A ZeroCopyOutputStream which writes to a file descriptor. +// +// FileOutputStream is preferred over using an ofstream with +// OstreamOutputStream. The latter will introduce an extra layer of buffering, +// harming performance. Also, it's conceivable that FileOutputStream could +// someday be enhanced to use zero-copy file descriptors on OSs which +// support them. +class PROTOBUF_EXPORT FileOutputStream : public ZeroCopyOutputStream { + public: + // Creates a stream that writes to the given Unix file descriptor. + // If a block_size is given, it specifies the size of the buffers + // that should be returned by Next(). Otherwise, a reasonable default + // is used. + explicit FileOutputStream(int file_descriptor, int block_size = -1); + ~FileOutputStream() override; + + // Flushes any buffers and closes the underlying file. Returns false if + // an error occurs during the process; use GetErrno() to examine the error. + // Even if an error occurs, the file descriptor is closed when this returns. + bool Close(); + + // Flushes FileOutputStream's buffers but does not close the + // underlying file. No special measures are taken to ensure that + // underlying operating system file object is synchronized to disk. + bool Flush(); + + // By default, the file descriptor is not closed when the stream is + // destroyed. Call SetCloseOnDelete(true) to change that. WARNING: + // This leaves no way for the caller to detect if close() fails. If + // detecting close() errors is important to you, you should arrange + // to close the descriptor yourself. + void SetCloseOnDelete(bool value) { copying_output_.SetCloseOnDelete(value); } + + // If an I/O error has occurred on this file descriptor, this is the + // errno from that error. Otherwise, this is zero. Once an error + // occurs, the stream is broken and all subsequent operations will + // fail. + int GetErrno() const { return copying_output_.GetErrno(); } + + // implements ZeroCopyOutputStream --------------------------------- + bool Next(void** data, int* size) override; + void BackUp(int count) override; + int64_t ByteCount() const override; + + private: + class PROTOBUF_EXPORT CopyingFileOutputStream : public CopyingOutputStream { + public: + CopyingFileOutputStream(int file_descriptor); + ~CopyingFileOutputStream() override; + + bool Close(); + void SetCloseOnDelete(bool value) { close_on_delete_ = value; } + int GetErrno() const { return errno_; } + + // implements CopyingOutputStream -------------------------------- + bool Write(const void* buffer, int size) override; + + private: + // The file descriptor. + const int file_; + bool close_on_delete_; + bool is_closed_; + + // The errno of the I/O error, if one has occurred. Otherwise, zero. + int errno_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CopyingFileOutputStream); + }; + + CopyingFileOutputStream copying_output_; + CopyingOutputStreamAdaptor impl_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(FileOutputStream); +}; + +// =================================================================== + +// A ZeroCopyInputStream which reads from a C++ istream. +// +// Note that for reading files (or anything represented by a file descriptor), +// FileInputStream is more efficient. +class PROTOBUF_EXPORT IstreamInputStream : public ZeroCopyInputStream { + public: + // Creates a stream that reads from the given C++ istream. + // If a block_size is given, it specifies the number of bytes that + // should be read and returned with each call to Next(). Otherwise, + // a reasonable default is used. + explicit IstreamInputStream(std::istream* stream, int block_size = -1); + + // implements ZeroCopyInputStream ---------------------------------- + bool Next(const void** data, int* size) override; + void BackUp(int count) override; + bool Skip(int count) override; + int64_t ByteCount() const override; + + private: + class PROTOBUF_EXPORT CopyingIstreamInputStream : public CopyingInputStream { + public: + CopyingIstreamInputStream(std::istream* input); + ~CopyingIstreamInputStream() override; + + // implements CopyingInputStream --------------------------------- + int Read(void* buffer, int size) override; + // (We use the default implementation of Skip().) + + private: + // The stream. + std::istream* input_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CopyingIstreamInputStream); + }; + + CopyingIstreamInputStream copying_input_; + CopyingInputStreamAdaptor impl_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(IstreamInputStream); +}; + +// =================================================================== + +// A ZeroCopyOutputStream which writes to a C++ ostream. +// +// Note that for writing files (or anything represented by a file descriptor), +// FileOutputStream is more efficient. +class PROTOBUF_EXPORT OstreamOutputStream : public ZeroCopyOutputStream { + public: + // Creates a stream that writes to the given C++ ostream. + // If a block_size is given, it specifies the size of the buffers + // that should be returned by Next(). Otherwise, a reasonable default + // is used. + explicit OstreamOutputStream(std::ostream* stream, int block_size = -1); + ~OstreamOutputStream() override; + + // implements ZeroCopyOutputStream --------------------------------- + bool Next(void** data, int* size) override; + void BackUp(int count) override; + int64_t ByteCount() const override; + + private: + class PROTOBUF_EXPORT CopyingOstreamOutputStream + : public CopyingOutputStream { + public: + CopyingOstreamOutputStream(std::ostream* output); + ~CopyingOstreamOutputStream() override; + + // implements CopyingOutputStream -------------------------------- + bool Write(const void* buffer, int size) override; + + private: + // The stream. + std::ostream* output_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CopyingOstreamOutputStream); + }; + + CopyingOstreamOutputStream copying_output_; + CopyingOutputStreamAdaptor impl_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(OstreamOutputStream); +}; + +// =================================================================== + +// A ZeroCopyInputStream which reads from several other streams in sequence. +// ConcatenatingInputStream is unable to distinguish between end-of-stream +// and read errors in the underlying streams, so it assumes any errors mean +// end-of-stream. So, if the underlying streams fail for any other reason, +// ConcatenatingInputStream may do odd things. It is suggested that you do +// not use ConcatenatingInputStream on streams that might produce read errors +// other than end-of-stream. +class PROTOBUF_EXPORT ConcatenatingInputStream : public ZeroCopyInputStream { + public: + // All streams passed in as well as the array itself must remain valid + // until the ConcatenatingInputStream is destroyed. + ConcatenatingInputStream(ZeroCopyInputStream* const streams[], int count); + ~ConcatenatingInputStream() override = default; + + // implements ZeroCopyInputStream ---------------------------------- + bool Next(const void** data, int* size) override; + void BackUp(int count) override; + bool Skip(int count) override; + int64_t ByteCount() const override; + + + private: + // As streams are retired, streams_ is incremented and count_ is + // decremented. + ZeroCopyInputStream* const* streams_; + int stream_count_; + int64 bytes_retired_; // Bytes read from previous streams. + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ConcatenatingInputStream); +}; + +// =================================================================== + +} // namespace io +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_IO_ZERO_COPY_STREAM_IMPL_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/io/zero_copy_stream_impl_lite.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/io/zero_copy_stream_impl_lite.h new file mode 100644 index 0000000000000000000000000000000000000000..83d2ac0dcf4e85f8055afae40ae5bd221fbe2383 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/io/zero_copy_stream_impl_lite.h @@ -0,0 +1,411 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// This file contains common implementations of the interfaces defined in +// zero_copy_stream.h which are included in the "lite" protobuf library. +// These implementations cover I/O on raw arrays and strings, as well as +// adaptors which make it easy to implement streams based on traditional +// streams. Of course, many users will probably want to write their own +// implementations of these interfaces specific to the particular I/O +// abstractions they prefer to use, but these should cover the most common +// cases. + +#ifndef GOOGLE_PROTOBUF_IO_ZERO_COPY_STREAM_IMPL_LITE_H__ +#define GOOGLE_PROTOBUF_IO_ZERO_COPY_STREAM_IMPL_LITE_H__ + + +#include +#include +#include + +#include +#include +#include +#include + + +#include + +namespace google { +namespace protobuf { +namespace io { + +// =================================================================== + +// A ZeroCopyInputStream backed by an in-memory array of bytes. +class PROTOBUF_EXPORT ArrayInputStream : public ZeroCopyInputStream { + public: + // Create an InputStream that returns the bytes pointed to by "data". + // "data" remains the property of the caller but must remain valid until + // the stream is destroyed. If a block_size is given, calls to Next() + // will return data blocks no larger than the given size. Otherwise, the + // first call to Next() returns the entire array. block_size is mainly + // useful for testing; in production you would probably never want to set + // it. + ArrayInputStream(const void* data, int size, int block_size = -1); + ~ArrayInputStream() override = default; + + // implements ZeroCopyInputStream ---------------------------------- + bool Next(const void** data, int* size) override; + void BackUp(int count) override; + bool Skip(int count) override; + int64_t ByteCount() const override; + + + private: + const uint8* const data_; // The byte array. + const int size_; // Total size of the array. + const int block_size_; // How many bytes to return at a time. + + int position_; + int last_returned_size_; // How many bytes we returned last time Next() + // was called (used for error checking only). + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ArrayInputStream); +}; + +// =================================================================== + +// A ZeroCopyOutputStream backed by an in-memory array of bytes. +class PROTOBUF_EXPORT ArrayOutputStream : public ZeroCopyOutputStream { + public: + // Create an OutputStream that writes to the bytes pointed to by "data". + // "data" remains the property of the caller but must remain valid until + // the stream is destroyed. If a block_size is given, calls to Next() + // will return data blocks no larger than the given size. Otherwise, the + // first call to Next() returns the entire array. block_size is mainly + // useful for testing; in production you would probably never want to set + // it. + ArrayOutputStream(void* data, int size, int block_size = -1); + ~ArrayOutputStream() override = default; + + // implements ZeroCopyOutputStream --------------------------------- + bool Next(void** data, int* size) override; + void BackUp(int count) override; + int64_t ByteCount() const override; + + private: + uint8* const data_; // The byte array. + const int size_; // Total size of the array. + const int block_size_; // How many bytes to return at a time. + + int position_; + int last_returned_size_; // How many bytes we returned last time Next() + // was called (used for error checking only). + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ArrayOutputStream); +}; + +// =================================================================== + +// A ZeroCopyOutputStream which appends bytes to a string. +class PROTOBUF_EXPORT StringOutputStream : public ZeroCopyOutputStream { + public: + // Create a StringOutputStream which appends bytes to the given string. + // The string remains property of the caller, but it is mutated in arbitrary + // ways and MUST NOT be accessed in any way until you're done with the + // stream. Either be sure there's no further usage, or (safest) destroy the + // stream before using the contents. + // + // Hint: If you call target->reserve(n) before creating the stream, + // the first call to Next() will return at least n bytes of buffer + // space. + explicit StringOutputStream(std::string* target); + ~StringOutputStream() override = default; + + // implements ZeroCopyOutputStream --------------------------------- + bool Next(void** data, int* size) override; + void BackUp(int count) override; + int64_t ByteCount() const override; + + private: + static const int kMinimumSize = 16; + + std::string* target_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(StringOutputStream); +}; + +// Note: There is no StringInputStream. Instead, just create an +// ArrayInputStream as follows: +// ArrayInputStream input(str.data(), str.size()); + +// =================================================================== + +// A generic traditional input stream interface. +// +// Lots of traditional input streams (e.g. file descriptors, C stdio +// streams, and C++ iostreams) expose an interface where every read +// involves copying bytes into a buffer. If you want to take such an +// interface and make a ZeroCopyInputStream based on it, simply implement +// CopyingInputStream and then use CopyingInputStreamAdaptor. +// +// CopyingInputStream implementations should avoid buffering if possible. +// CopyingInputStreamAdaptor does its own buffering and will read data +// in large blocks. +class PROTOBUF_EXPORT CopyingInputStream { + public: + virtual ~CopyingInputStream() {} + + // Reads up to "size" bytes into the given buffer. Returns the number of + // bytes read. Read() waits until at least one byte is available, or + // returns zero if no bytes will ever become available (EOF), or -1 if a + // permanent read error occurred. + virtual int Read(void* buffer, int size) = 0; + + // Skips the next "count" bytes of input. Returns the number of bytes + // actually skipped. This will always be exactly equal to "count" unless + // EOF was reached or a permanent read error occurred. + // + // The default implementation just repeatedly calls Read() into a scratch + // buffer. + virtual int Skip(int count); +}; + +// A ZeroCopyInputStream which reads from a CopyingInputStream. This is +// useful for implementing ZeroCopyInputStreams that read from traditional +// streams. Note that this class is not really zero-copy. +// +// If you want to read from file descriptors or C++ istreams, this is +// already implemented for you: use FileInputStream or IstreamInputStream +// respectively. +class PROTOBUF_EXPORT CopyingInputStreamAdaptor : public ZeroCopyInputStream { + public: + // Creates a stream that reads from the given CopyingInputStream. + // If a block_size is given, it specifies the number of bytes that + // should be read and returned with each call to Next(). Otherwise, + // a reasonable default is used. The caller retains ownership of + // copying_stream unless SetOwnsCopyingStream(true) is called. + explicit CopyingInputStreamAdaptor(CopyingInputStream* copying_stream, + int block_size = -1); + ~CopyingInputStreamAdaptor() override; + + // Call SetOwnsCopyingStream(true) to tell the CopyingInputStreamAdaptor to + // delete the underlying CopyingInputStream when it is destroyed. + void SetOwnsCopyingStream(bool value) { owns_copying_stream_ = value; } + + // implements ZeroCopyInputStream ---------------------------------- + bool Next(const void** data, int* size) override; + void BackUp(int count) override; + bool Skip(int count) override; + int64_t ByteCount() const override; + + private: + // Insures that buffer_ is not NULL. + void AllocateBufferIfNeeded(); + // Frees the buffer and resets buffer_used_. + void FreeBuffer(); + + // The underlying copying stream. + CopyingInputStream* copying_stream_; + bool owns_copying_stream_; + + // True if we have seen a permanent error from the underlying stream. + bool failed_; + + // The current position of copying_stream_, relative to the point where + // we started reading. + int64 position_; + + // Data is read into this buffer. It may be NULL if no buffer is currently + // in use. Otherwise, it points to an array of size buffer_size_. + std::unique_ptr buffer_; + const int buffer_size_; + + // Number of valid bytes currently in the buffer (i.e. the size last + // returned by Next()). 0 <= buffer_used_ <= buffer_size_. + int buffer_used_; + + // Number of bytes in the buffer which were backed up over by a call to + // BackUp(). These need to be returned again. + // 0 <= backup_bytes_ <= buffer_used_ + int backup_bytes_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CopyingInputStreamAdaptor); +}; + +// =================================================================== + +// A generic traditional output stream interface. +// +// Lots of traditional output streams (e.g. file descriptors, C stdio +// streams, and C++ iostreams) expose an interface where every write +// involves copying bytes from a buffer. If you want to take such an +// interface and make a ZeroCopyOutputStream based on it, simply implement +// CopyingOutputStream and then use CopyingOutputStreamAdaptor. +// +// CopyingOutputStream implementations should avoid buffering if possible. +// CopyingOutputStreamAdaptor does its own buffering and will write data +// in large blocks. +class PROTOBUF_EXPORT CopyingOutputStream { + public: + virtual ~CopyingOutputStream() {} + + // Writes "size" bytes from the given buffer to the output. Returns true + // if successful, false on a write error. + virtual bool Write(const void* buffer, int size) = 0; +}; + +// A ZeroCopyOutputStream which writes to a CopyingOutputStream. This is +// useful for implementing ZeroCopyOutputStreams that write to traditional +// streams. Note that this class is not really zero-copy. +// +// If you want to write to file descriptors or C++ ostreams, this is +// already implemented for you: use FileOutputStream or OstreamOutputStream +// respectively. +class PROTOBUF_EXPORT CopyingOutputStreamAdaptor : public ZeroCopyOutputStream { + public: + // Creates a stream that writes to the given Unix file descriptor. + // If a block_size is given, it specifies the size of the buffers + // that should be returned by Next(). Otherwise, a reasonable default + // is used. + explicit CopyingOutputStreamAdaptor(CopyingOutputStream* copying_stream, + int block_size = -1); + ~CopyingOutputStreamAdaptor() override; + + // Writes all pending data to the underlying stream. Returns false if a + // write error occurred on the underlying stream. (The underlying + // stream itself is not necessarily flushed.) + bool Flush(); + + // Call SetOwnsCopyingStream(true) to tell the CopyingOutputStreamAdaptor to + // delete the underlying CopyingOutputStream when it is destroyed. + void SetOwnsCopyingStream(bool value) { owns_copying_stream_ = value; } + + // implements ZeroCopyOutputStream --------------------------------- + bool Next(void** data, int* size) override; + void BackUp(int count) override; + int64_t ByteCount() const override; + + private: + // Write the current buffer, if it is present. + bool WriteBuffer(); + // Insures that buffer_ is not NULL. + void AllocateBufferIfNeeded(); + // Frees the buffer. + void FreeBuffer(); + + // The underlying copying stream. + CopyingOutputStream* copying_stream_; + bool owns_copying_stream_; + + // True if we have seen a permanent error from the underlying stream. + bool failed_; + + // The current position of copying_stream_, relative to the point where + // we started writing. + int64 position_; + + // Data is written from this buffer. It may be NULL if no buffer is + // currently in use. Otherwise, it points to an array of size buffer_size_. + std::unique_ptr buffer_; + const int buffer_size_; + + // Number of valid bytes currently in the buffer (i.e. the size last + // returned by Next()). When BackUp() is called, we just reduce this. + // 0 <= buffer_used_ <= buffer_size_. + int buffer_used_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CopyingOutputStreamAdaptor); +}; + +// =================================================================== + +// A ZeroCopyInputStream which wraps some other stream and limits it to +// a particular byte count. +class PROTOBUF_EXPORT LimitingInputStream : public ZeroCopyInputStream { + public: + LimitingInputStream(ZeroCopyInputStream* input, int64 limit); + ~LimitingInputStream() override; + + // implements ZeroCopyInputStream ---------------------------------- + bool Next(const void** data, int* size) override; + void BackUp(int count) override; + bool Skip(int count) override; + int64_t ByteCount() const override; + + + private: + ZeroCopyInputStream* input_; + int64 limit_; // Decreases as we go, becomes negative if we overshoot. + int64 prior_bytes_read_; // Bytes read on underlying stream at construction + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(LimitingInputStream); +}; + + +// =================================================================== + +// mutable_string_data() and as_string_data() are workarounds to improve +// the performance of writing new data to an existing string. Unfortunately +// the methods provided by the string class are suboptimal, and using memcpy() +// is mildly annoying because it requires its pointer args to be non-NULL even +// if we ask it to copy 0 bytes. Furthermore, string_as_array() has the +// property that it always returns NULL if its arg is the empty string, exactly +// what we want to avoid if we're using it in conjunction with memcpy()! +// With C++11, the desired memcpy() boils down to memcpy(..., &(*s)[0], size), +// where s is a string*. Without C++11, &(*s)[0] is not guaranteed to be safe, +// so we use string_as_array(), and live with the extra logic that tests whether +// *s is empty. + +// Return a pointer to mutable characters underlying the given string. The +// return value is valid until the next time the string is resized. We +// trust the caller to treat the return value as an array of length s->size(). +inline char* mutable_string_data(std::string* s) { + // This should be simpler & faster than string_as_array() because the latter + // is guaranteed to return NULL when *s is empty, so it has to check for that. + return &(*s)[0]; +} + +// as_string_data(s) is equivalent to +// ({ char* p = mutable_string_data(s); make_pair(p, p != NULL); }) +// Sometimes it's faster: in some scenarios p cannot be NULL, and then the +// code can avoid that check. +inline std::pair as_string_data(std::string* s) { + char* p = mutable_string_data(s); + return std::make_pair(p, true); +} + +} // namespace io +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_IO_ZERO_COPY_STREAM_IMPL_LITE_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/bytestream.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/bytestream.h new file mode 100644 index 0000000000000000000000000000000000000000..65d62941aa61faeae9eb0ca6f69b28b0b1e3b151 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/bytestream.h @@ -0,0 +1,356 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// This file declares the ByteSink and ByteSource abstract interfaces. These +// interfaces represent objects that consume (ByteSink) or produce (ByteSource) +// a sequence of bytes. Using these abstract interfaces in your APIs can help +// make your code work with a variety of input and output types. +// +// This file also declares the following commonly used implementations of these +// interfaces. +// +// ByteSink: +// UncheckedArrayByteSink Writes to an array, without bounds checking +// CheckedArrayByteSink Writes to an array, with bounds checking +// GrowingArrayByteSink Allocates and writes to a growable buffer +// StringByteSink Writes to an STL string +// NullByteSink Consumes a never-ending stream of bytes +// +// ByteSource: +// ArrayByteSource Reads from an array or string/StringPiece +// LimitedByteSource Limits the number of bytes read from an + +#ifndef GOOGLE_PROTOBUF_STUBS_BYTESTREAM_H_ +#define GOOGLE_PROTOBUF_STUBS_BYTESTREAM_H_ + +#include +#include + +#include +#include + +#include + +class CordByteSink; + +namespace google { +namespace protobuf { +namespace strings { + +// An abstract interface for an object that consumes a sequence of bytes. This +// interface offers a way to append data as well as a Flush() function. +// +// Example: +// +// string my_data; +// ... +// ByteSink* sink = ... +// sink->Append(my_data.data(), my_data.size()); +// sink->Flush(); +// +class PROTOBUF_EXPORT ByteSink { + public: + ByteSink() {} + virtual ~ByteSink() {} + + // Appends the "n" bytes starting at "bytes". + virtual void Append(const char* bytes, size_t n) = 0; + + // Flushes internal buffers. The default implementation does nothing. ByteSink + // subclasses may use internal buffers that require calling Flush() at the end + // of the stream. + virtual void Flush(); + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ByteSink); +}; + +// An abstract interface for an object that produces a fixed-size sequence of +// bytes. +// +// Example: +// +// ByteSource* source = ... +// while (source->Available() > 0) { +// StringPiece data = source->Peek(); +// ... do something with "data" ... +// source->Skip(data.length()); +// } +// +class PROTOBUF_EXPORT ByteSource { + public: + ByteSource() {} + virtual ~ByteSource() {} + + // Returns the number of bytes left to read from the source. Available() + // should decrease by N each time Skip(N) is called. Available() may not + // increase. Available() returning 0 indicates that the ByteSource is + // exhausted. + // + // Note: Size() may have been a more appropriate name as it's more + // indicative of the fixed-size nature of a ByteSource. + virtual size_t Available() const = 0; + + // Returns a StringPiece of the next contiguous region of the source. Does not + // reposition the source. The returned region is empty iff Available() == 0. + // + // The returned region is valid until the next call to Skip() or until this + // object is destroyed, whichever occurs first. + // + // The length of the returned StringPiece will be <= Available(). + virtual StringPiece Peek() = 0; + + // Skips the next n bytes. Invalidates any StringPiece returned by a previous + // call to Peek(). + // + // REQUIRES: Available() >= n + virtual void Skip(size_t n) = 0; + + // Writes the next n bytes in this ByteSource to the given ByteSink, and + // advances this ByteSource past the copied bytes. The default implementation + // of this method just copies the bytes normally, but subclasses might + // override CopyTo to optimize certain cases. + // + // REQUIRES: Available() >= n + virtual void CopyTo(ByteSink* sink, size_t n); + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ByteSource); +}; + +// +// Some commonly used implementations of ByteSink +// + +// Implementation of ByteSink that writes to an unsized byte array. No +// bounds-checking is performed--it is the caller's responsibility to ensure +// that the destination array is large enough. +// +// Example: +// +// char buf[10]; +// UncheckedArrayByteSink sink(buf); +// sink.Append("hi", 2); // OK +// sink.Append(data, 100); // WOOPS! Overflows buf[10]. +// +class PROTOBUF_EXPORT UncheckedArrayByteSink : public ByteSink { + public: + explicit UncheckedArrayByteSink(char* dest) : dest_(dest) {} + virtual void Append(const char* data, size_t n) override; + + // Returns the current output pointer so that a caller can see how many bytes + // were produced. + // + // Note: this method is not part of the ByteSink interface. + char* CurrentDestination() const { return dest_; } + + private: + char* dest_; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(UncheckedArrayByteSink); +}; + +// Implementation of ByteSink that writes to a sized byte array. This sink will +// not write more than "capacity" bytes to outbuf. Once "capacity" bytes are +// appended, subsequent bytes will be ignored and Overflowed() will return true. +// Overflowed() does not cause a runtime error (i.e., it does not CHECK fail). +// +// Example: +// +// char buf[10]; +// CheckedArrayByteSink sink(buf, 10); +// sink.Append("hi", 2); // OK +// sink.Append(data, 100); // Will only write 8 more bytes +// +class PROTOBUF_EXPORT CheckedArrayByteSink : public ByteSink { + public: + CheckedArrayByteSink(char* outbuf, size_t capacity); + virtual void Append(const char* bytes, size_t n) override; + + // Returns the number of bytes actually written to the sink. + size_t NumberOfBytesWritten() const { return size_; } + + // Returns true if any bytes were discarded, i.e., if there was an + // attempt to write more than 'capacity' bytes. + bool Overflowed() const { return overflowed_; } + + private: + char* outbuf_; + const size_t capacity_; + size_t size_; + bool overflowed_; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CheckedArrayByteSink); +}; + +// Implementation of ByteSink that allocates an internal buffer (a char array) +// and expands it as needed to accommodate appended data (similar to a string), +// and allows the caller to take ownership of the internal buffer via the +// GetBuffer() method. The buffer returned from GetBuffer() must be deleted by +// the caller with delete[]. GetBuffer() also sets the internal buffer to be +// empty, and subsequent appends to the sink will create a new buffer. The +// destructor will free the internal buffer if GetBuffer() was not called. +// +// Example: +// +// GrowingArrayByteSink sink(10); +// sink.Append("hi", 2); +// sink.Append(data, n); +// const char* buf = sink.GetBuffer(); // Ownership transferred +// delete[] buf; +// +class PROTOBUF_EXPORT GrowingArrayByteSink : public strings::ByteSink { + public: + explicit GrowingArrayByteSink(size_t estimated_size); + virtual ~GrowingArrayByteSink(); + virtual void Append(const char* bytes, size_t n) override; + + // Returns the allocated buffer, and sets nbytes to its size. The caller takes + // ownership of the buffer and must delete it with delete[]. + char* GetBuffer(size_t* nbytes); + + private: + void Expand(size_t amount); + void ShrinkToFit(); + + size_t capacity_; + char* buf_; + size_t size_; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(GrowingArrayByteSink); +}; + +// Implementation of ByteSink that appends to the given string. +// Existing contents of "dest" are not modified; new data is appended. +// +// Example: +// +// string dest = "Hello "; +// StringByteSink sink(&dest); +// sink.Append("World", 5); +// assert(dest == "Hello World"); +// +class PROTOBUF_EXPORT StringByteSink : public ByteSink { + public: + explicit StringByteSink(string* dest) : dest_(dest) {} + virtual void Append(const char* data, size_t n) override; + + private: + string* dest_; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(StringByteSink); +}; + +// Implementation of ByteSink that discards all data. +// +// Example: +// +// NullByteSink sink; +// sink.Append(data, data.size()); // All data ignored. +// +class PROTOBUF_EXPORT NullByteSink : public ByteSink { + public: + NullByteSink() {} + void Append(const char* /*data*/, size_t /*n*/) override {} + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(NullByteSink); +}; + +// +// Some commonly used implementations of ByteSource +// + +// Implementation of ByteSource that reads from a StringPiece. +// +// Example: +// +// string data = "Hello"; +// ArrayByteSource source(data); +// assert(source.Available() == 5); +// assert(source.Peek() == "Hello"); +// +class PROTOBUF_EXPORT ArrayByteSource : public ByteSource { + public: + explicit ArrayByteSource(StringPiece s) : input_(s) {} + + virtual size_t Available() const override; + virtual StringPiece Peek() override; + virtual void Skip(size_t n) override; + + private: + StringPiece input_; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ArrayByteSource); +}; + +// Implementation of ByteSource that wraps another ByteSource, limiting the +// number of bytes returned. +// +// The caller maintains ownership of the underlying source, and may not use the +// underlying source while using the LimitByteSource object. The underlying +// source's pointer is advanced by n bytes every time this LimitByteSource +// object is advanced by n. +// +// Example: +// +// string data = "Hello World"; +// ArrayByteSource abs(data); +// assert(abs.Available() == data.size()); +// +// LimitByteSource limit(abs, 5); +// assert(limit.Available() == 5); +// assert(limit.Peek() == "Hello"); +// +class PROTOBUF_EXPORT LimitByteSource : public ByteSource { + public: + // Returns at most "limit" bytes from "source". + LimitByteSource(ByteSource* source, size_t limit); + + virtual size_t Available() const override; + virtual StringPiece Peek() override; + virtual void Skip(size_t n) override; + + // We override CopyTo so that we can forward to the underlying source, in + // case it has an efficient implementation of CopyTo. + virtual void CopyTo(ByteSink* sink, size_t n) override; + + private: + ByteSource* source_; + size_t limit_; +}; + +} // namespace strings +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_STUBS_BYTESTREAM_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/callback.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/callback.h new file mode 100644 index 0000000000000000000000000000000000000000..731d46fc821ba3055073a5f66046892c7485903d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/callback.h @@ -0,0 +1,588 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#ifndef GOOGLE_PROTOBUF_STUBS_CALLBACK_H_ +#define GOOGLE_PROTOBUF_STUBS_CALLBACK_H_ + +#include + +#include + +#include + +// =================================================================== +// emulates google3/base/callback.h + +namespace google { +namespace protobuf { + +// Abstract interface for a callback. When calling an RPC, you must provide +// a Closure to call when the procedure completes. See the Service interface +// in service.h. +// +// To automatically construct a Closure which calls a particular function or +// method with a particular set of parameters, use the NewCallback() function. +// Example: +// void FooDone(const FooResponse* response) { +// ... +// } +// +// void CallFoo() { +// ... +// // When done, call FooDone() and pass it a pointer to the response. +// Closure* callback = NewCallback(&FooDone, response); +// // Make the call. +// service->Foo(controller, request, response, callback); +// } +// +// Example that calls a method: +// class Handler { +// public: +// ... +// +// void FooDone(const FooResponse* response) { +// ... +// } +// +// void CallFoo() { +// ... +// // When done, call FooDone() and pass it a pointer to the response. +// Closure* callback = NewCallback(this, &Handler::FooDone, response); +// // Make the call. +// service->Foo(controller, request, response, callback); +// } +// }; +// +// Currently NewCallback() supports binding zero, one, or two arguments. +// +// Callbacks created with NewCallback() automatically delete themselves when +// executed. They should be used when a callback is to be called exactly +// once (usually the case with RPC callbacks). If a callback may be called +// a different number of times (including zero), create it with +// NewPermanentCallback() instead. You are then responsible for deleting the +// callback (using the "delete" keyword as normal). +// +// Note that NewCallback() is a bit touchy regarding argument types. Generally, +// the values you provide for the parameter bindings must exactly match the +// types accepted by the callback function. For example: +// void Foo(string s); +// NewCallback(&Foo, "foo"); // WON'T WORK: const char* != string +// NewCallback(&Foo, string("foo")); // WORKS +// Also note that the arguments cannot be references: +// void Foo(const string& s); +// string my_str; +// NewCallback(&Foo, my_str); // WON'T WORK: Can't use references. +// However, correctly-typed pointers will work just fine. +class PROTOBUF_EXPORT Closure { + public: + Closure() {} + virtual ~Closure(); + + virtual void Run() = 0; + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(Closure); +}; + +template +class ResultCallback { + public: + ResultCallback() {} + virtual ~ResultCallback() {} + + virtual R Run() = 0; + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ResultCallback); +}; + +template +class PROTOBUF_EXPORT ResultCallback1 { + public: + ResultCallback1() {} + virtual ~ResultCallback1() {} + + virtual R Run(A1) = 0; + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ResultCallback1); +}; + +template +class PROTOBUF_EXPORT ResultCallback2 { + public: + ResultCallback2() {} + virtual ~ResultCallback2() {} + + virtual R Run(A1,A2) = 0; + + private: + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ResultCallback2); +}; + +namespace internal { + +class PROTOBUF_EXPORT FunctionClosure0 : public Closure { + public: + typedef void (*FunctionType)(); + + FunctionClosure0(FunctionType function, bool self_deleting) + : function_(function), self_deleting_(self_deleting) {} + ~FunctionClosure0(); + + void Run() override { + bool needs_delete = self_deleting_; // read in case callback deletes + function_(); + if (needs_delete) delete this; + } + + private: + FunctionType function_; + bool self_deleting_; +}; + +template +class MethodClosure0 : public Closure { + public: + typedef void (Class::*MethodType)(); + + MethodClosure0(Class* object, MethodType method, bool self_deleting) + : object_(object), method_(method), self_deleting_(self_deleting) {} + ~MethodClosure0() {} + + void Run() override { + bool needs_delete = self_deleting_; // read in case callback deletes + (object_->*method_)(); + if (needs_delete) delete this; + } + + private: + Class* object_; + MethodType method_; + bool self_deleting_; +}; + +template +class FunctionClosure1 : public Closure { + public: + typedef void (*FunctionType)(Arg1 arg1); + + FunctionClosure1(FunctionType function, bool self_deleting, + Arg1 arg1) + : function_(function), self_deleting_(self_deleting), + arg1_(arg1) {} + ~FunctionClosure1() {} + + void Run() override { + bool needs_delete = self_deleting_; // read in case callback deletes + function_(arg1_); + if (needs_delete) delete this; + } + + private: + FunctionType function_; + bool self_deleting_; + Arg1 arg1_; +}; + +template +class MethodClosure1 : public Closure { + public: + typedef void (Class::*MethodType)(Arg1 arg1); + + MethodClosure1(Class* object, MethodType method, bool self_deleting, + Arg1 arg1) + : object_(object), method_(method), self_deleting_(self_deleting), + arg1_(arg1) {} + ~MethodClosure1() {} + + void Run() override { + bool needs_delete = self_deleting_; // read in case callback deletes + (object_->*method_)(arg1_); + if (needs_delete) delete this; + } + + private: + Class* object_; + MethodType method_; + bool self_deleting_; + Arg1 arg1_; +}; + +template +class FunctionClosure2 : public Closure { + public: + typedef void (*FunctionType)(Arg1 arg1, Arg2 arg2); + + FunctionClosure2(FunctionType function, bool self_deleting, + Arg1 arg1, Arg2 arg2) + : function_(function), self_deleting_(self_deleting), + arg1_(arg1), arg2_(arg2) {} + ~FunctionClosure2() {} + + void Run() override { + bool needs_delete = self_deleting_; // read in case callback deletes + function_(arg1_, arg2_); + if (needs_delete) delete this; + } + + private: + FunctionType function_; + bool self_deleting_; + Arg1 arg1_; + Arg2 arg2_; +}; + +template +class MethodClosure2 : public Closure { + public: + typedef void (Class::*MethodType)(Arg1 arg1, Arg2 arg2); + + MethodClosure2(Class* object, MethodType method, bool self_deleting, + Arg1 arg1, Arg2 arg2) + : object_(object), method_(method), self_deleting_(self_deleting), + arg1_(arg1), arg2_(arg2) {} + ~MethodClosure2() {} + + void Run() override { + bool needs_delete = self_deleting_; // read in case callback deletes + (object_->*method_)(arg1_, arg2_); + if (needs_delete) delete this; + } + + private: + Class* object_; + MethodType method_; + bool self_deleting_; + Arg1 arg1_; + Arg2 arg2_; +}; + +template +class FunctionResultCallback_0_0 : public ResultCallback { + public: + typedef R (*FunctionType)(); + + FunctionResultCallback_0_0(FunctionType function, bool self_deleting) + : function_(function), self_deleting_(self_deleting) {} + ~FunctionResultCallback_0_0() {} + + R Run() override { + bool needs_delete = self_deleting_; // read in case callback deletes + R result = function_(); + if (needs_delete) delete this; + return result; + } + + private: + FunctionType function_; + bool self_deleting_; +}; + +template +class FunctionResultCallback_1_0 : public ResultCallback { + public: + typedef R (*FunctionType)(P1); + + FunctionResultCallback_1_0(FunctionType function, bool self_deleting, + P1 p1) + : function_(function), self_deleting_(self_deleting), p1_(p1) {} + ~FunctionResultCallback_1_0() {} + + R Run() override { + bool needs_delete = self_deleting_; // read in case callback deletes + R result = function_(p1_); + if (needs_delete) delete this; + return result; + } + + private: + FunctionType function_; + bool self_deleting_; + P1 p1_; +}; + +template +class FunctionResultCallback_0_1 : public ResultCallback1 { + public: + typedef R (*FunctionType)(Arg1 arg1); + + FunctionResultCallback_0_1(FunctionType function, bool self_deleting) + : function_(function), self_deleting_(self_deleting) {} + ~FunctionResultCallback_0_1() {} + + R Run(Arg1 a1) override { + bool needs_delete = self_deleting_; // read in case callback deletes + R result = function_(a1); + if (needs_delete) delete this; + return result; + } + + private: + FunctionType function_; + bool self_deleting_; +}; + +template +class FunctionResultCallback_1_1 : public ResultCallback1 { + public: + typedef R (*FunctionType)(P1, A1); + + FunctionResultCallback_1_1(FunctionType function, bool self_deleting, + P1 p1) + : function_(function), self_deleting_(self_deleting), p1_(p1) {} + ~FunctionResultCallback_1_1() {} + + R Run(A1 a1) override { + bool needs_delete = self_deleting_; // read in case callback deletes + R result = function_(p1_, a1); + if (needs_delete) delete this; + return result; + } + + private: + FunctionType function_; + bool self_deleting_; + P1 p1_; +}; + +template +struct InternalConstRef { + typedef typename std::remove_reference::type base_type; + typedef const base_type& type; +}; + +template +class MethodResultCallback_0_0 : public ResultCallback { + public: + typedef R (T::*MethodType)(); + MethodResultCallback_0_0(T* object, MethodType method, bool self_deleting) + : object_(object), + method_(method), + self_deleting_(self_deleting) {} + ~MethodResultCallback_0_0() {} + + R Run() { + bool needs_delete = self_deleting_; + R result = (object_->*method_)(); + if (needs_delete) delete this; + return result; + } + + private: + T* object_; + MethodType method_; + bool self_deleting_; +}; + +template +class MethodResultCallback_6_2 : public ResultCallback2 { + public: + typedef R (T::*MethodType)(P1, P2, P3, P4, P5, P6, A1, A2); + MethodResultCallback_6_2(T* object, MethodType method, bool self_deleting, + P1 p1, P2 p2, P3 p3, P4 p4, P5 p5, P6 p6) + : object_(object), + method_(method), + self_deleting_(self_deleting), + p1_(p1), + p2_(p2), + p3_(p3), + p4_(p4), + p5_(p5), + p6_(p6) {} + ~MethodResultCallback_6_2() {} + + R Run(A1 a1, A2 a2) override { + bool needs_delete = self_deleting_; + R result = (object_->*method_)(p1_, p2_, p3_, p4_, p5_, p6_, a1, a2); + if (needs_delete) delete this; + return result; + } + + private: + T* object_; + MethodType method_; + bool self_deleting_; + typename std::remove_reference::type p1_; + typename std::remove_reference::type p2_; + typename std::remove_reference::type p3_; + typename std::remove_reference::type p4_; + typename std::remove_reference::type p5_; + typename std::remove_reference::type p6_; +}; + +} // namespace internal + +// See Closure. +inline Closure* NewCallback(void (*function)()) { + return new internal::FunctionClosure0(function, true); +} + +// See Closure. +inline Closure* NewPermanentCallback(void (*function)()) { + return new internal::FunctionClosure0(function, false); +} + +// See Closure. +template +inline Closure* NewCallback(Class* object, void (Class::*method)()) { + return new internal::MethodClosure0(object, method, true); +} + +// See Closure. +template +inline Closure* NewPermanentCallback(Class* object, void (Class::*method)()) { + return new internal::MethodClosure0(object, method, false); +} + +// See Closure. +template +inline Closure* NewCallback(void (*function)(Arg1), + Arg1 arg1) { + return new internal::FunctionClosure1(function, true, arg1); +} + +// See Closure. +template +inline Closure* NewPermanentCallback(void (*function)(Arg1), + Arg1 arg1) { + return new internal::FunctionClosure1(function, false, arg1); +} + +// See Closure. +template +inline Closure* NewCallback(Class* object, void (Class::*method)(Arg1), + Arg1 arg1) { + return new internal::MethodClosure1(object, method, true, arg1); +} + +// See Closure. +template +inline Closure* NewPermanentCallback(Class* object, void (Class::*method)(Arg1), + Arg1 arg1) { + return new internal::MethodClosure1(object, method, false, arg1); +} + +// See Closure. +template +inline Closure* NewCallback(void (*function)(Arg1, Arg2), + Arg1 arg1, Arg2 arg2) { + return new internal::FunctionClosure2( + function, true, arg1, arg2); +} + +// See Closure. +template +inline Closure* NewPermanentCallback(void (*function)(Arg1, Arg2), + Arg1 arg1, Arg2 arg2) { + return new internal::FunctionClosure2( + function, false, arg1, arg2); +} + +// See Closure. +template +inline Closure* NewCallback(Class* object, void (Class::*method)(Arg1, Arg2), + Arg1 arg1, Arg2 arg2) { + return new internal::MethodClosure2( + object, method, true, arg1, arg2); +} + +// See Closure. +template +inline Closure* NewPermanentCallback( + Class* object, void (Class::*method)(Arg1, Arg2), + Arg1 arg1, Arg2 arg2) { + return new internal::MethodClosure2( + object, method, false, arg1, arg2); +} + +// See ResultCallback +template +inline ResultCallback* NewCallback(R (*function)()) { + return new internal::FunctionResultCallback_0_0(function, true); +} + +// See ResultCallback +template +inline ResultCallback* NewPermanentCallback(R (*function)()) { + return new internal::FunctionResultCallback_0_0(function, false); +} + +// See ResultCallback +template +inline ResultCallback* NewCallback(R (*function)(P1), P1 p1) { + return new internal::FunctionResultCallback_1_0( + function, true, p1); +} + +// See ResultCallback +template +inline ResultCallback* NewPermanentCallback( + R (*function)(P1), P1 p1) { + return new internal::FunctionResultCallback_1_0( + function, false, p1); +} + +// See ResultCallback1 +template +inline ResultCallback1* NewCallback(R (*function)(A1)) { + return new internal::FunctionResultCallback_0_1(function, true); +} + +// See ResultCallback1 +template +inline ResultCallback1* NewPermanentCallback(R (*function)(A1)) { + return new internal::FunctionResultCallback_0_1(function, false); +} + +// See ResultCallback1 +template +inline ResultCallback1* NewCallback(R (*function)(P1, A1), P1 p1) { + return new internal::FunctionResultCallback_1_1( + function, true, p1); +} + +// See ResultCallback1 +template +inline ResultCallback1* NewPermanentCallback( + R (*function)(P1, A1), P1 p1) { + return new internal::FunctionResultCallback_1_1( + function, false, p1); +} + +// See MethodResultCallback_0_0 +template +inline ResultCallback* NewPermanentCallback( + T1* object, R (T2::*function)()) { + return new internal::MethodResultCallback_0_0(object, function, false); +} + +// See MethodResultCallback_6_2 +template +inline ResultCallback2* NewPermanentCallback( + T* object, R (T::*function)(P1, P2, P3, P4, P5, P6, A1, A2), + typename internal::InternalConstRef::type p1, + typename internal::InternalConstRef::type p2, + typename internal::InternalConstRef::type p3, + typename internal::InternalConstRef::type p4, + typename internal::InternalConstRef::type p5, + typename internal::InternalConstRef::type p6) { + return new internal::MethodResultCallback_6_2(object, function, false, + p1, p2, p3, p4, p5, p6); +} + +// A function which does nothing. Useful for creating no-op callbacks, e.g.: +// Closure* nothing = NewCallback(&DoNothing); +void PROTOBUF_EXPORT DoNothing(); + +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_STUBS_CALLBACK_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/casts.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/casts.h new file mode 100644 index 0000000000000000000000000000000000000000..b77ca87d78970e2431d5cc65626b81d9275b1d41 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/casts.h @@ -0,0 +1,144 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2014 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_CASTS_H__ +#define GOOGLE_PROTOBUF_CASTS_H__ + +#include + +#include +#include + +namespace google { +namespace protobuf { +namespace internal { + +// Use implicit_cast as a safe version of static_cast or const_cast +// for upcasting in the type hierarchy (i.e. casting a pointer to Foo +// to a pointer to SuperclassOfFoo or casting a pointer to Foo to +// a const pointer to Foo). +// When you use implicit_cast, the compiler checks that the cast is safe. +// Such explicit implicit_casts are necessary in surprisingly many +// situations where C++ demands an exact type match instead of an +// argument type convertable to a target type. +// +// The From type can be inferred, so the preferred syntax for using +// implicit_cast is the same as for static_cast etc.: +// +// implicit_cast(expr) +// +// implicit_cast would have been part of the C++ standard library, +// but the proposal was submitted too late. It will probably make +// its way into the language in the future. +template +inline To implicit_cast(From const &f) { + return f; +} + +// When you upcast (that is, cast a pointer from type Foo to type +// SuperclassOfFoo), it's fine to use implicit_cast<>, since upcasts +// always succeed. When you downcast (that is, cast a pointer from +// type Foo to type SubclassOfFoo), static_cast<> isn't safe, because +// how do you know the pointer is really of type SubclassOfFoo? It +// could be a bare Foo, or of type DifferentSubclassOfFoo. Thus, +// when you downcast, you should use this macro. In debug mode, we +// use dynamic_cast<> to double-check the downcast is legal (we die +// if it's not). In normal mode, we do the efficient static_cast<> +// instead. Thus, it's important to test in debug mode to make sure +// the cast is legal! +// This is the only place in the code we should use dynamic_cast<>. +// In particular, you SHOULDN'T be using dynamic_cast<> in order to +// do RTTI (eg code like this: +// if (dynamic_cast(foo)) HandleASubclass1Object(foo); +// if (dynamic_cast(foo)) HandleASubclass2Object(foo); +// You should design the code some other way not to need this. + +template // use like this: down_cast(foo); +inline To down_cast(From* f) { // so we only accept pointers + // Ensures that To is a sub-type of From *. This test is here only + // for compile-time type checking, and has no overhead in an + // optimized build at run-time, as it will be optimized away + // completely. + if (false) { + implicit_cast(0); + } + +#if !defined(NDEBUG) && PROTOBUF_RTTI + assert(f == nullptr || dynamic_cast(f) != nullptr); // RTTI: debug mode only! +#endif + return static_cast(f); +} + +template // use like this: down_cast(foo); +inline To down_cast(From& f) { + typedef typename std::remove_reference::type* ToAsPointer; + // Ensures that To is a sub-type of From *. This test is here only + // for compile-time type checking, and has no overhead in an + // optimized build at run-time, as it will be optimized away + // completely. + if (false) { + implicit_cast(0); + } + +#if !defined(NDEBUG) && PROTOBUF_RTTI + // RTTI: debug mode only! + assert(dynamic_cast(&f) != nullptr); +#endif + return *static_cast(&f); +} + +template +inline To bit_cast(const From& from) { + GOOGLE_COMPILE_ASSERT(sizeof(From) == sizeof(To), + bit_cast_with_different_sizes); + To dest; + memcpy(&dest, &from, sizeof(dest)); + return dest; +} + +} // namespace internal + +// We made these internal so that they would show up as such in the docs, +// but we don't want to stick "internal::" in front of them everywhere. +using internal::implicit_cast; +using internal::down_cast; +using internal::bit_cast; + +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_CASTS_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/fastmem.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/fastmem.h new file mode 100644 index 0000000000000000000000000000000000000000..ba25746d319f09787a79db713b6e6255c02d1299 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/fastmem.h @@ -0,0 +1,162 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2014 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Fast memory copying and comparison routines. +// strings::fastmemcmp_inlined() replaces memcmp() +// strings::memcpy_inlined() replaces memcpy() +// strings::memeq(a, b, n) replaces memcmp(a, b, n) == 0 +// +// strings::*_inlined() routines are inline versions of the +// routines exported by this module. Sometimes using the inlined +// versions is faster. Measure before using the inlined versions. +// +// Performance measurement: +// strings::fastmemcmp_inlined +// Analysis: memcmp, fastmemcmp_inlined, fastmemcmp +// 2012-01-30 + +#ifndef GOOGLE_PROTOBUF_STUBS_FASTMEM_H_ +#define GOOGLE_PROTOBUF_STUBS_FASTMEM_H_ + +#include +#include +#include + +#include + +#include + +namespace google { +namespace protobuf { +namespace internal { + +// Return true if the n bytes at a equal the n bytes at b. +// The regions are allowed to overlap. +// +// The performance is similar to the performance memcmp(), but faster for +// moderately-sized inputs, or inputs that share a common prefix and differ +// somewhere in their last 8 bytes. Further optimizations can be added later +// if it makes sense to do so.:w +inline bool memeq(const char* a, const char* b, size_t n) { + size_t n_rounded_down = n & ~static_cast(7); + if (PROTOBUF_PREDICT_FALSE(n_rounded_down == 0)) { // n <= 7 + return memcmp(a, b, n) == 0; + } + // n >= 8 + uint64 u = GOOGLE_UNALIGNED_LOAD64(a) ^ GOOGLE_UNALIGNED_LOAD64(b); + uint64 v = GOOGLE_UNALIGNED_LOAD64(a + n - 8) ^ GOOGLE_UNALIGNED_LOAD64(b + n - 8); + if ((u | v) != 0) { // The first or last 8 bytes differ. + return false; + } + a += 8; + b += 8; + n = n_rounded_down - 8; + if (n > 128) { + // As of 2012, memcmp on x86-64 uses a big unrolled loop with SSE2 + // instructions, and while we could try to do something faster, it + // doesn't seem worth pursuing. + return memcmp(a, b, n) == 0; + } + for (; n >= 16; n -= 16) { + uint64 x = GOOGLE_UNALIGNED_LOAD64(a) ^ GOOGLE_UNALIGNED_LOAD64(b); + uint64 y = GOOGLE_UNALIGNED_LOAD64(a + 8) ^ GOOGLE_UNALIGNED_LOAD64(b + 8); + if ((x | y) != 0) { + return false; + } + a += 16; + b += 16; + } + // n must be 0 or 8 now because it was a multiple of 8 at the top of the loop. + return n == 0 || GOOGLE_UNALIGNED_LOAD64(a) == GOOGLE_UNALIGNED_LOAD64(b); +} + +inline int fastmemcmp_inlined(const char *a, const char *b, size_t n) { + if (n >= 64) { + return memcmp(a, b, n); + } + const char* a_limit = a + n; + while (a + sizeof(uint64) <= a_limit && + GOOGLE_UNALIGNED_LOAD64(a) == GOOGLE_UNALIGNED_LOAD64(b)) { + a += sizeof(uint64); + b += sizeof(uint64); + } + if (a + sizeof(uint32) <= a_limit && + GOOGLE_UNALIGNED_LOAD32(a) == GOOGLE_UNALIGNED_LOAD32(b)) { + a += sizeof(uint32); + b += sizeof(uint32); + } + while (a < a_limit) { + int d = + static_cast(static_cast(*a++) - static_cast(*b++)); + if (d) return d; + } + return 0; +} + +// The standard memcpy operation is slow for variable small sizes. +// This implementation inlines the optimal realization for sizes 1 to 16. +// To avoid code bloat don't use it in case of not performance-critical spots, +// nor when you don't expect very frequent values of size <= 16. +inline void memcpy_inlined(char *dst, const char *src, size_t size) { + // Compiler inlines code with minimal amount of data movement when third + // parameter of memcpy is a constant. + switch (size) { + case 1: memcpy(dst, src, 1); break; + case 2: memcpy(dst, src, 2); break; + case 3: memcpy(dst, src, 3); break; + case 4: memcpy(dst, src, 4); break; + case 5: memcpy(dst, src, 5); break; + case 6: memcpy(dst, src, 6); break; + case 7: memcpy(dst, src, 7); break; + case 8: memcpy(dst, src, 8); break; + case 9: memcpy(dst, src, 9); break; + case 10: memcpy(dst, src, 10); break; + case 11: memcpy(dst, src, 11); break; + case 12: memcpy(dst, src, 12); break; + case 13: memcpy(dst, src, 13); break; + case 14: memcpy(dst, src, 14); break; + case 15: memcpy(dst, src, 15); break; + case 16: memcpy(dst, src, 16); break; + default: memcpy(dst, src, size); break; + } +} + +} // namespace internal +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_STUBS_FASTMEM_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/hash.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/hash.h new file mode 100644 index 0000000000000000000000000000000000000000..4d61f3d44fb19dc1a8415c99dc8c96ddba846862 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/hash.h @@ -0,0 +1,127 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) + +#ifndef GOOGLE_PROTOBUF_STUBS_HASH_H__ +#define GOOGLE_PROTOBUF_STUBS_HASH_H__ + +#include +#include +#include +#include + +# define GOOGLE_PROTOBUF_HASH_NAMESPACE_DECLARATION_START \ + namespace google { \ + namespace protobuf { +# define GOOGLE_PROTOBUF_HASH_NAMESPACE_DECLARATION_END }} + +namespace google { +namespace protobuf { + +template +struct hash : public std::hash {}; + +template +struct hash { + inline size_t operator()(const Key* key) const { + return reinterpret_cast(key); + } +}; + +// Unlike the old SGI version, the TR1 "hash" does not special-case char*. So, +// we go ahead and provide our own implementation. +template <> +struct hash { + inline size_t operator()(const char* str) const { + size_t result = 0; + for (; *str != '\0'; str++) { + result = 5 * result + static_cast(*str); + } + return result; + } +}; + +template<> +struct hash { + size_t operator()(bool x) const { + return static_cast(x); + } +}; + +template <> +struct hash { + inline size_t operator()(const std::string& key) const { + return hash()(key.c_str()); + } + + static const size_t bucket_size = 4; + static const size_t min_buckets = 8; + inline bool operator()(const std::string& a, const std::string& b) const { + return a < b; + } +}; + +template +struct hash > { + inline size_t operator()(const std::pair& key) const { + size_t first_hash = hash()(key.first); + size_t second_hash = hash()(key.second); + + // FIXME(kenton): What is the best way to compute this hash? I have + // no idea! This seems a bit better than an XOR. + return first_hash * ((1 << 16) - 1) + second_hash; + } + + static const size_t bucket_size = 4; + static const size_t min_buckets = 8; + inline bool operator()(const std::pair& a, + const std::pair& b) const { + return a < b; + } +}; + +// Used by GCC/SGI STL only. (Why isn't this provided by the standard +// library? :( ) +struct streq { + inline bool operator()(const char* a, const char* b) const { + return strcmp(a, b) == 0; + } +}; + +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_STUBS_HASH_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/logging.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/logging.h new file mode 100644 index 0000000000000000000000000000000000000000..318d1a435d94b1f64ac7921bbc36ed01b676bc9d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/logging.h @@ -0,0 +1,246 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_STUBS_LOGGING_H_ +#define GOOGLE_PROTOBUF_STUBS_LOGGING_H_ + +#include +#include + +#include + +// =================================================================== +// emulates google3/base/logging.h + +namespace google { +namespace protobuf { + +enum LogLevel { + LOGLEVEL_INFO, // Informational. This is never actually used by + // libprotobuf. + LOGLEVEL_WARNING, // Warns about issues that, although not technically a + // problem now, could cause problems in the future. For + // example, a // warning will be printed when parsing a + // message that is near the message size limit. + LOGLEVEL_ERROR, // An error occurred which should never happen during + // normal use. + LOGLEVEL_FATAL, // An error occurred from which the library cannot + // recover. This usually indicates a programming error + // in the code which calls the library, especially when + // compiled in debug mode. + +#ifdef NDEBUG + LOGLEVEL_DFATAL = LOGLEVEL_ERROR +#else + LOGLEVEL_DFATAL = LOGLEVEL_FATAL +#endif +}; + +class StringPiece; +namespace util { +class Status; +} +class uint128; +namespace internal { + +class LogFinisher; + +class PROTOBUF_EXPORT LogMessage { + public: + LogMessage(LogLevel level, const char* filename, int line); + ~LogMessage(); + + LogMessage& operator<<(const std::string& value); + LogMessage& operator<<(const char* value); + LogMessage& operator<<(char value); + LogMessage& operator<<(int value); + LogMessage& operator<<(uint value); + LogMessage& operator<<(long value); + LogMessage& operator<<(unsigned long value); + LogMessage& operator<<(long long value); + LogMessage& operator<<(unsigned long long value); + LogMessage& operator<<(double value); + LogMessage& operator<<(void* value); + LogMessage& operator<<(const StringPiece& value); + LogMessage& operator<<(const util::Status& status); + LogMessage& operator<<(const uint128& value); + + private: + friend class LogFinisher; + void Finish(); + + LogLevel level_; + const char* filename_; + int line_; + std::string message_; +}; + +// Used to make the entire "LOG(BLAH) << etc." expression have a void return +// type and print a newline after each message. +class PROTOBUF_EXPORT LogFinisher { + public: + void operator=(LogMessage& other); +}; + +template +bool IsOk(T status) { return status.ok(); } +template<> +inline bool IsOk(bool status) { return status; } + +} // namespace internal + +// Undef everything in case we're being mixed with some other Google library +// which already defined them itself. Presumably all Google libraries will +// support the same syntax for these so it should not be a big deal if they +// end up using our definitions instead. +#undef GOOGLE_LOG +#undef GOOGLE_LOG_IF + +#undef GOOGLE_CHECK +#undef GOOGLE_CHECK_OK +#undef GOOGLE_CHECK_EQ +#undef GOOGLE_CHECK_NE +#undef GOOGLE_CHECK_LT +#undef GOOGLE_CHECK_LE +#undef GOOGLE_CHECK_GT +#undef GOOGLE_CHECK_GE +#undef GOOGLE_CHECK_NOTNULL + +#undef GOOGLE_DLOG +#undef GOOGLE_DCHECK +#undef GOOGLE_DCHECK_OK +#undef GOOGLE_DCHECK_EQ +#undef GOOGLE_DCHECK_NE +#undef GOOGLE_DCHECK_LT +#undef GOOGLE_DCHECK_LE +#undef GOOGLE_DCHECK_GT +#undef GOOGLE_DCHECK_GE + +#define GOOGLE_LOG(LEVEL) \ + ::google::protobuf::internal::LogFinisher() = \ + ::google::protobuf::internal::LogMessage( \ + ::google::protobuf::LOGLEVEL_##LEVEL, __FILE__, __LINE__) +#define GOOGLE_LOG_IF(LEVEL, CONDITION) \ + !(CONDITION) ? (void)0 : GOOGLE_LOG(LEVEL) + +#define GOOGLE_CHECK(EXPRESSION) \ + GOOGLE_LOG_IF(FATAL, !(EXPRESSION)) << "CHECK failed: " #EXPRESSION ": " +#define GOOGLE_CHECK_OK(A) GOOGLE_CHECK(::google::protobuf::internal::IsOk(A)) +#define GOOGLE_CHECK_EQ(A, B) GOOGLE_CHECK((A) == (B)) +#define GOOGLE_CHECK_NE(A, B) GOOGLE_CHECK((A) != (B)) +#define GOOGLE_CHECK_LT(A, B) GOOGLE_CHECK((A) < (B)) +#define GOOGLE_CHECK_LE(A, B) GOOGLE_CHECK((A) <= (B)) +#define GOOGLE_CHECK_GT(A, B) GOOGLE_CHECK((A) > (B)) +#define GOOGLE_CHECK_GE(A, B) GOOGLE_CHECK((A) >= (B)) + +namespace internal { +template +T* CheckNotNull(const char* /* file */, int /* line */, + const char* name, T* val) { + if (val == nullptr) { + GOOGLE_LOG(FATAL) << name; + } + return val; +} +} // namespace internal +#define GOOGLE_CHECK_NOTNULL(A) \ + ::google::protobuf::internal::CheckNotNull( \ + __FILE__, __LINE__, "'" #A "' must not be nullptr", (A)) + +#ifdef NDEBUG + +#define GOOGLE_DLOG(LEVEL) GOOGLE_LOG_IF(LEVEL, false) + +#define GOOGLE_DCHECK(EXPRESSION) while(false) GOOGLE_CHECK(EXPRESSION) +#define GOOGLE_DCHECK_OK(E) GOOGLE_DCHECK(::google::protobuf::internal::IsOk(E)) +#define GOOGLE_DCHECK_EQ(A, B) GOOGLE_DCHECK((A) == (B)) +#define GOOGLE_DCHECK_NE(A, B) GOOGLE_DCHECK((A) != (B)) +#define GOOGLE_DCHECK_LT(A, B) GOOGLE_DCHECK((A) < (B)) +#define GOOGLE_DCHECK_LE(A, B) GOOGLE_DCHECK((A) <= (B)) +#define GOOGLE_DCHECK_GT(A, B) GOOGLE_DCHECK((A) > (B)) +#define GOOGLE_DCHECK_GE(A, B) GOOGLE_DCHECK((A) >= (B)) + +#else // NDEBUG + +#define GOOGLE_DLOG GOOGLE_LOG + +#define GOOGLE_DCHECK GOOGLE_CHECK +#define GOOGLE_DCHECK_OK GOOGLE_CHECK_OK +#define GOOGLE_DCHECK_EQ GOOGLE_CHECK_EQ +#define GOOGLE_DCHECK_NE GOOGLE_CHECK_NE +#define GOOGLE_DCHECK_LT GOOGLE_CHECK_LT +#define GOOGLE_DCHECK_LE GOOGLE_CHECK_LE +#define GOOGLE_DCHECK_GT GOOGLE_CHECK_GT +#define GOOGLE_DCHECK_GE GOOGLE_CHECK_GE + +#endif // !NDEBUG + +typedef void LogHandler(LogLevel level, const char* filename, int line, + const std::string& message); + +// The protobuf library sometimes writes warning and error messages to +// stderr. These messages are primarily useful for developers, but may +// also help end users figure out a problem. If you would prefer that +// these messages be sent somewhere other than stderr, call SetLogHandler() +// to set your own handler. This returns the old handler. Set the handler +// to nullptr to ignore log messages (but see also LogSilencer, below). +// +// Obviously, SetLogHandler is not thread-safe. You should only call it +// at initialization time, and probably not from library code. If you +// simply want to suppress log messages temporarily (e.g. because you +// have some code that tends to trigger them frequently and you know +// the warnings are not important to you), use the LogSilencer class +// below. +PROTOBUF_EXPORT LogHandler* SetLogHandler(LogHandler* new_func); + +// Create a LogSilencer if you want to temporarily suppress all log +// messages. As long as any LogSilencer objects exist, non-fatal +// log messages will be discarded (the current LogHandler will *not* +// be called). Constructing a LogSilencer is thread-safe. You may +// accidentally suppress log messages occurring in another thread, but +// since messages are generally for debugging purposes only, this isn't +// a big deal. If you want to intercept log messages, use SetLogHandler(). +class PROTOBUF_EXPORT LogSilencer { + public: + LogSilencer(); + ~LogSilencer(); +}; + +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_STUBS_LOGGING_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/macros.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/macros.h new file mode 100644 index 0000000000000000000000000000000000000000..581790c6d72796fcf48170bd1aaa0ff2fee8e5b8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/macros.h @@ -0,0 +1,125 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_MACROS_H__ +#define GOOGLE_PROTOBUF_MACROS_H__ + +#include + +namespace google { +namespace protobuf { + +#undef GOOGLE_DISALLOW_EVIL_CONSTRUCTORS +#define GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(TypeName) \ + TypeName(const TypeName&); \ + void operator=(const TypeName&) + +#undef GOOGLE_DISALLOW_IMPLICIT_CONSTRUCTORS +#define GOOGLE_DISALLOW_IMPLICIT_CONSTRUCTORS(TypeName) \ + TypeName(); \ + TypeName(const TypeName&); \ + void operator=(const TypeName&) + +// =================================================================== +// from google3/base/basictypes.h + +// The GOOGLE_ARRAYSIZE(arr) macro returns the # of elements in an array arr. +// The expression is a compile-time constant, and therefore can be +// used in defining new arrays, for example. +// +// GOOGLE_ARRAYSIZE catches a few type errors. If you see a compiler error +// +// "warning: division by zero in ..." +// +// when using GOOGLE_ARRAYSIZE, you are (wrongfully) giving it a pointer. +// You should only use GOOGLE_ARRAYSIZE on statically allocated arrays. +// +// The following comments are on the implementation details, and can +// be ignored by the users. +// +// ARRAYSIZE(arr) works by inspecting sizeof(arr) (the # of bytes in +// the array) and sizeof(*(arr)) (the # of bytes in one array +// element). If the former is divisible by the latter, perhaps arr is +// indeed an array, in which case the division result is the # of +// elements in the array. Otherwise, arr cannot possibly be an array, +// and we generate a compiler error to prevent the code from +// compiling. +// +// Since the size of bool is implementation-defined, we need to cast +// !(sizeof(a) & sizeof(*(a))) to size_t in order to ensure the final +// result has type size_t. +// +// This macro is not perfect as it wrongfully accepts certain +// pointers, namely where the pointer size is divisible by the pointee +// size. Since all our code has to go through a 32-bit compiler, +// where a pointer is 4 bytes, this means all pointers to a type whose +// size is 3 or greater than 4 will be (righteously) rejected. +// +// Kudos to Jorg Brown for this simple and elegant implementation. + +#undef GOOGLE_ARRAYSIZE +#define GOOGLE_ARRAYSIZE(a) \ + ((sizeof(a) / sizeof(*(a))) / \ + static_cast(!(sizeof(a) % sizeof(*(a))))) + +// The COMPILE_ASSERT macro can be used to verify that a compile time +// expression is true. For example, you could use it to verify the +// size of a static array: +// +// COMPILE_ASSERT(ARRAYSIZE(content_type_names) == CONTENT_NUM_TYPES, +// content_type_names_incorrect_size); +// +// or to make sure a struct is smaller than a certain size: +// +// COMPILE_ASSERT(sizeof(foo) < 128, foo_too_large); +// +// The second argument to the macro is the name of the variable. If +// the expression is false, most compilers will issue a warning/error +// containing the name of the variable. + +namespace internal { + +template +struct CompileAssert { +}; + +} // namespace internal + +#define GOOGLE_COMPILE_ASSERT(expr, msg) static_assert(expr, #msg) + +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_MACROS_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/map_util.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/map_util.h new file mode 100644 index 0000000000000000000000000000000000000000..17f6b90aa0ad3372f40b2fa5dfe23be2157f647e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/map_util.h @@ -0,0 +1,774 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2014 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// from google3/util/gtl/map_util.h +// Author: Anton Carver + +#ifndef GOOGLE_PROTOBUF_STUBS_MAP_UTIL_H__ +#define GOOGLE_PROTOBUF_STUBS_MAP_UTIL_H__ + +#include +#include +#include +#include +#include + +#include + +namespace google { +namespace protobuf { +namespace internal { +// Local implementation of RemoveConst to avoid including base/type_traits.h. +template struct RemoveConst { typedef T type; }; +template struct RemoveConst : RemoveConst {}; +} // namespace internal + +// +// Find*() +// + +// Returns a const reference to the value associated with the given key if it +// exists. Crashes otherwise. +// +// This is intended as a replacement for operator[] as an rvalue (for reading) +// when the key is guaranteed to exist. +// +// operator[] for lookup is discouraged for several reasons: +// * It has a side-effect of inserting missing keys +// * It is not thread-safe (even when it is not inserting, it can still +// choose to resize the underlying storage) +// * It invalidates iterators (when it chooses to resize) +// * It default constructs a value object even if it doesn't need to +// +// This version assumes the key is printable, and includes it in the fatal log +// message. +template +const typename Collection::value_type::second_type& +FindOrDie(const Collection& collection, + const typename Collection::value_type::first_type& key) { + typename Collection::const_iterator it = collection.find(key); + GOOGLE_CHECK(it != collection.end()) << "Map key not found: " << key; + return it->second; +} + +// Same as above, but returns a non-const reference. +template +typename Collection::value_type::second_type& +FindOrDie(Collection& collection, // NOLINT + const typename Collection::value_type::first_type& key) { + typename Collection::iterator it = collection.find(key); + GOOGLE_CHECK(it != collection.end()) << "Map key not found: " << key; + return it->second; +} + +// Same as FindOrDie above, but doesn't log the key on failure. +template +const typename Collection::value_type::second_type& +FindOrDieNoPrint(const Collection& collection, + const typename Collection::value_type::first_type& key) { + typename Collection::const_iterator it = collection.find(key); + GOOGLE_CHECK(it != collection.end()) << "Map key not found"; + return it->second; +} + +// Same as above, but returns a non-const reference. +template +typename Collection::value_type::second_type& +FindOrDieNoPrint(Collection& collection, // NOLINT + const typename Collection::value_type::first_type& key) { + typename Collection::iterator it = collection.find(key); + GOOGLE_CHECK(it != collection.end()) << "Map key not found"; + return it->second; +} + +// Returns a const reference to the value associated with the given key if it +// exists, otherwise returns a const reference to the provided default value. +// +// WARNING: If a temporary object is passed as the default "value," +// this function will return a reference to that temporary object, +// which will be destroyed at the end of the statement. A common +// example: if you have a map with string values, and you pass a char* +// as the default "value," either use the returned value immediately +// or store it in a string (not string&). +// Details: http://go/findwithdefault +template +const typename Collection::value_type::second_type& +FindWithDefault(const Collection& collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& value) { + typename Collection::const_iterator it = collection.find(key); + if (it == collection.end()) { + return value; + } + return it->second; +} + +// Returns a pointer to the const value associated with the given key if it +// exists, or nullptr otherwise. +template +const typename Collection::value_type::second_type* +FindOrNull(const Collection& collection, + const typename Collection::value_type::first_type& key) { + typename Collection::const_iterator it = collection.find(key); + if (it == collection.end()) { + return 0; + } + return &it->second; +} + +// Same as above but returns a pointer to the non-const value. +template +typename Collection::value_type::second_type* +FindOrNull(Collection& collection, // NOLINT + const typename Collection::value_type::first_type& key) { + typename Collection::iterator it = collection.find(key); + if (it == collection.end()) { + return 0; + } + return &it->second; +} + +// Returns the pointer value associated with the given key. If none is found, +// nullptr is returned. The function is designed to be used with a map of keys to +// pointers. +// +// This function does not distinguish between a missing key and a key mapped +// to nullptr. +template +typename Collection::value_type::second_type +FindPtrOrNull(const Collection& collection, + const typename Collection::value_type::first_type& key) { + typename Collection::const_iterator it = collection.find(key); + if (it == collection.end()) { + return typename Collection::value_type::second_type(); + } + return it->second; +} + +// Same as above, except takes non-const reference to collection. +// +// This function is needed for containers that propagate constness to the +// pointee, such as boost::ptr_map. +template +typename Collection::value_type::second_type +FindPtrOrNull(Collection& collection, // NOLINT + const typename Collection::value_type::first_type& key) { + typename Collection::iterator it = collection.find(key); + if (it == collection.end()) { + return typename Collection::value_type::second_type(); + } + return it->second; +} + +// Finds the pointer value associated with the given key in a map whose values +// are linked_ptrs. Returns nullptr if key is not found. +template +typename Collection::value_type::second_type::element_type* +FindLinkedPtrOrNull(const Collection& collection, + const typename Collection::value_type::first_type& key) { + typename Collection::const_iterator it = collection.find(key); + if (it == collection.end()) { + return 0; + } + // Since linked_ptr::get() is a const member returning a non const, + // we do not need a version of this function taking a non const collection. + return it->second.get(); +} + +// Same as above, but dies if the key is not found. +template +typename Collection::value_type::second_type::element_type& +FindLinkedPtrOrDie(const Collection& collection, + const typename Collection::value_type::first_type& key) { + typename Collection::const_iterator it = collection.find(key); + GOOGLE_CHECK(it != collection.end()) << "key not found: " << key; + // Since linked_ptr::operator*() is a const member returning a non const, + // we do not need a version of this function taking a non const collection. + return *it->second; +} + +// Finds the value associated with the given key and copies it to *value (if not +// nullptr). Returns false if the key was not found, true otherwise. +template +bool FindCopy(const Collection& collection, + const Key& key, + Value* const value) { + typename Collection::const_iterator it = collection.find(key); + if (it == collection.end()) { + return false; + } + if (value) { + *value = it->second; + } + return true; +} + +// +// Contains*() +// + +// Returns true if and only if the given collection contains the given key. +template +bool ContainsKey(const Collection& collection, const Key& key) { + return collection.find(key) != collection.end(); +} + +// Returns true if and only if the given collection contains the given key-value +// pair. +template +bool ContainsKeyValuePair(const Collection& collection, + const Key& key, + const Value& value) { + typedef typename Collection::const_iterator const_iterator; + std::pair range = collection.equal_range(key); + for (const_iterator it = range.first; it != range.second; ++it) { + if (it->second == value) { + return true; + } + } + return false; +} + +// +// Insert*() +// + +// Inserts the given key-value pair into the collection. Returns true if and +// only if the key from the given pair didn't previously exist. Otherwise, the +// value in the map is replaced with the value from the given pair. +template +bool InsertOrUpdate(Collection* const collection, + const typename Collection::value_type& vt) { + std::pair ret = collection->insert(vt); + if (!ret.second) { + // update + ret.first->second = vt.second; + return false; + } + return true; +} + +// Same as above, except that the key and value are passed separately. +template +bool InsertOrUpdate(Collection* const collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& value) { + return InsertOrUpdate( + collection, typename Collection::value_type(key, value)); +} + +// Inserts/updates all the key-value pairs from the range defined by the +// iterators "first" and "last" into the given collection. +template +void InsertOrUpdateMany(Collection* const collection, + InputIterator first, InputIterator last) { + for (; first != last; ++first) { + InsertOrUpdate(collection, *first); + } +} + +// Change the value associated with a particular key in a map or hash_map +// of the form map which owns the objects pointed to by the +// value pointers. If there was an existing value for the key, it is deleted. +// True indicates an insert took place, false indicates an update + delete. +template +bool InsertAndDeleteExisting( + Collection* const collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& value) { + std::pair ret = + collection->insert(typename Collection::value_type(key, value)); + if (!ret.second) { + delete ret.first->second; + ret.first->second = value; + return false; + } + return true; +} + +// Inserts the given key and value into the given collection if and only if the +// given key did NOT already exist in the collection. If the key previously +// existed in the collection, the value is not changed. Returns true if the +// key-value pair was inserted; returns false if the key was already present. +template +bool InsertIfNotPresent(Collection* const collection, + const typename Collection::value_type& vt) { + return collection->insert(vt).second; +} + +// Same as above except the key and value are passed separately. +template +bool InsertIfNotPresent( + Collection* const collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& value) { + return InsertIfNotPresent( + collection, typename Collection::value_type(key, value)); +} + +// Same as above except dies if the key already exists in the collection. +template +void InsertOrDie(Collection* const collection, + const typename Collection::value_type& value) { + GOOGLE_CHECK(InsertIfNotPresent(collection, value)) + << "duplicate value: " << value; +} + +// Same as above except doesn't log the value on error. +template +void InsertOrDieNoPrint(Collection* const collection, + const typename Collection::value_type& value) { + GOOGLE_CHECK(InsertIfNotPresent(collection, value)) << "duplicate value."; +} + +// Inserts the key-value pair into the collection. Dies if key was already +// present. +template +void InsertOrDie(Collection* const collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& data) { + GOOGLE_CHECK(InsertIfNotPresent(collection, key, data)) + << "duplicate key: " << key; +} + +// Same as above except doesn't log the key on error. +template +void InsertOrDieNoPrint( + Collection* const collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& data) { + GOOGLE_CHECK(InsertIfNotPresent(collection, key, data)) << "duplicate key."; +} + +// Inserts a new key and default-initialized value. Dies if the key was already +// present. Returns a reference to the value. Example usage: +// +// map m; +// SomeProto& proto = InsertKeyOrDie(&m, 3); +// proto.set_field("foo"); +template +typename Collection::value_type::second_type& InsertKeyOrDie( + Collection* const collection, + const typename Collection::value_type::first_type& key) { + typedef typename Collection::value_type value_type; + std::pair res = + collection->insert(value_type(key, typename value_type::second_type())); + GOOGLE_CHECK(res.second) << "duplicate key: " << key; + return res.first->second; +} + +// +// Lookup*() +// + +// Looks up a given key and value pair in a collection and inserts the key-value +// pair if it's not already present. Returns a reference to the value associated +// with the key. +template +typename Collection::value_type::second_type& +LookupOrInsert(Collection* const collection, + const typename Collection::value_type& vt) { + return collection->insert(vt).first->second; +} + +// Same as above except the key-value are passed separately. +template +typename Collection::value_type::second_type& +LookupOrInsert(Collection* const collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& value) { + return LookupOrInsert( + collection, typename Collection::value_type(key, value)); +} + +// Counts the number of equivalent elements in the given "sequence", and stores +// the results in "count_map" with element as the key and count as the value. +// +// Example: +// vector v = {"a", "b", "c", "a", "b"}; +// map m; +// AddTokenCounts(v, 1, &m); +// assert(m["a"] == 2); +// assert(m["b"] == 2); +// assert(m["c"] == 1); +template +void AddTokenCounts( + const Sequence& sequence, + const typename Collection::value_type::second_type& increment, + Collection* const count_map) { + for (typename Sequence::const_iterator it = sequence.begin(); + it != sequence.end(); ++it) { + typename Collection::value_type::second_type& value = + LookupOrInsert(count_map, *it, + typename Collection::value_type::second_type()); + value += increment; + } +} + +// Returns a reference to the value associated with key. If not found, a value +// is default constructed on the heap and added to the map. +// +// This function is useful for containers of the form map, where +// inserting a new key, value pair involves constructing a new heap-allocated +// Value, and storing a pointer to that in the collection. +template +typename Collection::value_type::second_type& +LookupOrInsertNew(Collection* const collection, + const typename Collection::value_type::first_type& key) { + typedef typename std::iterator_traits< + typename Collection::value_type::second_type>::value_type Element; + std::pair ret = + collection->insert(typename Collection::value_type( + key, + static_cast(nullptr))); + if (ret.second) { + ret.first->second = new Element(); + } + return ret.first->second; +} + +// Same as above but constructs the value using the single-argument constructor +// and the given "arg". +template +typename Collection::value_type::second_type& +LookupOrInsertNew(Collection* const collection, + const typename Collection::value_type::first_type& key, + const Arg& arg) { + typedef typename std::iterator_traits< + typename Collection::value_type::second_type>::value_type Element; + std::pair ret = + collection->insert(typename Collection::value_type( + key, + static_cast(nullptr))); + if (ret.second) { + ret.first->second = new Element(arg); + } + return ret.first->second; +} + +// Lookup of linked/shared pointers is used in two scenarios: +// +// Use LookupOrInsertNewLinkedPtr if the container owns the elements. +// In this case it is fine working with the raw pointer as long as it is +// guaranteed that no other thread can delete/update an accessed element. +// A mutex will need to lock the container operation as well as the use +// of the returned elements. Finding an element may be performed using +// FindLinkedPtr*(). +// +// Use LookupOrInsertNewSharedPtr if the container does not own the elements +// for their whole lifetime. This is typically the case when a reader allows +// parallel updates to the container. In this case a Mutex only needs to lock +// container operations, but all element operations must be performed on the +// shared pointer. Finding an element must be performed using FindPtr*() and +// cannot be done with FindLinkedPtr*() even though it compiles. + +// Lookup a key in a map or hash_map whose values are linked_ptrs. If it is +// missing, set collection[key].reset(new Value::element_type) and return that. +// Value::element_type must be default constructable. +template +typename Collection::value_type::second_type::element_type* +LookupOrInsertNewLinkedPtr( + Collection* const collection, + const typename Collection::value_type::first_type& key) { + typedef typename Collection::value_type::second_type Value; + std::pair ret = + collection->insert(typename Collection::value_type(key, Value())); + if (ret.second) { + ret.first->second.reset(new typename Value::element_type); + } + return ret.first->second.get(); +} + +// A variant of LookupOrInsertNewLinkedPtr where the value is constructed using +// a single-parameter constructor. Note: the constructor argument is computed +// even if it will not be used, so only values cheap to compute should be passed +// here. On the other hand it does not matter how expensive the construction of +// the actual stored value is, as that only occurs if necessary. +template +typename Collection::value_type::second_type::element_type* +LookupOrInsertNewLinkedPtr( + Collection* const collection, + const typename Collection::value_type::first_type& key, + const Arg& arg) { + typedef typename Collection::value_type::second_type Value; + std::pair ret = + collection->insert(typename Collection::value_type(key, Value())); + if (ret.second) { + ret.first->second.reset(new typename Value::element_type(arg)); + } + return ret.first->second.get(); +} + +// Lookup a key in a map or hash_map whose values are shared_ptrs. If it is +// missing, set collection[key].reset(new Value::element_type). Unlike +// LookupOrInsertNewLinkedPtr, this function returns the shared_ptr instead of +// the raw pointer. Value::element_type must be default constructable. +template +typename Collection::value_type::second_type& +LookupOrInsertNewSharedPtr( + Collection* const collection, + const typename Collection::value_type::first_type& key) { + typedef typename Collection::value_type::second_type SharedPtr; + typedef typename Collection::value_type::second_type::element_type Element; + std::pair ret = + collection->insert(typename Collection::value_type(key, SharedPtr())); + if (ret.second) { + ret.first->second.reset(new Element()); + } + return ret.first->second; +} + +// A variant of LookupOrInsertNewSharedPtr where the value is constructed using +// a single-parameter constructor. Note: the constructor argument is computed +// even if it will not be used, so only values cheap to compute should be passed +// here. On the other hand it does not matter how expensive the construction of +// the actual stored value is, as that only occurs if necessary. +template +typename Collection::value_type::second_type& +LookupOrInsertNewSharedPtr( + Collection* const collection, + const typename Collection::value_type::first_type& key, + const Arg& arg) { + typedef typename Collection::value_type::second_type SharedPtr; + typedef typename Collection::value_type::second_type::element_type Element; + std::pair ret = + collection->insert(typename Collection::value_type(key, SharedPtr())); + if (ret.second) { + ret.first->second.reset(new Element(arg)); + } + return ret.first->second; +} + +// +// Misc Utility Functions +// + +// Updates the value associated with the given key. If the key was not already +// present, then the key-value pair are inserted and "previous" is unchanged. If +// the key was already present, the value is updated and "*previous" will +// contain a copy of the old value. +// +// InsertOrReturnExisting has complementary behavior that returns the +// address of an already existing value, rather than updating it. +template +bool UpdateReturnCopy(Collection* const collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& value, + typename Collection::value_type::second_type* previous) { + std::pair ret = + collection->insert(typename Collection::value_type(key, value)); + if (!ret.second) { + // update + if (previous) { + *previous = ret.first->second; + } + ret.first->second = value; + return true; + } + return false; +} + +// Same as above except that the key and value are passed as a pair. +template +bool UpdateReturnCopy(Collection* const collection, + const typename Collection::value_type& vt, + typename Collection::value_type::second_type* previous) { + std::pair ret = collection->insert(vt); + if (!ret.second) { + // update + if (previous) { + *previous = ret.first->second; + } + ret.first->second = vt.second; + return true; + } + return false; +} + +// Tries to insert the given key-value pair into the collection. Returns nullptr if +// the insert succeeds. Otherwise, returns a pointer to the existing value. +// +// This complements UpdateReturnCopy in that it allows to update only after +// verifying the old value and still insert quickly without having to look up +// twice. Unlike UpdateReturnCopy this also does not come with the issue of an +// undefined previous* in case new data was inserted. +template +typename Collection::value_type::second_type* InsertOrReturnExisting( + Collection* const collection, const typename Collection::value_type& vt) { + std::pair ret = collection->insert(vt); + if (ret.second) { + return nullptr; // Inserted, no existing previous value. + } else { + return &ret.first->second; // Return address of already existing value. + } +} + +// Same as above, except for explicit key and data. +template +typename Collection::value_type::second_type* InsertOrReturnExisting( + Collection* const collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& data) { + return InsertOrReturnExisting(collection, + typename Collection::value_type(key, data)); +} + +// Erases the collection item identified by the given key, and returns the value +// associated with that key. It is assumed that the value (i.e., the +// mapped_type) is a pointer. Returns nullptr if the key was not found in the +// collection. +// +// Examples: +// map my_map; +// +// One line cleanup: +// delete EraseKeyReturnValuePtr(&my_map, "abc"); +// +// Use returned value: +// std::unique_ptr value_ptr( +// EraseKeyReturnValuePtr(&my_map, "abc")); +// if (value_ptr.get()) +// value_ptr->DoSomething(); +// +template +typename Collection::value_type::second_type EraseKeyReturnValuePtr( + Collection* const collection, + const typename Collection::value_type::first_type& key) { + typename Collection::iterator it = collection->find(key); + if (it == collection->end()) { + return nullptr; + } + typename Collection::value_type::second_type v = it->second; + collection->erase(it); + return v; +} + +// Inserts all the keys from map_container into key_container, which must +// support insert(MapContainer::key_type). +// +// Note: any initial contents of the key_container are not cleared. +template +void InsertKeysFromMap(const MapContainer& map_container, + KeyContainer* key_container) { + GOOGLE_CHECK(key_container != nullptr); + for (typename MapContainer::const_iterator it = map_container.begin(); + it != map_container.end(); ++it) { + key_container->insert(it->first); + } +} + +// Appends all the keys from map_container into key_container, which must +// support push_back(MapContainer::key_type). +// +// Note: any initial contents of the key_container are not cleared. +template +void AppendKeysFromMap(const MapContainer& map_container, + KeyContainer* key_container) { + GOOGLE_CHECK(key_container != nullptr); + for (typename MapContainer::const_iterator it = map_container.begin(); + it != map_container.end(); ++it) { + key_container->push_back(it->first); + } +} + +// A more specialized overload of AppendKeysFromMap to optimize reallocations +// for the common case in which we're appending keys to a vector and hence can +// (and sometimes should) call reserve() first. +// +// (It would be possible to play SFINAE games to call reserve() for any +// container that supports it, but this seems to get us 99% of what we need +// without the complexity of a SFINAE-based solution.) +template +void AppendKeysFromMap(const MapContainer& map_container, + std::vector* key_container) { + GOOGLE_CHECK(key_container != nullptr); + // We now have the opportunity to call reserve(). Calling reserve() every + // time is a bad idea for some use cases: libstdc++'s implementation of + // vector<>::reserve() resizes the vector's backing store to exactly the + // given size (unless it's already at least that big). Because of this, + // the use case that involves appending a lot of small maps (total size + // N) one by one to a vector would be O(N^2). But never calling reserve() + // loses the opportunity to improve the use case of adding from a large + // map to an empty vector (this improves performance by up to 33%). A + // number of heuristics are possible; see the discussion in + // cl/34081696. Here we use the simplest one. + if (key_container->empty()) { + key_container->reserve(map_container.size()); + } + for (typename MapContainer::const_iterator it = map_container.begin(); + it != map_container.end(); ++it) { + key_container->push_back(it->first); + } +} + +// Inserts all the values from map_container into value_container, which must +// support push_back(MapContainer::mapped_type). +// +// Note: any initial contents of the value_container are not cleared. +template +void AppendValuesFromMap(const MapContainer& map_container, + ValueContainer* value_container) { + GOOGLE_CHECK(value_container != nullptr); + for (typename MapContainer::const_iterator it = map_container.begin(); + it != map_container.end(); ++it) { + value_container->push_back(it->second); + } +} + +// A more specialized overload of AppendValuesFromMap to optimize reallocations +// for the common case in which we're appending values to a vector and hence +// can (and sometimes should) call reserve() first. +// +// (It would be possible to play SFINAE games to call reserve() for any +// container that supports it, but this seems to get us 99% of what we need +// without the complexity of a SFINAE-based solution.) +template +void AppendValuesFromMap(const MapContainer& map_container, + std::vector* value_container) { + GOOGLE_CHECK(value_container != nullptr); + // See AppendKeysFromMap for why this is done. + if (value_container->empty()) { + value_container->reserve(map_container.size()); + } + for (typename MapContainer::const_iterator it = map_container.begin(); + it != map_container.end(); ++it) { + value_container->push_back(it->second); + } +} + +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_STUBS_MAP_UTIL_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/mutex.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/mutex.h new file mode 100644 index 0000000000000000000000000000000000000000..2193d4493920a0d9aa9067605f933a0339c55bad --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/mutex.h @@ -0,0 +1,191 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Copyright (c) 2006, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_STUBS_MUTEX_H_ +#define GOOGLE_PROTOBUF_STUBS_MUTEX_H_ + +#include + +#ifdef GOOGLE_PROTOBUF_SUPPORT_WINDOWS_XP + +#include + +// GetMessage conflicts with GeneratedMessageReflection::GetMessage(). +#ifdef GetMessage +#undef GetMessage +#endif + +#endif + +#include + +// Define thread-safety annotations for use below, if we are building with +// Clang. +#if defined(__clang__) && !defined(SWIG) +#define GOOGLE_PROTOBUF_ACQUIRE(...) \ + __attribute__((acquire_capability(__VA_ARGS__))) +#define GOOGLE_PROTOBUF_RELEASE(...) \ + __attribute__((release_capability(__VA_ARGS__))) +#define GOOGLE_PROTOBUF_CAPABILITY(x) __attribute__((capability(x))) +#else +#define GOOGLE_PROTOBUF_ACQUIRE(...) +#define GOOGLE_PROTOBUF_RELEASE(...) +#define GOOGLE_PROTOBUF_CAPABILITY(x) +#endif + +#include + +// =================================================================== +// emulates google3/base/mutex.h +namespace google { +namespace protobuf { +namespace internal { + +#define GOOGLE_PROTOBUF_LINKER_INITIALIZED + +#ifdef GOOGLE_PROTOBUF_SUPPORT_WINDOWS_XP + +// This class is a lightweight replacement for std::mutex on Windows platforms. +// std::mutex does not work on Windows XP SP2 with the latest VC++ libraries, +// because it utilizes the Concurrency Runtime that is only supported on Windows +// XP SP3 and above. +class PROTOBUF_EXPORT CriticalSectionLock { + public: + CriticalSectionLock() { InitializeCriticalSection(&critical_section_); } + ~CriticalSectionLock() { DeleteCriticalSection(&critical_section_); } + void lock() { EnterCriticalSection(&critical_section_); } + void unlock() { LeaveCriticalSection(&critical_section_); } + + private: + CRITICAL_SECTION critical_section_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CriticalSectionLock); +}; + +#endif + +// Mutex is a natural type to wrap. As both google and other organization have +// specialized mutexes. gRPC also provides an injection mechanism for custom +// mutexes. +class GOOGLE_PROTOBUF_CAPABILITY("mutex") PROTOBUF_EXPORT WrappedMutex { + public: + WrappedMutex() = default; + void Lock() GOOGLE_PROTOBUF_ACQUIRE() { mu_.lock(); } + void Unlock() GOOGLE_PROTOBUF_RELEASE() { mu_.unlock(); } + // Crash if this Mutex is not held exclusively by this thread. + // May fail to crash when it should; will never crash when it should not. + void AssertHeld() const {} + + private: +#ifndef GOOGLE_PROTOBUF_SUPPORT_WINDOWS_XP + std::mutex mu_; +#else // ifndef GOOGLE_PROTOBUF_SUPPORT_WINDOWS_XP + CriticalSectionLock mu_; +#endif // #ifndef GOOGLE_PROTOBUF_SUPPORT_WINDOWS_XP +}; + +using Mutex = WrappedMutex; + +// MutexLock(mu) acquires mu when constructed and releases it when destroyed. +class PROTOBUF_EXPORT MutexLock { + public: + explicit MutexLock(Mutex *mu) : mu_(mu) { this->mu_->Lock(); } + ~MutexLock() { this->mu_->Unlock(); } + private: + Mutex *const mu_; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(MutexLock); +}; + +// TODO(kenton): Implement these? Hard to implement portably. +typedef MutexLock ReaderMutexLock; +typedef MutexLock WriterMutexLock; + +// MutexLockMaybe is like MutexLock, but is a no-op when mu is nullptr. +class PROTOBUF_EXPORT MutexLockMaybe { + public: + explicit MutexLockMaybe(Mutex *mu) : + mu_(mu) { if (this->mu_ != nullptr) { this->mu_->Lock(); } } + ~MutexLockMaybe() { if (this->mu_ != nullptr) { this->mu_->Unlock(); } } + private: + Mutex *const mu_; + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(MutexLockMaybe); +}; + +#if defined(GOOGLE_PROTOBUF_NO_THREADLOCAL) +template +class ThreadLocalStorage { + public: + ThreadLocalStorage() { + pthread_key_create(&key_, &ThreadLocalStorage::Delete); + } + ~ThreadLocalStorage() { + pthread_key_delete(key_); + } + T* Get() { + T* result = static_cast(pthread_getspecific(key_)); + if (result == nullptr) { + result = new T(); + pthread_setspecific(key_, result); + } + return result; + } + private: + static void Delete(void* value) { + delete static_cast(value); + } + pthread_key_t key_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ThreadLocalStorage); +}; +#endif + +} // namespace internal + +// We made these internal so that they would show up as such in the docs, +// but we don't want to stick "internal::" in front of them everywhere. +using internal::Mutex; +using internal::MutexLock; +using internal::ReaderMutexLock; +using internal::WriterMutexLock; +using internal::MutexLockMaybe; + +} // namespace protobuf +} // namespace google + +#undef GOOGLE_PROTOBUF_ACQUIRE +#undef GOOGLE_PROTOBUF_RELEASE + +#include + +#endif // GOOGLE_PROTOBUF_STUBS_MUTEX_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/once.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/once.h new file mode 100644 index 0000000000000000000000000000000000000000..66ba5987a0d85f7cbf6dea96dee596f0ba0495fc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/once.h @@ -0,0 +1,60 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_STUBS_ONCE_H__ +#define GOOGLE_PROTOBUF_STUBS_ONCE_H__ + +#include +#include + +#include + +namespace google { +namespace protobuf { +namespace internal { + +using once_flag = std::once_flag; +template +void call_once(Args&&... args ) { + std::call_once(std::forward(args)...); +} + +} // namespace internal +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_STUBS_ONCE_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/platform_macros.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/platform_macros.h new file mode 100644 index 0000000000000000000000000000000000000000..f5d154fff83ed0d2e4ffc2113ee1746a5aa109c2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/platform_macros.h @@ -0,0 +1,139 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2012 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_PLATFORM_MACROS_H_ +#define GOOGLE_PROTOBUF_PLATFORM_MACROS_H_ + +#define GOOGLE_PROTOBUF_PLATFORM_ERROR \ +#error "Host platform was not detected as supported by protobuf" + +// Processor architecture detection. For more info on what's defined, see: +// http://msdn.microsoft.com/en-us/library/b0084kay.aspx +// http://www.agner.org/optimize/calling_conventions.pdf +// or with gcc, run: "echo | gcc -E -dM -" +#if defined(_M_X64) || defined(__x86_64__) +#define GOOGLE_PROTOBUF_ARCH_X64 1 +#define GOOGLE_PROTOBUF_ARCH_64_BIT 1 +#elif defined(_M_IX86) || defined(__i386__) +#define GOOGLE_PROTOBUF_ARCH_IA32 1 +#define GOOGLE_PROTOBUF_ARCH_32_BIT 1 +#elif defined(__QNX__) +#define GOOGLE_PROTOBUF_ARCH_ARM_QNX 1 +#define GOOGLE_PROTOBUF_ARCH_32_BIT 1 +#elif defined(_M_ARM) || defined(__ARMEL__) +#define GOOGLE_PROTOBUF_ARCH_ARM 1 +#define GOOGLE_PROTOBUF_ARCH_32_BIT 1 +#elif defined(_M_ARM64) +#define GOOGLE_PROTOBUF_ARCH_ARM 1 +#define GOOGLE_PROTOBUF_ARCH_64_BIT 1 +#elif defined(__aarch64__) +#define GOOGLE_PROTOBUF_ARCH_AARCH64 1 +#define GOOGLE_PROTOBUF_ARCH_64_BIT 1 +#elif defined(__mips__) +#if defined(__LP64__) +#define GOOGLE_PROTOBUF_ARCH_MIPS64 1 +#define GOOGLE_PROTOBUF_ARCH_64_BIT 1 +#else +#define GOOGLE_PROTOBUF_ARCH_MIPS 1 +#define GOOGLE_PROTOBUF_ARCH_32_BIT 1 +#endif +#elif defined(__pnacl__) +#define GOOGLE_PROTOBUF_ARCH_32_BIT 1 +#elif defined(sparc) +#define GOOGLE_PROTOBUF_ARCH_SPARC 1 +#if defined(__sparc_v9__) || defined(__sparcv9) || defined(__arch64__) +#define GOOGLE_PROTOBUF_ARCH_64_BIT 1 +#else +#define GOOGLE_PROTOBUF_ARCH_32_BIT 1 +#endif +#elif defined(_POWER) || defined(__powerpc64__) || defined(__PPC64__) +#define GOOGLE_PROTOBUF_ARCH_POWER 1 +#define GOOGLE_PROTOBUF_ARCH_64_BIT 1 +#elif defined(__PPC__) +#define GOOGLE_PROTOBUF_ARCH_PPC 1 +#define GOOGLE_PROTOBUF_ARCH_32_BIT 1 +#elif defined(__GNUC__) +# if (((__GNUC__ == 4) && (__GNUC_MINOR__ >= 7)) || (__GNUC__ > 4)) +// We fallback to the generic Clang/GCC >= 4.7 implementation in atomicops.h +# elif defined(__clang__) +# if !__has_extension(c_atomic) +GOOGLE_PROTOBUF_PLATFORM_ERROR +# endif +// We fallback to the generic Clang/GCC >= 4.7 implementation in atomicops.h +# endif +# if __LP64__ +# define GOOGLE_PROTOBUF_ARCH_64_BIT 1 +# else +# define GOOGLE_PROTOBUF_ARCH_32_BIT 1 +# endif +#else +GOOGLE_PROTOBUF_PLATFORM_ERROR +#endif + +#if defined(__APPLE__) +#define GOOGLE_PROTOBUF_OS_APPLE +#include +#include +#if TARGET_OS_IPHONE +#define GOOGLE_PROTOBUF_OS_IPHONE +#endif +#elif defined(__EMSCRIPTEN__) +#define GOOGLE_PROTOBUF_OS_EMSCRIPTEN +#elif defined(__native_client__) +#define GOOGLE_PROTOBUF_OS_NACL +#elif defined(sun) +#define GOOGLE_PROTOBUF_OS_SOLARIS +#elif defined(_AIX) +#define GOOGLE_PROTOBUF_OS_AIX +#elif defined(__ANDROID__) +#define GOOGLE_PROTOBUF_OS_ANDROID +#endif + +#undef GOOGLE_PROTOBUF_PLATFORM_ERROR + +#if defined(GOOGLE_PROTOBUF_OS_ANDROID) || defined(GOOGLE_PROTOBUF_OS_IPHONE) || defined(__OpenBSD__) +// Android ndk does not support the __thread keyword very well yet. Here +// we use pthread_key_create()/pthread_getspecific()/... methods for +// TLS support on android. +// iOS and OpenBSD also do not support the __thread keyword. +#define GOOGLE_PROTOBUF_NO_THREADLOCAL +#endif + +#if defined(__MAC_OS_X_VERSION_MIN_REQUIRED) && __MAC_OS_X_VERSION_MIN_REQUIRED < 1070 +// __thread keyword requires at least 10.7 +#define GOOGLE_PROTOBUF_NO_THREADLOCAL +#endif + +#endif // GOOGLE_PROTOBUF_PLATFORM_MACROS_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/port.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/port.h new file mode 100644 index 0000000000000000000000000000000000000000..a46e6de0e64aa94ad2683f1742b8a56320fb86ad --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/port.h @@ -0,0 +1,410 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_STUBS_PORT_H_ +#define GOOGLE_PROTOBUF_STUBS_PORT_H_ + +#include +#include +#include +#include +#include +#include + +#include + +#include + +#undef PROTOBUF_LITTLE_ENDIAN +#ifdef _WIN32 + // Assuming windows is always little-endian. + // TODO(xiaofeng): The PROTOBUF_LITTLE_ENDIAN is not only used for + // optimization but also for correctness. We should define an + // different macro to test the big-endian code path in coded_stream. + #if !defined(PROTOBUF_DISABLE_LITTLE_ENDIAN_OPT_FOR_TEST) + #define PROTOBUF_LITTLE_ENDIAN 1 + #endif + #if _MSC_VER >= 1300 && !defined(__INTEL_COMPILER) + // If MSVC has "/RTCc" set, it will complain about truncating casts at + // runtime. This file contains some intentional truncating casts. + #pragma runtime_checks("c", off) + #endif +#else + #include // __BYTE_ORDER + #if defined(__OpenBSD__) + #include + #endif + #if ((defined(__LITTLE_ENDIAN__) && !defined(__BIG_ENDIAN__)) || \ + (defined(__BYTE_ORDER) && __BYTE_ORDER == __LITTLE_ENDIAN) || \ + (defined(BYTE_ORDER) && BYTE_ORDER == LITTLE_ENDIAN)) && \ + !defined(PROTOBUF_DISABLE_LITTLE_ENDIAN_OPT_FOR_TEST) + #define PROTOBUF_LITTLE_ENDIAN 1 + #endif +#endif + +// These #includes are for the byte swap functions declared later on. +#ifdef _MSC_VER +#include // NOLINT(build/include) +#include +#elif defined(__APPLE__) +#include +#elif defined(__GLIBC__) || defined(__BIONIC__) || defined(__CYGWIN__) +#include // IWYU pragma: export +#endif + +// Legacy: some users reference these (internal-only) macros even though we +// don't need them any more. +#if defined(_MSC_VER) && defined(PROTOBUF_USE_DLLS) + #ifdef LIBPROTOBUF_EXPORTS + #define LIBPROTOBUF_EXPORT __declspec(dllexport) + #else + #define LIBPROTOBUF_EXPORT __declspec(dllimport) + #endif + #ifdef LIBPROTOC_EXPORTS + #define LIBPROTOC_EXPORT __declspec(dllexport) + #else + #define LIBPROTOC_EXPORT __declspec(dllimport) + #endif +#else + #define LIBPROTOBUF_EXPORT + #define LIBPROTOC_EXPORT +#endif + +#define PROTOBUF_RUNTIME_DEPRECATED(message) PROTOBUF_DEPRECATED_MSG(message) +#define GOOGLE_PROTOBUF_RUNTIME_DEPRECATED(message) \ + PROTOBUF_DEPRECATED_MSG(message) + +// =================================================================== +// from google3/base/port.h + +#if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L || \ + (defined(_MSC_VER) && _MSC_VER >= 1900)) +// Define this to 1 if the code is compiled in C++11 mode; leave it +// undefined otherwise. Do NOT define it to 0 -- that causes +// '#ifdef LANG_CXX11' to behave differently from '#if LANG_CXX11'. +#define LANG_CXX11 1 +#else +#error "Protobuf requires at least C++11." +#endif + +namespace google { +namespace protobuf { + +using ConstStringParam = const std::string &; + +typedef unsigned int uint; + +typedef int8_t int8; +typedef int16_t int16; +typedef int32_t int32; +typedef int64_t int64; + +typedef uint8_t uint8; +typedef uint16_t uint16; +typedef uint32_t uint32; +typedef uint64_t uint64; + +static const int32 kint32max = 0x7FFFFFFF; +static const int32 kint32min = -kint32max - 1; +static const int64 kint64max = PROTOBUF_LONGLONG(0x7FFFFFFFFFFFFFFF); +static const int64 kint64min = -kint64max - 1; +static const uint32 kuint32max = 0xFFFFFFFFu; +static const uint64 kuint64max = PROTOBUF_ULONGLONG(0xFFFFFFFFFFFFFFFF); + +#if defined(ADDRESS_SANITIZER) || defined(THREAD_SANITIZER) ||\ + defined(MEMORY_SANITIZER) + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus +uint16_t __sanitizer_unaligned_load16(const void *p); +uint32_t __sanitizer_unaligned_load32(const void *p); +uint64_t __sanitizer_unaligned_load64(const void *p); +void __sanitizer_unaligned_store16(void *p, uint16_t v); +void __sanitizer_unaligned_store32(void *p, uint32_t v); +void __sanitizer_unaligned_store64(void *p, uint64_t v); +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +inline uint16 GOOGLE_UNALIGNED_LOAD16(const void *p) { + return __sanitizer_unaligned_load16(p); +} + +inline uint32 GOOGLE_UNALIGNED_LOAD32(const void *p) { + return __sanitizer_unaligned_load32(p); +} + +inline uint64 GOOGLE_UNALIGNED_LOAD64(const void *p) { + return __sanitizer_unaligned_load64(p); +} + +inline void GOOGLE_UNALIGNED_STORE16(void *p, uint16 v) { + __sanitizer_unaligned_store16(p, v); +} + +inline void GOOGLE_UNALIGNED_STORE32(void *p, uint32 v) { + __sanitizer_unaligned_store32(p, v); +} + +inline void GOOGLE_UNALIGNED_STORE64(void *p, uint64 v) { + __sanitizer_unaligned_store64(p, v); +} + +#elif defined(GOOGLE_PROTOBUF_USE_UNALIGNED) && GOOGLE_PROTOBUF_USE_UNALIGNED + +#define GOOGLE_UNALIGNED_LOAD16(_p) (*reinterpret_cast(_p)) +#define GOOGLE_UNALIGNED_LOAD32(_p) (*reinterpret_cast(_p)) +#define GOOGLE_UNALIGNED_LOAD64(_p) (*reinterpret_cast(_p)) + +#define GOOGLE_UNALIGNED_STORE16(_p, _val) (*reinterpret_cast(_p) = (_val)) +#define GOOGLE_UNALIGNED_STORE32(_p, _val) (*reinterpret_cast(_p) = (_val)) +#define GOOGLE_UNALIGNED_STORE64(_p, _val) (*reinterpret_cast(_p) = (_val)) + +#else +inline uint16 GOOGLE_UNALIGNED_LOAD16(const void *p) { + uint16 t; + memcpy(&t, p, sizeof t); + return t; +} + +inline uint32 GOOGLE_UNALIGNED_LOAD32(const void *p) { + uint32 t; + memcpy(&t, p, sizeof t); + return t; +} + +inline uint64 GOOGLE_UNALIGNED_LOAD64(const void *p) { + uint64 t; + memcpy(&t, p, sizeof t); + return t; +} + +inline void GOOGLE_UNALIGNED_STORE16(void *p, uint16 v) { + memcpy(p, &v, sizeof v); +} + +inline void GOOGLE_UNALIGNED_STORE32(void *p, uint32 v) { + memcpy(p, &v, sizeof v); +} + +inline void GOOGLE_UNALIGNED_STORE64(void *p, uint64 v) { + memcpy(p, &v, sizeof v); +} +#endif + +#if defined(GOOGLE_PROTOBUF_OS_NACL) \ + || (defined(__ANDROID__) && defined(__clang__) \ + && (__clang_major__ == 3 && __clang_minor__ == 8) \ + && (__clang_patchlevel__ < 275480)) +# define GOOGLE_PROTOBUF_USE_PORTABLE_LOG2 +#endif + +// The following guarantees declaration of the byte swap functions. +#ifdef _MSC_VER +#define bswap_16(x) _byteswap_ushort(x) +#define bswap_32(x) _byteswap_ulong(x) +#define bswap_64(x) _byteswap_uint64(x) + +#elif defined(__APPLE__) +// Mac OS X / Darwin features +#define bswap_16(x) OSSwapInt16(x) +#define bswap_32(x) OSSwapInt32(x) +#define bswap_64(x) OSSwapInt64(x) + +#elif !defined(__GLIBC__) && !defined(__BIONIC__) && !defined(__CYGWIN__) + +#ifndef bswap_16 +static inline uint16 bswap_16(uint16 x) { + return static_cast(((x & 0xFF) << 8) | ((x & 0xFF00) >> 8)); +} +#define bswap_16(x) bswap_16(x) +#endif + +#ifndef bswap_32 +static inline uint32 bswap_32(uint32 x) { + return (((x & 0xFF) << 24) | + ((x & 0xFF00) << 8) | + ((x & 0xFF0000) >> 8) | + ((x & 0xFF000000) >> 24)); +} +#define bswap_32(x) bswap_32(x) +#endif + +#ifndef bswap_64 +static inline uint64 bswap_64(uint64 x) { + return (((x & PROTOBUF_ULONGLONG(0xFF)) << 56) | + ((x & PROTOBUF_ULONGLONG(0xFF00)) << 40) | + ((x & PROTOBUF_ULONGLONG(0xFF0000)) << 24) | + ((x & PROTOBUF_ULONGLONG(0xFF000000)) << 8) | + ((x & PROTOBUF_ULONGLONG(0xFF00000000)) >> 8) | + ((x & PROTOBUF_ULONGLONG(0xFF0000000000)) >> 24) | + ((x & PROTOBUF_ULONGLONG(0xFF000000000000)) >> 40) | + ((x & PROTOBUF_ULONGLONG(0xFF00000000000000)) >> 56)); +} +#define bswap_64(x) bswap_64(x) +#endif + +#endif + +// =================================================================== +// from google3/util/bits/bits.h + +class Bits { + public: + static uint32 Log2FloorNonZero(uint32 n) { +#if defined(__GNUC__) + return 31 ^ static_cast(__builtin_clz(n)); +#elif defined(_MSC_VER) + unsigned long where; + _BitScanReverse(&where, n); + return where; +#else + return Log2FloorNonZero_Portable(n); +#endif + } + + static uint32 Log2FloorNonZero64(uint64 n) { + // Older versions of clang run into an instruction-selection failure when + // it encounters __builtin_clzll: + // https://bugs.chromium.org/p/nativeclient/issues/detail?id=4395 + // This includes arm-nacl-clang and clang in older Android NDK versions. + // To work around this, when we build with those we use the portable + // implementation instead. +#if defined(__GNUC__) && !defined(GOOGLE_PROTOBUF_USE_PORTABLE_LOG2) + return 63 ^ static_cast(__builtin_clzll(n)); +#elif defined(_MSC_VER) && defined(_M_X64) + unsigned long where; + _BitScanReverse64(&where, n); + return where; +#else + return Log2FloorNonZero64_Portable(n); +#endif + } + private: + static int Log2FloorNonZero_Portable(uint32 n) { + if (n == 0) + return -1; + int log = 0; + uint32 value = n; + for (int i = 4; i >= 0; --i) { + int shift = (1 << i); + uint32 x = value >> shift; + if (x != 0) { + value = x; + log += shift; + } + } + assert(value == 1); + return log; + } + + static int Log2FloorNonZero64_Portable(uint64 n) { + const uint32 topbits = static_cast(n >> 32); + if (topbits == 0) { + // Top bits are zero, so scan in bottom bits + return static_cast(Log2FloorNonZero(static_cast(n))); + } else { + return 32 + static_cast(Log2FloorNonZero(topbits)); + } + } +}; + +// =================================================================== +// from google3/util/endian/endian.h +PROTOBUF_EXPORT uint32 ghtonl(uint32 x); + +class BigEndian { + public: +#ifdef PROTOBUF_LITTLE_ENDIAN + + static uint16 FromHost16(uint16 x) { return bswap_16(x); } + static uint16 ToHost16(uint16 x) { return bswap_16(x); } + + static uint32 FromHost32(uint32 x) { return bswap_32(x); } + static uint32 ToHost32(uint32 x) { return bswap_32(x); } + + static uint64 FromHost64(uint64 x) { return bswap_64(x); } + static uint64 ToHost64(uint64 x) { return bswap_64(x); } + + static bool IsLittleEndian() { return true; } + +#else + + static uint16 FromHost16(uint16 x) { return x; } + static uint16 ToHost16(uint16 x) { return x; } + + static uint32 FromHost32(uint32 x) { return x; } + static uint32 ToHost32(uint32 x) { return x; } + + static uint64 FromHost64(uint64 x) { return x; } + static uint64 ToHost64(uint64 x) { return x; } + + static bool IsLittleEndian() { return false; } + +#endif /* ENDIAN */ + + // Functions to do unaligned loads and stores in big-endian order. + static uint16 Load16(const void *p) { + return ToHost16(GOOGLE_UNALIGNED_LOAD16(p)); + } + + static void Store16(void *p, uint16 v) { + GOOGLE_UNALIGNED_STORE16(p, FromHost16(v)); + } + + static uint32 Load32(const void *p) { + return ToHost32(GOOGLE_UNALIGNED_LOAD32(p)); + } + + static void Store32(void *p, uint32 v) { + GOOGLE_UNALIGNED_STORE32(p, FromHost32(v)); + } + + static uint64 Load64(const void *p) { + return ToHost64(GOOGLE_UNALIGNED_LOAD64(p)); + } + + static void Store64(void *p, uint64 v) { + GOOGLE_UNALIGNED_STORE64(p, FromHost64(v)); + } +}; + +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_STUBS_PORT_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/status.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/status.h new file mode 100644 index 0000000000000000000000000000000000000000..cf15c91a33ed90712c77d8ce62185caf82948cc4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/status.h @@ -0,0 +1,130 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#ifndef GOOGLE_PROTOBUF_STUBS_STATUS_H_ +#define GOOGLE_PROTOBUF_STUBS_STATUS_H_ + +#include +#include + +#include +#include + +#include + +namespace google { +namespace protobuf { +namespace util { +namespace error { +// These values must match error codes defined in google/rpc/code.proto. +enum Code { + OK = 0, + CANCELLED = 1, + UNKNOWN = 2, + INVALID_ARGUMENT = 3, + DEADLINE_EXCEEDED = 4, + NOT_FOUND = 5, + ALREADY_EXISTS = 6, + PERMISSION_DENIED = 7, + UNAUTHENTICATED = 16, + RESOURCE_EXHAUSTED = 8, + FAILED_PRECONDITION = 9, + ABORTED = 10, + OUT_OF_RANGE = 11, + UNIMPLEMENTED = 12, + INTERNAL = 13, + UNAVAILABLE = 14, + DATA_LOSS = 15, +}; +} // namespace error + +class PROTOBUF_EXPORT Status { + public: + // Creates a "successful" status. + Status(); + + // Create a status in the canonical error space with the specified + // code, and error message. If "code == 0", error_message is + // ignored and a Status object identical to Status::OK is + // constructed. + Status(error::Code error_code, StringPiece error_message); + Status(const Status&); + Status& operator=(const Status& x); + ~Status() {} + + // Some pre-defined Status objects + static const Status OK; // Identical to 0-arg constructor + static const Status CANCELLED; + static const Status UNKNOWN; + + // Accessor + bool ok() const { + return error_code_ == error::OK; + } + int error_code() const { + return error_code_; + } + error::Code code() const { + return error_code_; + } + StringPiece error_message() const { + return error_message_; + } + StringPiece message() const { + return error_message_; + } + + bool operator==(const Status& x) const; + bool operator!=(const Status& x) const { + return !operator==(x); + } + + // Return a combination of the error code name and message. + string ToString() const; + + private: + error::Code error_code_; + string error_message_; +}; + +// Prints a human-readable representation of 'x' to 'os'. +PROTOBUF_EXPORT std::ostream& operator<<(std::ostream& os, const Status& x); + +} // namespace util +} // namespace protobuf +} // namespace google + +#include + +#endif // GOOGLE_PROTOBUF_STUBS_STATUS_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/stl_util.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/stl_util.h new file mode 100644 index 0000000000000000000000000000000000000000..89ca9b2fdfa10baff42682ed41ad66e0efed44f5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/stl_util.h @@ -0,0 +1,76 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// from google3/util/gtl/stl_util.h + +#ifndef GOOGLE_PROTOBUF_STUBS_STL_UTIL_H__ +#define GOOGLE_PROTOBUF_STUBS_STL_UTIL_H__ + +#include + +namespace google { +namespace protobuf { + +// Inside Google, this function implements a horrible, disgusting hack in which +// we reach into the string's private implementation and resize it without +// initializing the new bytes. In some cases doing this can significantly +// improve performance. However, since it's totally non-portable it has no +// place in open source code. Feel free to fill this function in with your +// own disgusting hack if you want the perf boost. +inline void STLStringResizeUninitialized(string* s, size_t new_size) { + s->resize(new_size); +} + +// Return a mutable char* pointing to a string's internal buffer, +// which may not be null-terminated. Writing through this pointer will +// modify the string. +// +// string_as_array(&str)[i] is valid for 0 <= i < str.size() until the +// next call to a string method that invalidates iterators. +// +// As of 2006-04, there is no standard-blessed way of getting a +// mutable reference to a string's internal buffer. However, issue 530 +// (http://www.open-std.org/JTC1/SC22/WG21/docs/lwg-active.html#530) +// proposes this as the method. According to Matt Austern, this should +// already work on all current implementations. +inline char* string_as_array(string* str) { + // DO NOT USE const_cast(str->data())! See the unittest for why. + return str->empty() ? nullptr : &*str->begin(); +} + +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_STUBS_STL_UTIL_H__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/stringpiece.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/stringpiece.h new file mode 100644 index 0000000000000000000000000000000000000000..b1c17f2605f7511f072eb7c66b22c127b25830dc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/stringpiece.h @@ -0,0 +1,494 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// A StringPiece points to part or all of a string, Cord, double-quoted string +// literal, or other string-like object. A StringPiece does *not* own the +// string to which it points. A StringPiece is not null-terminated. +// +// You can use StringPiece as a function or method parameter. A StringPiece +// parameter can receive a double-quoted string literal argument, a "const +// char*" argument, a string argument, or a StringPiece argument with no data +// copying. Systematic use of StringPiece for arguments reduces data +// copies and strlen() calls. +// +// Prefer passing StringPieces by value: +// void MyFunction(StringPiece arg); +// If circumstances require, you may also pass by const reference: +// void MyFunction(const StringPiece& arg); // not preferred +// Both of these have the same lifetime semantics. Passing by value +// generates slightly smaller code. For more discussion, see the thread +// go/stringpiecebyvalue on c-users. +// +// StringPiece is also suitable for local variables if you know that +// the lifetime of the underlying object is longer than the lifetime +// of your StringPiece variable. +// +// Beware of binding a StringPiece to a temporary: +// StringPiece sp = obj.MethodReturningString(); // BAD: lifetime problem +// +// This code is okay: +// string str = obj.MethodReturningString(); // str owns its contents +// StringPiece sp(str); // GOOD, because str outlives sp +// +// StringPiece is sometimes a poor choice for a return value and usually a poor +// choice for a data member. If you do use a StringPiece this way, it is your +// responsibility to ensure that the object pointed to by the StringPiece +// outlives the StringPiece. +// +// A StringPiece may represent just part of a string; thus the name "Piece". +// For example, when splitting a string, vector is a natural data +// type for the output. For another example, a Cord is a non-contiguous, +// potentially very long string-like object. The Cord class has an interface +// that iteratively provides StringPiece objects that point to the +// successive pieces of a Cord object. +// +// A StringPiece is not null-terminated. If you write code that scans a +// StringPiece, you must check its length before reading any characters. +// Common idioms that work on null-terminated strings do not work on +// StringPiece objects. +// +// There are several ways to create a null StringPiece: +// StringPiece() +// StringPiece(nullptr) +// StringPiece(nullptr, 0) +// For all of the above, sp.data() == nullptr, sp.length() == 0, +// and sp.empty() == true. Also, if you create a StringPiece with +// a non-null pointer then sp.data() != nullptr. Once created, +// sp.data() will stay either nullptr or not-nullptr, except if you call +// sp.clear() or sp.set(). +// +// Thus, you can use StringPiece(nullptr) to signal an out-of-band value +// that is different from other StringPiece values. This is similar +// to the way that const char* p1 = nullptr; is different from +// const char* p2 = "";. +// +// There are many ways to create an empty StringPiece: +// StringPiece() +// StringPiece(nullptr) +// StringPiece(nullptr, 0) +// StringPiece("") +// StringPiece("", 0) +// StringPiece("abcdef", 0) +// StringPiece("abcdef"+6, 0) +// For all of the above, sp.length() will be 0 and sp.empty() will be true. +// For some empty StringPiece values, sp.data() will be nullptr. +// For some empty StringPiece values, sp.data() will not be nullptr. +// +// Be careful not to confuse: null StringPiece and empty StringPiece. +// The set of empty StringPieces properly includes the set of null StringPieces. +// That is, every null StringPiece is an empty StringPiece, +// but some non-null StringPieces are empty Stringpieces too. +// +// All empty StringPiece values compare equal to each other. +// Even a null StringPieces compares equal to a non-null empty StringPiece: +// StringPiece() == StringPiece("", 0) +// StringPiece(nullptr) == StringPiece("abc", 0) +// StringPiece(nullptr, 0) == StringPiece("abcdef"+6, 0) +// +// Look carefully at this example: +// StringPiece("") == nullptr +// True or false? TRUE, because StringPiece::operator== converts +// the right-hand side from nullptr to StringPiece(nullptr), +// and then compares two zero-length spans of characters. +// However, we are working to make this example produce a compile error. +// +// Suppose you want to write: +// bool TestWhat?(StringPiece sp) { return sp == nullptr; } // BAD +// Do not do that. Write one of these instead: +// bool TestNull(StringPiece sp) { return sp.data() == nullptr; } +// bool TestEmpty(StringPiece sp) { return sp.empty(); } +// The intent of TestWhat? is unclear. Did you mean TestNull or TestEmpty? +// Right now, TestWhat? behaves likes TestEmpty. +// We are working to make TestWhat? produce a compile error. +// TestNull is good to test for an out-of-band signal. +// TestEmpty is good to test for an empty StringPiece. +// +// Caveats (again): +// (1) The lifetime of the pointed-to string (or piece of a string) +// must be longer than the lifetime of the StringPiece. +// (2) There may or may not be a '\0' character after the end of +// StringPiece data. +// (3) A null StringPiece is empty. +// An empty StringPiece may or may not be a null StringPiece. + +#ifndef GOOGLE_PROTOBUF_STUBS_STRINGPIECE_H_ +#define GOOGLE_PROTOBUF_STUBS_STRINGPIECE_H_ + +#include +#include +#include +#include +#include +#include + +#include + +#include + +namespace google { +namespace protobuf { +// StringPiece has *two* size types. +// StringPiece::size_type +// is unsigned +// is 32 bits in LP32, 64 bits in LP64, 64 bits in LLP64 +// no future changes intended +// stringpiece_ssize_type +// is signed +// is 32 bits in LP32, 64 bits in LP64, 64 bits in LLP64 +// future changes intended: http://go/64BitStringPiece +// +typedef std::string::difference_type stringpiece_ssize_type; + +// STRINGPIECE_CHECK_SIZE protects us from 32-bit overflows. +// TODO(mec): delete this after stringpiece_ssize_type goes 64 bit. +#if !defined(NDEBUG) +#define STRINGPIECE_CHECK_SIZE 1 +#elif defined(_FORTIFY_SOURCE) && _FORTIFY_SOURCE > 0 +#define STRINGPIECE_CHECK_SIZE 1 +#else +#define STRINGPIECE_CHECK_SIZE 0 +#endif + +class PROTOBUF_EXPORT StringPiece { + private: + const char* ptr_; + stringpiece_ssize_type length_; + + // Prevent overflow in debug mode or fortified mode. + // sizeof(stringpiece_ssize_type) may be smaller than sizeof(size_t). + static stringpiece_ssize_type CheckedSsizeTFromSizeT(size_t size) { +#if STRINGPIECE_CHECK_SIZE > 0 +#ifdef max +#undef max +#endif + if (size > static_cast( + std::numeric_limits::max())) { + // Some people grep for this message in logs + // so take care if you ever change it. + LogFatalSizeTooBig(size, "size_t to int conversion"); + } +#endif + return static_cast(size); + } + + // Out-of-line error path. + static void LogFatalSizeTooBig(size_t size, const char* details); + + public: + // We provide non-explicit singleton constructors so users can pass + // in a "const char*" or a "string" wherever a "StringPiece" is + // expected. + // + // Style guide exception granted: + // http://goto/style-guide-exception-20978288 + StringPiece() : ptr_(nullptr), length_(0) {} + + StringPiece(const char* str) // NOLINT(runtime/explicit) + : ptr_(str), length_(0) { + if (str != nullptr) { + length_ = CheckedSsizeTFromSizeT(strlen(str)); + } + } + + template + StringPiece( // NOLINT(runtime/explicit) + const std::basic_string, Allocator>& str) + : ptr_(str.data()), length_(0) { + length_ = CheckedSsizeTFromSizeT(str.size()); + } + + StringPiece(const char* offset, stringpiece_ssize_type len) + : ptr_(offset), length_(len) { + assert(len >= 0); + } + + // Substring of another StringPiece. + // pos must be non-negative and <= x.length(). + StringPiece(StringPiece x, stringpiece_ssize_type pos); + // Substring of another StringPiece. + // pos must be non-negative and <= x.length(). + // len must be non-negative and will be pinned to at most x.length() - pos. + StringPiece(StringPiece x, + stringpiece_ssize_type pos, + stringpiece_ssize_type len); + + // data() may return a pointer to a buffer with embedded NULs, and the + // returned buffer may or may not be null terminated. Therefore it is + // typically a mistake to pass data() to a routine that expects a NUL + // terminated string. + const char* data() const { return ptr_; } + stringpiece_ssize_type size() const { return length_; } + stringpiece_ssize_type length() const { return length_; } + bool empty() const { return length_ == 0; } + + void clear() { + ptr_ = nullptr; + length_ = 0; + } + + void set(const char* data, stringpiece_ssize_type len) { + assert(len >= 0); + ptr_ = data; + length_ = len; + } + + void set(const char* str) { + ptr_ = str; + if (str != nullptr) + length_ = CheckedSsizeTFromSizeT(strlen(str)); + else + length_ = 0; + } + + void set(const void* data, stringpiece_ssize_type len) { + ptr_ = reinterpret_cast(data); + length_ = len; + } + + char operator[](stringpiece_ssize_type i) const { + assert(0 <= i); + assert(i < length_); + return ptr_[i]; + } + + void remove_prefix(stringpiece_ssize_type n) { + assert(length_ >= n); + ptr_ += n; + length_ -= n; + } + + void remove_suffix(stringpiece_ssize_type n) { + assert(length_ >= n); + length_ -= n; + } + + // returns {-1, 0, 1} + int compare(StringPiece x) const { + const stringpiece_ssize_type min_size = + length_ < x.length_ ? length_ : x.length_; + int r = memcmp(ptr_, x.ptr_, static_cast(min_size)); + if (r < 0) return -1; + if (r > 0) return 1; + if (length_ < x.length_) return -1; + if (length_ > x.length_) return 1; + return 0; + } + + std::string as_string() const { return ToString(); } + // We also define ToString() here, since many other string-like + // interfaces name the routine that converts to a C++ string + // "ToString", and it's confusing to have the method that does that + // for a StringPiece be called "as_string()". We also leave the + // "as_string()" method defined here for existing code. + std::string ToString() const { + if (ptr_ == nullptr) return ""; + return std::string(data(), static_cast(size())); + } + + explicit operator std::string() const { return ToString(); } + + void CopyToString(std::string* target) const; + void AppendToString(std::string* target) const; + + bool starts_with(StringPiece x) const { + return (length_ >= x.length_) && + (memcmp(ptr_, x.ptr_, static_cast(x.length_)) == 0); + } + + bool ends_with(StringPiece x) const { + return ((length_ >= x.length_) && + (memcmp(ptr_ + (length_-x.length_), x.ptr_, + static_cast(x.length_)) == 0)); + } + + // Checks whether StringPiece starts with x and if so advances the beginning + // of it to past the match. It's basically a shortcut for starts_with + // followed by remove_prefix. + bool Consume(StringPiece x); + // Like above but for the end of the string. + bool ConsumeFromEnd(StringPiece x); + + // standard STL container boilerplate + typedef char value_type; + typedef const char* pointer; + typedef const char& reference; + typedef const char& const_reference; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + static const size_type npos; + typedef const char* const_iterator; + typedef const char* iterator; + typedef std::reverse_iterator const_reverse_iterator; + typedef std::reverse_iterator reverse_iterator; + iterator begin() const { return ptr_; } + iterator end() const { return ptr_ + length_; } + const_reverse_iterator rbegin() const { + return const_reverse_iterator(ptr_ + length_); + } + const_reverse_iterator rend() const { + return const_reverse_iterator(ptr_); + } + stringpiece_ssize_type max_size() const { return length_; } + stringpiece_ssize_type capacity() const { return length_; } + + // cpplint.py emits a false positive [build/include_what_you_use] + stringpiece_ssize_type copy(char* buf, size_type n, size_type pos = 0) const; // NOLINT + + bool contains(StringPiece s) const; + + stringpiece_ssize_type find(StringPiece s, size_type pos = 0) const; + stringpiece_ssize_type find(char c, size_type pos = 0) const; + stringpiece_ssize_type rfind(StringPiece s, size_type pos = npos) const; + stringpiece_ssize_type rfind(char c, size_type pos = npos) const; + + stringpiece_ssize_type find_first_of(StringPiece s, size_type pos = 0) const; + stringpiece_ssize_type find_first_of(char c, size_type pos = 0) const { + return find(c, pos); + } + stringpiece_ssize_type find_first_not_of(StringPiece s, + size_type pos = 0) const; + stringpiece_ssize_type find_first_not_of(char c, size_type pos = 0) const; + stringpiece_ssize_type find_last_of(StringPiece s, + size_type pos = npos) const; + stringpiece_ssize_type find_last_of(char c, size_type pos = npos) const { + return rfind(c, pos); + } + stringpiece_ssize_type find_last_not_of(StringPiece s, + size_type pos = npos) const; + stringpiece_ssize_type find_last_not_of(char c, size_type pos = npos) const; + + StringPiece substr(size_type pos, size_type n = npos) const; +}; + +// This large function is defined inline so that in a fairly common case where +// one of the arguments is a literal, the compiler can elide a lot of the +// following comparisons. +inline bool operator==(StringPiece x, StringPiece y) { + stringpiece_ssize_type len = x.size(); + if (len != y.size()) { + return false; + } + + return x.data() == y.data() || len <= 0 || + memcmp(x.data(), y.data(), static_cast(len)) == 0; +} + +inline bool operator!=(StringPiece x, StringPiece y) { + return !(x == y); +} + +inline bool operator<(StringPiece x, StringPiece y) { + const stringpiece_ssize_type min_size = + x.size() < y.size() ? x.size() : y.size(); + const int r = memcmp(x.data(), y.data(), static_cast(min_size)); + return (r < 0) || (r == 0 && x.size() < y.size()); +} + +inline bool operator>(StringPiece x, StringPiece y) { + return y < x; +} + +inline bool operator<=(StringPiece x, StringPiece y) { + return !(x > y); +} + +inline bool operator>=(StringPiece x, StringPiece y) { + return !(x < y); +} + +// allow StringPiece to be logged +extern std::ostream& operator<<(std::ostream& o, StringPiece piece); + +namespace internal { +// StringPiece is not a POD and can not be used in an union (pre C++11). We +// need a POD version of it. +struct StringPiecePod { + // Create from a StringPiece. + static StringPiecePod CreateFromStringPiece(StringPiece str) { + StringPiecePod pod; + pod.data_ = str.data(); + pod.size_ = str.size(); + return pod; + } + + // Cast to StringPiece. + operator StringPiece() const { return StringPiece(data_, size_); } + + bool operator==(const char* value) const { + return StringPiece(data_, size_) == StringPiece(value); + } + + char operator[](stringpiece_ssize_type i) const { + assert(0 <= i); + assert(i < size_); + return data_[i]; + } + + const char* data() const { return data_; } + + stringpiece_ssize_type size() const { + return size_; + } + + std::string ToString() const { + return std::string(data_, static_cast(size_)); + } + + explicit operator std::string() const { return ToString(); } + + private: + const char* data_; + stringpiece_ssize_type size_; +}; + +} // namespace internal +} // namespace protobuf +} // namespace google + +GOOGLE_PROTOBUF_HASH_NAMESPACE_DECLARATION_START +template<> struct hash { + size_t operator()(const StringPiece& s) const { + size_t result = 0; + for (const char *str = s.data(), *end = str + s.size(); str < end; str++) { + result = 5 * result + static_cast(*str); + } + return result; + } +}; +GOOGLE_PROTOBUF_HASH_NAMESPACE_DECLARATION_END + +#include + +#endif // STRINGS_STRINGPIECE_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/strutil.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/strutil.h new file mode 100644 index 0000000000000000000000000000000000000000..c5fdd08e00c66772b4ba1054c149dca12fcc8e9a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/stubs/strutil.h @@ -0,0 +1,952 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// from google3/strings/strutil.h + +#ifndef GOOGLE_PROTOBUF_STUBS_STRUTIL_H__ +#define GOOGLE_PROTOBUF_STUBS_STRUTIL_H__ + +#include +#include +#include + +#include +#include +#include + +namespace google { +namespace protobuf { + +#if defined(_MSC_VER) && _MSC_VER < 1800 +#define strtoll _strtoi64 +#define strtoull _strtoui64 +#elif defined(__DECCXX) && defined(__osf__) +// HP C++ on Tru64 does not have strtoll, but strtol is already 64-bit. +#define strtoll strtol +#define strtoull strtoul +#endif + +// ---------------------------------------------------------------------- +// ascii_isalnum() +// Check if an ASCII character is alphanumeric. We can't use ctype's +// isalnum() because it is affected by locale. This function is applied +// to identifiers in the protocol buffer language, not to natural-language +// strings, so locale should not be taken into account. +// ascii_isdigit() +// Like above, but only accepts digits. +// ascii_isspace() +// Check if the character is a space character. +// ---------------------------------------------------------------------- + +inline bool ascii_isalnum(char c) { + return ('a' <= c && c <= 'z') || + ('A' <= c && c <= 'Z') || + ('0' <= c && c <= '9'); +} + +inline bool ascii_isdigit(char c) { + return ('0' <= c && c <= '9'); +} + +inline bool ascii_isspace(char c) { + return c == ' ' || c == '\t' || c == '\n' || c == '\v' || c == '\f' || + c == '\r'; +} + +inline bool ascii_isupper(char c) { + return c >= 'A' && c <= 'Z'; +} + +inline bool ascii_islower(char c) { + return c >= 'a' && c <= 'z'; +} + +inline char ascii_toupper(char c) { + return ascii_islower(c) ? c - ('a' - 'A') : c; +} + +inline char ascii_tolower(char c) { + return ascii_isupper(c) ? c + ('a' - 'A') : c; +} + +inline int hex_digit_to_int(char c) { + /* Assume ASCII. */ + int x = static_cast(c); + if (x > '9') { + x += 9; + } + return x & 0xf; +} + +// ---------------------------------------------------------------------- +// HasPrefixString() +// Check if a string begins with a given prefix. +// StripPrefixString() +// Given a string and a putative prefix, returns the string minus the +// prefix string if the prefix matches, otherwise the original +// string. +// ---------------------------------------------------------------------- +inline bool HasPrefixString(StringPiece str, StringPiece prefix) { + return str.size() >= prefix.size() && + memcmp(str.data(), prefix.data(), prefix.size()) == 0; +} + +inline string StripPrefixString(const string& str, const string& prefix) { + if (HasPrefixString(str, prefix)) { + return str.substr(prefix.size()); + } else { + return str; + } +} + +// ---------------------------------------------------------------------- +// HasSuffixString() +// Return true if str ends in suffix. +// StripSuffixString() +// Given a string and a putative suffix, returns the string minus the +// suffix string if the suffix matches, otherwise the original +// string. +// ---------------------------------------------------------------------- +inline bool HasSuffixString(StringPiece str, StringPiece suffix) { + return str.size() >= suffix.size() && + memcmp(str.data() + str.size() - suffix.size(), suffix.data(), + suffix.size()) == 0; +} + +inline string StripSuffixString(const string& str, const string& suffix) { + if (HasSuffixString(str, suffix)) { + return str.substr(0, str.size() - suffix.size()); + } else { + return str; + } +} + +// ---------------------------------------------------------------------- +// ReplaceCharacters +// Replaces any occurrence of the character 'remove' (or the characters +// in 'remove') with the character 'replacewith'. +// Good for keeping html characters or protocol characters (\t) out +// of places where they might cause a problem. +// StripWhitespace +// Removes whitespaces from both ends of the given string. +// ---------------------------------------------------------------------- +PROTOBUF_EXPORT void ReplaceCharacters(string* s, const char* remove, + char replacewith); + +PROTOBUF_EXPORT void StripWhitespace(string* s); + +// ---------------------------------------------------------------------- +// LowerString() +// UpperString() +// ToUpper() +// Convert the characters in "s" to lowercase or uppercase. ASCII-only: +// these functions intentionally ignore locale because they are applied to +// identifiers used in the Protocol Buffer language, not to natural-language +// strings. +// ---------------------------------------------------------------------- + +inline void LowerString(string * s) { + string::iterator end = s->end(); + for (string::iterator i = s->begin(); i != end; ++i) { + // tolower() changes based on locale. We don't want this! + if ('A' <= *i && *i <= 'Z') *i += 'a' - 'A'; + } +} + +inline void UpperString(string * s) { + string::iterator end = s->end(); + for (string::iterator i = s->begin(); i != end; ++i) { + // toupper() changes based on locale. We don't want this! + if ('a' <= *i && *i <= 'z') *i += 'A' - 'a'; + } +} + +inline void ToUpper(string* s) { UpperString(s); } + +inline string ToUpper(const string& s) { + string out = s; + UpperString(&out); + return out; +} + +// ---------------------------------------------------------------------- +// StringReplace() +// Give me a string and two patterns "old" and "new", and I replace +// the first instance of "old" in the string with "new", if it +// exists. RETURN a new string, regardless of whether the replacement +// happened or not. +// ---------------------------------------------------------------------- + +PROTOBUF_EXPORT string StringReplace(const string& s, const string& oldsub, + const string& newsub, bool replace_all); + +// ---------------------------------------------------------------------- +// SplitStringUsing() +// Split a string using a character delimiter. Append the components +// to 'result'. If there are consecutive delimiters, this function skips +// over all of them. +// ---------------------------------------------------------------------- +PROTOBUF_EXPORT void SplitStringUsing(StringPiece full, const char* delim, + std::vector* res); + +// Split a string using one or more byte delimiters, presented +// as a nul-terminated c string. Append the components to 'result'. +// If there are consecutive delimiters, this function will return +// corresponding empty strings. If you want to drop the empty +// strings, try SplitStringUsing(). +// +// If "full" is the empty string, yields an empty string as the only value. +// ---------------------------------------------------------------------- +PROTOBUF_EXPORT void SplitStringAllowEmpty(StringPiece full, const char* delim, + std::vector* result); + +// ---------------------------------------------------------------------- +// Split() +// Split a string using a character delimiter. +// ---------------------------------------------------------------------- +inline std::vector Split(StringPiece full, const char* delim, + bool skip_empty = true) { + std::vector result; + if (skip_empty) { + SplitStringUsing(full, delim, &result); + } else { + SplitStringAllowEmpty(full, delim, &result); + } + return result; +} + +// ---------------------------------------------------------------------- +// JoinStrings() +// These methods concatenate a vector of strings into a C++ string, using +// the C-string "delim" as a separator between components. There are two +// flavors of the function, one flavor returns the concatenated string, +// another takes a pointer to the target string. In the latter case the +// target string is cleared and overwritten. +// ---------------------------------------------------------------------- +PROTOBUF_EXPORT void JoinStrings(const std::vector& components, + const char* delim, string* result); + +inline string JoinStrings(const std::vector& components, + const char* delim) { + string result; + JoinStrings(components, delim, &result); + return result; +} + +// ---------------------------------------------------------------------- +// UnescapeCEscapeSequences() +// Copies "source" to "dest", rewriting C-style escape sequences +// -- '\n', '\r', '\\', '\ooo', etc -- to their ASCII +// equivalents. "dest" must be sufficiently large to hold all +// the characters in the rewritten string (i.e. at least as large +// as strlen(source) + 1 should be safe, since the replacements +// are always shorter than the original escaped sequences). It's +// safe for source and dest to be the same. RETURNS the length +// of dest. +// +// It allows hex sequences \xhh, or generally \xhhhhh with an +// arbitrary number of hex digits, but all of them together must +// specify a value of a single byte (e.g. \x0045 is equivalent +// to \x45, and \x1234 is erroneous). +// +// It also allows escape sequences of the form \uhhhh (exactly four +// hex digits, upper or lower case) or \Uhhhhhhhh (exactly eight +// hex digits, upper or lower case) to specify a Unicode code +// point. The dest array will contain the UTF8-encoded version of +// that code-point (e.g., if source contains \u2019, then dest will +// contain the three bytes 0xE2, 0x80, and 0x99). +// +// Errors: In the first form of the call, errors are reported with +// LOG(ERROR). The same is true for the second form of the call if +// the pointer to the string std::vector is nullptr; otherwise, error +// messages are stored in the std::vector. In either case, the effect on +// the dest array is not defined, but rest of the source will be +// processed. +// ---------------------------------------------------------------------- + +PROTOBUF_EXPORT int UnescapeCEscapeSequences(const char* source, char* dest); +PROTOBUF_EXPORT int UnescapeCEscapeSequences(const char* source, char* dest, + std::vector* errors); + +// ---------------------------------------------------------------------- +// UnescapeCEscapeString() +// This does the same thing as UnescapeCEscapeSequences, but creates +// a new string. The caller does not need to worry about allocating +// a dest buffer. This should be used for non performance critical +// tasks such as printing debug messages. It is safe for src and dest +// to be the same. +// +// The second call stores its errors in a supplied string vector. +// If the string vector pointer is nullptr, it reports the errors with LOG(). +// +// In the first and second calls, the length of dest is returned. In the +// the third call, the new string is returned. +// ---------------------------------------------------------------------- + +PROTOBUF_EXPORT int UnescapeCEscapeString(const string& src, string* dest); +PROTOBUF_EXPORT int UnescapeCEscapeString(const string& src, string* dest, + std::vector* errors); +PROTOBUF_EXPORT string UnescapeCEscapeString(const string& src); + +// ---------------------------------------------------------------------- +// CEscape() +// Escapes 'src' using C-style escape sequences and returns the resulting +// string. +// +// Escaped chars: \n, \r, \t, ", ', \, and !isprint(). +// ---------------------------------------------------------------------- +PROTOBUF_EXPORT string CEscape(const string& src); + +// ---------------------------------------------------------------------- +// CEscapeAndAppend() +// Escapes 'src' using C-style escape sequences, and appends the escaped +// string to 'dest'. +// ---------------------------------------------------------------------- +PROTOBUF_EXPORT void CEscapeAndAppend(StringPiece src, string* dest); + +namespace strings { +// Like CEscape() but does not escape bytes with the upper bit set. +PROTOBUF_EXPORT string Utf8SafeCEscape(const string& src); + +// Like CEscape() but uses hex (\x) escapes instead of octals. +PROTOBUF_EXPORT string CHexEscape(const string& src); +} // namespace strings + +// ---------------------------------------------------------------------- +// strto32() +// strtou32() +// strto64() +// strtou64() +// Architecture-neutral plug compatible replacements for strtol() and +// strtoul(). Long's have different lengths on ILP-32 and LP-64 +// platforms, so using these is safer, from the point of view of +// overflow behavior, than using the standard libc functions. +// ---------------------------------------------------------------------- +PROTOBUF_EXPORT int32 strto32_adaptor(const char* nptr, char** endptr, + int base); +PROTOBUF_EXPORT uint32 strtou32_adaptor(const char* nptr, char** endptr, + int base); + +inline int32 strto32(const char *nptr, char **endptr, int base) { + if (sizeof(int32) == sizeof(long)) + return strtol(nptr, endptr, base); + else + return strto32_adaptor(nptr, endptr, base); +} + +inline uint32 strtou32(const char *nptr, char **endptr, int base) { + if (sizeof(uint32) == sizeof(unsigned long)) + return strtoul(nptr, endptr, base); + else + return strtou32_adaptor(nptr, endptr, base); +} + +// For now, long long is 64-bit on all the platforms we care about, so these +// functions can simply pass the call to strto[u]ll. +inline int64 strto64(const char *nptr, char **endptr, int base) { + GOOGLE_COMPILE_ASSERT(sizeof(int64) == sizeof(long long), + sizeof_int64_is_not_sizeof_long_long); + return strtoll(nptr, endptr, base); +} + +inline uint64 strtou64(const char *nptr, char **endptr, int base) { + GOOGLE_COMPILE_ASSERT(sizeof(uint64) == sizeof(unsigned long long), + sizeof_uint64_is_not_sizeof_long_long); + return strtoull(nptr, endptr, base); +} + +// ---------------------------------------------------------------------- +// safe_strtob() +// safe_strto32() +// safe_strtou32() +// safe_strto64() +// safe_strtou64() +// safe_strtof() +// safe_strtod() +// ---------------------------------------------------------------------- +PROTOBUF_EXPORT bool safe_strtob(StringPiece str, bool* value); + +PROTOBUF_EXPORT bool safe_strto32(const string& str, int32* value); +PROTOBUF_EXPORT bool safe_strtou32(const string& str, uint32* value); +inline bool safe_strto32(const char* str, int32* value) { + return safe_strto32(string(str), value); +} +inline bool safe_strto32(StringPiece str, int32* value) { + return safe_strto32(str.ToString(), value); +} +inline bool safe_strtou32(const char* str, uint32* value) { + return safe_strtou32(string(str), value); +} +inline bool safe_strtou32(StringPiece str, uint32* value) { + return safe_strtou32(str.ToString(), value); +} + +PROTOBUF_EXPORT bool safe_strto64(const string& str, int64* value); +PROTOBUF_EXPORT bool safe_strtou64(const string& str, uint64* value); +inline bool safe_strto64(const char* str, int64* value) { + return safe_strto64(string(str), value); +} +inline bool safe_strto64(StringPiece str, int64* value) { + return safe_strto64(str.ToString(), value); +} +inline bool safe_strtou64(const char* str, uint64* value) { + return safe_strtou64(string(str), value); +} +inline bool safe_strtou64(StringPiece str, uint64* value) { + return safe_strtou64(str.ToString(), value); +} + +PROTOBUF_EXPORT bool safe_strtof(const char* str, float* value); +PROTOBUF_EXPORT bool safe_strtod(const char* str, double* value); +inline bool safe_strtof(const string& str, float* value) { + return safe_strtof(str.c_str(), value); +} +inline bool safe_strtod(const string& str, double* value) { + return safe_strtod(str.c_str(), value); +} +inline bool safe_strtof(StringPiece str, float* value) { + return safe_strtof(str.ToString(), value); +} +inline bool safe_strtod(StringPiece str, double* value) { + return safe_strtod(str.ToString(), value); +} + +// ---------------------------------------------------------------------- +// FastIntToBuffer() +// FastHexToBuffer() +// FastHex64ToBuffer() +// FastHex32ToBuffer() +// FastTimeToBuffer() +// These are intended for speed. FastIntToBuffer() assumes the +// integer is non-negative. FastHexToBuffer() puts output in +// hex rather than decimal. FastTimeToBuffer() puts the output +// into RFC822 format. +// +// FastHex64ToBuffer() puts a 64-bit unsigned value in hex-format, +// padded to exactly 16 bytes (plus one byte for '\0') +// +// FastHex32ToBuffer() puts a 32-bit unsigned value in hex-format, +// padded to exactly 8 bytes (plus one byte for '\0') +// +// All functions take the output buffer as an arg. +// They all return a pointer to the beginning of the output, +// which may not be the beginning of the input buffer. +// ---------------------------------------------------------------------- + +// Suggested buffer size for FastToBuffer functions. Also works with +// DoubleToBuffer() and FloatToBuffer(). +static const int kFastToBufferSize = 32; + +PROTOBUF_EXPORT char* FastInt32ToBuffer(int32 i, char* buffer); +PROTOBUF_EXPORT char* FastInt64ToBuffer(int64 i, char* buffer); +char* FastUInt32ToBuffer(uint32 i, char* buffer); // inline below +char* FastUInt64ToBuffer(uint64 i, char* buffer); // inline below +PROTOBUF_EXPORT char* FastHexToBuffer(int i, char* buffer); +PROTOBUF_EXPORT char* FastHex64ToBuffer(uint64 i, char* buffer); +PROTOBUF_EXPORT char* FastHex32ToBuffer(uint32 i, char* buffer); + +// at least 22 bytes long +inline char* FastIntToBuffer(int i, char* buffer) { + return (sizeof(i) == 4 ? + FastInt32ToBuffer(i, buffer) : FastInt64ToBuffer(i, buffer)); +} +inline char* FastUIntToBuffer(unsigned int i, char* buffer) { + return (sizeof(i) == 4 ? + FastUInt32ToBuffer(i, buffer) : FastUInt64ToBuffer(i, buffer)); +} +inline char* FastLongToBuffer(long i, char* buffer) { + return (sizeof(i) == 4 ? + FastInt32ToBuffer(i, buffer) : FastInt64ToBuffer(i, buffer)); +} +inline char* FastULongToBuffer(unsigned long i, char* buffer) { + return (sizeof(i) == 4 ? + FastUInt32ToBuffer(i, buffer) : FastUInt64ToBuffer(i, buffer)); +} + +// ---------------------------------------------------------------------- +// FastInt32ToBufferLeft() +// FastUInt32ToBufferLeft() +// FastInt64ToBufferLeft() +// FastUInt64ToBufferLeft() +// +// Like the Fast*ToBuffer() functions above, these are intended for speed. +// Unlike the Fast*ToBuffer() functions, however, these functions write +// their output to the beginning of the buffer (hence the name, as the +// output is left-aligned). The caller is responsible for ensuring that +// the buffer has enough space to hold the output. +// +// Returns a pointer to the end of the string (i.e. the null character +// terminating the string). +// ---------------------------------------------------------------------- + +PROTOBUF_EXPORT char* FastInt32ToBufferLeft(int32 i, char* buffer); +PROTOBUF_EXPORT char* FastUInt32ToBufferLeft(uint32 i, char* buffer); +PROTOBUF_EXPORT char* FastInt64ToBufferLeft(int64 i, char* buffer); +PROTOBUF_EXPORT char* FastUInt64ToBufferLeft(uint64 i, char* buffer); + +// Just define these in terms of the above. +inline char* FastUInt32ToBuffer(uint32 i, char* buffer) { + FastUInt32ToBufferLeft(i, buffer); + return buffer; +} +inline char* FastUInt64ToBuffer(uint64 i, char* buffer) { + FastUInt64ToBufferLeft(i, buffer); + return buffer; +} + +inline string SimpleBtoa(bool value) { + return value ? "true" : "false"; +} + +// ---------------------------------------------------------------------- +// SimpleItoa() +// Description: converts an integer to a string. +// +// Return value: string +// ---------------------------------------------------------------------- +PROTOBUF_EXPORT string SimpleItoa(int i); +PROTOBUF_EXPORT string SimpleItoa(unsigned int i); +PROTOBUF_EXPORT string SimpleItoa(long i); +PROTOBUF_EXPORT string SimpleItoa(unsigned long i); +PROTOBUF_EXPORT string SimpleItoa(long long i); +PROTOBUF_EXPORT string SimpleItoa(unsigned long long i); + +// ---------------------------------------------------------------------- +// SimpleDtoa() +// SimpleFtoa() +// DoubleToBuffer() +// FloatToBuffer() +// Description: converts a double or float to a string which, if +// passed to NoLocaleStrtod(), will produce the exact same original double +// (except in case of NaN; all NaNs are considered the same value). +// We try to keep the string short but it's not guaranteed to be as +// short as possible. +// +// DoubleToBuffer() and FloatToBuffer() write the text to the given +// buffer and return it. The buffer must be at least +// kDoubleToBufferSize bytes for doubles and kFloatToBufferSize +// bytes for floats. kFastToBufferSize is also guaranteed to be large +// enough to hold either. +// +// Return value: string +// ---------------------------------------------------------------------- +PROTOBUF_EXPORT string SimpleDtoa(double value); +PROTOBUF_EXPORT string SimpleFtoa(float value); + +PROTOBUF_EXPORT char* DoubleToBuffer(double i, char* buffer); +PROTOBUF_EXPORT char* FloatToBuffer(float i, char* buffer); + +// In practice, doubles should never need more than 24 bytes and floats +// should never need more than 14 (including null terminators), but we +// overestimate to be safe. +static const int kDoubleToBufferSize = 32; +static const int kFloatToBufferSize = 24; + +namespace strings { + +enum PadSpec { + NO_PAD = 1, + ZERO_PAD_2, + ZERO_PAD_3, + ZERO_PAD_4, + ZERO_PAD_5, + ZERO_PAD_6, + ZERO_PAD_7, + ZERO_PAD_8, + ZERO_PAD_9, + ZERO_PAD_10, + ZERO_PAD_11, + ZERO_PAD_12, + ZERO_PAD_13, + ZERO_PAD_14, + ZERO_PAD_15, + ZERO_PAD_16, +}; + +struct Hex { + uint64 value; + enum PadSpec spec; + template + explicit Hex(Int v, PadSpec s = NO_PAD) + : spec(s) { + // Prevent sign-extension by casting integers to + // their unsigned counterparts. +#ifdef LANG_CXX11 + static_assert( + sizeof(v) == 1 || sizeof(v) == 2 || sizeof(v) == 4 || sizeof(v) == 8, + "Unknown integer type"); +#endif + value = sizeof(v) == 1 ? static_cast(v) + : sizeof(v) == 2 ? static_cast(v) + : sizeof(v) == 4 ? static_cast(v) + : static_cast(v); + } +}; + +struct PROTOBUF_EXPORT AlphaNum { + const char *piece_data_; // move these to string_ref eventually + size_t piece_size_; // move these to string_ref eventually + + char digits[kFastToBufferSize]; + + // No bool ctor -- bools convert to an integral type. + // A bool ctor would also convert incoming pointers (bletch). + + AlphaNum(int i32) + : piece_data_(digits), + piece_size_(FastInt32ToBufferLeft(i32, digits) - &digits[0]) {} + AlphaNum(unsigned int u32) + : piece_data_(digits), + piece_size_(FastUInt32ToBufferLeft(u32, digits) - &digits[0]) {} + AlphaNum(long long i64) + : piece_data_(digits), + piece_size_(FastInt64ToBufferLeft(i64, digits) - &digits[0]) {} + AlphaNum(unsigned long long u64) + : piece_data_(digits), + piece_size_(FastUInt64ToBufferLeft(u64, digits) - &digits[0]) {} + + // Note: on some architectures, "long" is only 32 bits, not 64, but the + // performance hit of using FastInt64ToBufferLeft to handle 32-bit values + // is quite minor. + AlphaNum(long i64) + : piece_data_(digits), + piece_size_(FastInt64ToBufferLeft(i64, digits) - &digits[0]) {} + AlphaNum(unsigned long u64) + : piece_data_(digits), + piece_size_(FastUInt64ToBufferLeft(u64, digits) - &digits[0]) {} + + AlphaNum(float f) + : piece_data_(digits), piece_size_(strlen(FloatToBuffer(f, digits))) {} + AlphaNum(double f) + : piece_data_(digits), piece_size_(strlen(DoubleToBuffer(f, digits))) {} + + AlphaNum(Hex hex); + + AlphaNum(const char* c_str) + : piece_data_(c_str), piece_size_(strlen(c_str)) {} + // TODO: Add a string_ref constructor, eventually + // AlphaNum(const StringPiece &pc) : piece(pc) {} + + AlphaNum(const string& str) + : piece_data_(str.data()), piece_size_(str.size()) {} + + AlphaNum(StringPiece str) + : piece_data_(str.data()), piece_size_(str.size()) {} + + AlphaNum(internal::StringPiecePod str) + : piece_data_(str.data()), piece_size_(str.size()) {} + + size_t size() const { return piece_size_; } + const char *data() const { return piece_data_; } + + private: + // Use ":" not ':' + AlphaNum(char c); // NOLINT(runtime/explicit) + + // Disallow copy and assign. + AlphaNum(const AlphaNum&); + void operator=(const AlphaNum&); +}; + +} // namespace strings + +using strings::AlphaNum; + +// ---------------------------------------------------------------------- +// StrCat() +// This merges the given strings or numbers, with no delimiter. This +// is designed to be the fastest possible way to construct a string out +// of a mix of raw C strings, strings, bool values, +// and numeric values. +// +// Don't use this for user-visible strings. The localization process +// works poorly on strings built up out of fragments. +// +// For clarity and performance, don't use StrCat when appending to a +// string. In particular, avoid using any of these (anti-)patterns: +// str.append(StrCat(...) +// str += StrCat(...) +// str = StrCat(str, ...) +// where the last is the worse, with the potential to change a loop +// from a linear time operation with O(1) dynamic allocations into a +// quadratic time operation with O(n) dynamic allocations. StrAppend +// is a better choice than any of the above, subject to the restriction +// of StrAppend(&str, a, b, c, ...) that none of the a, b, c, ... may +// be a reference into str. +// ---------------------------------------------------------------------- + +PROTOBUF_EXPORT string StrCat(const AlphaNum& a, const AlphaNum& b); +PROTOBUF_EXPORT string StrCat(const AlphaNum& a, const AlphaNum& b, + const AlphaNum& c); +PROTOBUF_EXPORT string StrCat(const AlphaNum& a, const AlphaNum& b, + const AlphaNum& c, const AlphaNum& d); +PROTOBUF_EXPORT string StrCat(const AlphaNum& a, const AlphaNum& b, + const AlphaNum& c, const AlphaNum& d, + const AlphaNum& e); +PROTOBUF_EXPORT string StrCat(const AlphaNum& a, const AlphaNum& b, + const AlphaNum& c, const AlphaNum& d, + const AlphaNum& e, const AlphaNum& f); +PROTOBUF_EXPORT string StrCat(const AlphaNum& a, const AlphaNum& b, + const AlphaNum& c, const AlphaNum& d, + const AlphaNum& e, const AlphaNum& f, + const AlphaNum& g); +PROTOBUF_EXPORT string StrCat(const AlphaNum& a, const AlphaNum& b, + const AlphaNum& c, const AlphaNum& d, + const AlphaNum& e, const AlphaNum& f, + const AlphaNum& g, const AlphaNum& h); +PROTOBUF_EXPORT string StrCat(const AlphaNum& a, const AlphaNum& b, + const AlphaNum& c, const AlphaNum& d, + const AlphaNum& e, const AlphaNum& f, + const AlphaNum& g, const AlphaNum& h, + const AlphaNum& i); + +inline string StrCat(const AlphaNum& a) { return string(a.data(), a.size()); } + +// ---------------------------------------------------------------------- +// StrAppend() +// Same as above, but adds the output to the given string. +// WARNING: For speed, StrAppend does not try to check each of its input +// arguments to be sure that they are not a subset of the string being +// appended to. That is, while this will work: +// +// string s = "foo"; +// s += s; +// +// This will not (necessarily) work: +// +// string s = "foo"; +// StrAppend(&s, s); +// +// Note: while StrCat supports appending up to 9 arguments, StrAppend +// is currently limited to 4. That's rarely an issue except when +// automatically transforming StrCat to StrAppend, and can easily be +// worked around as consecutive calls to StrAppend are quite efficient. +// ---------------------------------------------------------------------- + +PROTOBUF_EXPORT void StrAppend(string* dest, const AlphaNum& a); +PROTOBUF_EXPORT void StrAppend(string* dest, const AlphaNum& a, + const AlphaNum& b); +PROTOBUF_EXPORT void StrAppend(string* dest, const AlphaNum& a, + const AlphaNum& b, const AlphaNum& c); +PROTOBUF_EXPORT void StrAppend(string* dest, const AlphaNum& a, + const AlphaNum& b, const AlphaNum& c, + const AlphaNum& d); + +// ---------------------------------------------------------------------- +// Join() +// These methods concatenate a range of components into a C++ string, using +// the C-string "delim" as a separator between components. +// ---------------------------------------------------------------------- +template +void Join(Iterator start, Iterator end, + const char* delim, string* result) { + for (Iterator it = start; it != end; ++it) { + if (it != start) { + result->append(delim); + } + StrAppend(result, *it); + } +} + +template +string Join(const Range& components, + const char* delim) { + string result; + Join(components.begin(), components.end(), delim, &result); + return result; +} + +// ---------------------------------------------------------------------- +// ToHex() +// Return a lower-case hex string representation of the given integer. +// ---------------------------------------------------------------------- +PROTOBUF_EXPORT string ToHex(uint64 num); + +// ---------------------------------------------------------------------- +// GlobalReplaceSubstring() +// Replaces all instances of a substring in a string. Does nothing +// if 'substring' is empty. Returns the number of replacements. +// +// NOTE: The string pieces must not overlap s. +// ---------------------------------------------------------------------- +PROTOBUF_EXPORT int GlobalReplaceSubstring(const string& substring, + const string& replacement, + string* s); + +// ---------------------------------------------------------------------- +// Base64Unescape() +// Converts "src" which is encoded in Base64 to its binary equivalent and +// writes it to "dest". If src contains invalid characters, dest is cleared +// and the function returns false. Returns true on success. +// ---------------------------------------------------------------------- +PROTOBUF_EXPORT bool Base64Unescape(StringPiece src, string* dest); + +// ---------------------------------------------------------------------- +// WebSafeBase64Unescape() +// This is a variation of Base64Unescape which uses '-' instead of '+', and +// '_' instead of '/'. src is not null terminated, instead specify len. I +// recommend that slen +struct identity_ { + typedef T type; +}; + +// integral_constant, defined in tr1, is a wrapper for an integer +// value. We don't really need this generality; we could get away +// with hardcoding the integer type to bool. We use the fully +// general integer_constant for compatibility with tr1. + +template +struct integral_constant { + static const T value = v; + typedef T value_type; + typedef integral_constant type; +}; + +template const T integral_constant::value; + + +// Abbreviations: true_type and false_type are structs that represent boolean +// true and false values. Also define the boost::mpl versions of those names, +// true_ and false_. +typedef integral_constant true_type; +typedef integral_constant false_type; +typedef true_type true_; +typedef false_type false_; + +// if_ is a templatized conditional statement. +// if_ is a compile time evaluation of cond. +// if_<>::type contains A if cond is true, B otherwise. +template +struct if_{ + typedef A type; +}; + +template +struct if_ { + typedef B type; +}; + + +// type_equals_ is a template type comparator, similar to Loki IsSameType. +// type_equals_::value is true iff "A" is the same type as "B". +// +// New code should prefer base::is_same, defined in base/type_traits.h. +// It is functionally identical, but is_same is the standard spelling. +template +struct type_equals_ : public false_ { +}; + +template +struct type_equals_ : public true_ { +}; + +// and_ is a template && operator. +// and_::value evaluates "A::value && B::value". +template +struct and_ : public integral_constant { +}; + +// or_ is a template || operator. +// or_::value evaluates "A::value || B::value". +template +struct or_ : public integral_constant { +}; + + +} // namespace internal +} // namespace protobuf +} // namespace google + +#endif // GOOGLE_PROTOBUF_TEMPLATE_UTIL_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/pybind11/conduit/pybind11_conduit_v1.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/pybind11/conduit/pybind11_conduit_v1.h new file mode 100644 index 0000000000000000000000000000000000000000..33eac37d379bb92f0af1c49cc98ae29777c4b20f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/pybind11/conduit/pybind11_conduit_v1.h @@ -0,0 +1,121 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Copyright (c) 2024 The pybind Community. + +/* The pybind11_conduit_v1 feature enables type-safe interoperability between + +* different independent Python/C++ bindings systems, + +* including pybind11 versions with different PYBIND11_INTERNALS_VERSION's. + + * NOTE: The conduit feature + only covers from-Python-to-C++ conversions, it + does not cover from-C++-to-Python conversions. + (For the latter, a different feature would have to be added.) + +The naming of the feature is a bit misleading: + +* The feature is in no way tied to pybind11 internals. + +* It just happens to originate from pybind11 and currently still lives there. + +* The only external dependency is . + +The implementation is a VERY light-weight dependency. It is designed to be +compatible with any ISO C++11 (or higher) compiler, and does NOT require +C++ Exception Handling to be enabled. + +Please see https://github.com/pybind/pybind11/pull/5296 for more background. + +The implementation involves a + +def _pybind11_conduit_v1_( + self, + pybind11_platform_abi_id: bytes, + cpp_type_info_capsule: capsule, + pointer_kind: bytes) -> capsule + +method that is meant to be added to Python objects wrapping C++ objects +(e.g. pybind11::class_-wrapped types). + +The design of the _pybind11_conduit_v1_ feature provides two layers of +protection against C++ ABI mismatches: + +* The first and most important layer is that the pybind11_platform_abi_id's + must match between extensions. — This will never be perfect, but is the same + pragmatic approach used in pybind11 since 2017 + (https://github.com/pybind/pybind11/commit/96997a4b9d4ec3d389a570604394af5d5eee2557, + PYBIND11_INTERNALS_ID). + +* The second layer is that the typeid(std::type_info).name()'s must match + between extensions. + +The implementation below (which is shorter than this comment!), serves as a +battle-tested specification. The main API is this one function: + +auto *cpp_pointer = pybind11_conduit_v1::get_type_pointer_ephemeral(py_obj); + +It is meant to be a minimalistic reference implementation, intentionally +without comprehensive error reporting. It is expected that major bindings +systems will roll their own, compatible implementations, potentially with +system-specific error reporting. The essential specifications all bindings +systems need to agree on are merely: + +* PYBIND11_PLATFORM_ABI_ID (const char* literal). + +* The cpp_type_info capsule (see below: a void *ptr and a const char *name). + +* The cpp_conduit capsule (see below: a void *ptr and a const char *name). + +* "raw_pointer_ephemeral" means: the lifetime of the pointer is the lifetime + of the py_obj. + +*/ + +// THIS MUST STAY AT THE TOP! +#include "pybind11_platform_abi_id.h" + +#include +#include + +namespace pybind11_conduit_v1 { + +inline void *get_raw_pointer_ephemeral(PyObject *py_obj, const std::type_info *cpp_type_info) { + PyObject *cpp_type_info_capsule + = PyCapsule_New(const_cast(static_cast(cpp_type_info)), + typeid(std::type_info).name(), + nullptr); + if (cpp_type_info_capsule == nullptr) { + return nullptr; + } + PyObject *cpp_conduit = PyObject_CallMethod(py_obj, + "_pybind11_conduit_v1_", + "yOy", + PYBIND11_PLATFORM_ABI_ID, + cpp_type_info_capsule, + "raw_pointer_ephemeral"); + Py_DECREF(cpp_type_info_capsule); + if (cpp_conduit == nullptr) { + return nullptr; + } + void *raw_ptr = PyCapsule_GetPointer(cpp_conduit, cpp_type_info->name()); + Py_DECREF(cpp_conduit); + if (PyErr_Occurred()) { + return nullptr; + } + return raw_ptr; +} + +template +T *get_type_pointer_ephemeral(PyObject *py_obj) { + void *raw_ptr = get_raw_pointer_ephemeral(py_obj, &typeid(T)); + if (raw_ptr == nullptr) { + return nullptr; + } + return static_cast(raw_ptr); +} + +} // namespace pybind11_conduit_v1 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/pybind11/conduit/pybind11_platform_abi_id.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/pybind11/conduit/pybind11_platform_abi_id.h new file mode 100644 index 0000000000000000000000000000000000000000..27965a9429662a05fc61de7b01abe4ddf1cb0261 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/pybind11/conduit/pybind11_platform_abi_id.h @@ -0,0 +1,92 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// Copyright (c) 2024 The pybind Community. + +// To maximize reusability: +// DO NOT ADD CODE THAT REQUIRES C++ EXCEPTION HANDLING. + +#include "wrap_include_python_h.h" + +// Implementation details. DO NOT USE ELSEWHERE. (Unfortunately we cannot #undef them.) +// This is duplicated here to maximize portability. +#define PYBIND11_PLATFORM_ABI_ID_STRINGIFY(x) #x +#define PYBIND11_PLATFORM_ABI_ID_TOSTRING(x) PYBIND11_PLATFORM_ABI_ID_STRINGIFY(x) + +#ifdef PYBIND11_COMPILER_TYPE +// // To maintain backward compatibility (see PR #5439). +# define PYBIND11_COMPILER_TYPE_LEADING_UNDERSCORE "" +#else +# define PYBIND11_COMPILER_TYPE_LEADING_UNDERSCORE "_" +# if defined(__MINGW32__) +# define PYBIND11_COMPILER_TYPE "mingw" +# elif defined(__CYGWIN__) +# define PYBIND11_COMPILER_TYPE "gcc_cygwin" +# elif defined(_MSC_VER) +# define PYBIND11_COMPILER_TYPE "msvc" +# elif defined(__clang__) || defined(__GNUC__) +# define PYBIND11_COMPILER_TYPE "system" // Assumed compatible with system compiler. +# else +# error "Unknown PYBIND11_COMPILER_TYPE: PLEASE REVISE THIS CODE." +# endif +#endif + +// PR #5439 made this macro obsolete. However, there are many manipulations of this macro in the +// wild. Therefore, to maintain backward compatibility, it is kept around. +#ifndef PYBIND11_STDLIB +# define PYBIND11_STDLIB "" +#endif + +#ifndef PYBIND11_BUILD_ABI +# if defined(_MSC_VER) // See PR #4953. +# if defined(_MT) && defined(_DLL) // Corresponding to CL command line options /MD or /MDd. +# if (_MSC_VER) / 100 == 19 +# define PYBIND11_BUILD_ABI "_md_mscver19" +# else +# error "Unknown major version for MSC_VER: PLEASE REVISE THIS CODE." +# endif +# elif defined(_MT) // Corresponding to CL command line options /MT or /MTd. +# define PYBIND11_BUILD_ABI "_mt_mscver" PYBIND11_PLATFORM_ABI_ID_TOSTRING(_MSC_VER) +# else +# if (_MSC_VER) / 100 == 19 +# define PYBIND11_BUILD_ABI "_none_mscver19" +# else +# error "Unknown major version for MSC_VER: PLEASE REVISE THIS CODE." +# endif +# endif +# elif defined(_LIBCPP_ABI_VERSION) // https://libcxx.llvm.org/DesignDocs/ABIVersioning.html +# define PYBIND11_BUILD_ABI \ + "_libcpp_abi" PYBIND11_PLATFORM_ABI_ID_TOSTRING(_LIBCPP_ABI_VERSION) +# elif defined(_GLIBCXX_USE_CXX11_ABI) // See PR #5439. +# if defined(__NVCOMPILER) +// // Assume that NVHPC is in the 1xxx ABI family. +// // THIS ASSUMPTION IS NOT FUTURE PROOF but apparently the best we can do. +// // Please let us know if there is a way to validate the assumption here. +# elif !defined(__GXX_ABI_VERSION) +# error \ + "Unknown platform or compiler (_GLIBCXX_USE_CXX11_ABI): PLEASE REVISE THIS CODE." +# endif +# if defined(__GXX_ABI_VERSION) && __GXX_ABI_VERSION < 1002 || __GXX_ABI_VERSION >= 2000 +# error "Unknown platform or compiler (__GXX_ABI_VERSION): PLEASE REVISE THIS CODE." +# endif +# define PYBIND11_BUILD_ABI \ + "_libstdcpp_gxx_abi_1xxx_use_cxx11_abi_" PYBIND11_PLATFORM_ABI_ID_TOSTRING( \ + _GLIBCXX_USE_CXX11_ABI) +# else +# error "Unknown platform or compiler: PLEASE REVISE THIS CODE." +# endif +#endif + +// On MSVC, debug and release builds are not ABI-compatible! +#if defined(_MSC_VER) && defined(_DEBUG) +# define PYBIND11_BUILD_TYPE "_debug" +#else +# define PYBIND11_BUILD_TYPE "" +#endif + +#define PYBIND11_PLATFORM_ABI_ID \ + PYBIND11_COMPILER_TYPE PYBIND11_STDLIB PYBIND11_BUILD_ABI PYBIND11_BUILD_TYPE + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/pybind11/conduit/wrap_include_python_h.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/pybind11/conduit/wrap_include_python_h.h new file mode 100644 index 0000000000000000000000000000000000000000..0f15321f18755987e8b4670c17746d1531cb1c63 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/pybind11/conduit/wrap_include_python_h.h @@ -0,0 +1,77 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// Copyright (c) 2024 The pybind Community. + +// STRONG REQUIREMENT: +// This header is a wrapper around `#include `, therefore it +// MUST BE INCLUDED BEFORE ANY STANDARD HEADERS are included. +// See also: +// https://docs.python.org/3/c-api/intro.html#include-files +// Quoting from there: +// Note: Since Python may define some pre-processor definitions which affect +// the standard headers on some systems, you must include Python.h before +// any standard headers are included. + +// To maximize reusability: +// DO NOT ADD CODE THAT REQUIRES C++ EXCEPTION HANDLING. + +// Disable linking to pythonX_d.lib on Windows in debug mode. +#if defined(_MSC_VER) && defined(_DEBUG) && !defined(Py_DEBUG) +// Workaround for a VS 2022 issue. +// See https://github.com/pybind/pybind11/pull/3497 for full context. +// NOTE: This workaround knowingly violates the Python.h include order +// requirement (see above). +# include +# if _MSVC_STL_VERSION >= 143 +# include +# endif +# define PYBIND11_DEBUG_MARKER +# undef _DEBUG +#endif + +// Don't let Python.h #define (v)snprintf as macro because they are implemented +// properly in Visual Studio since 2015. +#if defined(_MSC_VER) +# define HAVE_SNPRINTF 1 +#endif + +#if defined(_MSC_VER) +# pragma warning(push) +# pragma warning(disable : 4505) +// C4505: 'PySlice_GetIndicesEx': unreferenced local function has been removed +#endif + +#include +#include +#include + +#if defined(_MSC_VER) +# pragma warning(pop) +#endif + +#if defined(PYBIND11_DEBUG_MARKER) +# define _DEBUG 1 +# undef PYBIND11_DEBUG_MARKER +#endif + +// Python #defines overrides on all sorts of core functions, which +// tends to wreak havok in C++ codebases that expect these to work +// like regular functions (potentially with several overloads). +#if defined(isalnum) +# undef isalnum +# undef isalpha +# undef islower +# undef isspace +# undef isupper +# undef tolower +# undef toupper +#endif + +#if defined(copysign) +# undef copysign +#endif + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/pybind11/detail/class.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/pybind11/detail/class.h new file mode 100644 index 0000000000000000000000000000000000000000..bf42e2a63ab21270b947d2b5abe585f9bcfda9a8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/pybind11/detail/class.h @@ -0,0 +1,828 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + pybind11/detail/class.h: Python C API implementation details for py::class_ + + Copyright (c) 2017 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include +#include + +#include "exception_translation.h" + +PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +PYBIND11_NAMESPACE_BEGIN(detail) + +#if !defined(PYPY_VERSION) +# define PYBIND11_BUILTIN_QUALNAME +# define PYBIND11_SET_OLDPY_QUALNAME(obj, nameobj) +#else +// In PyPy, we still set __qualname__ so that we can produce reliable function type +// signatures; in CPython this macro expands to nothing: +# define PYBIND11_SET_OLDPY_QUALNAME(obj, nameobj) \ + setattr((PyObject *) obj, "__qualname__", nameobj) +#endif + +inline std::string get_fully_qualified_tp_name(PyTypeObject *type) { +#if !defined(PYPY_VERSION) + return type->tp_name; +#else + auto module_name = handle((PyObject *) type).attr("__module__").cast(); + if (module_name == PYBIND11_BUILTINS_MODULE) + return type->tp_name; + else + return std::move(module_name) + "." + type->tp_name; +#endif +} + +inline PyTypeObject *type_incref(PyTypeObject *type) { + Py_INCREF(type); + return type; +} + +#if !defined(PYPY_VERSION) + +/// `pybind11_static_property.__get__()`: Always pass the class instead of the instance. +extern "C" inline PyObject *pybind11_static_get(PyObject *self, PyObject * /*ob*/, PyObject *cls) { + return PyProperty_Type.tp_descr_get(self, cls, cls); +} + +/// `pybind11_static_property.__set__()`: Just like the above `__get__()`. +extern "C" inline int pybind11_static_set(PyObject *self, PyObject *obj, PyObject *value) { + PyObject *cls = PyType_Check(obj) ? obj : (PyObject *) Py_TYPE(obj); + return PyProperty_Type.tp_descr_set(self, cls, value); +} + +// Forward declaration to use in `make_static_property_type()` +inline void enable_dynamic_attributes(PyHeapTypeObject *heap_type); + +/** A `static_property` is the same as a `property` but the `__get__()` and `__set__()` + methods are modified to always use the object type instead of a concrete instance. + Return value: New reference. */ +inline PyTypeObject *make_static_property_type() { + constexpr auto *name = "pybind11_static_property"; + auto name_obj = reinterpret_steal(PYBIND11_FROM_STRING(name)); + + /* Danger zone: from now (and until PyType_Ready), make sure to + issue no Python C API calls which could potentially invoke the + garbage collector (the GC will call type_traverse(), which will in + turn find the newly constructed type in an invalid state) */ + auto *heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0); + if (!heap_type) { + pybind11_fail("make_static_property_type(): error allocating type!"); + } + + heap_type->ht_name = name_obj.inc_ref().ptr(); +# ifdef PYBIND11_BUILTIN_QUALNAME + heap_type->ht_qualname = name_obj.inc_ref().ptr(); +# endif + + auto *type = &heap_type->ht_type; + type->tp_name = name; + type->tp_base = type_incref(&PyProperty_Type); + type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; + type->tp_descr_get = pybind11_static_get; + type->tp_descr_set = pybind11_static_set; + +# if PY_VERSION_HEX >= 0x030C0000 + // Since Python-3.12 property-derived types are required to + // have dynamic attributes (to set `__doc__`) + enable_dynamic_attributes(heap_type); +# endif + + if (PyType_Ready(type) < 0) { + pybind11_fail("make_static_property_type(): failure in PyType_Ready()!"); + } + + setattr((PyObject *) type, "__module__", str(PYBIND11_DUMMY_MODULE_NAME)); + PYBIND11_SET_OLDPY_QUALNAME(type, name_obj); + + return type; +} + +#else // PYPY + +/** PyPy has some issues with the above C API, so we evaluate Python code instead. + This function will only be called once so performance isn't really a concern. + Return value: New reference. */ +inline PyTypeObject *make_static_property_type() { + auto d = dict(); + PyObject *result = PyRun_String(R"(\ +class pybind11_static_property(property): + def __get__(self, obj, cls): + return property.__get__(self, cls, cls) + + def __set__(self, obj, value): + cls = obj if isinstance(obj, type) else type(obj) + property.__set__(self, cls, value) +)", + Py_file_input, + d.ptr(), + d.ptr()); + if (result == nullptr) + throw error_already_set(); + Py_DECREF(result); + return (PyTypeObject *) d["pybind11_static_property"].cast().release().ptr(); +} + +#endif // PYPY + +/** Types with static properties need to handle `Type.static_prop = x` in a specific way. + By default, Python replaces the `static_property` itself, but for wrapped C++ types + we need to call `static_property.__set__()` in order to propagate the new value to + the underlying C++ data structure. */ +extern "C" inline int pybind11_meta_setattro(PyObject *obj, PyObject *name, PyObject *value) { + // Use `_PyType_Lookup()` instead of `PyObject_GetAttr()` in order to get the raw + // descriptor (`property`) instead of calling `tp_descr_get` (`property.__get__()`). + PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name); + + // The following assignment combinations are possible: + // 1. `Type.static_prop = value` --> descr_set: `Type.static_prop.__set__(value)` + // 2. `Type.static_prop = other_static_prop` --> setattro: replace existing `static_prop` + // 3. `Type.regular_attribute = value` --> setattro: regular attribute assignment + auto *const static_prop = (PyObject *) get_internals().static_property_type; + const auto call_descr_set = (descr != nullptr) && (value != nullptr) + && (PyObject_IsInstance(descr, static_prop) != 0) + && (PyObject_IsInstance(value, static_prop) == 0); + if (call_descr_set) { + // Call `static_property.__set__()` instead of replacing the `static_property`. +#if !defined(PYPY_VERSION) + return Py_TYPE(descr)->tp_descr_set(descr, obj, value); +#else + if (PyObject *result = PyObject_CallMethod(descr, "__set__", "OO", obj, value)) { + Py_DECREF(result); + return 0; + } else { + return -1; + } +#endif + } else { + // Replace existing attribute. + return PyType_Type.tp_setattro(obj, name, value); + } +} + +/** + * Python 3's PyInstanceMethod_Type hides itself via its tp_descr_get, which prevents aliasing + * methods via cls.attr("m2") = cls.attr("m1"): instead the tp_descr_get returns a plain function, + * when called on a class, or a PyMethod, when called on an instance. Override that behaviour here + * to do a special case bypass for PyInstanceMethod_Types. + */ +extern "C" inline PyObject *pybind11_meta_getattro(PyObject *obj, PyObject *name) { + PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name); + if (descr && PyInstanceMethod_Check(descr)) { + Py_INCREF(descr); + return descr; + } + return PyType_Type.tp_getattro(obj, name); +} + +/// metaclass `__call__` function that is used to create all pybind11 objects. +extern "C" inline PyObject *pybind11_meta_call(PyObject *type, PyObject *args, PyObject *kwargs) { + + // use the default metaclass call to create/initialize the object + PyObject *self = PyType_Type.tp_call(type, args, kwargs); + if (self == nullptr) { + return nullptr; + } + + // Ensure that the base __init__ function(s) were called + values_and_holders vhs(self); + for (const auto &vh : vhs) { + if (!vh.holder_constructed() && !vhs.is_redundant_value_and_holder(vh)) { + PyErr_Format(PyExc_TypeError, + "%.200s.__init__() must be called when overriding __init__", + get_fully_qualified_tp_name(vh.type->type).c_str()); + Py_DECREF(self); + return nullptr; + } + } + + return self; +} + +/// Cleanup the type-info for a pybind11-registered type. +extern "C" inline void pybind11_meta_dealloc(PyObject *obj) { + with_internals([obj](internals &internals) { + auto *type = (PyTypeObject *) obj; + + // A pybind11-registered type will: + // 1) be found in internals.registered_types_py + // 2) have exactly one associated `detail::type_info` + auto found_type = internals.registered_types_py.find(type); + if (found_type != internals.registered_types_py.end() && found_type->second.size() == 1 + && found_type->second[0]->type == type) { + + auto *tinfo = found_type->second[0]; + auto tindex = std::type_index(*tinfo->cpptype); + internals.direct_conversions.erase(tindex); + + if (tinfo->module_local) { + get_local_internals().registered_types_cpp.erase(tindex); + } else { + internals.registered_types_cpp.erase(tindex); + } + internals.registered_types_py.erase(tinfo->type); + + // Actually just `std::erase_if`, but that's only available in C++20 + auto &cache = internals.inactive_override_cache; + for (auto it = cache.begin(), last = cache.end(); it != last;) { + if (it->first == (PyObject *) tinfo->type) { + it = cache.erase(it); + } else { + ++it; + } + } + + delete tinfo; + } + }); + + PyType_Type.tp_dealloc(obj); +} + +/** This metaclass is assigned by default to all pybind11 types and is required in order + for static properties to function correctly. Users may override this using `py::metaclass`. + Return value: New reference. */ +inline PyTypeObject *make_default_metaclass() { + constexpr auto *name = "pybind11_type"; + auto name_obj = reinterpret_steal(PYBIND11_FROM_STRING(name)); + + /* Danger zone: from now (and until PyType_Ready), make sure to + issue no Python C API calls which could potentially invoke the + garbage collector (the GC will call type_traverse(), which will in + turn find the newly constructed type in an invalid state) */ + auto *heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0); + if (!heap_type) { + pybind11_fail("make_default_metaclass(): error allocating metaclass!"); + } + + heap_type->ht_name = name_obj.inc_ref().ptr(); +#ifdef PYBIND11_BUILTIN_QUALNAME + heap_type->ht_qualname = name_obj.inc_ref().ptr(); +#endif + + auto *type = &heap_type->ht_type; + type->tp_name = name; + type->tp_base = type_incref(&PyType_Type); + type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; + + type->tp_call = pybind11_meta_call; + + type->tp_setattro = pybind11_meta_setattro; + type->tp_getattro = pybind11_meta_getattro; + + type->tp_dealloc = pybind11_meta_dealloc; + + if (PyType_Ready(type) < 0) { + pybind11_fail("make_default_metaclass(): failure in PyType_Ready()!"); + } + + setattr((PyObject *) type, "__module__", str(PYBIND11_DUMMY_MODULE_NAME)); + PYBIND11_SET_OLDPY_QUALNAME(type, name_obj); + + return type; +} + +/// For multiple inheritance types we need to recursively register/deregister base pointers for any +/// base classes with pointers that are difference from the instance value pointer so that we can +/// correctly recognize an offset base class pointer. This calls a function with any offset base +/// ptrs. +inline void traverse_offset_bases(void *valueptr, + const detail::type_info *tinfo, + instance *self, + bool (*f)(void * /*parentptr*/, instance * /*self*/)) { + for (handle h : reinterpret_borrow(tinfo->type->tp_bases)) { + if (auto *parent_tinfo = get_type_info((PyTypeObject *) h.ptr())) { + for (auto &c : parent_tinfo->implicit_casts) { + if (c.first == tinfo->cpptype) { + auto *parentptr = c.second(valueptr); + if (parentptr != valueptr) { + f(parentptr, self); + } + traverse_offset_bases(parentptr, parent_tinfo, self, f); + break; + } + } + } + } +} + +#ifdef Py_GIL_DISABLED +inline void enable_try_inc_ref(PyObject *obj) { + // TODO: Replace with PyUnstable_Object_EnableTryIncRef when available. + // See https://github.com/python/cpython/issues/128844 + if (_Py_IsImmortal(obj)) { + return; + } + for (;;) { + Py_ssize_t shared = _Py_atomic_load_ssize_relaxed(&obj->ob_ref_shared); + if ((shared & _Py_REF_SHARED_FLAG_MASK) != 0) { + // Nothing to do if it's in WEAKREFS, QUEUED, or MERGED states. + return; + } + if (_Py_atomic_compare_exchange_ssize( + &obj->ob_ref_shared, &shared, shared | _Py_REF_MAYBE_WEAKREF)) { + return; + } + } +} +#endif + +inline bool register_instance_impl(void *ptr, instance *self) { +#ifdef Py_GIL_DISABLED + enable_try_inc_ref(reinterpret_cast(self)); +#endif + with_instance_map(ptr, [&](instance_map &instances) { instances.emplace(ptr, self); }); + return true; // unused, but gives the same signature as the deregister func +} +inline bool deregister_instance_impl(void *ptr, instance *self) { + return with_instance_map(ptr, [&](instance_map &instances) { + auto range = instances.equal_range(ptr); + for (auto it = range.first; it != range.second; ++it) { + if (self == it->second) { + instances.erase(it); + return true; + } + } + return false; + }); +} + +inline void register_instance(instance *self, void *valptr, const type_info *tinfo) { + register_instance_impl(valptr, self); + if (!tinfo->simple_ancestors) { + traverse_offset_bases(valptr, tinfo, self, register_instance_impl); + } +} + +inline bool deregister_instance(instance *self, void *valptr, const type_info *tinfo) { + bool ret = deregister_instance_impl(valptr, self); + if (!tinfo->simple_ancestors) { + traverse_offset_bases(valptr, tinfo, self, deregister_instance_impl); + } + return ret; +} + +/// Instance creation function for all pybind11 types. It allocates the internal instance layout +/// for holding C++ objects and holders. Allocation is done lazily (the first time the instance is +/// cast to a reference or pointer), and initialization is done by an `__init__` function. +inline PyObject *make_new_instance(PyTypeObject *type) { +#if defined(PYPY_VERSION) + // PyPy gets tp_basicsize wrong (issue 2482) under multiple inheritance when the first + // inherited object is a plain Python type (i.e. not derived from an extension type). Fix it. + ssize_t instance_size = static_cast(sizeof(instance)); + if (type->tp_basicsize < instance_size) { + type->tp_basicsize = instance_size; + } +#endif + PyObject *self = type->tp_alloc(type, 0); + auto *inst = reinterpret_cast(self); + // Allocate the value/holder internals: + inst->allocate_layout(); + + return self; +} + +/// Instance creation function for all pybind11 types. It only allocates space for the +/// C++ object, but doesn't call the constructor -- an `__init__` function must do that. +extern "C" inline PyObject *pybind11_object_new(PyTypeObject *type, PyObject *, PyObject *) { + return make_new_instance(type); +} + +/// An `__init__` function constructs the C++ object. Users should provide at least one +/// of these using `py::init` or directly with `.def(__init__, ...)`. Otherwise, the +/// following default function will be used which simply throws an exception. +extern "C" inline int pybind11_object_init(PyObject *self, PyObject *, PyObject *) { + PyTypeObject *type = Py_TYPE(self); + std::string msg = get_fully_qualified_tp_name(type) + ": No constructor defined!"; + set_error(PyExc_TypeError, msg.c_str()); + return -1; +} + +inline void add_patient(PyObject *nurse, PyObject *patient) { + auto *instance = reinterpret_cast(nurse); + instance->has_patients = true; + Py_INCREF(patient); + + with_internals([&](internals &internals) { internals.patients[nurse].push_back(patient); }); +} + +inline void clear_patients(PyObject *self) { + auto *instance = reinterpret_cast(self); + std::vector patients; + + with_internals([&](internals &internals) { + auto pos = internals.patients.find(self); + + if (pos == internals.patients.end()) { + pybind11_fail( + "FATAL: Internal consistency check failed: Invalid clear_patients() call."); + } + + // Clearing the patients can cause more Python code to run, which + // can invalidate the iterator. Extract the vector of patients + // from the unordered_map first. + patients = std::move(pos->second); + internals.patients.erase(pos); + }); + + instance->has_patients = false; + for (PyObject *&patient : patients) { + Py_CLEAR(patient); + } +} + +/// Clears all internal data from the instance and removes it from registered instances in +/// preparation for deallocation. +inline void clear_instance(PyObject *self) { + auto *instance = reinterpret_cast(self); + + // Deallocate any values/holders, if present: + for (auto &v_h : values_and_holders(instance)) { + if (v_h) { + + // We have to deregister before we call dealloc because, for virtual MI types, we still + // need to be able to get the parent pointers. + if (v_h.instance_registered() + && !deregister_instance(instance, v_h.value_ptr(), v_h.type)) { + pybind11_fail( + "pybind11_object_dealloc(): Tried to deallocate unregistered instance!"); + } + + if (instance->owned || v_h.holder_constructed()) { + v_h.type->dealloc(v_h); + } + } else if (v_h.holder_constructed()) { + v_h.type->dealloc(v_h); // Disowned instance. + } + } + // Deallocate the value/holder layout internals: + instance->deallocate_layout(); + + if (instance->weakrefs) { + PyObject_ClearWeakRefs(self); + } + + PyObject **dict_ptr = _PyObject_GetDictPtr(self); + if (dict_ptr) { + Py_CLEAR(*dict_ptr); + } + + if (instance->has_patients) { + clear_patients(self); + } +} + +/// Instance destructor function for all pybind11 types. It calls `type_info.dealloc` +/// to destroy the C++ object itself, while the rest is Python bookkeeping. +extern "C" inline void pybind11_object_dealloc(PyObject *self) { + auto *type = Py_TYPE(self); + + // If this is a GC tracked object, untrack it first + // Note that the track call is implicitly done by the + // default tp_alloc, which we never override. + if (PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC) != 0) { + PyObject_GC_UnTrack(self); + } + + clear_instance(self); + + type->tp_free(self); + + // This was not needed before Python 3.8 (Python issue 35810) + // https://github.com/pybind/pybind11/issues/1946 + Py_DECREF(type); +} + +PYBIND11_WARNING_PUSH +PYBIND11_WARNING_DISABLE_GCC("-Wredundant-decls") + +std::string error_string(); + +PYBIND11_WARNING_POP + +/** Create the type which can be used as a common base for all classes. This is + needed in order to satisfy Python's requirements for multiple inheritance. + Return value: New reference. */ +inline PyObject *make_object_base_type(PyTypeObject *metaclass) { + constexpr auto *name = "pybind11_object"; + auto name_obj = reinterpret_steal(PYBIND11_FROM_STRING(name)); + + /* Danger zone: from now (and until PyType_Ready), make sure to + issue no Python C API calls which could potentially invoke the + garbage collector (the GC will call type_traverse(), which will in + turn find the newly constructed type in an invalid state) */ + auto *heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0); + if (!heap_type) { + pybind11_fail("make_object_base_type(): error allocating type!"); + } + + heap_type->ht_name = name_obj.inc_ref().ptr(); +#ifdef PYBIND11_BUILTIN_QUALNAME + heap_type->ht_qualname = name_obj.inc_ref().ptr(); +#endif + + auto *type = &heap_type->ht_type; + type->tp_name = name; + type->tp_base = type_incref(&PyBaseObject_Type); + type->tp_basicsize = static_cast(sizeof(instance)); + type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; + + type->tp_new = pybind11_object_new; + type->tp_init = pybind11_object_init; + type->tp_dealloc = pybind11_object_dealloc; + + /* Support weak references (needed for the keep_alive feature) */ + type->tp_weaklistoffset = offsetof(instance, weakrefs); + + if (PyType_Ready(type) < 0) { + pybind11_fail("PyType_Ready failed in make_object_base_type(): " + error_string()); + } + + setattr((PyObject *) type, "__module__", str(PYBIND11_DUMMY_MODULE_NAME)); + PYBIND11_SET_OLDPY_QUALNAME(type, name_obj); + + assert(!PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC)); + return (PyObject *) heap_type; +} + +/// dynamic_attr: Allow the garbage collector to traverse the internal instance `__dict__`. +extern "C" inline int pybind11_traverse(PyObject *self, visitproc visit, void *arg) { +#if PY_VERSION_HEX >= 0x030D0000 + PyObject_VisitManagedDict(self, visit, arg); +#else + PyObject *&dict = *_PyObject_GetDictPtr(self); + Py_VISIT(dict); +#endif +// https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_traverse +#if PY_VERSION_HEX >= 0x03090000 + Py_VISIT(Py_TYPE(self)); +#endif + return 0; +} + +/// dynamic_attr: Allow the GC to clear the dictionary. +extern "C" inline int pybind11_clear(PyObject *self) { +#if PY_VERSION_HEX >= 0x030D0000 + PyObject_ClearManagedDict(self); +#else + PyObject *&dict = *_PyObject_GetDictPtr(self); + Py_CLEAR(dict); +#endif + return 0; +} + +/// Give instances of this type a `__dict__` and opt into garbage collection. +inline void enable_dynamic_attributes(PyHeapTypeObject *heap_type) { + auto *type = &heap_type->ht_type; + type->tp_flags |= Py_TPFLAGS_HAVE_GC; +#ifdef PYBIND11_BACKWARD_COMPATIBILITY_TP_DICTOFFSET + type->tp_dictoffset = type->tp_basicsize; // place dict at the end + type->tp_basicsize += (ssize_t) sizeof(PyObject *); // and allocate enough space for it +#else + type->tp_flags |= Py_TPFLAGS_MANAGED_DICT; +#endif + type->tp_traverse = pybind11_traverse; + type->tp_clear = pybind11_clear; + + static PyGetSetDef getset[] + = {{"__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict, nullptr, nullptr}, + {nullptr, nullptr, nullptr, nullptr, nullptr}}; + type->tp_getset = getset; +} + +/// buffer_protocol: Fill in the view as specified by flags. +extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int flags) { + // Look for a `get_buffer` implementation in this type's info or any bases (following MRO). + type_info *tinfo = nullptr; + for (auto type : reinterpret_borrow(Py_TYPE(obj)->tp_mro)) { + tinfo = get_type_info((PyTypeObject *) type.ptr()); + if (tinfo && tinfo->get_buffer) { + break; + } + } + if (view == nullptr || !tinfo || !tinfo->get_buffer) { + if (view) { + view->obj = nullptr; + } + set_error(PyExc_BufferError, "pybind11_getbuffer(): Internal error"); + return -1; + } + std::memset(view, 0, sizeof(Py_buffer)); + std::unique_ptr info = nullptr; + try { + info.reset(tinfo->get_buffer(obj, tinfo->get_buffer_data)); + } catch (...) { + try_translate_exceptions(); + raise_from(PyExc_BufferError, "Error getting buffer"); + return -1; + } + if (info == nullptr) { + pybind11_fail("FATAL UNEXPECTED SITUATION: tinfo->get_buffer() returned nullptr."); + } + + if ((flags & PyBUF_WRITABLE) == PyBUF_WRITABLE && info->readonly) { + // view->obj = nullptr; // Was just memset to 0, so not necessary + set_error(PyExc_BufferError, "Writable buffer requested for readonly storage"); + return -1; + } + + // Fill in all the information, and then downgrade as requested by the caller, or raise an + // error if that's not possible. + view->itemsize = info->itemsize; + view->len = view->itemsize; + for (auto s : info->shape) { + view->len *= s; + } + view->ndim = static_cast(info->ndim); + view->shape = info->shape.data(); + view->strides = info->strides.data(); + view->readonly = static_cast(info->readonly); + if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) { + view->format = const_cast(info->format.c_str()); + } + + // Note, all contiguity flags imply PyBUF_STRIDES and lower. + if ((flags & PyBUF_C_CONTIGUOUS) == PyBUF_C_CONTIGUOUS) { + if (PyBuffer_IsContiguous(view, 'C') == 0) { + std::memset(view, 0, sizeof(Py_buffer)); + set_error(PyExc_BufferError, + "C-contiguous buffer requested for discontiguous storage"); + return -1; + } + } else if ((flags & PyBUF_F_CONTIGUOUS) == PyBUF_F_CONTIGUOUS) { + if (PyBuffer_IsContiguous(view, 'F') == 0) { + std::memset(view, 0, sizeof(Py_buffer)); + set_error(PyExc_BufferError, + "Fortran-contiguous buffer requested for discontiguous storage"); + return -1; + } + } else if ((flags & PyBUF_ANY_CONTIGUOUS) == PyBUF_ANY_CONTIGUOUS) { + if (PyBuffer_IsContiguous(view, 'A') == 0) { + std::memset(view, 0, sizeof(Py_buffer)); + set_error(PyExc_BufferError, "Contiguous buffer requested for discontiguous storage"); + return -1; + } + + } else if ((flags & PyBUF_STRIDES) != PyBUF_STRIDES) { + // If no strides are requested, the buffer must be C-contiguous. + // https://docs.python.org/3/c-api/buffer.html#contiguity-requests + if (PyBuffer_IsContiguous(view, 'C') == 0) { + std::memset(view, 0, sizeof(Py_buffer)); + set_error(PyExc_BufferError, + "C-contiguous buffer requested for discontiguous storage"); + return -1; + } + + view->strides = nullptr; + + // Since this is a contiguous buffer, it can also pretend to be 1D. + if ((flags & PyBUF_ND) != PyBUF_ND) { + view->shape = nullptr; + view->ndim = 0; + } + } + + // Set these after all checks so they don't leak out into the caller, and can be automatically + // cleaned up on error. + view->buf = info->ptr; + view->internal = info.release(); + view->obj = obj; + Py_INCREF(view->obj); + return 0; +} + +/// buffer_protocol: Release the resources of the buffer. +extern "C" inline void pybind11_releasebuffer(PyObject *, Py_buffer *view) { + delete (buffer_info *) view->internal; +} + +/// Give this type a buffer interface. +inline void enable_buffer_protocol(PyHeapTypeObject *heap_type) { + heap_type->ht_type.tp_as_buffer = &heap_type->as_buffer; + + heap_type->as_buffer.bf_getbuffer = pybind11_getbuffer; + heap_type->as_buffer.bf_releasebuffer = pybind11_releasebuffer; +} + +/** Create a brand new Python type according to the `type_record` specification. + Return value: New reference. */ +inline PyObject *make_new_python_type(const type_record &rec) { + auto name = reinterpret_steal(PYBIND11_FROM_STRING(rec.name)); + + auto qualname = name; + if (rec.scope && !PyModule_Check(rec.scope.ptr()) && hasattr(rec.scope, "__qualname__")) { + qualname = reinterpret_steal( + PyUnicode_FromFormat("%U.%U", rec.scope.attr("__qualname__").ptr(), name.ptr())); + } + + object module_ = get_module_name_if_available(rec.scope); + const auto *full_name = c_str( +#if !defined(PYPY_VERSION) + module_ ? str(module_).cast() + "." + rec.name : +#endif + rec.name); + + char *tp_doc = nullptr; + if (rec.doc && options::show_user_defined_docstrings()) { + /* Allocate memory for docstring (Python will free this later on) */ + size_t size = std::strlen(rec.doc) + 1; +#if PY_VERSION_HEX >= 0x030D0000 + tp_doc = (char *) PyMem_MALLOC(size); +#else + tp_doc = (char *) PyObject_MALLOC(size); +#endif + std::memcpy((void *) tp_doc, rec.doc, size); + } + + auto &internals = get_internals(); + auto bases = tuple(rec.bases); + auto *base = (bases.empty()) ? internals.instance_base : bases[0].ptr(); + + /* Danger zone: from now (and until PyType_Ready), make sure to + issue no Python C API calls which could potentially invoke the + garbage collector (the GC will call type_traverse(), which will in + turn find the newly constructed type in an invalid state) */ + auto *metaclass + = rec.metaclass.ptr() ? (PyTypeObject *) rec.metaclass.ptr() : internals.default_metaclass; + + auto *heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0); + if (!heap_type) { + pybind11_fail(std::string(rec.name) + ": Unable to create type object!"); + } + + heap_type->ht_name = name.release().ptr(); +#ifdef PYBIND11_BUILTIN_QUALNAME + heap_type->ht_qualname = qualname.inc_ref().ptr(); +#endif + + auto *type = &heap_type->ht_type; + type->tp_name = full_name; + type->tp_doc = tp_doc; + type->tp_base = type_incref((PyTypeObject *) base); + type->tp_basicsize = static_cast(sizeof(instance)); + if (!bases.empty()) { + type->tp_bases = bases.release().ptr(); + } + + /* Don't inherit base __init__ */ + type->tp_init = pybind11_object_init; + + /* Supported protocols */ + type->tp_as_number = &heap_type->as_number; + type->tp_as_sequence = &heap_type->as_sequence; + type->tp_as_mapping = &heap_type->as_mapping; + type->tp_as_async = &heap_type->as_async; + + /* Flags */ + type->tp_flags |= Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE; + if (!rec.is_final) { + type->tp_flags |= Py_TPFLAGS_BASETYPE; + } + + if (rec.dynamic_attr) { + enable_dynamic_attributes(heap_type); + } + + if (rec.buffer_protocol) { + enable_buffer_protocol(heap_type); + } + + if (rec.custom_type_setup_callback) { + rec.custom_type_setup_callback(heap_type); + } + + if (PyType_Ready(type) < 0) { + pybind11_fail(std::string(rec.name) + ": PyType_Ready failed: " + error_string()); + } + + assert(!rec.dynamic_attr || PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC)); + + /* Register type with the parent scope */ + if (rec.scope) { + setattr(rec.scope, rec.name, (PyObject *) type); + } else { + Py_INCREF(type); // Keep it alive forever (reference leak) + } + + if (module_) { // Needed by pydoc + setattr((PyObject *) type, "__module__", module_); + } + + PYBIND11_SET_OLDPY_QUALNAME(type, qualname); + + return (PyObject *) type; +} + +PYBIND11_NAMESPACE_END(detail) +PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE) + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/pybind11/detail/common.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/pybind11/detail/common.h new file mode 100644 index 0000000000000000000000000000000000000000..ba15312ba2d53245fe228aafbb9e63f178330108 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/pybind11/detail/common.h @@ -0,0 +1,1353 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + pybind11/detail/common.h -- Basic macros + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include +#if PY_VERSION_HEX < 0x03080000 +# error "PYTHON < 3.8 IS UNSUPPORTED. pybind11 v2.13 was the last to support Python 3.7." +#endif + +// Similar to Python's convention: https://docs.python.org/3/c-api/apiabiversion.html +// See also: https://github.com/python/cpython/blob/HEAD/Include/patchlevel.h +/* -- start version constants -- */ +#define PYBIND11_VERSION_MAJOR 3 +#define PYBIND11_VERSION_MINOR 0 +#define PYBIND11_VERSION_MICRO 1 +// ALPHA = 0xA, BETA = 0xB, GAMMA = 0xC (release candidate), FINAL = 0xF (stable release) +// - The release level is set to "alpha" for development versions. +// Use 0xA0 (LEVEL=0xA, SERIAL=0) for development versions. +// - For stable releases, set the serial to 0. +#define PYBIND11_VERSION_RELEASE_LEVEL PY_RELEASE_LEVEL_FINAL +#define PYBIND11_VERSION_RELEASE_SERIAL 0 +// String version of (micro, release level, release serial), e.g.: 0a0, 0b1, 0rc1, 0 +#define PYBIND11_VERSION_PATCH 1 +/* -- end version constants -- */ + +#if !defined(Py_PACK_FULL_VERSION) +// Stable API since Python 3.14.0a4 +# define Py_PACK_FULL_VERSION(X, Y, Z, LEVEL, SERIAL) \ + ((((X) & 0xff) << 24) | (((Y) & 0xff) << 16) | (((Z) & 0xff) << 8) \ + | (((LEVEL) & 0xf) << 4) | (((SERIAL) & 0xf) << 0)) +#endif +// Version as a single 4-byte hex number, e.g. 0x030C04B5 == 3.12.4b5. +#define PYBIND11_VERSION_HEX \ + Py_PACK_FULL_VERSION(PYBIND11_VERSION_MAJOR, \ + PYBIND11_VERSION_MINOR, \ + PYBIND11_VERSION_MICRO, \ + PYBIND11_VERSION_RELEASE_LEVEL, \ + PYBIND11_VERSION_RELEASE_SERIAL) + +#include "pybind11_namespace_macros.h" + +#if !(defined(_MSC_VER) && __cplusplus == 199711L) +# if __cplusplus >= 201402L +# define PYBIND11_CPP14 +# if __cplusplus >= 201703L +# define PYBIND11_CPP17 +# if __cplusplus >= 202002L +# define PYBIND11_CPP20 +// Please update tests/pybind11_tests.cpp `cpp_std()` when adding a macro here. +# endif +# endif +# endif +#elif defined(_MSC_VER) && __cplusplus == 199711L +// MSVC sets _MSVC_LANG rather than __cplusplus (supposedly until the standard is fully +// implemented). Unless you use the /Zc:__cplusplus flag on Visual Studio 2017 15.7 Preview 3 +// or newer. +# if _MSVC_LANG >= 201402L +# define PYBIND11_CPP14 +# if _MSVC_LANG > 201402L +# define PYBIND11_CPP17 +# if _MSVC_LANG >= 202002L +# define PYBIND11_CPP20 +# endif +# endif +# endif +#endif + +// These PYBIND11_HAS_... macros are consolidated in pybind11/detail/common.h +// to simplify backward compatibility handling for users (e.g., via #ifdef checks): +#define PYBIND11_HAS_TYPE_CASTER_STD_FUNCTION_SPECIALIZATIONS 1 +#define PYBIND11_HAS_INTERNALS_WITH_SMART_HOLDER_SUPPORT 1 +#define PYBIND11_HAS_CPP_CONDUIT 1 +#define PYBIND11_HAS_NATIVE_ENUM 1 + +#if defined(PYBIND11_CPP17) && defined(__has_include) +# if __has_include() +# define PYBIND11_HAS_FILESYSTEM 1 +# elif __has_include() +# define PYBIND11_HAS_EXPERIMENTAL_FILESYSTEM 1 +# endif +#endif + +#if defined(__cpp_lib_launder) && !(defined(_MSC_VER) && (_MSC_VER < 1914)) +# define PYBIND11_STD_LAUNDER std::launder +# define PYBIND11_HAS_STD_LAUNDER 1 +#else +# define PYBIND11_STD_LAUNDER +# define PYBIND11_HAS_STD_LAUNDER 0 +#endif + +#if defined(PYBIND11_CPP20) +# define PYBIND11_CONSTINIT constinit +# define PYBIND11_DTOR_CONSTEXPR constexpr +#else +# define PYBIND11_CONSTINIT +# define PYBIND11_DTOR_CONSTEXPR +#endif + +// Compiler version assertions +#if defined(__INTEL_COMPILER) +# if __INTEL_COMPILER < 1800 +# error pybind11 requires Intel C++ compiler v18 or newer +# elif __INTEL_COMPILER < 1900 && defined(PYBIND11_CPP14) +# error pybind11 supports only C++11 with Intel C++ compiler v18. Use v19 or newer for C++14. +# endif +/* The following pragma cannot be pop'ed: + https://community.intel.com/t5/Intel-C-Compiler/Inline-and-no-inline-warning/td-p/1216764 */ +# pragma warning disable 2196 // warning #2196: routine is both "inline" and "noinline" +#elif defined(__clang__) && !defined(__apple_build_version__) +# if __clang_major__ < 3 || (__clang_major__ == 3 && __clang_minor__ < 3) +# error pybind11 requires clang 3.3 or newer +# endif +#elif defined(__clang__) +// Apple changes clang version macros to its Xcode version; the first Xcode release based on +// (upstream) clang 3.3 was Xcode 5: +# if __clang_major__ < 5 +# error pybind11 requires Xcode/clang 5.0 or newer +# endif +#elif defined(__GNUG__) +# if __GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 8) +# error pybind11 requires gcc 4.8 or newer +# endif +#elif defined(_MSC_VER) +# if _MSC_VER < 1910 +# error pybind11 2.10+ requires MSVC 2017 or newer +# endif +#endif + +#if !defined(PYBIND11_EXPORT) +# if defined(WIN32) || defined(_WIN32) +# define PYBIND11_EXPORT __declspec(dllexport) +# else +# define PYBIND11_EXPORT __attribute__((visibility("default"))) +# endif +#endif + +// For CUDA, GCC7, GCC8: +// PYBIND11_NOINLINE_FORCED is incompatible with `-Wattributes -Werror`. +// When defining PYBIND11_NOINLINE_FORCED, it is best to also use `-Wno-attributes`. +// However, the measured shared-library size saving when using noinline are only +// 1.7% for CUDA, -0.2% for GCC7, and 0.0% for GCC8 (using -DCMAKE_BUILD_TYPE=MinSizeRel, +// the default under pybind11/tests). +#if !defined(PYBIND11_NOINLINE_FORCED) \ + && (defined(__CUDACC__) || (defined(__GNUC__) && (__GNUC__ == 7 || __GNUC__ == 8))) +# define PYBIND11_NOINLINE_DISABLED +#endif + +// The PYBIND11_NOINLINE macro is for function DEFINITIONS. +// In contrast, FORWARD DECLARATIONS should never use this macro: +// https://stackoverflow.com/questions/9317473/forward-declaration-of-inline-functions +#if defined(PYBIND11_NOINLINE_DISABLED) // Option for maximum portability and experimentation. +# define PYBIND11_NOINLINE inline +#elif defined(_MSC_VER) +# define PYBIND11_NOINLINE __declspec(noinline) inline +#else +# define PYBIND11_NOINLINE __attribute__((noinline)) inline +#endif + +#if defined(__MINGW32__) +// For unknown reasons all PYBIND11_DEPRECATED member trigger a warning when declared +// whether it is used or not +# define PYBIND11_DEPRECATED(reason) +#elif defined(PYBIND11_CPP14) +# define PYBIND11_DEPRECATED(reason) [[deprecated(reason)]] +#else +# define PYBIND11_DEPRECATED(reason) __attribute__((deprecated(reason))) +#endif + +#if defined(PYBIND11_CPP17) +# define PYBIND11_MAYBE_UNUSED [[maybe_unused]] +#elif defined(_MSC_VER) && !defined(__clang__) +# define PYBIND11_MAYBE_UNUSED +#else +# define PYBIND11_MAYBE_UNUSED __attribute__((__unused__)) +#endif + +// https://en.cppreference.com/w/c/chrono/localtime +#if defined(__STDC_LIB_EXT1__) && !defined(__STDC_WANT_LIB_EXT1__) +# define __STDC_WANT_LIB_EXT1__ +#endif + +#ifdef __has_include +// std::optional (but including it in c++14 mode isn't allowed) +# if defined(PYBIND11_CPP17) && __has_include() +# define PYBIND11_HAS_OPTIONAL 1 +# endif +// std::experimental::optional (but not allowed in c++11 mode) +# if defined(PYBIND11_CPP14) && (__has_include() && \ + !__has_include()) +# define PYBIND11_HAS_EXP_OPTIONAL 1 +# endif +// std::variant +# if defined(PYBIND11_CPP17) && __has_include() +# define PYBIND11_HAS_VARIANT 1 +# endif +#elif defined(_MSC_VER) && defined(PYBIND11_CPP17) +# define PYBIND11_HAS_OPTIONAL 1 +# define PYBIND11_HAS_VARIANT 1 +#endif + +#if defined(PYBIND11_CPP17) \ + && ((defined(__has_include) && __has_include()) || defined(_MSC_VER)) +# define PYBIND11_HAS_STRING_VIEW 1 +#endif + +#if (defined(PYPY_VERSION) || defined(GRAALVM_PYTHON)) && !defined(PYBIND11_SIMPLE_GIL_MANAGEMENT) +# define PYBIND11_SIMPLE_GIL_MANAGEMENT +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if defined(__has_include) +# if __has_include() +# include +# endif +#endif + +// For libc++, the exceptions should be exported, +// otherwise, the exception translation would be incorrect. +// IMPORTANT: This code block must stay BELOW the #include above (see PR #5390). +#if !defined(PYBIND11_EXPORT_EXCEPTION) +# if defined(_LIBCPP_EXCEPTION) +# define PYBIND11_EXPORT_EXCEPTION PYBIND11_EXPORT +# else +# define PYBIND11_EXPORT_EXCEPTION +# endif +#endif + +// Must be after including or one of the other headers specified by the standard +#if defined(__cpp_lib_char8_t) && __cpp_lib_char8_t >= 201811L +# define PYBIND11_HAS_U8STRING 1 +#endif + +// See description of PR #4246: +#if !defined(PYBIND11_NO_ASSERT_GIL_HELD_INCREF_DECREF) && !defined(NDEBUG) \ + && !defined(PYPY_VERSION) && !defined(PYBIND11_ASSERT_GIL_HELD_INCREF_DECREF) +# define PYBIND11_ASSERT_GIL_HELD_INCREF_DECREF +#endif + +// Slightly faster code paths are available when PYBIND11_HAS_SUBINTERPRETER_SUPPORT is *not* +// defined, so avoid defining it for implementations that do not support subinterpreters. However, +// defining it unnecessarily is not expected to break anything. +// This can be overridden by the user with -DPYBIND11_HAS_SUBINTERPRETER_SUPPORT=1 or 0 +#ifndef PYBIND11_HAS_SUBINTERPRETER_SUPPORT +# if PY_VERSION_HEX >= 0x030C0000 && !defined(PYPY_VERSION) && !defined(GRAALVM_PYTHON) +# define PYBIND11_HAS_SUBINTERPRETER_SUPPORT 1 +# endif +#else +# if PYBIND11_HAS_SUBINTERPRETER_SUPPORT == 0 +# undef PYBIND11_HAS_SUBINTERPRETER_SUPPORT +# endif +#endif + +// 3.13 Compatibility +#if 0x030D0000 <= PY_VERSION_HEX +# define PYBIND11_TYPE_IS_TYPE_HINT "typing.TypeIs" +# define PYBIND11_CAPSULE_TYPE_TYPE_HINT "types.CapsuleType" +#else +# define PYBIND11_TYPE_IS_TYPE_HINT "typing_extensions.TypeIs" +# define PYBIND11_CAPSULE_TYPE_TYPE_HINT "typing_extensions.CapsuleType" +#endif + +// 3.12 Compatibility +#if 0x030C0000 <= PY_VERSION_HEX +# define PYBIND11_BUFFER_TYPE_HINT "collections.abc.Buffer" +#else +# define PYBIND11_BUFFER_TYPE_HINT "typing_extensions.Buffer" +#endif + +// 3.11 Compatibility +#if 0x030B0000 <= PY_VERSION_HEX +# define PYBIND11_NEVER_TYPE_HINT "typing.Never" +#else +# define PYBIND11_NEVER_TYPE_HINT "typing_extensions.Never" +#endif + +// 3.10 Compatibility +#if 0x030A0000 <= PY_VERSION_HEX +# define PYBIND11_TYPE_GUARD_TYPE_HINT "typing.TypeGuard" +#else +# define PYBIND11_TYPE_GUARD_TYPE_HINT "typing_extensions.TypeGuard" +#endif + +// #define PYBIND11_STR_LEGACY_PERMISSIVE +// If DEFINED, pybind11::str can hold PyUnicodeObject or PyBytesObject +// (probably surprising and never documented, but this was the +// legacy behavior until and including v2.6.x). As a side-effect, +// pybind11::isinstance() is true for both pybind11::str and +// pybind11::bytes. +// If UNDEFINED, pybind11::str can only hold PyUnicodeObject, and +// pybind11::isinstance() is true only for pybind11::str. +// However, for Python 2 only (!), the pybind11::str caster +// implicitly decoded bytes to PyUnicodeObject. This was to ease +// the transition from the legacy behavior to the non-permissive +// behavior. + +/// Compatibility macros for Python 2 / Python 3 versions TODO: remove +#define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyInstanceMethod_New(ptr) +#define PYBIND11_INSTANCE_METHOD_CHECK PyInstanceMethod_Check +#define PYBIND11_INSTANCE_METHOD_GET_FUNCTION PyInstanceMethod_GET_FUNCTION +#define PYBIND11_BYTES_CHECK PyBytes_Check +#define PYBIND11_BYTES_FROM_STRING PyBytes_FromString +#define PYBIND11_BYTES_FROM_STRING_AND_SIZE PyBytes_FromStringAndSize +#define PYBIND11_BYTES_AS_STRING_AND_SIZE PyBytes_AsStringAndSize +#define PYBIND11_BYTES_AS_STRING PyBytes_AsString +#define PYBIND11_BYTES_SIZE PyBytes_Size +#define PYBIND11_LONG_CHECK(o) PyLong_Check(o) +#define PYBIND11_LONG_AS_LONGLONG(o) PyLong_AsLongLong(o) +#define PYBIND11_LONG_FROM_SIGNED(o) PyLong_FromSsize_t((ssize_t) (o)) +#define PYBIND11_LONG_FROM_UNSIGNED(o) PyLong_FromSize_t((size_t) (o)) +#define PYBIND11_BYTES_NAME "bytes" +#define PYBIND11_STRING_NAME "str" +#define PYBIND11_SLICE_OBJECT PyObject +#define PYBIND11_FROM_STRING PyUnicode_FromString +#define PYBIND11_STR_TYPE ::pybind11::str +#define PYBIND11_BOOL_ATTR "__bool__" +#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_bool) +#define PYBIND11_BUILTINS_MODULE "builtins" +// Providing a separate declaration to make Clang's -Wmissing-prototypes happy. +// See comment for PYBIND11_MODULE below for why this is marked "maybe unused". +#define PYBIND11_PLUGIN_DECL(name) \ + extern "C" PYBIND11_MAYBE_UNUSED PYBIND11_EXPORT PyObject *PyInit_##name(); +#define PYBIND11_PLUGIN_IMPL(name) \ + PYBIND11_PLUGIN_DECL(name) \ + extern "C" PYBIND11_EXPORT PyObject *PyInit_##name() + +#define PYBIND11_TRY_NEXT_OVERLOAD ((PyObject *) 1) // special failure return code +#define PYBIND11_STRINGIFY(x) #x +#define PYBIND11_TOSTRING(x) PYBIND11_STRINGIFY(x) +#define PYBIND11_CONCAT(first, second) first##second +#define PYBIND11_ENSURE_INTERNALS_READY \ + { \ + pybind11::detail::get_internals_pp_manager().unref(); \ + pybind11::detail::get_internals(); \ + } + +#if !defined(GRAALVM_PYTHON) +# define PYBIND11_PYCFUNCTION_GET_DOC(func) ((func)->m_ml->ml_doc) +# define PYBIND11_PYCFUNCTION_SET_DOC(func, doc) \ + do { \ + (func)->m_ml->ml_doc = (doc); \ + } while (0) +#else +# define PYBIND11_PYCFUNCTION_GET_DOC(func) (GraalPyCFunction_GetDoc((PyObject *) (func))) +# define PYBIND11_PYCFUNCTION_SET_DOC(func, doc) \ + do { \ + GraalPyCFunction_SetDoc((PyObject *) (func), (doc)); \ + } while (0) +#endif + +#define PYBIND11_CHECK_PYTHON_VERSION \ + { \ + const char *compiled_ver \ + = PYBIND11_TOSTRING(PY_MAJOR_VERSION) "." PYBIND11_TOSTRING(PY_MINOR_VERSION); \ + const char *runtime_ver = Py_GetVersion(); \ + size_t len = std::strlen(compiled_ver); \ + if (std::strncmp(runtime_ver, compiled_ver, len) != 0 \ + || (runtime_ver[len] >= '0' && runtime_ver[len] <= '9')) { \ + PyErr_Format(PyExc_ImportError, \ + "Python version mismatch: module was compiled for Python %s, " \ + "but the interpreter version is incompatible: %s.", \ + compiled_ver, \ + runtime_ver); \ + return nullptr; \ + } \ + } + +#define PYBIND11_CATCH_INIT_EXCEPTIONS \ + catch (pybind11::error_already_set & e) { \ + pybind11::raise_from(e, PyExc_ImportError, "initialization failed"); \ + } \ + catch (const std::exception &e) { \ + ::pybind11::set_error(PyExc_ImportError, e.what()); \ + } + +/** \rst + ***Deprecated in favor of PYBIND11_MODULE*** + + This macro creates the entry point that will be invoked when the Python interpreter + imports a plugin library. Please create a `module_` in the function body and return + the pointer to its underlying Python object at the end. + + .. code-block:: cpp + + PYBIND11_PLUGIN(example) { + pybind11::module_ m("example", "pybind11 example plugin"); + /// Set up bindings here + return m.ptr(); + } +\endrst */ +#define PYBIND11_PLUGIN(name) \ + PYBIND11_DEPRECATED("PYBIND11_PLUGIN is deprecated, use PYBIND11_MODULE") \ + static PyObject *pybind11_init(); \ + PYBIND11_PLUGIN_IMPL(name) { \ + PYBIND11_CHECK_PYTHON_VERSION \ + PYBIND11_ENSURE_INTERNALS_READY \ + try { \ + return pybind11_init(); \ + } \ + PYBIND11_CATCH_INIT_EXCEPTIONS \ + return nullptr; \ + } \ + PyObject *pybind11_init() + +// this push is for the next several macros +PYBIND11_WARNING_PUSH +PYBIND11_WARNING_DISABLE_CLANG("-Wgnu-zero-variadic-macro-arguments") + +/** +Create a PyInit_ function for this module. + +Note that this is run once for each (sub-)interpreter the module is imported into, including +possibly concurrently. The PyModuleDef is allowed to be static, but the PyObject* resulting from +PyModuleDef_Init should be treated like any other PyObject (so not shared across interpreters). + */ +#define PYBIND11_MODULE_PYINIT(name, pre_init, ...) \ + static int PYBIND11_CONCAT(pybind11_exec_, name)(PyObject *); \ + PYBIND11_PLUGIN_IMPL(name) { \ + PYBIND11_CHECK_PYTHON_VERSION \ + pre_init; \ + PYBIND11_ENSURE_INTERNALS_READY \ + static ::pybind11::detail::slots_array mod_def_slots = ::pybind11::detail::init_slots( \ + &PYBIND11_CONCAT(pybind11_exec_, name), ##__VA_ARGS__); \ + static PyModuleDef def{/* m_base */ PyModuleDef_HEAD_INIT, \ + /* m_name */ PYBIND11_TOSTRING(name), \ + /* m_doc */ nullptr, \ + /* m_size */ 0, \ + /* m_methods */ nullptr, \ + /* m_slots */ mod_def_slots.data(), \ + /* m_traverse */ nullptr, \ + /* m_clear */ nullptr, \ + /* m_free */ nullptr}; \ + return PyModuleDef_Init(&def); \ + } + +#define PYBIND11_MODULE_EXEC(name, variable) \ + static void PYBIND11_CONCAT(pybind11_init_, name)(::pybind11::module_ &); \ + int PYBIND11_CONCAT(pybind11_exec_, name)(PyObject * pm) { \ + try { \ + auto m = pybind11::reinterpret_borrow<::pybind11::module_>(pm); \ + if (!pybind11::detail::get_cached_module(m.attr("__spec__").attr("name"))) { \ + PYBIND11_CONCAT(pybind11_init_, name)(m); \ + pybind11::detail::cache_completed_module(m); \ + } \ + return 0; \ + } \ + PYBIND11_CATCH_INIT_EXCEPTIONS \ + return -1; \ + } \ + void PYBIND11_CONCAT(pybind11_init_, name)(::pybind11::module_ \ + & variable) // NOLINT(bugprone-macro-parentheses) + +/** \rst + This macro creates the entry point that will be invoked when the Python interpreter + imports an extension module. The module name is given as the first argument and it + should not be in quotes. The second macro argument defines a variable of type + ``py::module_`` which can be used to initialize the module. + + The entry point is marked as "maybe unused" to aid dead-code detection analysis: + since the entry point is typically only looked up at runtime and not referenced + during translation, it would otherwise appear as unused ("dead") code. + + .. code-block:: cpp + + PYBIND11_MODULE(example, m) { + m.doc() = "pybind11 example module"; + + // Add bindings here + m.def("foo", []() { + return "Hello, World!"; + }); + } + + The third and subsequent macro arguments are optional (available since 2.13.0), and + can be used to mark the extension module as supporting various Python features. + + - ``mod_gil_not_used()`` + - ``multiple_interpreters::per_interpreter_gil()`` + - ``multiple_interpreters::shared_gil()`` + - ``multiple_interpreters::not_supported()`` + + .. code-block:: cpp + + PYBIND11_MODULE(example, m, py::mod_gil_not_used()) { + m.doc() = "pybind11 example module safe to run without the GIL"; + m.def("foo", []() { + return "Hello, Free-threaded World!"; + }); + } + +\endrst */ +#define PYBIND11_MODULE(name, variable, ...) \ + PYBIND11_MODULE_PYINIT( \ + name, (pybind11::detail::get_num_interpreters_seen() += 1), ##__VA_ARGS__) \ + PYBIND11_MODULE_EXEC(name, variable) + +// pop gnu-zero-variadic-macro-arguments +PYBIND11_WARNING_POP + +PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +using ssize_t = Py_ssize_t; +using size_t = std::size_t; + +template +inline ssize_t ssize_t_cast(const IntType &val) { + static_assert(sizeof(IntType) <= sizeof(ssize_t), "Implicit narrowing is not permitted."); + return static_cast(val); +} + +/// Approach used to cast a previously unknown C++ instance into a Python object +enum class return_value_policy : uint8_t { + /** This is the default return value policy, which falls back to the policy + return_value_policy::take_ownership when the return value is a pointer. + Otherwise, it uses return_value::move or return_value::copy for rvalue + and lvalue references, respectively. See below for a description of what + all of these different policies do. */ + automatic = 0, + + /** As above, but use policy return_value_policy::reference when the return + value is a pointer. This is the default conversion policy for function + arguments when calling Python functions manually from C++ code (i.e. via + handle::operator()). You probably won't need to use this. */ + automatic_reference, + + /** Reference an existing object (i.e. do not create a new copy) and take + ownership. Python will call the destructor and delete operator when the + object's reference count reaches zero. Undefined behavior ensues when + the C++ side does the same.. */ + take_ownership, + + /** Create a new copy of the returned object, which will be owned by + Python. This policy is comparably safe because the lifetimes of the two + instances are decoupled. */ + copy, + + /** Use std::move to move the return value contents into a new instance + that will be owned by Python. This policy is comparably safe because the + lifetimes of the two instances (move source and destination) are + decoupled. */ + move, + + /** Reference an existing object, but do not take ownership. The C++ side + is responsible for managing the object's lifetime and deallocating it + when it is no longer used. Warning: undefined behavior will ensue when + the C++ side deletes an object that is still referenced and used by + Python. */ + reference, + + /** This policy only applies to methods and properties. It references the + object without taking ownership similar to the above + return_value_policy::reference policy. In contrast to that policy, the + function or property's implicit this argument (called the parent) is + considered to be the owner of the return value (the child). + pybind11 then couples the lifetime of the parent to the child via a + reference relationship that ensures that the parent cannot be garbage + collected while Python is still using the child. More advanced + variations of this scheme are also possible using combinations of + return_value_policy::reference and the keep_alive call policy */ + reference_internal +}; + +PYBIND11_NAMESPACE_BEGIN(detail) + +inline static constexpr int log2(size_t n, int k = 0) { + return (n <= 1) ? k : log2(n >> 1, k + 1); +} + +// Returns the size as a multiple of sizeof(void *), rounded up. +inline static constexpr size_t size_in_ptrs(size_t s) { + return 1 + ((s - 1) >> log2(sizeof(void *))); +} + +/** + * The space to allocate for simple layout instance holders (see below) in multiple of the size of + * a pointer (e.g. 2 means 16 bytes on 64-bit architectures). The default is the minimum required + * to holder either a std::unique_ptr or std::shared_ptr (which is almost always + * sizeof(std::shared_ptr)). + */ +constexpr size_t instance_simple_holder_in_ptrs() { + static_assert(sizeof(std::shared_ptr) >= sizeof(std::unique_ptr), + "pybind assumes std::shared_ptrs are at least as big as std::unique_ptrs"); + return size_in_ptrs(sizeof(std::shared_ptr)); +} + +// Forward declarations +struct type_info; +struct value_and_holder; + +struct nonsimple_values_and_holders { + void **values_and_holders; + uint8_t *status; +}; + +/// The 'instance' type which needs to be standard layout (need to be able to use 'offsetof') +struct instance { + PyObject_HEAD + /// Storage for pointers and holder; see simple_layout, below, for a description + union { + void *simple_value_holder[1 + instance_simple_holder_in_ptrs()]; + nonsimple_values_and_holders nonsimple; + }; + /// Weak references + PyObject *weakrefs; + /// If true, the pointer is owned which means we're free to manage it with a holder. + bool owned : 1; + /** + * An instance has two possible value/holder layouts. + * + * Simple layout (when this flag is true), means the `simple_value_holder` is set with a + * pointer and the holder object governing that pointer, i.e. [val1*][holder]. This layout is + * applied whenever there is no python-side multiple inheritance of bound C++ types *and* the + * type's holder will fit in the default space (which is large enough to hold either a + * std::unique_ptr or std::shared_ptr). + * + * Non-simple layout applies when using custom holders that require more space than + * `shared_ptr` (which is typically the size of two pointers), or when multiple inheritance is + * used on the python side. Non-simple layout allocates the required amount of memory to have + * multiple bound C++ classes as parents. Under this layout, `nonsimple.values_and_holders` is + * set to a pointer to allocated space of the required space to hold a sequence of value + * pointers and holders followed `status`, a set of bit flags (1 byte each), i.e. + * [val1*][holder1][val2*][holder2]...[bb...] where each [block] is rounded up to a multiple + * of `sizeof(void *)`. `nonsimple.status` is, for convenience, a pointer to the beginning of + * the [bb...] block (but not independently allocated). + * + * Status bits indicate whether the associated holder is constructed (& + * status_holder_constructed) and whether the value pointer is registered (& + * status_instance_registered) in `registered_instances`. + */ + bool simple_layout : 1; + /// For simple layout, tracks whether the holder has been constructed + bool simple_holder_constructed : 1; + /// For simple layout, tracks whether the instance is registered in `registered_instances` + bool simple_instance_registered : 1; + /// If true, get_internals().patients has an entry for this object + bool has_patients : 1; + /// If true, this Python object needs to be kept alive for the lifetime of the C++ value. + bool is_alias : 1; + + /// Initializes all of the above type/values/holders data (but not the instance values + /// themselves) + void allocate_layout(); + + /// Destroys/deallocates all of the above + void deallocate_layout(); + + /// Returns the value_and_holder wrapper for the given type (or the first, if `find_type` + /// omitted). Returns a default-constructed (with `.inst = nullptr`) object on failure if + /// `throw_if_missing` is false. + value_and_holder get_value_and_holder(const type_info *find_type = nullptr, + bool throw_if_missing = true); + + /// Bit values for the non-simple status flags + static constexpr uint8_t status_holder_constructed = 1; + static constexpr uint8_t status_instance_registered = 2; +}; + +static_assert(std::is_standard_layout::value, + "Internal error: `pybind11::detail::instance` is not standard layout!"); + +// Some older compilers (e.g. gcc 9.4.0) require +// static_assert(always_false::value, "..."); +// instead of +// static_assert(false, "..."); +// to trigger the static_assert() in a template only if it is actually instantiated. +template +struct always_false : std::false_type {}; + +/// from __cpp_future__ import (convenient aliases from C++14/17) +#if defined(PYBIND11_CPP14) +using std::conditional_t; +using std::enable_if_t; +using std::remove_cv_t; +using std::remove_reference_t; +#else +template +using enable_if_t = typename std::enable_if::type; +template +using conditional_t = typename std::conditional::type; +template +using remove_cv_t = typename std::remove_cv::type; +template +using remove_reference_t = typename std::remove_reference::type; +#endif + +#if defined(PYBIND11_CPP20) && defined(__cpp_lib_remove_cvref) +using std::remove_cvref; +using std::remove_cvref_t; +#else +template +struct remove_cvref { + using type = remove_cv_t>; +}; +template +using remove_cvref_t = typename remove_cvref::type; +#endif + +/// Example usage: is_same_ignoring_cvref::value +template +using is_same_ignoring_cvref = std::is_same, U>; + +/// Index sequences +#if defined(PYBIND11_CPP14) +using std::index_sequence; +using std::make_index_sequence; +#else +template +struct index_sequence {}; +// Comments about the algorithm below. +// +// Credit: This is based on an algorithm by taocpp here: +// https://github.com/taocpp/sequences/blob/main/include/tao/seq/make_integer_sequence.hpp +// but significantly simplified. +// +// We build up a sequence S by repeatedly doubling its length and sometimes adding 1 to the end. +// E.g. if the current S is 0...3, then we either go to 0...7 or 0...8 on the next pass. +// The goal is to end with S = 0...N-1. +// The key insight is that the times we need to add an additional digit to S correspond +// exactly to the 1's in the binary representation of the number N. +// +// Invariants: +// - digit is a power of 2 +// - N_digit_is_1 is whether N's binary representation has a 1 in that digit's position. +// - end <= N +// - S is 0...end-1. +// - if digit > 0, end * digit * 2 <= N < (end+1) * digit * 2 +// +// The process starts with digit > N, end = 0, and S is empty. +// The process concludes with digit=0, in which case, end == N and S is 0...N-1. + +template // N_digit_is_1=false +struct make_index_sequence_impl + : make_index_sequence_impl { +}; +template +struct make_index_sequence_impl + : make_index_sequence_impl {}; +template +struct make_index_sequence_impl<0, false, N, end, S...> { + using type = index_sequence; +}; +constexpr size_t next_power_of_2(size_t N) { return N == 0 ? 1 : next_power_of_2(N >> 1) << 1; } +template +using make_index_sequence = + typename make_index_sequence_impl::type; +#endif + +/// Make an index sequence of the indices of true arguments +template +struct select_indices_impl { + using type = ISeq; +}; +template +struct select_indices_impl, I, B, Bs...> + : select_indices_impl, index_sequence>, + I + 1, + Bs...> {}; +template +using select_indices = typename select_indices_impl, 0, Bs...>::type; + +/// Backports of std::bool_constant and std::negation to accommodate older compilers +template +using bool_constant = std::integral_constant; +template +struct negation : bool_constant {}; + +// PGI/Intel cannot detect operator delete with the "compatible" void_t impl, so +// using the new one (C++14 defect, so generally works on newer compilers, even +// if not in C++17 mode) +#if defined(__PGIC__) || defined(__INTEL_COMPILER) +template +using void_t = void; +#else +template +struct void_t_impl { + using type = void; +}; +template +using void_t = typename void_t_impl::type; +#endif + +/// Compile-time all/any/none of that check the boolean value of all template types +#if defined(__cpp_fold_expressions) && !(defined(_MSC_VER) && (_MSC_VER < 1916)) +template +using all_of = bool_constant<(Ts::value && ...)>; +template +using any_of = bool_constant<(Ts::value || ...)>; +#elif !defined(_MSC_VER) +template +struct bools {}; +template +using all_of = std::is_same, bools>; +template +using any_of = negation...>>; +#else +// MSVC has trouble with the above, but supports std::conjunction, which we can use instead (albeit +// at a slight loss of compilation efficiency). +template +using all_of = std::conjunction; +template +using any_of = std::disjunction; +#endif +template +using none_of = negation>; + +template class... Predicates> +using satisfies_all_of = all_of...>; +template class... Predicates> +using satisfies_any_of = any_of...>; +template class... Predicates> +using satisfies_none_of = none_of...>; + +/// Strip the class from a method type +template +struct remove_class {}; +template +struct remove_class { + using type = R(A...); +}; +template +struct remove_class { + using type = R(A...); +}; +#ifdef __cpp_noexcept_function_type +template +struct remove_class { + using type = R(A...); +}; +template +struct remove_class { + using type = R(A...); +}; +#endif +/// Helper template to strip away type modifiers +template +struct intrinsic_type { + using type = T; +}; +template +struct intrinsic_type { + using type = typename intrinsic_type::type; +}; +template +struct intrinsic_type { + using type = typename intrinsic_type::type; +}; +template +struct intrinsic_type { + using type = typename intrinsic_type::type; +}; +template +struct intrinsic_type { + using type = typename intrinsic_type::type; +}; +template +struct intrinsic_type { + using type = typename intrinsic_type::type; +}; +template +struct intrinsic_type { + using type = typename intrinsic_type::type; +}; +template +using intrinsic_t = typename intrinsic_type::type; + +/// Helper type to replace 'void' in some expressions +struct void_type {}; + +/// Helper template which holds a list of types +template +struct type_list {}; + +/// Compile-time integer sum +#ifdef __cpp_fold_expressions +template +constexpr size_t constexpr_sum(Ts... ns) { + return (0 + ... + size_t{ns}); +} +#else +constexpr size_t constexpr_sum() { return 0; } +template +constexpr size_t constexpr_sum(T n, Ts... ns) { + return size_t{n} + constexpr_sum(ns...); +} +#endif + +PYBIND11_NAMESPACE_BEGIN(constexpr_impl) +/// Implementation details for constexpr functions +constexpr int first(int i) { return i; } +template +constexpr int first(int i, T v, Ts... vs) { + return v ? i : first(i + 1, vs...); +} + +constexpr int last(int /*i*/, int result) { return result; } +template +constexpr int last(int i, int result, T v, Ts... vs) { + return last(i + 1, v ? i : result, vs...); +} +PYBIND11_NAMESPACE_END(constexpr_impl) + +/// Return the index of the first type in Ts which satisfies Predicate. +/// Returns sizeof...(Ts) if none match. +template