diff --git a/.gitattributes b/.gitattributes index 678ae2acd91c5ef42af69be09a10c0beaf4dd9f2..7963ab5de736ac4250a429ef86292a023bf951f9 100644 --- a/.gitattributes +++ b/.gitattributes @@ -267,3 +267,6 @@ source/tvm_ffi/lib/libtvm_ffi.so filter=lfs diff=lfs merge=lfs -text source/tvm_ffi/lib/libtvm_ffi_testing.so filter=lfs diff=lfs merge=lfs -text source/uvloop/loop.cpython-312-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text source/watchfiles/_rust_notify.cpython-312-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +source/yaml/_yaml.cpython-312-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +source/yarl/_quoting_c.cpython-312-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +source/zmq/backend/cython/_zmq.abi3.so filter=lfs diff=lfs merge=lfs -text diff --git a/source/watchfiles-1.1.1.dist-info/INSTALLER b/source/watchfiles-1.1.1.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/source/watchfiles-1.1.1.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/source/watchfiles-1.1.1.dist-info/METADATA b/source/watchfiles-1.1.1.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..ed8d9040d978122b5f2a537e8086392efdfe9448 --- /dev/null +++ b/source/watchfiles-1.1.1.dist-info/METADATA @@ -0,0 +1,148 @@ +Metadata-Version: 2.4 +Name: watchfiles +Version: 1.1.1 +Classifier: Development Status :: 5 - Production/Stable +Classifier: Environment :: Console +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Programming Language :: Python :: 3.13 +Classifier: Programming Language :: Python :: 3.14 +Classifier: Intended Audience :: Developers +Classifier: Intended Audience :: Information Technology +Classifier: Intended Audience :: System Administrators +Classifier: License :: OSI Approved :: MIT License +Classifier: Operating System :: POSIX :: Linux +Classifier: Operating System :: Microsoft :: Windows +Classifier: Operating System :: MacOS +Classifier: Environment :: MacOS X +Classifier: Topic :: Software Development :: Libraries :: Python Modules +Classifier: Topic :: System :: Filesystems +Classifier: Framework :: AnyIO +Requires-Dist: anyio>=3.0.0 +License-File: LICENSE +Summary: Simple, modern and high performance file watching and code reload in python. +Home-Page: https://github.com/samuelcolvin/watchfiles +Author-email: Samuel Colvin +License: MIT +Requires-Python: >=3.9 +Description-Content-Type: text/markdown; charset=UTF-8; variant=GFM +Project-URL: Homepage, https://github.com/samuelcolvin/watchfiles +Project-URL: Documentation, https://watchfiles.helpmanual.io +Project-URL: Funding, https://github.com/sponsors/samuelcolvin +Project-URL: Source, https://github.com/samuelcolvin/watchfiles +Project-URL: Changelog, https://github.com/samuelcolvin/watchfiles/releases + +# watchfiles + +[![CI](https://github.com/samuelcolvin/watchfiles/actions/workflows/ci.yml/badge.svg)](https://github.com/samuelcolvin/watchfiles/actions/workflows/ci.yml?query=branch%3Amain) +[![Coverage](https://codecov.io/gh/samuelcolvin/watchfiles/branch/main/graph/badge.svg)](https://codecov.io/gh/samuelcolvin/watchfiles) +[![pypi](https://img.shields.io/pypi/v/watchfiles.svg)](https://pypi.python.org/pypi/watchfiles) +[![CondaForge](https://img.shields.io/conda/v/conda-forge/watchfiles.svg)](https://anaconda.org/conda-forge/watchfiles) +[![license](https://img.shields.io/github/license/samuelcolvin/watchfiles.svg)](https://github.com/samuelcolvin/watchfiles/blob/main/LICENSE) + +Simple, modern and high performance file watching and code reload in python. + +--- + +**Documentation**: [watchfiles.helpmanual.io](https://watchfiles.helpmanual.io) + +**Source Code**: [github.com/samuelcolvin/watchfiles](https://github.com/samuelcolvin/watchfiles) + +--- + +Underlying file system notifications are handled by the [Notify](https://github.com/notify-rs/notify) rust library. + +This package was previously named "watchgod", +see [the migration guide](https://watchfiles.helpmanual.io/migrating/) for more information. + +## Installation + +**watchfiles** requires Python 3.9 - 3.14. + +```bash +pip install watchfiles +``` + +Binaries are available for most architectures on Linux, MacOS and Windows ([learn more](https://watchfiles.helpmanual.io/#installation)). + +Otherwise, you can install from source which requires Rust stable to be installed. + +## Usage + +Here are some examples of what **watchfiles** can do: + +### `watch` Usage + +```py +from watchfiles import watch + +for changes in watch('./path/to/dir'): + print(changes) +``` +See [`watch` docs](https://watchfiles.helpmanual.io/api/watch/#watchfiles.watch) for more details. + +### `awatch` Usage + +```py +import asyncio +from watchfiles import awatch + +async def main(): + async for changes in awatch('/path/to/dir'): + print(changes) + +asyncio.run(main()) +``` +See [`awatch` docs](https://watchfiles.helpmanual.io/api/watch/#watchfiles.awatch) for more details. + +### `run_process` Usage + +```py +from watchfiles import run_process + +def foobar(a, b, c): + ... + +if __name__ == '__main__': + run_process('./path/to/dir', target=foobar, args=(1, 2, 3)) +``` +See [`run_process` docs](https://watchfiles.helpmanual.io/api/run_process/#watchfiles.run_process) for more details. + +### `arun_process` Usage + +```py +import asyncio +from watchfiles import arun_process + +def foobar(a, b, c): + ... + +async def main(): + await arun_process('./path/to/dir', target=foobar, args=(1, 2, 3)) + +if __name__ == '__main__': + asyncio.run(main()) +``` +See [`arun_process` docs](https://watchfiles.helpmanual.io/api/run_process/#watchfiles.arun_process) for more details. + +## CLI + +**watchfiles** also comes with a CLI for running and reloading code. To run `some command` when files in `src` change: + +``` +watchfiles "some command" src +``` + +For more information, see [the CLI docs](https://watchfiles.helpmanual.io/cli/). + +Or run + +```bash +watchfiles --help +``` + diff --git a/source/watchfiles-1.1.1.dist-info/RECORD b/source/watchfiles-1.1.1.dist-info/RECORD new file mode 100644 index 0000000000000000000000000000000000000000..606cd6260a03cb40b59402332f805a1f1b817076 --- /dev/null +++ b/source/watchfiles-1.1.1.dist-info/RECORD @@ -0,0 +1,24 @@ +../../bin/watchfiles,sha256=UmgepAyVu9Gw-Yp6nEG9ks2cXHYq9nd7hBmVixDPM7s,211 +watchfiles-1.1.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +watchfiles-1.1.1.dist-info/METADATA,sha256=h34wYtQyezaYEn9GQWg8z4d9JCViFLT4vmKd_ip6WF8,4874 +watchfiles-1.1.1.dist-info/RECORD,, +watchfiles-1.1.1.dist-info/WHEEL,sha256=AUS7tHOBvWg1bDsPcHg1j3P_rKxqebEdeR--lIGHkyI,129 +watchfiles-1.1.1.dist-info/entry_points.txt,sha256=s1Dpa2d_KKBy-jKREWW60Z3GoRZ3JpCEo_9iYDt6hOQ,48 +watchfiles-1.1.1.dist-info/licenses/LICENSE,sha256=T9eDVbZ84md-3p-29jolDzd7t-IgiBNqX0aZrbS8Bp8,1091 +watchfiles/__init__.py,sha256=IRlM9KOSedMzF1fvLr7yEHPVS-UFERNThlB-tmWI8yU,364 +watchfiles/__main__.py,sha256=JgErYkiskih8Y6oRwowALtR-rwQhAAdqOYWjQraRIPI,59 +watchfiles/__pycache__/__init__.cpython-312.pyc,, +watchfiles/__pycache__/__main__.cpython-312.pyc,, +watchfiles/__pycache__/cli.cpython-312.pyc,, +watchfiles/__pycache__/filters.cpython-312.pyc,, +watchfiles/__pycache__/main.cpython-312.pyc,, +watchfiles/__pycache__/run.cpython-312.pyc,, +watchfiles/__pycache__/version.cpython-312.pyc,, +watchfiles/_rust_notify.cpython-312-x86_64-linux-gnu.so,sha256=sJsIMMJW0QyNqKUFF2eg4YaVUywMgJiAjdubGmyjAo0,1124288 +watchfiles/_rust_notify.pyi,sha256=q5FQkXgBJEFPt9RCf7my4wP5RM1FwSVpqf221csyebg,4753 +watchfiles/cli.py,sha256=DHMI0LfT7hOrWai_Y4RP_vvTvVdtcDaioixXLiv2pG4,7707 +watchfiles/filters.py,sha256=U0zXGOeg9dMHkT51-56BKpRrWIu95lPq0HDR_ZB4oDE,5139 +watchfiles/main.py,sha256=-pbJBFBA34VEXMt8VGcaPTQHAjsGhPf7Psu1gP_HnKk,15235 +watchfiles/py.typed,sha256=MS4Na3to9VTGPy_8wBQM_6mNKaX4qIpi5-w7_LZB-8I,69 +watchfiles/run.py,sha256=TLXb2y_xYx-t3xyszVQWHoGyG7RCb107Q0NoIcSWmjQ,15348 +watchfiles/version.py,sha256=NRWUnkZ32DamsNKV20EetagIGTLDMMUnqDWVGFFA2WQ,85 diff --git a/source/watchfiles-1.1.1.dist-info/WHEEL b/source/watchfiles-1.1.1.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..ecc076195c46d7bd21f36f202ceb5977d3100be8 --- /dev/null +++ b/source/watchfiles-1.1.1.dist-info/WHEEL @@ -0,0 +1,4 @@ +Wheel-Version: 1.0 +Generator: maturin (1.9.6) +Root-Is-Purelib: false +Tag: cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64 diff --git a/source/watchfiles-1.1.1.dist-info/entry_points.txt b/source/watchfiles-1.1.1.dist-info/entry_points.txt new file mode 100644 index 0000000000000000000000000000000000000000..51642969b76b6d4a8c0e9437a0ddae58772e835b --- /dev/null +++ b/source/watchfiles-1.1.1.dist-info/entry_points.txt @@ -0,0 +1,2 @@ +[console_scripts] +watchfiles=watchfiles.cli:cli diff --git a/source/watchfiles-1.1.1.dist-info/licenses/LICENSE b/source/watchfiles-1.1.1.dist-info/licenses/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..bdd9b15521a6e7860883f7d5fea2ed29e1c0422e --- /dev/null +++ b/source/watchfiles-1.1.1.dist-info/licenses/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2017 to present Samuel Colvin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/source/websockets-16.0.dist-info/INSTALLER b/source/websockets-16.0.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/source/websockets-16.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/source/websockets-16.0.dist-info/METADATA b/source/websockets-16.0.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..53b147e4b45cc4ec0502fd8388f8c65f6e89c97c --- /dev/null +++ b/source/websockets-16.0.dist-info/METADATA @@ -0,0 +1,179 @@ +Metadata-Version: 2.4 +Name: websockets +Version: 16.0 +Summary: An implementation of the WebSocket Protocol (RFC 6455 & 7692) +Author-email: Aymeric Augustin +License-Expression: BSD-3-Clause +Project-URL: Homepage, https://github.com/python-websockets/websockets +Project-URL: Changelog, https://websockets.readthedocs.io/en/stable/project/changelog.html +Project-URL: Documentation, https://websockets.readthedocs.io/ +Project-URL: Funding, https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=readme +Project-URL: Tracker, https://github.com/python-websockets/websockets/issues +Keywords: WebSocket +Classifier: Development Status :: 5 - Production/Stable +Classifier: Environment :: Web Environment +Classifier: Intended Audience :: Developers +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Programming Language :: Python :: 3.13 +Classifier: Programming Language :: Python :: 3.14 +Requires-Python: >=3.10 +Description-Content-Type: text/x-rst +License-File: LICENSE +Dynamic: description +Dynamic: description-content-type +Dynamic: license-file + +.. image:: logo/horizontal.svg + :width: 480px + :alt: websockets + +|licence| |version| |pyversions| |tests| |docs| |openssf| + +.. |licence| image:: https://img.shields.io/pypi/l/websockets.svg + :target: https://pypi.python.org/pypi/websockets + +.. |version| image:: https://img.shields.io/pypi/v/websockets.svg + :target: https://pypi.python.org/pypi/websockets + +.. |pyversions| image:: https://img.shields.io/pypi/pyversions/websockets.svg + :target: https://pypi.python.org/pypi/websockets + +.. |tests| image:: https://img.shields.io/github/checks-status/python-websockets/websockets/main?label=tests + :target: https://github.com/python-websockets/websockets/actions/workflows/tests.yml + +.. |docs| image:: https://img.shields.io/readthedocs/websockets.svg + :target: https://websockets.readthedocs.io/ + +.. |openssf| image:: https://bestpractices.coreinfrastructure.org/projects/6475/badge + :target: https://bestpractices.coreinfrastructure.org/projects/6475 + +What is ``websockets``? +----------------------- + +websockets is a library for building WebSocket_ servers and clients in Python +with a focus on correctness, simplicity, robustness, and performance. + +.. _WebSocket: https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API + +Built on top of ``asyncio``, Python's standard asynchronous I/O framework, the +default implementation provides an elegant coroutine-based API. + +An implementation on top of ``threading`` and a Sans-I/O implementation are also +available. + +`Documentation is available on Read the Docs. `_ + +.. copy-pasted because GitHub doesn't support the include directive + +Here's an echo server with the ``asyncio`` API: + +.. code:: python + + #!/usr/bin/env python + + import asyncio + from websockets.asyncio.server import serve + + async def echo(websocket): + async for message in websocket: + await websocket.send(message) + + async def main(): + async with serve(echo, "localhost", 8765) as server: + await server.serve_forever() + + asyncio.run(main()) + +Here's how a client sends and receives messages with the ``threading`` API: + +.. code:: python + + #!/usr/bin/env python + + from websockets.sync.client import connect + + def hello(): + with connect("ws://localhost:8765") as websocket: + websocket.send("Hello world!") + message = websocket.recv() + print(f"Received: {message}") + + hello() + + +Does that look good? + +`Get started with the tutorial! `_ + +Why should I use ``websockets``? +-------------------------------- + +The development of ``websockets`` is shaped by four principles: + +1. **Correctness**: ``websockets`` is heavily tested for compliance with + :rfc:`6455`. Continuous integration fails under 100% branch coverage. + +2. **Simplicity**: all you need to understand is ``msg = await ws.recv()`` and + ``await ws.send(msg)``. ``websockets`` takes care of managing connections + so you can focus on your application. + +3. **Robustness**: ``websockets`` is built for production. For example, it was + the only library to `handle backpressure correctly`_ before the issue + became widely known in the Python community. + +4. **Performance**: memory usage is optimized and configurable. A C extension + accelerates expensive operations. It's pre-compiled for Linux, macOS and + Windows and packaged in the wheel format for each system and Python version. + +Documentation is a first class concern in the project. Head over to `Read the +Docs`_ and see for yourself. + +.. _Read the Docs: https://websockets.readthedocs.io/ +.. _handle backpressure correctly: https://vorpus.org/blog/some-thoughts-on-asynchronous-api-design-in-a-post-asyncawait-world/#websocket-servers + +Why shouldn't I use ``websockets``? +----------------------------------- + +* If you prefer callbacks over coroutines: ``websockets`` was created to + provide the best coroutine-based API to manage WebSocket connections in + Python. Pick another library for a callback-based API. + +* If you're looking for a mixed HTTP / WebSocket library: ``websockets`` aims + at being an excellent implementation of :rfc:`6455`: The WebSocket Protocol + and :rfc:`7692`: Compression Extensions for WebSocket. Its support for HTTP + is minimal — just enough for an HTTP health check. + + If you want to do both in the same server, look at HTTP + WebSocket servers + that build on top of ``websockets`` to support WebSocket connections, like + uvicorn_ or Sanic_. + +.. _uvicorn: https://www.uvicorn.org/ +.. _Sanic: https://sanic.dev/en/ + +What else? +---------- + +Bug reports, patches and suggestions are welcome! + +To report a security vulnerability, please use the `Tidelift security +contact`_. Tidelift will coordinate the fix and disclosure. + +.. _Tidelift security contact: https://tidelift.com/security + +For anything else, please open an issue_ or send a `pull request`_. + +.. _issue: https://github.com/python-websockets/websockets/issues/new +.. _pull request: https://github.com/python-websockets/websockets/compare/ + +Participants must uphold the `Contributor Covenant code of conduct`_. + +.. _Contributor Covenant code of conduct: https://github.com/python-websockets/websockets/blob/main/CODE_OF_CONDUCT.md + +``websockets`` is released under the `BSD license`_. + +.. _BSD license: https://github.com/python-websockets/websockets/blob/main/LICENSE diff --git a/source/websockets-16.0.dist-info/RECORD b/source/websockets-16.0.dist-info/RECORD new file mode 100644 index 0000000000000000000000000000000000000000..12b7b579d092059cc1c953be50a0960473306c25 --- /dev/null +++ b/source/websockets-16.0.dist-info/RECORD @@ -0,0 +1,108 @@ +../../bin/websockets,sha256=jIwwGFqaK2AvxQf01v-BJ3EMdDyCxBvpgWufffO9SyU,213 +websockets-16.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +websockets-16.0.dist-info/METADATA,sha256=JcDvWo8DVSw5uoDAFbk9N8fJXuRJvnrcLXVBFyBjwN8,6799 +websockets-16.0.dist-info/RECORD,, +websockets-16.0.dist-info/WHEEL,sha256=mX4U4odf6w47aVjwZUmTYd1MF9BbrhVLKlaWSvZwHEk,186 +websockets-16.0.dist-info/entry_points.txt,sha256=Dnhn4dm5EsI4ZMAsHldGF6CwBXZrGXnR7cnK2-XR7zY,51 +websockets-16.0.dist-info/licenses/LICENSE,sha256=PWoMBQ2L7FL6utUC5F-yW9ArytvXDeo01Ee2oP9Obag,1514 +websockets-16.0.dist-info/top_level.txt,sha256=CMpdKklxKsvZgCgyltxUWOHibZXZ1uYIVpca9xsQ8Hk,11 +websockets/__init__.py,sha256=AC2Hq92uSc_WOo9_xvITpGshJ7Dy0Md5m2_ywsdSt_Y,7058 +websockets/__main__.py,sha256=wu5N2wk8mvBgyvr2ghmQf4prezAe0_i-p123VVreyYc,62 +websockets/__pycache__/__init__.cpython-312.pyc,, +websockets/__pycache__/__main__.cpython-312.pyc,, +websockets/__pycache__/auth.cpython-312.pyc,, +websockets/__pycache__/cli.cpython-312.pyc,, +websockets/__pycache__/client.cpython-312.pyc,, +websockets/__pycache__/connection.cpython-312.pyc,, +websockets/__pycache__/datastructures.cpython-312.pyc,, +websockets/__pycache__/exceptions.cpython-312.pyc,, +websockets/__pycache__/frames.cpython-312.pyc,, +websockets/__pycache__/headers.cpython-312.pyc,, +websockets/__pycache__/http.cpython-312.pyc,, +websockets/__pycache__/http11.cpython-312.pyc,, +websockets/__pycache__/imports.cpython-312.pyc,, +websockets/__pycache__/protocol.cpython-312.pyc,, +websockets/__pycache__/proxy.cpython-312.pyc,, +websockets/__pycache__/server.cpython-312.pyc,, +websockets/__pycache__/streams.cpython-312.pyc,, +websockets/__pycache__/typing.cpython-312.pyc,, +websockets/__pycache__/uri.cpython-312.pyc,, +websockets/__pycache__/utils.cpython-312.pyc,, +websockets/__pycache__/version.cpython-312.pyc,, +websockets/asyncio/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +websockets/asyncio/__pycache__/__init__.cpython-312.pyc,, +websockets/asyncio/__pycache__/async_timeout.cpython-312.pyc,, +websockets/asyncio/__pycache__/client.cpython-312.pyc,, +websockets/asyncio/__pycache__/compatibility.cpython-312.pyc,, +websockets/asyncio/__pycache__/connection.cpython-312.pyc,, +websockets/asyncio/__pycache__/messages.cpython-312.pyc,, +websockets/asyncio/__pycache__/router.cpython-312.pyc,, +websockets/asyncio/__pycache__/server.cpython-312.pyc,, +websockets/asyncio/async_timeout.py,sha256=N-6Mubyiaoh66PAXGvCzhgxCM-7V2XiRnH32Xi6J6TE,8971 +websockets/asyncio/client.py,sha256=e4xlgtzb3v29M2vN-UDiyoUtThg--d5GqKg3lt2pDdE,30850 +websockets/asyncio/compatibility.py,sha256=gkenDDhzNbm6_iXV5Edvbvp6uHZYdrTvGNjt8P_JtyQ,786 +websockets/asyncio/connection.py,sha256=87RdVURijJk8V-ShWAWfTEyhW5Z1YUXKV8ezUzxt5L0,49099 +websockets/asyncio/messages.py,sha256=u2M5WKY9xPyw8G3nKoXfdO5K41hrTnf4MdizVHzgdM4,11129 +websockets/asyncio/router.py,sha256=S-69vszK-SqUCcZbXXPOnux-eH2fTHYC2JNh7tOtmmA,7520 +websockets/asyncio/server.py,sha256=wQ9oBc0WBOIzbXKDYJ8UhXRTeoXrSfLu6CWCrUl-vck,37941 +websockets/auth.py,sha256=U_Jwmn59ZRQ6EecpOvMizQCG_ZbAvgUf1ik7haZRC3c,568 +websockets/cli.py,sha256=YnegH59z93JxSVIGiXiWhR3ktgI6k1_pf_BRLanxKrQ,5336 +websockets/client.py,sha256=fljI5k5oQ-Sfm53MCoyTlr2jFtOOIuO13H9bbtpBPes,13789 +websockets/connection.py,sha256=OLiMVkNd25_86sB8Q7CrCwBoXy9nA0OCgdgLRA8WUR8,323 +websockets/datastructures.py,sha256=Uq2CpjmXak9_pPWcOqh36rzJMo8eCi2lVPTFWDvK5sA,5518 +websockets/exceptions.py,sha256=bgaMdqQGGZosAEULeCB30XW2YnwomWa3c8YOrEfeOoY,12859 +websockets/extensions/__init__.py,sha256=QkZsxaJVllVSp1uhdD5uPGibdbx_091GrVVfS5LXcpw,98 +websockets/extensions/__pycache__/__init__.cpython-312.pyc,, +websockets/extensions/__pycache__/base.cpython-312.pyc,, +websockets/extensions/__pycache__/permessage_deflate.cpython-312.pyc,, +websockets/extensions/base.py,sha256=JNfyk543C7VuPH0QOobiqKoGrzjJILje6sz5ILvOPl4,2903 +websockets/extensions/permessage_deflate.py,sha256=AkuhkAKFo5lqJQMXnckbSs9b2KBBrOFsE1DHIcbLL3k,25770 +websockets/frames.py,sha256=5IK4GZpl8ukr0bZ_UA_jjjztK09yYQAl9m5NVmGLiK0,12889 +websockets/headers.py,sha256=yQnPljVZwV1_V-pOSRKNLG_u827wFC1h72cciojcQ8M,16046 +websockets/http.py,sha256=T1tNLmbkFCneXQ6qepBmsVVDXyP9i500IVzTJTeBMR4,659 +websockets/http11.py,sha256=T8ai5BcBGkV0n9It63oDeNpmtQMyg8Cpav5rf_yT0r4,15619 +websockets/imports.py,sha256=T_B9TUmHoceKMQ-PNphdQQAH2XdxAxwSQNeQEgqILkE,2795 +websockets/legacy/__init__.py,sha256=wQ5zRIENGUS_5eKNAX9CRE7x1TwKapKimrQFFWN9Sxs,276 +websockets/legacy/__pycache__/__init__.cpython-312.pyc,, +websockets/legacy/__pycache__/auth.cpython-312.pyc,, +websockets/legacy/__pycache__/client.cpython-312.pyc,, +websockets/legacy/__pycache__/exceptions.cpython-312.pyc,, +websockets/legacy/__pycache__/framing.cpython-312.pyc,, +websockets/legacy/__pycache__/handshake.cpython-312.pyc,, +websockets/legacy/__pycache__/http.cpython-312.pyc,, +websockets/legacy/__pycache__/protocol.cpython-312.pyc,, +websockets/legacy/__pycache__/server.cpython-312.pyc,, +websockets/legacy/auth.py,sha256=DcQcCSeVeP93JcH8vFWE0HIJL-X-f23LZ0DsJpav1So,6531 +websockets/legacy/client.py,sha256=fV2mbiU9rciXhJfAEKVSm0GztJDUbDpRQ-K5EMbkuQ0,26815 +websockets/legacy/exceptions.py,sha256=ViEjpoT09fzx_Zqf0aNGDVtRDNjXaOw0gdCta3LkjFc,1924 +websockets/legacy/framing.py,sha256=r9P1wiXv_1XuAVQw8SOPkuE9d4eZ0r_JowAkz9-WV4w,6366 +websockets/legacy/handshake.py,sha256=2Nzr5AN2xvDC5EdNP-kB3lOcrAaUNlYuj_-hr_jv7pM,5285 +websockets/legacy/http.py,sha256=cOCQmDWhIKQmm8UWGXPW7CDZg03wjogCsb0LP9oetNQ,7061 +websockets/legacy/protocol.py,sha256=ajtVXDb-lEm9BN0NF3iEaTI_b1q5fBCKTB9wvUoGOxY,63632 +websockets/legacy/server.py,sha256=7mwY-yD0ljNF93oPYumTWD7OIVbCWtaEOw1FFJBhIAM,45251 +websockets/protocol.py,sha256=vTqjPIg2HmO-bSxsczuEmWMxPTxPXU1hmVUjqnahV44,27247 +websockets/proxy.py,sha256=oFrbEYtasYWv-WDcniObD9nBR5Q5qkHpyCVLngx7WMQ,4969 +websockets/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +websockets/server.py,sha256=E4SWBA8WZRmAOpsUm-oCqacBGZre9e0iDmDIrfpV21Q,21790 +websockets/speedups.c,sha256=u_dncR4M38EX6He_fzb1TY6D3Hke67ZpoHLLhZZ0hvQ,5920 +websockets/speedups.cpython-312-x86_64-linux-gnu.so,sha256=F8FiVerlQi_Z0YSsuY_ASEHvWcddXkyyRa3ylkV80B0,38048 +websockets/speedups.pyi,sha256=unjvBNg-uW4c7z-9OW4WiSzZk_QH2bLEcjYAMuoSgBI,102 +websockets/streams.py,sha256=pXqga7ttjuF6lChWYiWLSfUlt3FCaQpEX1ae_jvcCeQ,4071 +websockets/sync/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +websockets/sync/__pycache__/__init__.cpython-312.pyc,, +websockets/sync/__pycache__/client.cpython-312.pyc,, +websockets/sync/__pycache__/connection.cpython-312.pyc,, +websockets/sync/__pycache__/messages.cpython-312.pyc,, +websockets/sync/__pycache__/router.cpython-312.pyc,, +websockets/sync/__pycache__/server.cpython-312.pyc,, +websockets/sync/__pycache__/utils.cpython-312.pyc,, +websockets/sync/client.py,sha256=_2Erytw1f3f9O_u2jLtS1oNV4HsHUi_h3lGvT9ZEaDQ,22108 +websockets/sync/connection.py,sha256=1pJYEMRHLWIN7538vJcIeFVnvSXVrD0n1xrfX7wDNSc,41868 +websockets/sync/messages.py,sha256=yZV1zhY07ZD0vRF5b1yDa7ug0rbA5UDOCCCQmWwAcds,12858 +websockets/sync/router.py,sha256=BqKSAKNZYtRWiOxol9qYeyfgyXRrMNJ6FrTTZLNcXMg,7172 +websockets/sync/server.py,sha256=s07HNK_2s1kLN62Uqc77uvND0z7C0YTXGePsCiBtXaE,27655 +websockets/sync/utils.py,sha256=TtW-ncYFvJmiSW2gO86ngE2BVsnnBdL-4H88kWNDYbg,1107 +websockets/typing.py,sha256=A6xh4m65pRzKAbuOs0kFuGhL4DWIIko-ppS4wvJVc0Q,1946 +websockets/uri.py,sha256=2fFMw-AbKJ5HVHNCuw1Rx1MnkCkNWRpogxWhhM30EU4,3125 +websockets/utils.py,sha256=AwhS4UmlbKv7meAaR7WNbUqD5JFoStOP1bAyo9sRMus,1197 +websockets/version.py,sha256=IhaztWxysdY-pd-0nOubnnPduvySSvdoBwrQdJKtZ2g,3202 diff --git a/source/websockets-16.0.dist-info/WHEEL b/source/websockets-16.0.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..9921a02fb6f1f01a39d88a07edea51e0d980e0dd --- /dev/null +++ b/source/websockets-16.0.dist-info/WHEEL @@ -0,0 +1,7 @@ +Wheel-Version: 1.0 +Generator: setuptools (80.9.0) +Root-Is-Purelib: false +Tag: cp312-cp312-manylinux_2_5_x86_64 +Tag: cp312-cp312-manylinux1_x86_64 +Tag: cp312-cp312-manylinux_2_28_x86_64 + diff --git a/source/websockets-16.0.dist-info/entry_points.txt b/source/websockets-16.0.dist-info/entry_points.txt new file mode 100644 index 0000000000000000000000000000000000000000..60cd61ca0074591bf4dd607f11ee22310a3daab4 --- /dev/null +++ b/source/websockets-16.0.dist-info/entry_points.txt @@ -0,0 +1,2 @@ +[console_scripts] +websockets = websockets.cli:main diff --git a/source/websockets-16.0.dist-info/licenses/LICENSE b/source/websockets-16.0.dist-info/licenses/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..5d61ece22a75a759aed8e52af280eca28d35d6bf --- /dev/null +++ b/source/websockets-16.0.dist-info/licenses/LICENSE @@ -0,0 +1,24 @@ +Copyright (c) Aymeric Augustin and contributors + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + * Neither the name of the copyright holder nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/source/websockets-16.0.dist-info/top_level.txt b/source/websockets-16.0.dist-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..14774b465e97f655dbcaa60d97c8a9aa72e7d51b --- /dev/null +++ b/source/websockets-16.0.dist-info/top_level.txt @@ -0,0 +1 @@ +websockets diff --git a/source/websockets/__init__.py b/source/websockets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f90aff5b958ab3a19ecbb3e5ccad676e3719d9ab --- /dev/null +++ b/source/websockets/__init__.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +# Importing the typing module would conflict with websockets.typing. +from typing import TYPE_CHECKING + +from .imports import lazy_import +from .version import version as __version__ # noqa: F401 + + +__all__ = [ + # .asyncio.client + "connect", + "unix_connect", + "ClientConnection", + # .asyncio.router + "route", + "unix_route", + "Router", + # .asyncio.server + "basic_auth", + "broadcast", + "serve", + "unix_serve", + "ServerConnection", + "Server", + # .client + "ClientProtocol", + # .datastructures + "Headers", + "HeadersLike", + "MultipleValuesError", + # .exceptions + "ConcurrencyError", + "ConnectionClosed", + "ConnectionClosedError", + "ConnectionClosedOK", + "DuplicateParameter", + "InvalidHandshake", + "InvalidHeader", + "InvalidHeaderFormat", + "InvalidHeaderValue", + "InvalidMessage", + "InvalidOrigin", + "InvalidParameterName", + "InvalidParameterValue", + "InvalidProxy", + "InvalidProxyMessage", + "InvalidProxyStatus", + "InvalidState", + "InvalidStatus", + "InvalidUpgrade", + "InvalidURI", + "NegotiationError", + "PayloadTooBig", + "ProtocolError", + "ProxyError", + "SecurityError", + "WebSocketException", + # .frames + "Close", + "CloseCode", + "Frame", + "Opcode", + # .http11 + "Request", + "Response", + # .protocol + "Protocol", + "Side", + "State", + # .server + "ServerProtocol", + # .typing + "Data", + "ExtensionName", + "ExtensionParameter", + "LoggerLike", + "StatusLike", + "Origin", + "Subprotocol", +] + +# When type checking, import non-deprecated aliases eagerly. Else, import on demand. +if TYPE_CHECKING: + from .asyncio.client import ClientConnection, connect, unix_connect + from .asyncio.router import Router, route, unix_route + from .asyncio.server import ( + Server, + ServerConnection, + basic_auth, + broadcast, + serve, + unix_serve, + ) + from .client import ClientProtocol + from .datastructures import Headers, HeadersLike, MultipleValuesError + from .exceptions import ( + ConcurrencyError, + ConnectionClosed, + ConnectionClosedError, + ConnectionClosedOK, + DuplicateParameter, + InvalidHandshake, + InvalidHeader, + InvalidHeaderFormat, + InvalidHeaderValue, + InvalidMessage, + InvalidOrigin, + InvalidParameterName, + InvalidParameterValue, + InvalidProxy, + InvalidProxyMessage, + InvalidProxyStatus, + InvalidState, + InvalidStatus, + InvalidUpgrade, + InvalidURI, + NegotiationError, + PayloadTooBig, + ProtocolError, + ProxyError, + SecurityError, + WebSocketException, + ) + from .frames import Close, CloseCode, Frame, Opcode + from .http11 import Request, Response + from .protocol import Protocol, Side, State + from .server import ServerProtocol + from .typing import ( + Data, + ExtensionName, + ExtensionParameter, + LoggerLike, + Origin, + StatusLike, + Subprotocol, + ) +else: + lazy_import( + globals(), + aliases={ + # .asyncio.client + "connect": ".asyncio.client", + "unix_connect": ".asyncio.client", + "ClientConnection": ".asyncio.client", + # .asyncio.router + "route": ".asyncio.router", + "unix_route": ".asyncio.router", + "Router": ".asyncio.router", + # .asyncio.server + "basic_auth": ".asyncio.server", + "broadcast": ".asyncio.server", + "serve": ".asyncio.server", + "unix_serve": ".asyncio.server", + "ServerConnection": ".asyncio.server", + "Server": ".asyncio.server", + # .client + "ClientProtocol": ".client", + # .datastructures + "Headers": ".datastructures", + "HeadersLike": ".datastructures", + "MultipleValuesError": ".datastructures", + # .exceptions + "ConcurrencyError": ".exceptions", + "ConnectionClosed": ".exceptions", + "ConnectionClosedError": ".exceptions", + "ConnectionClosedOK": ".exceptions", + "DuplicateParameter": ".exceptions", + "InvalidHandshake": ".exceptions", + "InvalidHeader": ".exceptions", + "InvalidHeaderFormat": ".exceptions", + "InvalidHeaderValue": ".exceptions", + "InvalidMessage": ".exceptions", + "InvalidOrigin": ".exceptions", + "InvalidParameterName": ".exceptions", + "InvalidParameterValue": ".exceptions", + "InvalidProxy": ".exceptions", + "InvalidProxyMessage": ".exceptions", + "InvalidProxyStatus": ".exceptions", + "InvalidState": ".exceptions", + "InvalidStatus": ".exceptions", + "InvalidUpgrade": ".exceptions", + "InvalidURI": ".exceptions", + "NegotiationError": ".exceptions", + "PayloadTooBig": ".exceptions", + "ProtocolError": ".exceptions", + "ProxyError": ".exceptions", + "SecurityError": ".exceptions", + "WebSocketException": ".exceptions", + # .frames + "Close": ".frames", + "CloseCode": ".frames", + "Frame": ".frames", + "Opcode": ".frames", + # .http11 + "Request": ".http11", + "Response": ".http11", + # .protocol + "Protocol": ".protocol", + "Side": ".protocol", + "State": ".protocol", + # .server + "ServerProtocol": ".server", + # .typing + "Data": ".typing", + "ExtensionName": ".typing", + "ExtensionParameter": ".typing", + "LoggerLike": ".typing", + "Origin": ".typing", + "StatusLike": ".typing", + "Subprotocol": ".typing", + }, + deprecated_aliases={ + # deprecated in 9.0 - 2021-09-01 + "framing": ".legacy", + "handshake": ".legacy", + "parse_uri": ".uri", + "WebSocketURI": ".uri", + # deprecated in 14.0 - 2024-11-09 + # .legacy.auth + "BasicAuthWebSocketServerProtocol": ".legacy.auth", + "basic_auth_protocol_factory": ".legacy.auth", + # .legacy.client + "WebSocketClientProtocol": ".legacy.client", + # .legacy.exceptions + "AbortHandshake": ".legacy.exceptions", + "InvalidStatusCode": ".legacy.exceptions", + "RedirectHandshake": ".legacy.exceptions", + "WebSocketProtocolError": ".legacy.exceptions", + # .legacy.protocol + "WebSocketCommonProtocol": ".legacy.protocol", + # .legacy.server + "WebSocketServer": ".legacy.server", + "WebSocketServerProtocol": ".legacy.server", + }, + ) diff --git a/source/websockets/__main__.py b/source/websockets/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..2f05ddc225577125ac018702cfe4de55a1aacd71 --- /dev/null +++ b/source/websockets/__main__.py @@ -0,0 +1,5 @@ +from .cli import main + + +if __name__ == "__main__": + main() diff --git a/source/websockets/asyncio/__init__.py b/source/websockets/asyncio/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/source/websockets/asyncio/async_timeout.py b/source/websockets/asyncio/async_timeout.py new file mode 100644 index 0000000000000000000000000000000000000000..6ffa899695637829dd5d3c7b58c68683000fc35d --- /dev/null +++ b/source/websockets/asyncio/async_timeout.py @@ -0,0 +1,282 @@ +# From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py +# Licensed under the Apache License (Apache-2.0) + +import asyncio +import enum +import sys +import warnings +from types import TracebackType +from typing import Optional, Type + + +if sys.version_info >= (3, 11): + from typing import final +else: + # From https://github.com/python/typing_extensions/blob/main/src/typing_extensions.py + # Licensed under the Python Software Foundation License (PSF-2.0) + + # @final exists in 3.8+, but we backport it for all versions + # before 3.11 to keep support for the __final__ attribute. + # See https://bugs.python.org/issue46342 + def final(f): + """This decorator can be used to indicate to type checkers that + the decorated method cannot be overridden, and decorated class + cannot be subclassed. For example: + + class Base: + @final + def done(self) -> None: + ... + class Sub(Base): + def done(self) -> None: # Error reported by type checker + ... + @final + class Leaf: + ... + class Other(Leaf): # Error reported by type checker + ... + + There is no runtime checking of these properties. The decorator + sets the ``__final__`` attribute to ``True`` on the decorated object + to allow runtime introspection. + """ + try: + f.__final__ = True + except (AttributeError, TypeError): + # Skip the attribute silently if it is not writable. + # AttributeError happens if the object has __slots__ or a + # read-only property, TypeError if it's a builtin class. + pass + return f + + # End https://github.com/python/typing_extensions/blob/main/src/typing_extensions.py + + +if sys.version_info >= (3, 11): + + def _uncancel_task(task: "asyncio.Task[object]") -> None: + task.uncancel() + +else: + + def _uncancel_task(task: "asyncio.Task[object]") -> None: + pass + + +__version__ = "4.0.3" + + +__all__ = ("timeout", "timeout_at", "Timeout") + + +def timeout(delay: Optional[float]) -> "Timeout": + """timeout context manager. + + Useful in cases when you want to apply timeout logic around block + of code or in cases when asyncio.wait_for is not suitable. For example: + + >>> async with timeout(0.001): + ... async with aiohttp.get('https://github.com') as r: + ... await r.text() + + + delay - value in seconds or None to disable timeout logic + """ + loop = asyncio.get_running_loop() + if delay is not None: + deadline = loop.time() + delay # type: Optional[float] + else: + deadline = None + return Timeout(deadline, loop) + + +def timeout_at(deadline: Optional[float]) -> "Timeout": + """Schedule the timeout at absolute time. + + deadline argument points on the time in the same clock system + as loop.time(). + + Please note: it is not POSIX time but a time with + undefined starting base, e.g. the time of the system power on. + + >>> async with timeout_at(loop.time() + 10): + ... async with aiohttp.get('https://github.com') as r: + ... await r.text() + + + """ + loop = asyncio.get_running_loop() + return Timeout(deadline, loop) + + +class _State(enum.Enum): + INIT = "INIT" + ENTER = "ENTER" + TIMEOUT = "TIMEOUT" + EXIT = "EXIT" + + +@final +class Timeout: + # Internal class, please don't instantiate it directly + # Use timeout() and timeout_at() public factories instead. + # + # Implementation note: `async with timeout()` is preferred + # over `with timeout()`. + # While technically the Timeout class implementation + # doesn't need to be async at all, + # the `async with` statement explicitly points that + # the context manager should be used from async function context. + # + # This design allows to avoid many silly misusages. + # + # TimeoutError is raised immediately when scheduled + # if the deadline is passed. + # The purpose is to time out as soon as possible + # without waiting for the next await expression. + + __slots__ = ("_deadline", "_loop", "_state", "_timeout_handler", "_task") + + def __init__( + self, deadline: Optional[float], loop: asyncio.AbstractEventLoop + ) -> None: + self._loop = loop + self._state = _State.INIT + + self._task: Optional["asyncio.Task[object]"] = None + self._timeout_handler = None # type: Optional[asyncio.Handle] + if deadline is None: + self._deadline = None # type: Optional[float] + else: + self.update(deadline) + + def __enter__(self) -> "Timeout": + warnings.warn( + "with timeout() is deprecated, use async with timeout() instead", + DeprecationWarning, + stacklevel=2, + ) + self._do_enter() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self._do_exit(exc_type) + return None + + async def __aenter__(self) -> "Timeout": + self._do_enter() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self._do_exit(exc_type) + return None + + @property + def expired(self) -> bool: + """Is timeout expired during execution?""" + return self._state == _State.TIMEOUT + + @property + def deadline(self) -> Optional[float]: + return self._deadline + + def reject(self) -> None: + """Reject scheduled timeout if any.""" + # cancel is maybe better name but + # task.cancel() raises CancelledError in asyncio world. + if self._state not in (_State.INIT, _State.ENTER): + raise RuntimeError(f"invalid state {self._state.value}") + self._reject() + + def _reject(self) -> None: + self._task = None + if self._timeout_handler is not None: + self._timeout_handler.cancel() + self._timeout_handler = None + + def shift(self, delay: float) -> None: + """Advance timeout on delay seconds. + + The delay can be negative. + + Raise RuntimeError if shift is called when deadline is not scheduled + """ + deadline = self._deadline + if deadline is None: + raise RuntimeError("cannot shift timeout if deadline is not scheduled") + self.update(deadline + delay) + + def update(self, deadline: float) -> None: + """Set deadline to absolute value. + + deadline argument points on the time in the same clock system + as loop.time(). + + If new deadline is in the past the timeout is raised immediately. + + Please note: it is not POSIX time but a time with + undefined starting base, e.g. the time of the system power on. + """ + if self._state == _State.EXIT: + raise RuntimeError("cannot reschedule after exit from context manager") + if self._state == _State.TIMEOUT: + raise RuntimeError("cannot reschedule expired timeout") + if self._timeout_handler is not None: + self._timeout_handler.cancel() + self._deadline = deadline + if self._state != _State.INIT: + self._reschedule() + + def _reschedule(self) -> None: + assert self._state == _State.ENTER + deadline = self._deadline + if deadline is None: + return + + now = self._loop.time() + if self._timeout_handler is not None: + self._timeout_handler.cancel() + + self._task = asyncio.current_task() + if deadline <= now: + self._timeout_handler = self._loop.call_soon(self._on_timeout) + else: + self._timeout_handler = self._loop.call_at(deadline, self._on_timeout) + + def _do_enter(self) -> None: + if self._state != _State.INIT: + raise RuntimeError(f"invalid state {self._state.value}") + self._state = _State.ENTER + self._reschedule() + + def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None: + if exc_type is asyncio.CancelledError and self._state == _State.TIMEOUT: + assert self._task is not None + _uncancel_task(self._task) + self._timeout_handler = None + self._task = None + raise asyncio.TimeoutError + # timeout has not expired + self._state = _State.EXIT + self._reject() + return None + + def _on_timeout(self) -> None: + assert self._task is not None + self._task.cancel() + self._state = _State.TIMEOUT + # drop the reference early + self._timeout_handler = None + + +# End https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py diff --git a/source/websockets/asyncio/client.py b/source/websockets/asyncio/client.py new file mode 100644 index 0000000000000000000000000000000000000000..05947f3a07e99e3b02fe3e5fe24be3a01bc58092 --- /dev/null +++ b/source/websockets/asyncio/client.py @@ -0,0 +1,804 @@ +from __future__ import annotations + +import asyncio +import logging +import os +import socket +import ssl as ssl_module +import traceback +import urllib.parse +from collections.abc import AsyncIterator, Generator, Sequence +from types import TracebackType +from typing import Any, Callable, Literal, cast + +from ..client import ClientProtocol, backoff +from ..datastructures import HeadersLike +from ..exceptions import ( + InvalidMessage, + InvalidProxyMessage, + InvalidProxyStatus, + InvalidStatus, + ProxyError, + SecurityError, +) +from ..extensions.base import ClientExtensionFactory +from ..extensions.permessage_deflate import enable_client_permessage_deflate +from ..headers import validate_subprotocols +from ..http11 import USER_AGENT, Response +from ..protocol import CONNECTING, Event +from ..proxy import Proxy, get_proxy, parse_proxy, prepare_connect_request +from ..streams import StreamReader +from ..typing import LoggerLike, Origin, Subprotocol +from ..uri import WebSocketURI, parse_uri +from .compatibility import TimeoutError, asyncio_timeout +from .connection import Connection + + +__all__ = ["connect", "unix_connect", "ClientConnection"] + +MAX_REDIRECTS = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10")) + + +class ClientConnection(Connection): + """ + :mod:`asyncio` implementation of a WebSocket client connection. + + :class:`ClientConnection` provides :meth:`recv` and :meth:`send` coroutines + for receiving and sending messages. + + It supports asynchronous iteration to receive messages:: + + async for message in websocket: + await process(message) + + The iterator exits normally when the connection is closed with code + 1000 (OK) or 1001 (going away) or without a close code. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is + closed with any other code. + + The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``, + and ``write_limit`` arguments have the same meaning as in :func:`connect`. + + Args: + protocol: Sans-I/O connection. + + """ + + def __init__( + self, + protocol: ClientProtocol, + *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = 10, + max_queue: int | None | tuple[int | None, int | None] = 16, + write_limit: int | tuple[int, int | None] = 2**15, + ) -> None: + self.protocol: ClientProtocol + super().__init__( + protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_queue=max_queue, + write_limit=write_limit, + ) + self.response_rcvd: asyncio.Future[None] = self.loop.create_future() + + async def handshake( + self, + additional_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + ) -> None: + """ + Perform the opening handshake. + + """ + async with self.send_context(expected_state=CONNECTING): + self.request = self.protocol.connect() + if additional_headers is not None: + self.request.headers.update(additional_headers) + if user_agent_header is not None: + self.request.headers.setdefault("User-Agent", user_agent_header) + self.protocol.send_request(self.request) + + await asyncio.wait( + [self.response_rcvd, self.connection_lost_waiter], + return_when=asyncio.FIRST_COMPLETED, + ) + + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a response, when the response cannot be parsed, or when the + # response fails the handshake. + + if self.protocol.handshake_exc is not None: + raise self.protocol.handshake_exc + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + """ + # First event - handshake response. + if self.response is None: + assert isinstance(event, Response) + self.response = event + self.response_rcvd.set_result(None) + # Later events - frames. + else: + super().process_event(event) + + +def process_exception(exc: Exception) -> Exception | None: + """ + Determine whether a connection error is retryable or fatal. + + When reconnecting automatically with ``async for ... in connect(...)``, if a + connection attempt fails, :func:`process_exception` is called to determine + whether to retry connecting or to raise the exception. + + This function defines the default behavior, which is to retry on: + + * :exc:`EOFError`, :exc:`OSError`, :exc:`asyncio.TimeoutError`: network + errors; + * :exc:`~websockets.exceptions.InvalidStatus` when the status code is 500, + 502, 503, or 504: server or proxy errors. + + All other exceptions are considered fatal. + + You can change this behavior with the ``process_exception`` argument of + :func:`connect`. + + Return :obj:`None` if the exception is retryable i.e. when the error could + be transient and trying to reconnect with the same parameters could succeed. + The exception will be logged at the ``INFO`` level. + + Return an exception, either ``exc`` or a new exception, if the exception is + fatal i.e. when trying to reconnect will most likely produce the same error. + That exception will be raised, breaking out of the retry loop. + + """ + # This catches python-socks' ProxyConnectionError and ProxyTimeoutError. + # Remove asyncio.TimeoutError when dropping Python < 3.11. + if isinstance(exc, (OSError, TimeoutError, asyncio.TimeoutError)): + return None + if isinstance(exc, InvalidMessage) and isinstance(exc.__cause__, EOFError): + return None + if isinstance(exc, InvalidStatus) and exc.response.status_code in [ + 500, # Internal Server Error + 502, # Bad Gateway + 503, # Service Unavailable + 504, # Gateway Timeout + ]: + return None + return exc + + +# This is spelled in lower case because it's exposed as a callable in the API. +class connect: + """ + Connect to the WebSocket server at ``uri``. + + This coroutine returns a :class:`ClientConnection` instance, which you can + use to send and receive messages. + + :func:`connect` may be used as an asynchronous context manager:: + + from websockets.asyncio.client import connect + + async with connect(...) as websocket: + ... + + The connection is closed automatically when exiting the context. + + :func:`connect` can be used as an infinite asynchronous iterator to + reconnect automatically on errors:: + + async for websocket in connect(...): + try: + ... + except websockets.exceptions.ConnectionClosed: + continue + + If the connection fails with a transient error, it is retried with + exponential backoff. If it fails with a fatal error, the exception is + raised, breaking out of the loop. + + The connection is closed automatically after each iteration of the loop. + + Args: + uri: URI of the WebSocket server. + origin: Value of the ``Origin`` header, for servers that require it. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + additional_headers (HeadersLike | None): Arbitrary HTTP headers to add + to the handshake request. + user_agent_header: Value of the ``User-Agent`` request header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. + Setting it to :obj:`None` removes the header. + proxy: If a proxy is configured, it is used by default. Set ``proxy`` + to :obj:`None` to disable the proxy or to the address of a proxy + to override the system configuration. See the :doc:`proxy docs + <../../topics/proxies>` for details. + process_exception: When reconnecting automatically, tell whether an + error is transient or fatal. The default behavior is defined by + :func:`process_exception`. Refer to its documentation for details. + open_timeout: Timeout for opening the connection in seconds. + :obj:`None` disables the timeout. + ping_interval: Interval between keepalive pings in seconds. + :obj:`None` disables keepalive. + ping_timeout: Timeout for keepalive pings in seconds. + :obj:`None` disables timeouts. + close_timeout: Timeout for closing the connection in seconds. + :obj:`None` disables the timeout. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. You may pass a ``(max_message_size, + max_fragment_size)`` tuple to set different limits for messages and + fragments when you expect long messages sent in short fragments. + max_queue: High-water mark of the buffer where frames are received. + It defaults to 16 frames. The low-water mark defaults to ``max_queue + // 4``. You may pass a ``(high, low)`` tuple to set the high-water + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. + write_limit: High-water mark of write buffer in bytes. It is passed to + :meth:`~asyncio.WriteTransport.set_write_buffer_limits`. It defaults + to 32 KiB. You may pass a ``(high, low)`` tuple to set the + high-water and low-water marks. + logger: Logger for this client. + It defaults to ``logging.getLogger("websockets.client")``. + See the :doc:`logging guide <../../topics/logging>` for details. + create_connection: Factory for the :class:`ClientConnection` managing + the connection. Set it to a wrapper or a subclass to customize + connection handling. + + Any other keyword arguments are passed to the event loop's + :meth:`~asyncio.loop.create_connection` method. + + For example: + + * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enforce TLS settings. + When connecting to a ``wss://`` URI, if ``ssl`` isn't provided, a TLS + context is created with :func:`~ssl.create_default_context`. + + * You can set ``server_hostname`` to override the host name from ``uri`` in + the TLS handshake. + + * You can set ``host`` and ``port`` to connect to a different host and port + from those found in ``uri``. This only changes the destination of the TCP + connection. The host name from ``uri`` is still used in the TLS handshake + for secure connections and in the ``Host`` header. + + * You can set ``sock`` to provide a preexisting TCP socket. You may call + :func:`socket.create_connection` (not to be confused with the event loop's + :meth:`~asyncio.loop.create_connection` method) to create a suitable + client socket and customize it. + + When using a proxy: + + * Prefix keyword arguments with ``proxy_`` for configuring TLS between the + client and an HTTPS proxy: ``proxy_ssl``, ``proxy_server_hostname``, + ``proxy_ssl_handshake_timeout``, and ``proxy_ssl_shutdown_timeout``. + * Use the standard keyword arguments for configuring TLS between the proxy + and the WebSocket server: ``ssl``, ``server_hostname``, + ``ssl_handshake_timeout``, and ``ssl_shutdown_timeout``. + * Other keyword arguments are used only for connecting to the proxy. + + Raises: + InvalidURI: If ``uri`` isn't a valid WebSocket URI. + InvalidProxy: If ``proxy`` isn't a valid proxy. + OSError: If the TCP connection fails. + InvalidHandshake: If the opening handshake fails. + TimeoutError: If the opening handshake times out. + + """ + + def __init__( + self, + uri: str, + *, + # WebSocket + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + compression: str | None = "deflate", + # HTTP + additional_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + proxy: str | Literal[True] | None = True, + process_exception: Callable[[Exception], Exception | None] = process_exception, + # Timeouts + open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = 10, + # Limits + max_size: int | None | tuple[int | None, int | None] = 2**20, + max_queue: int | None | tuple[int | None, int | None] = 16, + write_limit: int | tuple[int, int | None] = 2**15, + # Logging + logger: LoggerLike | None = None, + # Escape hatch for advanced customization + create_connection: type[ClientConnection] | None = None, + # Other keyword arguments are passed to loop.create_connection + **kwargs: Any, + ) -> None: + self.uri = uri + + if subprotocols is not None: + validate_subprotocols(subprotocols) + + if compression == "deflate": + extensions = enable_client_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + if logger is None: + logger = logging.getLogger("websockets.client") + + if create_connection is None: + create_connection = ClientConnection + + def protocol_factory(uri: WebSocketURI) -> ClientConnection: + # This is a protocol in the Sans-I/O implementation of websockets. + protocol = ClientProtocol( + uri, + origin=origin, + extensions=extensions, + subprotocols=subprotocols, + max_size=max_size, + logger=logger, + ) + # This is a connection in websockets and a protocol in asyncio. + connection = create_connection( + protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_queue=max_queue, + write_limit=write_limit, + ) + return connection + + self.proxy = proxy + self.protocol_factory = protocol_factory + self.additional_headers = additional_headers + self.user_agent_header = user_agent_header + self.process_exception = process_exception + self.open_timeout = open_timeout + self.logger = logger + self.connection_kwargs = kwargs + + async def create_connection(self) -> ClientConnection: + """Create TCP or Unix connection.""" + loop = asyncio.get_running_loop() + kwargs = self.connection_kwargs.copy() + + ws_uri = parse_uri(self.uri) + + proxy = self.proxy + if kwargs.get("unix", False): + proxy = None + if kwargs.get("sock") is not None: + proxy = None + if proxy is True: + proxy = get_proxy(ws_uri) + + def factory() -> ClientConnection: + return self.protocol_factory(ws_uri) + + if ws_uri.secure: + kwargs.setdefault("ssl", True) + kwargs.setdefault("server_hostname", ws_uri.host) + if kwargs.get("ssl") is None: + raise ValueError("ssl=None is incompatible with a wss:// URI") + else: + if kwargs.get("ssl") is not None: + raise ValueError("ssl argument is incompatible with a ws:// URI") + + if kwargs.pop("unix", False): + _, connection = await loop.create_unix_connection(factory, **kwargs) + elif proxy is not None: + proxy_parsed = parse_proxy(proxy) + if proxy_parsed.scheme[:5] == "socks": + # Connect to the server through the proxy. + sock = await connect_socks_proxy( + proxy_parsed, + ws_uri, + local_addr=kwargs.pop("local_addr", None), + ) + # Initialize WebSocket connection via the proxy. + _, connection = await loop.create_connection( + factory, + sock=sock, + **kwargs, + ) + elif proxy_parsed.scheme[:4] == "http": + # Split keyword arguments between the proxy and the server. + all_kwargs, proxy_kwargs, kwargs = kwargs, {}, {} + for key, value in all_kwargs.items(): + if key.startswith("ssl") or key == "server_hostname": + kwargs[key] = value + elif key.startswith("proxy_"): + proxy_kwargs[key[6:]] = value + else: + proxy_kwargs[key] = value + # Validate the proxy_ssl argument. + if proxy_parsed.scheme == "https": + proxy_kwargs.setdefault("ssl", True) + if proxy_kwargs.get("ssl") is None: + raise ValueError( + "proxy_ssl=None is incompatible with an https:// proxy" + ) + else: + if proxy_kwargs.get("ssl") is not None: + raise ValueError( + "proxy_ssl argument is incompatible with an http:// proxy" + ) + # Connect to the server through the proxy. + transport = await connect_http_proxy( + proxy_parsed, + ws_uri, + user_agent_header=self.user_agent_header, + **proxy_kwargs, + ) + # Initialize WebSocket connection via the proxy. + connection = factory() + transport.set_protocol(connection) + ssl = kwargs.pop("ssl", None) + if ssl is True: + ssl = ssl_module.create_default_context() + if ssl is not None: + new_transport = await loop.start_tls( + transport, connection, ssl, **kwargs + ) + assert new_transport is not None # help mypy + transport = new_transport + connection.connection_made(transport) + else: + raise AssertionError("unsupported proxy") + else: + # Connect to the server directly. + if kwargs.get("sock") is None: + kwargs.setdefault("host", ws_uri.host) + kwargs.setdefault("port", ws_uri.port) + # Initialize WebSocket connection. + _, connection = await loop.create_connection(factory, **kwargs) + return connection + + def process_redirect(self, exc: Exception) -> Exception | str: + """ + Determine whether a connection error is a redirect that can be followed. + + Return the new URI if it's a valid redirect. Else, return an exception. + + """ + if not ( + isinstance(exc, InvalidStatus) + and exc.response.status_code + in [ + 300, # Multiple Choices + 301, # Moved Permanently + 302, # Found + 303, # See Other + 307, # Temporary Redirect + 308, # Permanent Redirect + ] + and "Location" in exc.response.headers + ): + return exc + + old_ws_uri = parse_uri(self.uri) + new_uri = urllib.parse.urljoin(self.uri, exc.response.headers["Location"]) + new_ws_uri = parse_uri(new_uri) + + # If connect() received a socket, it is closed and cannot be reused. + if self.connection_kwargs.get("sock") is not None: + return ValueError( + f"cannot follow redirect to {new_uri} with a preexisting socket" + ) + + # TLS downgrade is forbidden. + if old_ws_uri.secure and not new_ws_uri.secure: + return SecurityError(f"cannot follow redirect to non-secure URI {new_uri}") + + # Apply restrictions to cross-origin redirects. + if ( + old_ws_uri.secure != new_ws_uri.secure + or old_ws_uri.host != new_ws_uri.host + or old_ws_uri.port != new_ws_uri.port + ): + # Cross-origin redirects on Unix sockets don't quite make sense. + if self.connection_kwargs.get("unix", False): + return ValueError( + f"cannot follow cross-origin redirect to {new_uri} " + f"with a Unix socket" + ) + + # Cross-origin redirects when host and port are overridden are ill-defined. + if ( + self.connection_kwargs.get("host") is not None + or self.connection_kwargs.get("port") is not None + ): + return ValueError( + f"cannot follow cross-origin redirect to {new_uri} " + f"with an explicit host or port" + ) + + return new_uri + + # ... = await connect(...) + + def __await__(self) -> Generator[Any, None, ClientConnection]: + # Create a suitable iterator by calling __await__ on a coroutine. + return self.__await_impl__().__await__() + + async def __await_impl__(self) -> ClientConnection: + try: + async with asyncio_timeout(self.open_timeout): + for _ in range(MAX_REDIRECTS): + self.connection = await self.create_connection() + try: + await self.connection.handshake( + self.additional_headers, + self.user_agent_header, + ) + except asyncio.CancelledError: + self.connection.transport.abort() + raise + except Exception as exc: + # Always close the connection even though keep-alive is + # the default in HTTP/1.1 because create_connection ties + # opening the network connection with initializing the + # protocol. In the current design of connect(), there is + # no easy way to reuse the network connection that works + # in every case nor to reinitialize the protocol. + self.connection.transport.abort() + + uri_or_exc = self.process_redirect(exc) + # Response is a valid redirect; follow it. + if isinstance(uri_or_exc, str): + self.uri = uri_or_exc + continue + # Response isn't a valid redirect; raise the exception. + if uri_or_exc is exc: + raise + else: + raise uri_or_exc from exc + + else: + self.connection.start_keepalive() + return self.connection + else: + raise SecurityError(f"more than {MAX_REDIRECTS} redirects") + + except TimeoutError as exc: + # Re-raise exception with an informative error message. + raise TimeoutError("timed out during opening handshake") from exc + + # ... = yield from connect(...) - remove when dropping Python < 3.11 + + __iter__ = __await__ + + # async with connect(...) as ...: ... + + async def __aenter__(self) -> ClientConnection: + return await self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + await self.connection.close() + + # async for ... in connect(...): + + async def __aiter__(self) -> AsyncIterator[ClientConnection]: + delays: Generator[float] | None = None + while True: + try: + async with self as protocol: + yield protocol + except Exception as exc: + # Determine whether the exception is retryable or fatal. + # The API of process_exception is "return an exception or None"; + # "raise an exception" is also supported because it's a frequent + # mistake. It isn't documented in order to keep the API simple. + try: + new_exc = self.process_exception(exc) + except Exception as raised_exc: + new_exc = raised_exc + + # The connection failed with a fatal error. + # Raise the exception and exit the loop. + if new_exc is exc: + raise + if new_exc is not None: + raise new_exc from exc + + # The connection failed with a retryable error. + # Start or continue backoff and reconnect. + if delays is None: + delays = backoff() + delay = next(delays) + self.logger.info( + "connect failed; reconnecting in %.1f seconds: %s", + delay, + traceback.format_exception_only(exc)[0].strip(), + ) + await asyncio.sleep(delay) + continue + + else: + # The connection succeeded. Reset backoff. + delays = None + + +def unix_connect( + path: str | None = None, + uri: str | None = None, + **kwargs: Any, +) -> connect: + """ + Connect to a WebSocket server listening on a Unix socket. + + This function accepts the same keyword arguments as :func:`connect`. + + It's only available on Unix. + + It's mainly useful for debugging servers listening on Unix sockets. + + Args: + path: File system path to the Unix socket. + uri: URI of the WebSocket server. ``uri`` defaults to + ``ws://localhost/`` or, when a ``ssl`` argument is provided, to + ``wss://localhost/``. + + """ + if uri is None: + if kwargs.get("ssl") is None: + uri = "ws://localhost/" + else: + uri = "wss://localhost/" + return connect(uri=uri, unix=True, path=path, **kwargs) + + +try: + from python_socks import ProxyType + from python_socks.async_.asyncio import Proxy as SocksProxy + +except ImportError: + + async def connect_socks_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + **kwargs: Any, + ) -> socket.socket: + raise ImportError("connecting through a SOCKS proxy requires python-socks") + +else: + SOCKS_PROXY_TYPES = { + "socks5h": ProxyType.SOCKS5, + "socks5": ProxyType.SOCKS5, + "socks4a": ProxyType.SOCKS4, + "socks4": ProxyType.SOCKS4, + } + + SOCKS_PROXY_RDNS = { + "socks5h": True, + "socks5": False, + "socks4a": True, + "socks4": False, + } + + async def connect_socks_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + **kwargs: Any, + ) -> socket.socket: + """Connect via a SOCKS proxy and return the socket.""" + socks_proxy = SocksProxy( + SOCKS_PROXY_TYPES[proxy.scheme], + proxy.host, + proxy.port, + proxy.username, + proxy.password, + SOCKS_PROXY_RDNS[proxy.scheme], + ) + # connect() is documented to raise OSError. + # socks_proxy.connect() doesn't raise TimeoutError; it gets canceled. + # Wrap other exceptions in ProxyError, a subclass of InvalidHandshake. + try: + return await socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs) + except OSError: + raise + except Exception as exc: + raise ProxyError("failed to connect to SOCKS proxy") from exc + + +class HTTPProxyConnection(asyncio.Protocol): + def __init__( + self, + ws_uri: WebSocketURI, + proxy: Proxy, + user_agent_header: str | None = None, + ): + self.ws_uri = ws_uri + self.proxy = proxy + self.user_agent_header = user_agent_header + + self.reader = StreamReader() + self.parser = Response.parse( + self.reader.read_line, + self.reader.read_exact, + self.reader.read_to_eof, + proxy=True, + ) + + loop = asyncio.get_running_loop() + self.response: asyncio.Future[Response] = loop.create_future() + + def run_parser(self) -> None: + try: + next(self.parser) + except StopIteration as exc: + response = exc.value + if 200 <= response.status_code < 300: + self.response.set_result(response) + else: + self.response.set_exception(InvalidProxyStatus(response)) + except Exception as exc: + proxy_exc = InvalidProxyMessage( + "did not receive a valid HTTP response from proxy" + ) + proxy_exc.__cause__ = exc + self.response.set_exception(proxy_exc) + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + transport = cast(asyncio.Transport, transport) + self.transport = transport + self.transport.write( + prepare_connect_request(self.proxy, self.ws_uri, self.user_agent_header) + ) + + def data_received(self, data: bytes) -> None: + self.reader.feed_data(data) + self.run_parser() + + def eof_received(self) -> None: + self.reader.feed_eof() + self.run_parser() + + def connection_lost(self, exc: Exception | None) -> None: + self.reader.feed_eof() + if exc is not None: + self.response.set_exception(exc) + + +async def connect_http_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + user_agent_header: str | None = None, + **kwargs: Any, +) -> asyncio.Transport: + transport, protocol = await asyncio.get_running_loop().create_connection( + lambda: HTTPProxyConnection(ws_uri, proxy, user_agent_header), + proxy.host, + proxy.port, + **kwargs, + ) + + try: + # This raises exceptions if the connection to the proxy fails. + await protocol.response + except Exception: + transport.close() + raise + + return transport diff --git a/source/websockets/asyncio/compatibility.py b/source/websockets/asyncio/compatibility.py new file mode 100644 index 0000000000000000000000000000000000000000..e17000069d530bdde8de5194e0d8257a5c5d1770 --- /dev/null +++ b/source/websockets/asyncio/compatibility.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import sys + + +__all__ = ["TimeoutError", "aiter", "anext", "asyncio_timeout", "asyncio_timeout_at"] + + +if sys.version_info[:2] >= (3, 11): + TimeoutError = TimeoutError + aiter = aiter + anext = anext + from asyncio import ( + timeout as asyncio_timeout, # noqa: F401 + timeout_at as asyncio_timeout_at, # noqa: F401 + ) + +else: # Python < 3.11 + from asyncio import TimeoutError + + def aiter(async_iterable): + return type(async_iterable).__aiter__(async_iterable) + + async def anext(async_iterator): + return await type(async_iterator).__anext__(async_iterator) + + from .async_timeout import ( + timeout as asyncio_timeout, # noqa: F401 + timeout_at as asyncio_timeout_at, # noqa: F401 + ) diff --git a/source/websockets/asyncio/connection.py b/source/websockets/asyncio/connection.py new file mode 100644 index 0000000000000000000000000000000000000000..205a2be50871cbc2cb85144421587a57a58bc725 --- /dev/null +++ b/source/websockets/asyncio/connection.py @@ -0,0 +1,1247 @@ +from __future__ import annotations + +import asyncio +import collections +import contextlib +import logging +import random +import struct +import sys +import traceback +import uuid +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterable, Mapping +from types import TracebackType +from typing import Any, Literal, cast, overload + +from ..exceptions import ( + ConcurrencyError, + ConnectionClosed, + ConnectionClosedOK, + ProtocolError, +) +from ..frames import DATA_OPCODES, CloseCode, Frame, Opcode +from ..http11 import Request, Response +from ..protocol import CLOSED, OPEN, Event, Protocol, State +from ..typing import BytesLike, Data, DataLike, LoggerLike, Subprotocol +from .compatibility import ( + TimeoutError, + aiter, + anext, + asyncio_timeout, + asyncio_timeout_at, +) +from .messages import Assembler + + +__all__ = ["Connection"] + + +class Connection(asyncio.Protocol): + """ + :mod:`asyncio` implementation of a WebSocket connection. + + :class:`Connection` provides APIs shared between WebSocket servers and + clients. + + You shouldn't use it directly. Instead, use + :class:`~websockets.asyncio.client.ClientConnection` or + :class:`~websockets.asyncio.server.ServerConnection`. + + """ + + def __init__( + self, + protocol: Protocol, + *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = 10, + max_queue: int | None | tuple[int | None, int | None] = 16, + write_limit: int | tuple[int, int | None] = 2**15, + ) -> None: + self.protocol = protocol + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout + self.close_timeout = close_timeout + if isinstance(max_queue, int) or max_queue is None: + self.max_queue_high, self.max_queue_low = max_queue, None + else: + self.max_queue_high, self.max_queue_low = max_queue + if isinstance(write_limit, int): + self.write_limit_high, self.write_limit_low = write_limit, None + else: + self.write_limit_high, self.write_limit_low = write_limit + + # Inject reference to this instance in the protocol's logger. + self.protocol.logger = logging.LoggerAdapter( + self.protocol.logger, + {"websocket": self}, + ) + + # Copy attributes from the protocol for convenience. + self.id: uuid.UUID = self.protocol.id + """Unique identifier of the connection. Useful in logs.""" + self.logger: LoggerLike = self.protocol.logger + """Logger for this connection.""" + self.debug = self.protocol.debug + + # HTTP handshake request and response. + self.request: Request | None = None + """Opening handshake request.""" + self.response: Response | None = None + """Opening handshake response.""" + + # Event loop running this connection. + self.loop = asyncio.get_running_loop() + + # Assembler turning frames into messages and serializing reads. + self.recv_messages: Assembler # initialized in connection_made + + # Deadline for the closing handshake. + self.close_deadline: float | None = None + + # Whether we are busy sending a fragmented message. + self.send_in_progress: asyncio.Future[None] | None = None + + # Mapping of ping IDs to pong waiters, in chronological order. + self.pending_pings: dict[bytes, tuple[asyncio.Future[float], float]] = {} + + self.latency: float = 0.0 + """ + Latency of the connection, in seconds. + + Latency is defined as the round-trip time of the connection. It is + measured by sending a Ping frame and waiting for a matching Pong frame. + Before the first measurement, :attr:`latency` is ``0.0``. + + By default, websockets enables a :ref:`keepalive ` mechanism + that sends Ping frames automatically at regular intervals. You can also + send Ping frames and measure latency with :meth:`ping`. + """ + + # Task that sends keepalive pings. None when ping_interval is None. + self.keepalive_task: asyncio.Task[None] | None = None + + # Exception raised while reading from the connection, to be chained to + # ConnectionClosed in order to show why the TCP connection dropped. + self.recv_exc: BaseException | None = None + + # Completed when the TCP connection is closed and the WebSocket + # connection state becomes CLOSED. + self.connection_lost_waiter: asyncio.Future[None] = self.loop.create_future() + + # Adapted from asyncio.FlowControlMixin. + self.paused: bool = False + self.drain_waiters: collections.deque[asyncio.Future[None]] = ( + collections.deque() + ) + + # Public attributes + + @property + def local_address(self) -> Any: + """ + Local address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family. + See :meth:`~socket.socket.getsockname`. + + """ + return self.transport.get_extra_info("sockname") + + @property + def remote_address(self) -> Any: + """ + Remote address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family. + See :meth:`~socket.socket.getpeername`. + + """ + return self.transport.get_extra_info("peername") + + @property + def state(self) -> State: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should call :meth:`~recv` or + :meth:`send` and handle :exc:`~websockets.exceptions.ConnectionClosed` + exceptions. + + """ + return self.protocol.state + + @property + def subprotocol(self) -> Subprotocol | None: + """ + Subprotocol negotiated during the opening handshake. + + :obj:`None` if no subprotocol was negotiated. + + """ + return self.protocol.subprotocol + + @property + def close_code(self) -> int | None: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should inspect attributes + of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.close_code + + @property + def close_reason(self) -> str | None: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should inspect attributes + of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.close_reason + + # Public methods + + async def __aenter__(self) -> Connection: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + if exc_type is None: + await self.close() + else: + await self.close(CloseCode.INTERNAL_ERROR) + + async def __aiter__(self) -> AsyncIterator[Data]: + """ + Iterate on incoming messages. + + The iterator calls :meth:`recv` and yields messages asynchronously in an + infinite loop. + + It exits when the connection is closed normally. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` exception after a + protocol error or a network failure. + + """ + try: + while True: + yield await self.recv() + except ConnectionClosedOK: + return + + @overload + async def recv(self, decode: Literal[True]) -> str: ... + + @overload + async def recv(self, decode: Literal[False]) -> bytes: ... + + @overload + async def recv(self, decode: bool | None = None) -> Data: ... + + async def recv(self, decode: bool | None = None) -> Data: + """ + Receive the next message. + + When the connection is closed, :meth:`recv` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises + :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal closure + and :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. This is how you detect the end of the + message stream. + + Canceling :meth:`recv` is safe. There's no risk of losing data. The next + invocation of :meth:`recv` will return the next message. + + This makes it possible to enforce a timeout by wrapping :meth:`recv` in + :func:`~asyncio.timeout` or :func:`~asyncio.wait_for`. + + When the message is fragmented, :meth:`recv` waits until all fragments + are received, reassembles them, and returns the whole message. + + Args: + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + + Returns: + A string (:class:`str`) for a Text_ frame or a bytestring + (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and + return a bytestring (:class:`bytes`). This improves performance + when decoding isn't needed, for example if the message contains + JSON and you're using a JSON library that expects a bytestring. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames and + return strings (:class:`str`). This may be useful for servers that + send binary frames instead of text frames. + + Raises: + ConnectionClosed: When the connection is closed. + ConcurrencyError: If two coroutines call :meth:`recv` or + :meth:`recv_streaming` concurrently. + + """ + try: + return await self.recv_messages.get(decode) + except EOFError: + pass + # fallthrough + except ConcurrencyError: + raise ConcurrencyError( + "cannot call recv while another coroutine " + "is already running recv or recv_streaming" + ) from None + except UnicodeDecodeError as exc: + async with self.send_context(): + self.protocol.fail( + CloseCode.INVALID_DATA, + f"{exc.reason} at position {exc.start}", + ) + # fallthrough + + # Wait for the protocol state to be CLOSED before accessing close_exc. + await asyncio.shield(self.connection_lost_waiter) + raise self.protocol.close_exc from self.recv_exc + + @overload + def recv_streaming(self, decode: Literal[True]) -> AsyncIterator[str]: ... + + @overload + def recv_streaming(self, decode: Literal[False]) -> AsyncIterator[bytes]: ... + + @overload + def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: ... + + async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: + """ + Receive the next message frame by frame. + + This method is designed for receiving fragmented messages. It returns an + asynchronous iterator that yields each fragment as it is received. This + iterator must be fully consumed. Else, future calls to :meth:`recv` or + :meth:`recv_streaming` will raise + :exc:`~websockets.exceptions.ConcurrencyError`, making the connection + unusable. + + :meth:`recv_streaming` raises the same exceptions as :meth:`recv`. + + Canceling :meth:`recv_streaming` before receiving the first frame is + safe. Canceling it after receiving one or more frames leaves the + iterator in a partially consumed state, making the connection unusable. + Instead, you should close the connection with :meth:`close`. + + Args: + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + + Returns: + An iterator of strings (:class:`str`) for a Text_ frame or + bytestrings (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and + yield bytestrings (:class:`bytes`). This improves performance + when decoding isn't needed. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames and + yield strings (:class:`str`). This may be useful for servers that + send binary frames instead of text frames. + + Raises: + ConnectionClosed: When the connection is closed. + ConcurrencyError: If two coroutines call :meth:`recv` or + :meth:`recv_streaming` concurrently. + + """ + try: + async for frame in self.recv_messages.get_iter(decode): + yield frame + return + except EOFError: + pass + # fallthrough + except ConcurrencyError: + raise ConcurrencyError( + "cannot call recv_streaming while another coroutine " + "is already running recv or recv_streaming" + ) from None + except UnicodeDecodeError as exc: + async with self.send_context(): + self.protocol.fail( + CloseCode.INVALID_DATA, + f"{exc.reason} at position {exc.start}", + ) + # fallthrough + + # Wait for the protocol state to be CLOSED before accessing close_exc. + await asyncio.shield(self.connection_lost_waiter) + raise self.protocol.close_exc from self.recv_exc + + async def send( + self, + message: DataLike | Iterable[DataLike] | AsyncIterable[DataLike], + text: bool | None = None, + ) -> None: + """ + Send a message. + + A string (:class:`str`) is sent as a Text_ frame. A bytestring or + bytes-like object (:class:`bytes`, :class:`bytearray`, or + :class:`memoryview`) is sent as a Binary_ frame. + + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + You may override this behavior with the ``text`` argument: + + * Set ``text=True`` to send an UTF-8 bytestring or bytes-like object + (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) in a + Text_ frame. This improves performance when the message is already + UTF-8 encoded, for example if the message contains JSON and you're + using a JSON library that produces a bytestring. + * Set ``text=False`` to send a string (:class:`str`) in a Binary_ + frame. This may be useful for servers that expect binary frames + instead of text frames. + + :meth:`send` also accepts an iterable or asynchronous iterable of + strings, bytestrings, or bytes-like objects to enable fragmentation_. + Each item is treated as a message fragment and sent in its own frame. + All items must be of the same type, or else :meth:`send` will raise a + :exc:`TypeError` and the connection will be closed. + + .. _fragmentation: https://datatracker.ietf.org/doc/html/rfc6455#section-5.4 + + :meth:`send` rejects dict-like objects because this is often an error. + (If you really want to send the keys of a dict-like object as fragments, + call its :meth:`~dict.keys` method and pass the result to :meth:`send`.) + + Canceling :meth:`send` is discouraged. Instead, you should close the + connection with :meth:`close`. Indeed, there are only two situations + where :meth:`send` may yield control to the event loop and then get + canceled; in both cases, :meth:`close` has the same effect and the + effect is more obvious: + + 1. The write buffer is full. If you don't want to wait until enough + data is sent, your only alternative is to close the connection. + :meth:`close` will likely time out then abort the TCP connection. + 2. ``message`` is an asynchronous iterator that yields control. + Stopping in the middle of a fragmented message will cause a + protocol error and the connection will be closed. + + When the connection is closed, :meth:`send` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it + raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal + connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. + + Args: + message: Message to send. + + Raises: + ConnectionClosed: When the connection is closed. + TypeError: If ``message`` doesn't have a supported type. + + """ + # While sending a fragmented message, prevent sending other messages + # until all fragments are sent. + while self.send_in_progress is not None: + await asyncio.shield(self.send_in_progress) + + # Unfragmented message -- this case must be handled first because + # strings and bytes-like objects are iterable. + + if isinstance(message, str): + async with self.send_context(): + if text is False: + self.protocol.send_binary(message.encode()) + else: + self.protocol.send_text(message.encode()) + + elif isinstance(message, BytesLike): + async with self.send_context(): + if text is True: + self.protocol.send_text(message) + else: + self.protocol.send_binary(message) + + # Catch a common mistake -- passing a dict to send(). + + elif isinstance(message, Mapping): + raise TypeError("data is a dict-like object") + + # Fragmented message -- regular iterator. + + elif isinstance(message, Iterable): + chunks = iter(message) + try: + chunk = next(chunks) + except StopIteration: + return + + assert self.send_in_progress is None + self.send_in_progress = self.loop.create_future() + try: + # First fragment. + if isinstance(chunk, str): + async with self.send_context(): + if text is False: + self.protocol.send_binary(chunk.encode(), fin=False) + else: + self.protocol.send_text(chunk.encode(), fin=False) + encode = True + elif isinstance(chunk, BytesLike): + async with self.send_context(): + if text is True: + self.protocol.send_text(chunk, fin=False) + else: + self.protocol.send_binary(chunk, fin=False) + encode = False + else: + raise TypeError("iterable must contain bytes or str") + + # Other fragments + for chunk in chunks: + if isinstance(chunk, str) and encode: + async with self.send_context(): + self.protocol.send_continuation(chunk.encode(), fin=False) + elif isinstance(chunk, BytesLike) and not encode: + async with self.send_context(): + self.protocol.send_continuation(chunk, fin=False) + else: + raise TypeError("iterable must contain uniform types") + + # Final fragment. + async with self.send_context(): + self.protocol.send_continuation(b"", fin=True) + + except Exception: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + async with self.send_context(): + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "error in fragmented message", + ) + raise + + finally: + self.send_in_progress.set_result(None) + self.send_in_progress = None + + # Fragmented message -- async iterator. + + elif isinstance(message, AsyncIterable): + achunks = aiter(message) + try: + chunk = await anext(achunks) + except StopAsyncIteration: + return + + assert self.send_in_progress is None + self.send_in_progress = self.loop.create_future() + try: + # First fragment. + if isinstance(chunk, str): + if text is False: + async with self.send_context(): + self.protocol.send_binary(chunk.encode(), fin=False) + else: + async with self.send_context(): + self.protocol.send_text(chunk.encode(), fin=False) + encode = True + elif isinstance(chunk, BytesLike): + if text is True: + async with self.send_context(): + self.protocol.send_text(chunk, fin=False) + else: + async with self.send_context(): + self.protocol.send_binary(chunk, fin=False) + encode = False + else: + raise TypeError("async iterable must contain bytes or str") + + # Other fragments + async for chunk in achunks: + if isinstance(chunk, str) and encode: + async with self.send_context(): + self.protocol.send_continuation(chunk.encode(), fin=False) + elif isinstance(chunk, BytesLike) and not encode: + async with self.send_context(): + self.protocol.send_continuation(chunk, fin=False) + else: + raise TypeError("async iterable must contain uniform types") + + # Final fragment. + async with self.send_context(): + self.protocol.send_continuation(b"", fin=True) + + except Exception: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + async with self.send_context(): + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "error in fragmented message", + ) + raise + + finally: + self.send_in_progress.set_result(None) + self.send_in_progress = None + + else: + raise TypeError("data must be str, bytes, iterable, or async iterable") + + async def close( + self, + code: CloseCode | int = CloseCode.NORMAL_CLOSURE, + reason: str = "", + ) -> None: + """ + Perform the closing handshake. + + :meth:`close` waits for the other end to complete the handshake and + for the TCP connection to terminate. + + :meth:`close` is idempotent: it doesn't do anything once the + connection is closed. + + Args: + code: WebSocket close code. + reason: WebSocket close reason. + + """ + try: + # The context manager takes care of waiting for the TCP connection + # to terminate after calling a method that sends a close frame. + async with self.send_context(): + if self.send_in_progress is not None: + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "close during fragmented message", + ) + else: + self.protocol.send_close(code, reason) + except ConnectionClosed: + # Ignore ConnectionClosed exceptions raised from send_context(). + # They mean that the connection is closed, which was the goal. + pass + + async def wait_closed(self) -> None: + """ + Wait until the connection is closed. + + :meth:`wait_closed` waits for the closing handshake to complete and for + the TCP connection to terminate. + + """ + await asyncio.shield(self.connection_lost_waiter) + + async def ping(self, data: DataLike | None = None) -> Awaitable[float]: + """ + Send a Ping_. + + .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 + + A ping may serve as a keepalive or as a check that the remote endpoint + received all messages up to this point + + Args: + data: Payload of the ping. A :class:`str` will be encoded to UTF-8. + If ``data`` is :obj:`None`, the payload is four random bytes. + + Returns: + A future that will be completed when the corresponding pong is + received. You can ignore it if you don't intend to wait. The result + of the future is the latency of the connection in seconds. + + :: + + pong_received = await ws.ping() + # only if you want to wait for the corresponding pong + latency = await pong_received + + Raises: + ConnectionClosed: When the connection is closed. + ConcurrencyError: If another ping was sent with the same data and + the corresponding pong wasn't received yet. + + """ + if isinstance(data, BytesLike): + data = bytes(data) + elif isinstance(data, str): + data = data.encode() + elif data is not None: + raise TypeError("data must be str or bytes-like") + + async with self.send_context(): + # Protect against duplicates if a payload is explicitly set. + if data in self.pending_pings: + raise ConcurrencyError("already waiting for a pong with the same data") + + # Generate a unique random payload otherwise. + while data is None or data in self.pending_pings: + data = struct.pack("!I", random.getrandbits(32)) + + pong_received = self.loop.create_future() + ping_timestamp = self.loop.time() + # The event loop's default clock is time.monotonic(). Its resolution + # is a bit low on Windows (~16ms). This is improved in Python 3.13. + self.pending_pings[data] = (pong_received, ping_timestamp) + self.protocol.send_ping(data) + return pong_received + + async def pong(self, data: DataLike = b"") -> None: + """ + Send a Pong_. + + .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 + + An unsolicited pong may serve as a unidirectional heartbeat. + + Args: + data: Payload of the pong. A :class:`str` will be encoded to UTF-8. + + Raises: + ConnectionClosed: When the connection is closed. + + """ + if isinstance(data, BytesLike): + data = bytes(data) + elif isinstance(data, str): + data = data.encode() + else: + raise TypeError("data must be str or bytes-like") + + async with self.send_context(): + self.protocol.send_pong(data) + + # Private methods + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + This method is overridden in subclasses to handle the handshake. + + """ + assert isinstance(event, Frame) + if event.opcode in DATA_OPCODES: + self.recv_messages.put(event) + + if event.opcode is Opcode.PONG: + self.acknowledge_pings(bytes(event.data)) + + def acknowledge_pings(self, data: bytes) -> None: + """ + Acknowledge pings when receiving a pong. + + """ + # Ignore unsolicited pong. + if data not in self.pending_pings: + return + + pong_timestamp = self.loop.time() + + # Sending a pong for only the most recent ping is legal. + # Acknowledge all previous pings too in that case. + ping_id = None + ping_ids = [] + for ping_id, (pong_received, ping_timestamp) in self.pending_pings.items(): + ping_ids.append(ping_id) + latency = pong_timestamp - ping_timestamp + if not pong_received.done(): + pong_received.set_result(latency) + if ping_id == data: + self.latency = latency + break + else: + raise AssertionError("solicited pong not found in pings") + + # Remove acknowledged pings from self.pending_pings. + for ping_id in ping_ids: + del self.pending_pings[ping_id] + + def terminate_pending_pings(self) -> None: + """ + Raise ConnectionClosed in pending pings when the connection is closed. + + """ + assert self.protocol.state is CLOSED + exc = self.protocol.close_exc + + for pong_received, _ping_timestamp in self.pending_pings.values(): + if not pong_received.done(): + pong_received.set_exception(exc) + # If the exception is never retrieved, it will be logged when ping + # is garbage-collected. This is confusing for users. + # Given that ping is done (with an exception), canceling it does + # nothing, but it prevents logging the exception. + pong_received.cancel() + + self.pending_pings.clear() + + async def keepalive(self) -> None: + """ + Send a Ping frame and wait for a Pong frame at regular intervals. + + """ + assert self.ping_interval is not None + latency = 0.0 + try: + while True: + # If self.ping_timeout > latency > self.ping_interval, + # pings will be sent immediately after receiving pongs. + # The period will be longer than self.ping_interval. + await asyncio.sleep(self.ping_interval - latency) + + # This cannot raise ConnectionClosed when the connection is + # closing because ping(), via send_context(), waits for the + # connection to be closed before raising ConnectionClosed. + # However, connection_lost() cancels keepalive_task before + # it gets a chance to resume excuting. + pong_received = await self.ping() + if self.debug: + self.logger.debug("% sent keepalive ping") + + if self.ping_timeout is not None: + try: + async with asyncio_timeout(self.ping_timeout): + # connection_lost cancels keepalive immediately + # after setting a ConnectionClosed exception on + # pong_received. A CancelledError is raised here, + # not a ConnectionClosed exception. + latency = await pong_received + if self.debug: + self.logger.debug("% received keepalive pong") + except asyncio.TimeoutError: + if self.debug: + self.logger.debug("- timed out waiting for keepalive pong") + async with self.send_context(): + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "keepalive ping timeout", + ) + raise AssertionError( + "send_context() should wait for connection_lost(), " + "which cancels keepalive()" + ) + except Exception: + self.logger.error("keepalive ping failed", exc_info=True) + + def start_keepalive(self) -> None: + """ + Run :meth:`keepalive` in a task, unless keepalive is disabled. + + """ + if self.ping_interval is not None: + self.keepalive_task = self.loop.create_task(self.keepalive()) + + @contextlib.asynccontextmanager + async def send_context( + self, + *, + expected_state: State = OPEN, # CONNECTING during the opening handshake + ) -> AsyncIterator[None]: + """ + Create a context for writing to the connection from user code. + + On entry, :meth:`send_context` checks that the connection is open; on + exit, it writes outgoing data to the socket:: + + async with self.send_context(): + self.protocol.send_text(message.encode()) + + When the connection isn't open on entry, when the connection is expected + to close on exit, or when an unexpected error happens, terminating the + connection, :meth:`send_context` waits until the connection is closed + then raises :exc:`~websockets.exceptions.ConnectionClosed`. + + """ + # Should we wait until the connection is closed? + wait_for_close = False + # Should we close the transport and raise ConnectionClosed? + raise_close_exc = False + # What exception should we chain ConnectionClosed to? + original_exc: BaseException | None = None + + if self.protocol.state is expected_state: + # Let the caller interact with the protocol. + try: + yield + except (ProtocolError, ConcurrencyError): + # The protocol state wasn't changed. Exit immediately. + raise + except Exception as exc: + self.logger.error("unexpected internal error", exc_info=True) + # This branch should never run. It's a safety net in case of + # bugs. Since we don't know what happened, we will close the + # connection and raise the exception to the caller. + wait_for_close = False + raise_close_exc = True + original_exc = exc + else: + # Check if the connection is expected to close soon. + if self.protocol.close_expected(): + wait_for_close = True + # Set the close deadline based on the close timeout. + # Since we tested earlier that protocol.state is OPEN + # (or CONNECTING), self.close_deadline is still None. + assert self.close_deadline is None + if self.close_timeout is not None: + self.close_deadline = self.loop.time() + self.close_timeout + # Write outgoing data to the socket with flow control. + try: + self.send_data() + await self.drain() + except Exception as exc: + if self.debug: + self.logger.debug( + "! error while sending data", + exc_info=True, + ) + # While the only expected exception here is OSError, + # other exceptions would be treated identically. + wait_for_close = False + raise_close_exc = True + original_exc = exc + + else: # self.protocol.state is not expected_state + # Minor layering violation: we assume that the connection + # will be closing soon if it isn't in the expected state. + wait_for_close = True + # Calculate close_deadline if it wasn't set yet. + if self.close_deadline is None: + if self.close_timeout is not None: + self.close_deadline = self.loop.time() + self.close_timeout + raise_close_exc = True + + # If the connection is expected to close soon and the close timeout + # elapses, close the socket to terminate the connection. + if wait_for_close: + try: + async with asyncio_timeout_at(self.close_deadline): + await asyncio.shield(self.connection_lost_waiter) + except TimeoutError: + # There's no risk of overwriting another error because + # original_exc is never set when wait_for_close is True. + assert original_exc is None + original_exc = TimeoutError("timed out while closing connection") + # Set recv_exc before closing the transport in order to get + # proper exception reporting. + raise_close_exc = True + self.set_recv_exc(original_exc) + + # If an error occurred, close the transport to terminate the connection and + # raise an exception. + if raise_close_exc: + self.transport.abort() + # Wait for the protocol state to be CLOSED before accessing close_exc. + await asyncio.shield(self.connection_lost_waiter) + raise self.protocol.close_exc from original_exc + + def send_data(self) -> None: + """ + Send outgoing data. + + """ + for data in self.protocol.data_to_send(): + if data: + self.transport.write(data) + else: + # Half-close the TCP connection when possible i.e. no TLS. + if self.transport.can_write_eof(): + if self.debug: + self.logger.debug("x half-closing TCP connection") + # write_eof() doesn't document which exceptions it raises. + # OSError is plausible. uvloop can raise RuntimeError here. + try: + self.transport.write_eof() + except Exception: # pragma: no cover + pass + # Else, close the TCP connection. + else: # pragma: no cover + if self.debug: + self.logger.debug("x closing TCP connection") + self.transport.close() + + def set_recv_exc(self, exc: BaseException | None) -> None: + """ + Set recv_exc, if not set yet. + + This method must be called only from connection callbacks. + + """ + if self.recv_exc is None: + self.recv_exc = exc + + # asyncio.Protocol methods + + # Connection callbacks + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + transport = cast(asyncio.Transport, transport) + self.recv_messages = Assembler( + self.max_queue_high, + self.max_queue_low, + pause=transport.pause_reading, + resume=transport.resume_reading, + ) + transport.set_write_buffer_limits( + self.write_limit_high, + self.write_limit_low, + ) + self.transport = transport + + def connection_lost(self, exc: Exception | None) -> None: + # Calling protocol.receive_eof() is safe because it's idempotent. + # This guarantees that the protocol state becomes CLOSED. + self.protocol.receive_eof() + assert self.protocol.state is CLOSED + + self.set_recv_exc(exc) + + # Abort recv() and pending pings with a ConnectionClosed exception. + self.recv_messages.close() + self.terminate_pending_pings() + + if self.keepalive_task is not None: + self.keepalive_task.cancel() + + # If self.connection_lost_waiter isn't pending, that's a bug, because: + # - it's set only here in connection_lost() which is called only once; + # - it must never be canceled. + self.connection_lost_waiter.set_result(None) + + # Adapted from asyncio.streams.FlowControlMixin + if self.paused: # pragma: no cover + self.paused = False + for waiter in self.drain_waiters: + if not waiter.done(): + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) + + # Flow control callbacks + + def pause_writing(self) -> None: # pragma: no cover + # Adapted from asyncio.streams.FlowControlMixin + assert not self.paused + self.paused = True + + def resume_writing(self) -> None: # pragma: no cover + # Adapted from asyncio.streams.FlowControlMixin + assert self.paused + self.paused = False + for waiter in self.drain_waiters: + if not waiter.done(): + waiter.set_result(None) + + async def drain(self) -> None: # pragma: no cover + # We don't check if the connection is closed because we call drain() + # immediately after write() and write() would fail in that case. + + # Adapted from asyncio.streams.StreamWriter + # Yield to the event loop so that connection_lost() may be called. + if self.transport.is_closing(): + await asyncio.sleep(0) + + # Adapted from asyncio.streams.FlowControlMixin + if self.paused: + waiter = self.loop.create_future() + self.drain_waiters.append(waiter) + try: + await waiter + finally: + self.drain_waiters.remove(waiter) + + # Streaming protocol callbacks + + def data_received(self, data: bytes) -> None: + # Feed incoming data to the protocol. + self.protocol.receive_data(data) + + # This isn't expected to raise an exception. + events = self.protocol.events_received() + + # Write outgoing data to the transport. + try: + self.send_data() + except Exception as exc: + if self.debug: + self.logger.debug("! error while sending data", exc_info=True) + self.set_recv_exc(exc) + + # If needed, set the close deadline based on the close timeout. + if self.protocol.close_expected(): + if self.close_deadline is None: + if self.close_timeout is not None: + self.close_deadline = self.loop.time() + self.close_timeout + + # If self.send_data raised an exception, then events are lost. + # Given that automatic responses write small amounts of data, + # this should be uncommon, so we don't handle the edge case. + + for event in events: + # This isn't expected to raise an exception. + self.process_event(event) + + def eof_received(self) -> None: + # Feed the end of the data stream to the protocol. + self.protocol.receive_eof() + + # This isn't expected to raise an exception. + events = self.protocol.events_received() + + # There is no error handling because send_data() can only write + # the end of the data stream and it handles errors by itself. + self.send_data() + + # This code path is triggered when receiving an HTTP response + # without a Content-Length header. This is the only case where + # reading until EOF generates an event; all other events have + # a known length. Ignore for coverage measurement because tests + # are in test_client.py rather than test_connection.py. + for event in events: # pragma: no cover + # This isn't expected to raise an exception. + self.process_event(event) + + # The WebSocket protocol has its own closing handshake: endpoints close + # the TCP or TLS connection after sending and receiving a close frame. + # As a consequence, they never need to write after receiving EOF, so + # there's no reason to keep the transport open by returning True. + # Besides, that doesn't work on TLS connections. + + +# broadcast() is defined in the connection module even though it's primarily +# used by servers and documented in the server module because it works with +# client connections too and because it's easier to test together with the +# Connection class. + + +def broadcast( + connections: Iterable[Connection], + message: DataLike, + raise_exceptions: bool = False, +) -> None: + """ + Broadcast a message to several WebSocket connections. + + A string (:class:`str`) is sent as a Text_ frame. A bytestring or bytes-like + object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent + as a Binary_ frame. + + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + :func:`broadcast` pushes the message synchronously to all connections even + if their write buffers are overflowing. There's no backpressure. + + If you broadcast messages faster than a connection can handle them, messages + will pile up in its write buffer until the connection times out. Keep + ``ping_interval`` and ``ping_timeout`` low to prevent excessive memory usage + from slow connections. + + Unlike :meth:`~websockets.asyncio.connection.Connection.send`, + :func:`broadcast` doesn't support sending fragmented messages. Indeed, + fragmentation is useful for sending large messages without buffering them in + memory, while :func:`broadcast` buffers one copy per connection as fast as + possible. + + :func:`broadcast` skips connections that aren't open in order to avoid + errors on connections where the closing handshake is in progress. + + :func:`broadcast` ignores failures to write the message on some connections. + It continues writing to other connections. On Python 3.11 and above, you may + set ``raise_exceptions`` to :obj:`True` to record failures and raise all + exceptions in a :pep:`654` :exc:`ExceptionGroup`. + + While :func:`broadcast` makes more sense for servers, it works identically + with clients, if you have a use case for opening connections to many servers + and broadcasting a message to them. + + Args: + websockets: WebSocket connections to which the message will be sent. + message: Message to send. + raise_exceptions: Whether to raise an exception in case of failures. + + Raises: + TypeError: If ``message`` doesn't have a supported type. + + """ + if isinstance(message, str): + send_method = "send_text" + message = message.encode() + elif isinstance(message, BytesLike): + send_method = "send_binary" + else: + raise TypeError("data must be str or bytes") + + if raise_exceptions: + if sys.version_info[:2] < (3, 11): # pragma: no cover + raise ValueError("raise_exceptions requires at least Python 3.11") + exceptions: list[Exception] = [] + + for connection in connections: + exception: Exception + + if connection.protocol.state is not OPEN: + continue + + if connection.send_in_progress is not None: + if raise_exceptions: + exception = ConcurrencyError("sending a fragmented message") + exceptions.append(exception) + else: + connection.logger.warning( + "skipped broadcast: sending a fragmented message", + ) + continue + + try: + # Call connection.protocol.send_text or send_binary. + # Either way, message is already converted to bytes. + getattr(connection.protocol, send_method)(message) + connection.send_data() + except Exception as write_exception: + if raise_exceptions: + exception = RuntimeError("failed to write message") + exception.__cause__ = write_exception + exceptions.append(exception) + else: + connection.logger.warning( + "skipped broadcast: failed to write message: %s", + traceback.format_exception_only(write_exception)[0].strip(), + ) + + if raise_exceptions and exceptions: + raise ExceptionGroup("skipped broadcast", exceptions) + + +# Pretend that broadcast is actually defined in the server module. +broadcast.__module__ = "websockets.asyncio.server" diff --git a/source/websockets/asyncio/messages.py b/source/websockets/asyncio/messages.py new file mode 100644 index 0000000000000000000000000000000000000000..f27cb2e7c8724434652f15be0a7efbdd49cf8951 --- /dev/null +++ b/source/websockets/asyncio/messages.py @@ -0,0 +1,316 @@ +from __future__ import annotations + +import asyncio +import codecs +import collections +from collections.abc import AsyncIterator, Iterable +from typing import Any, Callable, Generic, Literal, TypeVar, overload + +from ..exceptions import ConcurrencyError +from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame +from ..typing import Data + + +__all__ = ["Assembler"] + +UTF8Decoder = codecs.getincrementaldecoder("utf-8") + +T = TypeVar("T") + + +class SimpleQueue(Generic[T]): + """ + Simplified version of :class:`asyncio.Queue`. + + Provides only the subset of functionality needed by :class:`Assembler`. + + """ + + def __init__(self) -> None: + self.loop = asyncio.get_running_loop() + self.get_waiter: asyncio.Future[None] | None = None + self.queue: collections.deque[T] = collections.deque() + + def __len__(self) -> int: + return len(self.queue) + + def put(self, item: T) -> None: + """Put an item into the queue.""" + self.queue.append(item) + if self.get_waiter is not None and not self.get_waiter.done(): + self.get_waiter.set_result(None) + + async def get(self, block: bool = True) -> T: + """Remove and return an item from the queue, waiting if necessary.""" + if not self.queue: + if not block: + raise EOFError("stream of frames ended") + assert self.get_waiter is None, "cannot call get() concurrently" + self.get_waiter = self.loop.create_future() + try: + await self.get_waiter + finally: + self.get_waiter.cancel() + self.get_waiter = None + return self.queue.popleft() + + def reset(self, items: Iterable[T]) -> None: + """Put back items into an empty, idle queue.""" + assert self.get_waiter is None, "cannot reset() while get() is running" + assert not self.queue, "cannot reset() while queue isn't empty" + self.queue.extend(items) + + def abort(self) -> None: + """Close the queue, raising EOFError in get() if necessary.""" + if self.get_waiter is not None and not self.get_waiter.done(): + self.get_waiter.set_exception(EOFError("stream of frames ended")) + + +class Assembler: + """ + Assemble messages from frames. + + :class:`Assembler` expects only data frames. The stream of frames must + respect the protocol; if it doesn't, the behavior is undefined. + + Args: + pause: Called when the buffer of frames goes above the high water mark; + should pause reading from the network. + resume: Called when the buffer of frames goes below the low water mark; + should resume reading from the network. + + """ + + def __init__( + self, + high: int | None = None, + low: int | None = None, + pause: Callable[[], Any] = lambda: None, + resume: Callable[[], Any] = lambda: None, + ) -> None: + # Queue of incoming frames. + self.frames: SimpleQueue[Frame] = SimpleQueue() + + # We cannot put a hard limit on the size of the queue because a single + # call to Protocol.data_received() could produce thousands of frames, + # which must be buffered. Instead, we pause reading when the buffer goes + # above the high limit and we resume when it goes under the low limit. + if high is not None and low is None: + low = high // 4 + if high is None and low is not None: + high = low * 4 + if high is not None and low is not None: + if low < 0: + raise ValueError("low must be positive or equal to zero") + if high < low: + raise ValueError("high must be greater than or equal to low") + self.high, self.low = high, low + self.pause = pause + self.resume = resume + self.paused = False + + # This flag prevents concurrent calls to get() by user code. + self.get_in_progress = False + + # This flag marks the end of the connection. + self.closed = False + + @overload + async def get(self, decode: Literal[True]) -> str: ... + + @overload + async def get(self, decode: Literal[False]) -> bytes: ... + + @overload + async def get(self, decode: bool | None = None) -> Data: ... + + async def get(self, decode: bool | None = None) -> Data: + """ + Read the next message. + + :meth:`get` returns a single :class:`str` or :class:`bytes`. + + If the message is fragmented, :meth:`get` waits until the last frame is + received, then it reassembles the message and returns it. To receive + messages frame by frame, use :meth:`get_iter` instead. + + Args: + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. + + Raises: + EOFError: If the stream of frames has ended. + UnicodeDecodeError: If a text frame contains invalid UTF-8. + ConcurrencyError: If two coroutines run :meth:`get` or + :meth:`get_iter` concurrently. + + """ + if self.get_in_progress: + raise ConcurrencyError("get() or get_iter() is already running") + self.get_in_progress = True + + # Locking with get_in_progress prevents concurrent execution + # until get() fetches a complete message or is canceled. + + try: + # Fetch the first frame. + frame = await self.frames.get(not self.closed) + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + frames = [frame] + + # Fetch subsequent frames for fragmented messages. + while not frame.fin: + try: + frame = await self.frames.get(not self.closed) + except asyncio.CancelledError: + # Put frames already received back into the queue + # so that future calls to get() can return them. + self.frames.reset(frames) + raise + self.maybe_resume() + assert frame.opcode is OP_CONT + frames.append(frame) + + finally: + self.get_in_progress = False + + # This converts frame.data to bytes when it's a bytearray. + data = b"".join(frame.data for frame in frames) + if decode: + return data.decode() + else: + return data + + @overload + def get_iter(self, decode: Literal[True]) -> AsyncIterator[str]: ... + + @overload + def get_iter(self, decode: Literal[False]) -> AsyncIterator[bytes]: ... + + @overload + def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: ... + + async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: + """ + Stream the next message. + + Iterating the return value of :meth:`get_iter` asynchronously yields a + :class:`str` or :class:`bytes` for each frame in the message. + + The iterator must be fully consumed before calling :meth:`get_iter` or + :meth:`get` again. Else, :exc:`ConcurrencyError` is raised. + + This method only makes sense for fragmented messages. If messages aren't + fragmented, use :meth:`get` instead. + + Args: + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. + + Raises: + EOFError: If the stream of frames has ended. + UnicodeDecodeError: If a text frame contains invalid UTF-8. + ConcurrencyError: If two coroutines run :meth:`get` or + :meth:`get_iter` concurrently. + + """ + if self.get_in_progress: + raise ConcurrencyError("get() or get_iter() is already running") + self.get_in_progress = True + + # Locking with get_in_progress prevents concurrent execution + # until get_iter() fetches a complete message or is canceled. + + # If get_iter() raises an exception e.g. in decoder.decode(), + # get_in_progress remains set and the connection becomes unusable. + + # Yield the first frame. + try: + frame = await self.frames.get(not self.closed) + except asyncio.CancelledError: + self.get_in_progress = False + raise + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + if decode: + decoder = UTF8Decoder() + yield decoder.decode(frame.data, frame.fin) + else: + # Convert to bytes when frame.data is a bytearray. + yield bytes(frame.data) + + # Yield subsequent frames for fragmented messages. + while not frame.fin: + # We cannot handle asyncio.CancelledError because we don't buffer + # previous fragments — we're streaming them. Canceling get_iter() + # here will leave the assembler in a stuck state. Future calls to + # get() or get_iter() will raise ConcurrencyError. + frame = await self.frames.get(not self.closed) + self.maybe_resume() + assert frame.opcode is OP_CONT + if decode: + yield decoder.decode(frame.data, frame.fin) + else: + # Convert to bytes when frame.data is a bytearray. + yield bytes(frame.data) + + self.get_in_progress = False + + def put(self, frame: Frame) -> None: + """ + Add ``frame`` to the next message. + + Raises: + EOFError: If the stream of frames has ended. + + """ + if self.closed: + raise EOFError("stream of frames ended") + + self.frames.put(frame) + self.maybe_pause() + + def maybe_pause(self) -> None: + """Pause the writer if queue is above the high water mark.""" + # Skip if flow control is disabled. + if self.high is None: + return + + # Check for "> high" to support high = 0. + if len(self.frames) > self.high and not self.paused: + self.paused = True + self.pause() + + def maybe_resume(self) -> None: + """Resume the writer if queue is below the low water mark.""" + # Skip if flow control is disabled. + if self.low is None: + return + + # Check for "<= low" to support low = 0. + if len(self.frames) <= self.low and self.paused: + self.paused = False + self.resume() + + def close(self) -> None: + """ + End the stream of frames. + + Calling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, + or :meth:`put` is safe. They will raise :exc:`EOFError`. + + """ + if self.closed: + return + + self.closed = True + + # Unblock get() or get_iter(). + self.frames.abort() diff --git a/source/websockets/asyncio/router.py b/source/websockets/asyncio/router.py new file mode 100644 index 0000000000000000000000000000000000000000..49d2a40feee03ebcda5e05dc43eebeb27ccecab7 --- /dev/null +++ b/source/websockets/asyncio/router.py @@ -0,0 +1,219 @@ +from __future__ import annotations + +import http +import ssl as ssl_module +import urllib.parse +from typing import Any, Awaitable, Callable, Literal + +from ..http11 import Request, Response +from .server import Server, ServerConnection, serve + + +__all__ = ["route", "unix_route", "Router"] + + +try: + from werkzeug.exceptions import NotFound + from werkzeug.routing import Map, RequestRedirect + +except ImportError: + + def route( + url_map: Map, + *args: Any, + server_name: str | None = None, + ssl: ssl_module.SSLContext | Literal[True] | None = None, + create_router: type[Router] | None = None, + **kwargs: Any, + ) -> Awaitable[Server]: + raise ImportError("route() requires werkzeug") + + def unix_route( + url_map: Map, + path: str | None = None, + **kwargs: Any, + ) -> Awaitable[Server]: + raise ImportError("unix_route() requires werkzeug") + +else: + + def route( + url_map: Map, + *args: Any, + server_name: str | None = None, + ssl: ssl_module.SSLContext | Literal[True] | None = None, + create_router: type[Router] | None = None, + **kwargs: Any, + ) -> Awaitable[Server]: + """ + Create a WebSocket server dispatching connections to different handlers. + + This feature requires the third-party library `werkzeug`_: + + .. code-block:: console + + $ pip install werkzeug + + .. _werkzeug: https://werkzeug.palletsprojects.com/ + + :func:`route` accepts the same arguments as + :func:`~websockets.sync.server.serve`, except as described below. + + The first argument is a :class:`werkzeug.routing.Map` that maps URL patterns + to connection handlers. In addition to the connection, handlers receive + parameters captured in the URL as keyword arguments. + + Here's an example:: + + + from websockets.asyncio.router import route + from werkzeug.routing import Map, Rule + + async def channel_handler(websocket, channel_id): + ... + + url_map = Map([ + Rule("/channel/", endpoint=channel_handler), + ... + ]) + + # set this future to exit the server + stop = asyncio.get_running_loop().create_future() + + async with route(url_map, ...) as server: + await stop + + + Refer to the documentation of :mod:`werkzeug.routing` for details. + + If you define redirects with ``Rule(..., redirect_to=...)`` in the URL map, + when the server runs behind a reverse proxy that modifies the ``Host`` + header or terminates TLS, you need additional configuration: + + * Set ``server_name`` to the name of the server as seen by clients. When + not provided, websockets uses the value of the ``Host`` header. + + * Set ``ssl=True`` to generate ``wss://`` URIs without enabling TLS. + Under the hood, this bind the URL map with a ``url_scheme`` of + ``wss://`` instead of ``ws://``. + + There is no need to specify ``websocket=True`` in each rule. It is added + automatically. + + Args: + url_map: Mapping of URL patterns to connection handlers. + server_name: Name of the server as seen by clients. If :obj:`None`, + websockets uses the value of the ``Host`` header. + ssl: Configuration for enabling TLS on the connection. Set it to + :obj:`True` if a reverse proxy terminates TLS connections. + create_router: Factory for the :class:`Router` dispatching requests to + handlers. Set it to a wrapper or a subclass to customize routing. + + """ + url_scheme = "ws" if ssl is None else "wss" + if ssl is not True and ssl is not None: + kwargs["ssl"] = ssl + + if create_router is None: + create_router = Router + + router = create_router(url_map, server_name, url_scheme) + + _process_request: ( + Callable[ + [ServerConnection, Request], + Awaitable[Response | None] | Response | None, + ] + | None + ) = kwargs.pop("process_request", None) + if _process_request is None: + process_request: Callable[ + [ServerConnection, Request], + Awaitable[Response | None] | Response | None, + ] = router.route_request + else: + + async def process_request( + connection: ServerConnection, request: Request + ) -> Response | None: + response = _process_request(connection, request) + if isinstance(response, Awaitable): + response = await response + if response is not None: + return response + return router.route_request(connection, request) + + return serve(router.handler, *args, process_request=process_request, **kwargs) + + def unix_route( + url_map: Map, + path: str | None = None, + **kwargs: Any, + ) -> Awaitable[Server]: + """ + Create a WebSocket Unix server dispatching connections to different handlers. + + :func:`unix_route` combines the behaviors of :func:`route` and + :func:`~websockets.asyncio.server.unix_serve`. + + Args: + url_map: Mapping of URL patterns to connection handlers. + path: File system path to the Unix socket. + + """ + return route(url_map, unix=True, path=path, **kwargs) + + +class Router: + """WebSocket router supporting :func:`route`.""" + + def __init__( + self, + url_map: Map, + server_name: str | None = None, + url_scheme: str = "ws", + ) -> None: + self.url_map = url_map + self.server_name = server_name + self.url_scheme = url_scheme + for rule in self.url_map.iter_rules(): + rule.websocket = True + + def get_server_name(self, connection: ServerConnection, request: Request) -> str: + if self.server_name is None: + return request.headers["Host"] + else: + return self.server_name + + def redirect(self, connection: ServerConnection, url: str) -> Response: + response = connection.respond(http.HTTPStatus.FOUND, f"Found at {url}") + response.headers["Location"] = url + return response + + def not_found(self, connection: ServerConnection) -> Response: + return connection.respond(http.HTTPStatus.NOT_FOUND, "Not Found") + + def route_request( + self, connection: ServerConnection, request: Request + ) -> Response | None: + """Route incoming request.""" + url_map_adapter = self.url_map.bind( + server_name=self.get_server_name(connection, request), + url_scheme=self.url_scheme, + ) + try: + parsed = urllib.parse.urlparse(request.path) + handler, kwargs = url_map_adapter.match( + path_info=parsed.path, + query_args=parsed.query, + ) + except RequestRedirect as redirect: + return self.redirect(connection, redirect.new_url) + except NotFound: + return self.not_found(connection) + connection.handler, connection.handler_kwargs = handler, kwargs + return None + + async def handler(self, connection: ServerConnection) -> None: + """Handle a connection.""" + return await connection.handler(connection, **connection.handler_kwargs) diff --git a/source/websockets/asyncio/server.py b/source/websockets/asyncio/server.py new file mode 100644 index 0000000000000000000000000000000000000000..ef9bd807f102e05bbc3de212768bfed1679d8764 --- /dev/null +++ b/source/websockets/asyncio/server.py @@ -0,0 +1,997 @@ +from __future__ import annotations + +import asyncio +import hmac +import http +import logging +import re +import socket +import sys +from collections.abc import Awaitable, Generator, Iterable, Sequence +from types import TracebackType +from typing import Any, Callable, Mapping, cast + +from ..exceptions import InvalidHeader +from ..extensions.base import ServerExtensionFactory +from ..extensions.permessage_deflate import enable_server_permessage_deflate +from ..frames import CloseCode +from ..headers import ( + build_www_authenticate_basic, + parse_authorization_basic, + validate_subprotocols, +) +from ..http11 import SERVER, Request, Response +from ..protocol import CONNECTING, OPEN, Event +from ..server import ServerProtocol +from ..typing import LoggerLike, Origin, StatusLike, Subprotocol +from .compatibility import asyncio_timeout +from .connection import Connection, broadcast + + +__all__ = [ + "broadcast", + "serve", + "unix_serve", + "ServerConnection", + "Server", + "basic_auth", +] + + +class ServerConnection(Connection): + """ + :mod:`asyncio` implementation of a WebSocket server connection. + + :class:`ServerConnection` provides :meth:`recv` and :meth:`send` methods for + receiving and sending messages. + + It supports asynchronous iteration to receive messages:: + + async for message in websocket: + await process(message) + + The iterator exits normally when the connection is closed with code + 1000 (OK) or 1001 (going away) or without a close code. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is + closed with any other code. + + The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``, + and ``write_limit`` arguments have the same meaning as in :func:`serve`. + + Args: + protocol: Sans-I/O connection. + server: Server that manages this connection. + + """ + + def __init__( + self, + protocol: ServerProtocol, + server: Server, + *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = 10, + max_queue: int | None | tuple[int | None, int | None] = 16, + write_limit: int | tuple[int, int | None] = 2**15, + ) -> None: + self.protocol: ServerProtocol + super().__init__( + protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_queue=max_queue, + write_limit=write_limit, + ) + self.server = server + self.request_rcvd: asyncio.Future[None] = self.loop.create_future() + self.username: str # see basic_auth() + self.handler: Callable[[ServerConnection], Awaitable[None]] # see route() + self.handler_kwargs: Mapping[str, Any] # see route() + + def respond(self, status: StatusLike, text: str) -> Response: + """ + Create a plain text HTTP response. + + ``process_request`` and ``process_response`` may call this method to + return an HTTP response instead of performing the WebSocket opening + handshake. + + You can modify the response before returning it, for example by changing + HTTP headers. + + Args: + status: HTTP status code. + text: HTTP response body; it will be encoded to UTF-8. + + Returns: + HTTP response to send to the client. + + """ + return self.protocol.reject(status, text) + + async def handshake( + self, + process_request: ( + Callable[ + [ServerConnection, Request], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + process_response: ( + Callable[ + [ServerConnection, Request, Response], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + server_header: str | None = SERVER, + ) -> None: + """ + Perform the opening handshake. + + """ + await asyncio.wait( + [self.request_rcvd, self.connection_lost_waiter], + return_when=asyncio.FIRST_COMPLETED, + ) + + if self.request is not None: + async with self.send_context(expected_state=CONNECTING): + response = None + + if process_request is not None: + try: + response = process_request(self, self.request) + if isinstance(response, Awaitable): + response = await response + except Exception as exc: + self.protocol.handshake_exc = exc + response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + if response is None: + if self.server.is_serving(): + self.response = self.protocol.accept(self.request) + else: + self.response = self.protocol.reject( + http.HTTPStatus.SERVICE_UNAVAILABLE, + "Server is shutting down.\n", + ) + else: + assert isinstance(response, Response) # help mypy + self.response = response + + if server_header: + self.response.headers["Server"] = server_header + + response = None + + if process_response is not None: + try: + response = process_response(self, self.request, self.response) + if isinstance(response, Awaitable): + response = await response + except Exception as exc: + self.protocol.handshake_exc = exc + response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + if response is not None: + assert isinstance(response, Response) # help mypy + self.response = response + + self.protocol.send_response(self.response) + + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a request, when the request cannot be parsed, or when the + # handshake fails, including when process_request or process_response + # raises an exception. + + # It isn't set when process_request or process_response sends an HTTP + # response that rejects the handshake. + + if self.protocol.handshake_exc is not None: + raise self.protocol.handshake_exc + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + """ + # First event - handshake request. + if self.request is None: + assert isinstance(event, Request) + self.request = event + self.request_rcvd.set_result(None) + # Later events - frames. + else: + super().process_event(event) + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + super().connection_made(transport) + self.server.start_connection_handler(self) + + +class Server: + """ + WebSocket server returned by :func:`serve`. + + This class mirrors the API of :class:`asyncio.Server`. + + It keeps track of WebSocket connections in order to close them properly + when shutting down. + + Args: + handler: Connection handler. It receives the WebSocket connection, + which is a :class:`ServerConnection`, in argument. + process_request: Intercept the request during the opening handshake. + Return an HTTP response to force the response. Return :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. + ``process_request`` may be a function or a coroutine. + process_response: Intercept the response during the opening handshake. + Modify the response or return a new HTTP response to force the + response. Return :obj:`None` to continue normally. When you force an + HTTP 101 Continue response, the handshake is successful. Else, the + connection is aborted. ``process_response`` may be a function or a + coroutine. + server_header: Value of the ``Server`` response header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to + :obj:`None` removes the header. + open_timeout: Timeout for opening connections in seconds. + :obj:`None` disables the timeout. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. + See the :doc:`logging guide <../../topics/logging>` for details. + + """ + + def __init__( + self, + handler: Callable[[ServerConnection], Awaitable[None]], + *, + process_request: ( + Callable[ + [ServerConnection, Request], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + process_response: ( + Callable[ + [ServerConnection, Request, Response], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + server_header: str | None = SERVER, + open_timeout: float | None = 10, + logger: LoggerLike | None = None, + ) -> None: + self.loop = asyncio.get_running_loop() + self.handler = handler + self.process_request = process_request + self.process_response = process_response + self.server_header = server_header + self.open_timeout = open_timeout + if logger is None: + logger = logging.getLogger("websockets.server") + self.logger = logger + + # Keep track of active connections. + self.handlers: dict[ServerConnection, asyncio.Task[None]] = {} + + # Task responsible for closing the server and terminating connections. + self.close_task: asyncio.Task[None] | None = None + + # Completed when the server is closed and connections are terminated. + self.closed_waiter: asyncio.Future[None] = self.loop.create_future() + + @property + def connections(self) -> set[ServerConnection]: + """ + Set of active connections. + + This property contains all connections that completed the opening + handshake successfully and didn't start the closing handshake yet. + It can be useful in combination with :func:`~broadcast`. + + """ + return {connection for connection in self.handlers if connection.state is OPEN} + + def wrap(self, server: asyncio.Server) -> None: + """ + Attach to a given :class:`asyncio.Server`. + + Since :meth:`~asyncio.loop.create_server` doesn't support injecting a + custom ``Server`` class, the easiest solution that doesn't rely on + private :mod:`asyncio` APIs is to: + + - instantiate a :class:`Server` + - give the protocol factory a reference to that instance + - call :meth:`~asyncio.loop.create_server` with the factory + - attach the resulting :class:`asyncio.Server` with this method + + """ + self.server = server + for sock in server.sockets: + if sock.family == socket.AF_INET: + name = "%s:%d" % sock.getsockname() + elif sock.family == socket.AF_INET6: + name = "[%s]:%d" % sock.getsockname()[:2] + elif sock.family == socket.AF_UNIX: + name = sock.getsockname() + # In the unlikely event that someone runs websockets over a + # protocol other than IP or Unix sockets, avoid crashing. + else: # pragma: no cover + name = str(sock.getsockname()) + self.logger.info("server listening on %s", name) + + async def conn_handler(self, connection: ServerConnection) -> None: + """ + Handle the lifecycle of a WebSocket connection. + + Since this method doesn't have a caller that can handle exceptions, + it attempts to log relevant ones. + + It guarantees that the TCP connection is closed before exiting. + + """ + try: + async with asyncio_timeout(self.open_timeout): + try: + await connection.handshake( + self.process_request, + self.process_response, + self.server_header, + ) + except asyncio.CancelledError: + connection.transport.abort() + raise + except Exception: + connection.logger.error("opening handshake failed", exc_info=True) + connection.transport.abort() + return + + if connection.protocol.state is not OPEN: + # process_request or process_response rejected the handshake. + connection.transport.abort() + return + + try: + connection.start_keepalive() + await self.handler(connection) + except Exception: + connection.logger.error("connection handler failed", exc_info=True) + await connection.close(CloseCode.INTERNAL_ERROR) + else: + await connection.close() + + except TimeoutError: + # When the opening handshake times out, there's nothing to log. + pass + + except Exception: # pragma: no cover + # Don't leak connections on unexpected errors. + connection.transport.abort() + + finally: + # Registration is tied to the lifecycle of conn_handler() because + # the server waits for connection handlers to terminate, even if + # all connections are already closed. + del self.handlers[connection] + + def start_connection_handler(self, connection: ServerConnection) -> None: + """ + Register a connection with this server. + + """ + # The connection must be registered in self.handlers immediately. + # If it was registered in conn_handler(), a race condition could + # happen when closing the server after scheduling conn_handler() + # but before it starts executing. + self.handlers[connection] = self.loop.create_task(self.conn_handler(connection)) + + def close( + self, + close_connections: bool = True, + code: CloseCode | int = CloseCode.GOING_AWAY, + reason: str = "", + ) -> None: + """ + Close the server. + + * Close the underlying :class:`asyncio.Server`. + * When ``close_connections`` is :obj:`True`, which is the default, close + existing connections. Specifically: + + * Reject opening WebSocket connections with an HTTP 503 (service + unavailable) error. This happens when the server accepted the TCP + connection but didn't complete the opening handshake before closing. + * Close open WebSocket connections with code 1001 (going away). + ``code`` and ``reason`` can be customized, for example to use code + 1012 (service restart). + + * Wait until all connection handlers terminate. + + :meth:`close` is idempotent. + + """ + if self.close_task is None: + self.close_task = self.get_loop().create_task( + self._close(close_connections, code, reason) + ) + + async def _close( + self, + close_connections: bool = True, + code: CloseCode | int = CloseCode.GOING_AWAY, + reason: str = "", + ) -> None: + """ + Implementation of :meth:`close`. + + This calls :meth:`~asyncio.Server.close` on the underlying + :class:`asyncio.Server` object to stop accepting new connections and + then closes open connections. + + """ + self.logger.info("server closing") + + # Stop accepting new connections. + self.server.close() + + # Wait until all accepted connections reach connection_made() and call + # register(). See https://github.com/python/cpython/issues/79033 for + # details. This workaround can be removed when dropping Python < 3.11. + await asyncio.sleep(0) + + # After server.close(), handshake() closes OPENING connections with an + # HTTP 503 error. + + if close_connections: + # Close OPEN connections with code 1001 by default. + close_tasks = [ + asyncio.create_task(connection.close(code, reason)) + for connection in self.handlers + if connection.protocol.state is not CONNECTING + ] + # asyncio.wait doesn't accept an empty first argument. + if close_tasks: + await asyncio.wait(close_tasks) + + # Wait until all TCP connections are closed. + await self.server.wait_closed() + + # Wait until all connection handlers terminate. + # asyncio.wait doesn't accept an empty first argument. + if self.handlers: + await asyncio.wait(self.handlers.values()) + + # Tell wait_closed() to return. + self.closed_waiter.set_result(None) + + self.logger.info("server closed") + + async def wait_closed(self) -> None: + """ + Wait until the server is closed. + + When :meth:`wait_closed` returns, all TCP connections are closed and + all connection handlers have returned. + + To ensure a fast shutdown, a connection handler should always be + awaiting at least one of: + + * :meth:`~ServerConnection.recv`: when the connection is closed, + it raises :exc:`~websockets.exceptions.ConnectionClosedOK`; + * :meth:`~ServerConnection.wait_closed`: when the connection is + closed, it returns. + + Then the connection handler is immediately notified of the shutdown; + it can clean up and exit. + + """ + await asyncio.shield(self.closed_waiter) + + def get_loop(self) -> asyncio.AbstractEventLoop: + """ + See :meth:`asyncio.Server.get_loop`. + + """ + return self.server.get_loop() + + def is_serving(self) -> bool: # pragma: no cover + """ + See :meth:`asyncio.Server.is_serving`. + + """ + return self.server.is_serving() + + async def start_serving(self) -> None: # pragma: no cover + """ + See :meth:`asyncio.Server.start_serving`. + + Typical use:: + + server = await serve(..., start_serving=False) + # perform additional setup here... + # ... then start the server + await server.start_serving() + + """ + await self.server.start_serving() + + async def serve_forever(self) -> None: # pragma: no cover + """ + See :meth:`asyncio.Server.serve_forever`. + + Typical use:: + + server = await serve(...) + # this coroutine doesn't return + # canceling it stops the server + await server.serve_forever() + + This is an alternative to using :func:`serve` as an asynchronous context + manager. Shutdown is triggered by canceling :meth:`serve_forever` + instead of exiting a :func:`serve` context. + + """ + await self.server.serve_forever() + + @property + def sockets(self) -> tuple[socket.socket, ...]: + """ + See :attr:`asyncio.Server.sockets`. + + """ + return self.server.sockets + + async def __aenter__(self) -> Server: # pragma: no cover + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: # pragma: no cover + self.close() + await self.wait_closed() + + +# This is spelled in lower case because it's exposed as a callable in the API. +class serve: + """ + Create a WebSocket server listening on ``host`` and ``port``. + + Whenever a client connects, the server creates a :class:`ServerConnection`, + performs the opening handshake, and delegates to the ``handler`` coroutine. + + The handler receives the :class:`ServerConnection` instance, which you can + use to send and receive messages. + + Once the handler completes, either normally or with an exception, the server + performs the closing handshake and closes the connection. + + This coroutine returns a :class:`Server` whose API mirrors + :class:`asyncio.Server`. Treat it as an asynchronous context manager to + ensure that the server will be closed:: + + from websockets.asyncio.server import serve + + def handler(websocket): + ... + + # set this future to exit the server + stop = asyncio.get_running_loop().create_future() + + async with serve(handler, host, port): + await stop + + Alternatively, call :meth:`~Server.serve_forever` to serve requests and + cancel it to stop the server:: + + server = await serve(handler, host, port) + await server.serve_forever() + + Args: + handler: Connection handler. It receives the WebSocket connection, + which is a :class:`ServerConnection`, in argument. + host: Network interfaces the server binds to. + See :meth:`~asyncio.loop.create_server` for details. + port: TCP port the server listens on. + See :meth:`~asyncio.loop.create_server` for details. + origins: Acceptable values of the ``Origin`` header, for defending + against Cross-Site WebSocket Hijacking attacks. Values can be + :class:`str` to test for an exact match or regular expressions + compiled by :func:`re.compile` to test against a pattern. Include + :obj:`None` in the list if the lack of an origin is acceptable. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + select_subprotocol: Callback for selecting a subprotocol among + those supported by the client and the server. It receives a + :class:`ServerConnection` (not a + :class:`~websockets.server.ServerProtocol`!) instance and a list of + subprotocols offered by the client. Other than the first argument, + it has the same behavior as the + :meth:`ServerProtocol.select_subprotocol + ` method. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + process_request: Intercept the request during the opening handshake. + Return an HTTP response to force the response or :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. + ``process_request`` may be a function or a coroutine. + process_response: Intercept the response during the opening handshake. + Return an HTTP response to force the response or :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. + ``process_response`` may be a function or a coroutine. + server_header: Value of the ``Server`` response header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to + :obj:`None` removes the header. + open_timeout: Timeout for opening connections in seconds. + :obj:`None` disables the timeout. + ping_interval: Interval between keepalive pings in seconds. + :obj:`None` disables keepalive. + ping_timeout: Timeout for keepalive pings in seconds. + :obj:`None` disables timeouts. + close_timeout: Timeout for closing connections in seconds. + :obj:`None` disables the timeout. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. You may pass a ``(max_message_size, + max_fragment_size)`` tuple to set different limits for messages and + fragments when you expect long messages sent in short fragments. + max_queue: High-water mark of the buffer where frames are received. + It defaults to 16 frames. The low-water mark defaults to ``max_queue + // 4``. You may pass a ``(high, low)`` tuple to set the high-water + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. + write_limit: High-water mark of write buffer in bytes. It is passed to + :meth:`~asyncio.WriteTransport.set_write_buffer_limits`. It defaults + to 32 KiB. You may pass a ``(high, low)`` tuple to set the + high-water and low-water marks. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. See the + :doc:`logging guide <../../topics/logging>` for details. + create_connection: Factory for the :class:`ServerConnection` managing + the connection. Set it to a wrapper or a subclass to customize + connection handling. + + Any other keyword arguments are passed to the event loop's + :meth:`~asyncio.loop.create_server` method. + + For example: + + * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enable TLS. + + * You can set ``sock`` to provide a preexisting TCP socket. You may call + :func:`socket.create_server` (not to be confused with the event loop's + :meth:`~asyncio.loop.create_server` method) to create a suitable server + socket and customize it. + + * You can set ``start_serving`` to ``False`` to start accepting connections + only after you call :meth:`~Server.start_serving()` or + :meth:`~Server.serve_forever()`. + + """ + + def __init__( + self, + handler: Callable[[ServerConnection], Awaitable[None]], + host: str | None = None, + port: int | None = None, + *, + # WebSocket + origins: Sequence[Origin | re.Pattern[str] | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + select_subprotocol: ( + Callable[ + [ServerConnection, Sequence[Subprotocol]], + Subprotocol | None, + ] + | None + ) = None, + compression: str | None = "deflate", + # HTTP + process_request: ( + Callable[ + [ServerConnection, Request], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + process_response: ( + Callable[ + [ServerConnection, Request, Response], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + server_header: str | None = SERVER, + # Timeouts + open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = 10, + # Limits + max_size: int | None | tuple[int | None, int | None] = 2**20, + max_queue: int | None | tuple[int | None, int | None] = 16, + write_limit: int | tuple[int, int | None] = 2**15, + # Logging + logger: LoggerLike | None = None, + # Escape hatch for advanced customization + create_connection: type[ServerConnection] | None = None, + # Other keyword arguments are passed to loop.create_server + **kwargs: Any, + ) -> None: + if subprotocols is not None: + validate_subprotocols(subprotocols) + + if compression == "deflate": + extensions = enable_server_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + if create_connection is None: + create_connection = ServerConnection + + self.server = Server( + handler, + process_request=process_request, + process_response=process_response, + server_header=server_header, + open_timeout=open_timeout, + logger=logger, + ) + + if kwargs.get("ssl") is not None: + kwargs.setdefault("ssl_handshake_timeout", open_timeout) + if sys.version_info[:2] >= (3, 11): # pragma: no branch + kwargs.setdefault("ssl_shutdown_timeout", close_timeout) + + def factory() -> ServerConnection: + """ + Create an asyncio protocol for managing a WebSocket connection. + + """ + # Create a closure to give select_subprotocol access to connection. + protocol_select_subprotocol: ( + Callable[ + [ServerProtocol, Sequence[Subprotocol]], + Subprotocol | None, + ] + | None + ) = None + if select_subprotocol is not None: + + def protocol_select_subprotocol( + protocol: ServerProtocol, + subprotocols: Sequence[Subprotocol], + ) -> Subprotocol | None: + # mypy doesn't know that select_subprotocol is immutable. + assert select_subprotocol is not None + # Ensure this function is only used in the intended context. + assert protocol is connection.protocol + return select_subprotocol(connection, subprotocols) + + # This is a protocol in the Sans-I/O implementation of websockets. + protocol = ServerProtocol( + origins=origins, + extensions=extensions, + subprotocols=subprotocols, + select_subprotocol=protocol_select_subprotocol, + max_size=max_size, + logger=logger, + ) + # This is a connection in websockets and a protocol in asyncio. + connection = create_connection( + protocol, + self.server, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_queue=max_queue, + write_limit=write_limit, + ) + return connection + + loop = asyncio.get_running_loop() + if kwargs.pop("unix", False): + self.create_server = loop.create_unix_server(factory, **kwargs) + else: + # mypy cannot tell that kwargs must provide sock when port is None. + self.create_server = loop.create_server(factory, host, port, **kwargs) # type: ignore[arg-type] + + # async with serve(...) as ...: ... + + async def __aenter__(self) -> Server: + return await self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.server.close() + await self.server.wait_closed() + + # ... = await serve(...) + + def __await__(self) -> Generator[Any, None, Server]: + # Create a suitable iterator by calling __await__ on a coroutine. + return self.__await_impl__().__await__() + + async def __await_impl__(self) -> Server: + server = await self.create_server + self.server.wrap(server) + return self.server + + # ... = yield from serve(...) - remove when dropping Python < 3.11 + + __iter__ = __await__ + + +def unix_serve( + handler: Callable[[ServerConnection], Awaitable[None]], + path: str | None = None, + **kwargs: Any, +) -> Awaitable[Server]: + """ + Create a WebSocket server listening on a Unix socket. + + This function is identical to :func:`serve`, except the ``host`` and + ``port`` arguments are replaced by ``path``. It's only available on Unix. + + It's useful for deploying a server behind a reverse proxy such as nginx. + + Args: + handler: Connection handler. It receives the WebSocket connection, + which is a :class:`ServerConnection`, in argument. + path: File system path to the Unix socket. + + """ + return serve(handler, unix=True, path=path, **kwargs) + + +def is_credentials(credentials: Any) -> bool: + try: + username, password = credentials + except (TypeError, ValueError): + return False + else: + return isinstance(username, str) and isinstance(password, str) + + +def basic_auth( + realm: str = "", + credentials: tuple[str, str] | Iterable[tuple[str, str]] | None = None, + check_credentials: Callable[[str, str], Awaitable[bool] | bool] | None = None, +) -> Callable[[ServerConnection, Request], Awaitable[Response | None]]: + """ + Factory for ``process_request`` to enforce HTTP Basic Authentication. + + :func:`basic_auth` is designed to integrate with :func:`serve` as follows:: + + from websockets.asyncio.server import basic_auth, serve + + async with serve( + ..., + process_request=basic_auth( + realm="my dev server", + credentials=("hello", "iloveyou"), + ), + ): + + If authentication succeeds, the connection's ``username`` attribute is set. + If it fails, the server responds with an HTTP 401 Unauthorized status. + + One of ``credentials`` or ``check_credentials`` must be provided; not both. + + Args: + realm: Scope of protection. It should contain only ASCII characters + because the encoding of non-ASCII characters is undefined. Refer to + section 2.2 of :rfc:`7235` for details. + credentials: Hard coded authorized credentials. It can be a + ``(username, password)`` pair or a list of such pairs. + check_credentials: Function or coroutine that verifies credentials. + It receives ``username`` and ``password`` arguments and returns + whether they're valid. + Raises: + TypeError: If ``credentials`` or ``check_credentials`` is wrong. + ValueError: If ``credentials`` and ``check_credentials`` are both + provided or both not provided. + + """ + if (credentials is None) == (check_credentials is None): + raise ValueError("provide either credentials or check_credentials") + + if credentials is not None: + if is_credentials(credentials): + credentials_list = [cast(tuple[str, str], credentials)] + elif isinstance(credentials, Iterable): + credentials_list = list(cast(Iterable[tuple[str, str]], credentials)) + if not all(is_credentials(item) for item in credentials_list): + raise TypeError(f"invalid credentials argument: {credentials}") + else: + raise TypeError(f"invalid credentials argument: {credentials}") + + credentials_dict = dict(credentials_list) + + def check_credentials(username: str, password: str) -> bool: + try: + expected_password = credentials_dict[username] + except KeyError: + return False + return hmac.compare_digest(expected_password, password) + + assert check_credentials is not None # help mypy + + async def process_request( + connection: ServerConnection, + request: Request, + ) -> Response | None: + """ + Perform HTTP Basic Authentication. + + If it succeeds, set the connection's ``username`` attribute and return + :obj:`None`. If it fails, return an HTTP 401 Unauthorized responss. + + """ + try: + authorization = request.headers["Authorization"] + except KeyError: + response = connection.respond( + http.HTTPStatus.UNAUTHORIZED, + "Missing credentials\n", + ) + response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) + return response + + try: + username, password = parse_authorization_basic(authorization) + except InvalidHeader: + response = connection.respond( + http.HTTPStatus.UNAUTHORIZED, + "Unsupported credentials\n", + ) + response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) + return response + + valid_credentials = check_credentials(username, password) + if isinstance(valid_credentials, Awaitable): + valid_credentials = await valid_credentials + + if not valid_credentials: + response = connection.respond( + http.HTTPStatus.UNAUTHORIZED, + "Invalid credentials\n", + ) + response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) + return response + + connection.username = username + return None + + return process_request diff --git a/source/websockets/auth.py b/source/websockets/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..15b70a3727b2eb3202fc87173ad2fc8b742cf72c --- /dev/null +++ b/source/websockets/auth.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +import warnings + + +with warnings.catch_warnings(): + # Suppress redundant DeprecationWarning raised by websockets.legacy. + warnings.filterwarnings("ignore", category=DeprecationWarning) + from .legacy.auth import * + from .legacy.auth import __all__ # noqa: F401 + + +warnings.warn( # deprecated in 14.0 - 2024-11-09 + "websockets.auth, an alias for websockets.legacy.auth, is deprecated; " + "see https://websockets.readthedocs.io/en/stable/howto/upgrade.html " + "for upgrade instructions", + DeprecationWarning, +) diff --git a/source/websockets/cli.py b/source/websockets/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..e084b62a9ac250a7b20a1b7a7a18fe10e53db3a6 --- /dev/null +++ b/source/websockets/cli.py @@ -0,0 +1,178 @@ +from __future__ import annotations + +import argparse +import asyncio +import os +import sys +from typing import Generator + +from .asyncio.client import ClientConnection, connect +from .asyncio.messages import SimpleQueue +from .exceptions import ConnectionClosed +from .frames import Close +from .streams import StreamReader +from .version import version as websockets_version + + +__all__ = ["main"] + + +def print_during_input(string: str) -> None: + sys.stdout.write( + # Save cursor position + "\N{ESC}7" + # Add a new line + "\N{LINE FEED}" + # Move cursor up + "\N{ESC}[A" + # Insert blank line, scroll last line down + "\N{ESC}[L" + # Print string in the inserted blank line + f"{string}\N{LINE FEED}" + # Restore cursor position + "\N{ESC}8" + # Move cursor down + "\N{ESC}[B" + ) + sys.stdout.flush() + + +def print_over_input(string: str) -> None: + sys.stdout.write( + # Move cursor to beginning of line + "\N{CARRIAGE RETURN}" + # Delete current line + "\N{ESC}[K" + # Print string + f"{string}\N{LINE FEED}" + ) + sys.stdout.flush() + + +class ReadLines(asyncio.Protocol): + def __init__(self) -> None: + self.reader = StreamReader() + self.messages: SimpleQueue[str] = SimpleQueue() + + def parse(self) -> Generator[None, None, None]: + while True: + sys.stdout.write("> ") + sys.stdout.flush() + line = yield from self.reader.read_line(sys.maxsize) + self.messages.put(line.decode().rstrip("\r\n")) + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + self.parser = self.parse() + next(self.parser) + + def data_received(self, data: bytes) -> None: + self.reader.feed_data(data) + next(self.parser) + + def eof_received(self) -> None: + self.reader.feed_eof() + # next(self.parser) isn't useful and would raise EOFError. + + def connection_lost(self, exc: Exception | None) -> None: + self.reader.discard() + self.messages.abort() + + +async def print_incoming_messages(websocket: ClientConnection) -> None: + async for message in websocket: + if isinstance(message, str): + print_during_input("< " + message) + else: + print_during_input("< (binary) " + message.hex()) + + +async def send_outgoing_messages( + websocket: ClientConnection, + messages: SimpleQueue[str], +) -> None: + while True: + try: + message = await messages.get() + except EOFError: + break + try: + await websocket.send(message) + except ConnectionClosed: # pragma: no cover + break + + +async def interactive_client(uri: str) -> None: + try: + websocket = await connect(uri) + except Exception as exc: + print(f"Failed to connect to {uri}: {exc}.") + sys.exit(1) + else: + print(f"Connected to {uri}.") + + loop = asyncio.get_running_loop() + transport, protocol = await loop.connect_read_pipe(ReadLines, sys.stdin) + incoming = asyncio.create_task( + print_incoming_messages(websocket), + ) + outgoing = asyncio.create_task( + send_outgoing_messages(websocket, protocol.messages), + ) + try: + await asyncio.wait( + [incoming, outgoing], + # Clean up and exit when the server closes the connection + # or the user enters EOT (^D), whichever happens first. + return_when=asyncio.FIRST_COMPLETED, + ) + # asyncio.run() cancels the main task when the user triggers SIGINT (^C). + # https://docs.python.org/3/library/asyncio-runner.html#handling-keyboard-interruption + # Clean up and exit without re-raising CancelledError to prevent Python + # from raising KeyboardInterrupt and displaying a stack track. + except asyncio.CancelledError: # pragma: no cover + pass + finally: + incoming.cancel() + outgoing.cancel() + transport.close() + + await websocket.close() + assert websocket.close_code is not None and websocket.close_reason is not None + close_status = Close(websocket.close_code, websocket.close_reason) + print_over_input(f"Connection closed: {close_status}.") + + +def main(argv: list[str] | None = None) -> None: + parser = argparse.ArgumentParser( + prog="websockets", + description="Interactive WebSocket client.", + add_help=False, + ) + group = parser.add_mutually_exclusive_group() + group.add_argument("--version", action="store_true") + group.add_argument("uri", metavar="", nargs="?") + args = parser.parse_args(argv) + + if args.version: + print(f"websockets {websockets_version}") + return + + if args.uri is None: + parser.print_usage() + sys.exit(2) + + # Enable VT100 to support ANSI escape codes in Command Prompt on Windows. + # See https://github.com/python/cpython/issues/74261 for why this works. + if sys.platform == "win32": + os.system("") + + try: + import readline # noqa: F401 + except ImportError: # readline isn't available on all platforms + pass + + # Remove the try/except block when dropping Python < 3.11. + try: + asyncio.run(interactive_client(args.uri)) + except KeyboardInterrupt: # pragma: no cover + pass diff --git a/source/websockets/client.py b/source/websockets/client.py new file mode 100644 index 0000000000000000000000000000000000000000..0fbcda60ccfc6646ed4b81d0f9681d0c6b391510 --- /dev/null +++ b/source/websockets/client.py @@ -0,0 +1,391 @@ +from __future__ import annotations + +import os +import random +import warnings +from collections.abc import Generator, Sequence +from typing import Any + +from .datastructures import Headers, MultipleValuesError +from .exceptions import ( + InvalidHandshake, + InvalidHeader, + InvalidHeaderValue, + InvalidMessage, + InvalidStatus, + InvalidUpgrade, + NegotiationError, +) +from .extensions import ClientExtensionFactory, Extension +from .headers import ( + build_authorization_basic, + build_extension, + build_host, + build_subprotocol, + parse_connection, + parse_extension, + parse_subprotocol, + parse_upgrade, +) +from .http11 import Request, Response +from .imports import lazy_import +from .protocol import CLIENT, CONNECTING, OPEN, Protocol, State +from .typing import ( + ConnectionOption, + ExtensionHeader, + LoggerLike, + Origin, + Subprotocol, + UpgradeProtocol, +) +from .uri import WebSocketURI +from .utils import accept_key, generate_key + + +__all__ = ["ClientProtocol"] + + +class ClientProtocol(Protocol): + """ + Sans-I/O implementation of a WebSocket client connection. + + Args: + uri: URI of the WebSocket server, parsed + with :func:`~websockets.uri.parse_uri`. + origin: Value of the ``Origin`` header. This is useful when connecting + to a server that validates the ``Origin`` header to defend against + Cross-Site WebSocket Hijacking attacks. + extensions: List of supported extensions, in order in which they + should be tried. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + state: Initial state of the WebSocket connection. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. You may pass a ``(max_message_size, + max_fragment_size)`` tuple to set different limits for messages and + fragments when you expect long messages sent in short fragments. + logger: Logger for this connection; + defaults to ``logging.getLogger("websockets.client")``; + see the :doc:`logging guide <../../topics/logging>` for details. + + """ + + def __init__( + self, + uri: WebSocketURI, + *, + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + state: State = CONNECTING, + max_size: int | None | tuple[int | None, int | None] = 2**20, + logger: LoggerLike | None = None, + ) -> None: + super().__init__( + side=CLIENT, + state=state, + max_size=max_size, + logger=logger, + ) + self.uri = uri + self.origin = origin + self.available_extensions = extensions + self.available_subprotocols = subprotocols + self.key = generate_key() + + def connect(self) -> Request: + """ + Create a handshake request to open a connection. + + You must send the handshake request with :meth:`send_request`. + + You can modify it before sending it, for example to add HTTP headers. + + Returns: + WebSocket handshake request event to send to the server. + + """ + headers = Headers() + headers["Host"] = build_host(self.uri.host, self.uri.port, self.uri.secure) + if self.uri.user_info: + headers["Authorization"] = build_authorization_basic(*self.uri.user_info) + if self.origin is not None: + headers["Origin"] = self.origin + headers["Upgrade"] = "websocket" + headers["Connection"] = "Upgrade" + headers["Sec-WebSocket-Key"] = self.key + headers["Sec-WebSocket-Version"] = "13" + if self.available_extensions is not None: + headers["Sec-WebSocket-Extensions"] = build_extension( + [ + (extension_factory.name, extension_factory.get_request_params()) + for extension_factory in self.available_extensions + ] + ) + if self.available_subprotocols is not None: + headers["Sec-WebSocket-Protocol"] = build_subprotocol( + self.available_subprotocols + ) + return Request(self.uri.resource_name, headers) + + def process_response(self, response: Response) -> None: + """ + Check a handshake response. + + Args: + request: WebSocket handshake response received from the server. + + Raises: + InvalidHandshake: If the handshake response is invalid. + + """ + + if response.status_code != 101: + raise InvalidStatus(response) + + headers = response.headers + + connection: list[ConnectionOption] = sum( + [parse_connection(value) for value in headers.get_all("Connection")], [] + ) + if not any(value.lower() == "upgrade" for value in connection): + raise InvalidUpgrade( + "Connection", ", ".join(connection) if connection else None + ) + + upgrade: list[UpgradeProtocol] = sum( + [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] + ) + # For compatibility with non-strict implementations, ignore case when + # checking the Upgrade header. It's supposed to be 'WebSocket'. + if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"): + raise InvalidUpgrade("Upgrade", ", ".join(upgrade) if upgrade else None) + + try: + s_w_accept = headers["Sec-WebSocket-Accept"] + except KeyError: + raise InvalidHeader("Sec-WebSocket-Accept") from None + except MultipleValuesError: + raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from None + if s_w_accept != accept_key(self.key): + raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept) + + self.extensions = self.process_extensions(headers) + self.subprotocol = self.process_subprotocol(headers) + + def process_extensions(self, headers: Headers) -> list[Extension]: + """ + Handle the Sec-WebSocket-Extensions HTTP response header. + + Check that each extension is supported, as well as its parameters. + + :rfc:`6455` leaves the rules up to the specification of each + extension. + + To provide this level of flexibility, for each extension accepted by + the server, we check for a match with each extension available in the + client configuration. If no match is found, an exception is raised. + + If several variants of the same extension are accepted by the server, + it may be configured several times, which won't make sense in general. + Extensions must implement their own requirements. For this purpose, + the list of previously accepted extensions is provided. + + Other requirements, for example related to mandatory extensions or the + order of extensions, may be implemented by overriding this method. + + Args: + headers: WebSocket handshake response headers. + + Returns: + List of accepted extensions. + + Raises: + InvalidHandshake: To abort the handshake. + + """ + accepted_extensions: list[Extension] = [] + + extensions = headers.get_all("Sec-WebSocket-Extensions") + + if extensions: + if self.available_extensions is None: + raise NegotiationError("no extensions supported") + + parsed_extensions: list[ExtensionHeader] = sum( + [parse_extension(header_value) for header_value in extensions], [] + ) + + for name, response_params in parsed_extensions: + for extension_factory in self.available_extensions: + # Skip non-matching extensions based on their name. + if extension_factory.name != name: + continue + + # Skip non-matching extensions based on their params. + try: + extension = extension_factory.process_response_params( + response_params, accepted_extensions + ) + except NegotiationError: + continue + + # Add matching extension to the final list. + accepted_extensions.append(extension) + + # Break out of the loop once we have a match. + break + + # If we didn't break from the loop, no extension in our list + # matched what the server sent. Fail the connection. + else: + raise NegotiationError( + f"Unsupported extension: " + f"name = {name}, params = {response_params}" + ) + + return accepted_extensions + + def process_subprotocol(self, headers: Headers) -> Subprotocol | None: + """ + Handle the Sec-WebSocket-Protocol HTTP response header. + + If provided, check that it contains exactly one supported subprotocol. + + Args: + headers: WebSocket handshake response headers. + + Returns: + Subprotocol, if one was selected. + + """ + subprotocol: Subprotocol | None = None + + subprotocols = headers.get_all("Sec-WebSocket-Protocol") + + if subprotocols: + if self.available_subprotocols is None: + raise NegotiationError("no subprotocols supported") + + parsed_subprotocols: Sequence[Subprotocol] = sum( + [parse_subprotocol(header_value) for header_value in subprotocols], [] + ) + if len(parsed_subprotocols) > 1: + raise InvalidHeader( + "Sec-WebSocket-Protocol", + f"multiple values: {', '.join(parsed_subprotocols)}", + ) + + subprotocol = parsed_subprotocols[0] + if subprotocol not in self.available_subprotocols: + raise NegotiationError(f"unsupported subprotocol: {subprotocol}") + + return subprotocol + + def send_request(self, request: Request) -> None: + """ + Send a handshake request to the server. + + Args: + request: WebSocket handshake request event. + + """ + if self.debug: + self.logger.debug("> GET %s HTTP/1.1", request.path) + for key, value in request.headers.raw_items(): + self.logger.debug("> %s: %s", key, value) + + self.writes.append(request.serialize()) + + def parse(self) -> Generator[None]: + if self.state is CONNECTING: + try: + response = yield from Response.parse( + self.reader.read_line, + self.reader.read_exact, + self.reader.read_to_eof, + ) + except Exception as exc: + self.handshake_exc = InvalidMessage( + "did not receive a valid HTTP response" + ) + self.handshake_exc.__cause__ = exc + self.send_eof() + self.parser = self.discard() + next(self.parser) # start coroutine + yield + + if self.debug: + code, phrase = response.status_code, response.reason_phrase + self.logger.debug("< HTTP/1.1 %d %s", code, phrase) + for key, value in response.headers.raw_items(): + self.logger.debug("< %s: %s", key, value) + if response.body: + self.logger.debug("< [body] (%d bytes)", len(response.body)) + + try: + self.process_response(response) + except InvalidHandshake as exc: + response._exception = exc + self.events.append(response) + self.handshake_exc = exc + self.send_eof() + self.parser = self.discard() + next(self.parser) # start coroutine + yield + + assert self.state is CONNECTING + self.state = OPEN + self.events.append(response) + + yield from super().parse() + + +class ClientConnection(ClientProtocol): + def __init__(self, *args: Any, **kwargs: Any) -> None: + warnings.warn( # deprecated in 11.0 - 2023-04-02 + "ClientConnection was renamed to ClientProtocol", + DeprecationWarning, + ) + super().__init__(*args, **kwargs) + + +BACKOFF_INITIAL_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_INITIAL_DELAY", "5")) +BACKOFF_MIN_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_MIN_DELAY", "3.1")) +BACKOFF_MAX_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_MAX_DELAY", "90.0")) +BACKOFF_FACTOR = float(os.environ.get("WEBSOCKETS_BACKOFF_FACTOR", "1.618")) + + +def backoff( + initial_delay: float = BACKOFF_INITIAL_DELAY, + min_delay: float = BACKOFF_MIN_DELAY, + max_delay: float = BACKOFF_MAX_DELAY, + factor: float = BACKOFF_FACTOR, +) -> Generator[float]: + """ + Generate a series of backoff delays between reconnection attempts. + + Yields: + How many seconds to wait before retrying to connect. + + """ + # Add a random initial delay between 0 and 5 seconds. + # See 7.2.3. Recovering from Abnormal Closure in RFC 6455. + yield random.random() * initial_delay + delay = min_delay + while delay < max_delay: + yield delay + delay *= factor + while True: + yield max_delay + + +lazy_import( + globals(), + deprecated_aliases={ + # deprecated in 14.0 - 2024-11-09 + "WebSocketClientProtocol": ".legacy.client", + "connect": ".legacy.client", + "unix_connect": ".legacy.client", + }, +) diff --git a/source/websockets/connection.py b/source/websockets/connection.py new file mode 100644 index 0000000000000000000000000000000000000000..5e78e34479224d0332b165badd67a8933e0c73db --- /dev/null +++ b/source/websockets/connection.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +import warnings + +from .protocol import SEND_EOF, Protocol as Connection, Side, State # noqa: F401 + + +warnings.warn( # deprecated in 11.0 - 2023-04-02 + "websockets.connection was renamed to websockets.protocol " + "and Connection was renamed to Protocol", + DeprecationWarning, +) diff --git a/source/websockets/datastructures.py b/source/websockets/datastructures.py new file mode 100644 index 0000000000000000000000000000000000000000..6d5d66d9aef166965dc66a592ef5298ad2a21da3 --- /dev/null +++ b/source/websockets/datastructures.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +from collections.abc import Iterable, Iterator, Mapping, MutableMapping +from typing import Any, Protocol + + +__all__ = [ + "Headers", + "HeadersLike", + "MultipleValuesError", +] + + +class MultipleValuesError(LookupError): + """ + Exception raised when :class:`Headers` has multiple values for a key. + + """ + + def __str__(self) -> str: + # Implement the same logic as KeyError_str in Objects/exceptions.c. + if len(self.args) == 1: + return repr(self.args[0]) + return super().__str__() + + +class Headers(MutableMapping[str, str]): + """ + Efficient data structure for manipulating HTTP headers. + + A :class:`list` of ``(name, values)`` is inefficient for lookups. + + A :class:`dict` doesn't suffice because header names are case-insensitive + and multiple occurrences of headers with the same name are possible. + + :class:`Headers` stores HTTP headers in a hybrid data structure to provide + efficient insertions and lookups while preserving the original data. + + In order to account for multiple values with minimal hassle, + :class:`Headers` follows this logic: + + - When getting a header with ``headers[name]``: + - if there's no value, :exc:`KeyError` is raised; + - if there's exactly one value, it's returned; + - if there's more than one value, :exc:`MultipleValuesError` is raised. + + - When setting a header with ``headers[name] = value``, the value is + appended to the list of values for that header. + + - When deleting a header with ``del headers[name]``, all values for that + header are removed (this is slow). + + Other methods for manipulating headers are consistent with this logic. + + As long as no header occurs multiple times, :class:`Headers` behaves like + :class:`dict`, except keys are lower-cased to provide case-insensitivity. + + Two methods support manipulating multiple values explicitly: + + - :meth:`get_all` returns a list of all values for a header; + - :meth:`raw_items` returns an iterator of ``(name, values)`` pairs. + + """ + + __slots__ = ["_dict", "_list"] + + # Like dict, Headers accepts an optional "mapping or iterable" argument. + def __init__(self, *args: HeadersLike, **kwargs: str) -> None: + self._dict: dict[str, list[str]] = {} + self._list: list[tuple[str, str]] = [] + self.update(*args, **kwargs) + + def __str__(self) -> str: + return "".join(f"{key}: {value}\r\n" for key, value in self._list) + "\r\n" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self._list!r})" + + def copy(self) -> Headers: + copy = self.__class__() + copy._dict = self._dict.copy() + copy._list = self._list.copy() + return copy + + def serialize(self) -> bytes: + # Since headers only contain ASCII characters, we can keep this simple. + return str(self).encode() + + # Collection methods + + def __contains__(self, key: object) -> bool: + return isinstance(key, str) and key.lower() in self._dict + + def __iter__(self) -> Iterator[str]: + return iter(self._dict) + + def __len__(self) -> int: + return len(self._dict) + + # MutableMapping methods + + def __getitem__(self, key: str) -> str: + value = self._dict[key.lower()] + if len(value) == 1: + return value[0] + else: + raise MultipleValuesError(key) + + def __setitem__(self, key: str, value: str) -> None: + self._dict.setdefault(key.lower(), []).append(value) + self._list.append((key, value)) + + def __delitem__(self, key: str) -> None: + key_lower = key.lower() + self._dict.__delitem__(key_lower) + # This is inefficient. Fortunately deleting HTTP headers is uncommon. + self._list = [(k, v) for k, v in self._list if k.lower() != key_lower] + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Headers): + return NotImplemented + return self._dict == other._dict + + def clear(self) -> None: + """ + Remove all headers. + + """ + self._dict = {} + self._list = [] + + def update(self, *args: HeadersLike, **kwargs: str) -> None: + """ + Update from a :class:`Headers` instance and/or keyword arguments. + + """ + args = tuple( + arg.raw_items() if isinstance(arg, Headers) else arg for arg in args + ) + super().update(*args, **kwargs) + + # Methods for handling multiple values + + def get_all(self, key: str) -> list[str]: + """ + Return the (possibly empty) list of all values for a header. + + Args: + key: Header name. + + """ + return self._dict.get(key.lower(), []) + + def raw_items(self) -> Iterator[tuple[str, str]]: + """ + Return an iterator of all values as ``(name, value)`` pairs. + + """ + return iter(self._list) + + +# copy of _typeshed.SupportsKeysAndGetItem. +class SupportsKeysAndGetItem(Protocol): # pragma: no cover + """ + Dict-like types with ``keys() -> str`` and ``__getitem__(key: str) -> str`` methods. + + """ + + def keys(self) -> Iterable[str]: ... + + def __getitem__(self, key: str) -> str: ... + + +HeadersLike = ( + Headers | Mapping[str, str] | Iterable[tuple[str, str]] | SupportsKeysAndGetItem +) +""" +Types accepted where :class:`Headers` is expected. + +In addition to :class:`Headers` itself, this includes dict-like types where both +keys and values are :class:`str`. + +""" diff --git a/source/websockets/exceptions.py b/source/websockets/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..a88deaa66f6acf8d487d36fc0777c29016d0903d --- /dev/null +++ b/source/websockets/exceptions.py @@ -0,0 +1,473 @@ +""" +:mod:`websockets.exceptions` defines the following hierarchy of exceptions. + +* :exc:`WebSocketException` + * :exc:`ConnectionClosed` + * :exc:`ConnectionClosedOK` + * :exc:`ConnectionClosedError` + * :exc:`InvalidURI` + * :exc:`InvalidProxy` + * :exc:`InvalidHandshake` + * :exc:`SecurityError` + * :exc:`ProxyError` + * :exc:`InvalidProxyMessage` + * :exc:`InvalidProxyStatus` + * :exc:`InvalidMessage` + * :exc:`InvalidStatus` + * :exc:`InvalidStatusCode` (legacy) + * :exc:`InvalidHeader` + * :exc:`InvalidHeaderFormat` + * :exc:`InvalidHeaderValue` + * :exc:`InvalidOrigin` + * :exc:`InvalidUpgrade` + * :exc:`NegotiationError` + * :exc:`DuplicateParameter` + * :exc:`InvalidParameterName` + * :exc:`InvalidParameterValue` + * :exc:`AbortHandshake` (legacy) + * :exc:`RedirectHandshake` (legacy) + * :exc:`ProtocolError` (Sans-I/O) + * :exc:`PayloadTooBig` (Sans-I/O) + * :exc:`InvalidState` (Sans-I/O) + * :exc:`ConcurrencyError` + +""" + +from __future__ import annotations + +import warnings + +from .imports import lazy_import + + +__all__ = [ + "WebSocketException", + "ConnectionClosed", + "ConnectionClosedOK", + "ConnectionClosedError", + "InvalidURI", + "InvalidProxy", + "InvalidHandshake", + "SecurityError", + "ProxyError", + "InvalidProxyMessage", + "InvalidProxyStatus", + "InvalidMessage", + "InvalidStatus", + "InvalidHeader", + "InvalidHeaderFormat", + "InvalidHeaderValue", + "InvalidOrigin", + "InvalidUpgrade", + "NegotiationError", + "DuplicateParameter", + "InvalidParameterName", + "InvalidParameterValue", + "ProtocolError", + "PayloadTooBig", + "InvalidState", + "ConcurrencyError", +] + + +class WebSocketException(Exception): + """ + Base class for all exceptions defined by websockets. + + """ + + +class ConnectionClosed(WebSocketException): + """ + Raised when trying to interact with a closed connection. + + Attributes: + rcvd: If a close frame was received, its code and reason are available + in ``rcvd.code`` and ``rcvd.reason``. + sent: If a close frame was sent, its code and reason are available + in ``sent.code`` and ``sent.reason``. + rcvd_then_sent: If close frames were received and sent, this attribute + tells in which order this happened, from the perspective of this + side of the connection. + + """ + + def __init__( + self, + rcvd: frames.Close | None, + sent: frames.Close | None, + rcvd_then_sent: bool | None = None, + ) -> None: + self.rcvd = rcvd + self.sent = sent + self.rcvd_then_sent = rcvd_then_sent + assert (self.rcvd_then_sent is None) == (self.rcvd is None or self.sent is None) + + def __str__(self) -> str: + if self.rcvd is None: + if self.sent is None: + return "no close frame received or sent" + else: + return f"sent {self.sent}; no close frame received" + else: + if self.sent is None: + return f"received {self.rcvd}; no close frame sent" + else: + if self.rcvd_then_sent: + return f"received {self.rcvd}; then sent {self.sent}" + else: + return f"sent {self.sent}; then received {self.rcvd}" + + # code and reason attributes are provided for backwards-compatibility + + @property + def code(self) -> int: + warnings.warn( # deprecated in 13.1 - 2024-09-21 + "ConnectionClosed.code is deprecated; " + "use Protocol.close_code or ConnectionClosed.rcvd.code", + DeprecationWarning, + ) + if self.rcvd is None: + return frames.CloseCode.ABNORMAL_CLOSURE + return self.rcvd.code + + @property + def reason(self) -> str: + warnings.warn( # deprecated in 13.1 - 2024-09-21 + "ConnectionClosed.reason is deprecated; " + "use Protocol.close_reason or ConnectionClosed.rcvd.reason", + DeprecationWarning, + ) + if self.rcvd is None: + return "" + return self.rcvd.reason + + +class ConnectionClosedOK(ConnectionClosed): + """ + Like :exc:`ConnectionClosed`, when the connection terminated properly. + + A close code with code 1000 (OK) or 1001 (going away) or without a code was + received and sent. + + """ + + +class ConnectionClosedError(ConnectionClosed): + """ + Like :exc:`ConnectionClosed`, when the connection terminated with an error. + + A close frame with a code other than 1000 (OK) or 1001 (going away) was + received or sent, or the closing handshake didn't complete properly. + + """ + + +class InvalidURI(WebSocketException): + """ + Raised when connecting to a URI that isn't a valid WebSocket URI. + + """ + + def __init__(self, uri: str, msg: str) -> None: + self.uri = uri + self.msg = msg + + def __str__(self) -> str: + return f"{self.uri} isn't a valid URI: {self.msg}" + + +class InvalidProxy(WebSocketException): + """ + Raised when connecting via a proxy that isn't valid. + + """ + + def __init__(self, proxy: str, msg: str) -> None: + self.proxy = proxy + self.msg = msg + + def __str__(self) -> str: + return f"{self.proxy} isn't a valid proxy: {self.msg}" + + +class InvalidHandshake(WebSocketException): + """ + Base class for exceptions raised when the opening handshake fails. + + """ + + +class SecurityError(InvalidHandshake): + """ + Raised when a handshake request or response breaks a security rule. + + Security limits can be configured with :doc:`environment variables + <../reference/variables>`. + + """ + + +class ProxyError(InvalidHandshake): + """ + Raised when failing to connect to a proxy. + + """ + + +class InvalidProxyMessage(ProxyError): + """ + Raised when an HTTP proxy response is malformed. + + """ + + +class InvalidProxyStatus(ProxyError): + """ + Raised when an HTTP proxy rejects the connection. + + """ + + def __init__(self, response: http11.Response) -> None: + self.response = response + + def __str__(self) -> str: + return f"proxy rejected connection: HTTP {self.response.status_code:d}" + + +class InvalidMessage(InvalidHandshake): + """ + Raised when a handshake request or response is malformed. + + """ + + +class InvalidStatus(InvalidHandshake): + """ + Raised when a handshake response rejects the WebSocket upgrade. + + """ + + def __init__(self, response: http11.Response) -> None: + self.response = response + + def __str__(self) -> str: + return ( + f"server rejected WebSocket connection: HTTP {self.response.status_code:d}" + ) + + +class InvalidHeader(InvalidHandshake): + """ + Raised when an HTTP header doesn't have a valid format or value. + + """ + + def __init__(self, name: str, value: str | None = None) -> None: + self.name = name + self.value = value + + def __str__(self) -> str: + if self.value is None: + return f"missing {self.name} header" + elif self.value == "": + return f"empty {self.name} header" + else: + return f"invalid {self.name} header: {self.value}" + + +class InvalidHeaderFormat(InvalidHeader): + """ + Raised when an HTTP header cannot be parsed. + + The format of the header doesn't match the grammar for that header. + + """ + + def __init__(self, name: str, error: str, header: str, pos: int) -> None: + super().__init__(name, f"{error} at {pos} in {header}") + + +class InvalidHeaderValue(InvalidHeader): + """ + Raised when an HTTP header has a wrong value. + + The format of the header is correct but the value isn't acceptable. + + """ + + +class InvalidOrigin(InvalidHeader): + """ + Raised when the Origin header in a request isn't allowed. + + """ + + def __init__(self, origin: str | None) -> None: + super().__init__("Origin", origin) + + +class InvalidUpgrade(InvalidHeader): + """ + Raised when the Upgrade or Connection header isn't correct. + + """ + + +class NegotiationError(InvalidHandshake): + """ + Raised when negotiating an extension or a subprotocol fails. + + """ + + +class DuplicateParameter(NegotiationError): + """ + Raised when a parameter name is repeated in an extension header. + + """ + + def __init__(self, name: str) -> None: + self.name = name + + def __str__(self) -> str: + return f"duplicate parameter: {self.name}" + + +class InvalidParameterName(NegotiationError): + """ + Raised when a parameter name in an extension header is invalid. + + """ + + def __init__(self, name: str) -> None: + self.name = name + + def __str__(self) -> str: + return f"invalid parameter name: {self.name}" + + +class InvalidParameterValue(NegotiationError): + """ + Raised when a parameter value in an extension header is invalid. + + """ + + def __init__(self, name: str, value: str | None) -> None: + self.name = name + self.value = value + + def __str__(self) -> str: + if self.value is None: + return f"missing value for parameter {self.name}" + elif self.value == "": + return f"empty value for parameter {self.name}" + else: + return f"invalid value for parameter {self.name}: {self.value}" + + +class ProtocolError(WebSocketException): + """ + Raised when receiving or sending a frame that breaks the protocol. + + The Sans-I/O implementation raises this exception when: + + * receiving or sending a frame that contains invalid data; + * receiving or sending an invalid sequence of frames. + + """ + + +class PayloadTooBig(WebSocketException): + """ + Raised when parsing a frame with a payload that exceeds the maximum size. + + The Sans-I/O layer uses this exception internally. It doesn't bubble up to + the I/O layer. + + The :meth:`~websockets.extensions.Extension.decode` method of extensions + must raise :exc:`PayloadTooBig` if decoding a frame would exceed the limit. + + """ + + def __init__( + self, + size_or_message: int | None | str, + max_size: int | None = None, + current_size: int | None = None, + ) -> None: + if isinstance(size_or_message, str): + assert max_size is None + assert current_size is None + warnings.warn( # deprecated in 14.0 - 2024-11-09 + "PayloadTooBig(message) is deprecated; " + "change to PayloadTooBig(size, max_size)", + DeprecationWarning, + ) + self.message: str | None = size_or_message + else: + self.message = None + self.size: int | None = size_or_message + assert max_size is not None + self.max_size: int = max_size + self.current_size: int | None = None + self.set_current_size(current_size) + + def __str__(self) -> str: + if self.message is not None: + return self.message + else: + message = "frame " + if self.size is not None: + message += f"with {self.size} bytes " + if self.current_size is not None: + message += f"after reading {self.current_size} bytes " + message += f"exceeds limit of {self.max_size} bytes" + return message + + def set_current_size(self, current_size: int | None) -> None: + assert self.current_size is None + if current_size is not None: + self.max_size += current_size + self.current_size = current_size + + +class InvalidState(WebSocketException, AssertionError): + """ + Raised when sending a frame is forbidden in the current state. + + Specifically, the Sans-I/O layer raises this exception when: + + * sending a data frame to a connection in a state other + :attr:`~websockets.protocol.State.OPEN`; + * sending a control frame to a connection in a state other than + :attr:`~websockets.protocol.State.OPEN` or + :attr:`~websockets.protocol.State.CLOSING`. + + """ + + +class ConcurrencyError(WebSocketException, RuntimeError): + """ + Raised when receiving or sending messages concurrently. + + WebSocket is a connection-oriented protocol. Reads must be serialized; so + must be writes. However, reading and writing concurrently is possible. + + """ + + +# At the bottom to break import cycles created by type annotations. +from . import frames, http11 # noqa: E402 + + +lazy_import( + globals(), + deprecated_aliases={ + # deprecated in 14.0 - 2024-11-09 + "AbortHandshake": ".legacy.exceptions", + "InvalidStatusCode": ".legacy.exceptions", + "RedirectHandshake": ".legacy.exceptions", + "WebSocketProtocolError": ".legacy.exceptions", + }, +) diff --git a/source/websockets/extensions/__init__.py b/source/websockets/extensions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..02838b98a5335322daad566de9c0d9d0843fc49a --- /dev/null +++ b/source/websockets/extensions/__init__.py @@ -0,0 +1,4 @@ +from .base import * + + +__all__ = ["Extension", "ClientExtensionFactory", "ServerExtensionFactory"] diff --git a/source/websockets/extensions/base.py b/source/websockets/extensions/base.py new file mode 100644 index 0000000000000000000000000000000000000000..2fdc59f0fdae4d28fdcf18b6f9edf8d62ad22f01 --- /dev/null +++ b/source/websockets/extensions/base.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +from collections.abc import Sequence + +from ..frames import Frame +from ..typing import ExtensionName, ExtensionParameter + + +__all__ = ["Extension", "ClientExtensionFactory", "ServerExtensionFactory"] + + +class Extension: + """ + Base class for extensions. + + """ + + name: ExtensionName + """Extension identifier.""" + + def decode(self, frame: Frame, *, max_size: int | None = None) -> Frame: + """ + Decode an incoming frame. + + Args: + frame: Incoming frame. + max_size: Maximum payload size in bytes. + + Returns: + Decoded frame. + + Raises: + PayloadTooBig: If decoding the payload exceeds ``max_size``. + + """ + raise NotImplementedError + + def encode(self, frame: Frame) -> Frame: + """ + Encode an outgoing frame. + + Args: + frame: Outgoing frame. + + Returns: + Encoded frame. + + """ + raise NotImplementedError + + +class ClientExtensionFactory: + """ + Base class for client-side extension factories. + + """ + + name: ExtensionName + """Extension identifier.""" + + def get_request_params(self) -> Sequence[ExtensionParameter]: + """ + Build parameters to send to the server for this extension. + + Returns: + Parameters to send to the server. + + """ + raise NotImplementedError + + def process_response_params( + self, + params: Sequence[ExtensionParameter], + accepted_extensions: Sequence[Extension], + ) -> Extension: + """ + Process parameters received from the server. + + Args: + params: Parameters received from the server for this extension. + accepted_extensions: List of previously accepted extensions. + + Returns: + An extension instance. + + Raises: + NegotiationError: If parameters aren't acceptable. + + """ + raise NotImplementedError + + +class ServerExtensionFactory: + """ + Base class for server-side extension factories. + + """ + + name: ExtensionName + """Extension identifier.""" + + def process_request_params( + self, + params: Sequence[ExtensionParameter], + accepted_extensions: Sequence[Extension], + ) -> tuple[list[ExtensionParameter], Extension]: + """ + Process parameters received from the client. + + Args: + params: Parameters received from the client for this extension. + accepted_extensions: List of previously accepted extensions. + + Returns: + To accept the offer, parameters to send to the client for this + extension and an extension instance. + + Raises: + NegotiationError: To reject the offer, if parameters received from + the client aren't acceptable. + + """ + raise NotImplementedError diff --git a/source/websockets/extensions/permessage_deflate.py b/source/websockets/extensions/permessage_deflate.py new file mode 100644 index 0000000000000000000000000000000000000000..2bc63d79929321007578d3e1d96338f099fad085 --- /dev/null +++ b/source/websockets/extensions/permessage_deflate.py @@ -0,0 +1,699 @@ +from __future__ import annotations + +import zlib +from collections.abc import Sequence +from typing import Any, Literal + +from .. import frames +from ..exceptions import ( + DuplicateParameter, + InvalidParameterName, + InvalidParameterValue, + NegotiationError, + PayloadTooBig, + ProtocolError, +) +from ..typing import BytesLike, ExtensionName, ExtensionParameter +from .base import ClientExtensionFactory, Extension, ServerExtensionFactory + + +__all__ = [ + "PerMessageDeflate", + "ClientPerMessageDeflateFactory", + "enable_client_permessage_deflate", + "ServerPerMessageDeflateFactory", + "enable_server_permessage_deflate", +] + +_EMPTY_UNCOMPRESSED_BLOCK = b"\x00\x00\xff\xff" + +_MAX_WINDOW_BITS_VALUES = [str(bits) for bits in range(8, 16)] + + +class PerMessageDeflate(Extension): + """ + Per-Message Deflate extension. + + """ + + name = ExtensionName("permessage-deflate") + + def __init__( + self, + remote_no_context_takeover: bool, + local_no_context_takeover: bool, + remote_max_window_bits: int, + local_max_window_bits: int, + compress_settings: dict[Any, Any] | None = None, + ) -> None: + """ + Configure the Per-Message Deflate extension. + + """ + if compress_settings is None: + compress_settings = {} + + assert remote_no_context_takeover in [False, True] + assert local_no_context_takeover in [False, True] + assert 8 <= remote_max_window_bits <= 15 + assert 8 <= local_max_window_bits <= 15 + assert "wbits" not in compress_settings + + self.remote_no_context_takeover = remote_no_context_takeover + self.local_no_context_takeover = local_no_context_takeover + self.remote_max_window_bits = remote_max_window_bits + self.local_max_window_bits = local_max_window_bits + self.compress_settings = compress_settings + + if not self.remote_no_context_takeover: + self.decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits) + + if not self.local_no_context_takeover: + self.encoder = zlib.compressobj( + wbits=-self.local_max_window_bits, + **self.compress_settings, + ) + + # To handle continuation frames properly, we must keep track of + # whether that initial frame was encoded. + self.decode_cont_data = False + # There's no need for self.encode_cont_data because we always encode + # outgoing frames, so it would always be True. + + def __repr__(self) -> str: + return ( + f"PerMessageDeflate(" + f"remote_no_context_takeover={self.remote_no_context_takeover}, " + f"local_no_context_takeover={self.local_no_context_takeover}, " + f"remote_max_window_bits={self.remote_max_window_bits}, " + f"local_max_window_bits={self.local_max_window_bits})" + ) + + def decode( + self, + frame: frames.Frame, + *, + max_size: int | None = None, + ) -> frames.Frame: + """ + Decode an incoming frame. + + """ + # Skip control frames. + if frame.opcode in frames.CTRL_OPCODES: + return frame + + # Handle continuation data frames: + # - skip if the message isn't encoded + # - reset "decode continuation data" flag if it's a final frame + if frame.opcode is frames.OP_CONT: + if not self.decode_cont_data: + return frame + if frame.fin: + self.decode_cont_data = False + + # Handle text and binary data frames: + # - skip if the message isn't encoded + # - unset the rsv1 flag on the first frame of a compressed message + # - set "decode continuation data" flag if it's a non-final frame + else: + if not frame.rsv1: + return frame + if not frame.fin: + self.decode_cont_data = True + + # Re-initialize per-message decoder. + if self.remote_no_context_takeover: + self.decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits) + + # Uncompress data. Protect against zip bombs by preventing zlib from + # decompressing more than max_length bytes (except when the limit is + # disabled with max_size = None). + data: BytesLike + if frame.fin and len(frame.data) < 2044: + # Profiling shows that appending four bytes, which makes a copy, is + # faster than calling decompress() again when data is less than 2kB. + data = bytes(frame.data) + _EMPTY_UNCOMPRESSED_BLOCK + else: + data = frame.data + max_length = 0 if max_size is None else max_size + try: + data = self.decoder.decompress(data, max_length) + if self.decoder.unconsumed_tail: + assert max_size is not None # help mypy + raise PayloadTooBig(None, max_size) + if frame.fin and len(frame.data) >= 2044: + # This cannot generate additional data. + self.decoder.decompress(_EMPTY_UNCOMPRESSED_BLOCK) + except zlib.error as exc: + raise ProtocolError("decompression failed") from exc + + # Allow garbage collection of the decoder if it won't be reused. + if frame.fin and self.remote_no_context_takeover: + del self.decoder + + return frames.Frame( + frame.opcode, + data, + frame.fin, + # Unset the rsv1 flag on the first frame of a compressed message. + False, + frame.rsv2, + frame.rsv3, + ) + + def encode(self, frame: frames.Frame) -> frames.Frame: + """ + Encode an outgoing frame. + + """ + # Skip control frames. + if frame.opcode in frames.CTRL_OPCODES: + return frame + + # Since we always encode messages, there's no "encode continuation + # data" flag similar to "decode continuation data" at this time. + + if frame.opcode is not frames.OP_CONT: + # Re-initialize per-message decoder. + if self.local_no_context_takeover: + self.encoder = zlib.compressobj( + wbits=-self.local_max_window_bits, + **self.compress_settings, + ) + + # Compress data. + data: BytesLike + data = self.encoder.compress(frame.data) + self.encoder.flush(zlib.Z_SYNC_FLUSH) + if frame.fin: + # Sync flush generates between 5 or 6 bytes, ending with the bytes + # 0x00 0x00 0xff 0xff, which must be removed. + assert data[-4:] == _EMPTY_UNCOMPRESSED_BLOCK + # Making a copy is faster than memoryview(a)[:-4] until 2kB. + if len(data) < 2048: + data = data[:-4] + else: + data = memoryview(data)[:-4] + + # Allow garbage collection of the encoder if it won't be reused. + if frame.fin and self.local_no_context_takeover: + del self.encoder + + return frames.Frame( + frame.opcode, + data, + frame.fin, + # Set the rsv1 flag on the first frame of a compressed message. + frame.opcode is not frames.OP_CONT, + frame.rsv2, + frame.rsv3, + ) + + +def _build_parameters( + server_no_context_takeover: bool, + client_no_context_takeover: bool, + server_max_window_bits: int | None, + client_max_window_bits: int | Literal[True] | None, +) -> list[ExtensionParameter]: + """ + Build a list of ``(name, value)`` pairs for some compression parameters. + + """ + params: list[ExtensionParameter] = [] + if server_no_context_takeover: + params.append(("server_no_context_takeover", None)) + if client_no_context_takeover: + params.append(("client_no_context_takeover", None)) + if server_max_window_bits: + params.append(("server_max_window_bits", str(server_max_window_bits))) + if client_max_window_bits is True: # only in handshake requests + params.append(("client_max_window_bits", None)) + elif client_max_window_bits: + params.append(("client_max_window_bits", str(client_max_window_bits))) + return params + + +def _extract_parameters( + params: Sequence[ExtensionParameter], *, is_server: bool +) -> tuple[bool, bool, int | None, int | Literal[True] | None]: + """ + Extract compression parameters from a list of ``(name, value)`` pairs. + + If ``is_server`` is :obj:`True`, ``client_max_window_bits`` may be + provided without a value. This is only allowed in handshake requests. + + """ + server_no_context_takeover: bool = False + client_no_context_takeover: bool = False + server_max_window_bits: int | None = None + client_max_window_bits: int | Literal[True] | None = None + + for name, value in params: + if name == "server_no_context_takeover": + if server_no_context_takeover: + raise DuplicateParameter(name) + if value is None: + server_no_context_takeover = True + else: + raise InvalidParameterValue(name, value) + + elif name == "client_no_context_takeover": + if client_no_context_takeover: + raise DuplicateParameter(name) + if value is None: + client_no_context_takeover = True + else: + raise InvalidParameterValue(name, value) + + elif name == "server_max_window_bits": + if server_max_window_bits is not None: + raise DuplicateParameter(name) + if value in _MAX_WINDOW_BITS_VALUES: + server_max_window_bits = int(value) + else: + raise InvalidParameterValue(name, value) + + elif name == "client_max_window_bits": + if client_max_window_bits is not None: + raise DuplicateParameter(name) + if is_server and value is None: # only in handshake requests + client_max_window_bits = True + elif value in _MAX_WINDOW_BITS_VALUES: + client_max_window_bits = int(value) + else: + raise InvalidParameterValue(name, value) + + else: + raise InvalidParameterName(name) + + return ( + server_no_context_takeover, + client_no_context_takeover, + server_max_window_bits, + client_max_window_bits, + ) + + +class ClientPerMessageDeflateFactory(ClientExtensionFactory): + """ + Client-side extension factory for the Per-Message Deflate extension. + + Parameters behave as described in `section 7.1 of RFC 7692`_. + + .. _section 7.1 of RFC 7692: https://datatracker.ietf.org/doc/html/rfc7692#section-7.1 + + Set them to :obj:`True` to include them in the negotiation offer without a + value or to an integer value to include them with this value. + + Args: + server_no_context_takeover: Prevent server from using context takeover. + client_no_context_takeover: Prevent client from using context takeover. + server_max_window_bits: Maximum size of the server's LZ77 sliding window + in bits, between 8 and 15. + client_max_window_bits: Maximum size of the client's LZ77 sliding window + in bits, between 8 and 15, or :obj:`True` to indicate support without + setting a limit. + compress_settings: Additional keyword arguments for :func:`zlib.compressobj`, + excluding ``wbits``. + + """ + + name = ExtensionName("permessage-deflate") + + def __init__( + self, + server_no_context_takeover: bool = False, + client_no_context_takeover: bool = False, + server_max_window_bits: int | None = None, + client_max_window_bits: int | Literal[True] | None = True, + compress_settings: dict[str, Any] | None = None, + ) -> None: + """ + Configure the Per-Message Deflate extension factory. + + """ + if not (server_max_window_bits is None or 8 <= server_max_window_bits <= 15): + raise ValueError("server_max_window_bits must be between 8 and 15") + if not ( + client_max_window_bits is None + or client_max_window_bits is True + or 8 <= client_max_window_bits <= 15 + ): + raise ValueError("client_max_window_bits must be between 8 and 15") + if compress_settings is not None and "wbits" in compress_settings: + raise ValueError( + "compress_settings must not include wbits, " + "set client_max_window_bits instead" + ) + + self.server_no_context_takeover = server_no_context_takeover + self.client_no_context_takeover = client_no_context_takeover + self.server_max_window_bits = server_max_window_bits + self.client_max_window_bits = client_max_window_bits + self.compress_settings = compress_settings + + def get_request_params(self) -> Sequence[ExtensionParameter]: + """ + Build request parameters. + + """ + return _build_parameters( + self.server_no_context_takeover, + self.client_no_context_takeover, + self.server_max_window_bits, + self.client_max_window_bits, + ) + + def process_response_params( + self, + params: Sequence[ExtensionParameter], + accepted_extensions: Sequence[Extension], + ) -> PerMessageDeflate: + """ + Process response parameters. + + Return an extension instance. + + """ + if any(other.name == self.name for other in accepted_extensions): + raise NegotiationError(f"received duplicate {self.name}") + + # Request parameters are available in instance variables. + + # Load response parameters in local variables. + ( + server_no_context_takeover, + client_no_context_takeover, + server_max_window_bits, + client_max_window_bits, + ) = _extract_parameters(params, is_server=False) + + # After comparing the request and the response, the final + # configuration must be available in the local variables. + + # server_no_context_takeover + # + # Req. Resp. Result + # ------ ------ -------------------------------------------------- + # False False False + # False True True + # True False Error! + # True True True + + if self.server_no_context_takeover: + if not server_no_context_takeover: + raise NegotiationError("expected server_no_context_takeover") + + # client_no_context_takeover + # + # Req. Resp. Result + # ------ ------ -------------------------------------------------- + # False False False + # False True True + # True False True - must change value + # True True True + + if self.client_no_context_takeover: + if not client_no_context_takeover: + client_no_context_takeover = True + + # server_max_window_bits + + # Req. Resp. Result + # ------ ------ -------------------------------------------------- + # None None None + # None 8≤M≤15 M + # 8≤N≤15 None Error! + # 8≤N≤15 8≤M≤N M + # 8≤N≤15 N self.server_max_window_bits: + raise NegotiationError("unsupported server_max_window_bits") + + # client_max_window_bits + + # Req. Resp. Result + # ------ ------ -------------------------------------------------- + # None None None + # None 8≤M≤15 Error! + # True None None + # True 8≤M≤15 M + # 8≤N≤15 None N - must change value + # 8≤N≤15 8≤M≤N M + # 8≤N≤15 N self.client_max_window_bits: + raise NegotiationError("unsupported client_max_window_bits") + + return PerMessageDeflate( + server_no_context_takeover, # remote_no_context_takeover + client_no_context_takeover, # local_no_context_takeover + server_max_window_bits or 15, # remote_max_window_bits + client_max_window_bits or 15, # local_max_window_bits + self.compress_settings, + ) + + +def enable_client_permessage_deflate( + extensions: Sequence[ClientExtensionFactory] | None, +) -> Sequence[ClientExtensionFactory]: + """ + Enable Per-Message Deflate with default settings in client extensions. + + If the extension is already present, perhaps with non-default settings, + the configuration isn't changed. + + """ + if extensions is None: + extensions = [] + if not any( + extension_factory.name == ClientPerMessageDeflateFactory.name + for extension_factory in extensions + ): + extensions = list(extensions) + [ + ClientPerMessageDeflateFactory( + compress_settings={"memLevel": 5}, + ) + ] + return extensions + + +class ServerPerMessageDeflateFactory(ServerExtensionFactory): + """ + Server-side extension factory for the Per-Message Deflate extension. + + Parameters behave as described in `section 7.1 of RFC 7692`_. + + .. _section 7.1 of RFC 7692: https://datatracker.ietf.org/doc/html/rfc7692#section-7.1 + + Set them to :obj:`True` to include them in the negotiation offer without a + value or to an integer value to include them with this value. + + Args: + server_no_context_takeover: Prevent server from using context takeover. + client_no_context_takeover: Prevent client from using context takeover. + server_max_window_bits: Maximum size of the server's LZ77 sliding window + in bits, between 8 and 15. + client_max_window_bits: Maximum size of the client's LZ77 sliding window + in bits, between 8 and 15. + compress_settings: Additional keyword arguments for :func:`zlib.compressobj`, + excluding ``wbits``. + require_client_max_window_bits: Do not enable compression at all if + client doesn't advertise support for ``client_max_window_bits``; + the default behavior is to enable compression without enforcing + ``client_max_window_bits``. + + """ + + name = ExtensionName("permessage-deflate") + + def __init__( + self, + server_no_context_takeover: bool = False, + client_no_context_takeover: bool = False, + server_max_window_bits: int | None = None, + client_max_window_bits: int | None = None, + compress_settings: dict[str, Any] | None = None, + require_client_max_window_bits: bool = False, + ) -> None: + """ + Configure the Per-Message Deflate extension factory. + + """ + if not (server_max_window_bits is None or 8 <= server_max_window_bits <= 15): + raise ValueError("server_max_window_bits must be between 8 and 15") + if not (client_max_window_bits is None or 8 <= client_max_window_bits <= 15): + raise ValueError("client_max_window_bits must be between 8 and 15") + if compress_settings is not None and "wbits" in compress_settings: + raise ValueError( + "compress_settings must not include wbits, " + "set server_max_window_bits instead" + ) + if client_max_window_bits is None and require_client_max_window_bits: + raise ValueError( + "require_client_max_window_bits is enabled, " + "but client_max_window_bits isn't configured" + ) + + self.server_no_context_takeover = server_no_context_takeover + self.client_no_context_takeover = client_no_context_takeover + self.server_max_window_bits = server_max_window_bits + self.client_max_window_bits = client_max_window_bits + self.compress_settings = compress_settings + self.require_client_max_window_bits = require_client_max_window_bits + + def process_request_params( + self, + params: Sequence[ExtensionParameter], + accepted_extensions: Sequence[Extension], + ) -> tuple[list[ExtensionParameter], PerMessageDeflate]: + """ + Process request parameters. + + Return response params and an extension instance. + + """ + if any(other.name == self.name for other in accepted_extensions): + raise NegotiationError(f"skipped duplicate {self.name}") + + # Load request parameters in local variables. + ( + server_no_context_takeover, + client_no_context_takeover, + server_max_window_bits, + client_max_window_bits, + ) = _extract_parameters(params, is_server=True) + + # Configuration parameters are available in instance variables. + + # After comparing the request and the configuration, the response must + # be available in the local variables. + + # server_no_context_takeover + # + # Config Req. Resp. + # ------ ------ -------------------------------------------------- + # False False False + # False True True + # True False True - must change value to True + # True True True + + if self.server_no_context_takeover: + if not server_no_context_takeover: + server_no_context_takeover = True + + # client_no_context_takeover + # + # Config Req. Resp. + # ------ ------ -------------------------------------------------- + # False False False + # False True True (or False) + # True False True - must change value to True + # True True True (or False) + + if self.client_no_context_takeover: + if not client_no_context_takeover: + client_no_context_takeover = True + + # server_max_window_bits + + # Config Req. Resp. + # ------ ------ -------------------------------------------------- + # None None None + # None 8≤M≤15 M + # 8≤N≤15 None N - must change value + # 8≤N≤15 8≤M≤N M + # 8≤N≤15 N self.server_max_window_bits: + server_max_window_bits = self.server_max_window_bits + + # client_max_window_bits + + # Config Req. Resp. + # ------ ------ -------------------------------------------------- + # None None None + # None True None - must change value + # None 8≤M≤15 M (or None) + # 8≤N≤15 None None or Error! + # 8≤N≤15 True N - must change value + # 8≤N≤15 8≤M≤N M (or None) + # 8≤N≤15 N Sequence[ServerExtensionFactory]: + """ + Enable Per-Message Deflate with default settings in server extensions. + + If the extension is already present, perhaps with non-default settings, + the configuration isn't changed. + + """ + if extensions is None: + extensions = [] + if not any( + ext_factory.name == ServerPerMessageDeflateFactory.name + for ext_factory in extensions + ): + extensions = list(extensions) + [ + ServerPerMessageDeflateFactory( + server_max_window_bits=12, + client_max_window_bits=12, + compress_settings={"memLevel": 5}, + ) + ] + return extensions diff --git a/source/websockets/frames.py b/source/websockets/frames.py new file mode 100644 index 0000000000000000000000000000000000000000..7716e7a2b2896d598b4f3137c55a6e5e565201cb --- /dev/null +++ b/source/websockets/frames.py @@ -0,0 +1,431 @@ +from __future__ import annotations + +import dataclasses +import enum +import io +import os +import secrets +import struct +from collections.abc import Generator, Sequence +from typing import Callable + +from .exceptions import PayloadTooBig, ProtocolError +from .typing import BytesLike + + +try: + from .speedups import apply_mask +except ImportError: + from .utils import apply_mask + + +__all__ = [ + "Opcode", + "OP_CONT", + "OP_TEXT", + "OP_BINARY", + "OP_CLOSE", + "OP_PING", + "OP_PONG", + "DATA_OPCODES", + "CTRL_OPCODES", + "CloseCode", + "Frame", + "Close", +] + + +class Opcode(enum.IntEnum): + """Opcode values for WebSocket frames.""" + + CONT, TEXT, BINARY = 0x00, 0x01, 0x02 + CLOSE, PING, PONG = 0x08, 0x09, 0x0A + + +OP_CONT = Opcode.CONT +OP_TEXT = Opcode.TEXT +OP_BINARY = Opcode.BINARY +OP_CLOSE = Opcode.CLOSE +OP_PING = Opcode.PING +OP_PONG = Opcode.PONG + +DATA_OPCODES = OP_CONT, OP_TEXT, OP_BINARY +CTRL_OPCODES = OP_CLOSE, OP_PING, OP_PONG + + +class CloseCode(enum.IntEnum): + """Close code values for WebSocket close frames.""" + + NORMAL_CLOSURE = 1000 + GOING_AWAY = 1001 + PROTOCOL_ERROR = 1002 + UNSUPPORTED_DATA = 1003 + # 1004 is reserved + NO_STATUS_RCVD = 1005 + ABNORMAL_CLOSURE = 1006 + INVALID_DATA = 1007 + POLICY_VIOLATION = 1008 + MESSAGE_TOO_BIG = 1009 + MANDATORY_EXTENSION = 1010 + INTERNAL_ERROR = 1011 + SERVICE_RESTART = 1012 + TRY_AGAIN_LATER = 1013 + BAD_GATEWAY = 1014 + TLS_HANDSHAKE = 1015 + + +# See https://www.iana.org/assignments/websocket/websocket.xhtml +CLOSE_CODE_EXPLANATIONS: dict[int, str] = { + CloseCode.NORMAL_CLOSURE: "OK", + CloseCode.GOING_AWAY: "going away", + CloseCode.PROTOCOL_ERROR: "protocol error", + CloseCode.UNSUPPORTED_DATA: "unsupported data", + CloseCode.NO_STATUS_RCVD: "no status received [internal]", + CloseCode.ABNORMAL_CLOSURE: "abnormal closure [internal]", + CloseCode.INVALID_DATA: "invalid frame payload data", + CloseCode.POLICY_VIOLATION: "policy violation", + CloseCode.MESSAGE_TOO_BIG: "message too big", + CloseCode.MANDATORY_EXTENSION: "mandatory extension", + CloseCode.INTERNAL_ERROR: "internal error", + CloseCode.SERVICE_RESTART: "service restart", + CloseCode.TRY_AGAIN_LATER: "try again later", + CloseCode.BAD_GATEWAY: "bad gateway", + CloseCode.TLS_HANDSHAKE: "TLS handshake failure [internal]", +} + + +# Close code that are allowed in a close frame. +# Using a set optimizes `code in EXTERNAL_CLOSE_CODES`. +EXTERNAL_CLOSE_CODES = { + CloseCode.NORMAL_CLOSURE, + CloseCode.GOING_AWAY, + CloseCode.PROTOCOL_ERROR, + CloseCode.UNSUPPORTED_DATA, + CloseCode.INVALID_DATA, + CloseCode.POLICY_VIOLATION, + CloseCode.MESSAGE_TOO_BIG, + CloseCode.MANDATORY_EXTENSION, + CloseCode.INTERNAL_ERROR, + CloseCode.SERVICE_RESTART, + CloseCode.TRY_AGAIN_LATER, + CloseCode.BAD_GATEWAY, +} + + +OK_CLOSE_CODES = { + CloseCode.NORMAL_CLOSURE, + CloseCode.GOING_AWAY, + CloseCode.NO_STATUS_RCVD, +} + + +@dataclasses.dataclass +class Frame: + """ + WebSocket frame. + + Attributes: + opcode: Opcode. + data: Payload data. + fin: FIN bit. + rsv1: RSV1 bit. + rsv2: RSV2 bit. + rsv3: RSV3 bit. + + Only these fields are needed. The MASK bit, payload length and masking-key + are handled on the fly when parsing and serializing frames. + + """ + + opcode: Opcode + data: BytesLike + fin: bool = True + rsv1: bool = False + rsv2: bool = False + rsv3: bool = False + + # Configure if you want to see more in logs. Should be a multiple of 3. + MAX_LOG_SIZE = int(os.environ.get("WEBSOCKETS_MAX_LOG_SIZE", "75")) + + def __str__(self) -> str: + """ + Return a human-readable representation of a frame. + + """ + coding = None + length = f"{len(self.data)} byte{'' if len(self.data) == 1 else 's'}" + non_final = "" if self.fin else "continued" + + if self.opcode is OP_TEXT: + # Decoding only the beginning and the end is needlessly hard. + # Decode the entire payload then elide later if necessary. + data = repr(bytes(self.data).decode()) + elif self.opcode is OP_BINARY: + # We'll show at most the first 16 bytes and the last 8 bytes. + # Encode just what we need, plus two dummy bytes to elide later. + binary = self.data + if len(binary) > self.MAX_LOG_SIZE // 3: + cut = (self.MAX_LOG_SIZE // 3 - 1) // 3 # by default cut = 8 + binary = b"".join([binary[: 2 * cut], b"\x00\x00", binary[-cut:]]) + data = " ".join(f"{byte:02x}" for byte in binary) + elif self.opcode is OP_CLOSE: + data = str(Close.parse(self.data)) + elif self.data: + # We don't know if a Continuation frame contains text or binary. + # Ping and Pong frames could contain UTF-8. + # Attempt to decode as UTF-8 and display it as text; fallback to + # binary. If self.data is a memoryview, it has no decode() method, + # which raises AttributeError. + try: + data = repr(bytes(self.data).decode()) + coding = "text" + except (UnicodeDecodeError, AttributeError): + binary = self.data + if len(binary) > self.MAX_LOG_SIZE // 3: + cut = (self.MAX_LOG_SIZE // 3 - 1) // 3 # by default cut = 8 + binary = b"".join([binary[: 2 * cut], b"\x00\x00", binary[-cut:]]) + data = " ".join(f"{byte:02x}" for byte in binary) + coding = "binary" + else: + data = "''" + + if len(data) > self.MAX_LOG_SIZE: + cut = self.MAX_LOG_SIZE // 3 - 1 # by default cut = 24 + data = data[: 2 * cut] + "..." + data[-cut:] + + metadata = ", ".join(filter(None, [coding, length, non_final])) + + return f"{self.opcode.name} {data} [{metadata}]" + + @classmethod + def parse( + cls, + read_exact: Callable[[int], Generator[None, None, bytes | bytearray]], + *, + mask: bool, + max_size: int | None = None, + extensions: Sequence[extensions.Extension] | None = None, + ) -> Generator[None, None, Frame]: + """ + Parse a WebSocket frame. + + This is a generator-based coroutine. + + Args: + read_exact: Generator-based coroutine that reads the requested + bytes or raises an exception if there isn't enough data. + mask: Whether the frame should be masked i.e. whether the read + happens on the server side. + max_size: Maximum payload size in bytes. + extensions: List of extensions, applied in reverse order. + + Raises: + EOFError: If the connection is closed without a full WebSocket frame. + PayloadTooBig: If the frame's payload size exceeds ``max_size``. + ProtocolError: If the frame contains incorrect values. + + """ + # Read the header. + data = yield from read_exact(2) + head1, head2 = struct.unpack("!BB", data) + + # While not Pythonic, this is marginally faster than calling bool(). + fin = True if head1 & 0b10000000 else False + rsv1 = True if head1 & 0b01000000 else False + rsv2 = True if head1 & 0b00100000 else False + rsv3 = True if head1 & 0b00010000 else False + + try: + opcode = Opcode(head1 & 0b00001111) + except ValueError as exc: + raise ProtocolError("invalid opcode") from exc + + if (True if head2 & 0b10000000 else False) != mask: + raise ProtocolError("incorrect masking") + + length = head2 & 0b01111111 + if length == 126: + data = yield from read_exact(2) + (length,) = struct.unpack("!H", data) + elif length == 127: + data = yield from read_exact(8) + (length,) = struct.unpack("!Q", data) + if max_size is not None and length > max_size: + raise PayloadTooBig(length, max_size) + if mask: + mask_bytes = yield from read_exact(4) + + # Read the data. + data = yield from read_exact(length) + if mask: + data = apply_mask(data, mask_bytes) + + frame = cls(opcode, data, fin, rsv1, rsv2, rsv3) + + if extensions is None: + extensions = [] + for extension in reversed(extensions): + frame = extension.decode(frame, max_size=max_size) + + frame.check() + + return frame + + def serialize( + self, + *, + mask: bool, + extensions: Sequence[extensions.Extension] | None = None, + ) -> bytes: + """ + Serialize a WebSocket frame. + + Args: + mask: Whether the frame should be masked i.e. whether the write + happens on the client side. + extensions: List of extensions, applied in order. + + Raises: + ProtocolError: If the frame contains incorrect values. + + """ + self.check() + + if extensions is None: + extensions = [] + for extension in extensions: + self = extension.encode(self) + + output = io.BytesIO() + + # Prepare the header. + head1 = ( + (0b10000000 if self.fin else 0) + | (0b01000000 if self.rsv1 else 0) + | (0b00100000 if self.rsv2 else 0) + | (0b00010000 if self.rsv3 else 0) + | self.opcode + ) + + head2 = 0b10000000 if mask else 0 + + length = len(self.data) + if length < 126: + output.write(struct.pack("!BB", head1, head2 | length)) + elif length < 65536: + output.write(struct.pack("!BBH", head1, head2 | 126, length)) + else: + output.write(struct.pack("!BBQ", head1, head2 | 127, length)) + + if mask: + mask_bytes = secrets.token_bytes(4) + output.write(mask_bytes) + + # Prepare the data. + data: BytesLike + if mask: + data = apply_mask(self.data, mask_bytes) + else: + data = self.data + output.write(data) + + return output.getvalue() + + def check(self) -> None: + """ + Check that reserved bits and opcode have acceptable values. + + Raises: + ProtocolError: If a reserved bit or the opcode is invalid. + + """ + if self.rsv1 or self.rsv2 or self.rsv3: + raise ProtocolError("reserved bits must be 0") + + if self.opcode in CTRL_OPCODES: + if len(self.data) > 125: + raise ProtocolError("control frame too long") + if not self.fin: + raise ProtocolError("fragmented control frame") + + +@dataclasses.dataclass +class Close: + """ + Code and reason for WebSocket close frames. + + Attributes: + code: Close code. + reason: Close reason. + + """ + + code: CloseCode | int + reason: str + + def __str__(self) -> str: + """ + Return a human-readable representation of a close code and reason. + + """ + if 3000 <= self.code < 4000: + explanation = "registered" + elif 4000 <= self.code < 5000: + explanation = "private use" + else: + explanation = CLOSE_CODE_EXPLANATIONS.get(self.code, "unknown") + result = f"{self.code} ({explanation})" + + if self.reason: + result = f"{result} {self.reason}" + + return result + + @classmethod + def parse(cls, data: BytesLike) -> Close: + """ + Parse the payload of a close frame. + + Args: + data: Payload of the close frame. + + Raises: + ProtocolError: If data is ill-formed. + UnicodeDecodeError: If the reason isn't valid UTF-8. + + """ + if isinstance(data, memoryview): + raise AssertionError("only compressed outgoing frames use memoryview") + if len(data) >= 2: + (code,) = struct.unpack("!H", data[:2]) + reason = data[2:].decode() + close = cls(code, reason) + close.check() + return close + elif len(data) == 0: + return cls(CloseCode.NO_STATUS_RCVD, "") + else: + raise ProtocolError("close frame too short") + + def serialize(self) -> bytes: + """ + Serialize the payload of a close frame. + + """ + self.check() + return struct.pack("!H", self.code) + self.reason.encode() + + def check(self) -> None: + """ + Check that the close code has a valid value for a close frame. + + Raises: + ProtocolError: If the close code is invalid. + + """ + if not (self.code in EXTERNAL_CLOSE_CODES or 3000 <= self.code < 5000): + raise ProtocolError("invalid status code") + + +# At the bottom to break import cycles created by type annotations. +from . import extensions # noqa: E402 diff --git a/source/websockets/headers.py b/source/websockets/headers.py new file mode 100644 index 0000000000000000000000000000000000000000..e05ff5b4c34fcd3b3021eb6456e9b2d71dbe0222 --- /dev/null +++ b/source/websockets/headers.py @@ -0,0 +1,586 @@ +from __future__ import annotations + +import base64 +import binascii +import ipaddress +import re +from collections.abc import Sequence +from typing import Callable, TypeVar, cast + +from .exceptions import InvalidHeaderFormat, InvalidHeaderValue +from .typing import ( + ConnectionOption, + ExtensionHeader, + ExtensionName, + ExtensionParameter, + Subprotocol, + UpgradeProtocol, +) + + +__all__ = [ + "build_host", + "parse_connection", + "parse_upgrade", + "parse_extension", + "build_extension", + "parse_subprotocol", + "build_subprotocol", + "validate_subprotocols", + "build_www_authenticate_basic", + "parse_authorization_basic", + "build_authorization_basic", +] + + +T = TypeVar("T") + + +def build_host( + host: str, + port: int, + secure: bool, + *, + always_include_port: bool = False, +) -> str: + """ + Build a ``Host`` header. + + """ + # https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.2 + # IPv6 addresses must be enclosed in brackets. + try: + address = ipaddress.ip_address(host) + except ValueError: + # host is a hostname + pass + else: + # host is an IP address + if address.version == 6: + host = f"[{host}]" + + if always_include_port or port != (443 if secure else 80): + host = f"{host}:{port}" + + return host + + +# To avoid a dependency on a parsing library, we implement manually the ABNF +# described in https://datatracker.ietf.org/doc/html/rfc6455#section-9.1 and +# https://datatracker.ietf.org/doc/html/rfc7230#appendix-B. + + +def peek_ahead(header: str, pos: int) -> str | None: + """ + Return the next character from ``header`` at the given position. + + Return :obj:`None` at the end of ``header``. + + We never need to peek more than one character ahead. + + """ + return None if pos == len(header) else header[pos] + + +_OWS_re = re.compile(r"[\t ]*") + + +def parse_OWS(header: str, pos: int) -> int: + """ + Parse optional whitespace from ``header`` at the given position. + + Return the new position. + + The whitespace itself isn't returned because it isn't significant. + + """ + # There's always a match, possibly empty, whose content doesn't matter. + match = _OWS_re.match(header, pos) + assert match is not None + return match.end() + + +_token_re = re.compile(r"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+") + + +def parse_token(header: str, pos: int, header_name: str) -> tuple[str, int]: + """ + Parse a token from ``header`` at the given position. + + Return the token value and the new position. + + Raises: + InvalidHeaderFormat: On invalid inputs. + + """ + match = _token_re.match(header, pos) + if match is None: + raise InvalidHeaderFormat(header_name, "expected token", header, pos) + return match.group(), match.end() + + +_quoted_string_re = re.compile( + r'"(?:[\x09\x20-\x21\x23-\x5b\x5d-\x7e]|\\[\x09\x20-\x7e\x80-\xff])*"' +) + + +_unquote_re = re.compile(r"\\([\x09\x20-\x7e\x80-\xff])") + + +def parse_quoted_string(header: str, pos: int, header_name: str) -> tuple[str, int]: + """ + Parse a quoted string from ``header`` at the given position. + + Return the unquoted value and the new position. + + Raises: + InvalidHeaderFormat: On invalid inputs. + + """ + match = _quoted_string_re.match(header, pos) + if match is None: + raise InvalidHeaderFormat(header_name, "expected quoted string", header, pos) + return _unquote_re.sub(r"\1", match.group()[1:-1]), match.end() + + +_quotable_re = re.compile(r"[\x09\x20-\x7e\x80-\xff]*") + + +_quote_re = re.compile(r"([\x22\x5c])") + + +def build_quoted_string(value: str) -> str: + """ + Format ``value`` as a quoted string. + + This is the reverse of :func:`parse_quoted_string`. + + """ + match = _quotable_re.fullmatch(value) + if match is None: + raise ValueError("invalid characters for quoted-string encoding") + return '"' + _quote_re.sub(r"\\\1", value) + '"' + + +def parse_list( + parse_item: Callable[[str, int, str], tuple[T, int]], + header: str, + pos: int, + header_name: str, +) -> list[T]: + """ + Parse a comma-separated list from ``header`` at the given position. + + This is appropriate for parsing values with the following grammar: + + 1#item + + ``parse_item`` parses one item. + + ``header`` is assumed not to start or end with whitespace. + + (This function is designed for parsing an entire header value and + :func:`~websockets.http.read_headers` strips whitespace from values.) + + Return a list of items. + + Raises: + InvalidHeaderFormat: On invalid inputs. + + """ + # Per https://datatracker.ietf.org/doc/html/rfc7230#section-7, "a recipient + # MUST parse and ignore a reasonable number of empty list elements"; + # hence while loops that remove extra delimiters. + + # Remove extra delimiters before the first item. + while peek_ahead(header, pos) == ",": + pos = parse_OWS(header, pos + 1) + + items = [] + while True: + # Loop invariant: a item starts at pos in header. + item, pos = parse_item(header, pos, header_name) + items.append(item) + pos = parse_OWS(header, pos) + + # We may have reached the end of the header. + if pos == len(header): + break + + # There must be a delimiter after each element except the last one. + if peek_ahead(header, pos) == ",": + pos = parse_OWS(header, pos + 1) + else: + raise InvalidHeaderFormat(header_name, "expected comma", header, pos) + + # Remove extra delimiters before the next item. + while peek_ahead(header, pos) == ",": + pos = parse_OWS(header, pos + 1) + + # We may have reached the end of the header. + if pos == len(header): + break + + # Since we only advance in the header by one character with peek_ahead() + # or with the end position of a regex match, we can't overshoot the end. + assert pos == len(header) + + return items + + +def parse_connection_option( + header: str, pos: int, header_name: str +) -> tuple[ConnectionOption, int]: + """ + Parse a Connection option from ``header`` at the given position. + + Return the protocol value and the new position. + + Raises: + InvalidHeaderFormat: On invalid inputs. + + """ + item, pos = parse_token(header, pos, header_name) + return cast(ConnectionOption, item), pos + + +def parse_connection(header: str) -> list[ConnectionOption]: + """ + Parse a ``Connection`` header. + + Return a list of HTTP connection options. + + Args + header: value of the ``Connection`` header. + + Raises: + InvalidHeaderFormat: On invalid inputs. + + """ + return parse_list(parse_connection_option, header, 0, "Connection") + + +_protocol_re = re.compile( + r"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+(?:/[-!#$%&\'*+.^_`|~0-9a-zA-Z]+)?" +) + + +def parse_upgrade_protocol( + header: str, pos: int, header_name: str +) -> tuple[UpgradeProtocol, int]: + """ + Parse an Upgrade protocol from ``header`` at the given position. + + Return the protocol value and the new position. + + Raises: + InvalidHeaderFormat: On invalid inputs. + + """ + match = _protocol_re.match(header, pos) + if match is None: + raise InvalidHeaderFormat(header_name, "expected protocol", header, pos) + return cast(UpgradeProtocol, match.group()), match.end() + + +def parse_upgrade(header: str) -> list[UpgradeProtocol]: + """ + Parse an ``Upgrade`` header. + + Return a list of HTTP protocols. + + Args: + header: Value of the ``Upgrade`` header. + + Raises: + InvalidHeaderFormat: On invalid inputs. + + """ + return parse_list(parse_upgrade_protocol, header, 0, "Upgrade") + + +def parse_extension_item_param( + header: str, pos: int, header_name: str +) -> tuple[ExtensionParameter, int]: + """ + Parse a single extension parameter from ``header`` at the given position. + + Return a ``(name, value)`` pair and the new position. + + Raises: + InvalidHeaderFormat: On invalid inputs. + + """ + # Extract parameter name. + name, pos = parse_token(header, pos, header_name) + pos = parse_OWS(header, pos) + # Extract parameter value, if there is one. + value: str | None = None + if peek_ahead(header, pos) == "=": + pos = parse_OWS(header, pos + 1) + if peek_ahead(header, pos) == '"': + pos_before = pos # for proper error reporting below + value, pos = parse_quoted_string(header, pos, header_name) + # https://datatracker.ietf.org/doc/html/rfc6455#section-9.1 says: + # the value after quoted-string unescaping MUST conform to + # the 'token' ABNF. + if _token_re.fullmatch(value) is None: + raise InvalidHeaderFormat( + header_name, "invalid quoted header content", header, pos_before + ) + else: + value, pos = parse_token(header, pos, header_name) + pos = parse_OWS(header, pos) + + return (name, value), pos + + +def parse_extension_item( + header: str, pos: int, header_name: str +) -> tuple[ExtensionHeader, int]: + """ + Parse an extension definition from ``header`` at the given position. + + Return an ``(extension name, parameters)`` pair, where ``parameters`` is a + list of ``(name, value)`` pairs, and the new position. + + Raises: + InvalidHeaderFormat: On invalid inputs. + + """ + # Extract extension name. + name, pos = parse_token(header, pos, header_name) + pos = parse_OWS(header, pos) + # Extract all parameters. + parameters = [] + while peek_ahead(header, pos) == ";": + pos = parse_OWS(header, pos + 1) + parameter, pos = parse_extension_item_param(header, pos, header_name) + parameters.append(parameter) + return (cast(ExtensionName, name), parameters), pos + + +def parse_extension(header: str) -> list[ExtensionHeader]: + """ + Parse a ``Sec-WebSocket-Extensions`` header. + + Return a list of WebSocket extensions and their parameters in this format:: + + [ + ( + 'extension name', + [ + ('parameter name', 'parameter value'), + .... + ] + ), + ... + ] + + Parameter values are :obj:`None` when no value is provided. + + Raises: + InvalidHeaderFormat: On invalid inputs. + + """ + return parse_list(parse_extension_item, header, 0, "Sec-WebSocket-Extensions") + + +parse_extension_list = parse_extension # alias for backwards compatibility + + +def build_extension_item( + name: ExtensionName, parameters: Sequence[ExtensionParameter] +) -> str: + """ + Build an extension definition. + + This is the reverse of :func:`parse_extension_item`. + + """ + return "; ".join( + [cast(str, name)] + + [ + # Quoted strings aren't necessary because values are always tokens. + name if value is None else f"{name}={value}" + for name, value in parameters + ] + ) + + +def build_extension(extensions: Sequence[ExtensionHeader]) -> str: + """ + Build a ``Sec-WebSocket-Extensions`` header. + + This is the reverse of :func:`parse_extension`. + + """ + return ", ".join( + build_extension_item(name, parameters) for name, parameters in extensions + ) + + +build_extension_list = build_extension # alias for backwards compatibility + + +def parse_subprotocol_item( + header: str, pos: int, header_name: str +) -> tuple[Subprotocol, int]: + """ + Parse a subprotocol from ``header`` at the given position. + + Return the subprotocol value and the new position. + + Raises: + InvalidHeaderFormat: On invalid inputs. + + """ + item, pos = parse_token(header, pos, header_name) + return cast(Subprotocol, item), pos + + +def parse_subprotocol(header: str) -> list[Subprotocol]: + """ + Parse a ``Sec-WebSocket-Protocol`` header. + + Return a list of WebSocket subprotocols. + + Raises: + InvalidHeaderFormat: On invalid inputs. + + """ + return parse_list(parse_subprotocol_item, header, 0, "Sec-WebSocket-Protocol") + + +parse_subprotocol_list = parse_subprotocol # alias for backwards compatibility + + +def build_subprotocol(subprotocols: Sequence[Subprotocol]) -> str: + """ + Build a ``Sec-WebSocket-Protocol`` header. + + This is the reverse of :func:`parse_subprotocol`. + + """ + return ", ".join(subprotocols) + + +build_subprotocol_list = build_subprotocol # alias for backwards compatibility + + +def validate_subprotocols(subprotocols: Sequence[Subprotocol]) -> None: + """ + Validate that ``subprotocols`` is suitable for :func:`build_subprotocol`. + + """ + if not isinstance(subprotocols, Sequence): + raise TypeError("subprotocols must be a list") + if isinstance(subprotocols, str): + raise TypeError("subprotocols must be a list, not a str") + for subprotocol in subprotocols: + if not _token_re.fullmatch(subprotocol): + raise ValueError(f"invalid subprotocol: {subprotocol}") + + +def build_www_authenticate_basic(realm: str) -> str: + """ + Build a ``WWW-Authenticate`` header for HTTP Basic Auth. + + Args: + realm: Identifier of the protection space. + + """ + # https://datatracker.ietf.org/doc/html/rfc7617#section-2 + realm = build_quoted_string(realm) + charset = build_quoted_string("UTF-8") + return f"Basic realm={realm}, charset={charset}" + + +_token68_re = re.compile(r"[A-Za-z0-9-._~+/]+=*") + + +def parse_token68(header: str, pos: int, header_name: str) -> tuple[str, int]: + """ + Parse a token68 from ``header`` at the given position. + + Return the token value and the new position. + + Raises: + InvalidHeaderFormat: On invalid inputs. + + """ + match = _token68_re.match(header, pos) + if match is None: + raise InvalidHeaderFormat(header_name, "expected token68", header, pos) + return match.group(), match.end() + + +def parse_end(header: str, pos: int, header_name: str) -> None: + """ + Check that parsing reached the end of header. + + """ + if pos < len(header): + raise InvalidHeaderFormat(header_name, "trailing data", header, pos) + + +def parse_authorization_basic(header: str) -> tuple[str, str]: + """ + Parse an ``Authorization`` header for HTTP Basic Auth. + + Return a ``(username, password)`` tuple. + + Args: + header: Value of the ``Authorization`` header. + + Raises: + InvalidHeaderFormat: On invalid inputs. + InvalidHeaderValue: On unsupported inputs. + + """ + # https://datatracker.ietf.org/doc/html/rfc7235#section-2.1 + # https://datatracker.ietf.org/doc/html/rfc7617#section-2 + scheme, pos = parse_token(header, 0, "Authorization") + if scheme.lower() != "basic": + raise InvalidHeaderValue( + "Authorization", + f"unsupported scheme: {scheme}", + ) + if peek_ahead(header, pos) != " ": + raise InvalidHeaderFormat( + "Authorization", "expected space after scheme", header, pos + ) + pos += 1 + basic_credentials, pos = parse_token68(header, pos, "Authorization") + parse_end(header, pos, "Authorization") + + try: + user_pass = base64.b64decode(basic_credentials.encode()).decode() + except binascii.Error: + raise InvalidHeaderValue( + "Authorization", + "expected base64-encoded credentials", + ) from None + try: + username, password = user_pass.split(":", 1) + except ValueError: + raise InvalidHeaderValue( + "Authorization", + "expected username:password credentials", + ) from None + + return username, password + + +def build_authorization_basic(username: str, password: str) -> str: + """ + Build an ``Authorization`` header for HTTP Basic Auth. + + This is the reverse of :func:`parse_authorization_basic`. + + """ + # https://datatracker.ietf.org/doc/html/rfc7617#section-2 + assert ":" not in username + user_pass = f"{username}:{password}" + basic_credentials = base64.b64encode(user_pass.encode()).decode() + return "Basic " + basic_credentials diff --git a/source/websockets/http.py b/source/websockets/http.py new file mode 100644 index 0000000000000000000000000000000000000000..0d860e5379404c12f8fb4177ca4fcb6764b86f3b --- /dev/null +++ b/source/websockets/http.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +import warnings + +from .datastructures import Headers, MultipleValuesError # noqa: F401 + + +with warnings.catch_warnings(): + # Suppress redundant DeprecationWarning raised by websockets.legacy. + warnings.filterwarnings("ignore", category=DeprecationWarning) + from .legacy.http import read_request, read_response # noqa: F401 + + +warnings.warn( # deprecated in 9.0 - 2021-09-01 + "Headers and MultipleValuesError were moved " + "from websockets.http to websockets.datastructures" + "and read_request and read_response were moved " + "from websockets.http to websockets.legacy.http", + DeprecationWarning, +) diff --git a/source/websockets/http11.py b/source/websockets/http11.py new file mode 100644 index 0000000000000000000000000000000000000000..5af73eb0cf58fa4246e9a274a2b1ef36f301eb9c --- /dev/null +++ b/source/websockets/http11.py @@ -0,0 +1,438 @@ +from __future__ import annotations + +import dataclasses +import os +import re +import sys +import warnings +from collections.abc import Generator +from typing import Callable + +from .datastructures import Headers +from .exceptions import SecurityError +from .version import version as websockets_version + + +__all__ = [ + "SERVER", + "USER_AGENT", + "Request", + "Response", +] + + +PYTHON_VERSION = "{}.{}".format(*sys.version_info) + +# User-Agent header for HTTP requests. +USER_AGENT = os.environ.get( + "WEBSOCKETS_USER_AGENT", + f"Python/{PYTHON_VERSION} websockets/{websockets_version}", +) + +# Server header for HTTP responses. +SERVER = os.environ.get( + "WEBSOCKETS_SERVER", + f"Python/{PYTHON_VERSION} websockets/{websockets_version}", +) + +# Maximum total size of headers is around 128 * 8 KiB = 1 MiB. +MAX_NUM_HEADERS = int(os.environ.get("WEBSOCKETS_MAX_NUM_HEADERS", "128")) + +# Limit request line and header lines. 8KiB is the most common default +# configuration of popular HTTP servers. +MAX_LINE_LENGTH = int(os.environ.get("WEBSOCKETS_MAX_LINE_LENGTH", "8192")) + +# Support for HTTP response bodies is intended to read an error message +# returned by a server. It isn't designed to perform large file transfers. +MAX_BODY_SIZE = int(os.environ.get("WEBSOCKETS_MAX_BODY_SIZE", "1_048_576")) # 1 MiB + + +def d(value: bytes | bytearray) -> str: + """ + Decode a bytestring for interpolating into an error message. + + """ + return value.decode(errors="backslashreplace") + + +# See https://datatracker.ietf.org/doc/html/rfc7230#appendix-B. + +# Regex for validating header names. + +_token_re = re.compile(rb"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+") + +# Regex for validating header values. + +# We don't attempt to support obsolete line folding. + +# Include HTAB (\x09), SP (\x20), VCHAR (\x21-\x7e), obs-text (\x80-\xff). + +# The ABNF is complicated because it attempts to express that optional +# whitespace is ignored. We strip whitespace and don't revalidate that. + +# See also https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189 + +_value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*") + + +@dataclasses.dataclass +class Request: + """ + WebSocket handshake request. + + Attributes: + path: Request path, including optional query. + headers: Request headers. + """ + + path: str + headers: Headers + # body isn't useful is the context of this library. + + _exception: Exception | None = None + + @property + def exception(self) -> Exception | None: # pragma: no cover + warnings.warn( # deprecated in 10.3 - 2022-04-17 + "Request.exception is deprecated; use ServerProtocol.handshake_exc instead", + DeprecationWarning, + ) + return self._exception + + @classmethod + def parse( + cls, + read_line: Callable[[int], Generator[None, None, bytes | bytearray]], + ) -> Generator[None, None, Request]: + """ + Parse a WebSocket handshake request. + + This is a generator-based coroutine. + + The request path isn't URL-decoded or validated in any way. + + The request path and headers are expected to contain only ASCII + characters. Other characters are represented with surrogate escapes. + + :meth:`parse` doesn't attempt to read the request body because + WebSocket handshake requests don't have one. If the request contains a + body, it may be read from the data stream after :meth:`parse` returns. + + Args: + read_line: Generator-based coroutine that reads a LF-terminated + line or raises an exception if there isn't enough data + + Raises: + EOFError: If the connection is closed without a full HTTP request. + SecurityError: If the request exceeds a security limit. + ValueError: If the request isn't well formatted. + + """ + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.1 + + # Parsing is simple because fixed values are expected for method and + # version and because path isn't checked. Since WebSocket software tends + # to implement HTTP/1.1 strictly, there's little need for lenient parsing. + + try: + request_line = yield from parse_line(read_line) + except EOFError as exc: + raise EOFError("connection closed while reading HTTP request line") from exc + + try: + method, raw_path, protocol = request_line.split(b" ", 2) + except ValueError: # not enough values to unpack (expected 3, got 1-2) + raise ValueError(f"invalid HTTP request line: {d(request_line)}") from None + if protocol != b"HTTP/1.1": + raise ValueError( + f"unsupported protocol; expected HTTP/1.1: {d(request_line)}" + ) + if method != b"GET": + raise ValueError(f"unsupported HTTP method; expected GET; got {d(method)}") + path = raw_path.decode("ascii", "surrogateescape") + + headers = yield from parse_headers(read_line) + + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.3 + + if "Transfer-Encoding" in headers: + raise NotImplementedError("transfer codings aren't supported") + + if "Content-Length" in headers: + # Some devices send a Content-Length header with a value of 0. + # This raises ValueError if Content-Length isn't an integer too. + if int(headers["Content-Length"]) != 0: + raise ValueError("unsupported request body") + + return cls(path, headers) + + def serialize(self) -> bytes: + """ + Serialize a WebSocket handshake request. + + """ + # Since the request line and headers only contain ASCII characters, + # we can keep this simple. + request = f"GET {self.path} HTTP/1.1\r\n".encode() + request += self.headers.serialize() + return request + + +@dataclasses.dataclass +class Response: + """ + WebSocket handshake response. + + Attributes: + status_code: Response code. + reason_phrase: Response reason. + headers: Response headers. + body: Response body. + + """ + + status_code: int + reason_phrase: str + headers: Headers + body: bytes | bytearray = b"" + + _exception: Exception | None = None + + @property + def exception(self) -> Exception | None: # pragma: no cover + warnings.warn( # deprecated in 10.3 - 2022-04-17 + "Response.exception is deprecated; " + "use ClientProtocol.handshake_exc instead", + DeprecationWarning, + ) + return self._exception + + @classmethod + def parse( + cls, + read_line: Callable[[int], Generator[None, None, bytes | bytearray]], + read_exact: Callable[[int], Generator[None, None, bytes | bytearray]], + read_to_eof: Callable[[int], Generator[None, None, bytes | bytearray]], + proxy: bool = False, + ) -> Generator[None, None, Response]: + """ + Parse a WebSocket handshake response. + + This is a generator-based coroutine. + + The reason phrase and headers are expected to contain only ASCII + characters. Other characters are represented with surrogate escapes. + + Args: + read_line: Generator-based coroutine that reads a LF-terminated + line or raises an exception if there isn't enough data. + read_exact: Generator-based coroutine that reads the requested + bytes or raises an exception if there isn't enough data. + read_to_eof: Generator-based coroutine that reads until the end + of the stream. + + Raises: + EOFError: If the connection is closed without a full HTTP response. + SecurityError: If the response exceeds a security limit. + LookupError: If the response isn't well formatted. + ValueError: If the response isn't well formatted. + + """ + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.2 + + try: + status_line = yield from parse_line(read_line) + except EOFError as exc: + raise EOFError("connection closed while reading HTTP status line") from exc + + try: + protocol, raw_status_code, raw_reason = status_line.split(b" ", 2) + except ValueError: # not enough values to unpack (expected 3, got 1-2) + raise ValueError(f"invalid HTTP status line: {d(status_line)}") from None + if proxy: # some proxies still use HTTP/1.0 + if protocol not in [b"HTTP/1.1", b"HTTP/1.0"]: + raise ValueError( + f"unsupported protocol; expected HTTP/1.1 or HTTP/1.0: " + f"{d(status_line)}" + ) + else: + if protocol != b"HTTP/1.1": + raise ValueError( + f"unsupported protocol; expected HTTP/1.1: {d(status_line)}" + ) + try: + status_code = int(raw_status_code) + except ValueError: # invalid literal for int() with base 10 + raise ValueError( + f"invalid status code; expected integer; got {d(raw_status_code)}" + ) from None + if not 100 <= status_code < 600: + raise ValueError( + f"invalid status code; expected 100–599; got {d(raw_status_code)}" + ) + if not _value_re.fullmatch(raw_reason): + raise ValueError(f"invalid HTTP reason phrase: {d(raw_reason)}") + reason = raw_reason.decode("ascii", "surrogateescape") + + headers = yield from parse_headers(read_line) + + body: bytes | bytearray + if proxy: + body = b"" + else: + body = yield from read_body( + status_code, headers, read_line, read_exact, read_to_eof + ) + + return cls(status_code, reason, headers, body) + + def serialize(self) -> bytes: + """ + Serialize a WebSocket handshake response. + + """ + # Since the status line and headers only contain ASCII characters, + # we can keep this simple. + response = f"HTTP/1.1 {self.status_code} {self.reason_phrase}\r\n".encode() + response += self.headers.serialize() + response += self.body + return response + + +def parse_line( + read_line: Callable[[int], Generator[None, None, bytes | bytearray]], +) -> Generator[None, None, bytes | bytearray]: + """ + Parse a single line. + + CRLF is stripped from the return value. + + Args: + read_line: Generator-based coroutine that reads a LF-terminated line + or raises an exception if there isn't enough data. + + Raises: + EOFError: If the connection is closed without a CRLF. + SecurityError: If the response exceeds a security limit. + + """ + try: + line = yield from read_line(MAX_LINE_LENGTH) + except RuntimeError: + raise SecurityError("line too long") + # Not mandatory but safe - https://datatracker.ietf.org/doc/html/rfc7230#section-3.5 + if not line.endswith(b"\r\n"): + raise EOFError("line without CRLF") + return line[:-2] + + +def parse_headers( + read_line: Callable[[int], Generator[None, None, bytes | bytearray]], +) -> Generator[None, None, Headers]: + """ + Parse HTTP headers. + + Non-ASCII characters are represented with surrogate escapes. + + Args: + read_line: Generator-based coroutine that reads a LF-terminated line + or raises an exception if there isn't enough data. + + Raises: + EOFError: If the connection is closed without complete headers. + SecurityError: If the request exceeds a security limit. + ValueError: If the request isn't well formatted. + + """ + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.2 + + # We don't attempt to support obsolete line folding. + + headers = Headers() + for _ in range(MAX_NUM_HEADERS + 1): + try: + line = yield from parse_line(read_line) + except EOFError as exc: + raise EOFError("connection closed while reading HTTP headers") from exc + if line == b"": + break + + try: + raw_name, raw_value = line.split(b":", 1) + except ValueError: # not enough values to unpack (expected 2, got 1) + raise ValueError(f"invalid HTTP header line: {d(line)}") from None + if not _token_re.fullmatch(raw_name): + raise ValueError(f"invalid HTTP header name: {d(raw_name)}") + raw_value = raw_value.strip(b" \t") + if not _value_re.fullmatch(raw_value): + raise ValueError(f"invalid HTTP header value: {d(raw_value)}") + + name = raw_name.decode("ascii") # guaranteed to be ASCII at this point + value = raw_value.decode("ascii", "surrogateescape") + headers[name] = value + + else: + raise SecurityError("too many HTTP headers") + + return headers + + +def read_body( + status_code: int, + headers: Headers, + read_line: Callable[[int], Generator[None, None, bytes | bytearray]], + read_exact: Callable[[int], Generator[None, None, bytes | bytearray]], + read_to_eof: Callable[[int], Generator[None, None, bytes | bytearray]], +) -> Generator[None, None, bytes | bytearray]: + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.3 + + # Since websockets only does GET requests (no HEAD, no CONNECT), all + # responses except 1xx, 204, and 304 include a message body. + if 100 <= status_code < 200 or status_code == 204 or status_code == 304: + return b"" + + # MultipleValuesError is sufficiently unlikely that we don't attempt to + # handle it when accessing headers. Instead we document that its parent + # class, LookupError, may be raised. + # Conversions from str to int are protected by sys.set_int_max_str_digits.. + + elif (coding := headers.get("Transfer-Encoding")) is not None: + if coding != "chunked": + raise NotImplementedError(f"transfer coding {coding} isn't supported") + + body = b"" + while True: + chunk_size_line = yield from parse_line(read_line) + raw_chunk_size = chunk_size_line.split(b";", 1)[0] + # Set a lower limit than default_max_str_digits; 1 EB is plenty. + if len(raw_chunk_size) > 15: + str_chunk_size = raw_chunk_size.decode(errors="backslashreplace") + raise SecurityError(f"chunk too large: 0x{str_chunk_size} bytes") + chunk_size = int(raw_chunk_size, 16) + if chunk_size == 0: + break + if len(body) + chunk_size > MAX_BODY_SIZE: + raise SecurityError( + f"chunk too large: {chunk_size} bytes after {len(body)} bytes" + ) + body += yield from read_exact(chunk_size) + if (yield from read_exact(2)) != b"\r\n": + raise ValueError("chunk without CRLF") + # Read the trailer. + yield from parse_headers(read_line) + return body + + elif (raw_content_length := headers.get("Content-Length")) is not None: + # Set a lower limit than default_max_str_digits; 1 EiB is plenty. + if len(raw_content_length) > 18: + raise SecurityError(f"body too large: {raw_content_length} bytes") + content_length = int(raw_content_length) + if content_length > MAX_BODY_SIZE: + raise SecurityError(f"body too large: {content_length} bytes") + return (yield from read_exact(content_length)) + + else: + try: + return (yield from read_to_eof(MAX_BODY_SIZE)) + except RuntimeError: + raise SecurityError(f"body too large: over {MAX_BODY_SIZE} bytes") diff --git a/source/websockets/imports.py b/source/websockets/imports.py new file mode 100644 index 0000000000000000000000000000000000000000..c63fb212ec602ae6ec75fe1b86a29fb2e11334df --- /dev/null +++ b/source/websockets/imports.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import warnings +from collections.abc import Iterable +from typing import Any + + +__all__ = ["lazy_import"] + + +def import_name(name: str, source: str, namespace: dict[str, Any]) -> Any: + """ + Import ``name`` from ``source`` in ``namespace``. + + There are two use cases: + + - ``name`` is an object defined in ``source``; + - ``name`` is a submodule of ``source``. + + Neither :func:`__import__` nor :func:`~importlib.import_module` does + exactly this. :func:`__import__` is closer to the intended behavior. + + """ + level = 0 + while source[level] == ".": + level += 1 + assert level < len(source), "importing from parent isn't supported" + module = __import__(source[level:], namespace, None, [name], level) + return getattr(module, name) + + +def lazy_import( + namespace: dict[str, Any], + aliases: dict[str, str] | None = None, + deprecated_aliases: dict[str, str] | None = None, +) -> None: + """ + Provide lazy, module-level imports. + + Typical use:: + + __getattr__, __dir__ = lazy_import( + globals(), + aliases={ + "": "", + ... + }, + deprecated_aliases={ + ..., + } + ) + + This function defines ``__getattr__`` and ``__dir__`` per :pep:`562`. + + """ + if aliases is None: + aliases = {} + if deprecated_aliases is None: + deprecated_aliases = {} + + namespace_set = set(namespace) + aliases_set = set(aliases) + deprecated_aliases_set = set(deprecated_aliases) + + assert not namespace_set & aliases_set, "namespace conflict" + assert not namespace_set & deprecated_aliases_set, "namespace conflict" + assert not aliases_set & deprecated_aliases_set, "namespace conflict" + + package = namespace["__name__"] + + def __getattr__(name: str) -> Any: + assert aliases is not None # mypy cannot figure this out + try: + source = aliases[name] + except KeyError: + pass + else: + return import_name(name, source, namespace) + + assert deprecated_aliases is not None # mypy cannot figure this out + try: + source = deprecated_aliases[name] + except KeyError: + pass + else: + warnings.warn( + f"{package}.{name} is deprecated", + DeprecationWarning, + stacklevel=2, + ) + return import_name(name, source, namespace) + + raise AttributeError(f"module {package!r} has no attribute {name!r}") + + namespace["__getattr__"] = __getattr__ + + def __dir__() -> Iterable[str]: + return sorted(namespace_set | aliases_set | deprecated_aliases_set) + + namespace["__dir__"] = __dir__ diff --git a/source/websockets/legacy/__init__.py b/source/websockets/legacy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ad9aa25064f626754bda8a8bb149d974002a064e --- /dev/null +++ b/source/websockets/legacy/__init__.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +import warnings + + +warnings.warn( # deprecated in 14.0 - 2024-11-09 + "websockets.legacy is deprecated; " + "see https://websockets.readthedocs.io/en/stable/howto/upgrade.html " + "for upgrade instructions", + DeprecationWarning, +) diff --git a/source/websockets/legacy/auth.py b/source/websockets/legacy/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..a262fcd791bc66a1ad0ee9389faed1c62c69e2be --- /dev/null +++ b/source/websockets/legacy/auth.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import functools +import hmac +import http +from collections.abc import Awaitable, Iterable +from typing import Any, Callable, cast + +from ..datastructures import Headers +from ..exceptions import InvalidHeader +from ..headers import build_www_authenticate_basic, parse_authorization_basic +from .server import HTTPResponse, WebSocketServerProtocol + + +__all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"] + +Credentials = tuple[str, str] + + +def is_credentials(value: Any) -> bool: + try: + username, password = value + except (TypeError, ValueError): + return False + else: + return isinstance(username, str) and isinstance(password, str) + + +class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol): + """ + WebSocket server protocol that enforces HTTP Basic Auth. + + """ + + realm: str = "" + """ + Scope of protection. + + If provided, it should contain only ASCII characters because the + encoding of non-ASCII characters is undefined. + """ + + username: str | None = None + """Username of the authenticated user.""" + + def __init__( + self, + *args: Any, + realm: str | None = None, + check_credentials: Callable[[str, str], Awaitable[bool]] | None = None, + **kwargs: Any, + ) -> None: + if realm is not None: + self.realm = realm # shadow class attribute + self._check_credentials = check_credentials + super().__init__(*args, **kwargs) + + async def check_credentials(self, username: str, password: str) -> bool: + """ + Check whether credentials are authorized. + + This coroutine may be overridden in a subclass, for example to + authenticate against a database or an external service. + + Args: + username: HTTP Basic Auth username. + password: HTTP Basic Auth password. + + Returns: + :obj:`True` if the handshake should continue; + :obj:`False` if it should fail with an HTTP 401 error. + + """ + if self._check_credentials is not None: + return await self._check_credentials(username, password) + + return False + + async def process_request( + self, + path: str, + request_headers: Headers, + ) -> HTTPResponse | None: + """ + Check HTTP Basic Auth and return an HTTP 401 response if needed. + + """ + try: + authorization = request_headers["Authorization"] + except KeyError: + return ( + http.HTTPStatus.UNAUTHORIZED, + [("WWW-Authenticate", build_www_authenticate_basic(self.realm))], + b"Missing credentials\n", + ) + + try: + username, password = parse_authorization_basic(authorization) + except InvalidHeader: + return ( + http.HTTPStatus.UNAUTHORIZED, + [("WWW-Authenticate", build_www_authenticate_basic(self.realm))], + b"Unsupported credentials\n", + ) + + if not await self.check_credentials(username, password): + return ( + http.HTTPStatus.UNAUTHORIZED, + [("WWW-Authenticate", build_www_authenticate_basic(self.realm))], + b"Invalid credentials\n", + ) + + self.username = username + + return await super().process_request(path, request_headers) + + +def basic_auth_protocol_factory( + realm: str | None = None, + credentials: Credentials | Iterable[Credentials] | None = None, + check_credentials: Callable[[str, str], Awaitable[bool]] | None = None, + create_protocol: Callable[..., BasicAuthWebSocketServerProtocol] | None = None, +) -> Callable[..., BasicAuthWebSocketServerProtocol]: + """ + Protocol factory that enforces HTTP Basic Auth. + + :func:`basic_auth_protocol_factory` is designed to integrate with + :func:`~websockets.legacy.server.serve` like this:: + + serve( + ..., + create_protocol=basic_auth_protocol_factory( + realm="my dev server", + credentials=("hello", "iloveyou"), + ) + ) + + Args: + realm: Scope of protection. It should contain only ASCII characters + because the encoding of non-ASCII characters is undefined. + Refer to section 2.2 of :rfc:`7235` for details. + credentials: Hard coded authorized credentials. It can be a + ``(username, password)`` pair or a list of such pairs. + check_credentials: Coroutine that verifies credentials. + It receives ``username`` and ``password`` arguments + and returns a :class:`bool`. One of ``credentials`` or + ``check_credentials`` must be provided but not both. + create_protocol: Factory that creates the protocol. By default, this + is :class:`BasicAuthWebSocketServerProtocol`. It can be replaced + by a subclass. + Raises: + TypeError: If the ``credentials`` or ``check_credentials`` argument is + wrong. + + """ + if (credentials is None) == (check_credentials is None): + raise TypeError("provide either credentials or check_credentials") + + if credentials is not None: + if is_credentials(credentials): + credentials_list = [cast(Credentials, credentials)] + elif isinstance(credentials, Iterable): + credentials_list = list(cast(Iterable[Credentials], credentials)) + if not all(is_credentials(item) for item in credentials_list): + raise TypeError(f"invalid credentials argument: {credentials}") + else: + raise TypeError(f"invalid credentials argument: {credentials}") + + credentials_dict = dict(credentials_list) + + async def check_credentials(username: str, password: str) -> bool: + try: + expected_password = credentials_dict[username] + except KeyError: + return False + return hmac.compare_digest(expected_password, password) + + if create_protocol is None: + create_protocol = BasicAuthWebSocketServerProtocol + + # Help mypy and avoid this error: "type[BasicAuthWebSocketServerProtocol] | + # Callable[..., BasicAuthWebSocketServerProtocol]" not callable [misc] + create_protocol = cast( + Callable[..., BasicAuthWebSocketServerProtocol], create_protocol + ) + return functools.partial( + create_protocol, + realm=realm, + check_credentials=check_credentials, + ) diff --git a/source/websockets/legacy/client.py b/source/websockets/legacy/client.py new file mode 100644 index 0000000000000000000000000000000000000000..575c84519c35b3aa25481a5a3abf45b9719df630 --- /dev/null +++ b/source/websockets/legacy/client.py @@ -0,0 +1,703 @@ +from __future__ import annotations + +import asyncio +import functools +import logging +import os +import random +import traceback +import urllib.parse +import warnings +from collections.abc import AsyncIterator, Generator, Sequence +from types import TracebackType +from typing import Any, Callable, cast + +from ..asyncio.compatibility import asyncio_timeout +from ..datastructures import Headers, HeadersLike +from ..exceptions import ( + InvalidHeader, + InvalidHeaderValue, + InvalidMessage, + NegotiationError, + SecurityError, +) +from ..extensions import ClientExtensionFactory, Extension +from ..extensions.permessage_deflate import enable_client_permessage_deflate +from ..headers import ( + build_authorization_basic, + build_extension, + build_host, + build_subprotocol, + parse_extension, + parse_subprotocol, + validate_subprotocols, +) +from ..http11 import USER_AGENT +from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol +from ..uri import WebSocketURI, parse_uri +from .exceptions import InvalidStatusCode, RedirectHandshake +from .handshake import build_request, check_response +from .http import read_response +from .protocol import WebSocketCommonProtocol + + +__all__ = ["connect", "unix_connect", "WebSocketClientProtocol"] + + +class WebSocketClientProtocol(WebSocketCommonProtocol): + """ + WebSocket client connection. + + :class:`WebSocketClientProtocol` provides :meth:`recv` and :meth:`send` + coroutines for receiving and sending messages. + + It supports asynchronous iteration to receive messages:: + + async for message in websocket: + await process(message) + + The iterator exits normally when the connection is closed with close code + 1000 (OK) or 1001 (going away) or without a close code. It raises + a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection + is closed with any other code. + + See :func:`connect` for the documentation of ``logger``, ``origin``, + ``extensions``, ``subprotocols``, ``extra_headers``, and + ``user_agent_header``. + + See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the + documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, + ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``. + + """ + + is_client = True + side = "client" + + def __init__( + self, + *, + logger: LoggerLike | None = None, + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + **kwargs: Any, + ) -> None: + if logger is None: + logger = logging.getLogger("websockets.client") + super().__init__(logger=logger, **kwargs) + self.origin = origin + self.available_extensions = extensions + self.available_subprotocols = subprotocols + self.extra_headers = extra_headers + self.user_agent_header = user_agent_header + + def write_http_request(self, path: str, headers: Headers) -> None: + """ + Write request line and headers to the HTTP request. + + """ + self.path = path + self.request_headers = headers + + if self.debug: + self.logger.debug("> GET %s HTTP/1.1", path) + for key, value in headers.raw_items(): + self.logger.debug("> %s: %s", key, value) + + # Since the path and headers only contain ASCII characters, + # we can keep this simple. + request = f"GET {path} HTTP/1.1\r\n" + request += str(headers) + + self.transport.write(request.encode()) + + async def read_http_response(self) -> tuple[int, Headers]: + """ + Read status line and headers from the HTTP response. + + If the response contains a body, it may be read from ``self.reader`` + after this coroutine returns. + + Raises: + InvalidMessage: If the HTTP message is malformed or isn't an + HTTP/1.1 GET response. + + """ + try: + status_code, reason, headers = await read_response(self.reader) + except Exception as exc: + raise InvalidMessage("did not receive a valid HTTP response") from exc + + if self.debug: + self.logger.debug("< HTTP/1.1 %d %s", status_code, reason) + for key, value in headers.raw_items(): + self.logger.debug("< %s: %s", key, value) + + self.response_headers = headers + + return status_code, self.response_headers + + @staticmethod + def process_extensions( + headers: Headers, + available_extensions: Sequence[ClientExtensionFactory] | None, + ) -> list[Extension]: + """ + Handle the Sec-WebSocket-Extensions HTTP response header. + + Check that each extension is supported, as well as its parameters. + + Return the list of accepted extensions. + + Raise :exc:`~websockets.exceptions.InvalidHandshake` to abort the + connection. + + :rfc:`6455` leaves the rules up to the specification of each + :extension. + + To provide this level of flexibility, for each extension accepted by + the server, we check for a match with each extension available in the + client configuration. If no match is found, an exception is raised. + + If several variants of the same extension are accepted by the server, + it may be configured several times, which won't make sense in general. + Extensions must implement their own requirements. For this purpose, + the list of previously accepted extensions is provided. + + Other requirements, for example related to mandatory extensions or the + order of extensions, may be implemented by overriding this method. + + """ + accepted_extensions: list[Extension] = [] + + header_values = headers.get_all("Sec-WebSocket-Extensions") + + if header_values: + if available_extensions is None: + raise NegotiationError("no extensions supported") + + parsed_header_values: list[ExtensionHeader] = sum( + [parse_extension(header_value) for header_value in header_values], [] + ) + + for name, response_params in parsed_header_values: + for extension_factory in available_extensions: + # Skip non-matching extensions based on their name. + if extension_factory.name != name: + continue + + # Skip non-matching extensions based on their params. + try: + extension = extension_factory.process_response_params( + response_params, accepted_extensions + ) + except NegotiationError: + continue + + # Add matching extension to the final list. + accepted_extensions.append(extension) + + # Break out of the loop once we have a match. + break + + # If we didn't break from the loop, no extension in our list + # matched what the server sent. Fail the connection. + else: + raise NegotiationError( + f"Unsupported extension: " + f"name = {name}, params = {response_params}" + ) + + return accepted_extensions + + @staticmethod + def process_subprotocol( + headers: Headers, available_subprotocols: Sequence[Subprotocol] | None + ) -> Subprotocol | None: + """ + Handle the Sec-WebSocket-Protocol HTTP response header. + + Check that it contains exactly one supported subprotocol. + + Return the selected subprotocol. + + """ + subprotocol: Subprotocol | None = None + + header_values = headers.get_all("Sec-WebSocket-Protocol") + + if header_values: + if available_subprotocols is None: + raise NegotiationError("no subprotocols supported") + + parsed_header_values: Sequence[Subprotocol] = sum( + [parse_subprotocol(header_value) for header_value in header_values], [] + ) + + if len(parsed_header_values) > 1: + raise InvalidHeaderValue( + "Sec-WebSocket-Protocol", + f"multiple values: {', '.join(parsed_header_values)}", + ) + + subprotocol = parsed_header_values[0] + + if subprotocol not in available_subprotocols: + raise NegotiationError(f"unsupported subprotocol: {subprotocol}") + + return subprotocol + + async def handshake( + self, + wsuri: WebSocketURI, + origin: Origin | None = None, + available_extensions: Sequence[ClientExtensionFactory] | None = None, + available_subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLike | None = None, + ) -> None: + """ + Perform the client side of the opening handshake. + + Args: + wsuri: URI of the WebSocket server. + origin: Value of the ``Origin`` header. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + extra_headers: Arbitrary HTTP headers to add to the handshake request. + + Raises: + InvalidHandshake: If the handshake fails. + + """ + request_headers = Headers() + + request_headers["Host"] = build_host(wsuri.host, wsuri.port, wsuri.secure) + + if wsuri.user_info: + request_headers["Authorization"] = build_authorization_basic( + *wsuri.user_info + ) + + if origin is not None: + request_headers["Origin"] = origin + + key = build_request(request_headers) + + if available_extensions is not None: + extensions_header = build_extension( + [ + (extension_factory.name, extension_factory.get_request_params()) + for extension_factory in available_extensions + ] + ) + request_headers["Sec-WebSocket-Extensions"] = extensions_header + + if available_subprotocols is not None: + protocol_header = build_subprotocol(available_subprotocols) + request_headers["Sec-WebSocket-Protocol"] = protocol_header + + if self.extra_headers is not None: + request_headers.update(self.extra_headers) + + if self.user_agent_header: + request_headers.setdefault("User-Agent", self.user_agent_header) + + self.write_http_request(wsuri.resource_name, request_headers) + + status_code, response_headers = await self.read_http_response() + if status_code in (301, 302, 303, 307, 308): + if "Location" not in response_headers: + raise InvalidHeader("Location") + raise RedirectHandshake(response_headers["Location"]) + elif status_code != 101: + raise InvalidStatusCode(status_code, response_headers) + + check_response(response_headers, key) + + self.extensions = self.process_extensions( + response_headers, available_extensions + ) + + self.subprotocol = self.process_subprotocol( + response_headers, available_subprotocols + ) + + self.connection_open() + + +class Connect: + """ + Connect to the WebSocket server at ``uri``. + + Awaiting :func:`connect` yields a :class:`WebSocketClientProtocol` which + can then be used to send and receive messages. + + :func:`connect` can be used as a asynchronous context manager:: + + async with connect(...) as websocket: + ... + + The connection is closed automatically when exiting the context. + + :func:`connect` can be used as an infinite asynchronous iterator to + reconnect automatically on errors:: + + async for websocket in connect(...): + try: + ... + except websockets.exceptions.ConnectionClosed: + continue + + The connection is closed automatically after each iteration of the loop. + + If an error occurs while establishing the connection, :func:`connect` + retries with exponential backoff. The backoff delay starts at three + seconds and increases up to one minute. + + If an error occurs in the body of the loop, you can handle the exception + and :func:`connect` will reconnect with the next iteration; or you can + let the exception bubble up and break out of the loop. This lets you + decide which errors trigger a reconnection and which errors are fatal. + + Args: + uri: URI of the WebSocket server. + create_protocol: Factory for the :class:`asyncio.Protocol` managing + the connection. It defaults to :class:`WebSocketClientProtocol`. + Set it to a wrapper or a subclass to customize connection handling. + logger: Logger for this client. + It defaults to ``logging.getLogger("websockets.client")``. + See the :doc:`logging guide <../../topics/logging>` for details. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + origin: Value of the ``Origin`` header, for servers that require it. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + extra_headers: Arbitrary HTTP headers to add to the handshake request. + user_agent_header: Value of the ``User-Agent`` request header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. + Setting it to :obj:`None` removes the header. + open_timeout: Timeout for opening the connection in seconds. + :obj:`None` disables the timeout. + + See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the + documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, + ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``. + + Any other keyword arguments are passed the event loop's + :meth:`~asyncio.loop.create_connection` method. + + For example: + + * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enforce TLS + settings. When connecting to a ``wss://`` URI, if ``ssl`` isn't + provided, a TLS context is created + with :func:`~ssl.create_default_context`. + + * You can set ``host`` and ``port`` to connect to a different host and + port from those found in ``uri``. This only changes the destination of + the TCP connection. The host name from ``uri`` is still used in the TLS + handshake for secure connections and in the ``Host`` header. + + Raises: + InvalidURI: If ``uri`` isn't a valid WebSocket URI. + OSError: If the TCP connection fails. + InvalidHandshake: If the opening handshake fails. + ~asyncio.TimeoutError: If the opening handshake times out. + + """ + + MAX_REDIRECTS_ALLOWED = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10")) + + def __init__( + self, + uri: str, + *, + create_protocol: Callable[..., WebSocketClientProtocol] | None = None, + logger: LoggerLike | None = None, + compression: str | None = "deflate", + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = None, + max_size: int | None = 2**20, + max_queue: int | None = 2**5, + read_limit: int = 2**16, + write_limit: int = 2**16, + **kwargs: Any, + ) -> None: + # Backwards compatibility: close_timeout used to be called timeout. + timeout: float | None = kwargs.pop("timeout", None) + if timeout is None: + timeout = 10 + else: + warnings.warn("rename timeout to close_timeout", DeprecationWarning) + # If both are specified, timeout is ignored. + if close_timeout is None: + close_timeout = timeout + + # Backwards compatibility: create_protocol used to be called klass. + klass: type[WebSocketClientProtocol] | None = kwargs.pop("klass", None) + if klass is None: + klass = WebSocketClientProtocol + else: + warnings.warn("rename klass to create_protocol", DeprecationWarning) + # If both are specified, klass is ignored. + if create_protocol is None: + create_protocol = klass + + # Backwards compatibility: recv() used to return None on closed connections + legacy_recv: bool = kwargs.pop("legacy_recv", False) + + # Backwards compatibility: the loop parameter used to be supported. + _loop: asyncio.AbstractEventLoop | None = kwargs.pop("loop", None) + if _loop is None: + loop = asyncio.get_event_loop() + else: + loop = _loop + warnings.warn("remove loop argument", DeprecationWarning) + + wsuri = parse_uri(uri) + if wsuri.secure: + kwargs.setdefault("ssl", True) + elif kwargs.get("ssl") is not None: + raise ValueError( + "connect() received a ssl argument for a ws:// URI, " + "use a wss:// URI to enable TLS" + ) + + if compression == "deflate": + extensions = enable_client_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + if subprotocols is not None: + validate_subprotocols(subprotocols) + + # Help mypy and avoid this error: "type[WebSocketClientProtocol] | + # Callable[..., WebSocketClientProtocol]" not callable [misc] + create_protocol = cast(Callable[..., WebSocketClientProtocol], create_protocol) + factory = functools.partial( + create_protocol, + logger=logger, + origin=origin, + extensions=extensions, + subprotocols=subprotocols, + extra_headers=extra_headers, + user_agent_header=user_agent_header, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_size=max_size, + max_queue=max_queue, + read_limit=read_limit, + write_limit=write_limit, + host=wsuri.host, + port=wsuri.port, + secure=wsuri.secure, + legacy_recv=legacy_recv, + loop=_loop, + ) + + if kwargs.pop("unix", False): + path: str | None = kwargs.pop("path", None) + create_connection = functools.partial( + loop.create_unix_connection, factory, path, **kwargs + ) + else: + host: str | None + port: int | None + if kwargs.get("sock") is None: + host, port = wsuri.host, wsuri.port + else: + # If sock is given, host and port shouldn't be specified. + host, port = None, None + if kwargs.get("ssl"): + kwargs.setdefault("server_hostname", wsuri.host) + # If host and port are given, override values from the URI. + host = kwargs.pop("host", host) + port = kwargs.pop("port", port) + create_connection = functools.partial( + loop.create_connection, factory, host, port, **kwargs + ) + + self.open_timeout = open_timeout + if logger is None: + logger = logging.getLogger("websockets.client") + self.logger = logger + + # This is a coroutine function. + self._create_connection = create_connection + self._uri = uri + self._wsuri = wsuri + + def handle_redirect(self, uri: str) -> None: + # Update the state of this instance to connect to a new URI. + old_uri = self._uri + old_wsuri = self._wsuri + new_uri = urllib.parse.urljoin(old_uri, uri) + new_wsuri = parse_uri(new_uri) + + # Forbid TLS downgrade. + if old_wsuri.secure and not new_wsuri.secure: + raise SecurityError("redirect from WSS to WS") + + same_origin = ( + old_wsuri.secure == new_wsuri.secure + and old_wsuri.host == new_wsuri.host + and old_wsuri.port == new_wsuri.port + ) + + # Rewrite secure, host, and port for cross-origin redirects. + # This preserves connection overrides with the host and port + # arguments if the redirect points to the same host and port. + if not same_origin: + factory = self._create_connection.args[0] + # Support TLS upgrade. + if not old_wsuri.secure and new_wsuri.secure: + factory.keywords["secure"] = True + self._create_connection.keywords.setdefault("ssl", True) + # Replace secure, host, and port arguments of the protocol factory. + factory = functools.partial( + factory.func, + *factory.args, + **dict(factory.keywords, host=new_wsuri.host, port=new_wsuri.port), + ) + # Replace secure, host, and port arguments of create_connection. + self._create_connection = functools.partial( + self._create_connection.func, + *(factory, new_wsuri.host, new_wsuri.port), + **self._create_connection.keywords, + ) + + # Set the new WebSocket URI. This suffices for same-origin redirects. + self._uri = new_uri + self._wsuri = new_wsuri + + # async for ... in connect(...): + + BACKOFF_INITIAL = float(os.environ.get("WEBSOCKETS_BACKOFF_INITIAL_DELAY", "5")) + BACKOFF_MIN = float(os.environ.get("WEBSOCKETS_BACKOFF_MIN_DELAY", "3.1")) + BACKOFF_MAX = float(os.environ.get("WEBSOCKETS_BACKOFF_MAX_DELAY", "90.0")) + BACKOFF_FACTOR = float(os.environ.get("WEBSOCKETS_BACKOFF_FACTOR", "1.618")) + + async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]: + backoff_delay = self.BACKOFF_MIN / self.BACKOFF_FACTOR + while True: + try: + async with self as protocol: + yield protocol + except Exception as exc: + # Add a random initial delay between 0 and 5 seconds. + # See 7.2.3. Recovering from Abnormal Closure in RFC 6455. + if backoff_delay == self.BACKOFF_MIN: + initial_delay = random.random() * self.BACKOFF_INITIAL + self.logger.info( + "connect failed; reconnecting in %.1f seconds: %s", + initial_delay, + traceback.format_exception_only(exc)[0].strip(), + ) + await asyncio.sleep(initial_delay) + else: + self.logger.info( + "connect failed again; retrying in %d seconds: %s", + int(backoff_delay), + traceback.format_exception_only(exc)[0].strip(), + ) + await asyncio.sleep(int(backoff_delay)) + # Increase delay with truncated exponential backoff. + backoff_delay = backoff_delay * self.BACKOFF_FACTOR + backoff_delay = min(backoff_delay, self.BACKOFF_MAX) + continue + else: + # Connection succeeded - reset backoff delay + backoff_delay = self.BACKOFF_MIN + + # async with connect(...) as ...: + + async def __aenter__(self) -> WebSocketClientProtocol: + return await self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + await self.protocol.close() + + # ... = await connect(...) + + def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]: + # Create a suitable iterator by calling __await__ on a coroutine. + return self.__await_impl__().__await__() + + async def __await_impl__(self) -> WebSocketClientProtocol: + async with asyncio_timeout(self.open_timeout): + for _redirects in range(self.MAX_REDIRECTS_ALLOWED): + _transport, protocol = await self._create_connection() + try: + await protocol.handshake( + self._wsuri, + origin=protocol.origin, + available_extensions=protocol.available_extensions, + available_subprotocols=protocol.available_subprotocols, + extra_headers=protocol.extra_headers, + ) + except RedirectHandshake as exc: + protocol.fail_connection() + await protocol.wait_closed() + self.handle_redirect(exc.uri) + # Avoid leaking a connected socket when the handshake fails. + except (Exception, asyncio.CancelledError): + protocol.fail_connection() + await protocol.wait_closed() + raise + else: + self.protocol = protocol + return protocol + else: + raise SecurityError("too many redirects") + + # ... = yield from connect(...) - remove when dropping Python < 3.11 + + __iter__ = __await__ + + +connect = Connect + + +def unix_connect( + path: str | None = None, + uri: str = "ws://localhost/", + **kwargs: Any, +) -> Connect: + """ + Similar to :func:`connect`, but for connecting to a Unix socket. + + This function builds upon the event loop's + :meth:`~asyncio.loop.create_unix_connection` method. + + It is only available on Unix. + + It's mainly useful for debugging servers listening on Unix sockets. + + Args: + path: File system path to the Unix socket. + uri: URI of the WebSocket server; the host is used in the TLS + handshake for secure connections and in the ``Host`` header. + + """ + return connect(uri=uri, path=path, unix=True, **kwargs) diff --git a/source/websockets/legacy/exceptions.py b/source/websockets/legacy/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..29a2525b4e73b788f773682ce0b88e13eafc6e26 --- /dev/null +++ b/source/websockets/legacy/exceptions.py @@ -0,0 +1,71 @@ +import http + +from .. import datastructures +from ..exceptions import ( + InvalidHandshake, + # InvalidMessage was incorrectly moved here in versions 14.0 and 14.1. + InvalidMessage, # noqa: F401 + ProtocolError as WebSocketProtocolError, # noqa: F401 +) +from ..typing import StatusLike + + +class InvalidStatusCode(InvalidHandshake): + """ + Raised when a handshake response status code is invalid. + + """ + + def __init__(self, status_code: int, headers: datastructures.Headers) -> None: + self.status_code = status_code + self.headers = headers + + def __str__(self) -> str: + return f"server rejected WebSocket connection: HTTP {self.status_code}" + + +class AbortHandshake(InvalidHandshake): + """ + Raised to abort the handshake on purpose and return an HTTP response. + + This exception is an implementation detail. + + The public API is + :meth:`~websockets.legacy.server.WebSocketServerProtocol.process_request`. + + Attributes: + status (~http.HTTPStatus): HTTP status code. + headers (Headers): HTTP response headers. + body (bytes): HTTP response body. + """ + + def __init__( + self, + status: StatusLike, + headers: datastructures.HeadersLike, + body: bytes = b"", + ) -> None: + # If a user passes an int instead of an HTTPStatus, fix it automatically. + self.status = http.HTTPStatus(status) + self.headers = datastructures.Headers(headers) + self.body = body + + def __str__(self) -> str: + return ( + f"HTTP {self.status:d}, {len(self.headers)} headers, {len(self.body)} bytes" + ) + + +class RedirectHandshake(InvalidHandshake): + """ + Raised when a handshake gets redirected. + + This exception is an implementation detail. + + """ + + def __init__(self, uri: str) -> None: + self.uri = uri + + def __str__(self) -> str: + return f"redirect to {self.uri}" diff --git a/source/websockets/legacy/framing.py b/source/websockets/legacy/framing.py new file mode 100644 index 0000000000000000000000000000000000000000..452d2fb34dc45577401321ce76d6a030d2e49597 --- /dev/null +++ b/source/websockets/legacy/framing.py @@ -0,0 +1,224 @@ +from __future__ import annotations + +import struct +from collections.abc import Awaitable, Sequence +from typing import Any, Callable, NamedTuple + +from .. import extensions, frames +from ..exceptions import PayloadTooBig, ProtocolError +from ..typing import BytesLike, DataLike + + +try: + from ..speedups import apply_mask +except ImportError: + from ..utils import apply_mask + + +class Frame(NamedTuple): + fin: bool + opcode: frames.Opcode + data: BytesLike + rsv1: bool = False + rsv2: bool = False + rsv3: bool = False + + @property + def new_frame(self) -> frames.Frame: + return frames.Frame( + self.opcode, + self.data, + self.fin, + self.rsv1, + self.rsv2, + self.rsv3, + ) + + def __str__(self) -> str: + return str(self.new_frame) + + def check(self) -> None: + return self.new_frame.check() + + @classmethod + async def read( + cls, + reader: Callable[[int], Awaitable[bytes]], + *, + mask: bool, + max_size: int | None = None, + extensions: Sequence[extensions.Extension] | None = None, + ) -> Frame: + """ + Read a WebSocket frame. + + Args: + reader: Coroutine that reads exactly the requested number of + bytes, unless the end of file is reached. + mask: Whether the frame should be masked i.e. whether the read + happens on the server side. + max_size: Maximum payload size in bytes. + extensions: List of extensions, applied in reverse order. + + Raises: + PayloadTooBig: If the frame exceeds ``max_size``. + ProtocolError: If the frame contains incorrect values. + + """ + + # Read the header. + data = await reader(2) + head1, head2 = struct.unpack("!BB", data) + + # While not Pythonic, this is marginally faster than calling bool(). + fin = True if head1 & 0b10000000 else False + rsv1 = True if head1 & 0b01000000 else False + rsv2 = True if head1 & 0b00100000 else False + rsv3 = True if head1 & 0b00010000 else False + + try: + opcode = frames.Opcode(head1 & 0b00001111) + except ValueError as exc: + raise ProtocolError("invalid opcode") from exc + + if (True if head2 & 0b10000000 else False) != mask: + raise ProtocolError("incorrect masking") + + length = head2 & 0b01111111 + if length == 126: + data = await reader(2) + (length,) = struct.unpack("!H", data) + elif length == 127: + data = await reader(8) + (length,) = struct.unpack("!Q", data) + if max_size is not None and length > max_size: + raise PayloadTooBig(length, max_size) + if mask: + mask_bits = await reader(4) + + # Read the data. + data = await reader(length) + if mask: + data = apply_mask(data, mask_bits) + + new_frame = frames.Frame(opcode, data, fin, rsv1, rsv2, rsv3) + + if extensions is None: + extensions = [] + for extension in reversed(extensions): + new_frame = extension.decode(new_frame, max_size=max_size) + + new_frame.check() + + return cls( + new_frame.fin, + new_frame.opcode, + new_frame.data, + new_frame.rsv1, + new_frame.rsv2, + new_frame.rsv3, + ) + + def write( + self, + write: Callable[[bytes], Any], + *, + mask: bool, + extensions: Sequence[extensions.Extension] | None = None, + ) -> None: + """ + Write a WebSocket frame. + + Args: + frame: Frame to write. + write: Function that writes bytes. + mask: Whether the frame should be masked i.e. whether the write + happens on the client side. + extensions: List of extensions, applied in order. + + Raises: + ProtocolError: If the frame contains incorrect values. + + """ + # The frame is written in a single call to write in order to prevent + # TCP fragmentation. See #68 for details. This also makes it safe to + # send frames concurrently from multiple coroutines. + write(self.new_frame.serialize(mask=mask, extensions=extensions)) + + +def prepare_data(data: DataLike) -> tuple[int, BytesLike]: + """ + Convert a string or byte-like object to an opcode and a bytes-like object. + + This function is designed for data frames. + + If ``data`` is a :class:`str`, return ``OP_TEXT`` and a :class:`bytes` + object encoding ``data`` in UTF-8. + + If ``data`` is a bytes-like object, return ``OP_BINARY`` and a bytes-like + object. + + Raises: + TypeError: If ``data`` doesn't have a supported type. + + """ + if isinstance(data, str): + return frames.Opcode.TEXT, data.encode() + elif isinstance(data, BytesLike): + return frames.Opcode.BINARY, data + else: + raise TypeError("data must be str or bytes-like") + + +def prepare_ctrl(data: DataLike) -> bytes: + """ + Convert a string or byte-like object to bytes. + + This function is designed for ping and pong frames. + + If ``data`` is a :class:`str`, return a :class:`bytes` object encoding + ``data`` in UTF-8. + + If ``data`` is a bytes-like object, return a :class:`bytes` object. + + Raises: + TypeError: If ``data`` doesn't have a supported type. + + """ + if isinstance(data, str): + return data.encode() + elif isinstance(data, BytesLike): + return bytes(data) + else: + raise TypeError("data must be str or bytes-like") + + +# Backwards compatibility with previously documented public APIs +encode_data = prepare_ctrl + +# Backwards compatibility with previously documented public APIs +from ..frames import Close # noqa: E402 F401, I001 + + +def parse_close(data: bytes) -> tuple[int, str]: + """ + Parse the payload from a close frame. + + Returns: + Close code and reason. + + Raises: + ProtocolError: If data is ill-formed. + UnicodeDecodeError: If the reason isn't valid UTF-8. + + """ + close = Close.parse(data) + return close.code, close.reason + + +def serialize_close(code: int, reason: str) -> bytes: + """ + Serialize the payload for a close frame. + + """ + return Close(code, reason).serialize() diff --git a/source/websockets/legacy/handshake.py b/source/websockets/legacy/handshake.py new file mode 100644 index 0000000000000000000000000000000000000000..6a7157c010720733ca42d363cd810c354bf9d221 --- /dev/null +++ b/source/websockets/legacy/handshake.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +import base64 +import binascii + +from ..datastructures import Headers, MultipleValuesError +from ..exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade +from ..headers import parse_connection, parse_upgrade +from ..typing import ConnectionOption, UpgradeProtocol +from ..utils import accept_key as accept, generate_key + + +__all__ = ["build_request", "check_request", "build_response", "check_response"] + + +def build_request(headers: Headers) -> str: + """ + Build a handshake request to send to the server. + + Update request headers passed in argument. + + Args: + headers: Handshake request headers. + + Returns: + ``key`` that must be passed to :func:`check_response`. + + """ + key = generate_key() + headers["Upgrade"] = "websocket" + headers["Connection"] = "Upgrade" + headers["Sec-WebSocket-Key"] = key + headers["Sec-WebSocket-Version"] = "13" + return key + + +def check_request(headers: Headers) -> str: + """ + Check a handshake request received from the client. + + This function doesn't verify that the request is an HTTP/1.1 or higher GET + request and doesn't perform ``Host`` and ``Origin`` checks. These controls + are usually performed earlier in the HTTP request handling code. They're + the responsibility of the caller. + + Args: + headers: Handshake request headers. + + Returns: + ``key`` that must be passed to :func:`build_response`. + + Raises: + InvalidHandshake: If the handshake request is invalid. + Then, the server must return a 400 Bad Request error. + + """ + connection: list[ConnectionOption] = sum( + [parse_connection(value) for value in headers.get_all("Connection")], [] + ) + + if not any(value.lower() == "upgrade" for value in connection): + raise InvalidUpgrade("Connection", ", ".join(connection)) + + upgrade: list[UpgradeProtocol] = sum( + [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] + ) + + # For compatibility with non-strict implementations, ignore case when + # checking the Upgrade header. The RFC always uses "websocket", except + # in section 11.2. (IANA registration) where it uses "WebSocket". + if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"): + raise InvalidUpgrade("Upgrade", ", ".join(upgrade)) + + try: + s_w_key = headers["Sec-WebSocket-Key"] + except KeyError as exc: + raise InvalidHeader("Sec-WebSocket-Key") from exc + except MultipleValuesError as exc: + raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from exc + + try: + raw_key = base64.b64decode(s_w_key.encode(), validate=True) + except binascii.Error as exc: + raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key) from exc + if len(raw_key) != 16: + raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key) + + try: + s_w_version = headers["Sec-WebSocket-Version"] + except KeyError as exc: + raise InvalidHeader("Sec-WebSocket-Version") from exc + except MultipleValuesError as exc: + raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from exc + + if s_w_version != "13": + raise InvalidHeaderValue("Sec-WebSocket-Version", s_w_version) + + return s_w_key + + +def build_response(headers: Headers, key: str) -> None: + """ + Build a handshake response to send to the client. + + Update response headers passed in argument. + + Args: + headers: Handshake response headers. + key: Returned by :func:`check_request`. + + """ + headers["Upgrade"] = "websocket" + headers["Connection"] = "Upgrade" + headers["Sec-WebSocket-Accept"] = accept(key) + + +def check_response(headers: Headers, key: str) -> None: + """ + Check a handshake response received from the server. + + This function doesn't verify that the response is an HTTP/1.1 or higher + response with a 101 status code. These controls are the responsibility of + the caller. + + Args: + headers: Handshake response headers. + key: Returned by :func:`build_request`. + + Raises: + InvalidHandshake: If the handshake response is invalid. + + """ + connection: list[ConnectionOption] = sum( + [parse_connection(value) for value in headers.get_all("Connection")], [] + ) + + if not any(value.lower() == "upgrade" for value in connection): + raise InvalidUpgrade("Connection", " ".join(connection)) + + upgrade: list[UpgradeProtocol] = sum( + [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] + ) + + # For compatibility with non-strict implementations, ignore case when + # checking the Upgrade header. The RFC always uses "websocket", except + # in section 11.2. (IANA registration) where it uses "WebSocket". + if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"): + raise InvalidUpgrade("Upgrade", ", ".join(upgrade)) + + try: + s_w_accept = headers["Sec-WebSocket-Accept"] + except KeyError as exc: + raise InvalidHeader("Sec-WebSocket-Accept") from exc + except MultipleValuesError as exc: + raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from exc + + if s_w_accept != accept(key): + raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept) diff --git a/source/websockets/legacy/http.py b/source/websockets/legacy/http.py new file mode 100644 index 0000000000000000000000000000000000000000..a7c8a927e177d36f6a54bcc293dc853e2e15e736 --- /dev/null +++ b/source/websockets/legacy/http.py @@ -0,0 +1,201 @@ +from __future__ import annotations + +import asyncio +import os +import re + +from ..datastructures import Headers +from ..exceptions import SecurityError + + +__all__ = ["read_request", "read_response"] + +MAX_NUM_HEADERS = int(os.environ.get("WEBSOCKETS_MAX_NUM_HEADERS", "128")) +MAX_LINE_LENGTH = int(os.environ.get("WEBSOCKETS_MAX_LINE_LENGTH", "8192")) + + +def d(value: bytes) -> str: + """ + Decode a bytestring for interpolating into an error message. + + """ + return value.decode(errors="backslashreplace") + + +# See https://datatracker.ietf.org/doc/html/rfc7230#appendix-B. + +# Regex for validating header names. + +_token_re = re.compile(rb"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+") + +# Regex for validating header values. + +# We don't attempt to support obsolete line folding. + +# Include HTAB (\x09), SP (\x20), VCHAR (\x21-\x7e), obs-text (\x80-\xff). + +# The ABNF is complicated because it attempts to express that optional +# whitespace is ignored. We strip whitespace and don't revalidate that. + +# See also https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189 + +_value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*") + + +async def read_request(stream: asyncio.StreamReader) -> tuple[str, Headers]: + """ + Read an HTTP/1.1 GET request and return ``(path, headers)``. + + ``path`` isn't URL-decoded or validated in any way. + + ``path`` and ``headers`` are expected to contain only ASCII characters. + Other characters are represented with surrogate escapes. + + :func:`read_request` doesn't attempt to read the request body because + WebSocket handshake requests don't have one. If the request contains a + body, it may be read from ``stream`` after this coroutine returns. + + Args: + stream: Input to read the request from. + + Raises: + EOFError: If the connection is closed without a full HTTP request. + SecurityError: If the request exceeds a security limit. + ValueError: If the request isn't well formatted. + + """ + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.1 + + # Parsing is simple because fixed values are expected for method and + # version and because path isn't checked. Since WebSocket software tends + # to implement HTTP/1.1 strictly, there's little need for lenient parsing. + + try: + request_line = await read_line(stream) + except EOFError as exc: + raise EOFError("connection closed while reading HTTP request line") from exc + + try: + method, raw_path, version = request_line.split(b" ", 2) + except ValueError: # not enough values to unpack (expected 3, got 1-2) + raise ValueError(f"invalid HTTP request line: {d(request_line)}") from None + + if method != b"GET": + raise ValueError(f"unsupported HTTP method: {d(method)}") + if version != b"HTTP/1.1": + raise ValueError(f"unsupported HTTP version: {d(version)}") + path = raw_path.decode("ascii", "surrogateescape") + + headers = await read_headers(stream) + + return path, headers + + +async def read_response(stream: asyncio.StreamReader) -> tuple[int, str, Headers]: + """ + Read an HTTP/1.1 response and return ``(status_code, reason, headers)``. + + ``reason`` and ``headers`` are expected to contain only ASCII characters. + Other characters are represented with surrogate escapes. + + :func:`read_request` doesn't attempt to read the response body because + WebSocket handshake responses don't have one. If the response contains a + body, it may be read from ``stream`` after this coroutine returns. + + Args: + stream: Input to read the response from. + + Raises: + EOFError: If the connection is closed without a full HTTP response. + SecurityError: If the response exceeds a security limit. + ValueError: If the response isn't well formatted. + + """ + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.2 + + # As in read_request, parsing is simple because a fixed value is expected + # for version, status_code is a 3-digit number, and reason can be ignored. + + try: + status_line = await read_line(stream) + except EOFError as exc: + raise EOFError("connection closed while reading HTTP status line") from exc + + try: + version, raw_status_code, raw_reason = status_line.split(b" ", 2) + except ValueError: # not enough values to unpack (expected 3, got 1-2) + raise ValueError(f"invalid HTTP status line: {d(status_line)}") from None + + if version != b"HTTP/1.1": + raise ValueError(f"unsupported HTTP version: {d(version)}") + try: + status_code = int(raw_status_code) + except ValueError: # invalid literal for int() with base 10 + raise ValueError(f"invalid HTTP status code: {d(raw_status_code)}") from None + if not 100 <= status_code < 1000: + raise ValueError(f"unsupported HTTP status code: {d(raw_status_code)}") + if not _value_re.fullmatch(raw_reason): + raise ValueError(f"invalid HTTP reason phrase: {d(raw_reason)}") + reason = raw_reason.decode() + + headers = await read_headers(stream) + + return status_code, reason, headers + + +async def read_headers(stream: asyncio.StreamReader) -> Headers: + """ + Read HTTP headers from ``stream``. + + Non-ASCII characters are represented with surrogate escapes. + + """ + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.2 + + # We don't attempt to support obsolete line folding. + + headers = Headers() + for _ in range(MAX_NUM_HEADERS + 1): + try: + line = await read_line(stream) + except EOFError as exc: + raise EOFError("connection closed while reading HTTP headers") from exc + if line == b"": + break + + try: + raw_name, raw_value = line.split(b":", 1) + except ValueError: # not enough values to unpack (expected 2, got 1) + raise ValueError(f"invalid HTTP header line: {d(line)}") from None + if not _token_re.fullmatch(raw_name): + raise ValueError(f"invalid HTTP header name: {d(raw_name)}") + raw_value = raw_value.strip(b" \t") + if not _value_re.fullmatch(raw_value): + raise ValueError(f"invalid HTTP header value: {d(raw_value)}") + + name = raw_name.decode("ascii") # guaranteed to be ASCII at this point + value = raw_value.decode("ascii", "surrogateescape") + headers[name] = value + + else: + raise SecurityError("too many HTTP headers") + + return headers + + +async def read_line(stream: asyncio.StreamReader) -> bytes: + """ + Read a single line from ``stream``. + + CRLF is stripped from the return value. + + """ + # Security: this is bounded by the StreamReader's limit (default = 32 KiB). + line = await stream.readline() + # Security: this guarantees header values are small (hard-coded = 8 KiB) + if len(line) > MAX_LINE_LENGTH: + raise SecurityError("line too long") + # Not mandatory but safe - https://datatracker.ietf.org/doc/html/rfc7230#section-3.5 + if not line.endswith(b"\r\n"): + raise EOFError("line without CRLF") + return line[:-2] diff --git a/source/websockets/legacy/protocol.py b/source/websockets/legacy/protocol.py new file mode 100644 index 0000000000000000000000000000000000000000..ab4ec6bcec9ee886a97815239376dff0d166d8f9 --- /dev/null +++ b/source/websockets/legacy/protocol.py @@ -0,0 +1,1635 @@ +from __future__ import annotations + +import asyncio +import codecs +import collections +import logging +import random +import ssl +import struct +import sys +import time +import traceback +import uuid +import warnings +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterable, Mapping +from typing import Any, Callable, Deque, cast + +from ..asyncio.compatibility import asyncio_timeout +from ..datastructures import Headers +from ..exceptions import ( + ConnectionClosed, + ConnectionClosedError, + ConnectionClosedOK, + InvalidState, + PayloadTooBig, + ProtocolError, +) +from ..extensions import Extension +from ..frames import ( + OK_CLOSE_CODES, + OP_BINARY, + OP_CLOSE, + OP_CONT, + OP_PING, + OP_PONG, + OP_TEXT, + Close, + CloseCode, + Opcode, +) +from ..protocol import State +from ..typing import BytesLike, Data, DataLike, LoggerLike, Subprotocol +from .framing import Frame, prepare_ctrl, prepare_data + + +__all__ = ["WebSocketCommonProtocol"] + + +# In order to ensure consistency, the code always checks the current value of +# WebSocketCommonProtocol.state before assigning a new value and never yields +# between the check and the assignment. + + +class WebSocketCommonProtocol(asyncio.Protocol): + """ + WebSocket connection. + + :class:`WebSocketCommonProtocol` provides APIs shared between WebSocket + servers and clients. You shouldn't use it directly. Instead, use + :class:`~websockets.legacy.client.WebSocketClientProtocol` or + :class:`~websockets.legacy.server.WebSocketServerProtocol`. + + This documentation focuses on low-level details that aren't covered in the + documentation of :class:`~websockets.legacy.client.WebSocketClientProtocol` + and :class:`~websockets.legacy.server.WebSocketServerProtocol` for the sake + of simplicity. + + Once the connection is open, a Ping_ frame is sent every ``ping_interval`` + seconds. This serves as a keepalive. It helps keeping the connection open, + especially in the presence of proxies with short timeouts on inactive + connections. Set ``ping_interval`` to :obj:`None` to disable this behavior. + + .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 + + If the corresponding Pong_ frame isn't received within ``ping_timeout`` + seconds, the connection is considered unusable and is closed with code 1011. + This ensures that the remote endpoint remains responsive. Set + ``ping_timeout`` to :obj:`None` to disable this behavior. + + .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 + + See the discussion of :doc:`keepalive <../../topics/keepalive>` for details. + + The ``close_timeout`` parameter defines a maximum wait time for completing + the closing handshake and terminating the TCP connection. For legacy + reasons, :meth:`close` completes in at most ``5 * close_timeout`` seconds + for clients and ``4 * close_timeout`` for servers. + + ``close_timeout`` is a parameter of the protocol because websockets usually + calls :meth:`close` implicitly upon exit: + + * on the client side, when using :func:`~websockets.legacy.client.connect` + as a context manager; + * on the server side, when the connection handler terminates. + + To apply a timeout to any other API, wrap it in :func:`~asyncio.timeout` or + :func:`~asyncio.wait_for`. + + The ``max_size`` parameter enforces the maximum size for incoming messages + in bytes. The default value is 1 MiB. If a larger message is received, + :meth:`recv` will raise :exc:`~websockets.exceptions.ConnectionClosedError` + and the connection will be closed with code 1009. + + The ``max_queue`` parameter sets the maximum length of the queue that + holds incoming messages. The default value is ``32``. Messages are added + to an in-memory queue when they're received; then :meth:`recv` pops from + that queue. In order to prevent excessive memory consumption when + messages are received faster than they can be processed, the queue must + be bounded. If the queue fills up, the protocol stops processing incoming + data until :meth:`recv` is called. In this situation, various receive + buffers (at least in :mod:`asyncio` and in the OS) will fill up, then the + TCP receive window will shrink, slowing down transmission to avoid packet + loss. + + Since Python can use up to 4 bytes of memory to represent a single + character, each connection may use up to ``4 * max_size * max_queue`` + bytes of memory to store incoming messages. By default, this is 128 MiB. + You may want to lower the limits, depending on your application's + requirements. + + The ``read_limit`` argument sets the high-water limit of the buffer for + incoming bytes. The low-water limit is half the high-water limit. The + default value is 64 KiB, half of asyncio's default (based on the current + implementation of :class:`~asyncio.StreamReader`). + + The ``write_limit`` argument sets the high-water limit of the buffer for + outgoing bytes. The low-water limit is a quarter of the high-water limit. + The default value is 64 KiB, equal to asyncio's default (based on the + current implementation of ``FlowControlMixin``). + + See the discussion of :doc:`memory usage <../../topics/memory>` for details. + + Args: + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.protocol")``. + See the :doc:`logging guide <../../topics/logging>` for details. + ping_interval: Interval between keepalive pings in seconds. + :obj:`None` disables keepalive. + ping_timeout: Timeout for keepalive pings in seconds. + :obj:`None` disables timeouts. + close_timeout: Timeout for closing the connection in seconds. + For legacy reasons, the actual timeout is 4 or 5 times larger. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. + max_queue: Maximum number of incoming messages in receive buffer. + :obj:`None` disables the limit. + read_limit: High-water mark of read buffer in bytes. + write_limit: High-water mark of write buffer in bytes. + + """ + + # There are only two differences between the client-side and server-side + # behavior: masking the payload and closing the underlying TCP connection. + # Set is_client = True/False and side = "client"/"server" to pick a side. + is_client: bool + side: str = "undefined" + + def __init__( + self, + *, + logger: LoggerLike | None = None, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = None, + max_size: int | None = 2**20, + max_queue: int | None = 2**5, + read_limit: int = 2**16, + write_limit: int = 2**16, + # The following arguments are kept only for backwards compatibility. + host: str | None = None, + port: int | None = None, + secure: bool | None = None, + legacy_recv: bool = False, + loop: asyncio.AbstractEventLoop | None = None, + timeout: float | None = None, + ) -> None: + if legacy_recv: # pragma: no cover + warnings.warn("legacy_recv is deprecated", DeprecationWarning) + + # Backwards compatibility: close_timeout used to be called timeout. + if timeout is None: + timeout = 10 + else: + warnings.warn("rename timeout to close_timeout", DeprecationWarning) + # If both are specified, timeout is ignored. + if close_timeout is None: + close_timeout = timeout + + # Backwards compatibility: the loop parameter used to be supported. + if loop is None: + loop = asyncio.get_event_loop() + else: + warnings.warn("remove loop argument", DeprecationWarning) + + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout + self.close_timeout = close_timeout + self.max_size = max_size + self.max_queue = max_queue + self.read_limit = read_limit + self.write_limit = write_limit + + # Unique identifier. For logs. + self.id: uuid.UUID = uuid.uuid4() + """Unique identifier of the connection. Useful in logs.""" + + # Logger or LoggerAdapter for this connection. + if logger is None: + logger = logging.getLogger("websockets.protocol") + self.logger: LoggerLike = logging.LoggerAdapter(logger, {"websocket": self}) + """Logger for this connection.""" + + # Track if DEBUG is enabled. Shortcut logging calls if it isn't. + self.debug = logger.isEnabledFor(logging.DEBUG) + + self.loop = loop + + self._host = host + self._port = port + self._secure = secure + self.legacy_recv = legacy_recv + + # Configure read buffer limits. The high-water limit is defined by + # ``self.read_limit``. The ``limit`` argument controls the line length + # limit and half the buffer limit of :class:`~asyncio.StreamReader`. + # That's why it must be set to half of ``self.read_limit``. + self.reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop) + + # Copied from asyncio.FlowControlMixin + self._paused = False + self._drain_waiter: asyncio.Future[None] | None = None + + # This class implements the data transfer and closing handshake, which + # are shared between the client-side and the server-side. + # Subclasses implement the opening handshake and, on success, execute + # :meth:`connection_open` to change the state to OPEN. + self.state = State.CONNECTING + if self.debug: + self.logger.debug("= connection is CONNECTING") + + # HTTP protocol parameters. + self.path: str + """Path of the opening handshake request.""" + self.request_headers: Headers + """Opening handshake request headers.""" + self.response_headers: Headers + """Opening handshake response headers.""" + + # WebSocket protocol parameters. + self.extensions: list[Extension] = [] + self.subprotocol: Subprotocol | None = None + """Subprotocol, if one was negotiated.""" + + # Close code and reason, set when a close frame is sent or received. + self.close_rcvd: Close | None = None + self.close_sent: Close | None = None + self.close_rcvd_then_sent: bool | None = None + + # Completed when the connection state becomes CLOSED. Translates the + # :meth:`connection_lost` callback to a :class:`~asyncio.Future` + # that can be awaited. (Other :class:`~asyncio.Protocol` callbacks are + # translated by ``self.stream_reader``). + self.connection_lost_waiter: asyncio.Future[None] = loop.create_future() + + # Queue of received messages. + self.messages: Deque[Data] = collections.deque() + self._pop_message_waiter: asyncio.Future[None] | None = None + self._put_message_waiter: asyncio.Future[None] | None = None + + # Protect sending fragmented messages. + self._fragmented_message_waiter: asyncio.Future[None] | None = None + + # Mapping of ping IDs to pong waiters, in chronological order. + self.pings: dict[bytes, tuple[asyncio.Future[float], float]] = {} + + self.latency: float = 0 + """ + Latency of the connection, in seconds. + + Latency is defined as the round-trip time of the connection. It is + measured by sending a Ping frame and waiting for a matching Pong frame. + Before the first measurement, :attr:`latency` is ``0``. + + By default, websockets enables a :ref:`keepalive ` mechanism + that sends Ping frames automatically at regular intervals. You can also + send Ping frames and measure latency with :meth:`ping`. + """ + + # Task running the data transfer. + self.transfer_data_task: asyncio.Task[None] + + # Exception that occurred during data transfer, if any. + self.transfer_data_exc: BaseException | None = None + + # Task sending keepalive pings. + self.keepalive_ping_task: asyncio.Task[None] + + # Task closing the TCP connection. + self.close_connection_task: asyncio.Task[None] + + # Copied from asyncio.FlowControlMixin + async def _drain_helper(self) -> None: # pragma: no cover + if self.connection_lost_waiter.done(): + raise ConnectionResetError("Connection lost") + if not self._paused: + return + waiter = self._drain_waiter + assert waiter is None or waiter.cancelled() + waiter = self.loop.create_future() + self._drain_waiter = waiter + await waiter + + # Copied from asyncio.StreamWriter + async def _drain(self) -> None: # pragma: no cover + if self.reader is not None: + exc = self.reader.exception() + if exc is not None: + raise exc + if self.transport is not None: + if self.transport.is_closing(): + # Yield to the event loop so connection_lost() may be + # called. Without this, _drain_helper() would return + # immediately, and code that calls + # write(...); yield from drain() + # in a loop would never call connection_lost(), so it + # would not see an error when the socket is closed. + await asyncio.sleep(0) + await self._drain_helper() + + def connection_open(self) -> None: + """ + Callback when the WebSocket opening handshake completes. + + Enter the OPEN state and start the data transfer phase. + + """ + # 4.1. The WebSocket Connection is Established. + assert self.state is State.CONNECTING + self.state = State.OPEN + if self.debug: + self.logger.debug("= connection is OPEN") + # Start the task that receives incoming WebSocket messages. + self.transfer_data_task = self.loop.create_task(self.transfer_data()) + # Start the task that sends pings at regular intervals. + self.keepalive_ping_task = self.loop.create_task(self.keepalive_ping()) + # Start the task that eventually closes the TCP connection. + self.close_connection_task = self.loop.create_task(self.close_connection()) + + @property + def host(self) -> str | None: + alternative = "remote_address" if self.is_client else "local_address" + warnings.warn(f"use {alternative}[0] instead of host", DeprecationWarning) + return self._host + + @property + def port(self) -> int | None: + alternative = "remote_address" if self.is_client else "local_address" + warnings.warn(f"use {alternative}[1] instead of port", DeprecationWarning) + return self._port + + @property + def secure(self) -> bool | None: + warnings.warn("don't use secure", DeprecationWarning) + return self._secure + + # Public API + + @property + def local_address(self) -> Any: + """ + Local address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family; + see :meth:`~socket.socket.getsockname`. + + :obj:`None` if the TCP connection isn't established yet. + + """ + try: + transport = self.transport + except AttributeError: + return None + else: + return transport.get_extra_info("sockname") + + @property + def remote_address(self) -> Any: + """ + Remote address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family; + see :meth:`~socket.socket.getpeername`. + + :obj:`None` if the TCP connection isn't established yet. + + """ + try: + transport = self.transport + except AttributeError: + return None + else: + return transport.get_extra_info("peername") + + @property + def open(self) -> bool: + """ + :obj:`True` when the connection is open; :obj:`False` otherwise. + + This attribute may be used to detect disconnections. However, this + approach is discouraged per the EAFP_ principle. Instead, you should + handle :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + .. _EAFP: https://docs.python.org/3/glossary.html#term-eafp + + """ + return self.state is State.OPEN and not self.transfer_data_task.done() + + @property + def closed(self) -> bool: + """ + :obj:`True` when the connection is closed; :obj:`False` otherwise. + + Be aware that both :attr:`open` and :attr:`closed` are :obj:`False` + during the opening and closing sequences. + + """ + return self.state is State.CLOSED + + @property + def close_code(self) -> int | None: + """ + WebSocket close code, defined in `section 7.1.5 of RFC 6455`_. + + .. _section 7.1.5 of RFC 6455: + https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.5 + + :obj:`None` if the connection isn't closed yet. + + """ + if self.state is not State.CLOSED: + return None + elif self.close_rcvd is None: + return CloseCode.ABNORMAL_CLOSURE + else: + return self.close_rcvd.code + + @property + def close_reason(self) -> str | None: + """ + WebSocket close reason, defined in `section 7.1.6 of RFC 6455`_. + + .. _section 7.1.6 of RFC 6455: + https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.6 + + :obj:`None` if the connection isn't closed yet. + + """ + if self.state is not State.CLOSED: + return None + elif self.close_rcvd is None: + return "" + else: + return self.close_rcvd.reason + + async def __aiter__(self) -> AsyncIterator[Data]: + """ + Iterate on incoming messages. + + The iterator exits normally when the connection is closed with the close + code 1000 (OK) or 1001 (going away) or without a close code. + + It raises a :exc:`~websockets.exceptions.ConnectionClosedError` + exception when the connection is closed with any other code. + + """ + try: + while True: + yield await self.recv() + except ConnectionClosedOK: + return + + async def recv(self) -> Data: + """ + Receive the next message. + + When the connection is closed, :meth:`recv` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises + :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal + connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. This is how you detect the end of the + message stream. + + Canceling :meth:`recv` is safe. There's no risk of losing the next + message. The next invocation of :meth:`recv` will return it. + + This makes it possible to enforce a timeout by wrapping :meth:`recv` in + :func:`~asyncio.timeout` or :func:`~asyncio.wait_for`. + + Returns: + A string (:class:`str`) for a Text_ frame. A bytestring + (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If two coroutines call :meth:`recv` concurrently. + + """ + if self._pop_message_waiter is not None: + raise RuntimeError( + "cannot call recv while another coroutine " + "is already waiting for the next message" + ) + + # Don't await self.ensure_open() here: + # - messages could be available in the queue even if the connection + # is closed; + # - messages could be received before the closing frame even if the + # connection is closing. + + # Wait until there's a message in the queue (if necessary) or the + # connection is closed. + while len(self.messages) <= 0: + pop_message_waiter: asyncio.Future[None] = self.loop.create_future() + self._pop_message_waiter = pop_message_waiter + try: + # If asyncio.wait() is canceled, it doesn't cancel + # pop_message_waiter and self.transfer_data_task. + await asyncio.wait( + [pop_message_waiter, self.transfer_data_task], + return_when=asyncio.FIRST_COMPLETED, + ) + finally: + self._pop_message_waiter = None + + # If asyncio.wait(...) exited because self.transfer_data_task + # completed before receiving a new message, raise a suitable + # exception (or return None if legacy_recv is enabled). + if not pop_message_waiter.done(): + if self.legacy_recv: + return None # type: ignore + else: + # Wait until the connection is closed to raise + # ConnectionClosed with the correct code and reason. + await self.ensure_open() + + # Pop a message from the queue. + message = self.messages.popleft() + + # Notify transfer_data(). + if self._put_message_waiter is not None: + self._put_message_waiter.set_result(None) + self._put_message_waiter = None + + return message + + async def send( + self, + message: DataLike | Iterable[DataLike] | AsyncIterable[DataLike], + ) -> None: + """ + Send a message. + + A string (:class:`str`) is sent as a Text_ frame. A bytestring or + bytes-like object (:class:`bytes`, :class:`bytearray`, or + :class:`memoryview`) is sent as a Binary_ frame. + + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + :meth:`send` also accepts an iterable or an asynchronous iterable of + strings, bytestrings, or bytes-like objects to enable fragmentation_. + Each item is treated as a message fragment and sent in its own frame. + All items must be of the same type, or else :meth:`send` will raise a + :exc:`TypeError` and the connection will be closed. + + .. _fragmentation: https://datatracker.ietf.org/doc/html/rfc6455#section-5.4 + + :meth:`send` rejects dict-like objects because this is often an error. + (If you want to send the keys of a dict-like object as fragments, call + its :meth:`~dict.keys` method and pass the result to :meth:`send`.) + + Canceling :meth:`send` is discouraged. Instead, you should close the + connection with :meth:`close`. Indeed, there are only two situations + where :meth:`send` may yield control to the event loop and then get + canceled; in both cases, :meth:`close` has the same effect and is + more clear: + + 1. The write buffer is full. If you don't want to wait until enough + data is sent, your only alternative is to close the connection. + :meth:`close` will likely time out then abort the TCP connection. + 2. ``message`` is an asynchronous iterator that yields control. + Stopping in the middle of a fragmented message will cause a + protocol error and the connection will be closed. + + When the connection is closed, :meth:`send` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it + raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal + connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. + + Args: + message: Message to send. + + Raises: + ConnectionClosed: When the connection is closed. + TypeError: If ``message`` doesn't have a supported type. + + """ + await self.ensure_open() + + # While sending a fragmented message, prevent sending other messages + # until all fragments are sent. + while self._fragmented_message_waiter is not None: + await asyncio.shield(self._fragmented_message_waiter) + + # Unfragmented message -- this case must be handled first because + # strings and bytes-like objects are iterable. + + if isinstance(message, (str, bytes, bytearray, memoryview)): + opcode, data = prepare_data(message) + await self.write_frame(True, opcode, data) + + # Catch a common mistake -- passing a dict to send(). + + elif isinstance(message, Mapping): + raise TypeError("data is a dict-like object") + + # Fragmented message -- regular iterator. + + elif isinstance(message, Iterable): + iter_message = iter(message) + try: + fragment = next(iter_message) + except StopIteration: + return + opcode, data = prepare_data(fragment) + + self._fragmented_message_waiter = self.loop.create_future() + try: + # First fragment. + await self.write_frame(False, opcode, data) + + # Other fragments. + for fragment in iter_message: + confirm_opcode, data = prepare_data(fragment) + if confirm_opcode != opcode: + raise TypeError("data contains inconsistent types") + await self.write_frame(False, OP_CONT, data) + + # Final fragment. + await self.write_frame(True, OP_CONT, b"") + + except (Exception, asyncio.CancelledError): + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + self.fail_connection(CloseCode.INTERNAL_ERROR) + raise + + finally: + self._fragmented_message_waiter.set_result(None) + self._fragmented_message_waiter = None + + # Fragmented message -- asynchronous iterator + + elif isinstance(message, AsyncIterable): + # Implement aiter_message = aiter(message) without aiter + # Work around https://github.com/python/mypy/issues/5738 + aiter_message = cast( + Callable[[AsyncIterable[DataLike]], AsyncIterator[DataLike]], + type(message).__aiter__, + )(message) + try: + # Implement fragment = anext(aiter_message) without anext + # Work around https://github.com/python/mypy/issues/5738 + fragment = await cast( + Callable[[AsyncIterator[DataLike]], Awaitable[DataLike]], + type(aiter_message).__anext__, + )(aiter_message) + except StopAsyncIteration: + return + opcode, data = prepare_data(fragment) + + self._fragmented_message_waiter = self.loop.create_future() + try: + # First fragment. + await self.write_frame(False, opcode, data) + + # Other fragments. + async for fragment in aiter_message: + confirm_opcode, data = prepare_data(fragment) + if confirm_opcode != opcode: + raise TypeError("data contains inconsistent types") + await self.write_frame(False, OP_CONT, data) + + # Final fragment. + await self.write_frame(True, OP_CONT, b"") + + except (Exception, asyncio.CancelledError): + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + self.fail_connection(CloseCode.INTERNAL_ERROR) + raise + + finally: + self._fragmented_message_waiter.set_result(None) + self._fragmented_message_waiter = None + + else: + raise TypeError("data must be str, bytes-like, or iterable") + + async def close( + self, + code: int = CloseCode.NORMAL_CLOSURE, + reason: str = "", + ) -> None: + """ + Perform the closing handshake. + + :meth:`close` waits for the other end to complete the handshake and + for the TCP connection to terminate. As a consequence, there's no need + to await :meth:`wait_closed` after :meth:`close`. + + :meth:`close` is idempotent: it doesn't do anything once the + connection is closed. + + Wrapping :func:`close` in :func:`~asyncio.create_task` is safe, given + that errors during connection termination aren't particularly useful. + + Canceling :meth:`close` is discouraged. If it takes too long, you can + set a shorter ``close_timeout``. If you don't want to wait, let the + Python process exit, then the OS will take care of closing the TCP + connection. + + Args: + code: WebSocket close code. + reason: WebSocket close reason. + + """ + try: + async with asyncio_timeout(self.close_timeout): + await self.write_close_frame(Close(code, reason)) + except asyncio.TimeoutError: + # If the close frame cannot be sent because the send buffers + # are full, the closing handshake won't complete anyway. + # Fail the connection to shut down faster. + self.fail_connection() + + # If no close frame is received within the timeout, asyncio_timeout() + # cancels the data transfer task and raises TimeoutError. + + # If close() is called multiple times concurrently and one of these + # calls hits the timeout, the data transfer task will be canceled. + # Other calls will receive a CancelledError here. + + try: + # If close() is canceled during the wait, self.transfer_data_task + # is canceled before the timeout elapses. + async with asyncio_timeout(self.close_timeout): + await self.transfer_data_task + except (asyncio.TimeoutError, asyncio.CancelledError): + pass + + # Wait for the close connection task to close the TCP connection. + await asyncio.shield(self.close_connection_task) + + async def wait_closed(self) -> None: + """ + Wait until the connection is closed. + + This coroutine is identical to the :attr:`closed` attribute, except it + can be awaited. + + This can make it easier to detect connection termination, regardless + of its cause, in tasks that interact with the WebSocket connection. + + """ + await asyncio.shield(self.connection_lost_waiter) + + async def ping(self, data: DataLike | None = None) -> Awaitable[float]: + """ + Send a Ping_. + + .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 + + A ping may serve as a keepalive, as a check that the remote endpoint + received all messages up to this point, or to measure :attr:`latency`. + + Canceling :meth:`ping` is discouraged. If :meth:`ping` doesn't return + immediately, it means the write buffer is full. If you don't want to + wait, you should close the connection. + + Canceling the :class:`~asyncio.Future` returned by :meth:`ping` has no + effect. + + Args: + data: Payload of the ping. A string will be encoded to UTF-8. + If ``data`` is :obj:`None`, the payload is four random bytes. + + Returns: + A future that will be completed when the corresponding pong is + received. You can ignore it if you don't intend to wait. The result + of the future is the latency of the connection in seconds. + + :: + + pong_waiter = await ws.ping() + # only if you want to wait for the corresponding pong + latency = await pong_waiter + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If another ping was sent with the same data and + the corresponding pong wasn't received yet. + + """ + await self.ensure_open() + + if data is not None: + data = prepare_ctrl(data) + + # Protect against duplicates if a payload is explicitly set. + if data in self.pings: + raise RuntimeError("already waiting for a pong with the same data") + + # Generate a unique random payload otherwise. + while data is None or data in self.pings: + data = struct.pack("!I", random.getrandbits(32)) + + pong_waiter = self.loop.create_future() + # Resolution of time.monotonic() may be too low on Windows. + ping_timestamp = time.perf_counter() + self.pings[data] = (pong_waiter, ping_timestamp) + + await self.write_frame(True, OP_PING, data) + + return asyncio.shield(pong_waiter) + + async def pong(self, data: DataLike = b"") -> None: + """ + Send a Pong_. + + .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 + + An unsolicited pong may serve as a unidirectional heartbeat. + + Canceling :meth:`pong` is discouraged. If :meth:`pong` doesn't return + immediately, it means the write buffer is full. If you don't want to + wait, you should close the connection. + + Args: + data: Payload of the pong. A string will be encoded to UTF-8. + + Raises: + ConnectionClosed: When the connection is closed. + + """ + await self.ensure_open() + + data = prepare_ctrl(data) + + await self.write_frame(True, OP_PONG, data) + + # Private methods - no guarantees. + + def connection_closed_exc(self) -> ConnectionClosed: + exc: ConnectionClosed + if ( + self.close_rcvd is not None + and self.close_rcvd.code in OK_CLOSE_CODES + and self.close_sent is not None + and self.close_sent.code in OK_CLOSE_CODES + ): + exc = ConnectionClosedOK( + self.close_rcvd, + self.close_sent, + self.close_rcvd_then_sent, + ) + else: + exc = ConnectionClosedError( + self.close_rcvd, + self.close_sent, + self.close_rcvd_then_sent, + ) + # Chain to the exception that terminated data transfer, if any. + exc.__cause__ = self.transfer_data_exc + return exc + + async def ensure_open(self) -> None: + """ + Check that the WebSocket connection is open. + + Raise :exc:`~websockets.exceptions.ConnectionClosed` if it isn't. + + """ + # Handle cases from most common to least common for performance. + if self.state is State.OPEN: + # If self.transfer_data_task exited without a closing handshake, + # self.close_connection_task may be closing the connection, going + # straight from OPEN to CLOSED. + if self.transfer_data_task.done(): + await asyncio.shield(self.close_connection_task) + raise self.connection_closed_exc() + else: + return + + if self.state is State.CLOSED: + raise self.connection_closed_exc() + + if self.state is State.CLOSING: + # If we started the closing handshake, wait for its completion to + # get the proper close code and reason. self.close_connection_task + # will complete within 4 or 5 * close_timeout after close(). The + # CLOSING state also occurs when failing the connection. In that + # case self.close_connection_task will complete even faster. + await asyncio.shield(self.close_connection_task) + raise self.connection_closed_exc() + + # Control may only reach this point in buggy third-party subclasses. + assert self.state is State.CONNECTING + raise InvalidState("WebSocket connection isn't established yet") + + async def transfer_data(self) -> None: + """ + Read incoming messages and put them in a queue. + + This coroutine runs in a task until the closing handshake is started. + + """ + try: + while True: + message = await self.read_message() + + # Exit the loop when receiving a close frame. + if message is None: + break + + # Wait until there's room in the queue (if necessary). + if self.max_queue is not None: + while len(self.messages) >= self.max_queue: + self._put_message_waiter = self.loop.create_future() + try: + await asyncio.shield(self._put_message_waiter) + finally: + self._put_message_waiter = None + + # Put the message in the queue. + self.messages.append(message) + + # Notify recv(). + if self._pop_message_waiter is not None: + self._pop_message_waiter.set_result(None) + self._pop_message_waiter = None + + except asyncio.CancelledError as exc: + self.transfer_data_exc = exc + # If fail_connection() cancels this task, avoid logging the error + # twice and failing the connection again. + raise + + except ProtocolError as exc: + self.transfer_data_exc = exc + self.fail_connection(CloseCode.PROTOCOL_ERROR) + + except (ConnectionError, TimeoutError, EOFError, ssl.SSLError) as exc: + # Reading data with self.reader.readexactly may raise: + # - most subclasses of ConnectionError if the TCP connection + # breaks, is reset, or is aborted; + # - TimeoutError if the TCP connection times out; + # - IncompleteReadError, a subclass of EOFError, if fewer + # bytes are available than requested; + # - ssl.SSLError if the other side infringes the TLS protocol. + self.transfer_data_exc = exc + self.fail_connection(CloseCode.ABNORMAL_CLOSURE) + + except UnicodeDecodeError as exc: + self.transfer_data_exc = exc + self.fail_connection(CloseCode.INVALID_DATA) + + except PayloadTooBig as exc: + self.transfer_data_exc = exc + self.fail_connection(CloseCode.MESSAGE_TOO_BIG) + + except Exception as exc: + # This shouldn't happen often because exceptions expected under + # regular circumstances are handled above. If it does, consider + # catching and handling more exceptions. + self.logger.error("data transfer failed", exc_info=True) + + self.transfer_data_exc = exc + self.fail_connection(CloseCode.INTERNAL_ERROR) + + async def read_message(self) -> Data | None: + """ + Read a single message from the connection. + + Re-assemble data frames if the message is fragmented. + + Return :obj:`None` when the closing handshake is started. + + """ + frame = await self.read_data_frame(max_size=self.max_size) + + # A close frame was received. + if frame is None: + return None + + if frame.opcode == OP_TEXT: + text = True + elif frame.opcode == OP_BINARY: + text = False + else: # frame.opcode == OP_CONT + raise ProtocolError("unexpected opcode") + + # Shortcut for the common case - no fragmentation + if frame.fin: + if isinstance(frame.data, memoryview): + raise AssertionError("only compressed outgoing frames use memoryview") + return frame.data.decode() if text else bytes(frame.data) + + # 5.4. Fragmentation + fragments: list[DataLike] = [] + max_size = self.max_size + if text: + decoder_factory = codecs.getincrementaldecoder("utf-8") + decoder = decoder_factory(errors="strict") + if max_size is None: + + def append(frame: Frame) -> None: + nonlocal fragments + fragments.append(decoder.decode(frame.data, frame.fin)) + + else: + + def append(frame: Frame) -> None: + nonlocal fragments, max_size + fragments.append(decoder.decode(frame.data, frame.fin)) + assert isinstance(max_size, int) + max_size -= len(frame.data) + + else: + if max_size is None: + + def append(frame: Frame) -> None: + nonlocal fragments + fragments.append(frame.data) + + else: + + def append(frame: Frame) -> None: + nonlocal fragments, max_size + fragments.append(frame.data) + assert isinstance(max_size, int) + max_size -= len(frame.data) + + append(frame) + + while not frame.fin: + frame = await self.read_data_frame(max_size=max_size) + if frame is None: + raise ProtocolError("incomplete fragmented message") + if frame.opcode != OP_CONT: + raise ProtocolError("unexpected opcode") + append(frame) + + return ("" if text else b"").join(fragments) + + async def read_data_frame(self, max_size: int | None) -> Frame | None: + """ + Read a single data frame from the connection. + + Process control frames received before the next data frame. + + Return :obj:`None` if a close frame is encountered before any data frame. + + """ + # 6.2. Receiving Data + while True: + frame = await self.read_frame(max_size) + + # 5.5. Control Frames + if frame.opcode == OP_CLOSE: + # 7.1.5. The WebSocket Connection Close Code + # 7.1.6. The WebSocket Connection Close Reason + self.close_rcvd = Close.parse(frame.data) + if self.close_sent is not None: + self.close_rcvd_then_sent = False + try: + # Echo the original data instead of re-serializing it with + # Close.serialize() because that fails when the close frame + # is empty and Close.parse() synthesizes a 1005 close code. + await self.write_close_frame(self.close_rcvd, frame.data) + except ConnectionClosed: + # Connection closed before we could echo the close frame. + pass + return None + + elif frame.opcode == OP_PING: + # Answer pings, unless connection is CLOSING. + if self.state is State.OPEN: + try: + await self.pong(frame.data) + except ConnectionClosed: + # Connection closed while draining write buffer. + pass + + elif frame.opcode == OP_PONG: + if frame.data in self.pings: + pong_timestamp = time.perf_counter() + # Sending a pong for only the most recent ping is legal. + # Acknowledge all previous pings too in that case. + ping_id = None + ping_ids = [] + for ping_id, (pong_waiter, ping_timestamp) in self.pings.items(): + ping_ids.append(ping_id) + if not pong_waiter.done(): + pong_waiter.set_result(pong_timestamp - ping_timestamp) + if ping_id == frame.data: + self.latency = pong_timestamp - ping_timestamp + break + else: + raise AssertionError("solicited pong not found in pings") + # Remove acknowledged pings from self.pings. + for ping_id in ping_ids: + del self.pings[ping_id] + + # 5.6. Data Frames + else: + return frame + + async def read_frame(self, max_size: int | None) -> Frame: + """ + Read a single frame from the connection. + + """ + frame = await Frame.read( + self.reader.readexactly, + mask=not self.is_client, + max_size=max_size, + extensions=self.extensions, + ) + if self.debug: + self.logger.debug("< %s", frame) + return frame + + def write_frame_sync(self, fin: bool, opcode: int, data: BytesLike) -> None: + frame = Frame(fin, Opcode(opcode), data) + if self.debug: + self.logger.debug("> %s", frame) + frame.write( + self.transport.write, + mask=self.is_client, + extensions=self.extensions, + ) + + async def drain(self) -> None: + try: + # Handle flow control automatically. + await self._drain() + except ConnectionError: + # Terminate the connection if the socket died. + self.fail_connection() + # Wait until the connection is closed to raise ConnectionClosed + # with the correct code and reason. + await self.ensure_open() + + async def write_frame( + self, fin: bool, opcode: int, data: BytesLike, *, _state: int = State.OPEN + ) -> None: + # Defensive assertion for protocol compliance. + if self.state is not _state: # pragma: no cover + raise InvalidState( + f"Cannot write to a WebSocket in the {self.state.name} state" + ) + self.write_frame_sync(fin, opcode, data) + await self.drain() + + async def write_close_frame( + self, close: Close, data: BytesLike | None = None + ) -> None: + """ + Write a close frame if and only if the connection state is OPEN. + + This dedicated coroutine must be used for writing close frames to + ensure that at most one close frame is sent on a given connection. + + """ + # Test and set the connection state before sending the close frame to + # avoid sending two frames in case of concurrent calls. + if self.state is State.OPEN: + # 7.1.3. The WebSocket Closing Handshake is Started + self.state = State.CLOSING + if self.debug: + self.logger.debug("= connection is CLOSING") + + self.close_sent = close + if self.close_rcvd is not None: + self.close_rcvd_then_sent = True + if data is None: + data = close.serialize() + + # 7.1.2. Start the WebSocket Closing Handshake + await self.write_frame(True, OP_CLOSE, data, _state=State.CLOSING) + + async def keepalive_ping(self) -> None: + """ + Send a Ping frame and wait for a Pong frame at regular intervals. + + This coroutine exits when the connection terminates and one of the + following happens: + + - :meth:`ping` raises :exc:`ConnectionClosed`, or + - :meth:`close_connection` cancels :attr:`keepalive_ping_task`. + + """ + if self.ping_interval is None: + return + + try: + while True: + await asyncio.sleep(self.ping_interval) + + if self.debug: + self.logger.debug("% sending keepalive ping") + pong_waiter = await self.ping() + + if self.ping_timeout is not None: + try: + async with asyncio_timeout(self.ping_timeout): + # Raises CancelledError if the connection is closed, + # when close_connection() cancels keepalive_ping(). + # Raises ConnectionClosed if the connection is lost, + # when connection_lost() calls abort_pings(). + await pong_waiter + if self.debug: + self.logger.debug("% received keepalive pong") + except asyncio.TimeoutError: + if self.debug: + self.logger.debug("- timed out waiting for keepalive pong") + self.fail_connection( + CloseCode.INTERNAL_ERROR, + "keepalive ping timeout", + ) + break + + except ConnectionClosed: + pass + + except Exception: + self.logger.error("keepalive ping failed", exc_info=True) + + async def close_connection(self) -> None: + """ + 7.1.1. Close the WebSocket Connection + + When the opening handshake succeeds, :meth:`connection_open` starts + this coroutine in a task. It waits for the data transfer phase to + complete then it closes the TCP connection cleanly. + + When the opening handshake fails, :meth:`fail_connection` does the + same. There's no data transfer phase in that case. + + """ + try: + # Wait for the data transfer phase to complete. + if hasattr(self, "transfer_data_task"): + try: + await self.transfer_data_task + except asyncio.CancelledError: + pass + + # Cancel the keepalive ping task. + if hasattr(self, "keepalive_ping_task"): + self.keepalive_ping_task.cancel() + + # A client should wait for a TCP close from the server. + if self.is_client and hasattr(self, "transfer_data_task"): + if await self.wait_for_connection_lost(): + return + if self.debug: + self.logger.debug("- timed out waiting for TCP close") + + # Half-close the TCP connection if possible (when there's no TLS). + if self.transport.can_write_eof(): + if self.debug: + self.logger.debug("x half-closing TCP connection") + # write_eof() doesn't document which exceptions it raises. + # "[Errno 107] Transport endpoint is not connected" happens + # but it isn't completely clear under which circumstances. + # uvloop can raise RuntimeError here. + try: + self.transport.write_eof() + except (OSError, RuntimeError): # pragma: no cover + pass + + if await self.wait_for_connection_lost(): + return + if self.debug: + self.logger.debug("- timed out waiting for TCP close") + + finally: + # The try/finally ensures that the transport never remains open, + # even if this coroutine is canceled (for example). + await self.close_transport() + + async def close_transport(self) -> None: + """ + Close the TCP connection. + + """ + # If connection_lost() was called, the TCP connection is closed. + # However, if TLS is enabled, the transport still needs closing. + # Else asyncio complains: ResourceWarning: unclosed transport. + if self.connection_lost_waiter.done() and self.transport.is_closing(): + return + + # Close the TCP connection. Buffers are flushed asynchronously. + if self.debug: + self.logger.debug("x closing TCP connection") + self.transport.close() + + if await self.wait_for_connection_lost(): + return + if self.debug: + self.logger.debug("- timed out waiting for TCP close") + + # Abort the TCP connection. Buffers are discarded. + if self.debug: + self.logger.debug("x aborting TCP connection") + self.transport.abort() + + # connection_lost() is called quickly after aborting. + await self.wait_for_connection_lost() + + async def wait_for_connection_lost(self) -> bool: + """ + Wait until the TCP connection is closed or ``self.close_timeout`` elapses. + + Return :obj:`True` if the connection is closed and :obj:`False` + otherwise. + + """ + if not self.connection_lost_waiter.done(): + try: + async with asyncio_timeout(self.close_timeout): + await asyncio.shield(self.connection_lost_waiter) + except asyncio.TimeoutError: + pass + # Re-check self.connection_lost_waiter.done() synchronously because + # connection_lost() could run between the moment the timeout occurs + # and the moment this coroutine resumes running. + return self.connection_lost_waiter.done() + + def fail_connection( + self, + code: int = CloseCode.ABNORMAL_CLOSURE, + reason: str = "", + ) -> None: + """ + 7.1.7. Fail the WebSocket Connection + + This requires: + + 1. Stopping all processing of incoming data, which means canceling + :attr:`transfer_data_task`. The close code will be 1006 unless a + close frame was received earlier. + + 2. Sending a close frame with an appropriate code if the opening + handshake succeeded and the other side is likely to process it. + + 3. Closing the connection. :meth:`close_connection` takes care of + this once :attr:`transfer_data_task` exits after being canceled. + + (The specification describes these steps in the opposite order.) + + """ + if self.debug: + self.logger.debug("! failing connection with code %d", code) + + # Cancel transfer_data_task if the opening handshake succeeded. + # cancel() is idempotent and ignored if the task is done already. + if hasattr(self, "transfer_data_task"): + self.transfer_data_task.cancel() + + # Send a close frame when the state is OPEN (a close frame was already + # sent if it's CLOSING), except when failing the connection because of + # an error reading from or writing to the network. + # Don't send a close frame if the connection is broken. + if code != CloseCode.ABNORMAL_CLOSURE and self.state is State.OPEN: + close = Close(code, reason) + + # Write the close frame without draining the write buffer. + + # Keeping fail_connection() synchronous guarantees it can't + # get stuck and simplifies the implementation of the callers. + # Not drainig the write buffer is acceptable in this context. + + # This duplicates a few lines of code from write_close_frame(). + + self.state = State.CLOSING + if self.debug: + self.logger.debug("= connection is CLOSING") + + # If self.close_rcvd was set, the connection state would be + # CLOSING. Therefore self.close_rcvd isn't set and we don't + # have to set self.close_rcvd_then_sent. + assert self.close_rcvd is None + self.close_sent = close + + self.write_frame_sync(True, OP_CLOSE, close.serialize()) + + # Start close_connection_task if the opening handshake didn't succeed. + if not hasattr(self, "close_connection_task"): + self.close_connection_task = self.loop.create_task(self.close_connection()) + + def abort_pings(self) -> None: + """ + Raise ConnectionClosed in pending keepalive pings. + + They'll never receive a pong once the connection is closed. + + """ + assert self.state is State.CLOSED + exc = self.connection_closed_exc() + + for pong_waiter, _ping_timestamp in self.pings.values(): + pong_waiter.set_exception(exc) + # If the exception is never retrieved, it will be logged when ping + # is garbage-collected. This is confusing for users. + # Given that ping is done (with an exception), canceling it does + # nothing, but it prevents logging the exception. + pong_waiter.cancel() + + # asyncio.Protocol methods + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + """ + Configure write buffer limits. + + The high-water limit is defined by ``self.write_limit``. + + The low-water limit currently defaults to ``self.write_limit // 4`` in + :meth:`~asyncio.WriteTransport.set_write_buffer_limits`, which should + be all right for reasonable use cases of this library. + + This is the earliest point where we can get hold of the transport, + which means it's the best point for configuring it. + + """ + transport = cast(asyncio.Transport, transport) + transport.set_write_buffer_limits(self.write_limit) + self.transport = transport + + # Copied from asyncio.StreamReaderProtocol + self.reader.set_transport(transport) + + def connection_lost(self, exc: Exception | None) -> None: + """ + 7.1.4. The WebSocket Connection is Closed. + + """ + self.state = State.CLOSED + if self.debug: + self.logger.debug("= connection is CLOSED") + + self.abort_pings() + + # If self.connection_lost_waiter isn't pending, that's a bug, because: + # - it's set only here in connection_lost() which is called only once; + # - it must never be canceled. + self.connection_lost_waiter.set_result(None) + + if True: # pragma: no cover + # Copied from asyncio.StreamReaderProtocol + if self.reader is not None: + if exc is None: + self.reader.feed_eof() + else: + self.reader.set_exception(exc) + + # Copied from asyncio.FlowControlMixin + # Wake up the writer if currently paused. + if not self._paused: + return + waiter = self._drain_waiter + if waiter is None: + return + self._drain_waiter = None + if waiter.done(): + return + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) + + def pause_writing(self) -> None: # pragma: no cover + assert not self._paused + self._paused = True + + def resume_writing(self) -> None: # pragma: no cover + assert self._paused + self._paused = False + + waiter = self._drain_waiter + if waiter is not None: + self._drain_waiter = None + if not waiter.done(): + waiter.set_result(None) + + def data_received(self, data: bytes) -> None: + self.reader.feed_data(data) + + def eof_received(self) -> None: + """ + Close the transport after receiving EOF. + + The WebSocket protocol has its own closing handshake: endpoints close + the TCP or TLS connection after sending and receiving a close frame. + + As a consequence, they never need to write after receiving EOF, so + there's no reason to keep the transport open by returning :obj:`True`. + + Besides, that doesn't work on TLS connections. + + """ + self.reader.feed_eof() + + +# broadcast() is defined in the protocol module even though it's primarily +# used by servers and documented in the server module because it works with +# client connections too and because it's easier to test together with the +# WebSocketCommonProtocol class. + + +def broadcast( + websockets: Iterable[WebSocketCommonProtocol], + message: DataLike, + raise_exceptions: bool = False, +) -> None: + """ + Broadcast a message to several WebSocket connections. + + A string (:class:`str`) is sent as a Text_ frame. A bytestring or bytes-like + object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent + as a Binary_ frame. + + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + :func:`broadcast` pushes the message synchronously to all connections even + if their write buffers are overflowing. There's no backpressure. + + If you broadcast messages faster than a connection can handle them, messages + will pile up in its write buffer until the connection times out. Keep + ``ping_interval`` and ``ping_timeout`` low to prevent excessive memory usage + from slow connections. + + Unlike :meth:`~websockets.legacy.protocol.WebSocketCommonProtocol.send`, + :func:`broadcast` doesn't support sending fragmented messages. Indeed, + fragmentation is useful for sending large messages without buffering them in + memory, while :func:`broadcast` buffers one copy per connection as fast as + possible. + + :func:`broadcast` skips connections that aren't open in order to avoid + errors on connections where the closing handshake is in progress. + + :func:`broadcast` ignores failures to write the message on some connections. + It continues writing to other connections. On Python 3.11 and above, you may + set ``raise_exceptions`` to :obj:`True` to record failures and raise all + exceptions in a :pep:`654` :exc:`ExceptionGroup`. + + While :func:`broadcast` makes more sense for servers, it works identically + with clients, if you have a use case for opening connections to many servers + and broadcasting a message to them. + + Args: + websockets: WebSocket connections to which the message will be sent. + message: Message to send. + raise_exceptions: Whether to raise an exception in case of failures. + + Raises: + TypeError: If ``message`` doesn't have a supported type. + + """ + if not isinstance(message, (str, bytes, bytearray, memoryview)): + raise TypeError("data must be str or bytes-like") + + if raise_exceptions: + if sys.version_info[:2] < (3, 11): # pragma: no cover + raise ValueError("raise_exceptions requires at least Python 3.11") + exceptions = [] + + opcode, data = prepare_data(message) + + for websocket in websockets: + if websocket.state is not State.OPEN: + continue + + if websocket._fragmented_message_waiter is not None: + if raise_exceptions: + exception = RuntimeError("sending a fragmented message") + exceptions.append(exception) + else: + websocket.logger.warning( + "skipped broadcast: sending a fragmented message", + ) + continue + + try: + websocket.write_frame_sync(True, opcode, data) + except Exception as write_exception: + if raise_exceptions: + exception = RuntimeError("failed to write message") + exception.__cause__ = write_exception + exceptions.append(exception) + else: + websocket.logger.warning( + "skipped broadcast: failed to write message: %s", + traceback.format_exception_only(write_exception)[0].strip(), + ) + + if raise_exceptions and exceptions: + raise ExceptionGroup("skipped broadcast", exceptions) + + +# Pretend that broadcast is actually defined in the server module. +broadcast.__module__ = "websockets.legacy.server" diff --git a/source/websockets/legacy/server.py b/source/websockets/legacy/server.py new file mode 100644 index 0000000000000000000000000000000000000000..d2a69c716c34083020033b089a7d4156fbc1e93e --- /dev/null +++ b/source/websockets/legacy/server.py @@ -0,0 +1,1191 @@ +from __future__ import annotations + +import asyncio +import email.utils +import functools +import http +import inspect +import logging +import socket +import warnings +from collections.abc import Awaitable, Generator, Iterable, Sequence +from types import TracebackType +from typing import Any, Callable, cast + +from ..asyncio.compatibility import asyncio_timeout +from ..datastructures import Headers, HeadersLike, MultipleValuesError +from ..exceptions import ( + InvalidHandshake, + InvalidHeader, + InvalidMessage, + InvalidOrigin, + InvalidUpgrade, + NegotiationError, +) +from ..extensions import Extension, ServerExtensionFactory +from ..extensions.permessage_deflate import enable_server_permessage_deflate +from ..headers import ( + build_extension, + parse_extension, + parse_subprotocol, + validate_subprotocols, +) +from ..http11 import SERVER +from ..protocol import State +from ..typing import ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol +from .exceptions import AbortHandshake +from .handshake import build_response, check_request +from .http import read_request +from .protocol import WebSocketCommonProtocol, broadcast + + +__all__ = [ + "broadcast", + "serve", + "unix_serve", + "WebSocketServerProtocol", + "WebSocketServer", +] + + +HeadersLikeOrCallable = HeadersLike | Callable[[str, Headers], HeadersLike] + +HTTPResponse = tuple[StatusLike, HeadersLike, bytes] + + +class WebSocketServerProtocol(WebSocketCommonProtocol): + """ + WebSocket server connection. + + :class:`WebSocketServerProtocol` provides :meth:`recv` and :meth:`send` + coroutines for receiving and sending messages. + + It supports asynchronous iteration to receive messages:: + + async for message in websocket: + await process(message) + + The iterator exits normally when the connection is closed with close code + 1000 (OK) or 1001 (going away) or without a close code. It raises + a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection + is closed with any other code. + + You may customize the opening handshake in a subclass by + overriding :meth:`process_request` or :meth:`select_subprotocol`. + + Args: + ws_server: WebSocket server that created this connection. + + See :func:`serve` for the documentation of ``ws_handler``, ``logger``, ``origins``, + ``extensions``, ``subprotocols``, ``extra_headers``, and ``server_header``. + + See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the + documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, + ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``. + + """ + + is_client = False + side = "server" + + def __init__( + self, + # The version that accepts the path in the second argument is deprecated. + ws_handler: ( + Callable[[WebSocketServerProtocol], Awaitable[Any]] + | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] + ), + ws_server: WebSocketServer, + *, + logger: LoggerLike | None = None, + origins: Sequence[Origin | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLikeOrCallable | None = None, + server_header: str | None = SERVER, + process_request: ( + Callable[[str, Headers], Awaitable[HTTPResponse | None]] | None + ) = None, + select_subprotocol: ( + Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] | None + ) = None, + open_timeout: float | None = 10, + **kwargs: Any, + ) -> None: + if logger is None: + logger = logging.getLogger("websockets.server") + super().__init__(logger=logger, **kwargs) + # For backwards compatibility with 6.0 or earlier. + if origins is not None and "" in origins: + warnings.warn("use None instead of '' in origins", DeprecationWarning) + origins = [None if origin == "" else origin for origin in origins] + # For backwards compatibility with 10.0 or earlier. Done here in + # addition to serve to trigger the deprecation warning on direct + # use of WebSocketServerProtocol. + self.ws_handler = remove_path_argument(ws_handler) + self.ws_server = ws_server + self.origins = origins + self.available_extensions = extensions + self.available_subprotocols = subprotocols + self.extra_headers = extra_headers + self.server_header = server_header + self._process_request = process_request + self._select_subprotocol = select_subprotocol + self.open_timeout = open_timeout + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + """ + Register connection and initialize a task to handle it. + + """ + super().connection_made(transport) + # Register the connection with the server before creating the handler + # task. Registering at the beginning of the handler coroutine would + # create a race condition between the creation of the task, which + # schedules its execution, and the moment the handler starts running. + self.ws_server.register(self) + self.handler_task = self.loop.create_task(self.handler()) + + async def handler(self) -> None: + """ + Handle the lifecycle of a WebSocket connection. + + Since this method doesn't have a caller able to handle exceptions, it + attempts to log relevant ones and guarantees that the TCP connection is + closed before exiting. + + """ + try: + try: + async with asyncio_timeout(self.open_timeout): + await self.handshake( + origins=self.origins, + available_extensions=self.available_extensions, + available_subprotocols=self.available_subprotocols, + extra_headers=self.extra_headers, + ) + except asyncio.TimeoutError: # pragma: no cover + raise + except ConnectionError: + raise + except Exception as exc: + if isinstance(exc, AbortHandshake): + status, headers, body = exc.status, exc.headers, exc.body + elif isinstance(exc, InvalidOrigin): + if self.debug: + self.logger.debug("! invalid origin", exc_info=True) + status, headers, body = ( + http.HTTPStatus.FORBIDDEN, + Headers(), + f"Failed to open a WebSocket connection: {exc}.\n".encode(), + ) + elif isinstance(exc, InvalidUpgrade): + if self.debug: + self.logger.debug("! invalid upgrade", exc_info=True) + status, headers, body = ( + http.HTTPStatus.UPGRADE_REQUIRED, + Headers([("Upgrade", "websocket")]), + ( + f"Failed to open a WebSocket connection: {exc}.\n" + f"\n" + f"You cannot access a WebSocket server directly " + f"with a browser. You need a WebSocket client.\n" + ).encode(), + ) + elif isinstance(exc, InvalidHandshake): + if self.debug: + self.logger.debug("! invalid handshake", exc_info=True) + exc_chain = cast(BaseException, exc) + exc_str = f"{exc_chain}" + while exc_chain.__cause__ is not None: + exc_chain = exc_chain.__cause__ + exc_str += f"; {exc_chain}" + status, headers, body = ( + http.HTTPStatus.BAD_REQUEST, + Headers(), + f"Failed to open a WebSocket connection: {exc_str}.\n".encode(), + ) + else: + self.logger.error("opening handshake failed", exc_info=True) + status, headers, body = ( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + Headers(), + ( + b"Failed to open a WebSocket connection.\n" + b"See server log for more information.\n" + ), + ) + + headers.setdefault("Date", email.utils.formatdate(usegmt=True)) + if self.server_header: + headers.setdefault("Server", self.server_header) + + headers.setdefault("Content-Length", str(len(body))) + headers.setdefault("Content-Type", "text/plain") + headers.setdefault("Connection", "close") + + self.write_http_response(status, headers, body) + self.logger.info( + "connection rejected (%d %s)", status.value, status.phrase + ) + await self.close_transport() + return + + try: + await self.ws_handler(self) + except Exception: + self.logger.error("connection handler failed", exc_info=True) + if not self.closed: + self.fail_connection(1011) + raise + + try: + await self.close() + except ConnectionError: + raise + except Exception: + self.logger.error("closing handshake failed", exc_info=True) + raise + + except Exception: + # Last-ditch attempt to avoid leaking connections on errors. + try: + self.transport.close() + except Exception: # pragma: no cover + pass + + finally: + # Unregister the connection with the server when the handler task + # terminates. Registration is tied to the lifecycle of the handler + # task because the server waits for tasks attached to registered + # connections before terminating. + self.ws_server.unregister(self) + self.logger.info("connection closed") + + async def read_http_request(self) -> tuple[str, Headers]: + """ + Read request line and headers from the HTTP request. + + If the request contains a body, it may be read from ``self.reader`` + after this coroutine returns. + + Raises: + InvalidMessage: If the HTTP message is malformed or isn't an + HTTP/1.1 GET request. + + """ + try: + path, headers = await read_request(self.reader) + except asyncio.CancelledError: # pragma: no cover + raise + except Exception as exc: + raise InvalidMessage("did not receive a valid HTTP request") from exc + + if self.debug: + self.logger.debug("< GET %s HTTP/1.1", path) + for key, value in headers.raw_items(): + self.logger.debug("< %s: %s", key, value) + + self.path = path + self.request_headers = headers + + return path, headers + + def write_http_response( + self, status: http.HTTPStatus, headers: Headers, body: bytes | None = None + ) -> None: + """ + Write status line and headers to the HTTP response. + + This coroutine is also able to write a response body. + + """ + self.response_headers = headers + + if self.debug: + self.logger.debug("> HTTP/1.1 %d %s", status.value, status.phrase) + for key, value in headers.raw_items(): + self.logger.debug("> %s: %s", key, value) + if body is not None: + self.logger.debug("> [body] (%d bytes)", len(body)) + + # Since the status line and headers only contain ASCII characters, + # we can keep this simple. + response = f"HTTP/1.1 {status.value} {status.phrase}\r\n" + response += str(headers) + + self.transport.write(response.encode()) + + if body is not None: + self.transport.write(body) + + async def process_request( + self, path: str, request_headers: Headers + ) -> HTTPResponse | None: + """ + Intercept the HTTP request and return an HTTP response if appropriate. + + You may override this method in a :class:`WebSocketServerProtocol` + subclass, for example: + + * to return an HTTP 200 OK response on a given path; then a load + balancer can use this path for a health check; + * to authenticate the request and return an HTTP 401 Unauthorized or an + HTTP 403 Forbidden when authentication fails. + + You may also override this method with the ``process_request`` + argument of :func:`serve` and :class:`WebSocketServerProtocol`. This + is equivalent, except ``process_request`` won't have access to the + protocol instance, so it can't store information for later use. + + :meth:`process_request` is expected to complete quickly. If it may run + for a long time, then it should await :meth:`wait_closed` and exit if + :meth:`wait_closed` completes, or else it could prevent the server + from shutting down. + + Args: + path: Request path, including optional query string. + request_headers: Request headers. + + Returns: + tuple[StatusLike, HeadersLike, bytes] | None: :obj:`None` to + continue the WebSocket handshake normally. + + An HTTP response, represented by a 3-uple of the response status, + headers, and body, to abort the WebSocket handshake and return + that HTTP response instead. + + """ + if self._process_request is not None: + response = self._process_request(path, request_headers) + if isinstance(response, Awaitable): + return await response + else: + # For backwards compatibility with 7.0. + warnings.warn( + "declare process_request as a coroutine", DeprecationWarning + ) + return response + return None + + @staticmethod + def process_origin( + headers: Headers, origins: Sequence[Origin | None] | None = None + ) -> Origin | None: + """ + Handle the Origin HTTP request header. + + Args: + headers: Request headers. + origins: Optional list of acceptable origins. + + Raises: + InvalidOrigin: If the origin isn't acceptable. + + """ + # "The user agent MUST NOT include more than one Origin header field" + # per https://datatracker.ietf.org/doc/html/rfc6454#section-7.3. + try: + origin = headers.get("Origin") + except MultipleValuesError as exc: + raise InvalidHeader("Origin", "multiple values") from exc + if origin is not None: + origin = cast(Origin, origin) + if origins is not None: + if origin not in origins: + raise InvalidOrigin(origin) + return origin + + @staticmethod + def process_extensions( + headers: Headers, + available_extensions: Sequence[ServerExtensionFactory] | None, + ) -> tuple[str | None, list[Extension]]: + """ + Handle the Sec-WebSocket-Extensions HTTP request header. + + Accept or reject each extension proposed in the client request. + Negotiate parameters for accepted extensions. + + Return the Sec-WebSocket-Extensions HTTP response header and the list + of accepted extensions. + + :rfc:`6455` leaves the rules up to the specification of each + :extension. + + To provide this level of flexibility, for each extension proposed by + the client, we check for a match with each extension available in the + server configuration. If no match is found, the extension is ignored. + + If several variants of the same extension are proposed by the client, + it may be accepted several times, which won't make sense in general. + Extensions must implement their own requirements. For this purpose, + the list of previously accepted extensions is provided. + + This process doesn't allow the server to reorder extensions. It can + only select a subset of the extensions proposed by the client. + + Other requirements, for example related to mandatory extensions or the + order of extensions, may be implemented by overriding this method. + + Args: + headers: Request headers. + extensions: Optional list of supported extensions. + + Raises: + InvalidHandshake: To abort the handshake with an HTTP 400 error. + + """ + response_header_value: str | None = None + + extension_headers: list[ExtensionHeader] = [] + accepted_extensions: list[Extension] = [] + + header_values = headers.get_all("Sec-WebSocket-Extensions") + + if header_values and available_extensions: + parsed_header_values: list[ExtensionHeader] = sum( + [parse_extension(header_value) for header_value in header_values], [] + ) + + for name, request_params in parsed_header_values: + for ext_factory in available_extensions: + # Skip non-matching extensions based on their name. + if ext_factory.name != name: + continue + + # Skip non-matching extensions based on their params. + try: + response_params, extension = ext_factory.process_request_params( + request_params, accepted_extensions + ) + except NegotiationError: + continue + + # Add matching extension to the final list. + extension_headers.append((name, response_params)) + accepted_extensions.append(extension) + + # Break out of the loop once we have a match. + break + + # If we didn't break from the loop, no extension in our list + # matched what the client sent. The extension is declined. + + # Serialize extension header. + if extension_headers: + response_header_value = build_extension(extension_headers) + + return response_header_value, accepted_extensions + + # Not @staticmethod because it calls self.select_subprotocol() + def process_subprotocol( + self, headers: Headers, available_subprotocols: Sequence[Subprotocol] | None + ) -> Subprotocol | None: + """ + Handle the Sec-WebSocket-Protocol HTTP request header. + + Return Sec-WebSocket-Protocol HTTP response header, which is the same + as the selected subprotocol. + + Args: + headers: Request headers. + available_subprotocols: Optional list of supported subprotocols. + + Raises: + InvalidHandshake: To abort the handshake with an HTTP 400 error. + + """ + subprotocol: Subprotocol | None = None + + header_values = headers.get_all("Sec-WebSocket-Protocol") + + if header_values and available_subprotocols: + parsed_header_values: list[Subprotocol] = sum( + [parse_subprotocol(header_value) for header_value in header_values], [] + ) + + subprotocol = self.select_subprotocol( + parsed_header_values, available_subprotocols + ) + + return subprotocol + + def select_subprotocol( + self, + client_subprotocols: Sequence[Subprotocol], + server_subprotocols: Sequence[Subprotocol], + ) -> Subprotocol | None: + """ + Pick a subprotocol among those supported by the client and the server. + + If several subprotocols are available, select the preferred subprotocol + by giving equal weight to the preferences of the client and the server. + + If no subprotocol is available, proceed without a subprotocol. + + You may provide a ``select_subprotocol`` argument to :func:`serve` or + :class:`WebSocketServerProtocol` to override this logic. For example, + you could reject the handshake if the client doesn't support a + particular subprotocol, rather than accept the handshake without that + subprotocol. + + Args: + client_subprotocols: List of subprotocols offered by the client. + server_subprotocols: List of subprotocols available on the server. + + Returns: + Selected subprotocol, if a common subprotocol was found. + + :obj:`None` to continue without a subprotocol. + + """ + if self._select_subprotocol is not None: + return self._select_subprotocol(client_subprotocols, server_subprotocols) + + subprotocols = set(client_subprotocols) & set(server_subprotocols) + if not subprotocols: + return None + return sorted( + subprotocols, + key=lambda p: client_subprotocols.index(p) + server_subprotocols.index(p), + )[0] + + async def handshake( + self, + origins: Sequence[Origin | None] | None = None, + available_extensions: Sequence[ServerExtensionFactory] | None = None, + available_subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLikeOrCallable | None = None, + ) -> str: + """ + Perform the server side of the opening handshake. + + Args: + origins: List of acceptable values of the Origin HTTP header; + include :obj:`None` if the lack of an origin is acceptable. + extensions: List of supported extensions, in order in which they + should be tried. + subprotocols: List of supported subprotocols, in order of + decreasing preference. + extra_headers: Arbitrary HTTP headers to add to the response when + the handshake succeeds. + + Returns: + path of the URI of the request. + + Raises: + InvalidHandshake: If the handshake fails. + + """ + path, request_headers = await self.read_http_request() + + # Hook for customizing request handling, for example checking + # authentication or treating some paths as plain HTTP endpoints. + early_response_awaitable = self.process_request(path, request_headers) + if isinstance(early_response_awaitable, Awaitable): + early_response = await early_response_awaitable + else: + # For backwards compatibility with 7.0. + warnings.warn("declare process_request as a coroutine", DeprecationWarning) + early_response = early_response_awaitable + + # The connection may drop while process_request is running. + if self.state is State.CLOSED: + # This subclass of ConnectionError is silently ignored in handler(). + raise BrokenPipeError("connection closed during opening handshake") + + # Change the response to a 503 error if the server is shutting down. + if not self.ws_server.is_serving(): + early_response = ( + http.HTTPStatus.SERVICE_UNAVAILABLE, + [], + b"Server is shutting down.\n", + ) + + if early_response is not None: + raise AbortHandshake(*early_response) + + key = check_request(request_headers) + + self.origin = self.process_origin(request_headers, origins) + + extensions_header, self.extensions = self.process_extensions( + request_headers, available_extensions + ) + + protocol_header = self.subprotocol = self.process_subprotocol( + request_headers, available_subprotocols + ) + + response_headers = Headers() + + build_response(response_headers, key) + + if extensions_header is not None: + response_headers["Sec-WebSocket-Extensions"] = extensions_header + + if protocol_header is not None: + response_headers["Sec-WebSocket-Protocol"] = protocol_header + + if callable(extra_headers): + extra_headers = extra_headers(path, self.request_headers) + if extra_headers is not None: + response_headers.update(extra_headers) + + response_headers.setdefault("Date", email.utils.formatdate(usegmt=True)) + if self.server_header is not None: + response_headers.setdefault("Server", self.server_header) + + self.write_http_response(http.HTTPStatus.SWITCHING_PROTOCOLS, response_headers) + + self.logger.info("connection open") + + self.connection_open() + + return path + + +class WebSocketServer: + """ + WebSocket server returned by :func:`serve`. + + This class mirrors the API of :class:`~asyncio.Server`. + + It keeps track of WebSocket connections in order to close them properly + when shutting down. + + Args: + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. + See the :doc:`logging guide <../../topics/logging>` for details. + + """ + + def __init__(self, logger: LoggerLike | None = None) -> None: + if logger is None: + logger = logging.getLogger("websockets.server") + self.logger = logger + + # Keep track of active connections. + self.websockets: set[WebSocketServerProtocol] = set() + + # Task responsible for closing the server and terminating connections. + self.close_task: asyncio.Task[None] | None = None + + # Completed when the server is closed and connections are terminated. + self.closed_waiter: asyncio.Future[None] + + def wrap(self, server: asyncio.base_events.Server) -> None: + """ + Attach to a given :class:`~asyncio.Server`. + + Since :meth:`~asyncio.loop.create_server` doesn't support injecting a + custom ``Server`` class, the easiest solution that doesn't rely on + private :mod:`asyncio` APIs is to: + + - instantiate a :class:`WebSocketServer` + - give the protocol factory a reference to that instance + - call :meth:`~asyncio.loop.create_server` with the factory + - attach the resulting :class:`~asyncio.Server` with this method + + """ + self.server = server + for sock in server.sockets: + if sock.family == socket.AF_INET: + name = "%s:%d" % sock.getsockname() + elif sock.family == socket.AF_INET6: + name = "[%s]:%d" % sock.getsockname()[:2] + elif sock.family == socket.AF_UNIX: + name = sock.getsockname() + # In the unlikely event that someone runs websockets over a + # protocol other than IP or Unix sockets, avoid crashing. + else: # pragma: no cover + name = str(sock.getsockname()) + self.logger.info("server listening on %s", name) + + # Initialized here because we need a reference to the event loop. + # This could be moved back to __init__ now that Python < 3.10 isn't + # supported anymore, but I'm not taking that risk in legacy code. + self.closed_waiter = server.get_loop().create_future() + + def register(self, protocol: WebSocketServerProtocol) -> None: + """ + Register a connection with this server. + + """ + self.websockets.add(protocol) + + def unregister(self, protocol: WebSocketServerProtocol) -> None: + """ + Unregister a connection with this server. + + """ + self.websockets.remove(protocol) + + def close(self, close_connections: bool = True) -> None: + """ + Close the server. + + * Close the underlying :class:`~asyncio.Server`. + * When ``close_connections`` is :obj:`True`, which is the default, + close existing connections. Specifically: + + * Reject opening WebSocket connections with an HTTP 503 (service + unavailable) error. This happens when the server accepted the TCP + connection but didn't complete the opening handshake before closing. + * Close open WebSocket connections with close code 1001 (going away). + + * Wait until all connection handlers terminate. + + :meth:`close` is idempotent. + + """ + if self.close_task is None: + self.close_task = self.get_loop().create_task( + self._close(close_connections) + ) + + async def _close(self, close_connections: bool) -> None: + """ + Implementation of :meth:`close`. + + This calls :meth:`~asyncio.Server.close` on the underlying + :class:`~asyncio.Server` object to stop accepting new connections and + then closes open connections with close code 1001. + + """ + self.logger.info("server closing") + + # Stop accepting new connections. + self.server.close() + + # Wait until all accepted connections reach connection_made() and call + # register(). See https://github.com/python/cpython/issues/79033 for + # details. This workaround can be removed when dropping Python < 3.11. + await asyncio.sleep(0) + + if close_connections: + # Close OPEN connections with close code 1001. After server.close(), + # handshake() closes OPENING connections with an HTTP 503 error. + close_tasks = [ + asyncio.create_task(websocket.close(1001)) + for websocket in self.websockets + if websocket.state is not State.CONNECTING + ] + # asyncio.wait doesn't accept an empty first argument. + if close_tasks: + await asyncio.wait(close_tasks) + + # Wait until all TCP connections are closed. + await self.server.wait_closed() + + # Wait until all connection handlers terminate. + # asyncio.wait doesn't accept an empty first argument. + if self.websockets: + await asyncio.wait( + [websocket.handler_task for websocket in self.websockets] + ) + + # Tell wait_closed() to return. + self.closed_waiter.set_result(None) + + self.logger.info("server closed") + + async def wait_closed(self) -> None: + """ + Wait until the server is closed. + + When :meth:`wait_closed` returns, all TCP connections are closed and + all connection handlers have returned. + + To ensure a fast shutdown, a connection handler should always be + awaiting at least one of: + + * :meth:`~WebSocketServerProtocol.recv`: when the connection is closed, + it raises :exc:`~websockets.exceptions.ConnectionClosedOK`; + * :meth:`~WebSocketServerProtocol.wait_closed`: when the connection is + closed, it returns. + + Then the connection handler is immediately notified of the shutdown; + it can clean up and exit. + + """ + await asyncio.shield(self.closed_waiter) + + def get_loop(self) -> asyncio.AbstractEventLoop: + """ + See :meth:`asyncio.Server.get_loop`. + + """ + return self.server.get_loop() + + def is_serving(self) -> bool: + """ + See :meth:`asyncio.Server.is_serving`. + + """ + return self.server.is_serving() + + async def start_serving(self) -> None: # pragma: no cover + """ + See :meth:`asyncio.Server.start_serving`. + + Typical use:: + + server = await serve(..., start_serving=False) + # perform additional setup here... + # ... then start the server + await server.start_serving() + + """ + await self.server.start_serving() + + async def serve_forever(self) -> None: # pragma: no cover + """ + See :meth:`asyncio.Server.serve_forever`. + + Typical use:: + + server = await serve(...) + # this coroutine doesn't return + # canceling it stops the server + await server.serve_forever() + + This is an alternative to using :func:`serve` as an asynchronous context + manager. Shutdown is triggered by canceling :meth:`serve_forever` + instead of exiting a :func:`serve` context. + + """ + await self.server.serve_forever() + + @property + def sockets(self) -> Iterable[socket.socket]: + """ + See :attr:`asyncio.Server.sockets`. + + """ + return self.server.sockets + + async def __aenter__(self) -> WebSocketServer: # pragma: no cover + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: # pragma: no cover + self.close() + await self.wait_closed() + + +class Serve: + """ + Start a WebSocket server listening on ``host`` and ``port``. + + Whenever a client connects, the server creates a + :class:`WebSocketServerProtocol`, performs the opening handshake, and + delegates to the connection handler, ``ws_handler``. + + The handler receives the :class:`WebSocketServerProtocol` and uses it to + send and receive messages. + + Once the handler completes, either normally or with an exception, the + server performs the closing handshake and closes the connection. + + Awaiting :func:`serve` yields a :class:`WebSocketServer`. This object + provides a :meth:`~WebSocketServer.close` method to shut down the server:: + + # set this future to exit the server + stop = asyncio.get_running_loop().create_future() + + server = await serve(...) + await stop + server.close() + await server.wait_closed() + + :func:`serve` can be used as an asynchronous context manager. Then, the + server is shut down automatically when exiting the context:: + + # set this future to exit the server + stop = asyncio.get_running_loop().create_future() + + async with serve(...): + await stop + + Args: + ws_handler: Connection handler. It receives the WebSocket connection, + which is a :class:`WebSocketServerProtocol`, in argument. + host: Network interfaces the server binds to. + See :meth:`~asyncio.loop.create_server` for details. + port: TCP port the server listens on. + See :meth:`~asyncio.loop.create_server` for details. + create_protocol: Factory for the :class:`asyncio.Protocol` managing + the connection. It defaults to :class:`WebSocketServerProtocol`. + Set it to a wrapper or a subclass to customize connection handling. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. + See the :doc:`logging guide <../../topics/logging>` for details. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + origins: Acceptable values of the ``Origin`` header, for defending + against Cross-Site WebSocket Hijacking attacks. Include :obj:`None` + in the list if the lack of an origin is acceptable. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + extra_headers (HeadersLike | Callable[[str, Headers] | HeadersLike]): + Arbitrary HTTP headers to add to the response. This can be + a :data:`~websockets.datastructures.HeadersLike` or a callable + taking the request path and headers in arguments and returning + a :data:`~websockets.datastructures.HeadersLike`. + server_header: Value of the ``Server`` response header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. + Setting it to :obj:`None` removes the header. + process_request (Callable[[str, Headers], \ + Awaitable[tuple[StatusLike, HeadersLike, bytes] | None]] | None): + Intercept HTTP request before the opening handshake. + See :meth:`~WebSocketServerProtocol.process_request` for details. + select_subprotocol: Select a subprotocol supported by the client. + See :meth:`~WebSocketServerProtocol.select_subprotocol` for details. + open_timeout: Timeout for opening connections in seconds. + :obj:`None` disables the timeout. + + See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the + documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, + ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``. + + Any other keyword arguments are passed the event loop's + :meth:`~asyncio.loop.create_server` method. + + For example: + + * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enable TLS. + + * You can set ``sock`` to a :obj:`~socket.socket` that you created + outside of websockets. + + Returns: + WebSocket server. + + """ + + def __init__( + self, + # The version that accepts the path in the second argument is deprecated. + ws_handler: ( + Callable[[WebSocketServerProtocol], Awaitable[Any]] + | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] + ), + host: str | Sequence[str] | None = None, + port: int | None = None, + *, + create_protocol: Callable[..., WebSocketServerProtocol] | None = None, + logger: LoggerLike | None = None, + compression: str | None = "deflate", + origins: Sequence[Origin | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + extra_headers: HeadersLikeOrCallable | None = None, + server_header: str | None = SERVER, + process_request: ( + Callable[[str, Headers], Awaitable[HTTPResponse | None]] | None + ) = None, + select_subprotocol: ( + Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] | None + ) = None, + open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = None, + max_size: int | None = 2**20, + max_queue: int | None = 2**5, + read_limit: int = 2**16, + write_limit: int = 2**16, + **kwargs: Any, + ) -> None: + # Backwards compatibility: close_timeout used to be called timeout. + timeout: float | None = kwargs.pop("timeout", None) + if timeout is None: + timeout = 10 + else: + warnings.warn("rename timeout to close_timeout", DeprecationWarning) + # If both are specified, timeout is ignored. + if close_timeout is None: + close_timeout = timeout + + # Backwards compatibility: create_protocol used to be called klass. + klass: type[WebSocketServerProtocol] | None = kwargs.pop("klass", None) + if klass is None: + klass = WebSocketServerProtocol + else: + warnings.warn("rename klass to create_protocol", DeprecationWarning) + # If both are specified, klass is ignored. + if create_protocol is None: + create_protocol = klass + + # Backwards compatibility: recv() used to return None on closed connections + legacy_recv: bool = kwargs.pop("legacy_recv", False) + + # Backwards compatibility: the loop parameter used to be supported. + _loop: asyncio.AbstractEventLoop | None = kwargs.pop("loop", None) + if _loop is None: + loop = asyncio.get_event_loop() + else: + loop = _loop + warnings.warn("remove loop argument", DeprecationWarning) + + ws_server = WebSocketServer(logger=logger) + + secure = kwargs.get("ssl") is not None + + if compression == "deflate": + extensions = enable_server_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + if subprotocols is not None: + validate_subprotocols(subprotocols) + + # Help mypy and avoid this error: "type[WebSocketServerProtocol] | + # Callable[..., WebSocketServerProtocol]" not callable [misc] + create_protocol = cast(Callable[..., WebSocketServerProtocol], create_protocol) + factory = functools.partial( + create_protocol, + # For backwards compatibility with 10.0 or earlier. Done here in + # addition to WebSocketServerProtocol to trigger the deprecation + # warning once per serve() call rather than once per connection. + remove_path_argument(ws_handler), + ws_server, + host=host, + port=port, + secure=secure, + open_timeout=open_timeout, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_size=max_size, + max_queue=max_queue, + read_limit=read_limit, + write_limit=write_limit, + loop=_loop, + legacy_recv=legacy_recv, + origins=origins, + extensions=extensions, + subprotocols=subprotocols, + extra_headers=extra_headers, + server_header=server_header, + process_request=process_request, + select_subprotocol=select_subprotocol, + logger=logger, + ) + + if kwargs.pop("unix", False): + path: str | None = kwargs.pop("path", None) + # unix_serve(path) must not specify host and port parameters. + assert host is None and port is None + create_server = functools.partial( + loop.create_unix_server, factory, path, **kwargs + ) + else: + create_server = functools.partial( + loop.create_server, factory, host, port, **kwargs + ) + + # This is a coroutine function. + self._create_server = create_server + self.ws_server = ws_server + + # async with serve(...) + + async def __aenter__(self) -> WebSocketServer: + return await self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.ws_server.close() + await self.ws_server.wait_closed() + + # await serve(...) + + def __await__(self) -> Generator[Any, None, WebSocketServer]: + # Create a suitable iterator by calling __await__ on a coroutine. + return self.__await_impl__().__await__() + + async def __await_impl__(self) -> WebSocketServer: + server = await self._create_server() + self.ws_server.wrap(server) + return self.ws_server + + # yield from serve(...) - remove when dropping Python < 3.11 + + __iter__ = __await__ + + +serve = Serve + + +def unix_serve( + # The version that accepts the path in the second argument is deprecated. + ws_handler: ( + Callable[[WebSocketServerProtocol], Awaitable[Any]] + | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] + ), + path: str | None = None, + **kwargs: Any, +) -> Serve: + """ + Start a WebSocket server listening on a Unix socket. + + This function is identical to :func:`serve`, except the ``host`` and + ``port`` arguments are replaced by ``path``. It is only available on Unix. + + Unrecognized keyword arguments are passed the event loop's + :meth:`~asyncio.loop.create_unix_server` method. + + It's useful for deploying a server behind a reverse proxy such as nginx. + + Args: + path: File system path to the Unix socket. + + """ + return serve(ws_handler, path=path, unix=True, **kwargs) + + +def remove_path_argument( + ws_handler: ( + Callable[[WebSocketServerProtocol], Awaitable[Any]] + | Callable[[WebSocketServerProtocol, str], Awaitable[Any]] + ), +) -> Callable[[WebSocketServerProtocol], Awaitable[Any]]: + try: + inspect.signature(ws_handler).bind(None) + except TypeError: + try: + inspect.signature(ws_handler).bind(None, "") + except TypeError: # pragma: no cover + # ws_handler accepts neither one nor two arguments; leave it alone. + pass + else: + # ws_handler accepts two arguments; activate backwards compatibility. + warnings.warn("remove second argument of ws_handler", DeprecationWarning) + + async def _ws_handler(websocket: WebSocketServerProtocol) -> Any: + return await cast( + Callable[[WebSocketServerProtocol, str], Awaitable[Any]], + ws_handler, + )(websocket, websocket.path) + + return _ws_handler + + return cast( + Callable[[WebSocketServerProtocol], Awaitable[Any]], + ws_handler, + ) diff --git a/source/websockets/protocol.py b/source/websockets/protocol.py new file mode 100644 index 0000000000000000000000000000000000000000..66ebbe3ffb302225e625f027dc79d0590e85793d --- /dev/null +++ b/source/websockets/protocol.py @@ -0,0 +1,768 @@ +from __future__ import annotations + +import enum +import logging +import uuid +from collections.abc import Generator + +from .exceptions import ( + ConnectionClosed, + ConnectionClosedError, + ConnectionClosedOK, + InvalidState, + PayloadTooBig, + ProtocolError, +) +from .extensions import Extension +from .frames import ( + OK_CLOSE_CODES, + OP_BINARY, + OP_CLOSE, + OP_CONT, + OP_PING, + OP_PONG, + OP_TEXT, + Close, + CloseCode, + Frame, +) +from .http11 import Request, Response +from .streams import StreamReader +from .typing import BytesLike, LoggerLike, Origin, Subprotocol + + +__all__ = [ + "Protocol", + "Side", + "State", + "SEND_EOF", +] + +Event = Request | Response | Frame +"""Events that :meth:`~Protocol.events_received` may return.""" + + +class Side(enum.IntEnum): + """A WebSocket connection is either a server or a client.""" + + SERVER, CLIENT = range(2) + + +SERVER = Side.SERVER +CLIENT = Side.CLIENT + + +class State(enum.IntEnum): + """A WebSocket connection is in one of these four states.""" + + CONNECTING, OPEN, CLOSING, CLOSED = range(4) + + +CONNECTING = State.CONNECTING +OPEN = State.OPEN +CLOSING = State.CLOSING +CLOSED = State.CLOSED + + +SEND_EOF = b"" +"""Sentinel signaling that the TCP connection must be half-closed.""" + + +class Protocol: + """ + Sans-I/O implementation of a WebSocket connection. + + Args: + side: :attr:`~Side.CLIENT` or :attr:`~Side.SERVER`. + state: Initial state of the WebSocket connection. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. You may pass a ``(max_message_size, + max_fragment_size)`` tuple to set different limits for messages and + fragments when you expect long messages sent in short fragments. + logger: Logger for this connection; depending on ``side``, + defaults to ``logging.getLogger("websockets.client")`` + or ``logging.getLogger("websockets.server")``; + see the :doc:`logging guide <../../topics/logging>` for details. + + """ + + def __init__( + self, + side: Side, + *, + state: State = OPEN, + max_size: tuple[int | None, int | None] | int | None = 2**20, + logger: LoggerLike | None = None, + ) -> None: + # Unique identifier. For logs. + self.id: uuid.UUID = uuid.uuid4() + """Unique identifier of the connection. Useful in logs.""" + + # Logger or LoggerAdapter for this connection. + if logger is None: + logger = logging.getLogger(f"websockets.{side.name.lower()}") + self.logger: LoggerLike = logger + """Logger for this connection.""" + + # Track if DEBUG is enabled. Shortcut logging calls if it isn't. + self.debug = logger.isEnabledFor(logging.DEBUG) + + # Connection side. CLIENT or SERVER. + self.side = side + + # Connection state. Initially OPEN because subclasses handle CONNECTING. + self.state = state + + # Maximum size of incoming messages in bytes. + if isinstance(max_size, int) or max_size is None: + self.max_message_size, self.max_fragment_size = max_size, None + else: + self.max_message_size, self.max_fragment_size = max_size + + # Current size of incoming message in bytes. Only set while reading a + # fragmented message i.e. a data frames with the FIN bit not set. + self.current_size: int | None = None + + # True while sending a fragmented message i.e. a data frames with the + # FIN bit not set. + self.expect_continuation_frame = False + + # WebSocket protocol parameters. + self.origin: Origin | None = None + self.extensions: list[Extension] = [] + self.subprotocol: Subprotocol | None = None + + # Close code and reason, set when a close frame is sent or received. + self.close_rcvd: Close | None = None + self.close_sent: Close | None = None + self.close_rcvd_then_sent: bool | None = None + + # Track if an exception happened during the handshake. + self.handshake_exc: Exception | None = None + """ + Exception to raise if the opening handshake failed. + + :obj:`None` if the opening handshake succeeded. + + """ + + # Track if send_eof() was called. + self.eof_sent = False + + # Parser state. + self.reader = StreamReader() + self.events: list[Event] = [] + self.writes: list[bytes] = [] + self.parser = self.parse() + next(self.parser) # start coroutine + self.parser_exc: Exception | None = None + + @property + def state(self) -> State: + """ + State of the WebSocket connection. + + Defined in 4.1_, 4.2_, 7.1.3_, and 7.1.4_ of :rfc:`6455`. + + .. _4.1: https://datatracker.ietf.org/doc/html/rfc6455#section-4.1 + .. _4.2: https://datatracker.ietf.org/doc/html/rfc6455#section-4.2 + .. _7.1.3: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.3 + .. _7.1.4: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.4 + + """ + return self._state + + @state.setter + def state(self, state: State) -> None: + if self.debug: + self.logger.debug("= connection is %s", state.name) + self._state = state + + @property + def close_code(self) -> int | None: + """ + WebSocket close code received from the remote endpoint. + + Defined in 7.1.5_ of :rfc:`6455`. + + .. _7.1.5: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.5 + + :obj:`None` if the connection isn't closed yet. + + """ + if self.state is not CLOSED: + return None + elif self.close_rcvd is None: + return CloseCode.ABNORMAL_CLOSURE + else: + return self.close_rcvd.code + + @property + def close_reason(self) -> str | None: + """ + WebSocket close reason received from the remote endpoint. + + Defined in 7.1.6_ of :rfc:`6455`. + + .. _7.1.6: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.6 + + :obj:`None` if the connection isn't closed yet. + + """ + if self.state is not CLOSED: + return None + elif self.close_rcvd is None: + return "" + else: + return self.close_rcvd.reason + + @property + def close_exc(self) -> ConnectionClosed: + """ + Exception to raise when trying to interact with a closed connection. + + Don't raise this exception while the connection :attr:`state` + is :attr:`~websockets.protocol.State.CLOSING`; wait until + it's :attr:`~websockets.protocol.State.CLOSED`. + + Indeed, the exception includes the close code and reason, which are + known only once the connection is closed. + + Raises: + AssertionError: If the connection isn't closed yet. + + """ + assert self.state is CLOSED, "connection isn't closed yet" + exc_type: type[ConnectionClosed] + if ( + self.close_rcvd is not None + and self.close_sent is not None + and self.close_rcvd.code in OK_CLOSE_CODES + and self.close_sent.code in OK_CLOSE_CODES + ): + exc_type = ConnectionClosedOK + else: + exc_type = ConnectionClosedError + exc: ConnectionClosed = exc_type( + self.close_rcvd, + self.close_sent, + self.close_rcvd_then_sent, + ) + # Chain to the exception raised in the parser, if any. + exc.__cause__ = self.parser_exc + return exc + + # Public methods for receiving data. + + def receive_data(self, data: bytes | bytearray) -> None: + """ + Receive data from the network. + + After calling this method: + + - You must call :meth:`data_to_send` and send this data to the network. + - You should call :meth:`events_received` and process resulting events. + + Raises: + EOFError: If :meth:`receive_eof` was called earlier. + + """ + self.reader.feed_data(data) + next(self.parser) + + def receive_eof(self) -> None: + """ + Receive the end of the data stream from the network. + + After calling this method: + + - You must call :meth:`data_to_send` and send this data to the network; + it will return ``[b""]``, signaling the end of the stream, or ``[]``. + - You aren't expected to call :meth:`events_received`; it won't return + any new events. + + :meth:`receive_eof` is idempotent. + + """ + if self.reader.eof: + return + self.reader.feed_eof() + next(self.parser) + + # Public methods for sending events. + + def send_continuation(self, data: BytesLike, fin: bool) -> None: + """ + Send a `Continuation frame`_. + + .. _Continuation frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + Parameters: + data: payload containing the same kind of data + as the initial frame. + fin: FIN bit; set it to :obj:`True` if this is the last frame + of a fragmented message and to :obj:`False` otherwise. + + Raises: + ProtocolError: If a fragmented message isn't in progress. + + """ + if not self.expect_continuation_frame: + raise ProtocolError("unexpected continuation frame") + if self._state is not OPEN: + raise InvalidState(f"connection is {self.state.name.lower()}") + self.expect_continuation_frame = not fin + self.send_frame(Frame(OP_CONT, data, fin)) + + def send_text(self, data: BytesLike, fin: bool = True) -> None: + """ + Send a `Text frame`_. + + .. _Text frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + Parameters: + data: payload containing text encoded with UTF-8. + fin: FIN bit; set it to :obj:`False` if this is the first frame of + a fragmented message. + + Raises: + ProtocolError: If a fragmented message is in progress. + + """ + if self.expect_continuation_frame: + raise ProtocolError("expected a continuation frame") + if self._state is not OPEN: + raise InvalidState(f"connection is {self.state.name.lower()}") + self.expect_continuation_frame = not fin + self.send_frame(Frame(OP_TEXT, data, fin)) + + def send_binary(self, data: BytesLike, fin: bool = True) -> None: + """ + Send a `Binary frame`_. + + .. _Binary frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + Parameters: + data: payload containing arbitrary binary data. + fin: FIN bit; set it to :obj:`False` if this is the first frame of + a fragmented message. + + Raises: + ProtocolError: If a fragmented message is in progress. + + """ + if self.expect_continuation_frame: + raise ProtocolError("expected a continuation frame") + if self._state is not OPEN: + raise InvalidState(f"connection is {self.state.name.lower()}") + self.expect_continuation_frame = not fin + self.send_frame(Frame(OP_BINARY, data, fin)) + + def send_close(self, code: CloseCode | int | None = None, reason: str = "") -> None: + """ + Send a `Close frame`_. + + .. _Close frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.1 + + Parameters: + code: close code. + reason: close reason. + + Raises: + ProtocolError: If the code isn't valid or if a reason is provided + without a code. + + """ + # While RFC 6455 doesn't rule out sending more than one close Frame, + # websockets is conservative in what it sends and doesn't allow that. + if self._state is not OPEN: + raise InvalidState(f"connection is {self.state.name.lower()}") + if code is None: + if reason != "": + raise ProtocolError("cannot send a reason without a code") + close = Close(CloseCode.NO_STATUS_RCVD, "") + data = b"" + else: + close = Close(code, reason) + data = close.serialize() + # 7.1.3. The WebSocket Closing Handshake is Started + self.send_frame(Frame(OP_CLOSE, data)) + # Since the state is OPEN, no close frame was received yet. + # As a consequence, self.close_rcvd_then_sent remains None. + assert self.close_rcvd is None + self.close_sent = close + self.state = CLOSING + + def send_ping(self, data: BytesLike) -> None: + """ + Send a `Ping frame`_. + + .. _Ping frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 + + Parameters: + data: payload containing arbitrary binary data. + + """ + # RFC 6455 allows control frames after starting the closing handshake. + if self._state is not OPEN and self._state is not CLOSING: + raise InvalidState(f"connection is {self.state.name.lower()}") + self.send_frame(Frame(OP_PING, data)) + + def send_pong(self, data: BytesLike) -> None: + """ + Send a `Pong frame`_. + + .. _Pong frame: + https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 + + Parameters: + data: payload containing arbitrary binary data. + + """ + # RFC 6455 allows control frames after starting the closing handshake. + if self._state is not OPEN and self._state is not CLOSING: + raise InvalidState(f"connection is {self.state.name.lower()}") + self.send_frame(Frame(OP_PONG, data)) + + def fail(self, code: CloseCode | int, reason: str = "") -> None: + """ + `Fail the WebSocket connection`_. + + .. _Fail the WebSocket connection: + https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.7 + + Parameters: + code: close code + reason: close reason + + Raises: + ProtocolError: If the code isn't valid. + """ + # 7.1.7. Fail the WebSocket Connection + + # Send a close frame when the state is OPEN (a close frame was already + # sent if it's CLOSING), except when failing the connection because + # of an error reading from or writing to the network. + if self.state is OPEN: + if code != CloseCode.ABNORMAL_CLOSURE: + close = Close(code, reason) + data = close.serialize() + self.send_frame(Frame(OP_CLOSE, data)) + self.close_sent = close + # If recv_messages() raised an exception upon receiving a close + # frame but before echoing it, then close_rcvd is not None even + # though the state is OPEN. This happens when the connection is + # closed while receiving a fragmented message. + if self.close_rcvd is not None: + self.close_rcvd_then_sent = True + self.state = CLOSING + + # When failing the connection, a server closes the TCP connection + # without waiting for the client to complete the handshake, while a + # client waits for the server to close the TCP connection, possibly + # after sending a close frame that the client will ignore. + if self.side is SERVER and not self.eof_sent: + self.send_eof() + + # 7.1.7. Fail the WebSocket Connection "An endpoint MUST NOT continue + # to attempt to process data(including a responding Close frame) from + # the remote endpoint after being instructed to _Fail the WebSocket + # Connection_." + self.parser = self.discard() + next(self.parser) # start coroutine + + # Public method for getting incoming events after receiving data. + + def events_received(self) -> list[Event]: + """ + Fetch events generated from data received from the network. + + Call this method immediately after any of the ``receive_*()`` methods. + + Process resulting events, likely by passing them to the application. + + Returns: + Events read from the connection. + """ + events, self.events = self.events, [] + return events + + # Public method for getting outgoing data after receiving data or sending events. + + def data_to_send(self) -> list[bytes]: + """ + Obtain data to send to the network. + + Call this method immediately after any of the ``receive_*()``, + ``send_*()``, or :meth:`fail` methods. + + Write resulting data to the connection. + + The empty bytestring :data:`~websockets.protocol.SEND_EOF` signals + the end of the data stream. When you receive it, half-close the TCP + connection. + + Returns: + Data to write to the connection. + + """ + writes, self.writes = self.writes, [] + return writes + + def close_expected(self) -> bool: + """ + Tell if the TCP connection is expected to close soon. + + Call this method immediately after any of the ``receive_*()``, + ``send_close()``, or :meth:`fail` methods. + + If it returns :obj:`True`, schedule closing the TCP connection after a + short timeout if the other side hasn't already closed it. + + Returns: + Whether the TCP connection is expected to close soon. + + """ + # During the opening handshake, when our state is CONNECTING, we expect + # a TCP close if and only if the hansdake fails. When it does, we start + # the TCP closing handshake by sending EOF with send_eof(). + + # Once the opening handshake completes successfully, we expect a TCP + # close if and only if we sent a close frame, meaning that our state + # progressed to CLOSING: + + # * Normal closure: once we send a close frame, we expect a TCP close: + # server waits for client to complete the TCP closing handshake; + # client waits for server to initiate the TCP closing handshake. + + # * Abnormal closure: we always send a close frame and the same logic + # applies, except on EOFError where we don't send a close frame + # because we already received the TCP close, so we don't expect it. + + # If our state is CLOSED, we already received a TCP close so we don't + # expect it anymore. + + # Micro-optimization: put the most common case first + if self.state is OPEN: + return False + if self.state is CLOSING: + return True + if self.state is CLOSED: + return False + assert self.state is CONNECTING + return self.eof_sent + + # Private methods for receiving data. + + def parse(self) -> Generator[None]: + """ + Parse incoming data into frames. + + :meth:`receive_data` and :meth:`receive_eof` run this generator + coroutine until it needs more data or reaches EOF. + + :meth:`parse` never raises an exception. Instead, it sets the + :attr:`parser_exc` and yields control. + + """ + try: + while True: + if (yield from self.reader.at_eof()): + if self.debug: + self.logger.debug("< EOF") + # If the WebSocket connection is closed cleanly, with a + # closing handhshake, recv_frame() substitutes parse() + # with discard(). This branch is reached only when the + # connection isn't closed cleanly. + raise EOFError("unexpected end of stream") + + max_size = None + + if self.max_message_size is not None: + if self.current_size is None: + max_size = self.max_message_size + else: + max_size = self.max_message_size - self.current_size + + if self.max_fragment_size is not None: + if max_size is None: + max_size = self.max_fragment_size + else: + max_size = min(max_size, self.max_fragment_size) + + # During a normal closure, execution ends here on the next + # iteration of the loop after receiving a close frame. At + # this point, recv_frame() replaced parse() by discard(). + frame = yield from Frame.parse( + self.reader.read_exact, + mask=self.side is SERVER, + max_size=max_size, + extensions=self.extensions, + ) + + if self.debug: + self.logger.debug("< %s", frame) + + self.recv_frame(frame) + + except ProtocolError as exc: + self.fail(CloseCode.PROTOCOL_ERROR, str(exc)) + self.parser_exc = exc + + except EOFError as exc: + self.fail(CloseCode.ABNORMAL_CLOSURE, str(exc)) + self.parser_exc = exc + + except UnicodeDecodeError as exc: + self.fail(CloseCode.INVALID_DATA, f"{exc.reason} at position {exc.start}") + self.parser_exc = exc + + except PayloadTooBig as exc: + exc.set_current_size(self.current_size) + self.fail(CloseCode.MESSAGE_TOO_BIG, str(exc)) + self.parser_exc = exc + + except Exception as exc: + self.logger.error("parser failed", exc_info=True) + # Don't include exception details, which may be security-sensitive. + self.fail(CloseCode.INTERNAL_ERROR) + self.parser_exc = exc + + # During an abnormal closure, execution ends here after catching an + # exception. At this point, fail() replaced parse() by discard(). + yield + raise AssertionError("parse() shouldn't step after error") + + def discard(self) -> Generator[None]: + """ + Discard incoming data. + + This coroutine replaces :meth:`parse`: + + - after receiving a close frame, during a normal closure (1.4); + - after sending a close frame, during an abnormal closure (7.1.7). + + """ + # After the opening handshake completes, the server closes the TCP + # connection in the same circumstances where discard() replaces parse(). + # The client closes it when it receives EOF from the server or times + # out. (The latter case cannot be handled in this Sans-I/O layer.) + assert (self.side is SERVER or self.state is CONNECTING) == (self.eof_sent) + while not (yield from self.reader.at_eof()): + self.reader.discard() + if self.debug: + self.logger.debug("< EOF") + # A server closes the TCP connection immediately, while a client + # waits for the server to close the TCP connection. + if self.side is CLIENT and self.state is not CONNECTING: + self.send_eof() + self.state = CLOSED + # If discard() completes normally, execution ends here. + yield + # Once the reader reaches EOF, its feed_data/eof() methods raise an + # error, so our receive_data/eof() methods don't step the generator. + raise AssertionError("discard() shouldn't step after EOF") + + def recv_frame(self, frame: Frame) -> None: + """ + Process an incoming frame. + + """ + if frame.opcode is OP_TEXT or frame.opcode is OP_BINARY: + if self.current_size is not None: + raise ProtocolError("expected a continuation frame") + if not frame.fin: + self.current_size = len(frame.data) + + elif frame.opcode is OP_CONT: + if self.current_size is None: + raise ProtocolError("unexpected continuation frame") + if frame.fin: + self.current_size = None + else: + self.current_size += len(frame.data) + + elif frame.opcode is OP_PING: + # 5.5.2. Ping: "Upon receipt of a Ping frame, an endpoint MUST + # send a Pong frame in response" + pong_frame = Frame(OP_PONG, frame.data) + self.send_frame(pong_frame) + + elif frame.opcode is OP_PONG: + # 5.5.3 Pong: "A response to an unsolicited Pong frame is not + # expected." + pass + + elif frame.opcode is OP_CLOSE: + # 7.1.5. The WebSocket Connection Close Code + # 7.1.6. The WebSocket Connection Close Reason + self.close_rcvd = Close.parse(frame.data) + if self.state is CLOSING: + assert self.close_sent is not None + self.close_rcvd_then_sent = False + + if self.current_size is not None: + raise ProtocolError("incomplete fragmented message") + + # 5.5.1 Close: "If an endpoint receives a Close frame and did + # not previously send a Close frame, the endpoint MUST send a + # Close frame in response. (When sending a Close frame in + # response, the endpoint typically echos the status code it + # received.)" + + if self.state is OPEN: + # Echo the original data instead of re-serializing it with + # Close.serialize() because that fails when the close frame + # is empty and Close.parse() synthesizes a 1005 close code. + # The rest is identical to send_close(). + self.send_frame(Frame(OP_CLOSE, frame.data)) + self.close_sent = self.close_rcvd + self.close_rcvd_then_sent = True + self.state = CLOSING + + # 7.1.2. Start the WebSocket Closing Handshake: "Once an + # endpoint has both sent and received a Close control frame, + # that endpoint SHOULD _Close the WebSocket Connection_" + + # A server closes the TCP connection immediately, while a client + # waits for the server to close the TCP connection. + if self.side is SERVER: + self.send_eof() + + # 1.4. Closing Handshake: "after receiving a control frame + # indicating the connection should be closed, a peer discards + # any further data received." + # RFC 6455 allows reading Ping and Pong frames after a Close frame. + # However, that doesn't seem useful; websockets doesn't support it. + self.parser = self.discard() + next(self.parser) # start coroutine + + else: + # This can't happen because Frame.parse() validates opcodes. + raise AssertionError(f"unexpected opcode: {frame.opcode:02x}") + + self.events.append(frame) + + # Private methods for sending events. + + def send_frame(self, frame: Frame) -> None: + if self.debug: + self.logger.debug("> %s", frame) + self.writes.append( + frame.serialize( + mask=self.side is CLIENT, + extensions=self.extensions, + ) + ) + + def send_eof(self) -> None: + assert not self.eof_sent + self.eof_sent = True + if self.debug: + self.logger.debug("> EOF") + self.writes.append(SEND_EOF) diff --git a/source/websockets/proxy.py b/source/websockets/proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..a343b37bcda58da54d81002e4a2d3b9fe6f08f76 --- /dev/null +++ b/source/websockets/proxy.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +import dataclasses +import urllib.parse +import urllib.request + +from .datastructures import Headers +from .exceptions import InvalidProxy +from .headers import build_authorization_basic, build_host +from .http11 import USER_AGENT +from .uri import DELIMS, WebSocketURI + + +__all__ = ["get_proxy", "parse_proxy", "Proxy"] + + +@dataclasses.dataclass +class Proxy: + """ + Proxy address. + + Attributes: + scheme: ``"socks5h"``, ``"socks5"``, ``"socks4a"``, ``"socks4"``, + ``"https"``, or ``"http"``. + host: Normalized to lower case. + port: Always set even if it's the default. + username: Available when the proxy address contains `User Information`_. + password: Available when the proxy address contains `User Information`_. + + .. _User Information: https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.1 + + """ + + scheme: str + host: str + port: int + username: str | None = None + password: str | None = None + + @property + def user_info(self) -> tuple[str, str] | None: + if self.username is None: + return None + assert self.password is not None + return (self.username, self.password) + + +def parse_proxy(proxy: str) -> Proxy: + """ + Parse and validate a proxy. + + Args: + proxy: proxy. + + Returns: + Parsed proxy. + + Raises: + InvalidProxy: If ``proxy`` isn't a valid proxy. + + """ + parsed = urllib.parse.urlparse(proxy) + if parsed.scheme not in ["socks5h", "socks5", "socks4a", "socks4", "https", "http"]: + raise InvalidProxy(proxy, f"scheme {parsed.scheme} isn't supported") + if parsed.hostname is None: + raise InvalidProxy(proxy, "hostname isn't provided") + if parsed.path not in ["", "/"]: + raise InvalidProxy(proxy, "path is meaningless") + if parsed.query != "": + raise InvalidProxy(proxy, "query is meaningless") + if parsed.fragment != "": + raise InvalidProxy(proxy, "fragment is meaningless") + + scheme = parsed.scheme + host = parsed.hostname + port = parsed.port or (443 if parsed.scheme == "https" else 80) + username = parsed.username + password = parsed.password + # urllib.parse.urlparse accepts URLs with a username but without a + # password. This doesn't make sense for HTTP Basic Auth credentials. + if username is not None and password is None: + raise InvalidProxy(proxy, "username provided without password") + + try: + proxy.encode("ascii") + except UnicodeEncodeError: + # Input contains non-ASCII characters. + # It must be an IRI. Convert it to a URI. + host = host.encode("idna").decode() + if username is not None: + assert password is not None + username = urllib.parse.quote(username, safe=DELIMS) + password = urllib.parse.quote(password, safe=DELIMS) + + return Proxy(scheme, host, port, username, password) + + +def get_proxy(uri: WebSocketURI) -> str | None: + """ + Return the proxy to use for connecting to the given WebSocket URI, if any. + + """ + if urllib.request.proxy_bypass(f"{uri.host}:{uri.port}"): + return None + + # According to the _Proxy Usage_ section of RFC 6455, use a SOCKS5 proxy if + # available, else favor the proxy for HTTPS connections over the proxy for + # HTTP connections. + + # The priority of a proxy for WebSocket connections is unspecified. We give + # it the highest priority. This makes it easy to configure a specific proxy + # for websockets. + + # getproxies() may return SOCKS proxies as {"socks": "http://host:port"} or + # as {"https": "socks5h://host:port"} depending on whether they're declared + # in the operating system or in environment variables. + + proxies = urllib.request.getproxies() + if uri.secure: + schemes = ["wss", "socks", "https"] + else: + schemes = ["ws", "socks", "https", "http"] + + for scheme in schemes: + proxy = proxies.get(scheme) + if proxy is not None: + if scheme == "socks" and proxy.startswith("http://"): + proxy = "socks5h://" + proxy[7:] + return proxy + else: + return None + + +def prepare_connect_request( + proxy: Proxy, + ws_uri: WebSocketURI, + user_agent_header: str | None = USER_AGENT, +) -> bytes: + host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True) + headers = Headers() + headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure) + if user_agent_header is not None: + headers["User-Agent"] = user_agent_header + if proxy.username is not None: + assert proxy.password is not None # enforced by parse_proxy() + headers["Proxy-Authorization"] = build_authorization_basic( + proxy.username, proxy.password + ) + # We cannot use the Request class because it supports only GET requests. + return f"CONNECT {host} HTTP/1.1\r\n".encode() + headers.serialize() diff --git a/source/websockets/py.typed b/source/websockets/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/source/websockets/server.py b/source/websockets/server.py new file mode 100644 index 0000000000000000000000000000000000000000..de2c63548f09e92c27bed63340e11e5e213ea137 --- /dev/null +++ b/source/websockets/server.py @@ -0,0 +1,589 @@ +from __future__ import annotations + +import base64 +import binascii +import email.utils +import http +import re +import warnings +from collections.abc import Generator, Sequence +from typing import Any, Callable, cast + +from .datastructures import Headers, MultipleValuesError +from .exceptions import ( + InvalidHandshake, + InvalidHeader, + InvalidHeaderValue, + InvalidMessage, + InvalidOrigin, + InvalidUpgrade, + NegotiationError, +) +from .extensions import Extension, ServerExtensionFactory +from .headers import ( + build_extension, + parse_connection, + parse_extension, + parse_subprotocol, + parse_upgrade, +) +from .http11 import Request, Response +from .imports import lazy_import +from .protocol import CONNECTING, OPEN, SERVER, Protocol, State +from .typing import ( + ConnectionOption, + ExtensionHeader, + LoggerLike, + Origin, + StatusLike, + Subprotocol, + UpgradeProtocol, +) +from .utils import accept_key + + +__all__ = ["ServerProtocol"] + + +class ServerProtocol(Protocol): + """ + Sans-I/O implementation of a WebSocket server connection. + + Args: + origins: Acceptable values of the ``Origin`` header. Values can be + :class:`str` to test for an exact match or regular expressions + compiled by :func:`re.compile` to test against a pattern. Include + :obj:`None` in the list if the lack of an origin is acceptable. + This is useful for defending against Cross-Site WebSocket + Hijacking attacks. + extensions: List of supported extensions, in order in which they + should be tried. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + select_subprotocol: Callback for selecting a subprotocol among + those supported by the client and the server. It has the same + signature as the :meth:`select_subprotocol` method, including a + :class:`ServerProtocol` instance as first argument. + state: Initial state of the WebSocket connection. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. You may pass a ``(max_message_size, + max_fragment_size)`` tuple to set different limits for messages and + fragments when you expect long messages sent in short fragments. + logger: Logger for this connection; + defaults to ``logging.getLogger("websockets.server")``; + see the :doc:`logging guide <../../topics/logging>` for details. + + """ + + def __init__( + self, + *, + origins: Sequence[Origin | re.Pattern[str] | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + select_subprotocol: ( + Callable[ + [ServerProtocol, Sequence[Subprotocol]], + Subprotocol | None, + ] + | None + ) = None, + state: State = CONNECTING, + max_size: int | None | tuple[int | None, int | None] = 2**20, + logger: LoggerLike | None = None, + ) -> None: + super().__init__( + side=SERVER, + state=state, + max_size=max_size, + logger=logger, + ) + self.origins = origins + self.available_extensions = extensions + self.available_subprotocols = subprotocols + if select_subprotocol is not None: + # Bind select_subprotocol then shadow self.select_subprotocol. + # Use setattr to work around https://github.com/python/mypy/issues/2427. + setattr( + self, + "select_subprotocol", + select_subprotocol.__get__(self, self.__class__), + ) + + def accept(self, request: Request) -> Response: + """ + Create a handshake response to accept the connection. + + If the handshake request is valid and the handshake successful, + :meth:`accept` returns an HTTP response with status code 101. + + Else, it returns an HTTP response with another status code. This rejects + the connection, like :meth:`reject` would. + + You must send the handshake response with :meth:`send_response`. + + You may modify the response before sending it, typically by adding HTTP + headers. + + Args: + request: WebSocket handshake request received from the client. + + Returns: + WebSocket handshake response or HTTP response to send to the client. + + """ + try: + ( + accept_header, + extensions_header, + protocol_header, + ) = self.process_request(request) + except InvalidOrigin as exc: + request._exception = exc + self.handshake_exc = exc + if self.debug: + self.logger.debug("! invalid origin", exc_info=True) + return self.reject( + http.HTTPStatus.FORBIDDEN, + f"Failed to open a WebSocket connection: {exc}.\n", + ) + except InvalidUpgrade as exc: + request._exception = exc + self.handshake_exc = exc + if self.debug: + self.logger.debug("! invalid upgrade", exc_info=True) + response = self.reject( + http.HTTPStatus.UPGRADE_REQUIRED, + ( + f"Failed to open a WebSocket connection: {exc}.\n" + f"\n" + f"You cannot access a WebSocket server directly " + f"with a browser. You need a WebSocket client.\n" + ), + ) + response.headers["Upgrade"] = "websocket" + return response + except InvalidHandshake as exc: + request._exception = exc + self.handshake_exc = exc + if self.debug: + self.logger.debug("! invalid handshake", exc_info=True) + exc_chain = cast(BaseException, exc) + exc_str = f"{exc_chain}" + while exc_chain.__cause__ is not None: + exc_chain = exc_chain.__cause__ + exc_str += f"; {exc_chain}" + return self.reject( + http.HTTPStatus.BAD_REQUEST, + f"Failed to open a WebSocket connection: {exc_str}.\n", + ) + except Exception as exc: + # Handle exceptions raised by user-provided select_subprotocol and + # unexpected errors. + request._exception = exc + self.handshake_exc = exc + self.logger.error("opening handshake failed", exc_info=True) + return self.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + headers = Headers() + headers["Date"] = email.utils.formatdate(usegmt=True) + headers["Upgrade"] = "websocket" + headers["Connection"] = "Upgrade" + headers["Sec-WebSocket-Accept"] = accept_header + if extensions_header is not None: + headers["Sec-WebSocket-Extensions"] = extensions_header + if protocol_header is not None: + headers["Sec-WebSocket-Protocol"] = protocol_header + return Response(101, "Switching Protocols", headers) + + def process_request( + self, + request: Request, + ) -> tuple[str, str | None, str | None]: + """ + Check a handshake request and negotiate extensions and subprotocol. + + This function doesn't verify that the request is an HTTP/1.1 or higher + GET request and doesn't check the ``Host`` header. These controls are + usually performed earlier in the HTTP request handling code. They're + the responsibility of the caller. + + Args: + request: WebSocket handshake request received from the client. + + Returns: + ``Sec-WebSocket-Accept``, ``Sec-WebSocket-Extensions``, and + ``Sec-WebSocket-Protocol`` headers for the handshake response. + + Raises: + InvalidHandshake: If the handshake request is invalid; + then the server must return 400 Bad Request error. + + """ + headers = request.headers + + connection: list[ConnectionOption] = sum( + [parse_connection(value) for value in headers.get_all("Connection")], [] + ) + if not any(value.lower() == "upgrade" for value in connection): + raise InvalidUpgrade( + "Connection", ", ".join(connection) if connection else None + ) + + upgrade: list[UpgradeProtocol] = sum( + [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] + ) + # For compatibility with non-strict implementations, ignore case when + # checking the Upgrade header. The RFC always uses "websocket", except + # in section 11.2. (IANA registration) where it uses "WebSocket". + if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"): + raise InvalidUpgrade("Upgrade", ", ".join(upgrade) if upgrade else None) + + try: + key = headers["Sec-WebSocket-Key"] + except KeyError: + raise InvalidHeader("Sec-WebSocket-Key") from None + except MultipleValuesError: + raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from None + try: + raw_key = base64.b64decode(key.encode(), validate=True) + except binascii.Error as exc: + raise InvalidHeaderValue("Sec-WebSocket-Key", key) from exc + if len(raw_key) != 16: + raise InvalidHeaderValue("Sec-WebSocket-Key", key) + accept_header = accept_key(key) + + try: + version = headers["Sec-WebSocket-Version"] + except KeyError: + raise InvalidHeader("Sec-WebSocket-Version") from None + except MultipleValuesError: + raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from None + if version != "13": + raise InvalidHeaderValue("Sec-WebSocket-Version", version) + + self.origin = self.process_origin(headers) + extensions_header, self.extensions = self.process_extensions(headers) + protocol_header = self.subprotocol = self.process_subprotocol(headers) + + return (accept_header, extensions_header, protocol_header) + + def process_origin(self, headers: Headers) -> Origin | None: + """ + Handle the Origin HTTP request header. + + Args: + headers: WebSocket handshake request headers. + + Returns: + origin, if it is acceptable. + + Raises: + InvalidHandshake: If the Origin header is invalid. + InvalidOrigin: If the origin isn't acceptable. + + """ + # "The user agent MUST NOT include more than one Origin header field" + # per https://datatracker.ietf.org/doc/html/rfc6454#section-7.3. + try: + origin = headers.get("Origin") + except MultipleValuesError: + raise InvalidHeader("Origin", "multiple values") from None + if origin is not None: + origin = cast(Origin, origin) + if self.origins is not None: + for origin_or_regex in self.origins: + if origin_or_regex == origin or ( + isinstance(origin_or_regex, re.Pattern) + and origin is not None + and origin_or_regex.fullmatch(origin) is not None + ): + break + else: + raise InvalidOrigin(origin) + return origin + + def process_extensions( + self, + headers: Headers, + ) -> tuple[str | None, list[Extension]]: + """ + Handle the Sec-WebSocket-Extensions HTTP request header. + + Accept or reject each extension proposed in the client request. + Negotiate parameters for accepted extensions. + + Per :rfc:`6455`, negotiation rules are defined by the specification of + each extension. + + To provide this level of flexibility, for each extension proposed by + the client, we check for a match with each extension available in the + server configuration. If no match is found, the extension is ignored. + + If several variants of the same extension are proposed by the client, + it may be accepted several times, which won't make sense in general. + Extensions must implement their own requirements. For this purpose, + the list of previously accepted extensions is provided. + + This process doesn't allow the server to reorder extensions. It can + only select a subset of the extensions proposed by the client. + + Other requirements, for example related to mandatory extensions or the + order of extensions, may be implemented by overriding this method. + + Args: + headers: WebSocket handshake request headers. + + Returns: + ``Sec-WebSocket-Extensions`` HTTP response header and list of + accepted extensions. + + Raises: + InvalidHandshake: If the Sec-WebSocket-Extensions header is invalid. + + """ + response_header_value: str | None = None + + extension_headers: list[ExtensionHeader] = [] + accepted_extensions: list[Extension] = [] + + header_values = headers.get_all("Sec-WebSocket-Extensions") + + if header_values and self.available_extensions: + parsed_header_values: list[ExtensionHeader] = sum( + [parse_extension(header_value) for header_value in header_values], [] + ) + + for name, request_params in parsed_header_values: + for ext_factory in self.available_extensions: + # Skip non-matching extensions based on their name. + if ext_factory.name != name: + continue + + # Skip non-matching extensions based on their params. + try: + response_params, extension = ext_factory.process_request_params( + request_params, accepted_extensions + ) + except NegotiationError: + continue + + # Add matching extension to the final list. + extension_headers.append((name, response_params)) + accepted_extensions.append(extension) + + # Break out of the loop once we have a match. + break + + # If we didn't break from the loop, no extension in our list + # matched what the client sent. The extension is declined. + + # Serialize extension header. + if extension_headers: + response_header_value = build_extension(extension_headers) + + return response_header_value, accepted_extensions + + def process_subprotocol(self, headers: Headers) -> Subprotocol | None: + """ + Handle the Sec-WebSocket-Protocol HTTP request header. + + Args: + headers: WebSocket handshake request headers. + + Returns: + Subprotocol, if one was selected; this is also the value of the + ``Sec-WebSocket-Protocol`` response header. + + Raises: + InvalidHandshake: If the Sec-WebSocket-Subprotocol header is invalid. + + """ + subprotocols: Sequence[Subprotocol] = sum( + [ + parse_subprotocol(header_value) + for header_value in headers.get_all("Sec-WebSocket-Protocol") + ], + [], + ) + return self.select_subprotocol(subprotocols) + + def select_subprotocol( + self, + subprotocols: Sequence[Subprotocol], + ) -> Subprotocol | None: + """ + Pick a subprotocol among those offered by the client. + + If several subprotocols are supported by both the client and the server, + pick the first one in the list declared the server. + + If the server doesn't support any subprotocols, continue without a + subprotocol, regardless of what the client offers. + + If the server supports at least one subprotocol and the client doesn't + offer any, abort the handshake with an HTTP 400 error. + + You provide a ``select_subprotocol`` argument to :class:`ServerProtocol` + to override this logic. For example, you could accept the connection + even if client doesn't offer a subprotocol, rather than reject it. + + Here's how to negotiate the ``chat`` subprotocol if the client supports + it and continue without a subprotocol otherwise:: + + def select_subprotocol(protocol, subprotocols): + if "chat" in subprotocols: + return "chat" + + Args: + subprotocols: List of subprotocols offered by the client. + + Returns: + Selected subprotocol, if a common subprotocol was found. + + :obj:`None` to continue without a subprotocol. + + Raises: + NegotiationError: Custom implementations may raise this exception + to abort the handshake with an HTTP 400 error. + + """ + # Server doesn't offer any subprotocols. + if not self.available_subprotocols: # None or empty list + return None + + # Server offers at least one subprotocol but client doesn't offer any. + if not subprotocols: + raise NegotiationError("missing subprotocol") + + # Server and client both offer subprotocols. Look for a shared one. + proposed_subprotocols = set(subprotocols) + for subprotocol in self.available_subprotocols: + if subprotocol in proposed_subprotocols: + return subprotocol + + # No common subprotocol was found. + raise NegotiationError( + "invalid subprotocol; expected one of " + + ", ".join(self.available_subprotocols) + ) + + def reject(self, status: StatusLike, text: str) -> Response: + """ + Create a handshake response to reject the connection. + + A short plain text response is the best fallback when failing to + establish a WebSocket connection. + + You must send the handshake response with :meth:`send_response`. + + You may modify the response before sending it, for example by changing + HTTP headers. + + Args: + status: HTTP status code. + text: HTTP response body; it will be encoded to UTF-8. + + Returns: + HTTP response to send to the client. + + """ + # If status is an int instead of an HTTPStatus, fix it automatically. + status = http.HTTPStatus(status) + body = text.encode() + headers = Headers( + [ + ("Date", email.utils.formatdate(usegmt=True)), + ("Connection", "close"), + ("Content-Length", str(len(body))), + ("Content-Type", "text/plain; charset=utf-8"), + ] + ) + return Response(status.value, status.phrase, headers, body) + + def send_response(self, response: Response) -> None: + """ + Send a handshake response to the client. + + Args: + response: WebSocket handshake response event to send. + + """ + if self.debug: + code, phrase = response.status_code, response.reason_phrase + self.logger.debug("> HTTP/1.1 %d %s", code, phrase) + for key, value in response.headers.raw_items(): + self.logger.debug("> %s: %s", key, value) + if response.body: + self.logger.debug("> [body] (%d bytes)", len(response.body)) + + self.writes.append(response.serialize()) + + if response.status_code == 101: + assert self.state is CONNECTING + self.state = OPEN + self.logger.info("connection open") + + else: + self.logger.info( + "connection rejected (%d %s)", + response.status_code, + response.reason_phrase, + ) + + self.send_eof() + self.parser = self.discard() + next(self.parser) # start coroutine + + def parse(self) -> Generator[None]: + if self.state is CONNECTING: + try: + request = yield from Request.parse( + self.reader.read_line, + ) + except Exception as exc: + self.handshake_exc = InvalidMessage( + "did not receive a valid HTTP request" + ) + self.handshake_exc.__cause__ = exc + self.send_eof() + self.parser = self.discard() + next(self.parser) # start coroutine + yield + + if self.debug: + self.logger.debug("< GET %s HTTP/1.1", request.path) + for key, value in request.headers.raw_items(): + self.logger.debug("< %s: %s", key, value) + + self.events.append(request) + + yield from super().parse() + + +class ServerConnection(ServerProtocol): + def __init__(self, *args: Any, **kwargs: Any) -> None: + warnings.warn( # deprecated in 11.0 - 2023-04-02 + "ServerConnection was renamed to ServerProtocol", + DeprecationWarning, + ) + super().__init__(*args, **kwargs) + + +lazy_import( + globals(), + deprecated_aliases={ + # deprecated in 14.0 - 2024-11-09 + "WebSocketServer": ".legacy.server", + "WebSocketServerProtocol": ".legacy.server", + "broadcast": ".legacy.server", + "serve": ".legacy.server", + "unix_serve": ".legacy.server", + }, +) diff --git a/source/websockets/speedups.c b/source/websockets/speedups.c new file mode 100644 index 0000000000000000000000000000000000000000..f14ba3b9740a9f7ea901471c2f524f008c7c8ef5 --- /dev/null +++ b/source/websockets/speedups.c @@ -0,0 +1,229 @@ +/* C implementation of performance sensitive functions. */ + +#define PY_SSIZE_T_CLEAN +#include +#include /* uint8_t, uint32_t, uint64_t */ + +#if __ARM_NEON +#include +#elif __SSE2__ +#include +#endif + +static const Py_ssize_t MASK_LEN = 4; + +/* Similar to PyBytes_AsStringAndSize, but accepts more types */ + +static int +_PyBytesLike_AsStringAndSize(PyObject *obj, PyObject **tmp, char **buffer, Py_ssize_t *length) +{ + // This supports bytes, bytearrays, and memoryview objects, + // which are common data structures for handling byte streams. + // If *tmp isn't NULL, the caller gets a new reference. + if (PyBytes_Check(obj)) + { + *tmp = NULL; + *buffer = PyBytes_AS_STRING(obj); + *length = PyBytes_GET_SIZE(obj); + } + else if (PyByteArray_Check(obj)) + { + *tmp = NULL; + *buffer = PyByteArray_AS_STRING(obj); + *length = PyByteArray_GET_SIZE(obj); + } + else if (PyMemoryView_Check(obj)) + { + *tmp = PyMemoryView_GetContiguous(obj, PyBUF_READ, 'C'); + if (*tmp == NULL) + { + return -1; + } + Py_buffer *mv_buf; + mv_buf = PyMemoryView_GET_BUFFER(*tmp); + *buffer = mv_buf->buf; + *length = mv_buf->len; + } + else + { + PyErr_Format( + PyExc_TypeError, + "expected a bytes-like object, %.200s found", + Py_TYPE(obj)->tp_name); + return -1; + } + + return 0; +} + +/* C implementation of websockets.utils.apply_mask */ + +static PyObject * +apply_mask(PyObject *self, PyObject *args, PyObject *kwds) +{ + + // In order to support various bytes-like types, accept any Python object. + + static char *kwlist[] = {"data", "mask", NULL}; + PyObject *input_obj; + PyObject *mask_obj; + + // A pointer to a char * + length will be extracted from the data and mask + // arguments, possibly via a Py_buffer. + + PyObject *input_tmp = NULL; + char *input; + Py_ssize_t input_len; + PyObject *mask_tmp = NULL; + char *mask; + Py_ssize_t mask_len; + + // Initialize a PyBytesObject then get a pointer to the underlying char * + // in order to avoid an extra memory copy in PyBytes_FromStringAndSize. + + PyObject *result = NULL; + char *output; + + // Other variables. + + Py_ssize_t i = 0; + + // Parse inputs. + + if (!PyArg_ParseTupleAndKeywords( + args, kwds, "OO", kwlist, &input_obj, &mask_obj)) + { + goto exit; + } + + if (_PyBytesLike_AsStringAndSize(input_obj, &input_tmp, &input, &input_len) == -1) + { + goto exit; + } + + if (_PyBytesLike_AsStringAndSize(mask_obj, &mask_tmp, &mask, &mask_len) == -1) + { + goto exit; + } + + if (mask_len != MASK_LEN) + { + PyErr_SetString(PyExc_ValueError, "mask must contain 4 bytes"); + goto exit; + } + + // Create output. + + result = PyBytes_FromStringAndSize(NULL, input_len); + if (result == NULL) + { + goto exit; + } + + // Since we just created result, we don't need error checks. + output = PyBytes_AS_STRING(result); + + // Perform the masking operation. + + // Apparently GCC cannot figure out the following optimizations by itself. + + // We need a new scope for MSVC 2010 (non C99 friendly) + { +#if __ARM_NEON + + // With NEON support, XOR by blocks of 16 bytes = 128 bits. + + Py_ssize_t input_len_128 = input_len & ~15; + uint8x16_t mask_128 = vreinterpretq_u8_u32(vdupq_n_u32(*(uint32_t *)mask)); + + for (; i < input_len_128; i += 16) + { + uint8x16_t in_128 = vld1q_u8((uint8_t *)(input + i)); + uint8x16_t out_128 = veorq_u8(in_128, mask_128); + vst1q_u8((uint8_t *)(output + i), out_128); + } + +#elif __SSE2__ + + // With SSE2 support, XOR by blocks of 16 bytes = 128 bits. + + // Since we cannot control the 16-bytes alignment of input and output + // buffers, we rely on loadu/storeu rather than load/store. + + Py_ssize_t input_len_128 = input_len & ~15; + __m128i mask_128 = _mm_set1_epi32(*(uint32_t *)mask); + + for (; i < input_len_128; i += 16) + { + __m128i in_128 = _mm_loadu_si128((__m128i *)(input + i)); + __m128i out_128 = _mm_xor_si128(in_128, mask_128); + _mm_storeu_si128((__m128i *)(output + i), out_128); + } + +#else + + // Without SSE2 support, XOR by blocks of 8 bytes = 64 bits. + + // We assume the memory allocator aligns everything on 8 bytes boundaries. + + Py_ssize_t input_len_64 = input_len & ~7; + uint32_t mask_32 = *(uint32_t *)mask; + uint64_t mask_64 = ((uint64_t)mask_32 << 32) | (uint64_t)mask_32; + + for (; i < input_len_64; i += 8) + { + *(uint64_t *)(output + i) = *(uint64_t *)(input + i) ^ mask_64; + } + +#endif + } + + // XOR the remainder of the input byte by byte. + + for (; i < input_len; i++) + { + output[i] = input[i] ^ mask[i & (MASK_LEN - 1)]; + } + +exit: + Py_XDECREF(input_tmp); + Py_XDECREF(mask_tmp); + return result; + +} + +static PyMethodDef speedups_methods[] = { + { + "apply_mask", + (PyCFunction)apply_mask, + METH_VARARGS | METH_KEYWORDS, + "Apply masking to the data of a WebSocket message.", + }, + {NULL, NULL, 0, NULL}, /* Sentinel */ +}; + +static struct PyModuleDef speedups_module = { + PyModuleDef_HEAD_INIT, + "websocket.speedups", /* m_name */ + "C implementation of performance sensitive functions.", + /* m_doc */ + -1, /* m_size */ + speedups_methods, /* m_methods */ + NULL, + NULL, + NULL, + NULL +}; + +PyMODINIT_FUNC +PyInit_speedups(void) +{ + PyObject *m = PyModule_Create(&speedups_module); + if (m == NULL) { + return NULL; + } +#ifdef Py_GIL_DISABLED + PyUnstable_Module_SetGIL(m, Py_MOD_GIL_NOT_USED); +#endif + return m; +} diff --git a/source/websockets/speedups.cpython-312-x86_64-linux-gnu.so b/source/websockets/speedups.cpython-312-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..e1cd3ca69152c9350ce46e9b92c5d6d27f7254f1 Binary files /dev/null and b/source/websockets/speedups.cpython-312-x86_64-linux-gnu.so differ diff --git a/source/websockets/speedups.pyi b/source/websockets/speedups.pyi new file mode 100644 index 0000000000000000000000000000000000000000..ffd6c3e07e451e712b53df7b7cf36b61d3a0bf1d --- /dev/null +++ b/source/websockets/speedups.pyi @@ -0,0 +1,3 @@ +from .typing import BytesLike + +def apply_mask(data: BytesLike, mask: bytes | bytearray) -> bytes: ... diff --git a/source/websockets/streams.py b/source/websockets/streams.py new file mode 100644 index 0000000000000000000000000000000000000000..309ce152ddb5d6407a29bc06e1edbe0a273e45e3 --- /dev/null +++ b/source/websockets/streams.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from collections.abc import Generator + + +class StreamReader: + """ + Generator-based stream reader. + + This class doesn't support concurrent calls to :meth:`read_line`, + :meth:`read_exact`, or :meth:`read_to_eof`. Make sure calls are + serialized. + + """ + + def __init__(self) -> None: + self.buffer = bytearray() + self.eof = False + + def read_line(self, m: int) -> Generator[None, None, bytearray]: + """ + Read a LF-terminated line from the stream. + + This is a generator-based coroutine. + + The return value includes the LF character. + + Args: + m: Maximum number bytes to read; this is a security limit. + + Raises: + EOFError: If the stream ends without a LF. + RuntimeError: If the stream ends in more than ``m`` bytes. + + """ + n = 0 # number of bytes to read + p = 0 # number of bytes without a newline + while True: + n = self.buffer.find(b"\n", p) + 1 + if n > 0: + break + p = len(self.buffer) + if p > m: + raise RuntimeError(f"read {p} bytes, expected no more than {m} bytes") + if self.eof: + raise EOFError(f"stream ends after {p} bytes, before end of line") + yield + if n > m: + raise RuntimeError(f"read {n} bytes, expected no more than {m} bytes") + r = self.buffer[:n] + del self.buffer[:n] + return r + + def read_exact(self, n: int) -> Generator[None, None, bytearray]: + """ + Read a given number of bytes from the stream. + + This is a generator-based coroutine. + + Args: + n: How many bytes to read. + + Raises: + EOFError: If the stream ends in less than ``n`` bytes. + + """ + assert n >= 0 + while len(self.buffer) < n: + if self.eof: + p = len(self.buffer) + raise EOFError(f"stream ends after {p} bytes, expected {n} bytes") + yield + r = self.buffer[:n] + del self.buffer[:n] + return r + + def read_to_eof(self, m: int) -> Generator[None, None, bytearray]: + """ + Read all bytes from the stream. + + This is a generator-based coroutine. + + Args: + m: Maximum number bytes to read; this is a security limit. + + Raises: + RuntimeError: If the stream ends in more than ``m`` bytes. + + """ + while not self.eof: + p = len(self.buffer) + if p > m: + raise RuntimeError(f"read {p} bytes, expected no more than {m} bytes") + yield + r = self.buffer[:] + del self.buffer[:] + return r + + def at_eof(self) -> Generator[None, None, bool]: + """ + Tell whether the stream has ended and all data was read. + + This is a generator-based coroutine. + + """ + while True: + if self.buffer: + return False + if self.eof: + return True + # When all data was read but the stream hasn't ended, we can't + # tell if until either feed_data() or feed_eof() is called. + yield + + def feed_data(self, data: bytes | bytearray) -> None: + """ + Write data to the stream. + + :meth:`feed_data` cannot be called after :meth:`feed_eof`. + + Args: + data: Data to write. + + Raises: + EOFError: If the stream has ended. + + """ + if self.eof: + raise EOFError("stream ended") + self.buffer += data + + def feed_eof(self) -> None: + """ + End the stream. + + :meth:`feed_eof` cannot be called more than once. + + Raises: + EOFError: If the stream has ended. + + """ + if self.eof: + raise EOFError("stream ended") + self.eof = True + + def discard(self) -> None: + """ + Discard all buffered data, but don't end the stream. + + """ + del self.buffer[:] diff --git a/source/websockets/sync/__init__.py b/source/websockets/sync/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/source/websockets/sync/client.py b/source/websockets/sync/client.py new file mode 100644 index 0000000000000000000000000000000000000000..b3fff44eee51ca2961f40c0994d7078c0670a9da --- /dev/null +++ b/source/websockets/sync/client.py @@ -0,0 +1,633 @@ +from __future__ import annotations + +import socket +import ssl as ssl_module +import threading +import warnings +from collections.abc import Sequence +from typing import Any, Callable, Literal, TypeVar, cast + +from ..client import ClientProtocol +from ..datastructures import HeadersLike +from ..exceptions import InvalidProxyMessage, InvalidProxyStatus, ProxyError +from ..extensions.base import ClientExtensionFactory +from ..extensions.permessage_deflate import enable_client_permessage_deflate +from ..headers import validate_subprotocols +from ..http11 import USER_AGENT, Response +from ..protocol import CONNECTING, Event +from ..proxy import Proxy, get_proxy, parse_proxy, prepare_connect_request +from ..streams import StreamReader +from ..typing import BytesLike, LoggerLike, Origin, Subprotocol +from ..uri import WebSocketURI, parse_uri +from .connection import Connection +from .utils import Deadline + + +__all__ = ["connect", "unix_connect", "ClientConnection"] + + +class ClientConnection(Connection): + """ + :mod:`threading` implementation of a WebSocket client connection. + + :class:`ClientConnection` provides :meth:`recv` and :meth:`send` methods for + receiving and sending messages. + + It supports iteration to receive messages:: + + for message in websocket: + process(message) + + The iterator exits normally when the connection is closed with code + 1000 (OK) or 1001 (going away) or without a close code. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is + closed with any other code. + + The ``ping_interval``, ``ping_timeout``, ``close_timeout``, and + ``max_queue`` arguments have the same meaning as in :func:`connect`. + + Args: + socket: Socket connected to a WebSocket server. + protocol: Sans-I/O connection. + + """ + + def __init__( + self, + socket: socket.socket, + protocol: ClientProtocol, + *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = 10, + max_queue: int | None | tuple[int | None, int | None] = 16, + ) -> None: + self.protocol: ClientProtocol + self.response_rcvd = threading.Event() + super().__init__( + socket, + protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_queue=max_queue, + ) + + def handshake( + self, + additional_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + timeout: float | None = None, + ) -> None: + """ + Perform the opening handshake. + + """ + with self.send_context(expected_state=CONNECTING): + self.request = self.protocol.connect() + if additional_headers is not None: + self.request.headers.update(additional_headers) + if user_agent_header is not None: + self.request.headers.setdefault("User-Agent", user_agent_header) + self.protocol.send_request(self.request) + + if not self.response_rcvd.wait(timeout): + raise TimeoutError("timed out while waiting for handshake response") + + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a response, when the response cannot be parsed, or when the + # response fails the handshake. + + if self.protocol.handshake_exc is not None: + raise self.protocol.handshake_exc + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + """ + # First event - handshake response. + if self.response is None: + assert isinstance(event, Response) + self.response = event + self.response_rcvd.set() + # Later events - frames. + else: + super().process_event(event) + + def recv_events(self) -> None: + """ + Read incoming data from the socket and process events. + + """ + try: + super().recv_events() + finally: + # If the connection is closed during the handshake, unblock it. + self.response_rcvd.set() + + +def connect( + uri: str, + *, + # TCP/TLS + sock: socket.socket | None = None, + ssl: ssl_module.SSLContext | None = None, + server_hostname: str | None = None, + # WebSocket + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + compression: str | None = "deflate", + # HTTP + additional_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + proxy: str | Literal[True] | None = True, + proxy_ssl: ssl_module.SSLContext | None = None, + proxy_server_hostname: str | None = None, + # Timeouts + open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = 10, + # Limits + max_size: int | None | tuple[int | None, int | None] = 2**20, + max_queue: int | None | tuple[int | None, int | None] = 16, + # Logging + logger: LoggerLike | None = None, + # Escape hatch for advanced customization + create_connection: type[ClientConnection] | None = None, + **kwargs: Any, +) -> ClientConnection: + """ + Connect to the WebSocket server at ``uri``. + + This function returns a :class:`ClientConnection` instance, which you can + use to send and receive messages. + + :func:`connect` may be used as a context manager:: + + from websockets.sync.client import connect + + with connect(...) as websocket: + ... + + The connection is closed automatically when exiting the context. + + Args: + uri: URI of the WebSocket server. + sock: Preexisting TCP socket. ``sock`` overrides the host and port + from ``uri``. You may call :func:`socket.create_connection` to + create a suitable TCP socket. + ssl: Configuration for enabling TLS on the connection. + server_hostname: Host name for the TLS handshake. ``server_hostname`` + overrides the host name from ``uri``. + origin: Value of the ``Origin`` header, for servers that require it. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + additional_headers (HeadersLike | None): Arbitrary HTTP headers to add + to the handshake request. + user_agent_header: Value of the ``User-Agent`` request header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. + Setting it to :obj:`None` removes the header. + proxy: If a proxy is configured, it is used by default. Set ``proxy`` + to :obj:`None` to disable the proxy or to the address of a proxy + to override the system configuration. See the :doc:`proxy docs + <../../topics/proxies>` for details. + proxy_ssl: Configuration for enabling TLS on the proxy connection. + proxy_server_hostname: Host name for the TLS handshake with the proxy. + ``proxy_server_hostname`` overrides the host name from ``proxy``. + open_timeout: Timeout for opening the connection in seconds. + :obj:`None` disables the timeout. + ping_interval: Interval between keepalive pings in seconds. + :obj:`None` disables keepalive. + ping_timeout: Timeout for keepalive pings in seconds. + :obj:`None` disables timeouts. + close_timeout: Timeout for closing the connection in seconds. + :obj:`None` disables the timeout. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. You may pass a ``(max_message_size, + max_fragment_size)`` tuple to set different limits for messages and + fragments when you expect long messages sent in short fragments. + max_queue: High-water mark of the buffer where frames are received. + It defaults to 16 frames. The low-water mark defaults to ``max_queue + // 4``. You may pass a ``(high, low)`` tuple to set the high-water + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. + logger: Logger for this client. + It defaults to ``logging.getLogger("websockets.client")``. + See the :doc:`logging guide <../../topics/logging>` for details. + create_connection: Factory for the :class:`ClientConnection` managing + the connection. Set it to a wrapper or a subclass to customize + connection handling. + + Any other keyword arguments are passed to :func:`~socket.create_connection`. + + Raises: + InvalidURI: If ``uri`` isn't a valid WebSocket URI. + OSError: If the TCP connection fails. + InvalidHandshake: If the opening handshake fails. + TimeoutError: If the opening handshake times out. + + """ + + # Process parameters + + # Backwards compatibility: ssl used to be called ssl_context. + if ssl is None and "ssl_context" in kwargs: + ssl = kwargs.pop("ssl_context") + warnings.warn( # deprecated in 13.0 - 2024-08-20 + "ssl_context was renamed to ssl", + DeprecationWarning, + ) + + ws_uri = parse_uri(uri) + if not ws_uri.secure and ssl is not None: + raise ValueError("ssl argument is incompatible with a ws:// URI") + + # Private APIs for unix_connect() + unix: bool = kwargs.pop("unix", False) + path: str | None = kwargs.pop("path", None) + + if unix: + if path is None and sock is None: + raise ValueError("missing path argument") + elif path is not None and sock is not None: + raise ValueError("path and sock arguments are incompatible") + + if subprotocols is not None: + validate_subprotocols(subprotocols) + + if compression == "deflate": + extensions = enable_client_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + if unix: + proxy = None + if sock is not None: + proxy = None + if proxy is True: + proxy = get_proxy(ws_uri) + + # Calculate timeouts on the TCP, TLS, and WebSocket handshakes. + # The TCP and TLS timeouts must be set on the socket, then removed + # to avoid conflicting with the WebSocket timeout in handshake(). + deadline = Deadline(open_timeout) + + if create_connection is None: + create_connection = ClientConnection + + try: + # Connect socket + + if sock is None: + if unix: + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.settimeout(deadline.timeout()) + assert path is not None # mypy cannot figure this out + sock.connect(path) + elif proxy is not None: + proxy_parsed = parse_proxy(proxy) + if proxy_parsed.scheme[:5] == "socks": + # Connect to the server through the proxy. + sock = connect_socks_proxy( + proxy_parsed, + ws_uri, + deadline, + # websockets is consistent with the socket module while + # python_socks is consistent across implementations. + local_addr=kwargs.pop("source_address", None), + ) + elif proxy_parsed.scheme[:4] == "http": + # Validate the proxy_ssl argument. + if proxy_parsed.scheme != "https" and proxy_ssl is not None: + raise ValueError( + "proxy_ssl argument is incompatible with an http:// proxy" + ) + # Connect to the server through the proxy. + sock = connect_http_proxy( + proxy_parsed, + ws_uri, + deadline, + user_agent_header=user_agent_header, + ssl=proxy_ssl, + server_hostname=proxy_server_hostname, + **kwargs, + ) + else: + raise AssertionError("unsupported proxy") + else: + kwargs.setdefault("timeout", deadline.timeout()) + sock = socket.create_connection( + (ws_uri.host, ws_uri.port), + **kwargs, + ) + sock.settimeout(None) + + # Disable Nagle algorithm + + if not unix: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True) + + # Initialize TLS wrapper and perform TLS handshake + + if ws_uri.secure: + if ssl is None: + ssl = ssl_module.create_default_context() + if server_hostname is None: + server_hostname = ws_uri.host + sock.settimeout(deadline.timeout()) + if proxy_ssl is None: + sock = ssl.wrap_socket(sock, server_hostname=server_hostname) + else: + sock_2 = SSLSSLSocket(sock, ssl, server_hostname=server_hostname) + # Let's pretend that sock is a socket, even though it isn't. + sock = cast(socket.socket, sock_2) + sock.settimeout(None) + + # Initialize WebSocket protocol + + protocol = ClientProtocol( + ws_uri, + origin=origin, + extensions=extensions, + subprotocols=subprotocols, + max_size=max_size, + logger=logger, + ) + + # Initialize WebSocket connection + + connection = create_connection( + sock, + protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_queue=max_queue, + ) + except Exception: + if sock is not None: + sock.close() + raise + + try: + connection.handshake( + additional_headers, + user_agent_header, + deadline.timeout(), + ) + except Exception: + connection.close_socket() + connection.recv_events_thread.join() + raise + + connection.start_keepalive() + return connection + + +def unix_connect( + path: str | None = None, + uri: str | None = None, + **kwargs: Any, +) -> ClientConnection: + """ + Connect to a WebSocket server listening on a Unix socket. + + This function accepts the same keyword arguments as :func:`connect`. + + It's only available on Unix. + + It's mainly useful for debugging servers listening on Unix sockets. + + Args: + path: File system path to the Unix socket. + uri: URI of the WebSocket server. ``uri`` defaults to + ``ws://localhost/`` or, when a ``ssl`` is provided, to + ``wss://localhost/``. + + """ + if uri is None: + # Backwards compatibility: ssl used to be called ssl_context. + if kwargs.get("ssl") is None and kwargs.get("ssl_context") is None: + uri = "ws://localhost/" + else: + uri = "wss://localhost/" + return connect(uri=uri, unix=True, path=path, **kwargs) + + +try: + from python_socks import ProxyType + from python_socks.sync import Proxy as SocksProxy + +except ImportError: + + def connect_socks_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + deadline: Deadline, + **kwargs: Any, + ) -> socket.socket: + raise ImportError("connecting through a SOCKS proxy requires python-socks") + +else: + SOCKS_PROXY_TYPES = { + "socks5h": ProxyType.SOCKS5, + "socks5": ProxyType.SOCKS5, + "socks4a": ProxyType.SOCKS4, + "socks4": ProxyType.SOCKS4, + } + + SOCKS_PROXY_RDNS = { + "socks5h": True, + "socks5": False, + "socks4a": True, + "socks4": False, + } + + def connect_socks_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + deadline: Deadline, + **kwargs: Any, + ) -> socket.socket: + """Connect via a SOCKS proxy and return the socket.""" + socks_proxy = SocksProxy( + SOCKS_PROXY_TYPES[proxy.scheme], + proxy.host, + proxy.port, + proxy.username, + proxy.password, + SOCKS_PROXY_RDNS[proxy.scheme], + ) + kwargs.setdefault("timeout", deadline.timeout()) + # connect() is documented to raise OSError and TimeoutError. + # Wrap other exceptions in ProxyError, a subclass of InvalidHandshake. + try: + return socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs) + except (OSError, TimeoutError, socket.timeout): + raise + except Exception as exc: + raise ProxyError("failed to connect to SOCKS proxy") from exc + + +def read_connect_response(sock: socket.socket, deadline: Deadline) -> Response: + reader = StreamReader() + parser = Response.parse( + reader.read_line, + reader.read_exact, + reader.read_to_eof, + proxy=True, + ) + try: + while True: + sock.settimeout(deadline.timeout()) + data = sock.recv(4096) + if data: + reader.feed_data(data) + else: + reader.feed_eof() + next(parser) + except StopIteration as exc: + assert isinstance(exc.value, Response) # help mypy + response = exc.value + if 200 <= response.status_code < 300: + return response + else: + raise InvalidProxyStatus(response) + except socket.timeout: + raise TimeoutError("timed out while connecting to HTTP proxy") + except Exception as exc: + raise InvalidProxyMessage( + "did not receive a valid HTTP response from proxy" + ) from exc + finally: + sock.settimeout(None) + + +def connect_http_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + deadline: Deadline, + *, + user_agent_header: str | None = None, + ssl: ssl_module.SSLContext | None = None, + server_hostname: str | None = None, + **kwargs: Any, +) -> socket.socket: + # Connect socket + + kwargs.setdefault("timeout", deadline.timeout()) + sock = socket.create_connection((proxy.host, proxy.port), **kwargs) + + # Initialize TLS wrapper and perform TLS handshake + + if proxy.scheme == "https": + if ssl is None: + ssl = ssl_module.create_default_context() + if server_hostname is None: + server_hostname = proxy.host + sock.settimeout(deadline.timeout()) + sock = ssl.wrap_socket(sock, server_hostname=server_hostname) + sock.settimeout(None) + + # Send CONNECT request to the proxy and read response. + + sock.sendall(prepare_connect_request(proxy, ws_uri, user_agent_header)) + try: + read_connect_response(sock, deadline) + except Exception: + sock.close() + raise + + return sock + + +T = TypeVar("T") +F = TypeVar("F", bound=Callable[..., T]) + + +class SSLSSLSocket: + """ + Socket-like object providing TLS-in-TLS. + + Only methods that are used by websockets are implemented. + + """ + + recv_bufsize = 65536 + + def __init__( + self, + sock: socket.socket, + ssl_context: ssl_module.SSLContext, + server_hostname: str | None = None, + ) -> None: + self.incoming = ssl_module.MemoryBIO() + self.outgoing = ssl_module.MemoryBIO() + self.ssl_socket = sock + self.ssl_object = ssl_context.wrap_bio( + self.incoming, + self.outgoing, + server_hostname=server_hostname, + ) + self.run_io(self.ssl_object.do_handshake) + + def run_io(self, func: Callable[..., T], *args: Any) -> T: + while True: + want_read = False + want_write = False + try: + result = func(*args) + except ssl_module.SSLWantReadError: + want_read = True + except ssl_module.SSLWantWriteError: # pragma: no cover + want_write = True + + # Write outgoing data in all cases. + data = self.outgoing.read() + if data: + self.ssl_socket.sendall(data) + + # Read incoming data and retry on SSLWantReadError. + if want_read: + data = self.ssl_socket.recv(self.recv_bufsize) + if data: + self.incoming.write(data) + else: + self.incoming.write_eof() + continue + # Retry after writing outgoing data on SSLWantWriteError. + if want_write: # pragma: no cover + continue + # Return result if no error happened. + return result + + def recv(self, buflen: int) -> bytes: + try: + return self.run_io(self.ssl_object.read, buflen) + except ssl_module.SSLEOFError: + return b"" # always ignore ragged EOFs + + def send(self, data: BytesLike) -> int: + return self.run_io(self.ssl_object.write, data) + + def sendall(self, data: BytesLike) -> None: + # adapted from ssl_module.SSLSocket.sendall() + count = 0 + with memoryview(data) as view, view.cast("B") as byte_view: + amount = len(byte_view) + while count < amount: + count += self.send(byte_view[count:]) + + # recv_into(), recvfrom(), recvfrom_into(), sendto(), unwrap(), and the + # flags argument aren't implemented because websockets doesn't need them. + + def __getattr__(self, name: str) -> Any: + return getattr(self.ssl_socket, name) diff --git a/source/websockets/sync/connection.py b/source/websockets/sync/connection.py new file mode 100644 index 0000000000000000000000000000000000000000..665f478ac9377174ca5c912f5939f81bdd8db3d7 --- /dev/null +++ b/source/websockets/sync/connection.py @@ -0,0 +1,1078 @@ +from __future__ import annotations + +import contextlib +import logging +import random +import socket +import struct +import threading +import time +import uuid +from collections.abc import Iterable, Iterator, Mapping +from types import TracebackType +from typing import Any, Literal, overload + +from ..exceptions import ( + ConcurrencyError, + ConnectionClosed, + ConnectionClosedOK, + ProtocolError, +) +from ..frames import DATA_OPCODES, CloseCode, Frame, Opcode +from ..http11 import Request, Response +from ..protocol import CLOSED, OPEN, Event, Protocol, State +from ..typing import BytesLike, Data, DataLike, LoggerLike, Subprotocol +from .messages import Assembler +from .utils import Deadline + + +__all__ = ["Connection"] + + +class Connection: + """ + :mod:`threading` implementation of a WebSocket connection. + + :class:`Connection` provides APIs shared between WebSocket servers and + clients. + + You shouldn't use it directly. Instead, use + :class:`~websockets.sync.client.ClientConnection` or + :class:`~websockets.sync.server.ServerConnection`. + + """ + + recv_bufsize = 65536 + + def __init__( + self, + socket: socket.socket, + protocol: Protocol, + *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = 10, + max_queue: int | None | tuple[int | None, int | None] = 16, + ) -> None: + self.socket = socket + self.protocol = protocol + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout + self.close_timeout = close_timeout + if isinstance(max_queue, int) or max_queue is None: + max_queue_high, max_queue_low = max_queue, None + else: + max_queue_high, max_queue_low = max_queue + + # Inject reference to this instance in the protocol's logger. + self.protocol.logger = logging.LoggerAdapter( + self.protocol.logger, + {"websocket": self}, + ) + + # Copy attributes from the protocol for convenience. + self.id: uuid.UUID = self.protocol.id + """Unique identifier of the connection. Useful in logs.""" + self.logger: LoggerLike = self.protocol.logger + """Logger for this connection.""" + self.debug = self.protocol.debug + + # HTTP handshake request and response. + self.request: Request | None = None + """Opening handshake request.""" + self.response: Response | None = None + """Opening handshake response.""" + + # Mutex serializing interactions with the protocol. + self.protocol_mutex = threading.Lock() + + # Lock stopping reads when the assembler buffer is full. + self.recv_flow_control = threading.Lock() + + # Assembler turning frames into messages and serializing reads. + self.recv_messages = Assembler( + max_queue_high, + max_queue_low, + pause=self.recv_flow_control.acquire, + resume=self.recv_flow_control.release, + ) + + # Deadline for the closing handshake. + self.close_deadline: Deadline | None = None + + # Whether we are busy sending a fragmented message. + self.send_in_progress = False + + # Mapping of ping IDs to pong waiters, in chronological order. + self.pending_pings: dict[bytes, tuple[threading.Event, float, bool]] = {} + + self.latency: float = 0.0 + """ + Latency of the connection, in seconds. + + Latency is defined as the round-trip time of the connection. It is + measured by sending a Ping frame and waiting for a matching Pong frame. + Before the first measurement, :attr:`latency` is ``0.0``. + + By default, websockets enables a :ref:`keepalive ` mechanism + that sends Ping frames automatically at regular intervals. You can also + send Ping frames and measure latency with :meth:`ping`. + """ + + # Thread that sends keepalive pings. None when ping_interval is None. + self.keepalive_thread: threading.Thread | None = None + + # Exception raised while reading from the connection, to be chained to + # ConnectionClosed in order to show why the TCP connection dropped. + self.recv_exc: BaseException | None = None + + # Receiving events from the socket. This thread is marked as daemon to + # allow creating a connection in a non-daemon thread and using it in a + # daemon thread. This mustn't prevent the interpreter from exiting. + self.recv_events_thread = threading.Thread( + target=self.recv_events, + daemon=True, + ) + + # Start recv_events only after all attributes are initialized. + self.recv_events_thread.start() + + # Public attributes + + @property + def local_address(self) -> Any: + """ + Local address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family. + See :meth:`~socket.socket.getsockname`. + + """ + return self.socket.getsockname() + + @property + def remote_address(self) -> Any: + """ + Remote address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family. + See :meth:`~socket.socket.getpeername`. + + """ + return self.socket.getpeername() + + @property + def state(self) -> State: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should call :meth:`~recv` or + :meth:`send` and handle :exc:`~websockets.exceptions.ConnectionClosed` + exceptions. + + """ + return self.protocol.state + + @property + def subprotocol(self) -> Subprotocol | None: + """ + Subprotocol negotiated during the opening handshake. + + :obj:`None` if no subprotocol was negotiated. + + """ + return self.protocol.subprotocol + + @property + def close_code(self) -> int | None: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should inspect attributes + of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.close_code + + @property + def close_reason(self) -> str | None: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should inspect attributes + of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.close_reason + + # Public methods + + def __enter__(self) -> Connection: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + if exc_type is None: + self.close() + else: + self.close(CloseCode.INTERNAL_ERROR) + + def __iter__(self) -> Iterator[Data]: + """ + Iterate on incoming messages. + + The iterator calls :meth:`recv` and yields messages in an infinite loop. + + It exits when the connection is closed normally. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` exception after a + protocol error or a network failure. + + """ + try: + while True: + yield self.recv() + except ConnectionClosedOK: + return + + # This overload structure is required to avoid the error: + # "parameter without a default follows parameter with a default" + + @overload + def recv(self, timeout: float | None, decode: Literal[True]) -> str: ... + + @overload + def recv(self, timeout: float | None, decode: Literal[False]) -> bytes: ... + + @overload + def recv(self, timeout: float | None = None, *, decode: Literal[True]) -> str: ... + + @overload + def recv( + self, timeout: float | None = None, *, decode: Literal[False] + ) -> bytes: ... + + @overload + def recv( + self, timeout: float | None = None, decode: bool | None = None + ) -> Data: ... + + def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data: + """ + Receive the next message. + + When the connection is closed, :meth:`recv` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises + :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal closure + and :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. This is how you detect the end of the + message stream. + + If ``timeout`` is :obj:`None`, block until a message is received. If + ``timeout`` is set, wait up to ``timeout`` seconds for a message to be + received and return it, else raise :exc:`TimeoutError`. If ``timeout`` + is ``0`` or negative, check if a message has been received already and + return it, else raise :exc:`TimeoutError`. + + When the message is fragmented, :meth:`recv` waits until all fragments + are received, reassembles them, and returns the whole message. + + Args: + timeout: Timeout for receiving a message in seconds. + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + + Returns: + A string (:class:`str`) for a Text_ frame or a bytestring + (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and + return a bytestring (:class:`bytes`). This improves performance + when decoding isn't needed, for example if the message contains + JSON and you're using a JSON library that expects a bytestring. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames and + return strings (:class:`str`). This may be useful for servers that + send binary frames instead of text frames. + + Raises: + ConnectionClosed: When the connection is closed. + ConcurrencyError: If two threads call :meth:`recv` or + :meth:`recv_streaming` concurrently. + + """ + try: + return self.recv_messages.get(timeout, decode) + except EOFError: + pass + # fallthrough + except ConcurrencyError: + raise ConcurrencyError( + "cannot call recv while another thread " + "is already running recv or recv_streaming" + ) from None + except UnicodeDecodeError as exc: + with self.send_context(): + self.protocol.fail( + CloseCode.INVALID_DATA, + f"{exc.reason} at position {exc.start}", + ) + # fallthrough + + # Wait for the protocol state to be CLOSED before accessing close_exc. + self.recv_events_thread.join() + raise self.protocol.close_exc from self.recv_exc + + @overload + def recv_streaming(self, decode: Literal[True]) -> Iterator[str]: ... + + @overload + def recv_streaming(self, decode: Literal[False]) -> Iterator[bytes]: ... + + @overload + def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]: ... + + def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]: + """ + Receive the next message frame by frame. + + This method is designed for receiving fragmented messages. It returns an + iterator that yields each fragment as it is received. This iterator must + be fully consumed. Else, future calls to :meth:`recv` or + :meth:`recv_streaming` will raise + :exc:`~websockets.exceptions.ConcurrencyError`, making the connection + unusable. + + :meth:`recv_streaming` raises the same exceptions as :meth:`recv`. + + Args: + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + + Returns: + An iterator of strings (:class:`str`) for a Text_ frame or + bytestrings (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and + yield bytestrings (:class:`bytes`). This improves performance + when decoding isn't needed. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames and + yield strings (:class:`str`). This may be useful for servers that + send binary frames instead of text frames. + + Raises: + ConnectionClosed: When the connection is closed. + ConcurrencyError: If two threads call :meth:`recv` or + :meth:`recv_streaming` concurrently. + + """ + try: + yield from self.recv_messages.get_iter(decode) + return + except EOFError: + pass + # fallthrough + except ConcurrencyError: + raise ConcurrencyError( + "cannot call recv_streaming while another thread " + "is already running recv or recv_streaming" + ) from None + except UnicodeDecodeError as exc: + with self.send_context(): + self.protocol.fail( + CloseCode.INVALID_DATA, + f"{exc.reason} at position {exc.start}", + ) + # fallthrough + + # Wait for the protocol state to be CLOSED before accessing close_exc. + self.recv_events_thread.join() + raise self.protocol.close_exc from self.recv_exc + + def send( + self, + message: DataLike | Iterable[DataLike], + text: bool | None = None, + ) -> None: + """ + Send a message. + + A string (:class:`str`) is sent as a Text_ frame. A bytestring or + bytes-like object (:class:`bytes`, :class:`bytearray`, or + :class:`memoryview`) is sent as a Binary_ frame. + + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + You may override this behavior with the ``text`` argument: + + * Set ``text=True`` to send an UTF-8 bytestring or bytes-like object + (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) in a + Text_ frame. This improves performance when the message is already + UTF-8 encoded, for example if the message contains JSON and you're + using a JSON library that produces a bytestring. + * Set ``text=False`` to send a string (:class:`str`) in a Binary_ + frame. This may be useful for servers that expect binary frames + instead of text frames. + + :meth:`send` also accepts an iterable of strings, bytestrings, or + bytes-like objects to enable fragmentation_. Each item is treated as a + message fragment and sent in its own frame. All items must be of the + same type, or else :meth:`send` will raise a :exc:`TypeError` and the + connection will be closed. + + .. _fragmentation: https://datatracker.ietf.org/doc/html/rfc6455#section-5.4 + + :meth:`send` rejects dict-like objects because this is often an error. + (If you really want to send the keys of a dict-like object as fragments, + call its :meth:`~dict.keys` method and pass the result to :meth:`send`.) + + When the connection is closed, :meth:`send` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it + raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal + connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. + + Args: + message: Message to send. + + Raises: + ConnectionClosed: When the connection is closed. + ConcurrencyError: If the connection is sending a fragmented message. + TypeError: If ``message`` doesn't have a supported type. + + """ + # Unfragmented message -- this case must be handled first because + # strings and bytes-like objects are iterable. + + if isinstance(message, str): + with self.send_context(): + if self.send_in_progress: + raise ConcurrencyError( + "cannot call send while another thread is already running send" + ) + if text is False: + self.protocol.send_binary(message.encode()) + else: + self.protocol.send_text(message.encode()) + + elif isinstance(message, BytesLike): + with self.send_context(): + if self.send_in_progress: + raise ConcurrencyError( + "cannot call send while another thread is already running send" + ) + if text is True: + self.protocol.send_text(message) + else: + self.protocol.send_binary(message) + + # Catch a common mistake -- passing a dict to send(). + + elif isinstance(message, Mapping): + raise TypeError("data is a dict-like object") + + # Fragmented message -- regular iterator. + + elif isinstance(message, Iterable): + chunks = iter(message) + try: + chunk = next(chunks) + except StopIteration: + return + + try: + # First fragment. + if isinstance(chunk, str): + with self.send_context(): + if self.send_in_progress: + raise ConcurrencyError( + "cannot call send while another thread " + "is already running send" + ) + self.send_in_progress = True + if text is False: + self.protocol.send_binary(chunk.encode(), fin=False) + else: + self.protocol.send_text(chunk.encode(), fin=False) + encode = True + elif isinstance(chunk, BytesLike): + with self.send_context(): + if self.send_in_progress: + raise ConcurrencyError( + "cannot call send while another thread " + "is already running send" + ) + self.send_in_progress = True + if text is True: + self.protocol.send_text(chunk, fin=False) + else: + self.protocol.send_binary(chunk, fin=False) + encode = False + else: + raise TypeError("iterable must contain bytes or str") + + # Other fragments + for chunk in chunks: + if isinstance(chunk, str) and encode: + with self.send_context(): + assert self.send_in_progress + self.protocol.send_continuation(chunk.encode(), fin=False) + elif isinstance(chunk, BytesLike) and not encode: + with self.send_context(): + assert self.send_in_progress + self.protocol.send_continuation(chunk, fin=False) + else: + raise TypeError("iterable must contain uniform types") + + # Final fragment. + with self.send_context(): + self.protocol.send_continuation(b"", fin=True) + self.send_in_progress = False + + except ConcurrencyError: + # We didn't start sending a fragmented message. + # The connection is still usable. + raise + + except Exception: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + with self.send_context(): + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "error in fragmented message", + ) + raise + + else: + raise TypeError("data must be str, bytes, or iterable") + + def close( + self, + code: CloseCode | int = CloseCode.NORMAL_CLOSURE, + reason: str = "", + ) -> None: + """ + Perform the closing handshake. + + :meth:`close` waits for the other end to complete the handshake and + for the TCP connection to terminate. + + :meth:`close` is idempotent: it doesn't do anything once the + connection is closed. + + Args: + code: WebSocket close code. + reason: WebSocket close reason. + + """ + try: + # The context manager takes care of waiting for the TCP connection + # to terminate after calling a method that sends a close frame. + with self.send_context(): + if self.send_in_progress: + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "close during fragmented message", + ) + else: + self.protocol.send_close(code, reason) + except ConnectionClosed: + # Ignore ConnectionClosed exceptions raised from send_context(). + # They mean that the connection is closed, which was the goal. + pass + + def ping( + self, + data: DataLike | None = None, + ack_on_close: bool = False, + ) -> threading.Event: + """ + Send a Ping_. + + .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 + + A ping may serve as a keepalive or as a check that the remote endpoint + received all messages up to this point + + Args: + data: Payload of the ping. A :class:`str` will be encoded to UTF-8. + If ``data`` is :obj:`None`, the payload is four random bytes. + ack_on_close: when this option is :obj:`True`, the event will also + be set when the connection is closed. While this avoids getting + stuck waiting for a pong that will never arrive, it requires + checking that the state of the connection is still ``OPEN`` to + confirm that a pong was received, rather than the connection + being closed. + + Returns: + An event that will be set when the corresponding pong is received. + You can ignore it if you don't intend to wait. + + :: + + pong_received = ws.ping() + # only if you want to wait for the corresponding pong + pong_received.wait() + + Raises: + ConnectionClosed: When the connection is closed. + ConcurrencyError: If another ping was sent with the same data and + the corresponding pong wasn't received yet. + + """ + if isinstance(data, BytesLike): + data = bytes(data) + elif isinstance(data, str): + data = data.encode() + elif data is not None: + raise TypeError("data must be str or bytes-like") + + with self.send_context(): + # Protect against duplicates if a payload is explicitly set. + if data in self.pending_pings: + raise ConcurrencyError("already waiting for a pong with the same data") + + # Generate a unique random payload otherwise. + while data is None or data in self.pending_pings: + data = struct.pack("!I", random.getrandbits(32)) + + pong_received = threading.Event() + ping_timestamp = time.monotonic() + self.pending_pings[data] = (pong_received, ping_timestamp, ack_on_close) + self.protocol.send_ping(data) + return pong_received + + def pong(self, data: DataLike = b"") -> None: + """ + Send a Pong_. + + .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 + + An unsolicited pong may serve as a unidirectional heartbeat. + + Args: + data: Payload of the pong. A :class:`str` will be encoded to UTF-8. + + Raises: + ConnectionClosed: When the connection is closed. + + """ + if isinstance(data, BytesLike): + data = bytes(data) + elif isinstance(data, str): + data = data.encode() + else: + raise TypeError("data must be str or bytes-like") + + with self.send_context(): + self.protocol.send_pong(data) + + # Private methods + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + This method is overridden in subclasses to handle the handshake. + + """ + assert isinstance(event, Frame) + if event.opcode in DATA_OPCODES: + self.recv_messages.put(event) + + if event.opcode is Opcode.PONG: + self.acknowledge_pings(bytes(event.data)) + + def acknowledge_pings(self, data: bytes) -> None: + """ + Acknowledge pings when receiving a pong. + + """ + with self.protocol_mutex: + # Ignore unsolicited pong. + if data not in self.pending_pings: + return + + pong_timestamp = time.monotonic() + + # Sending a pong for only the most recent ping is legal. + # Acknowledge all previous pings too in that case. + ping_id = None + ping_ids = [] + for ping_id, ( + pong_received, + ping_timestamp, + _ack_on_close, + ) in self.pending_pings.items(): + ping_ids.append(ping_id) + pong_received.set() + if ping_id == data: + self.latency = pong_timestamp - ping_timestamp + break + else: + raise AssertionError("solicited pong not found in pings") + + # Remove acknowledged pings from self.pending_pings. + for ping_id in ping_ids: + del self.pending_pings[ping_id] + + def terminate_pending_pings(self) -> None: + """ + Acknowledge pending pings when the connection is closed. + + """ + assert self.protocol_mutex.locked() + assert self.protocol.state is CLOSED + + for pong_received, _ping_timestamp, ack_on_close in self.pending_pings.values(): + if ack_on_close: + pong_received.set() + + self.pending_pings.clear() + + def keepalive(self) -> None: + """ + Send a Ping frame and wait for a Pong frame at regular intervals. + + """ + assert self.ping_interval is not None + try: + while True: + # If self.ping_timeout > self.latency > self.ping_interval, + # pings will be sent immediately after receiving pongs. + # The period will be longer than self.ping_interval. + self.recv_events_thread.join(self.ping_interval - self.latency) + if not self.recv_events_thread.is_alive(): + break + + try: + pong_received = self.ping(ack_on_close=True) + except ConnectionClosed: + break + if self.debug: + self.logger.debug("% sent keepalive ping") + + if self.ping_timeout is not None: + if pong_received.wait(self.ping_timeout): + if self.debug: + self.logger.debug("% received keepalive pong") + else: + if self.debug: + self.logger.debug("- timed out waiting for keepalive pong") + with self.send_context(): + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "keepalive ping timeout", + ) + break + except Exception: + self.logger.error("keepalive ping failed", exc_info=True) + + def start_keepalive(self) -> None: + """ + Run :meth:`keepalive` in a thread, unless keepalive is disabled. + + """ + if self.ping_interval is not None: + # This thread is marked as daemon like self.recv_events_thread. + self.keepalive_thread = threading.Thread( + target=self.keepalive, + daemon=True, + ) + self.keepalive_thread.start() + + def recv_events(self) -> None: + """ + Read incoming data from the socket and process events. + + Run this method in a thread as long as the connection is alive. + + ``recv_events()`` exits immediately when ``self.socket`` is closed. + + """ + try: + while True: + try: + # If the assembler buffer is full, block until it drains. + with self.recv_flow_control: + pass + if self.close_deadline is not None: + self.socket.settimeout(self.close_deadline.timeout()) + data = self.socket.recv(self.recv_bufsize) + except Exception as exc: + if self.debug: + self.logger.debug( + "! error while receiving data", + exc_info=True, + ) + # When the closing handshake is initiated by our side, + # recv() may block until send_context() closes the socket. + # In that case, send_context() already set recv_exc. + # Calling set_recv_exc() avoids overwriting it. + with self.protocol_mutex: + self.set_recv_exc(exc) + break + + if data == b"": + break + + # Acquire the connection lock. + with self.protocol_mutex: + # Feed incoming data to the protocol. + self.protocol.receive_data(data) + + # This isn't expected to raise an exception. + events = self.protocol.events_received() + + # Write outgoing data to the socket. + try: + self.send_data() + except Exception as exc: + if self.debug: + self.logger.debug( + "! error while sending data", + exc_info=True, + ) + # Similarly to the above, avoid overriding an exception + # set by send_context(), in case of a race condition + # i.e. send_context() closes the socket after recv() + # returns above but before send_data() calls send(). + self.set_recv_exc(exc) + break + + # If needed, set the close deadline based on the close timeout. + if self.protocol.close_expected(): + if self.close_deadline is None: + self.close_deadline = Deadline(self.close_timeout) + + # Unlock conn_mutex before processing events. Else, the + # application can't send messages in response to events. + + # If self.send_data raised an exception, then events are lost. + # Given that automatic responses write small amounts of data, + # this should be uncommon, so we don't handle the edge case. + + for event in events: + # This isn't expected to raise an exception. + self.process_event(event) + + # Breaking out of the while True: ... loop means that we believe + # that the socket doesn't work anymore. + + with self.protocol_mutex: + # Feed the end of the data stream to the protocol. + self.protocol.receive_eof() + + # This isn't expected to raise an exception. + events = self.protocol.events_received() + + # There is no error handling because send_data() can only write + # the end of the data stream and it handles errors by itself. + self.send_data() + + # This code path is triggered when receiving an HTTP response + # without a Content-Length header. This is the only case where + # reading until EOF generates an event; all other events have + # a known length. Ignore for coverage measurement because tests + # are in test_client.py rather than test_connection.py. + for event in events: # pragma: no cover + # This isn't expected to raise an exception. + self.process_event(event) + + except Exception as exc: + # This branch should never run. It's a safety net in case of bugs. + self.logger.error("unexpected internal error", exc_info=True) + with self.protocol_mutex: + self.set_recv_exc(exc) + finally: + # This isn't expected to raise an exception. + self.close_socket() + + @contextlib.contextmanager + def send_context( + self, + *, + expected_state: State = OPEN, # CONNECTING during the opening handshake + ) -> Iterator[None]: + """ + Create a context for writing to the connection from user code. + + On entry, :meth:`send_context` acquires the connection lock and checks + that the connection is open; on exit, it writes outgoing data to the + socket and releases the connection lock:: + + with self.send_context(): + self.protocol.send_text(message.encode()) + + When the connection isn't open on entry, when the connection is expected + to close on exit, or when an unexpected error happens, terminating the + connection, :meth:`send_context` waits until the connection is closed + then raises :exc:`~websockets.exceptions.ConnectionClosed`. + + """ + # Should we wait until the connection is closed? + wait_for_close = False + # Should we close the socket and raise ConnectionClosed? + raise_close_exc = False + # What exception should we chain ConnectionClosed to? + original_exc: BaseException | None = None + + # Acquire the protocol lock. + with self.protocol_mutex: + if self.protocol.state is expected_state: + # Let the caller interact with the protocol. + try: + yield + except (ProtocolError, ConcurrencyError): + # The protocol state wasn't changed. Exit immediately. + raise + except Exception as exc: + self.logger.error("unexpected internal error", exc_info=True) + # This branch should never run. It's a safety net in case of + # bugs. Since we don't know what happened, we will close the + # connection and raise the exception to the caller. + wait_for_close = False + raise_close_exc = True + original_exc = exc + else: + # Check if the connection is expected to close soon. + if self.protocol.close_expected(): + wait_for_close = True + # Set the close deadline based on the close timeout. + # Since we tested earlier that protocol.state is OPEN + # (or CONNECTING) and we didn't release protocol_mutex, + # self.close_deadline is still None. + assert self.close_deadline is None + self.close_deadline = Deadline(self.close_timeout) + # Write outgoing data to the socket. + try: + self.send_data() + except Exception as exc: + if self.debug: + self.logger.debug( + "! error while sending data", + exc_info=True, + ) + # While the only expected exception here is OSError, + # other exceptions would be treated identically. + wait_for_close = False + raise_close_exc = True + original_exc = exc + + else: # self.protocol.state is not expected_state + # Minor layering violation: we assume that the connection + # will be closing soon if it isn't in the expected state. + wait_for_close = True + # Calculate close_deadline if it wasn't set yet. + if self.close_deadline is None: + self.close_deadline = Deadline(self.close_timeout) + raise_close_exc = True + + # To avoid a deadlock, release the connection lock by exiting the + # context manager before waiting for recv_events() to terminate. + + # If the connection is expected to close soon and the close timeout + # elapses, close the socket to terminate the connection. + if wait_for_close: + # Thread.join() returns immediately if timeout is negative. + assert self.close_deadline is not None + timeout = self.close_deadline.timeout(raise_if_elapsed=False) + self.recv_events_thread.join(timeout) + if self.recv_events_thread.is_alive(): + # There's no risk of overwriting another error because + # original_exc is never set when wait_for_close is True. + assert original_exc is None + original_exc = TimeoutError("timed out while closing connection") + # Set recv_exc before closing the socket in order to get + # proper exception reporting. + raise_close_exc = True + with self.protocol_mutex: + self.set_recv_exc(original_exc) + + # If an error occurred, close the socket to terminate the connection and + # raise an exception. + if raise_close_exc: + self.close_socket() + # Wait for the protocol state to be CLOSED before accessing close_exc. + self.recv_events_thread.join() + raise self.protocol.close_exc from original_exc + + def send_data(self) -> None: + """ + Send outgoing data. + + This method requires holding protocol_mutex. + + """ + assert self.protocol_mutex.locked() + for data in self.protocol.data_to_send(): + if data: + if self.close_deadline is not None: + self.socket.settimeout(self.close_deadline.timeout()) + self.socket.sendall(data) + else: + try: + self.socket.shutdown(socket.SHUT_WR) + except OSError: # socket already closed + pass + + def set_recv_exc(self, exc: BaseException | None) -> None: + """ + Set recv_exc, if not set yet. + + This method requires holding protocol_mutex and must be called only from + the thread running recv_events(). + + """ + assert self.protocol_mutex.locked() + if self.recv_exc is None: + self.recv_exc = exc + + def close_socket(self) -> None: + """ + Shutdown and close socket. Close message assembler. + + Calling close_socket() guarantees that recv_events() terminates. Indeed, + recv_events() may block only on socket.recv() or on recv_messages.put(). + + """ + # shutdown() is required to interrupt recv() on Linux. + try: + self.socket.shutdown(socket.SHUT_RDWR) + except OSError: # socket already closed + pass + self.socket.close() + + # Calling protocol.receive_eof() is safe because it's idempotent. + # This guarantees that the protocol state becomes CLOSED. + with self.protocol_mutex: + self.protocol.receive_eof() + assert self.protocol.state is CLOSED + + # Abort recv() with a ConnectionClosed exception. + self.recv_messages.close() + + # Acknowledge pings sent with the ack_on_close option. + self.terminate_pending_pings() diff --git a/source/websockets/sync/messages.py b/source/websockets/sync/messages.py new file mode 100644 index 0000000000000000000000000000000000000000..d95519f63ecfb20e681de13ed6bf3481c966d2d3 --- /dev/null +++ b/source/websockets/sync/messages.py @@ -0,0 +1,348 @@ +from __future__ import annotations + +import codecs +import queue +import threading +from typing import Any, Callable, Iterable, Iterator, Literal, overload + +from ..exceptions import ConcurrencyError +from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame +from ..typing import Data +from .utils import Deadline + + +__all__ = ["Assembler"] + +UTF8Decoder = codecs.getincrementaldecoder("utf-8") + + +class Assembler: + """ + Assemble messages from frames. + + :class:`Assembler` expects only data frames. The stream of frames must + respect the protocol; if it doesn't, the behavior is undefined. + + Args: + pause: Called when the buffer of frames goes above the high water mark; + should pause reading from the network. + resume: Called when the buffer of frames goes below the low water mark; + should resume reading from the network. + + """ + + def __init__( + self, + high: int | None = None, + low: int | None = None, + pause: Callable[[], Any] = lambda: None, + resume: Callable[[], Any] = lambda: None, + ) -> None: + # Serialize reads and writes -- except for reads via synchronization + # primitives provided by the threading and queue modules. + self.mutex = threading.Lock() + + # Queue of incoming frames. + self.frames: queue.SimpleQueue[Frame | None] = queue.SimpleQueue() + + # We cannot put a hard limit on the size of the queue because a single + # call to Protocol.data_received() could produce thousands of frames, + # which must be buffered. Instead, we pause reading when the buffer goes + # above the high limit and we resume when it goes under the low limit. + if high is not None and low is None: + low = high // 4 + if high is None and low is not None: + high = low * 4 + if high is not None and low is not None: + if low < 0: + raise ValueError("low must be positive or equal to zero") + if high < low: + raise ValueError("high must be greater than or equal to low") + self.high, self.low = high, low + self.pause = pause + self.resume = resume + self.paused = False + + # This flag prevents concurrent calls to get() by user code. + self.get_in_progress = False + + # This flag marks the end of the connection. + self.closed = False + + def get_next_frame(self, timeout: float | None = None) -> Frame: + # Helper to factor out the logic for getting the next frame from the + # queue, while handling timeouts and reaching the end of the stream. + if self.closed: + try: + frame = self.frames.get(block=False) + except queue.Empty: + raise EOFError("stream of frames ended") from None + else: + try: + # Check for a frame that's already received if timeout <= 0. + # SimpleQueue.get() doesn't support negative timeout values. + if timeout is not None and timeout <= 0: + frame = self.frames.get(block=False) + else: + frame = self.frames.get(block=True, timeout=timeout) + except queue.Empty: + raise TimeoutError(f"timed out in {timeout:.1f}s") from None + if frame is None: + raise EOFError("stream of frames ended") + return frame + + def reset_queue(self, frames: Iterable[Frame]) -> None: + # Helper to put frames back into the queue after they were fetched. + # This happens only when the queue is empty. However, by the time + # we acquire self.mutex, put() may have added items in the queue. + # Therefore, we must handle the case where the queue is not empty. + frame: Frame | None + with self.mutex: + queued = [] + try: + while True: + queued.append(self.frames.get(block=False)) + except queue.Empty: + pass + for frame in frames: + self.frames.put(frame) + # This loop runs only when a race condition occurs. + for frame in queued: # pragma: no cover + self.frames.put(frame) + + # This overload structure is required to avoid the error: + # "parameter without a default follows parameter with a default" + + @overload + def get(self, timeout: float | None, decode: Literal[True]) -> str: ... + + @overload + def get(self, timeout: float | None, decode: Literal[False]) -> bytes: ... + + @overload + def get(self, timeout: float | None = None, *, decode: Literal[True]) -> str: ... + + @overload + def get(self, timeout: float | None = None, *, decode: Literal[False]) -> bytes: ... + + @overload + def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: ... + + def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: + """ + Read the next message. + + :meth:`get` returns a single :class:`str` or :class:`bytes`. + + If the message is fragmented, :meth:`get` waits until the last frame is + received, then it reassembles the message and returns it. To receive + messages frame by frame, use :meth:`get_iter` instead. + + Args: + timeout: If a timeout is provided and elapses before a complete + message is received, :meth:`get` raises :exc:`TimeoutError`. + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. + + Raises: + EOFError: If the stream of frames has ended. + UnicodeDecodeError: If a text frame contains invalid UTF-8. + ConcurrencyError: If two coroutines run :meth:`get` or + :meth:`get_iter` concurrently. + TimeoutError: If a timeout is provided and elapses before a + complete message is received. + + """ + with self.mutex: + if self.get_in_progress: + raise ConcurrencyError("get() or get_iter() is already running") + self.get_in_progress = True + + # Locking with get_in_progress prevents concurrent execution + # until get() fetches a complete message or times out. + + try: + deadline = Deadline(timeout) + + # Fetch the first frame. + frame = self.get_next_frame(deadline.timeout(raise_if_elapsed=False)) + with self.mutex: + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + frames = [frame] + + # Fetch subsequent frames for fragmented messages. + while not frame.fin: + try: + frame = self.get_next_frame( + deadline.timeout(raise_if_elapsed=False) + ) + except TimeoutError: + # Put frames already received back into the queue + # so that future calls to get() can return them. + self.reset_queue(frames) + raise + with self.mutex: + self.maybe_resume() + assert frame.opcode is OP_CONT + frames.append(frame) + + finally: + self.get_in_progress = False + + # This converts frame.data to bytes when it's a bytearray. + data = b"".join(frame.data for frame in frames) + if decode: + return data.decode() + else: + return data + + @overload + def get_iter(self, decode: Literal[True]) -> Iterator[str]: ... + + @overload + def get_iter(self, decode: Literal[False]) -> Iterator[bytes]: ... + + @overload + def get_iter(self, decode: bool | None = None) -> Iterator[Data]: ... + + def get_iter(self, decode: bool | None = None) -> Iterator[Data]: + """ + Stream the next message. + + Iterating the return value of :meth:`get_iter` yields a :class:`str` or + :class:`bytes` for each frame in the message. + + The iterator must be fully consumed before calling :meth:`get_iter` or + :meth:`get` again. Else, :exc:`ConcurrencyError` is raised. + + This method only makes sense for fragmented messages. If messages aren't + fragmented, use :meth:`get` instead. + + Args: + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. + + Raises: + EOFError: If the stream of frames has ended. + UnicodeDecodeError: If a text frame contains invalid UTF-8. + ConcurrencyError: If two coroutines run :meth:`get` or + :meth:`get_iter` concurrently. + + """ + with self.mutex: + if self.get_in_progress: + raise ConcurrencyError("get() or get_iter() is already running") + self.get_in_progress = True + + # Locking with get_in_progress prevents concurrent execution + # until get_iter() fetches a complete message or times out. + + # If get_iter() raises an exception e.g. in decoder.decode(), + # get_in_progress remains set and the connection becomes unusable. + + # Yield the first frame. + frame = self.get_next_frame() + with self.mutex: + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + if decode: + decoder = UTF8Decoder() + yield decoder.decode(frame.data, frame.fin) + else: + # Convert to bytes when frame.data is a bytearray. + yield bytes(frame.data) + + # Yield subsequent frames for fragmented messages. + while not frame.fin: + frame = self.get_next_frame() + with self.mutex: + self.maybe_resume() + assert frame.opcode is OP_CONT + if decode: + yield decoder.decode(frame.data, frame.fin) + else: + # Convert to bytes when frame.data is a bytearray. + yield bytes(frame.data) + + self.get_in_progress = False + + def put(self, frame: Frame) -> None: + """ + Add ``frame`` to the next message. + + Raises: + EOFError: If the stream of frames has ended. + + """ + with self.mutex: + if self.closed: + raise EOFError("stream of frames ended") + + self.frames.put(frame) + self.maybe_pause() + + # put() and get/get_iter() call maybe_pause() and maybe_resume() while + # holding self.mutex. This guarantees that the calls interleave properly. + # Specifically, it prevents a race condition where maybe_resume() would + # run before maybe_pause(), leaving the connection incorrectly paused. + + # A race condition is possible when get/get_iter() call self.frames.get() + # without holding self.mutex. However, it's harmless — and even beneficial! + # It can only result in popping an item from the queue before maybe_resume() + # runs and skipping a pause() - resume() cycle that would otherwise occur. + + def maybe_pause(self) -> None: + """Pause the writer if queue is above the high water mark.""" + # Skip if flow control is disabled. + if self.high is None: + return + + assert self.mutex.locked() + + # Check for "> high" to support high = 0. + if self.frames.qsize() > self.high and not self.paused: + self.paused = True + self.pause() + + def maybe_resume(self) -> None: + """Resume the writer if queue is below the low water mark.""" + # Skip if flow control is disabled. + if self.low is None: + return + + assert self.mutex.locked() + + # Check for "<= low" to support low = 0. + if self.frames.qsize() <= self.low and self.paused: + self.paused = False + self.resume() + + def close(self) -> None: + """ + End the stream of frames. + + Calling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, + or :meth:`put` is safe. They will raise :exc:`EOFError`. + + """ + with self.mutex: + if self.closed: + return + + self.closed = True + + if self.get_in_progress: + # Unblock get() or get_iter(). + self.frames.put(None) + + if self.paused: + # Unblock recv_events(). + self.paused = False + self.resume() diff --git a/source/websockets/sync/router.py b/source/websockets/sync/router.py new file mode 100644 index 0000000000000000000000000000000000000000..1c35e8aaea909b3dd1a16f9346cd94f5063aa17c --- /dev/null +++ b/source/websockets/sync/router.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +import http +import ssl as ssl_module +import urllib.parse +from typing import Any, Callable, Literal + +from ..http11 import Request, Response +from .server import Server, ServerConnection, serve + + +__all__ = ["route", "unix_route", "Router"] + + +try: + from werkzeug.exceptions import NotFound + from werkzeug.routing import Map, RequestRedirect + +except ImportError: + + def route( + url_map: Map, + *args: Any, + server_name: str | None = None, + ssl: ssl_module.SSLContext | Literal[True] | None = None, + create_router: type[Router] | None = None, + **kwargs: Any, + ) -> Server: + raise ImportError("route() requires werkzeug") + + def unix_route( + url_map: Map, + path: str | None = None, + **kwargs: Any, + ) -> Server: + raise ImportError("unix_route() requires werkzeug") + +else: + + def route( + url_map: Map, + *args: Any, + server_name: str | None = None, + ssl: ssl_module.SSLContext | Literal[True] | None = None, + create_router: type[Router] | None = None, + **kwargs: Any, + ) -> Server: + """ + Create a WebSocket server dispatching connections to different handlers. + + This feature requires the third-party library `werkzeug`_: + + .. code-block:: console + + $ pip install werkzeug + + .. _werkzeug: https://werkzeug.palletsprojects.com/ + + :func:`route` accepts the same arguments as + :func:`~websockets.sync.server.serve`, except as described below. + + The first argument is a :class:`werkzeug.routing.Map` that maps URL patterns + to connection handlers. In addition to the connection, handlers receive + parameters captured in the URL as keyword arguments. + + Here's an example:: + + + from websockets.sync.router import route + from werkzeug.routing import Map, Rule + + def channel_handler(websocket, channel_id): + ... + + url_map = Map([ + Rule("/channel/", endpoint=channel_handler), + ... + ]) + + with route(url_map, ...) as server: + server.serve_forever() + + Refer to the documentation of :mod:`werkzeug.routing` for details. + + If you define redirects with ``Rule(..., redirect_to=...)`` in the URL map, + when the server runs behind a reverse proxy that modifies the ``Host`` + header or terminates TLS, you need additional configuration: + + * Set ``server_name`` to the name of the server as seen by clients. When + not provided, websockets uses the value of the ``Host`` header. + + * Set ``ssl=True`` to generate ``wss://`` URIs without enabling TLS. + Under the hood, this bind the URL map with a ``url_scheme`` of + ``wss://`` instead of ``ws://``. + + There is no need to specify ``websocket=True`` in each rule. It is added + automatically. + + Args: + url_map: Mapping of URL patterns to connection handlers. + server_name: Name of the server as seen by clients. If :obj:`None`, + websockets uses the value of the ``Host`` header. + ssl: Configuration for enabling TLS on the connection. Set it to + :obj:`True` if a reverse proxy terminates TLS connections. + create_router: Factory for the :class:`Router` dispatching requests to + handlers. Set it to a wrapper or a subclass to customize routing. + + """ + url_scheme = "ws" if ssl is None else "wss" + if ssl is not True and ssl is not None: + kwargs["ssl"] = ssl + + if create_router is None: + create_router = Router + + router = create_router(url_map, server_name, url_scheme) + + _process_request: ( + Callable[ + [ServerConnection, Request], + Response | None, + ] + | None + ) = kwargs.pop("process_request", None) + if _process_request is None: + process_request: Callable[ + [ServerConnection, Request], + Response | None, + ] = router.route_request + else: + + def process_request( + connection: ServerConnection, request: Request + ) -> Response | None: + response = _process_request(connection, request) + if response is not None: + return response + return router.route_request(connection, request) + + return serve(router.handler, *args, process_request=process_request, **kwargs) + + def unix_route( + url_map: Map, + path: str | None = None, + **kwargs: Any, + ) -> Server: + """ + Create a WebSocket Unix server dispatching connections to different handlers. + + :func:`unix_route` combines the behaviors of :func:`route` and + :func:`~websockets.sync.server.unix_serve`. + + Args: + url_map: Mapping of URL patterns to connection handlers. + path: File system path to the Unix socket. + + """ + return route(url_map, unix=True, path=path, **kwargs) + + +class Router: + """WebSocket router supporting :func:`route`.""" + + def __init__( + self, + url_map: Map, + server_name: str | None = None, + url_scheme: str = "ws", + ) -> None: + self.url_map = url_map + self.server_name = server_name + self.url_scheme = url_scheme + for rule in self.url_map.iter_rules(): + rule.websocket = True + + def get_server_name(self, connection: ServerConnection, request: Request) -> str: + if self.server_name is None: + return request.headers["Host"] + else: + return self.server_name + + def redirect(self, connection: ServerConnection, url: str) -> Response: + response = connection.respond(http.HTTPStatus.FOUND, f"Found at {url}") + response.headers["Location"] = url + return response + + def not_found(self, connection: ServerConnection) -> Response: + return connection.respond(http.HTTPStatus.NOT_FOUND, "Not Found") + + def route_request( + self, connection: ServerConnection, request: Request + ) -> Response | None: + """Route incoming request.""" + url_map_adapter = self.url_map.bind( + server_name=self.get_server_name(connection, request), + url_scheme=self.url_scheme, + ) + try: + parsed = urllib.parse.urlparse(request.path) + handler, kwargs = url_map_adapter.match( + path_info=parsed.path, + query_args=parsed.query, + ) + except RequestRedirect as redirect: + return self.redirect(connection, redirect.new_url) + except NotFound: + return self.not_found(connection) + connection.handler, connection.handler_kwargs = handler, kwargs + return None + + def handler(self, connection: ServerConnection) -> None: + """Handle a connection.""" + return connection.handler(connection, **connection.handler_kwargs) diff --git a/source/websockets/sync/server.py b/source/websockets/sync/server.py new file mode 100644 index 0000000000000000000000000000000000000000..ffd82fbad36e10e2f7a6a96cbcfba432b7ee780c --- /dev/null +++ b/source/websockets/sync/server.py @@ -0,0 +1,765 @@ +from __future__ import annotations + +import hmac +import http +import logging +import os +import re +import selectors +import socket +import ssl as ssl_module +import sys +import threading +import warnings +from collections.abc import Iterable, Sequence +from types import TracebackType +from typing import Any, Callable, Mapping, cast + +from ..exceptions import InvalidHeader +from ..extensions.base import ServerExtensionFactory +from ..extensions.permessage_deflate import enable_server_permessage_deflate +from ..frames import CloseCode +from ..headers import ( + build_www_authenticate_basic, + parse_authorization_basic, + validate_subprotocols, +) +from ..http11 import SERVER, Request, Response +from ..protocol import CONNECTING, OPEN, Event +from ..server import ServerProtocol +from ..typing import LoggerLike, Origin, StatusLike, Subprotocol +from .connection import Connection +from .utils import Deadline + + +__all__ = ["serve", "unix_serve", "ServerConnection", "Server", "basic_auth"] + + +class ServerConnection(Connection): + """ + :mod:`threading` implementation of a WebSocket server connection. + + :class:`ServerConnection` provides :meth:`recv` and :meth:`send` methods for + receiving and sending messages. + + It supports iteration to receive messages:: + + for message in websocket: + process(message) + + The iterator exits normally when the connection is closed with code + 1000 (OK) or 1001 (going away) or without a close code. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is + closed with any other code. + + The ``ping_interval``, ``ping_timeout``, ``close_timeout``, and + ``max_queue`` arguments have the same meaning as in :func:`serve`. + + Args: + socket: Socket connected to a WebSocket client. + protocol: Sans-I/O connection. + + """ + + def __init__( + self, + socket: socket.socket, + protocol: ServerProtocol, + *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = 10, + max_queue: int | None | tuple[int | None, int | None] = 16, + ) -> None: + self.protocol: ServerProtocol + self.request_rcvd = threading.Event() + super().__init__( + socket, + protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_queue=max_queue, + ) + self.username: str # see basic_auth() + self.handler: Callable[[ServerConnection], None] # see route() + self.handler_kwargs: Mapping[str, Any] # see route() + + def respond(self, status: StatusLike, text: str) -> Response: + """ + Create a plain text HTTP response. + + ``process_request`` and ``process_response`` may call this method to + return an HTTP response instead of performing the WebSocket opening + handshake. + + You can modify the response before returning it, for example by changing + HTTP headers. + + Args: + status: HTTP status code. + text: HTTP response body; it will be encoded to UTF-8. + + Returns: + HTTP response to send to the client. + + """ + return self.protocol.reject(status, text) + + def handshake( + self, + process_request: ( + Callable[ + [ServerConnection, Request], + Response | None, + ] + | None + ) = None, + process_response: ( + Callable[ + [ServerConnection, Request, Response], + Response | None, + ] + | None + ) = None, + server_header: str | None = SERVER, + timeout: float | None = None, + ) -> None: + """ + Perform the opening handshake. + + """ + if not self.request_rcvd.wait(timeout): + raise TimeoutError("timed out while waiting for handshake request") + + if self.request is not None: + with self.send_context(expected_state=CONNECTING): + response = None + + if process_request is not None: + try: + response = process_request(self, self.request) + except Exception as exc: + self.protocol.handshake_exc = exc + response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + if response is None: + self.response = self.protocol.accept(self.request) + else: + self.response = response + + if server_header: + self.response.headers["Server"] = server_header + + response = None + + if process_response is not None: + try: + response = process_response(self, self.request, self.response) + except Exception as exc: + self.protocol.handshake_exc = exc + response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + if response is not None: + self.response = response + + self.protocol.send_response(self.response) + + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a request, when the request cannot be parsed, or when the + # handshake fails, including when process_request or process_response + # raises an exception. + + # It isn't set when process_request or process_response sends an HTTP + # response that rejects the handshake. + + if self.protocol.handshake_exc is not None: + raise self.protocol.handshake_exc + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + """ + # First event - handshake request. + if self.request is None: + assert isinstance(event, Request) + self.request = event + self.request_rcvd.set() + # Later events - frames. + else: + super().process_event(event) + + def recv_events(self) -> None: + """ + Read incoming data from the socket and process events. + + """ + try: + super().recv_events() + finally: + # If the connection is closed during the handshake, unblock it. + self.request_rcvd.set() + + +class Server: + """ + WebSocket server returned by :func:`serve`. + + This class mirrors the API of :class:`~socketserver.BaseServer`, notably the + :meth:`~socketserver.BaseServer.serve_forever` and + :meth:`~socketserver.BaseServer.shutdown` methods, as well as the context + manager protocol. + + Args: + socket: Server socket listening for new connections. + handler: Handler for one connection. Receives the socket and address + returned by :meth:`~socket.socket.accept`. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. + See the :doc:`logging guide <../../topics/logging>` for details. + + """ + + def __init__( + self, + socket: socket.socket, + handler: Callable[[socket.socket, Any], None], + logger: LoggerLike | None = None, + ) -> None: + self.socket = socket + self.handler = handler + if logger is None: + logger = logging.getLogger("websockets.server") + self.logger = logger + if sys.platform != "win32": + self.shutdown_watcher, self.shutdown_notifier = os.pipe() + + def serve_forever(self) -> None: + """ + See :meth:`socketserver.BaseServer.serve_forever`. + + This method doesn't return. Calling :meth:`shutdown` from another thread + stops the server. + + Typical use:: + + with serve(...) as server: + server.serve_forever() + + """ + poller = selectors.DefaultSelector() + try: + poller.register(self.socket, selectors.EVENT_READ) + except ValueError: # pragma: no cover + # If shutdown() is called before poller.register(), + # the socket is closed and poller.register() raises + # ValueError: Invalid file descriptor: -1 + return + if sys.platform != "win32": + poller.register(self.shutdown_watcher, selectors.EVENT_READ) + + while True: + poller.select() + try: + # If the socket is closed, this will raise an exception and exit + # the loop. So we don't need to check the return value of select(). + sock, addr = self.socket.accept() + except OSError: + break + # Since there isn't a mechanism for tracking connections and waiting + # for them to terminate, we cannot use daemon threads, or else all + # connections would be terminate brutally when closing the server. + thread = threading.Thread(target=self.handler, args=(sock, addr)) + thread.start() + + def shutdown(self) -> None: + """ + See :meth:`socketserver.BaseServer.shutdown`. + + """ + self.socket.close() + if sys.platform != "win32": + os.write(self.shutdown_notifier, b"x") + + def fileno(self) -> int: + """ + See :meth:`socketserver.BaseServer.fileno`. + + """ + return self.socket.fileno() + + def __enter__(self) -> Server: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.shutdown() + + +def __getattr__(name: str) -> Any: + if name == "WebSocketServer": + warnings.warn( # deprecated in 13.0 - 2024-08-20 + "WebSocketServer was renamed to Server", + DeprecationWarning, + ) + return Server + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def serve( + handler: Callable[[ServerConnection], None], + host: str | None = None, + port: int | None = None, + *, + # TCP/TLS + sock: socket.socket | None = None, + ssl: ssl_module.SSLContext | None = None, + # WebSocket + origins: Sequence[Origin | re.Pattern[str] | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + select_subprotocol: ( + Callable[ + [ServerConnection, Sequence[Subprotocol]], + Subprotocol | None, + ] + | None + ) = None, + compression: str | None = "deflate", + # HTTP + process_request: ( + Callable[ + [ServerConnection, Request], + Response | None, + ] + | None + ) = None, + process_response: ( + Callable[ + [ServerConnection, Request, Response], + Response | None, + ] + | None + ) = None, + server_header: str | None = SERVER, + # Timeouts + open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = 10, + # Limits + max_size: int | None | tuple[int | None, int | None] = 2**20, + max_queue: int | None | tuple[int | None, int | None] = 16, + # Logging + logger: LoggerLike | None = None, + # Escape hatch for advanced customization + create_connection: type[ServerConnection] | None = None, + **kwargs: Any, +) -> Server: + """ + Create a WebSocket server listening on ``host`` and ``port``. + + Whenever a client connects, the server creates a :class:`ServerConnection`, + performs the opening handshake, and delegates to the ``handler``. + + The handler receives the :class:`ServerConnection` instance, which you can + use to send and receive messages. + + Once the handler completes, either normally or with an exception, the server + performs the closing handshake and closes the connection. + + This function returns a :class:`Server` whose API mirrors + :class:`~socketserver.BaseServer`. Treat it as a context manager to ensure + that it will be closed and call :meth:`~Server.serve_forever` to serve + requests:: + + from websockets.sync.server import serve + + def handler(websocket): + ... + + with serve(handler, ...) as server: + server.serve_forever() + + Args: + handler: Connection handler. It receives the WebSocket connection, + which is a :class:`ServerConnection`, in argument. + host: Network interfaces the server binds to. + See :func:`~socket.create_server` for details. + port: TCP port the server listens on. + See :func:`~socket.create_server` for details. + sock: Preexisting TCP socket. ``sock`` replaces ``host`` and ``port``. + You may call :func:`socket.create_server` to create a suitable TCP + socket. + ssl: Configuration for enabling TLS on the connection. + origins: Acceptable values of the ``Origin`` header, for defending + against Cross-Site WebSocket Hijacking attacks. Values can be + :class:`str` to test for an exact match or regular expressions + compiled by :func:`re.compile` to test against a pattern. Include + :obj:`None` in the list if the lack of an origin is acceptable. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + select_subprotocol: Callback for selecting a subprotocol among + those supported by the client and the server. It receives a + :class:`ServerConnection` (not a + :class:`~websockets.server.ServerProtocol`!) instance and a list of + subprotocols offered by the client. Other than the first argument, + it has the same behavior as the + :meth:`ServerProtocol.select_subprotocol + ` method. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + process_request: Intercept the request during the opening handshake. + Return an HTTP response to force the response. Return :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. + process_response: Intercept the response during the opening handshake. + Modify the response or return a new HTTP response to force the + response. Return :obj:`None` to continue normally. When you force an + HTTP 101 Continue response, the handshake is successful. Else, the + connection is aborted. + server_header: Value of the ``Server`` response header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to + :obj:`None` removes the header. + open_timeout: Timeout for opening connections in seconds. + :obj:`None` disables the timeout. + ping_interval: Interval between keepalive pings in seconds. + :obj:`None` disables keepalive. + ping_timeout: Timeout for keepalive pings in seconds. + :obj:`None` disables timeouts. + close_timeout: Timeout for closing connections in seconds. + :obj:`None` disables the timeout. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. You may pass a ``(max_message_size, + max_fragment_size)`` tuple to set different limits for messages and + fragments when you expect long messages sent in short fragments. + max_queue: High-water mark of the buffer where frames are received. + It defaults to 16 frames. The low-water mark defaults to ``max_queue + // 4``. You may pass a ``(high, low)`` tuple to set the high-water + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. See the + :doc:`logging guide <../../topics/logging>` for details. + create_connection: Factory for the :class:`ServerConnection` managing + the connection. Set it to a wrapper or a subclass to customize + connection handling. + + Any other keyword arguments are passed to :func:`~socket.create_server`. + + """ + + # Process parameters + + # Backwards compatibility: ssl used to be called ssl_context. + if ssl is None and "ssl_context" in kwargs: + ssl = kwargs.pop("ssl_context") + warnings.warn( # deprecated in 13.0 - 2024-08-20 + "ssl_context was renamed to ssl", + DeprecationWarning, + ) + + if subprotocols is not None: + validate_subprotocols(subprotocols) + + if compression == "deflate": + extensions = enable_server_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + if create_connection is None: + create_connection = ServerConnection + + # Bind socket and listen + + # Private APIs for unix_connect() + unix: bool = kwargs.pop("unix", False) + path: str | None = kwargs.pop("path", None) + + if sock is None: + if unix: + if path is None: + raise ValueError("missing path argument") + kwargs.setdefault("family", socket.AF_UNIX) + sock = socket.create_server(path, **kwargs) + else: + sock = socket.create_server((host, port), **kwargs) + else: + if path is not None: + raise ValueError("path and sock arguments are incompatible") + + # Initialize TLS wrapper + + if ssl is not None: + sock = ssl.wrap_socket( + sock, + server_side=True, + # Delay TLS handshake until after we set a timeout on the socket. + do_handshake_on_connect=False, + ) + + # Define request handler + + def conn_handler(sock: socket.socket, addr: Any) -> None: + # Calculate timeouts on the TLS and WebSocket handshakes. + # The TLS timeout must be set on the socket, then removed + # to avoid conflicting with the WebSocket timeout in handshake(). + deadline = Deadline(open_timeout) + + try: + # Disable Nagle algorithm + + if not unix: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True) + + # Perform TLS handshake + + if ssl is not None: + sock.settimeout(deadline.timeout()) + # mypy cannot figure this out + assert isinstance(sock, ssl_module.SSLSocket) + sock.do_handshake() + sock.settimeout(None) + + # Create a closure to give select_subprotocol access to connection. + protocol_select_subprotocol: ( + Callable[ + [ServerProtocol, Sequence[Subprotocol]], + Subprotocol | None, + ] + | None + ) = None + if select_subprotocol is not None: + + def protocol_select_subprotocol( + protocol: ServerProtocol, + subprotocols: Sequence[Subprotocol], + ) -> Subprotocol | None: + # mypy doesn't know that select_subprotocol is immutable. + assert select_subprotocol is not None + # Ensure this function is only used in the intended context. + assert protocol is connection.protocol + return select_subprotocol(connection, subprotocols) + + # Initialize WebSocket protocol + + protocol = ServerProtocol( + origins=origins, + extensions=extensions, + subprotocols=subprotocols, + select_subprotocol=protocol_select_subprotocol, + max_size=max_size, + logger=logger, + ) + + # Initialize WebSocket connection + + assert create_connection is not None # help mypy + connection = create_connection( + sock, + protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_queue=max_queue, + ) + except Exception: + sock.close() + return + + try: + try: + connection.handshake( + process_request, + process_response, + server_header, + deadline.timeout(), + ) + except TimeoutError: + connection.close_socket() + connection.recv_events_thread.join() + return + except Exception: + connection.logger.error("opening handshake failed", exc_info=True) + connection.close_socket() + connection.recv_events_thread.join() + return + + assert connection.protocol.state is OPEN + try: + connection.start_keepalive() + handler(connection) + except Exception: + connection.logger.error("connection handler failed", exc_info=True) + connection.close(CloseCode.INTERNAL_ERROR) + else: + connection.close() + + except Exception: # pragma: no cover + # Don't leak sockets on unexpected errors. + sock.close() + + # Initialize server + + return Server(sock, conn_handler, logger) + + +def unix_serve( + handler: Callable[[ServerConnection], None], + path: str | None = None, + **kwargs: Any, +) -> Server: + """ + Create a WebSocket server listening on a Unix socket. + + This function accepts the same keyword arguments as :func:`serve`. + + It's only available on Unix. + + It's useful for deploying a server behind a reverse proxy such as nginx. + + Args: + handler: Connection handler. It receives the WebSocket connection, + which is a :class:`ServerConnection`, in argument. + path: File system path to the Unix socket. + + """ + return serve(handler, unix=True, path=path, **kwargs) + + +def is_credentials(credentials: Any) -> bool: + try: + username, password = credentials + except (TypeError, ValueError): + return False + else: + return isinstance(username, str) and isinstance(password, str) + + +def basic_auth( + realm: str = "", + credentials: tuple[str, str] | Iterable[tuple[str, str]] | None = None, + check_credentials: Callable[[str, str], bool] | None = None, +) -> Callable[[ServerConnection, Request], Response | None]: + """ + Factory for ``process_request`` to enforce HTTP Basic Authentication. + + :func:`basic_auth` is designed to integrate with :func:`serve` as follows:: + + from websockets.sync.server import basic_auth, serve + + with serve( + ..., + process_request=basic_auth( + realm="my dev server", + credentials=("hello", "iloveyou"), + ), + ): + + If authentication succeeds, the connection's ``username`` attribute is set. + If it fails, the server responds with an HTTP 401 Unauthorized status. + + One of ``credentials`` or ``check_credentials`` must be provided; not both. + + Args: + realm: Scope of protection. It should contain only ASCII characters + because the encoding of non-ASCII characters is undefined. Refer to + section 2.2 of :rfc:`7235` for details. + credentials: Hard coded authorized credentials. It can be a + ``(username, password)`` pair or a list of such pairs. + check_credentials: Function that verifies credentials. + It receives ``username`` and ``password`` arguments and returns + whether they're valid. + Raises: + TypeError: If ``credentials`` or ``check_credentials`` is wrong. + ValueError: If ``credentials`` and ``check_credentials`` are both + provided or both not provided. + + """ + if (credentials is None) == (check_credentials is None): + raise ValueError("provide either credentials or check_credentials") + + if credentials is not None: + if is_credentials(credentials): + credentials_list = [cast(tuple[str, str], credentials)] + elif isinstance(credentials, Iterable): + credentials_list = list(cast(Iterable[tuple[str, str]], credentials)) + if not all(is_credentials(item) for item in credentials_list): + raise TypeError(f"invalid credentials argument: {credentials}") + else: + raise TypeError(f"invalid credentials argument: {credentials}") + + credentials_dict = dict(credentials_list) + + def check_credentials(username: str, password: str) -> bool: + try: + expected_password = credentials_dict[username] + except KeyError: + return False + return hmac.compare_digest(expected_password, password) + + assert check_credentials is not None # help mypy + + def process_request( + connection: ServerConnection, + request: Request, + ) -> Response | None: + """ + Perform HTTP Basic Authentication. + + If it succeeds, set the connection's ``username`` attribute and return + :obj:`None`. If it fails, return an HTTP 401 Unauthorized responss. + + """ + try: + authorization = request.headers["Authorization"] + except KeyError: + response = connection.respond( + http.HTTPStatus.UNAUTHORIZED, + "Missing credentials\n", + ) + response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) + return response + + try: + username, password = parse_authorization_basic(authorization) + except InvalidHeader: + response = connection.respond( + http.HTTPStatus.UNAUTHORIZED, + "Unsupported credentials\n", + ) + response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) + return response + + if not check_credentials(username, password): + response = connection.respond( + http.HTTPStatus.UNAUTHORIZED, + "Invalid credentials\n", + ) + response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm) + return response + + connection.username = username + return None + + return process_request diff --git a/source/websockets/sync/utils.py b/source/websockets/sync/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..00bce2cc6bb19fa280a6ef2b3481403e6f6ba74f --- /dev/null +++ b/source/websockets/sync/utils.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import time + + +__all__ = ["Deadline"] + + +class Deadline: + """ + Manage timeouts across multiple steps. + + Args: + timeout: Time available in seconds or :obj:`None` if there is no limit. + + """ + + def __init__(self, timeout: float | None) -> None: + self.deadline: float | None + if timeout is None: + self.deadline = None + else: + self.deadline = time.monotonic() + timeout + + def timeout(self, *, raise_if_elapsed: bool = True) -> float | None: + """ + Calculate a timeout from a deadline. + + Args: + raise_if_elapsed: Whether to raise :exc:`TimeoutError` + if the deadline lapsed. + + Raises: + TimeoutError: If the deadline lapsed. + + Returns: + Time left in seconds or :obj:`None` if there is no limit. + + """ + if self.deadline is None: + return None + timeout = self.deadline - time.monotonic() + if raise_if_elapsed and timeout <= 0: + raise TimeoutError("timed out") + return timeout diff --git a/source/websockets/typing.py b/source/websockets/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..69b1a8d372837b0a297177f8353a335375cfbd49 --- /dev/null +++ b/source/websockets/typing.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import http +import logging +from typing import TYPE_CHECKING, Any, NewType, Sequence + + +__all__ = [ + "Data", + "LoggerLike", + "StatusLike", + "Origin", + "Subprotocol", + "ExtensionName", + "ExtensionParameter", +] + + +# Public types used in the signature of public APIs + +Data = str | bytes +"""Types supported in a WebSocket message: +:class:`str` for a Text_ frame, :class:`bytes` for a Binary_ frame. + +.. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 +.. _Binary : https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + +""" + +BytesLike = bytes | bytearray | memoryview +"""Types accepted where :class:`bytes` is expected.""" + +DataLike = str | bytes | bytearray | memoryview +"""Types accepted where :class:`Data` is expected.""" + +if TYPE_CHECKING: + LoggerLike = logging.Logger | logging.LoggerAdapter[Any] + """Types accepted where a :class:`~logging.Logger` is expected.""" +else: # remove this branch when dropping support for Python < 3.11 + LoggerLike = logging.Logger | logging.LoggerAdapter + """Types accepted where a :class:`~logging.Logger` is expected.""" + + +StatusLike = http.HTTPStatus | int +""" +Types accepted where an :class:`~http.HTTPStatus` is expected.""" + + +Origin = NewType("Origin", str) +"""Value of a ``Origin`` header.""" + + +Subprotocol = NewType("Subprotocol", str) +"""Subprotocol in a ``Sec-WebSocket-Protocol`` header.""" + + +ExtensionName = NewType("ExtensionName", str) +"""Name of a WebSocket extension.""" + +ExtensionParameter = tuple[str, str | None] +"""Parameter of a WebSocket extension.""" + + +# Private types + +ExtensionHeader = tuple[ExtensionName, Sequence[ExtensionParameter]] +"""Extension in a ``Sec-WebSocket-Extensions`` header.""" + + +ConnectionOption = NewType("ConnectionOption", str) +"""Connection option in a ``Connection`` header.""" + + +UpgradeProtocol = NewType("UpgradeProtocol", str) +"""Upgrade protocol in an ``Upgrade`` header.""" diff --git a/source/websockets/uri.py b/source/websockets/uri.py new file mode 100644 index 0000000000000000000000000000000000000000..f85e16810b8e43bbd348871ef661b2644ea34240 --- /dev/null +++ b/source/websockets/uri.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import dataclasses +import urllib.parse + +from .exceptions import InvalidURI + + +__all__ = ["parse_uri", "WebSocketURI"] + + +# All characters from the gen-delims and sub-delims sets in RFC 3987. +DELIMS = ":/?#[]@!$&'()*+,;=" + + +@dataclasses.dataclass +class WebSocketURI: + """ + WebSocket URI. + + Attributes: + secure: :obj:`True` for a ``wss`` URI, :obj:`False` for a ``ws`` URI. + host: Normalized to lower case. + port: Always set even if it's the default. + path: May be empty. + query: May be empty if the URI doesn't include a query component. + username: Available when the URI contains `User Information`_. + password: Available when the URI contains `User Information`_. + + .. _User Information: https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.1 + + """ + + secure: bool + host: str + port: int + path: str + query: str + username: str | None = None + password: str | None = None + + @property + def resource_name(self) -> str: + if self.path: + resource_name = self.path + else: + resource_name = "/" + if self.query: + resource_name += "?" + self.query + return resource_name + + @property + def user_info(self) -> tuple[str, str] | None: + if self.username is None: + return None + assert self.password is not None + return (self.username, self.password) + + +def parse_uri(uri: str) -> WebSocketURI: + """ + Parse and validate a WebSocket URI. + + Args: + uri: WebSocket URI. + + Returns: + Parsed WebSocket URI. + + Raises: + InvalidURI: If ``uri`` isn't a valid WebSocket URI. + + """ + parsed = urllib.parse.urlparse(uri) + if parsed.scheme not in ["ws", "wss"]: + raise InvalidURI(uri, "scheme isn't ws or wss") + if parsed.hostname is None: + raise InvalidURI(uri, "hostname isn't provided") + if parsed.fragment != "": + raise InvalidURI(uri, "fragment identifier is meaningless") + + secure = parsed.scheme == "wss" + host = parsed.hostname + port = parsed.port or (443 if secure else 80) + path = parsed.path + query = parsed.query + username = parsed.username + password = parsed.password + # urllib.parse.urlparse accepts URLs with a username but without a + # password. This doesn't make sense for HTTP Basic Auth credentials. + if username is not None and password is None: + raise InvalidURI(uri, "username provided without password") + + try: + uri.encode("ascii") + except UnicodeEncodeError: + # Input contains non-ASCII characters. + # It must be an IRI. Convert it to a URI. + host = host.encode("idna").decode() + path = urllib.parse.quote(path, safe=DELIMS) + query = urllib.parse.quote(query, safe=DELIMS) + if username is not None: + assert password is not None + username = urllib.parse.quote(username, safe=DELIMS) + password = urllib.parse.quote(password, safe=DELIMS) + + return WebSocketURI(secure, host, port, path, query, username, password) diff --git a/source/websockets/utils.py b/source/websockets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b2a90e52b5a88e2a086fdff088ace214fdf23bb9 --- /dev/null +++ b/source/websockets/utils.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +import base64 +import hashlib +import secrets +import sys + +from .typing import BytesLike + + +__all__ = ["accept_key", "apply_mask"] + + +GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + +def generate_key() -> str: + """ + Generate a random key for the Sec-WebSocket-Key header. + + """ + key = secrets.token_bytes(16) + return base64.b64encode(key).decode() + + +def accept_key(key: str) -> str: + """ + Compute the value of the Sec-WebSocket-Accept header. + + Args: + key: Value of the Sec-WebSocket-Key header. + + """ + sha1 = hashlib.sha1((key + GUID).encode()).digest() + return base64.b64encode(sha1).decode() + + +def apply_mask(data: BytesLike, mask: bytes | bytearray) -> bytes: + """ + Apply masking to the data of a WebSocket message. + + Args: + data: Data to mask. + mask: 4-bytes mask. + + """ + if len(mask) != 4: + raise ValueError("mask must contain 4 bytes") + + data_int = int.from_bytes(data, sys.byteorder) + mask_repeated = mask * (len(data) // 4) + mask[: len(data) % 4] + mask_int = int.from_bytes(mask_repeated, sys.byteorder) + return (data_int ^ mask_int).to_bytes(len(data), sys.byteorder) diff --git a/source/websockets/version.py b/source/websockets/version.py new file mode 100644 index 0000000000000000000000000000000000000000..dde52b6b0b1db8c0c3763ff537c23b7ce718b60c --- /dev/null +++ b/source/websockets/version.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import importlib.metadata + + +__all__ = ["tag", "version", "commit"] + + +# ========= =========== =================== +# release development +# ========= =========== =================== +# tag X.Y X.Y (upcoming) +# version X.Y X.Y.dev1+g5678cde +# commit X.Y 5678cde +# ========= =========== =================== + + +# When tagging a release, set `released = True`. +# After tagging a release, set `released = False` and increment `tag`. + +released = True + +tag = version = commit = "16.0" + + +if not released: # pragma: no cover + import pathlib + import re + import subprocess + + def get_version(tag: str) -> str: + # Since setup.py executes the contents of src/websockets/version.py, + # __file__ can point to either of these two files. + file_path = pathlib.Path(__file__) + root_dir = file_path.parents[0 if file_path.name == "setup.py" else 2] + + # Read version from package metadata if it is installed. + try: + version = importlib.metadata.version("websockets") + except ImportError: + pass + else: + # Check that this file belongs to the installed package. + files = importlib.metadata.files("websockets") + if files: + version_files = [f for f in files if f.name == file_path.name] + if version_files: + version_file = version_files[0] + if version_file.locate() == file_path: + return version + + # Read version from git if available. + try: + description = subprocess.run( + ["git", "describe", "--dirty", "--tags", "--long"], + capture_output=True, + cwd=root_dir, + timeout=1, + check=True, + text=True, + ).stdout.strip() + # subprocess.run raises FileNotFoundError if git isn't on $PATH. + except ( + FileNotFoundError, + subprocess.CalledProcessError, + subprocess.TimeoutExpired, + ): + pass + else: + description_re = r"[0-9.]+-([0-9]+)-(g[0-9a-f]{7,}(?:-dirty)?)" + match = re.fullmatch(description_re, description) + if match is None: + raise ValueError(f"Unexpected git description: {description}") + distance, remainder = match.groups() + remainder = remainder.replace("-", ".") # required by PEP 440 + return f"{tag}.dev{distance}+{remainder}" + + # Avoid crashing if the development version cannot be determined. + return f"{tag}.dev0+gunknown" + + version = get_version(tag) + + def get_commit(tag: str, version: str) -> str: + # Extract commit from version, falling back to tag if not available. + version_re = r"[0-9.]+\.dev[0-9]+\+g([0-9a-f]{7,}|unknown)(?:\.dirty)?" + match = re.fullmatch(version_re, version) + if match is None: + raise ValueError(f"Unexpected version: {version}") + (commit,) = match.groups() + return tag if commit == "unknown" else commit + + commit = get_commit(tag, version) diff --git a/source/yaml/__init__.py b/source/yaml/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d58f0891737def7f38e5d86dde2dbf9be0c13dce --- /dev/null +++ b/source/yaml/__init__.py @@ -0,0 +1,390 @@ + +from .error import * + +from .tokens import * +from .events import * +from .nodes import * + +from .loader import * +from .dumper import * + +__version__ = '6.0.3' +try: + from .cyaml import * + __with_libyaml__ = True +except ImportError: + __with_libyaml__ = False + +import io + +#------------------------------------------------------------------------------ +# XXX "Warnings control" is now deprecated. Leaving in the API function to not +# break code that uses it. +#------------------------------------------------------------------------------ +def warnings(settings=None): + if settings is None: + return {} + +#------------------------------------------------------------------------------ +def scan(stream, Loader=Loader): + """ + Scan a YAML stream and produce scanning tokens. + """ + loader = Loader(stream) + try: + while loader.check_token(): + yield loader.get_token() + finally: + loader.dispose() + +def parse(stream, Loader=Loader): + """ + Parse a YAML stream and produce parsing events. + """ + loader = Loader(stream) + try: + while loader.check_event(): + yield loader.get_event() + finally: + loader.dispose() + +def compose(stream, Loader=Loader): + """ + Parse the first YAML document in a stream + and produce the corresponding representation tree. + """ + loader = Loader(stream) + try: + return loader.get_single_node() + finally: + loader.dispose() + +def compose_all(stream, Loader=Loader): + """ + Parse all YAML documents in a stream + and produce corresponding representation trees. + """ + loader = Loader(stream) + try: + while loader.check_node(): + yield loader.get_node() + finally: + loader.dispose() + +def load(stream, Loader): + """ + Parse the first YAML document in a stream + and produce the corresponding Python object. + """ + loader = Loader(stream) + try: + return loader.get_single_data() + finally: + loader.dispose() + +def load_all(stream, Loader): + """ + Parse all YAML documents in a stream + and produce corresponding Python objects. + """ + loader = Loader(stream) + try: + while loader.check_data(): + yield loader.get_data() + finally: + loader.dispose() + +def full_load(stream): + """ + Parse the first YAML document in a stream + and produce the corresponding Python object. + + Resolve all tags except those known to be + unsafe on untrusted input. + """ + return load(stream, FullLoader) + +def full_load_all(stream): + """ + Parse all YAML documents in a stream + and produce corresponding Python objects. + + Resolve all tags except those known to be + unsafe on untrusted input. + """ + return load_all(stream, FullLoader) + +def safe_load(stream): + """ + Parse the first YAML document in a stream + and produce the corresponding Python object. + + Resolve only basic YAML tags. This is known + to be safe for untrusted input. + """ + return load(stream, SafeLoader) + +def safe_load_all(stream): + """ + Parse all YAML documents in a stream + and produce corresponding Python objects. + + Resolve only basic YAML tags. This is known + to be safe for untrusted input. + """ + return load_all(stream, SafeLoader) + +def unsafe_load(stream): + """ + Parse the first YAML document in a stream + and produce the corresponding Python object. + + Resolve all tags, even those known to be + unsafe on untrusted input. + """ + return load(stream, UnsafeLoader) + +def unsafe_load_all(stream): + """ + Parse all YAML documents in a stream + and produce corresponding Python objects. + + Resolve all tags, even those known to be + unsafe on untrusted input. + """ + return load_all(stream, UnsafeLoader) + +def emit(events, stream=None, Dumper=Dumper, + canonical=None, indent=None, width=None, + allow_unicode=None, line_break=None): + """ + Emit YAML parsing events into a stream. + If stream is None, return the produced string instead. + """ + getvalue = None + if stream is None: + stream = io.StringIO() + getvalue = stream.getvalue + dumper = Dumper(stream, canonical=canonical, indent=indent, width=width, + allow_unicode=allow_unicode, line_break=line_break) + try: + for event in events: + dumper.emit(event) + finally: + dumper.dispose() + if getvalue: + return getvalue() + +def serialize_all(nodes, stream=None, Dumper=Dumper, + canonical=None, indent=None, width=None, + allow_unicode=None, line_break=None, + encoding=None, explicit_start=None, explicit_end=None, + version=None, tags=None): + """ + Serialize a sequence of representation trees into a YAML stream. + If stream is None, return the produced string instead. + """ + getvalue = None + if stream is None: + if encoding is None: + stream = io.StringIO() + else: + stream = io.BytesIO() + getvalue = stream.getvalue + dumper = Dumper(stream, canonical=canonical, indent=indent, width=width, + allow_unicode=allow_unicode, line_break=line_break, + encoding=encoding, version=version, tags=tags, + explicit_start=explicit_start, explicit_end=explicit_end) + try: + dumper.open() + for node in nodes: + dumper.serialize(node) + dumper.close() + finally: + dumper.dispose() + if getvalue: + return getvalue() + +def serialize(node, stream=None, Dumper=Dumper, **kwds): + """ + Serialize a representation tree into a YAML stream. + If stream is None, return the produced string instead. + """ + return serialize_all([node], stream, Dumper=Dumper, **kwds) + +def dump_all(documents, stream=None, Dumper=Dumper, + default_style=None, default_flow_style=False, + canonical=None, indent=None, width=None, + allow_unicode=None, line_break=None, + encoding=None, explicit_start=None, explicit_end=None, + version=None, tags=None, sort_keys=True): + """ + Serialize a sequence of Python objects into a YAML stream. + If stream is None, return the produced string instead. + """ + getvalue = None + if stream is None: + if encoding is None: + stream = io.StringIO() + else: + stream = io.BytesIO() + getvalue = stream.getvalue + dumper = Dumper(stream, default_style=default_style, + default_flow_style=default_flow_style, + canonical=canonical, indent=indent, width=width, + allow_unicode=allow_unicode, line_break=line_break, + encoding=encoding, version=version, tags=tags, + explicit_start=explicit_start, explicit_end=explicit_end, sort_keys=sort_keys) + try: + dumper.open() + for data in documents: + dumper.represent(data) + dumper.close() + finally: + dumper.dispose() + if getvalue: + return getvalue() + +def dump(data, stream=None, Dumper=Dumper, **kwds): + """ + Serialize a Python object into a YAML stream. + If stream is None, return the produced string instead. + """ + return dump_all([data], stream, Dumper=Dumper, **kwds) + +def safe_dump_all(documents, stream=None, **kwds): + """ + Serialize a sequence of Python objects into a YAML stream. + Produce only basic YAML tags. + If stream is None, return the produced string instead. + """ + return dump_all(documents, stream, Dumper=SafeDumper, **kwds) + +def safe_dump(data, stream=None, **kwds): + """ + Serialize a Python object into a YAML stream. + Produce only basic YAML tags. + If stream is None, return the produced string instead. + """ + return dump_all([data], stream, Dumper=SafeDumper, **kwds) + +def add_implicit_resolver(tag, regexp, first=None, + Loader=None, Dumper=Dumper): + """ + Add an implicit scalar detector. + If an implicit scalar value matches the given regexp, + the corresponding tag is assigned to the scalar. + first is a sequence of possible initial characters or None. + """ + if Loader is None: + loader.Loader.add_implicit_resolver(tag, regexp, first) + loader.FullLoader.add_implicit_resolver(tag, regexp, first) + loader.UnsafeLoader.add_implicit_resolver(tag, regexp, first) + else: + Loader.add_implicit_resolver(tag, regexp, first) + Dumper.add_implicit_resolver(tag, regexp, first) + +def add_path_resolver(tag, path, kind=None, Loader=None, Dumper=Dumper): + """ + Add a path based resolver for the given tag. + A path is a list of keys that forms a path + to a node in the representation tree. + Keys can be string values, integers, or None. + """ + if Loader is None: + loader.Loader.add_path_resolver(tag, path, kind) + loader.FullLoader.add_path_resolver(tag, path, kind) + loader.UnsafeLoader.add_path_resolver(tag, path, kind) + else: + Loader.add_path_resolver(tag, path, kind) + Dumper.add_path_resolver(tag, path, kind) + +def add_constructor(tag, constructor, Loader=None): + """ + Add a constructor for the given tag. + Constructor is a function that accepts a Loader instance + and a node object and produces the corresponding Python object. + """ + if Loader is None: + loader.Loader.add_constructor(tag, constructor) + loader.FullLoader.add_constructor(tag, constructor) + loader.UnsafeLoader.add_constructor(tag, constructor) + else: + Loader.add_constructor(tag, constructor) + +def add_multi_constructor(tag_prefix, multi_constructor, Loader=None): + """ + Add a multi-constructor for the given tag prefix. + Multi-constructor is called for a node if its tag starts with tag_prefix. + Multi-constructor accepts a Loader instance, a tag suffix, + and a node object and produces the corresponding Python object. + """ + if Loader is None: + loader.Loader.add_multi_constructor(tag_prefix, multi_constructor) + loader.FullLoader.add_multi_constructor(tag_prefix, multi_constructor) + loader.UnsafeLoader.add_multi_constructor(tag_prefix, multi_constructor) + else: + Loader.add_multi_constructor(tag_prefix, multi_constructor) + +def add_representer(data_type, representer, Dumper=Dumper): + """ + Add a representer for the given type. + Representer is a function accepting a Dumper instance + and an instance of the given data type + and producing the corresponding representation node. + """ + Dumper.add_representer(data_type, representer) + +def add_multi_representer(data_type, multi_representer, Dumper=Dumper): + """ + Add a representer for the given type. + Multi-representer is a function accepting a Dumper instance + and an instance of the given data type or subtype + and producing the corresponding representation node. + """ + Dumper.add_multi_representer(data_type, multi_representer) + +class YAMLObjectMetaclass(type): + """ + The metaclass for YAMLObject. + """ + def __init__(cls, name, bases, kwds): + super(YAMLObjectMetaclass, cls).__init__(name, bases, kwds) + if 'yaml_tag' in kwds and kwds['yaml_tag'] is not None: + if isinstance(cls.yaml_loader, list): + for loader in cls.yaml_loader: + loader.add_constructor(cls.yaml_tag, cls.from_yaml) + else: + cls.yaml_loader.add_constructor(cls.yaml_tag, cls.from_yaml) + + cls.yaml_dumper.add_representer(cls, cls.to_yaml) + +class YAMLObject(metaclass=YAMLObjectMetaclass): + """ + An object that can dump itself to a YAML stream + and load itself from a YAML stream. + """ + + __slots__ = () # no direct instantiation, so allow immutable subclasses + + yaml_loader = [Loader, FullLoader, UnsafeLoader] + yaml_dumper = Dumper + + yaml_tag = None + yaml_flow_style = None + + @classmethod + def from_yaml(cls, loader, node): + """ + Convert a representation node to a Python object. + """ + return loader.construct_yaml_object(node, cls) + + @classmethod + def to_yaml(cls, dumper, data): + """ + Convert a Python object to a representation node. + """ + return dumper.represent_yaml_object(cls.yaml_tag, data, cls, + flow_style=cls.yaml_flow_style) + diff --git a/source/yaml/_yaml.cpython-312-x86_64-linux-gnu.so b/source/yaml/_yaml.cpython-312-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..5a17311d6696643869b7726d2cd50c046db77b68 --- /dev/null +++ b/source/yaml/_yaml.cpython-312-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:957a099a4521c1f7669306fb6f79ca63fa4b9a0b4c463f7cac833a65a5c5c0cb +size 2679264 diff --git a/source/yaml/composer.py b/source/yaml/composer.py new file mode 100644 index 0000000000000000000000000000000000000000..6d15cb40e3b4198819c91c6f8d8b32807fcf53b2 --- /dev/null +++ b/source/yaml/composer.py @@ -0,0 +1,139 @@ + +__all__ = ['Composer', 'ComposerError'] + +from .error import MarkedYAMLError +from .events import * +from .nodes import * + +class ComposerError(MarkedYAMLError): + pass + +class Composer: + + def __init__(self): + self.anchors = {} + + def check_node(self): + # Drop the STREAM-START event. + if self.check_event(StreamStartEvent): + self.get_event() + + # If there are more documents available? + return not self.check_event(StreamEndEvent) + + def get_node(self): + # Get the root node of the next document. + if not self.check_event(StreamEndEvent): + return self.compose_document() + + def get_single_node(self): + # Drop the STREAM-START event. + self.get_event() + + # Compose a document if the stream is not empty. + document = None + if not self.check_event(StreamEndEvent): + document = self.compose_document() + + # Ensure that the stream contains no more documents. + if not self.check_event(StreamEndEvent): + event = self.get_event() + raise ComposerError("expected a single document in the stream", + document.start_mark, "but found another document", + event.start_mark) + + # Drop the STREAM-END event. + self.get_event() + + return document + + def compose_document(self): + # Drop the DOCUMENT-START event. + self.get_event() + + # Compose the root node. + node = self.compose_node(None, None) + + # Drop the DOCUMENT-END event. + self.get_event() + + self.anchors = {} + return node + + def compose_node(self, parent, index): + if self.check_event(AliasEvent): + event = self.get_event() + anchor = event.anchor + if anchor not in self.anchors: + raise ComposerError(None, None, "found undefined alias %r" + % anchor, event.start_mark) + return self.anchors[anchor] + event = self.peek_event() + anchor = event.anchor + if anchor is not None: + if anchor in self.anchors: + raise ComposerError("found duplicate anchor %r; first occurrence" + % anchor, self.anchors[anchor].start_mark, + "second occurrence", event.start_mark) + self.descend_resolver(parent, index) + if self.check_event(ScalarEvent): + node = self.compose_scalar_node(anchor) + elif self.check_event(SequenceStartEvent): + node = self.compose_sequence_node(anchor) + elif self.check_event(MappingStartEvent): + node = self.compose_mapping_node(anchor) + self.ascend_resolver() + return node + + def compose_scalar_node(self, anchor): + event = self.get_event() + tag = event.tag + if tag is None or tag == '!': + tag = self.resolve(ScalarNode, event.value, event.implicit) + node = ScalarNode(tag, event.value, + event.start_mark, event.end_mark, style=event.style) + if anchor is not None: + self.anchors[anchor] = node + return node + + def compose_sequence_node(self, anchor): + start_event = self.get_event() + tag = start_event.tag + if tag is None or tag == '!': + tag = self.resolve(SequenceNode, None, start_event.implicit) + node = SequenceNode(tag, [], + start_event.start_mark, None, + flow_style=start_event.flow_style) + if anchor is not None: + self.anchors[anchor] = node + index = 0 + while not self.check_event(SequenceEndEvent): + node.value.append(self.compose_node(node, index)) + index += 1 + end_event = self.get_event() + node.end_mark = end_event.end_mark + return node + + def compose_mapping_node(self, anchor): + start_event = self.get_event() + tag = start_event.tag + if tag is None or tag == '!': + tag = self.resolve(MappingNode, None, start_event.implicit) + node = MappingNode(tag, [], + start_event.start_mark, None, + flow_style=start_event.flow_style) + if anchor is not None: + self.anchors[anchor] = node + while not self.check_event(MappingEndEvent): + #key_event = self.peek_event() + item_key = self.compose_node(node, None) + #if item_key in node.value: + # raise ComposerError("while composing a mapping", start_event.start_mark, + # "found duplicate key", key_event.start_mark) + item_value = self.compose_node(node, item_key) + #node.value[item_key] = item_value + node.value.append((item_key, item_value)) + end_event = self.get_event() + node.end_mark = end_event.end_mark + return node + diff --git a/source/yaml/constructor.py b/source/yaml/constructor.py new file mode 100644 index 0000000000000000000000000000000000000000..619acd3070a4845c653fcf22a626e05158035bc2 --- /dev/null +++ b/source/yaml/constructor.py @@ -0,0 +1,748 @@ + +__all__ = [ + 'BaseConstructor', + 'SafeConstructor', + 'FullConstructor', + 'UnsafeConstructor', + 'Constructor', + 'ConstructorError' +] + +from .error import * +from .nodes import * + +import collections.abc, datetime, base64, binascii, re, sys, types + +class ConstructorError(MarkedYAMLError): + pass + +class BaseConstructor: + + yaml_constructors = {} + yaml_multi_constructors = {} + + def __init__(self): + self.constructed_objects = {} + self.recursive_objects = {} + self.state_generators = [] + self.deep_construct = False + + def check_data(self): + # If there are more documents available? + return self.check_node() + + def check_state_key(self, key): + """Block special attributes/methods from being set in a newly created + object, to prevent user-controlled methods from being called during + deserialization""" + if self.get_state_keys_blacklist_regexp().match(key): + raise ConstructorError(None, None, + "blacklisted key '%s' in instance state found" % (key,), None) + + def get_data(self): + # Construct and return the next document. + if self.check_node(): + return self.construct_document(self.get_node()) + + def get_single_data(self): + # Ensure that the stream contains a single document and construct it. + node = self.get_single_node() + if node is not None: + return self.construct_document(node) + return None + + def construct_document(self, node): + data = self.construct_object(node) + while self.state_generators: + state_generators = self.state_generators + self.state_generators = [] + for generator in state_generators: + for dummy in generator: + pass + self.constructed_objects = {} + self.recursive_objects = {} + self.deep_construct = False + return data + + def construct_object(self, node, deep=False): + if node in self.constructed_objects: + return self.constructed_objects[node] + if deep: + old_deep = self.deep_construct + self.deep_construct = True + if node in self.recursive_objects: + raise ConstructorError(None, None, + "found unconstructable recursive node", node.start_mark) + self.recursive_objects[node] = None + constructor = None + tag_suffix = None + if node.tag in self.yaml_constructors: + constructor = self.yaml_constructors[node.tag] + else: + for tag_prefix in self.yaml_multi_constructors: + if tag_prefix is not None and node.tag.startswith(tag_prefix): + tag_suffix = node.tag[len(tag_prefix):] + constructor = self.yaml_multi_constructors[tag_prefix] + break + else: + if None in self.yaml_multi_constructors: + tag_suffix = node.tag + constructor = self.yaml_multi_constructors[None] + elif None in self.yaml_constructors: + constructor = self.yaml_constructors[None] + elif isinstance(node, ScalarNode): + constructor = self.__class__.construct_scalar + elif isinstance(node, SequenceNode): + constructor = self.__class__.construct_sequence + elif isinstance(node, MappingNode): + constructor = self.__class__.construct_mapping + if tag_suffix is None: + data = constructor(self, node) + else: + data = constructor(self, tag_suffix, node) + if isinstance(data, types.GeneratorType): + generator = data + data = next(generator) + if self.deep_construct: + for dummy in generator: + pass + else: + self.state_generators.append(generator) + self.constructed_objects[node] = data + del self.recursive_objects[node] + if deep: + self.deep_construct = old_deep + return data + + def construct_scalar(self, node): + if not isinstance(node, ScalarNode): + raise ConstructorError(None, None, + "expected a scalar node, but found %s" % node.id, + node.start_mark) + return node.value + + def construct_sequence(self, node, deep=False): + if not isinstance(node, SequenceNode): + raise ConstructorError(None, None, + "expected a sequence node, but found %s" % node.id, + node.start_mark) + return [self.construct_object(child, deep=deep) + for child in node.value] + + def construct_mapping(self, node, deep=False): + if not isinstance(node, MappingNode): + raise ConstructorError(None, None, + "expected a mapping node, but found %s" % node.id, + node.start_mark) + mapping = {} + for key_node, value_node in node.value: + key = self.construct_object(key_node, deep=deep) + if not isinstance(key, collections.abc.Hashable): + raise ConstructorError("while constructing a mapping", node.start_mark, + "found unhashable key", key_node.start_mark) + value = self.construct_object(value_node, deep=deep) + mapping[key] = value + return mapping + + def construct_pairs(self, node, deep=False): + if not isinstance(node, MappingNode): + raise ConstructorError(None, None, + "expected a mapping node, but found %s" % node.id, + node.start_mark) + pairs = [] + for key_node, value_node in node.value: + key = self.construct_object(key_node, deep=deep) + value = self.construct_object(value_node, deep=deep) + pairs.append((key, value)) + return pairs + + @classmethod + def add_constructor(cls, tag, constructor): + if not 'yaml_constructors' in cls.__dict__: + cls.yaml_constructors = cls.yaml_constructors.copy() + cls.yaml_constructors[tag] = constructor + + @classmethod + def add_multi_constructor(cls, tag_prefix, multi_constructor): + if not 'yaml_multi_constructors' in cls.__dict__: + cls.yaml_multi_constructors = cls.yaml_multi_constructors.copy() + cls.yaml_multi_constructors[tag_prefix] = multi_constructor + +class SafeConstructor(BaseConstructor): + + def construct_scalar(self, node): + if isinstance(node, MappingNode): + for key_node, value_node in node.value: + if key_node.tag == 'tag:yaml.org,2002:value': + return self.construct_scalar(value_node) + return super().construct_scalar(node) + + def flatten_mapping(self, node): + merge = [] + index = 0 + while index < len(node.value): + key_node, value_node = node.value[index] + if key_node.tag == 'tag:yaml.org,2002:merge': + del node.value[index] + if isinstance(value_node, MappingNode): + self.flatten_mapping(value_node) + merge.extend(value_node.value) + elif isinstance(value_node, SequenceNode): + submerge = [] + for subnode in value_node.value: + if not isinstance(subnode, MappingNode): + raise ConstructorError("while constructing a mapping", + node.start_mark, + "expected a mapping for merging, but found %s" + % subnode.id, subnode.start_mark) + self.flatten_mapping(subnode) + submerge.append(subnode.value) + submerge.reverse() + for value in submerge: + merge.extend(value) + else: + raise ConstructorError("while constructing a mapping", node.start_mark, + "expected a mapping or list of mappings for merging, but found %s" + % value_node.id, value_node.start_mark) + elif key_node.tag == 'tag:yaml.org,2002:value': + key_node.tag = 'tag:yaml.org,2002:str' + index += 1 + else: + index += 1 + if merge: + node.value = merge + node.value + + def construct_mapping(self, node, deep=False): + if isinstance(node, MappingNode): + self.flatten_mapping(node) + return super().construct_mapping(node, deep=deep) + + def construct_yaml_null(self, node): + self.construct_scalar(node) + return None + + bool_values = { + 'yes': True, + 'no': False, + 'true': True, + 'false': False, + 'on': True, + 'off': False, + } + + def construct_yaml_bool(self, node): + value = self.construct_scalar(node) + return self.bool_values[value.lower()] + + def construct_yaml_int(self, node): + value = self.construct_scalar(node) + value = value.replace('_', '') + sign = +1 + if value[0] == '-': + sign = -1 + if value[0] in '+-': + value = value[1:] + if value == '0': + return 0 + elif value.startswith('0b'): + return sign*int(value[2:], 2) + elif value.startswith('0x'): + return sign*int(value[2:], 16) + elif value[0] == '0': + return sign*int(value, 8) + elif ':' in value: + digits = [int(part) for part in value.split(':')] + digits.reverse() + base = 1 + value = 0 + for digit in digits: + value += digit*base + base *= 60 + return sign*value + else: + return sign*int(value) + + inf_value = 1e300 + while inf_value != inf_value*inf_value: + inf_value *= inf_value + nan_value = -inf_value/inf_value # Trying to make a quiet NaN (like C99). + + def construct_yaml_float(self, node): + value = self.construct_scalar(node) + value = value.replace('_', '').lower() + sign = +1 + if value[0] == '-': + sign = -1 + if value[0] in '+-': + value = value[1:] + if value == '.inf': + return sign*self.inf_value + elif value == '.nan': + return self.nan_value + elif ':' in value: + digits = [float(part) for part in value.split(':')] + digits.reverse() + base = 1 + value = 0.0 + for digit in digits: + value += digit*base + base *= 60 + return sign*value + else: + return sign*float(value) + + def construct_yaml_binary(self, node): + try: + value = self.construct_scalar(node).encode('ascii') + except UnicodeEncodeError as exc: + raise ConstructorError(None, None, + "failed to convert base64 data into ascii: %s" % exc, + node.start_mark) + try: + if hasattr(base64, 'decodebytes'): + return base64.decodebytes(value) + else: + return base64.decodestring(value) + except binascii.Error as exc: + raise ConstructorError(None, None, + "failed to decode base64 data: %s" % exc, node.start_mark) + + timestamp_regexp = re.compile( + r'''^(?P[0-9][0-9][0-9][0-9]) + -(?P[0-9][0-9]?) + -(?P[0-9][0-9]?) + (?:(?:[Tt]|[ \t]+) + (?P[0-9][0-9]?) + :(?P[0-9][0-9]) + :(?P[0-9][0-9]) + (?:\.(?P[0-9]*))? + (?:[ \t]*(?PZ|(?P[-+])(?P[0-9][0-9]?) + (?::(?P[0-9][0-9]))?))?)?$''', re.X) + + def construct_yaml_timestamp(self, node): + value = self.construct_scalar(node) + match = self.timestamp_regexp.match(node.value) + values = match.groupdict() + year = int(values['year']) + month = int(values['month']) + day = int(values['day']) + if not values['hour']: + return datetime.date(year, month, day) + hour = int(values['hour']) + minute = int(values['minute']) + second = int(values['second']) + fraction = 0 + tzinfo = None + if values['fraction']: + fraction = values['fraction'][:6] + while len(fraction) < 6: + fraction += '0' + fraction = int(fraction) + if values['tz_sign']: + tz_hour = int(values['tz_hour']) + tz_minute = int(values['tz_minute'] or 0) + delta = datetime.timedelta(hours=tz_hour, minutes=tz_minute) + if values['tz_sign'] == '-': + delta = -delta + tzinfo = datetime.timezone(delta) + elif values['tz']: + tzinfo = datetime.timezone.utc + return datetime.datetime(year, month, day, hour, minute, second, fraction, + tzinfo=tzinfo) + + def construct_yaml_omap(self, node): + # Note: we do not check for duplicate keys, because it's too + # CPU-expensive. + omap = [] + yield omap + if not isinstance(node, SequenceNode): + raise ConstructorError("while constructing an ordered map", node.start_mark, + "expected a sequence, but found %s" % node.id, node.start_mark) + for subnode in node.value: + if not isinstance(subnode, MappingNode): + raise ConstructorError("while constructing an ordered map", node.start_mark, + "expected a mapping of length 1, but found %s" % subnode.id, + subnode.start_mark) + if len(subnode.value) != 1: + raise ConstructorError("while constructing an ordered map", node.start_mark, + "expected a single mapping item, but found %d items" % len(subnode.value), + subnode.start_mark) + key_node, value_node = subnode.value[0] + key = self.construct_object(key_node) + value = self.construct_object(value_node) + omap.append((key, value)) + + def construct_yaml_pairs(self, node): + # Note: the same code as `construct_yaml_omap`. + pairs = [] + yield pairs + if not isinstance(node, SequenceNode): + raise ConstructorError("while constructing pairs", node.start_mark, + "expected a sequence, but found %s" % node.id, node.start_mark) + for subnode in node.value: + if not isinstance(subnode, MappingNode): + raise ConstructorError("while constructing pairs", node.start_mark, + "expected a mapping of length 1, but found %s" % subnode.id, + subnode.start_mark) + if len(subnode.value) != 1: + raise ConstructorError("while constructing pairs", node.start_mark, + "expected a single mapping item, but found %d items" % len(subnode.value), + subnode.start_mark) + key_node, value_node = subnode.value[0] + key = self.construct_object(key_node) + value = self.construct_object(value_node) + pairs.append((key, value)) + + def construct_yaml_set(self, node): + data = set() + yield data + value = self.construct_mapping(node) + data.update(value) + + def construct_yaml_str(self, node): + return self.construct_scalar(node) + + def construct_yaml_seq(self, node): + data = [] + yield data + data.extend(self.construct_sequence(node)) + + def construct_yaml_map(self, node): + data = {} + yield data + value = self.construct_mapping(node) + data.update(value) + + def construct_yaml_object(self, node, cls): + data = cls.__new__(cls) + yield data + if hasattr(data, '__setstate__'): + state = self.construct_mapping(node, deep=True) + data.__setstate__(state) + else: + state = self.construct_mapping(node) + data.__dict__.update(state) + + def construct_undefined(self, node): + raise ConstructorError(None, None, + "could not determine a constructor for the tag %r" % node.tag, + node.start_mark) + +SafeConstructor.add_constructor( + 'tag:yaml.org,2002:null', + SafeConstructor.construct_yaml_null) + +SafeConstructor.add_constructor( + 'tag:yaml.org,2002:bool', + SafeConstructor.construct_yaml_bool) + +SafeConstructor.add_constructor( + 'tag:yaml.org,2002:int', + SafeConstructor.construct_yaml_int) + +SafeConstructor.add_constructor( + 'tag:yaml.org,2002:float', + SafeConstructor.construct_yaml_float) + +SafeConstructor.add_constructor( + 'tag:yaml.org,2002:binary', + SafeConstructor.construct_yaml_binary) + +SafeConstructor.add_constructor( + 'tag:yaml.org,2002:timestamp', + SafeConstructor.construct_yaml_timestamp) + +SafeConstructor.add_constructor( + 'tag:yaml.org,2002:omap', + SafeConstructor.construct_yaml_omap) + +SafeConstructor.add_constructor( + 'tag:yaml.org,2002:pairs', + SafeConstructor.construct_yaml_pairs) + +SafeConstructor.add_constructor( + 'tag:yaml.org,2002:set', + SafeConstructor.construct_yaml_set) + +SafeConstructor.add_constructor( + 'tag:yaml.org,2002:str', + SafeConstructor.construct_yaml_str) + +SafeConstructor.add_constructor( + 'tag:yaml.org,2002:seq', + SafeConstructor.construct_yaml_seq) + +SafeConstructor.add_constructor( + 'tag:yaml.org,2002:map', + SafeConstructor.construct_yaml_map) + +SafeConstructor.add_constructor(None, + SafeConstructor.construct_undefined) + +class FullConstructor(SafeConstructor): + # 'extend' is blacklisted because it is used by + # construct_python_object_apply to add `listitems` to a newly generate + # python instance + def get_state_keys_blacklist(self): + return ['^extend$', '^__.*__$'] + + def get_state_keys_blacklist_regexp(self): + if not hasattr(self, 'state_keys_blacklist_regexp'): + self.state_keys_blacklist_regexp = re.compile('(' + '|'.join(self.get_state_keys_blacklist()) + ')') + return self.state_keys_blacklist_regexp + + def construct_python_str(self, node): + return self.construct_scalar(node) + + def construct_python_unicode(self, node): + return self.construct_scalar(node) + + def construct_python_bytes(self, node): + try: + value = self.construct_scalar(node).encode('ascii') + except UnicodeEncodeError as exc: + raise ConstructorError(None, None, + "failed to convert base64 data into ascii: %s" % exc, + node.start_mark) + try: + if hasattr(base64, 'decodebytes'): + return base64.decodebytes(value) + else: + return base64.decodestring(value) + except binascii.Error as exc: + raise ConstructorError(None, None, + "failed to decode base64 data: %s" % exc, node.start_mark) + + def construct_python_long(self, node): + return self.construct_yaml_int(node) + + def construct_python_complex(self, node): + return complex(self.construct_scalar(node)) + + def construct_python_tuple(self, node): + return tuple(self.construct_sequence(node)) + + def find_python_module(self, name, mark, unsafe=False): + if not name: + raise ConstructorError("while constructing a Python module", mark, + "expected non-empty name appended to the tag", mark) + if unsafe: + try: + __import__(name) + except ImportError as exc: + raise ConstructorError("while constructing a Python module", mark, + "cannot find module %r (%s)" % (name, exc), mark) + if name not in sys.modules: + raise ConstructorError("while constructing a Python module", mark, + "module %r is not imported" % name, mark) + return sys.modules[name] + + def find_python_name(self, name, mark, unsafe=False): + if not name: + raise ConstructorError("while constructing a Python object", mark, + "expected non-empty name appended to the tag", mark) + if '.' in name: + module_name, object_name = name.rsplit('.', 1) + else: + module_name = 'builtins' + object_name = name + if unsafe: + try: + __import__(module_name) + except ImportError as exc: + raise ConstructorError("while constructing a Python object", mark, + "cannot find module %r (%s)" % (module_name, exc), mark) + if module_name not in sys.modules: + raise ConstructorError("while constructing a Python object", mark, + "module %r is not imported" % module_name, mark) + module = sys.modules[module_name] + if not hasattr(module, object_name): + raise ConstructorError("while constructing a Python object", mark, + "cannot find %r in the module %r" + % (object_name, module.__name__), mark) + return getattr(module, object_name) + + def construct_python_name(self, suffix, node): + value = self.construct_scalar(node) + if value: + raise ConstructorError("while constructing a Python name", node.start_mark, + "expected the empty value, but found %r" % value, node.start_mark) + return self.find_python_name(suffix, node.start_mark) + + def construct_python_module(self, suffix, node): + value = self.construct_scalar(node) + if value: + raise ConstructorError("while constructing a Python module", node.start_mark, + "expected the empty value, but found %r" % value, node.start_mark) + return self.find_python_module(suffix, node.start_mark) + + def make_python_instance(self, suffix, node, + args=None, kwds=None, newobj=False, unsafe=False): + if not args: + args = [] + if not kwds: + kwds = {} + cls = self.find_python_name(suffix, node.start_mark) + if not (unsafe or isinstance(cls, type)): + raise ConstructorError("while constructing a Python instance", node.start_mark, + "expected a class, but found %r" % type(cls), + node.start_mark) + if newobj and isinstance(cls, type): + return cls.__new__(cls, *args, **kwds) + else: + return cls(*args, **kwds) + + def set_python_instance_state(self, instance, state, unsafe=False): + if hasattr(instance, '__setstate__'): + instance.__setstate__(state) + else: + slotstate = {} + if isinstance(state, tuple) and len(state) == 2: + state, slotstate = state + if hasattr(instance, '__dict__'): + if not unsafe and state: + for key in state.keys(): + self.check_state_key(key) + instance.__dict__.update(state) + elif state: + slotstate.update(state) + for key, value in slotstate.items(): + if not unsafe: + self.check_state_key(key) + setattr(instance, key, value) + + def construct_python_object(self, suffix, node): + # Format: + # !!python/object:module.name { ... state ... } + instance = self.make_python_instance(suffix, node, newobj=True) + yield instance + deep = hasattr(instance, '__setstate__') + state = self.construct_mapping(node, deep=deep) + self.set_python_instance_state(instance, state) + + def construct_python_object_apply(self, suffix, node, newobj=False): + # Format: + # !!python/object/apply # (or !!python/object/new) + # args: [ ... arguments ... ] + # kwds: { ... keywords ... } + # state: ... state ... + # listitems: [ ... listitems ... ] + # dictitems: { ... dictitems ... } + # or short format: + # !!python/object/apply [ ... arguments ... ] + # The difference between !!python/object/apply and !!python/object/new + # is how an object is created, check make_python_instance for details. + if isinstance(node, SequenceNode): + args = self.construct_sequence(node, deep=True) + kwds = {} + state = {} + listitems = [] + dictitems = {} + else: + value = self.construct_mapping(node, deep=True) + args = value.get('args', []) + kwds = value.get('kwds', {}) + state = value.get('state', {}) + listitems = value.get('listitems', []) + dictitems = value.get('dictitems', {}) + instance = self.make_python_instance(suffix, node, args, kwds, newobj) + if state: + self.set_python_instance_state(instance, state) + if listitems: + instance.extend(listitems) + if dictitems: + for key in dictitems: + instance[key] = dictitems[key] + return instance + + def construct_python_object_new(self, suffix, node): + return self.construct_python_object_apply(suffix, node, newobj=True) + +FullConstructor.add_constructor( + 'tag:yaml.org,2002:python/none', + FullConstructor.construct_yaml_null) + +FullConstructor.add_constructor( + 'tag:yaml.org,2002:python/bool', + FullConstructor.construct_yaml_bool) + +FullConstructor.add_constructor( + 'tag:yaml.org,2002:python/str', + FullConstructor.construct_python_str) + +FullConstructor.add_constructor( + 'tag:yaml.org,2002:python/unicode', + FullConstructor.construct_python_unicode) + +FullConstructor.add_constructor( + 'tag:yaml.org,2002:python/bytes', + FullConstructor.construct_python_bytes) + +FullConstructor.add_constructor( + 'tag:yaml.org,2002:python/int', + FullConstructor.construct_yaml_int) + +FullConstructor.add_constructor( + 'tag:yaml.org,2002:python/long', + FullConstructor.construct_python_long) + +FullConstructor.add_constructor( + 'tag:yaml.org,2002:python/float', + FullConstructor.construct_yaml_float) + +FullConstructor.add_constructor( + 'tag:yaml.org,2002:python/complex', + FullConstructor.construct_python_complex) + +FullConstructor.add_constructor( + 'tag:yaml.org,2002:python/list', + FullConstructor.construct_yaml_seq) + +FullConstructor.add_constructor( + 'tag:yaml.org,2002:python/tuple', + FullConstructor.construct_python_tuple) + +FullConstructor.add_constructor( + 'tag:yaml.org,2002:python/dict', + FullConstructor.construct_yaml_map) + +FullConstructor.add_multi_constructor( + 'tag:yaml.org,2002:python/name:', + FullConstructor.construct_python_name) + +class UnsafeConstructor(FullConstructor): + + def find_python_module(self, name, mark): + return super(UnsafeConstructor, self).find_python_module(name, mark, unsafe=True) + + def find_python_name(self, name, mark): + return super(UnsafeConstructor, self).find_python_name(name, mark, unsafe=True) + + def make_python_instance(self, suffix, node, args=None, kwds=None, newobj=False): + return super(UnsafeConstructor, self).make_python_instance( + suffix, node, args, kwds, newobj, unsafe=True) + + def set_python_instance_state(self, instance, state): + return super(UnsafeConstructor, self).set_python_instance_state( + instance, state, unsafe=True) + +UnsafeConstructor.add_multi_constructor( + 'tag:yaml.org,2002:python/module:', + UnsafeConstructor.construct_python_module) + +UnsafeConstructor.add_multi_constructor( + 'tag:yaml.org,2002:python/object:', + UnsafeConstructor.construct_python_object) + +UnsafeConstructor.add_multi_constructor( + 'tag:yaml.org,2002:python/object/new:', + UnsafeConstructor.construct_python_object_new) + +UnsafeConstructor.add_multi_constructor( + 'tag:yaml.org,2002:python/object/apply:', + UnsafeConstructor.construct_python_object_apply) + +# Constructor is same as UnsafeConstructor. Need to leave this in place in case +# people have extended it directly. +class Constructor(UnsafeConstructor): + pass diff --git a/source/yaml/cyaml.py b/source/yaml/cyaml.py new file mode 100644 index 0000000000000000000000000000000000000000..0c21345879b298bb8668201bebe7d289586b17f9 --- /dev/null +++ b/source/yaml/cyaml.py @@ -0,0 +1,101 @@ + +__all__ = [ + 'CBaseLoader', 'CSafeLoader', 'CFullLoader', 'CUnsafeLoader', 'CLoader', + 'CBaseDumper', 'CSafeDumper', 'CDumper' +] + +from yaml._yaml import CParser, CEmitter + +from .constructor import * + +from .serializer import * +from .representer import * + +from .resolver import * + +class CBaseLoader(CParser, BaseConstructor, BaseResolver): + + def __init__(self, stream): + CParser.__init__(self, stream) + BaseConstructor.__init__(self) + BaseResolver.__init__(self) + +class CSafeLoader(CParser, SafeConstructor, Resolver): + + def __init__(self, stream): + CParser.__init__(self, stream) + SafeConstructor.__init__(self) + Resolver.__init__(self) + +class CFullLoader(CParser, FullConstructor, Resolver): + + def __init__(self, stream): + CParser.__init__(self, stream) + FullConstructor.__init__(self) + Resolver.__init__(self) + +class CUnsafeLoader(CParser, UnsafeConstructor, Resolver): + + def __init__(self, stream): + CParser.__init__(self, stream) + UnsafeConstructor.__init__(self) + Resolver.__init__(self) + +class CLoader(CParser, Constructor, Resolver): + + def __init__(self, stream): + CParser.__init__(self, stream) + Constructor.__init__(self) + Resolver.__init__(self) + +class CBaseDumper(CEmitter, BaseRepresenter, BaseResolver): + + def __init__(self, stream, + default_style=None, default_flow_style=False, + canonical=None, indent=None, width=None, + allow_unicode=None, line_break=None, + encoding=None, explicit_start=None, explicit_end=None, + version=None, tags=None, sort_keys=True): + CEmitter.__init__(self, stream, canonical=canonical, + indent=indent, width=width, encoding=encoding, + allow_unicode=allow_unicode, line_break=line_break, + explicit_start=explicit_start, explicit_end=explicit_end, + version=version, tags=tags) + Representer.__init__(self, default_style=default_style, + default_flow_style=default_flow_style, sort_keys=sort_keys) + Resolver.__init__(self) + +class CSafeDumper(CEmitter, SafeRepresenter, Resolver): + + def __init__(self, stream, + default_style=None, default_flow_style=False, + canonical=None, indent=None, width=None, + allow_unicode=None, line_break=None, + encoding=None, explicit_start=None, explicit_end=None, + version=None, tags=None, sort_keys=True): + CEmitter.__init__(self, stream, canonical=canonical, + indent=indent, width=width, encoding=encoding, + allow_unicode=allow_unicode, line_break=line_break, + explicit_start=explicit_start, explicit_end=explicit_end, + version=version, tags=tags) + SafeRepresenter.__init__(self, default_style=default_style, + default_flow_style=default_flow_style, sort_keys=sort_keys) + Resolver.__init__(self) + +class CDumper(CEmitter, Serializer, Representer, Resolver): + + def __init__(self, stream, + default_style=None, default_flow_style=False, + canonical=None, indent=None, width=None, + allow_unicode=None, line_break=None, + encoding=None, explicit_start=None, explicit_end=None, + version=None, tags=None, sort_keys=True): + CEmitter.__init__(self, stream, canonical=canonical, + indent=indent, width=width, encoding=encoding, + allow_unicode=allow_unicode, line_break=line_break, + explicit_start=explicit_start, explicit_end=explicit_end, + version=version, tags=tags) + Representer.__init__(self, default_style=default_style, + default_flow_style=default_flow_style, sort_keys=sort_keys) + Resolver.__init__(self) + diff --git a/source/yaml/dumper.py b/source/yaml/dumper.py new file mode 100644 index 0000000000000000000000000000000000000000..6aadba551f3836b02f4752277f4b3027073defad --- /dev/null +++ b/source/yaml/dumper.py @@ -0,0 +1,62 @@ + +__all__ = ['BaseDumper', 'SafeDumper', 'Dumper'] + +from .emitter import * +from .serializer import * +from .representer import * +from .resolver import * + +class BaseDumper(Emitter, Serializer, BaseRepresenter, BaseResolver): + + def __init__(self, stream, + default_style=None, default_flow_style=False, + canonical=None, indent=None, width=None, + allow_unicode=None, line_break=None, + encoding=None, explicit_start=None, explicit_end=None, + version=None, tags=None, sort_keys=True): + Emitter.__init__(self, stream, canonical=canonical, + indent=indent, width=width, + allow_unicode=allow_unicode, line_break=line_break) + Serializer.__init__(self, encoding=encoding, + explicit_start=explicit_start, explicit_end=explicit_end, + version=version, tags=tags) + Representer.__init__(self, default_style=default_style, + default_flow_style=default_flow_style, sort_keys=sort_keys) + Resolver.__init__(self) + +class SafeDumper(Emitter, Serializer, SafeRepresenter, Resolver): + + def __init__(self, stream, + default_style=None, default_flow_style=False, + canonical=None, indent=None, width=None, + allow_unicode=None, line_break=None, + encoding=None, explicit_start=None, explicit_end=None, + version=None, tags=None, sort_keys=True): + Emitter.__init__(self, stream, canonical=canonical, + indent=indent, width=width, + allow_unicode=allow_unicode, line_break=line_break) + Serializer.__init__(self, encoding=encoding, + explicit_start=explicit_start, explicit_end=explicit_end, + version=version, tags=tags) + SafeRepresenter.__init__(self, default_style=default_style, + default_flow_style=default_flow_style, sort_keys=sort_keys) + Resolver.__init__(self) + +class Dumper(Emitter, Serializer, Representer, Resolver): + + def __init__(self, stream, + default_style=None, default_flow_style=False, + canonical=None, indent=None, width=None, + allow_unicode=None, line_break=None, + encoding=None, explicit_start=None, explicit_end=None, + version=None, tags=None, sort_keys=True): + Emitter.__init__(self, stream, canonical=canonical, + indent=indent, width=width, + allow_unicode=allow_unicode, line_break=line_break) + Serializer.__init__(self, encoding=encoding, + explicit_start=explicit_start, explicit_end=explicit_end, + version=version, tags=tags) + Representer.__init__(self, default_style=default_style, + default_flow_style=default_flow_style, sort_keys=sort_keys) + Resolver.__init__(self) + diff --git a/source/yaml/emitter.py b/source/yaml/emitter.py new file mode 100644 index 0000000000000000000000000000000000000000..a664d011162af69184df2f8e59ab7feec818f7c7 --- /dev/null +++ b/source/yaml/emitter.py @@ -0,0 +1,1137 @@ + +# Emitter expects events obeying the following grammar: +# stream ::= STREAM-START document* STREAM-END +# document ::= DOCUMENT-START node DOCUMENT-END +# node ::= SCALAR | sequence | mapping +# sequence ::= SEQUENCE-START node* SEQUENCE-END +# mapping ::= MAPPING-START (node node)* MAPPING-END + +__all__ = ['Emitter', 'EmitterError'] + +from .error import YAMLError +from .events import * + +class EmitterError(YAMLError): + pass + +class ScalarAnalysis: + def __init__(self, scalar, empty, multiline, + allow_flow_plain, allow_block_plain, + allow_single_quoted, allow_double_quoted, + allow_block): + self.scalar = scalar + self.empty = empty + self.multiline = multiline + self.allow_flow_plain = allow_flow_plain + self.allow_block_plain = allow_block_plain + self.allow_single_quoted = allow_single_quoted + self.allow_double_quoted = allow_double_quoted + self.allow_block = allow_block + +class Emitter: + + DEFAULT_TAG_PREFIXES = { + '!' : '!', + 'tag:yaml.org,2002:' : '!!', + } + + def __init__(self, stream, canonical=None, indent=None, width=None, + allow_unicode=None, line_break=None): + + # The stream should have the methods `write` and possibly `flush`. + self.stream = stream + + # Encoding can be overridden by STREAM-START. + self.encoding = None + + # Emitter is a state machine with a stack of states to handle nested + # structures. + self.states = [] + self.state = self.expect_stream_start + + # Current event and the event queue. + self.events = [] + self.event = None + + # The current indentation level and the stack of previous indents. + self.indents = [] + self.indent = None + + # Flow level. + self.flow_level = 0 + + # Contexts. + self.root_context = False + self.sequence_context = False + self.mapping_context = False + self.simple_key_context = False + + # Characteristics of the last emitted character: + # - current position. + # - is it a whitespace? + # - is it an indention character + # (indentation space, '-', '?', or ':')? + self.line = 0 + self.column = 0 + self.whitespace = True + self.indention = True + + # Whether the document requires an explicit document indicator + self.open_ended = False + + # Formatting details. + self.canonical = canonical + self.allow_unicode = allow_unicode + self.best_indent = 2 + if indent and 1 < indent < 10: + self.best_indent = indent + self.best_width = 80 + if width and width > self.best_indent*2: + self.best_width = width + self.best_line_break = '\n' + if line_break in ['\r', '\n', '\r\n']: + self.best_line_break = line_break + + # Tag prefixes. + self.tag_prefixes = None + + # Prepared anchor and tag. + self.prepared_anchor = None + self.prepared_tag = None + + # Scalar analysis and style. + self.analysis = None + self.style = None + + def dispose(self): + # Reset the state attributes (to clear self-references) + self.states = [] + self.state = None + + def emit(self, event): + self.events.append(event) + while not self.need_more_events(): + self.event = self.events.pop(0) + self.state() + self.event = None + + # In some cases, we wait for a few next events before emitting. + + def need_more_events(self): + if not self.events: + return True + event = self.events[0] + if isinstance(event, DocumentStartEvent): + return self.need_events(1) + elif isinstance(event, SequenceStartEvent): + return self.need_events(2) + elif isinstance(event, MappingStartEvent): + return self.need_events(3) + else: + return False + + def need_events(self, count): + level = 0 + for event in self.events[1:]: + if isinstance(event, (DocumentStartEvent, CollectionStartEvent)): + level += 1 + elif isinstance(event, (DocumentEndEvent, CollectionEndEvent)): + level -= 1 + elif isinstance(event, StreamEndEvent): + level = -1 + if level < 0: + return False + return (len(self.events) < count+1) + + def increase_indent(self, flow=False, indentless=False): + self.indents.append(self.indent) + if self.indent is None: + if flow: + self.indent = self.best_indent + else: + self.indent = 0 + elif not indentless: + self.indent += self.best_indent + + # States. + + # Stream handlers. + + def expect_stream_start(self): + if isinstance(self.event, StreamStartEvent): + if self.event.encoding and not hasattr(self.stream, 'encoding'): + self.encoding = self.event.encoding + self.write_stream_start() + self.state = self.expect_first_document_start + else: + raise EmitterError("expected StreamStartEvent, but got %s" + % self.event) + + def expect_nothing(self): + raise EmitterError("expected nothing, but got %s" % self.event) + + # Document handlers. + + def expect_first_document_start(self): + return self.expect_document_start(first=True) + + def expect_document_start(self, first=False): + if isinstance(self.event, DocumentStartEvent): + if (self.event.version or self.event.tags) and self.open_ended: + self.write_indicator('...', True) + self.write_indent() + if self.event.version: + version_text = self.prepare_version(self.event.version) + self.write_version_directive(version_text) + self.tag_prefixes = self.DEFAULT_TAG_PREFIXES.copy() + if self.event.tags: + handles = sorted(self.event.tags.keys()) + for handle in handles: + prefix = self.event.tags[handle] + self.tag_prefixes[prefix] = handle + handle_text = self.prepare_tag_handle(handle) + prefix_text = self.prepare_tag_prefix(prefix) + self.write_tag_directive(handle_text, prefix_text) + implicit = (first and not self.event.explicit and not self.canonical + and not self.event.version and not self.event.tags + and not self.check_empty_document()) + if not implicit: + self.write_indent() + self.write_indicator('---', True) + if self.canonical: + self.write_indent() + self.state = self.expect_document_root + elif isinstance(self.event, StreamEndEvent): + if self.open_ended: + self.write_indicator('...', True) + self.write_indent() + self.write_stream_end() + self.state = self.expect_nothing + else: + raise EmitterError("expected DocumentStartEvent, but got %s" + % self.event) + + def expect_document_end(self): + if isinstance(self.event, DocumentEndEvent): + self.write_indent() + if self.event.explicit: + self.write_indicator('...', True) + self.write_indent() + self.flush_stream() + self.state = self.expect_document_start + else: + raise EmitterError("expected DocumentEndEvent, but got %s" + % self.event) + + def expect_document_root(self): + self.states.append(self.expect_document_end) + self.expect_node(root=True) + + # Node handlers. + + def expect_node(self, root=False, sequence=False, mapping=False, + simple_key=False): + self.root_context = root + self.sequence_context = sequence + self.mapping_context = mapping + self.simple_key_context = simple_key + if isinstance(self.event, AliasEvent): + self.expect_alias() + elif isinstance(self.event, (ScalarEvent, CollectionStartEvent)): + self.process_anchor('&') + self.process_tag() + if isinstance(self.event, ScalarEvent): + self.expect_scalar() + elif isinstance(self.event, SequenceStartEvent): + if self.flow_level or self.canonical or self.event.flow_style \ + or self.check_empty_sequence(): + self.expect_flow_sequence() + else: + self.expect_block_sequence() + elif isinstance(self.event, MappingStartEvent): + if self.flow_level or self.canonical or self.event.flow_style \ + or self.check_empty_mapping(): + self.expect_flow_mapping() + else: + self.expect_block_mapping() + else: + raise EmitterError("expected NodeEvent, but got %s" % self.event) + + def expect_alias(self): + if self.event.anchor is None: + raise EmitterError("anchor is not specified for alias") + self.process_anchor('*') + self.state = self.states.pop() + + def expect_scalar(self): + self.increase_indent(flow=True) + self.process_scalar() + self.indent = self.indents.pop() + self.state = self.states.pop() + + # Flow sequence handlers. + + def expect_flow_sequence(self): + self.write_indicator('[', True, whitespace=True) + self.flow_level += 1 + self.increase_indent(flow=True) + self.state = self.expect_first_flow_sequence_item + + def expect_first_flow_sequence_item(self): + if isinstance(self.event, SequenceEndEvent): + self.indent = self.indents.pop() + self.flow_level -= 1 + self.write_indicator(']', False) + self.state = self.states.pop() + else: + if self.canonical or self.column > self.best_width: + self.write_indent() + self.states.append(self.expect_flow_sequence_item) + self.expect_node(sequence=True) + + def expect_flow_sequence_item(self): + if isinstance(self.event, SequenceEndEvent): + self.indent = self.indents.pop() + self.flow_level -= 1 + if self.canonical: + self.write_indicator(',', False) + self.write_indent() + self.write_indicator(']', False) + self.state = self.states.pop() + else: + self.write_indicator(',', False) + if self.canonical or self.column > self.best_width: + self.write_indent() + self.states.append(self.expect_flow_sequence_item) + self.expect_node(sequence=True) + + # Flow mapping handlers. + + def expect_flow_mapping(self): + self.write_indicator('{', True, whitespace=True) + self.flow_level += 1 + self.increase_indent(flow=True) + self.state = self.expect_first_flow_mapping_key + + def expect_first_flow_mapping_key(self): + if isinstance(self.event, MappingEndEvent): + self.indent = self.indents.pop() + self.flow_level -= 1 + self.write_indicator('}', False) + self.state = self.states.pop() + else: + if self.canonical or self.column > self.best_width: + self.write_indent() + if not self.canonical and self.check_simple_key(): + self.states.append(self.expect_flow_mapping_simple_value) + self.expect_node(mapping=True, simple_key=True) + else: + self.write_indicator('?', True) + self.states.append(self.expect_flow_mapping_value) + self.expect_node(mapping=True) + + def expect_flow_mapping_key(self): + if isinstance(self.event, MappingEndEvent): + self.indent = self.indents.pop() + self.flow_level -= 1 + if self.canonical: + self.write_indicator(',', False) + self.write_indent() + self.write_indicator('}', False) + self.state = self.states.pop() + else: + self.write_indicator(',', False) + if self.canonical or self.column > self.best_width: + self.write_indent() + if not self.canonical and self.check_simple_key(): + self.states.append(self.expect_flow_mapping_simple_value) + self.expect_node(mapping=True, simple_key=True) + else: + self.write_indicator('?', True) + self.states.append(self.expect_flow_mapping_value) + self.expect_node(mapping=True) + + def expect_flow_mapping_simple_value(self): + self.write_indicator(':', False) + self.states.append(self.expect_flow_mapping_key) + self.expect_node(mapping=True) + + def expect_flow_mapping_value(self): + if self.canonical or self.column > self.best_width: + self.write_indent() + self.write_indicator(':', True) + self.states.append(self.expect_flow_mapping_key) + self.expect_node(mapping=True) + + # Block sequence handlers. + + def expect_block_sequence(self): + indentless = (self.mapping_context and not self.indention) + self.increase_indent(flow=False, indentless=indentless) + self.state = self.expect_first_block_sequence_item + + def expect_first_block_sequence_item(self): + return self.expect_block_sequence_item(first=True) + + def expect_block_sequence_item(self, first=False): + if not first and isinstance(self.event, SequenceEndEvent): + self.indent = self.indents.pop() + self.state = self.states.pop() + else: + self.write_indent() + self.write_indicator('-', True, indention=True) + self.states.append(self.expect_block_sequence_item) + self.expect_node(sequence=True) + + # Block mapping handlers. + + def expect_block_mapping(self): + self.increase_indent(flow=False) + self.state = self.expect_first_block_mapping_key + + def expect_first_block_mapping_key(self): + return self.expect_block_mapping_key(first=True) + + def expect_block_mapping_key(self, first=False): + if not first and isinstance(self.event, MappingEndEvent): + self.indent = self.indents.pop() + self.state = self.states.pop() + else: + self.write_indent() + if self.check_simple_key(): + self.states.append(self.expect_block_mapping_simple_value) + self.expect_node(mapping=True, simple_key=True) + else: + self.write_indicator('?', True, indention=True) + self.states.append(self.expect_block_mapping_value) + self.expect_node(mapping=True) + + def expect_block_mapping_simple_value(self): + self.write_indicator(':', False) + self.states.append(self.expect_block_mapping_key) + self.expect_node(mapping=True) + + def expect_block_mapping_value(self): + self.write_indent() + self.write_indicator(':', True, indention=True) + self.states.append(self.expect_block_mapping_key) + self.expect_node(mapping=True) + + # Checkers. + + def check_empty_sequence(self): + return (isinstance(self.event, SequenceStartEvent) and self.events + and isinstance(self.events[0], SequenceEndEvent)) + + def check_empty_mapping(self): + return (isinstance(self.event, MappingStartEvent) and self.events + and isinstance(self.events[0], MappingEndEvent)) + + def check_empty_document(self): + if not isinstance(self.event, DocumentStartEvent) or not self.events: + return False + event = self.events[0] + return (isinstance(event, ScalarEvent) and event.anchor is None + and event.tag is None and event.implicit and event.value == '') + + def check_simple_key(self): + length = 0 + if isinstance(self.event, NodeEvent) and self.event.anchor is not None: + if self.prepared_anchor is None: + self.prepared_anchor = self.prepare_anchor(self.event.anchor) + length += len(self.prepared_anchor) + if isinstance(self.event, (ScalarEvent, CollectionStartEvent)) \ + and self.event.tag is not None: + if self.prepared_tag is None: + self.prepared_tag = self.prepare_tag(self.event.tag) + length += len(self.prepared_tag) + if isinstance(self.event, ScalarEvent): + if self.analysis is None: + self.analysis = self.analyze_scalar(self.event.value) + length += len(self.analysis.scalar) + return (length < 128 and (isinstance(self.event, AliasEvent) + or (isinstance(self.event, ScalarEvent) + and not self.analysis.empty and not self.analysis.multiline) + or self.check_empty_sequence() or self.check_empty_mapping())) + + # Anchor, Tag, and Scalar processors. + + def process_anchor(self, indicator): + if self.event.anchor is None: + self.prepared_anchor = None + return + if self.prepared_anchor is None: + self.prepared_anchor = self.prepare_anchor(self.event.anchor) + if self.prepared_anchor: + self.write_indicator(indicator+self.prepared_anchor, True) + self.prepared_anchor = None + + def process_tag(self): + tag = self.event.tag + if isinstance(self.event, ScalarEvent): + if self.style is None: + self.style = self.choose_scalar_style() + if ((not self.canonical or tag is None) and + ((self.style == '' and self.event.implicit[0]) + or (self.style != '' and self.event.implicit[1]))): + self.prepared_tag = None + return + if self.event.implicit[0] and tag is None: + tag = '!' + self.prepared_tag = None + else: + if (not self.canonical or tag is None) and self.event.implicit: + self.prepared_tag = None + return + if tag is None: + raise EmitterError("tag is not specified") + if self.prepared_tag is None: + self.prepared_tag = self.prepare_tag(tag) + if self.prepared_tag: + self.write_indicator(self.prepared_tag, True) + self.prepared_tag = None + + def choose_scalar_style(self): + if self.analysis is None: + self.analysis = self.analyze_scalar(self.event.value) + if self.event.style == '"' or self.canonical: + return '"' + if not self.event.style and self.event.implicit[0]: + if (not (self.simple_key_context and + (self.analysis.empty or self.analysis.multiline)) + and (self.flow_level and self.analysis.allow_flow_plain + or (not self.flow_level and self.analysis.allow_block_plain))): + return '' + if self.event.style and self.event.style in '|>': + if (not self.flow_level and not self.simple_key_context + and self.analysis.allow_block): + return self.event.style + if not self.event.style or self.event.style == '\'': + if (self.analysis.allow_single_quoted and + not (self.simple_key_context and self.analysis.multiline)): + return '\'' + return '"' + + def process_scalar(self): + if self.analysis is None: + self.analysis = self.analyze_scalar(self.event.value) + if self.style is None: + self.style = self.choose_scalar_style() + split = (not self.simple_key_context) + #if self.analysis.multiline and split \ + # and (not self.style or self.style in '\'\"'): + # self.write_indent() + if self.style == '"': + self.write_double_quoted(self.analysis.scalar, split) + elif self.style == '\'': + self.write_single_quoted(self.analysis.scalar, split) + elif self.style == '>': + self.write_folded(self.analysis.scalar) + elif self.style == '|': + self.write_literal(self.analysis.scalar) + else: + self.write_plain(self.analysis.scalar, split) + self.analysis = None + self.style = None + + # Analyzers. + + def prepare_version(self, version): + major, minor = version + if major != 1: + raise EmitterError("unsupported YAML version: %d.%d" % (major, minor)) + return '%d.%d' % (major, minor) + + def prepare_tag_handle(self, handle): + if not handle: + raise EmitterError("tag handle must not be empty") + if handle[0] != '!' or handle[-1] != '!': + raise EmitterError("tag handle must start and end with '!': %r" % handle) + for ch in handle[1:-1]: + if not ('0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \ + or ch in '-_'): + raise EmitterError("invalid character %r in the tag handle: %r" + % (ch, handle)) + return handle + + def prepare_tag_prefix(self, prefix): + if not prefix: + raise EmitterError("tag prefix must not be empty") + chunks = [] + start = end = 0 + if prefix[0] == '!': + end = 1 + while end < len(prefix): + ch = prefix[end] + if '0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \ + or ch in '-;/?!:@&=+$,_.~*\'()[]': + end += 1 + else: + if start < end: + chunks.append(prefix[start:end]) + start = end = end+1 + data = ch.encode('utf-8') + for ch in data: + chunks.append('%%%02X' % ord(ch)) + if start < end: + chunks.append(prefix[start:end]) + return ''.join(chunks) + + def prepare_tag(self, tag): + if not tag: + raise EmitterError("tag must not be empty") + if tag == '!': + return tag + handle = None + suffix = tag + prefixes = sorted(self.tag_prefixes.keys()) + for prefix in prefixes: + if tag.startswith(prefix) \ + and (prefix == '!' or len(prefix) < len(tag)): + handle = self.tag_prefixes[prefix] + suffix = tag[len(prefix):] + chunks = [] + start = end = 0 + while end < len(suffix): + ch = suffix[end] + if '0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \ + or ch in '-;/?:@&=+$,_.~*\'()[]' \ + or (ch == '!' and handle != '!'): + end += 1 + else: + if start < end: + chunks.append(suffix[start:end]) + start = end = end+1 + data = ch.encode('utf-8') + for ch in data: + chunks.append('%%%02X' % ch) + if start < end: + chunks.append(suffix[start:end]) + suffix_text = ''.join(chunks) + if handle: + return '%s%s' % (handle, suffix_text) + else: + return '!<%s>' % suffix_text + + def prepare_anchor(self, anchor): + if not anchor: + raise EmitterError("anchor must not be empty") + for ch in anchor: + if not ('0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \ + or ch in '-_'): + raise EmitterError("invalid character %r in the anchor: %r" + % (ch, anchor)) + return anchor + + def analyze_scalar(self, scalar): + + # Empty scalar is a special case. + if not scalar: + return ScalarAnalysis(scalar=scalar, empty=True, multiline=False, + allow_flow_plain=False, allow_block_plain=True, + allow_single_quoted=True, allow_double_quoted=True, + allow_block=False) + + # Indicators and special characters. + block_indicators = False + flow_indicators = False + line_breaks = False + special_characters = False + + # Important whitespace combinations. + leading_space = False + leading_break = False + trailing_space = False + trailing_break = False + break_space = False + space_break = False + + # Check document indicators. + if scalar.startswith('---') or scalar.startswith('...'): + block_indicators = True + flow_indicators = True + + # First character or preceded by a whitespace. + preceded_by_whitespace = True + + # Last character or followed by a whitespace. + followed_by_whitespace = (len(scalar) == 1 or + scalar[1] in '\0 \t\r\n\x85\u2028\u2029') + + # The previous character is a space. + previous_space = False + + # The previous character is a break. + previous_break = False + + index = 0 + while index < len(scalar): + ch = scalar[index] + + # Check for indicators. + if index == 0: + # Leading indicators are special characters. + if ch in '#,[]{}&*!|>\'\"%@`': + flow_indicators = True + block_indicators = True + if ch in '?:': + flow_indicators = True + if followed_by_whitespace: + block_indicators = True + if ch == '-' and followed_by_whitespace: + flow_indicators = True + block_indicators = True + else: + # Some indicators cannot appear within a scalar as well. + if ch in ',?[]{}': + flow_indicators = True + if ch == ':': + flow_indicators = True + if followed_by_whitespace: + block_indicators = True + if ch == '#' and preceded_by_whitespace: + flow_indicators = True + block_indicators = True + + # Check for line breaks, special, and unicode characters. + if ch in '\n\x85\u2028\u2029': + line_breaks = True + if not (ch == '\n' or '\x20' <= ch <= '\x7E'): + if (ch == '\x85' or '\xA0' <= ch <= '\uD7FF' + or '\uE000' <= ch <= '\uFFFD' + or '\U00010000' <= ch < '\U0010ffff') and ch != '\uFEFF': + unicode_characters = True + if not self.allow_unicode: + special_characters = True + else: + special_characters = True + + # Detect important whitespace combinations. + if ch == ' ': + if index == 0: + leading_space = True + if index == len(scalar)-1: + trailing_space = True + if previous_break: + break_space = True + previous_space = True + previous_break = False + elif ch in '\n\x85\u2028\u2029': + if index == 0: + leading_break = True + if index == len(scalar)-1: + trailing_break = True + if previous_space: + space_break = True + previous_space = False + previous_break = True + else: + previous_space = False + previous_break = False + + # Prepare for the next character. + index += 1 + preceded_by_whitespace = (ch in '\0 \t\r\n\x85\u2028\u2029') + followed_by_whitespace = (index+1 >= len(scalar) or + scalar[index+1] in '\0 \t\r\n\x85\u2028\u2029') + + # Let's decide what styles are allowed. + allow_flow_plain = True + allow_block_plain = True + allow_single_quoted = True + allow_double_quoted = True + allow_block = True + + # Leading and trailing whitespaces are bad for plain scalars. + if (leading_space or leading_break + or trailing_space or trailing_break): + allow_flow_plain = allow_block_plain = False + + # We do not permit trailing spaces for block scalars. + if trailing_space: + allow_block = False + + # Spaces at the beginning of a new line are only acceptable for block + # scalars. + if break_space: + allow_flow_plain = allow_block_plain = allow_single_quoted = False + + # Spaces followed by breaks, as well as special character are only + # allowed for double quoted scalars. + if space_break or special_characters: + allow_flow_plain = allow_block_plain = \ + allow_single_quoted = allow_block = False + + # Although the plain scalar writer supports breaks, we never emit + # multiline plain scalars. + if line_breaks: + allow_flow_plain = allow_block_plain = False + + # Flow indicators are forbidden for flow plain scalars. + if flow_indicators: + allow_flow_plain = False + + # Block indicators are forbidden for block plain scalars. + if block_indicators: + allow_block_plain = False + + return ScalarAnalysis(scalar=scalar, + empty=False, multiline=line_breaks, + allow_flow_plain=allow_flow_plain, + allow_block_plain=allow_block_plain, + allow_single_quoted=allow_single_quoted, + allow_double_quoted=allow_double_quoted, + allow_block=allow_block) + + # Writers. + + def flush_stream(self): + if hasattr(self.stream, 'flush'): + self.stream.flush() + + def write_stream_start(self): + # Write BOM if needed. + if self.encoding and self.encoding.startswith('utf-16'): + self.stream.write('\uFEFF'.encode(self.encoding)) + + def write_stream_end(self): + self.flush_stream() + + def write_indicator(self, indicator, need_whitespace, + whitespace=False, indention=False): + if self.whitespace or not need_whitespace: + data = indicator + else: + data = ' '+indicator + self.whitespace = whitespace + self.indention = self.indention and indention + self.column += len(data) + self.open_ended = False + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + + def write_indent(self): + indent = self.indent or 0 + if not self.indention or self.column > indent \ + or (self.column == indent and not self.whitespace): + self.write_line_break() + if self.column < indent: + self.whitespace = True + data = ' '*(indent-self.column) + self.column = indent + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + + def write_line_break(self, data=None): + if data is None: + data = self.best_line_break + self.whitespace = True + self.indention = True + self.line += 1 + self.column = 0 + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + + def write_version_directive(self, version_text): + data = '%%YAML %s' % version_text + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + self.write_line_break() + + def write_tag_directive(self, handle_text, prefix_text): + data = '%%TAG %s %s' % (handle_text, prefix_text) + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + self.write_line_break() + + # Scalar streams. + + def write_single_quoted(self, text, split=True): + self.write_indicator('\'', True) + spaces = False + breaks = False + start = end = 0 + while end <= len(text): + ch = None + if end < len(text): + ch = text[end] + if spaces: + if ch is None or ch != ' ': + if start+1 == end and self.column > self.best_width and split \ + and start != 0 and end != len(text): + self.write_indent() + else: + data = text[start:end] + self.column += len(data) + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + start = end + elif breaks: + if ch is None or ch not in '\n\x85\u2028\u2029': + if text[start] == '\n': + self.write_line_break() + for br in text[start:end]: + if br == '\n': + self.write_line_break() + else: + self.write_line_break(br) + self.write_indent() + start = end + else: + if ch is None or ch in ' \n\x85\u2028\u2029' or ch == '\'': + if start < end: + data = text[start:end] + self.column += len(data) + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + start = end + if ch == '\'': + data = '\'\'' + self.column += 2 + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + start = end + 1 + if ch is not None: + spaces = (ch == ' ') + breaks = (ch in '\n\x85\u2028\u2029') + end += 1 + self.write_indicator('\'', False) + + ESCAPE_REPLACEMENTS = { + '\0': '0', + '\x07': 'a', + '\x08': 'b', + '\x09': 't', + '\x0A': 'n', + '\x0B': 'v', + '\x0C': 'f', + '\x0D': 'r', + '\x1B': 'e', + '\"': '\"', + '\\': '\\', + '\x85': 'N', + '\xA0': '_', + '\u2028': 'L', + '\u2029': 'P', + } + + def write_double_quoted(self, text, split=True): + self.write_indicator('"', True) + start = end = 0 + while end <= len(text): + ch = None + if end < len(text): + ch = text[end] + if ch is None or ch in '"\\\x85\u2028\u2029\uFEFF' \ + or not ('\x20' <= ch <= '\x7E' + or (self.allow_unicode + and ('\xA0' <= ch <= '\uD7FF' + or '\uE000' <= ch <= '\uFFFD'))): + if start < end: + data = text[start:end] + self.column += len(data) + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + start = end + if ch is not None: + if ch in self.ESCAPE_REPLACEMENTS: + data = '\\'+self.ESCAPE_REPLACEMENTS[ch] + elif ch <= '\xFF': + data = '\\x%02X' % ord(ch) + elif ch <= '\uFFFF': + data = '\\u%04X' % ord(ch) + else: + data = '\\U%08X' % ord(ch) + self.column += len(data) + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + start = end+1 + if 0 < end < len(text)-1 and (ch == ' ' or start >= end) \ + and self.column+(end-start) > self.best_width and split: + data = text[start:end]+'\\' + if start < end: + start = end + self.column += len(data) + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + self.write_indent() + self.whitespace = False + self.indention = False + if text[start] == ' ': + data = '\\' + self.column += len(data) + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + end += 1 + self.write_indicator('"', False) + + def determine_block_hints(self, text): + hints = '' + if text: + if text[0] in ' \n\x85\u2028\u2029': + hints += str(self.best_indent) + if text[-1] not in '\n\x85\u2028\u2029': + hints += '-' + elif len(text) == 1 or text[-2] in '\n\x85\u2028\u2029': + hints += '+' + return hints + + def write_folded(self, text): + hints = self.determine_block_hints(text) + self.write_indicator('>'+hints, True) + if hints[-1:] == '+': + self.open_ended = True + self.write_line_break() + leading_space = True + spaces = False + breaks = True + start = end = 0 + while end <= len(text): + ch = None + if end < len(text): + ch = text[end] + if breaks: + if ch is None or ch not in '\n\x85\u2028\u2029': + if not leading_space and ch is not None and ch != ' ' \ + and text[start] == '\n': + self.write_line_break() + leading_space = (ch == ' ') + for br in text[start:end]: + if br == '\n': + self.write_line_break() + else: + self.write_line_break(br) + if ch is not None: + self.write_indent() + start = end + elif spaces: + if ch != ' ': + if start+1 == end and self.column > self.best_width: + self.write_indent() + else: + data = text[start:end] + self.column += len(data) + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + start = end + else: + if ch is None or ch in ' \n\x85\u2028\u2029': + data = text[start:end] + self.column += len(data) + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + if ch is None: + self.write_line_break() + start = end + if ch is not None: + breaks = (ch in '\n\x85\u2028\u2029') + spaces = (ch == ' ') + end += 1 + + def write_literal(self, text): + hints = self.determine_block_hints(text) + self.write_indicator('|'+hints, True) + if hints[-1:] == '+': + self.open_ended = True + self.write_line_break() + breaks = True + start = end = 0 + while end <= len(text): + ch = None + if end < len(text): + ch = text[end] + if breaks: + if ch is None or ch not in '\n\x85\u2028\u2029': + for br in text[start:end]: + if br == '\n': + self.write_line_break() + else: + self.write_line_break(br) + if ch is not None: + self.write_indent() + start = end + else: + if ch is None or ch in '\n\x85\u2028\u2029': + data = text[start:end] + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + if ch is None: + self.write_line_break() + start = end + if ch is not None: + breaks = (ch in '\n\x85\u2028\u2029') + end += 1 + + def write_plain(self, text, split=True): + if self.root_context: + self.open_ended = True + if not text: + return + if not self.whitespace: + data = ' ' + self.column += len(data) + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + self.whitespace = False + self.indention = False + spaces = False + breaks = False + start = end = 0 + while end <= len(text): + ch = None + if end < len(text): + ch = text[end] + if spaces: + if ch != ' ': + if start+1 == end and self.column > self.best_width and split: + self.write_indent() + self.whitespace = False + self.indention = False + else: + data = text[start:end] + self.column += len(data) + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + start = end + elif breaks: + if ch not in '\n\x85\u2028\u2029': + if text[start] == '\n': + self.write_line_break() + for br in text[start:end]: + if br == '\n': + self.write_line_break() + else: + self.write_line_break(br) + self.write_indent() + self.whitespace = False + self.indention = False + start = end + else: + if ch is None or ch in ' \n\x85\u2028\u2029': + data = text[start:end] + self.column += len(data) + if self.encoding: + data = data.encode(self.encoding) + self.stream.write(data) + start = end + if ch is not None: + spaces = (ch == ' ') + breaks = (ch in '\n\x85\u2028\u2029') + end += 1 diff --git a/source/yaml/error.py b/source/yaml/error.py new file mode 100644 index 0000000000000000000000000000000000000000..b796b4dc519512c4825ff539a2e6aa20f4d370d0 --- /dev/null +++ b/source/yaml/error.py @@ -0,0 +1,75 @@ + +__all__ = ['Mark', 'YAMLError', 'MarkedYAMLError'] + +class Mark: + + def __init__(self, name, index, line, column, buffer, pointer): + self.name = name + self.index = index + self.line = line + self.column = column + self.buffer = buffer + self.pointer = pointer + + def get_snippet(self, indent=4, max_length=75): + if self.buffer is None: + return None + head = '' + start = self.pointer + while start > 0 and self.buffer[start-1] not in '\0\r\n\x85\u2028\u2029': + start -= 1 + if self.pointer-start > max_length/2-1: + head = ' ... ' + start += 5 + break + tail = '' + end = self.pointer + while end < len(self.buffer) and self.buffer[end] not in '\0\r\n\x85\u2028\u2029': + end += 1 + if end-self.pointer > max_length/2-1: + tail = ' ... ' + end -= 5 + break + snippet = self.buffer[start:end] + return ' '*indent + head + snippet + tail + '\n' \ + + ' '*(indent+self.pointer-start+len(head)) + '^' + + def __str__(self): + snippet = self.get_snippet() + where = " in \"%s\", line %d, column %d" \ + % (self.name, self.line+1, self.column+1) + if snippet is not None: + where += ":\n"+snippet + return where + +class YAMLError(Exception): + pass + +class MarkedYAMLError(YAMLError): + + def __init__(self, context=None, context_mark=None, + problem=None, problem_mark=None, note=None): + self.context = context + self.context_mark = context_mark + self.problem = problem + self.problem_mark = problem_mark + self.note = note + + def __str__(self): + lines = [] + if self.context is not None: + lines.append(self.context) + if self.context_mark is not None \ + and (self.problem is None or self.problem_mark is None + or self.context_mark.name != self.problem_mark.name + or self.context_mark.line != self.problem_mark.line + or self.context_mark.column != self.problem_mark.column): + lines.append(str(self.context_mark)) + if self.problem is not None: + lines.append(self.problem) + if self.problem_mark is not None: + lines.append(str(self.problem_mark)) + if self.note is not None: + lines.append(self.note) + return '\n'.join(lines) + diff --git a/source/yaml/events.py b/source/yaml/events.py new file mode 100644 index 0000000000000000000000000000000000000000..f79ad389cb6c9517e391dcd25534866bc9ccd36a --- /dev/null +++ b/source/yaml/events.py @@ -0,0 +1,86 @@ + +# Abstract classes. + +class Event(object): + def __init__(self, start_mark=None, end_mark=None): + self.start_mark = start_mark + self.end_mark = end_mark + def __repr__(self): + attributes = [key for key in ['anchor', 'tag', 'implicit', 'value'] + if hasattr(self, key)] + arguments = ', '.join(['%s=%r' % (key, getattr(self, key)) + for key in attributes]) + return '%s(%s)' % (self.__class__.__name__, arguments) + +class NodeEvent(Event): + def __init__(self, anchor, start_mark=None, end_mark=None): + self.anchor = anchor + self.start_mark = start_mark + self.end_mark = end_mark + +class CollectionStartEvent(NodeEvent): + def __init__(self, anchor, tag, implicit, start_mark=None, end_mark=None, + flow_style=None): + self.anchor = anchor + self.tag = tag + self.implicit = implicit + self.start_mark = start_mark + self.end_mark = end_mark + self.flow_style = flow_style + +class CollectionEndEvent(Event): + pass + +# Implementations. + +class StreamStartEvent(Event): + def __init__(self, start_mark=None, end_mark=None, encoding=None): + self.start_mark = start_mark + self.end_mark = end_mark + self.encoding = encoding + +class StreamEndEvent(Event): + pass + +class DocumentStartEvent(Event): + def __init__(self, start_mark=None, end_mark=None, + explicit=None, version=None, tags=None): + self.start_mark = start_mark + self.end_mark = end_mark + self.explicit = explicit + self.version = version + self.tags = tags + +class DocumentEndEvent(Event): + def __init__(self, start_mark=None, end_mark=None, + explicit=None): + self.start_mark = start_mark + self.end_mark = end_mark + self.explicit = explicit + +class AliasEvent(NodeEvent): + pass + +class ScalarEvent(NodeEvent): + def __init__(self, anchor, tag, implicit, value, + start_mark=None, end_mark=None, style=None): + self.anchor = anchor + self.tag = tag + self.implicit = implicit + self.value = value + self.start_mark = start_mark + self.end_mark = end_mark + self.style = style + +class SequenceStartEvent(CollectionStartEvent): + pass + +class SequenceEndEvent(CollectionEndEvent): + pass + +class MappingStartEvent(CollectionStartEvent): + pass + +class MappingEndEvent(CollectionEndEvent): + pass + diff --git a/source/yaml/loader.py b/source/yaml/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..e90c11224c38e559cdf0cb205f0692ebd4fb8681 --- /dev/null +++ b/source/yaml/loader.py @@ -0,0 +1,63 @@ + +__all__ = ['BaseLoader', 'FullLoader', 'SafeLoader', 'Loader', 'UnsafeLoader'] + +from .reader import * +from .scanner import * +from .parser import * +from .composer import * +from .constructor import * +from .resolver import * + +class BaseLoader(Reader, Scanner, Parser, Composer, BaseConstructor, BaseResolver): + + def __init__(self, stream): + Reader.__init__(self, stream) + Scanner.__init__(self) + Parser.__init__(self) + Composer.__init__(self) + BaseConstructor.__init__(self) + BaseResolver.__init__(self) + +class FullLoader(Reader, Scanner, Parser, Composer, FullConstructor, Resolver): + + def __init__(self, stream): + Reader.__init__(self, stream) + Scanner.__init__(self) + Parser.__init__(self) + Composer.__init__(self) + FullConstructor.__init__(self) + Resolver.__init__(self) + +class SafeLoader(Reader, Scanner, Parser, Composer, SafeConstructor, Resolver): + + def __init__(self, stream): + Reader.__init__(self, stream) + Scanner.__init__(self) + Parser.__init__(self) + Composer.__init__(self) + SafeConstructor.__init__(self) + Resolver.__init__(self) + +class Loader(Reader, Scanner, Parser, Composer, Constructor, Resolver): + + def __init__(self, stream): + Reader.__init__(self, stream) + Scanner.__init__(self) + Parser.__init__(self) + Composer.__init__(self) + Constructor.__init__(self) + Resolver.__init__(self) + +# UnsafeLoader is the same as Loader (which is and was always unsafe on +# untrusted input). Use of either Loader or UnsafeLoader should be rare, since +# FullLoad should be able to load almost all YAML safely. Loader is left intact +# to ensure backwards compatibility. +class UnsafeLoader(Reader, Scanner, Parser, Composer, Constructor, Resolver): + + def __init__(self, stream): + Reader.__init__(self, stream) + Scanner.__init__(self) + Parser.__init__(self) + Composer.__init__(self) + Constructor.__init__(self) + Resolver.__init__(self) diff --git a/source/yaml/nodes.py b/source/yaml/nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..c4f070c41e1fb1bc01af27d69329e92dded38908 --- /dev/null +++ b/source/yaml/nodes.py @@ -0,0 +1,49 @@ + +class Node(object): + def __init__(self, tag, value, start_mark, end_mark): + self.tag = tag + self.value = value + self.start_mark = start_mark + self.end_mark = end_mark + def __repr__(self): + value = self.value + #if isinstance(value, list): + # if len(value) == 0: + # value = '' + # elif len(value) == 1: + # value = '<1 item>' + # else: + # value = '<%d items>' % len(value) + #else: + # if len(value) > 75: + # value = repr(value[:70]+u' ... ') + # else: + # value = repr(value) + value = repr(value) + return '%s(tag=%r, value=%s)' % (self.__class__.__name__, self.tag, value) + +class ScalarNode(Node): + id = 'scalar' + def __init__(self, tag, value, + start_mark=None, end_mark=None, style=None): + self.tag = tag + self.value = value + self.start_mark = start_mark + self.end_mark = end_mark + self.style = style + +class CollectionNode(Node): + def __init__(self, tag, value, + start_mark=None, end_mark=None, flow_style=None): + self.tag = tag + self.value = value + self.start_mark = start_mark + self.end_mark = end_mark + self.flow_style = flow_style + +class SequenceNode(CollectionNode): + id = 'sequence' + +class MappingNode(CollectionNode): + id = 'mapping' + diff --git a/source/yaml/parser.py b/source/yaml/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..13a5995d292045d0f865a99abf692bd35dc87814 --- /dev/null +++ b/source/yaml/parser.py @@ -0,0 +1,589 @@ + +# The following YAML grammar is LL(1) and is parsed by a recursive descent +# parser. +# +# stream ::= STREAM-START implicit_document? explicit_document* STREAM-END +# implicit_document ::= block_node DOCUMENT-END* +# explicit_document ::= DIRECTIVE* DOCUMENT-START block_node? DOCUMENT-END* +# block_node_or_indentless_sequence ::= +# ALIAS +# | properties (block_content | indentless_block_sequence)? +# | block_content +# | indentless_block_sequence +# block_node ::= ALIAS +# | properties block_content? +# | block_content +# flow_node ::= ALIAS +# | properties flow_content? +# | flow_content +# properties ::= TAG ANCHOR? | ANCHOR TAG? +# block_content ::= block_collection | flow_collection | SCALAR +# flow_content ::= flow_collection | SCALAR +# block_collection ::= block_sequence | block_mapping +# flow_collection ::= flow_sequence | flow_mapping +# block_sequence ::= BLOCK-SEQUENCE-START (BLOCK-ENTRY block_node?)* BLOCK-END +# indentless_sequence ::= (BLOCK-ENTRY block_node?)+ +# block_mapping ::= BLOCK-MAPPING_START +# ((KEY block_node_or_indentless_sequence?)? +# (VALUE block_node_or_indentless_sequence?)?)* +# BLOCK-END +# flow_sequence ::= FLOW-SEQUENCE-START +# (flow_sequence_entry FLOW-ENTRY)* +# flow_sequence_entry? +# FLOW-SEQUENCE-END +# flow_sequence_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? +# flow_mapping ::= FLOW-MAPPING-START +# (flow_mapping_entry FLOW-ENTRY)* +# flow_mapping_entry? +# FLOW-MAPPING-END +# flow_mapping_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? +# +# FIRST sets: +# +# stream: { STREAM-START } +# explicit_document: { DIRECTIVE DOCUMENT-START } +# implicit_document: FIRST(block_node) +# block_node: { ALIAS TAG ANCHOR SCALAR BLOCK-SEQUENCE-START BLOCK-MAPPING-START FLOW-SEQUENCE-START FLOW-MAPPING-START } +# flow_node: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START FLOW-MAPPING-START } +# block_content: { BLOCK-SEQUENCE-START BLOCK-MAPPING-START FLOW-SEQUENCE-START FLOW-MAPPING-START SCALAR } +# flow_content: { FLOW-SEQUENCE-START FLOW-MAPPING-START SCALAR } +# block_collection: { BLOCK-SEQUENCE-START BLOCK-MAPPING-START } +# flow_collection: { FLOW-SEQUENCE-START FLOW-MAPPING-START } +# block_sequence: { BLOCK-SEQUENCE-START } +# block_mapping: { BLOCK-MAPPING-START } +# block_node_or_indentless_sequence: { ALIAS ANCHOR TAG SCALAR BLOCK-SEQUENCE-START BLOCK-MAPPING-START FLOW-SEQUENCE-START FLOW-MAPPING-START BLOCK-ENTRY } +# indentless_sequence: { ENTRY } +# flow_collection: { FLOW-SEQUENCE-START FLOW-MAPPING-START } +# flow_sequence: { FLOW-SEQUENCE-START } +# flow_mapping: { FLOW-MAPPING-START } +# flow_sequence_entry: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START FLOW-MAPPING-START KEY } +# flow_mapping_entry: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START FLOW-MAPPING-START KEY } + +__all__ = ['Parser', 'ParserError'] + +from .error import MarkedYAMLError +from .tokens import * +from .events import * +from .scanner import * + +class ParserError(MarkedYAMLError): + pass + +class Parser: + # Since writing a recursive-descendant parser is a straightforward task, we + # do not give many comments here. + + DEFAULT_TAGS = { + '!': '!', + '!!': 'tag:yaml.org,2002:', + } + + def __init__(self): + self.current_event = None + self.yaml_version = None + self.tag_handles = {} + self.states = [] + self.marks = [] + self.state = self.parse_stream_start + + def dispose(self): + # Reset the state attributes (to clear self-references) + self.states = [] + self.state = None + + def check_event(self, *choices): + # Check the type of the next event. + if self.current_event is None: + if self.state: + self.current_event = self.state() + if self.current_event is not None: + if not choices: + return True + for choice in choices: + if isinstance(self.current_event, choice): + return True + return False + + def peek_event(self): + # Get the next event. + if self.current_event is None: + if self.state: + self.current_event = self.state() + return self.current_event + + def get_event(self): + # Get the next event and proceed further. + if self.current_event is None: + if self.state: + self.current_event = self.state() + value = self.current_event + self.current_event = None + return value + + # stream ::= STREAM-START implicit_document? explicit_document* STREAM-END + # implicit_document ::= block_node DOCUMENT-END* + # explicit_document ::= DIRECTIVE* DOCUMENT-START block_node? DOCUMENT-END* + + def parse_stream_start(self): + + # Parse the stream start. + token = self.get_token() + event = StreamStartEvent(token.start_mark, token.end_mark, + encoding=token.encoding) + + # Prepare the next state. + self.state = self.parse_implicit_document_start + + return event + + def parse_implicit_document_start(self): + + # Parse an implicit document. + if not self.check_token(DirectiveToken, DocumentStartToken, + StreamEndToken): + self.tag_handles = self.DEFAULT_TAGS + token = self.peek_token() + start_mark = end_mark = token.start_mark + event = DocumentStartEvent(start_mark, end_mark, + explicit=False) + + # Prepare the next state. + self.states.append(self.parse_document_end) + self.state = self.parse_block_node + + return event + + else: + return self.parse_document_start() + + def parse_document_start(self): + + # Parse any extra document end indicators. + while self.check_token(DocumentEndToken): + self.get_token() + + # Parse an explicit document. + if not self.check_token(StreamEndToken): + token = self.peek_token() + start_mark = token.start_mark + version, tags = self.process_directives() + if not self.check_token(DocumentStartToken): + raise ParserError(None, None, + "expected '', but found %r" + % self.peek_token().id, + self.peek_token().start_mark) + token = self.get_token() + end_mark = token.end_mark + event = DocumentStartEvent(start_mark, end_mark, + explicit=True, version=version, tags=tags) + self.states.append(self.parse_document_end) + self.state = self.parse_document_content + else: + # Parse the end of the stream. + token = self.get_token() + event = StreamEndEvent(token.start_mark, token.end_mark) + assert not self.states + assert not self.marks + self.state = None + return event + + def parse_document_end(self): + + # Parse the document end. + token = self.peek_token() + start_mark = end_mark = token.start_mark + explicit = False + if self.check_token(DocumentEndToken): + token = self.get_token() + end_mark = token.end_mark + explicit = True + event = DocumentEndEvent(start_mark, end_mark, + explicit=explicit) + + # Prepare the next state. + self.state = self.parse_document_start + + return event + + def parse_document_content(self): + if self.check_token(DirectiveToken, + DocumentStartToken, DocumentEndToken, StreamEndToken): + event = self.process_empty_scalar(self.peek_token().start_mark) + self.state = self.states.pop() + return event + else: + return self.parse_block_node() + + def process_directives(self): + self.yaml_version = None + self.tag_handles = {} + while self.check_token(DirectiveToken): + token = self.get_token() + if token.name == 'YAML': + if self.yaml_version is not None: + raise ParserError(None, None, + "found duplicate YAML directive", token.start_mark) + major, minor = token.value + if major != 1: + raise ParserError(None, None, + "found incompatible YAML document (version 1.* is required)", + token.start_mark) + self.yaml_version = token.value + elif token.name == 'TAG': + handle, prefix = token.value + if handle in self.tag_handles: + raise ParserError(None, None, + "duplicate tag handle %r" % handle, + token.start_mark) + self.tag_handles[handle] = prefix + if self.tag_handles: + value = self.yaml_version, self.tag_handles.copy() + else: + value = self.yaml_version, None + for key in self.DEFAULT_TAGS: + if key not in self.tag_handles: + self.tag_handles[key] = self.DEFAULT_TAGS[key] + return value + + # block_node_or_indentless_sequence ::= ALIAS + # | properties (block_content | indentless_block_sequence)? + # | block_content + # | indentless_block_sequence + # block_node ::= ALIAS + # | properties block_content? + # | block_content + # flow_node ::= ALIAS + # | properties flow_content? + # | flow_content + # properties ::= TAG ANCHOR? | ANCHOR TAG? + # block_content ::= block_collection | flow_collection | SCALAR + # flow_content ::= flow_collection | SCALAR + # block_collection ::= block_sequence | block_mapping + # flow_collection ::= flow_sequence | flow_mapping + + def parse_block_node(self): + return self.parse_node(block=True) + + def parse_flow_node(self): + return self.parse_node() + + def parse_block_node_or_indentless_sequence(self): + return self.parse_node(block=True, indentless_sequence=True) + + def parse_node(self, block=False, indentless_sequence=False): + if self.check_token(AliasToken): + token = self.get_token() + event = AliasEvent(token.value, token.start_mark, token.end_mark) + self.state = self.states.pop() + else: + anchor = None + tag = None + start_mark = end_mark = tag_mark = None + if self.check_token(AnchorToken): + token = self.get_token() + start_mark = token.start_mark + end_mark = token.end_mark + anchor = token.value + if self.check_token(TagToken): + token = self.get_token() + tag_mark = token.start_mark + end_mark = token.end_mark + tag = token.value + elif self.check_token(TagToken): + token = self.get_token() + start_mark = tag_mark = token.start_mark + end_mark = token.end_mark + tag = token.value + if self.check_token(AnchorToken): + token = self.get_token() + end_mark = token.end_mark + anchor = token.value + if tag is not None: + handle, suffix = tag + if handle is not None: + if handle not in self.tag_handles: + raise ParserError("while parsing a node", start_mark, + "found undefined tag handle %r" % handle, + tag_mark) + tag = self.tag_handles[handle]+suffix + else: + tag = suffix + #if tag == '!': + # raise ParserError("while parsing a node", start_mark, + # "found non-specific tag '!'", tag_mark, + # "Please check 'http://pyyaml.org/wiki/YAMLNonSpecificTag' and share your opinion.") + if start_mark is None: + start_mark = end_mark = self.peek_token().start_mark + event = None + implicit = (tag is None or tag == '!') + if indentless_sequence and self.check_token(BlockEntryToken): + end_mark = self.peek_token().end_mark + event = SequenceStartEvent(anchor, tag, implicit, + start_mark, end_mark) + self.state = self.parse_indentless_sequence_entry + else: + if self.check_token(ScalarToken): + token = self.get_token() + end_mark = token.end_mark + if (token.plain and tag is None) or tag == '!': + implicit = (True, False) + elif tag is None: + implicit = (False, True) + else: + implicit = (False, False) + event = ScalarEvent(anchor, tag, implicit, token.value, + start_mark, end_mark, style=token.style) + self.state = self.states.pop() + elif self.check_token(FlowSequenceStartToken): + end_mark = self.peek_token().end_mark + event = SequenceStartEvent(anchor, tag, implicit, + start_mark, end_mark, flow_style=True) + self.state = self.parse_flow_sequence_first_entry + elif self.check_token(FlowMappingStartToken): + end_mark = self.peek_token().end_mark + event = MappingStartEvent(anchor, tag, implicit, + start_mark, end_mark, flow_style=True) + self.state = self.parse_flow_mapping_first_key + elif block and self.check_token(BlockSequenceStartToken): + end_mark = self.peek_token().start_mark + event = SequenceStartEvent(anchor, tag, implicit, + start_mark, end_mark, flow_style=False) + self.state = self.parse_block_sequence_first_entry + elif block and self.check_token(BlockMappingStartToken): + end_mark = self.peek_token().start_mark + event = MappingStartEvent(anchor, tag, implicit, + start_mark, end_mark, flow_style=False) + self.state = self.parse_block_mapping_first_key + elif anchor is not None or tag is not None: + # Empty scalars are allowed even if a tag or an anchor is + # specified. + event = ScalarEvent(anchor, tag, (implicit, False), '', + start_mark, end_mark) + self.state = self.states.pop() + else: + if block: + node = 'block' + else: + node = 'flow' + token = self.peek_token() + raise ParserError("while parsing a %s node" % node, start_mark, + "expected the node content, but found %r" % token.id, + token.start_mark) + return event + + # block_sequence ::= BLOCK-SEQUENCE-START (BLOCK-ENTRY block_node?)* BLOCK-END + + def parse_block_sequence_first_entry(self): + token = self.get_token() + self.marks.append(token.start_mark) + return self.parse_block_sequence_entry() + + def parse_block_sequence_entry(self): + if self.check_token(BlockEntryToken): + token = self.get_token() + if not self.check_token(BlockEntryToken, BlockEndToken): + self.states.append(self.parse_block_sequence_entry) + return self.parse_block_node() + else: + self.state = self.parse_block_sequence_entry + return self.process_empty_scalar(token.end_mark) + if not self.check_token(BlockEndToken): + token = self.peek_token() + raise ParserError("while parsing a block collection", self.marks[-1], + "expected , but found %r" % token.id, token.start_mark) + token = self.get_token() + event = SequenceEndEvent(token.start_mark, token.end_mark) + self.state = self.states.pop() + self.marks.pop() + return event + + # indentless_sequence ::= (BLOCK-ENTRY block_node?)+ + + def parse_indentless_sequence_entry(self): + if self.check_token(BlockEntryToken): + token = self.get_token() + if not self.check_token(BlockEntryToken, + KeyToken, ValueToken, BlockEndToken): + self.states.append(self.parse_indentless_sequence_entry) + return self.parse_block_node() + else: + self.state = self.parse_indentless_sequence_entry + return self.process_empty_scalar(token.end_mark) + token = self.peek_token() + event = SequenceEndEvent(token.start_mark, token.start_mark) + self.state = self.states.pop() + return event + + # block_mapping ::= BLOCK-MAPPING_START + # ((KEY block_node_or_indentless_sequence?)? + # (VALUE block_node_or_indentless_sequence?)?)* + # BLOCK-END + + def parse_block_mapping_first_key(self): + token = self.get_token() + self.marks.append(token.start_mark) + return self.parse_block_mapping_key() + + def parse_block_mapping_key(self): + if self.check_token(KeyToken): + token = self.get_token() + if not self.check_token(KeyToken, ValueToken, BlockEndToken): + self.states.append(self.parse_block_mapping_value) + return self.parse_block_node_or_indentless_sequence() + else: + self.state = self.parse_block_mapping_value + return self.process_empty_scalar(token.end_mark) + if not self.check_token(BlockEndToken): + token = self.peek_token() + raise ParserError("while parsing a block mapping", self.marks[-1], + "expected , but found %r" % token.id, token.start_mark) + token = self.get_token() + event = MappingEndEvent(token.start_mark, token.end_mark) + self.state = self.states.pop() + self.marks.pop() + return event + + def parse_block_mapping_value(self): + if self.check_token(ValueToken): + token = self.get_token() + if not self.check_token(KeyToken, ValueToken, BlockEndToken): + self.states.append(self.parse_block_mapping_key) + return self.parse_block_node_or_indentless_sequence() + else: + self.state = self.parse_block_mapping_key + return self.process_empty_scalar(token.end_mark) + else: + self.state = self.parse_block_mapping_key + token = self.peek_token() + return self.process_empty_scalar(token.start_mark) + + # flow_sequence ::= FLOW-SEQUENCE-START + # (flow_sequence_entry FLOW-ENTRY)* + # flow_sequence_entry? + # FLOW-SEQUENCE-END + # flow_sequence_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? + # + # Note that while production rules for both flow_sequence_entry and + # flow_mapping_entry are equal, their interpretations are different. + # For `flow_sequence_entry`, the part `KEY flow_node? (VALUE flow_node?)?` + # generate an inline mapping (set syntax). + + def parse_flow_sequence_first_entry(self): + token = self.get_token() + self.marks.append(token.start_mark) + return self.parse_flow_sequence_entry(first=True) + + def parse_flow_sequence_entry(self, first=False): + if not self.check_token(FlowSequenceEndToken): + if not first: + if self.check_token(FlowEntryToken): + self.get_token() + else: + token = self.peek_token() + raise ParserError("while parsing a flow sequence", self.marks[-1], + "expected ',' or ']', but got %r" % token.id, token.start_mark) + + if self.check_token(KeyToken): + token = self.peek_token() + event = MappingStartEvent(None, None, True, + token.start_mark, token.end_mark, + flow_style=True) + self.state = self.parse_flow_sequence_entry_mapping_key + return event + elif not self.check_token(FlowSequenceEndToken): + self.states.append(self.parse_flow_sequence_entry) + return self.parse_flow_node() + token = self.get_token() + event = SequenceEndEvent(token.start_mark, token.end_mark) + self.state = self.states.pop() + self.marks.pop() + return event + + def parse_flow_sequence_entry_mapping_key(self): + token = self.get_token() + if not self.check_token(ValueToken, + FlowEntryToken, FlowSequenceEndToken): + self.states.append(self.parse_flow_sequence_entry_mapping_value) + return self.parse_flow_node() + else: + self.state = self.parse_flow_sequence_entry_mapping_value + return self.process_empty_scalar(token.end_mark) + + def parse_flow_sequence_entry_mapping_value(self): + if self.check_token(ValueToken): + token = self.get_token() + if not self.check_token(FlowEntryToken, FlowSequenceEndToken): + self.states.append(self.parse_flow_sequence_entry_mapping_end) + return self.parse_flow_node() + else: + self.state = self.parse_flow_sequence_entry_mapping_end + return self.process_empty_scalar(token.end_mark) + else: + self.state = self.parse_flow_sequence_entry_mapping_end + token = self.peek_token() + return self.process_empty_scalar(token.start_mark) + + def parse_flow_sequence_entry_mapping_end(self): + self.state = self.parse_flow_sequence_entry + token = self.peek_token() + return MappingEndEvent(token.start_mark, token.start_mark) + + # flow_mapping ::= FLOW-MAPPING-START + # (flow_mapping_entry FLOW-ENTRY)* + # flow_mapping_entry? + # FLOW-MAPPING-END + # flow_mapping_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? + + def parse_flow_mapping_first_key(self): + token = self.get_token() + self.marks.append(token.start_mark) + return self.parse_flow_mapping_key(first=True) + + def parse_flow_mapping_key(self, first=False): + if not self.check_token(FlowMappingEndToken): + if not first: + if self.check_token(FlowEntryToken): + self.get_token() + else: + token = self.peek_token() + raise ParserError("while parsing a flow mapping", self.marks[-1], + "expected ',' or '}', but got %r" % token.id, token.start_mark) + if self.check_token(KeyToken): + token = self.get_token() + if not self.check_token(ValueToken, + FlowEntryToken, FlowMappingEndToken): + self.states.append(self.parse_flow_mapping_value) + return self.parse_flow_node() + else: + self.state = self.parse_flow_mapping_value + return self.process_empty_scalar(token.end_mark) + elif not self.check_token(FlowMappingEndToken): + self.states.append(self.parse_flow_mapping_empty_value) + return self.parse_flow_node() + token = self.get_token() + event = MappingEndEvent(token.start_mark, token.end_mark) + self.state = self.states.pop() + self.marks.pop() + return event + + def parse_flow_mapping_value(self): + if self.check_token(ValueToken): + token = self.get_token() + if not self.check_token(FlowEntryToken, FlowMappingEndToken): + self.states.append(self.parse_flow_mapping_key) + return self.parse_flow_node() + else: + self.state = self.parse_flow_mapping_key + return self.process_empty_scalar(token.end_mark) + else: + self.state = self.parse_flow_mapping_key + token = self.peek_token() + return self.process_empty_scalar(token.start_mark) + + def parse_flow_mapping_empty_value(self): + self.state = self.parse_flow_mapping_key + return self.process_empty_scalar(self.peek_token().start_mark) + + def process_empty_scalar(self, mark): + return ScalarEvent(None, None, (True, False), '', mark, mark) + diff --git a/source/yaml/reader.py b/source/yaml/reader.py new file mode 100644 index 0000000000000000000000000000000000000000..774b0219b5932a0ee1c27e637371de5ba8d9cb16 --- /dev/null +++ b/source/yaml/reader.py @@ -0,0 +1,185 @@ +# This module contains abstractions for the input stream. You don't have to +# looks further, there are no pretty code. +# +# We define two classes here. +# +# Mark(source, line, column) +# It's just a record and its only use is producing nice error messages. +# Parser does not use it for any other purposes. +# +# Reader(source, data) +# Reader determines the encoding of `data` and converts it to unicode. +# Reader provides the following methods and attributes: +# reader.peek(length=1) - return the next `length` characters +# reader.forward(length=1) - move the current position to `length` characters. +# reader.index - the number of the current character. +# reader.line, stream.column - the line and the column of the current character. + +__all__ = ['Reader', 'ReaderError'] + +from .error import YAMLError, Mark + +import codecs, re + +class ReaderError(YAMLError): + + def __init__(self, name, position, character, encoding, reason): + self.name = name + self.character = character + self.position = position + self.encoding = encoding + self.reason = reason + + def __str__(self): + if isinstance(self.character, bytes): + return "'%s' codec can't decode byte #x%02x: %s\n" \ + " in \"%s\", position %d" \ + % (self.encoding, ord(self.character), self.reason, + self.name, self.position) + else: + return "unacceptable character #x%04x: %s\n" \ + " in \"%s\", position %d" \ + % (self.character, self.reason, + self.name, self.position) + +class Reader(object): + # Reader: + # - determines the data encoding and converts it to a unicode string, + # - checks if characters are in allowed range, + # - adds '\0' to the end. + + # Reader accepts + # - a `bytes` object, + # - a `str` object, + # - a file-like object with its `read` method returning `str`, + # - a file-like object with its `read` method returning `unicode`. + + # Yeah, it's ugly and slow. + + def __init__(self, stream): + self.name = None + self.stream = None + self.stream_pointer = 0 + self.eof = True + self.buffer = '' + self.pointer = 0 + self.raw_buffer = None + self.raw_decode = None + self.encoding = None + self.index = 0 + self.line = 0 + self.column = 0 + if isinstance(stream, str): + self.name = "" + self.check_printable(stream) + self.buffer = stream+'\0' + elif isinstance(stream, bytes): + self.name = "" + self.raw_buffer = stream + self.determine_encoding() + else: + self.stream = stream + self.name = getattr(stream, 'name', "") + self.eof = False + self.raw_buffer = None + self.determine_encoding() + + def peek(self, index=0): + try: + return self.buffer[self.pointer+index] + except IndexError: + self.update(index+1) + return self.buffer[self.pointer+index] + + def prefix(self, length=1): + if self.pointer+length >= len(self.buffer): + self.update(length) + return self.buffer[self.pointer:self.pointer+length] + + def forward(self, length=1): + if self.pointer+length+1 >= len(self.buffer): + self.update(length+1) + while length: + ch = self.buffer[self.pointer] + self.pointer += 1 + self.index += 1 + if ch in '\n\x85\u2028\u2029' \ + or (ch == '\r' and self.buffer[self.pointer] != '\n'): + self.line += 1 + self.column = 0 + elif ch != '\uFEFF': + self.column += 1 + length -= 1 + + def get_mark(self): + if self.stream is None: + return Mark(self.name, self.index, self.line, self.column, + self.buffer, self.pointer) + else: + return Mark(self.name, self.index, self.line, self.column, + None, None) + + def determine_encoding(self): + while not self.eof and (self.raw_buffer is None or len(self.raw_buffer) < 2): + self.update_raw() + if isinstance(self.raw_buffer, bytes): + if self.raw_buffer.startswith(codecs.BOM_UTF16_LE): + self.raw_decode = codecs.utf_16_le_decode + self.encoding = 'utf-16-le' + elif self.raw_buffer.startswith(codecs.BOM_UTF16_BE): + self.raw_decode = codecs.utf_16_be_decode + self.encoding = 'utf-16-be' + else: + self.raw_decode = codecs.utf_8_decode + self.encoding = 'utf-8' + self.update(1) + + NON_PRINTABLE = re.compile('[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]') + def check_printable(self, data): + match = self.NON_PRINTABLE.search(data) + if match: + character = match.group() + position = self.index+(len(self.buffer)-self.pointer)+match.start() + raise ReaderError(self.name, position, ord(character), + 'unicode', "special characters are not allowed") + + def update(self, length): + if self.raw_buffer is None: + return + self.buffer = self.buffer[self.pointer:] + self.pointer = 0 + while len(self.buffer) < length: + if not self.eof: + self.update_raw() + if self.raw_decode is not None: + try: + data, converted = self.raw_decode(self.raw_buffer, + 'strict', self.eof) + except UnicodeDecodeError as exc: + character = self.raw_buffer[exc.start] + if self.stream is not None: + position = self.stream_pointer-len(self.raw_buffer)+exc.start + else: + position = exc.start + raise ReaderError(self.name, position, character, + exc.encoding, exc.reason) + else: + data = self.raw_buffer + converted = len(data) + self.check_printable(data) + self.buffer += data + self.raw_buffer = self.raw_buffer[converted:] + if self.eof: + self.buffer += '\0' + self.raw_buffer = None + break + + def update_raw(self, size=4096): + data = self.stream.read(size) + if self.raw_buffer is None: + self.raw_buffer = data + else: + self.raw_buffer += data + self.stream_pointer += len(data) + if not data: + self.eof = True diff --git a/source/yaml/representer.py b/source/yaml/representer.py new file mode 100644 index 0000000000000000000000000000000000000000..808ca06dfbd60c9a23eb079151b74a82ef688749 --- /dev/null +++ b/source/yaml/representer.py @@ -0,0 +1,389 @@ + +__all__ = ['BaseRepresenter', 'SafeRepresenter', 'Representer', + 'RepresenterError'] + +from .error import * +from .nodes import * + +import datetime, copyreg, types, base64, collections + +class RepresenterError(YAMLError): + pass + +class BaseRepresenter: + + yaml_representers = {} + yaml_multi_representers = {} + + def __init__(self, default_style=None, default_flow_style=False, sort_keys=True): + self.default_style = default_style + self.sort_keys = sort_keys + self.default_flow_style = default_flow_style + self.represented_objects = {} + self.object_keeper = [] + self.alias_key = None + + def represent(self, data): + node = self.represent_data(data) + self.serialize(node) + self.represented_objects = {} + self.object_keeper = [] + self.alias_key = None + + def represent_data(self, data): + if self.ignore_aliases(data): + self.alias_key = None + else: + self.alias_key = id(data) + if self.alias_key is not None: + if self.alias_key in self.represented_objects: + node = self.represented_objects[self.alias_key] + #if node is None: + # raise RepresenterError("recursive objects are not allowed: %r" % data) + return node + #self.represented_objects[alias_key] = None + self.object_keeper.append(data) + data_types = type(data).__mro__ + if data_types[0] in self.yaml_representers: + node = self.yaml_representers[data_types[0]](self, data) + else: + for data_type in data_types: + if data_type in self.yaml_multi_representers: + node = self.yaml_multi_representers[data_type](self, data) + break + else: + if None in self.yaml_multi_representers: + node = self.yaml_multi_representers[None](self, data) + elif None in self.yaml_representers: + node = self.yaml_representers[None](self, data) + else: + node = ScalarNode(None, str(data)) + #if alias_key is not None: + # self.represented_objects[alias_key] = node + return node + + @classmethod + def add_representer(cls, data_type, representer): + if not 'yaml_representers' in cls.__dict__: + cls.yaml_representers = cls.yaml_representers.copy() + cls.yaml_representers[data_type] = representer + + @classmethod + def add_multi_representer(cls, data_type, representer): + if not 'yaml_multi_representers' in cls.__dict__: + cls.yaml_multi_representers = cls.yaml_multi_representers.copy() + cls.yaml_multi_representers[data_type] = representer + + def represent_scalar(self, tag, value, style=None): + if style is None: + style = self.default_style + node = ScalarNode(tag, value, style=style) + if self.alias_key is not None: + self.represented_objects[self.alias_key] = node + return node + + def represent_sequence(self, tag, sequence, flow_style=None): + value = [] + node = SequenceNode(tag, value, flow_style=flow_style) + if self.alias_key is not None: + self.represented_objects[self.alias_key] = node + best_style = True + for item in sequence: + node_item = self.represent_data(item) + if not (isinstance(node_item, ScalarNode) and not node_item.style): + best_style = False + value.append(node_item) + if flow_style is None: + if self.default_flow_style is not None: + node.flow_style = self.default_flow_style + else: + node.flow_style = best_style + return node + + def represent_mapping(self, tag, mapping, flow_style=None): + value = [] + node = MappingNode(tag, value, flow_style=flow_style) + if self.alias_key is not None: + self.represented_objects[self.alias_key] = node + best_style = True + if hasattr(mapping, 'items'): + mapping = list(mapping.items()) + if self.sort_keys: + try: + mapping = sorted(mapping) + except TypeError: + pass + for item_key, item_value in mapping: + node_key = self.represent_data(item_key) + node_value = self.represent_data(item_value) + if not (isinstance(node_key, ScalarNode) and not node_key.style): + best_style = False + if not (isinstance(node_value, ScalarNode) and not node_value.style): + best_style = False + value.append((node_key, node_value)) + if flow_style is None: + if self.default_flow_style is not None: + node.flow_style = self.default_flow_style + else: + node.flow_style = best_style + return node + + def ignore_aliases(self, data): + return False + +class SafeRepresenter(BaseRepresenter): + + def ignore_aliases(self, data): + if data is None: + return True + if isinstance(data, tuple) and data == (): + return True + if isinstance(data, (str, bytes, bool, int, float)): + return True + + def represent_none(self, data): + return self.represent_scalar('tag:yaml.org,2002:null', 'null') + + def represent_str(self, data): + return self.represent_scalar('tag:yaml.org,2002:str', data) + + def represent_binary(self, data): + if hasattr(base64, 'encodebytes'): + data = base64.encodebytes(data).decode('ascii') + else: + data = base64.encodestring(data).decode('ascii') + return self.represent_scalar('tag:yaml.org,2002:binary', data, style='|') + + def represent_bool(self, data): + if data: + value = 'true' + else: + value = 'false' + return self.represent_scalar('tag:yaml.org,2002:bool', value) + + def represent_int(self, data): + return self.represent_scalar('tag:yaml.org,2002:int', str(data)) + + inf_value = 1e300 + while repr(inf_value) != repr(inf_value*inf_value): + inf_value *= inf_value + + def represent_float(self, data): + if data != data or (data == 0.0 and data == 1.0): + value = '.nan' + elif data == self.inf_value: + value = '.inf' + elif data == -self.inf_value: + value = '-.inf' + else: + value = repr(data).lower() + # Note that in some cases `repr(data)` represents a float number + # without the decimal parts. For instance: + # >>> repr(1e17) + # '1e17' + # Unfortunately, this is not a valid float representation according + # to the definition of the `!!float` tag. We fix this by adding + # '.0' before the 'e' symbol. + if '.' not in value and 'e' in value: + value = value.replace('e', '.0e', 1) + return self.represent_scalar('tag:yaml.org,2002:float', value) + + def represent_list(self, data): + #pairs = (len(data) > 0 and isinstance(data, list)) + #if pairs: + # for item in data: + # if not isinstance(item, tuple) or len(item) != 2: + # pairs = False + # break + #if not pairs: + return self.represent_sequence('tag:yaml.org,2002:seq', data) + #value = [] + #for item_key, item_value in data: + # value.append(self.represent_mapping(u'tag:yaml.org,2002:map', + # [(item_key, item_value)])) + #return SequenceNode(u'tag:yaml.org,2002:pairs', value) + + def represent_dict(self, data): + return self.represent_mapping('tag:yaml.org,2002:map', data) + + def represent_set(self, data): + value = {} + for key in data: + value[key] = None + return self.represent_mapping('tag:yaml.org,2002:set', value) + + def represent_date(self, data): + value = data.isoformat() + return self.represent_scalar('tag:yaml.org,2002:timestamp', value) + + def represent_datetime(self, data): + value = data.isoformat(' ') + return self.represent_scalar('tag:yaml.org,2002:timestamp', value) + + def represent_yaml_object(self, tag, data, cls, flow_style=None): + if hasattr(data, '__getstate__'): + state = data.__getstate__() + else: + state = data.__dict__.copy() + return self.represent_mapping(tag, state, flow_style=flow_style) + + def represent_undefined(self, data): + raise RepresenterError("cannot represent an object", data) + +SafeRepresenter.add_representer(type(None), + SafeRepresenter.represent_none) + +SafeRepresenter.add_representer(str, + SafeRepresenter.represent_str) + +SafeRepresenter.add_representer(bytes, + SafeRepresenter.represent_binary) + +SafeRepresenter.add_representer(bool, + SafeRepresenter.represent_bool) + +SafeRepresenter.add_representer(int, + SafeRepresenter.represent_int) + +SafeRepresenter.add_representer(float, + SafeRepresenter.represent_float) + +SafeRepresenter.add_representer(list, + SafeRepresenter.represent_list) + +SafeRepresenter.add_representer(tuple, + SafeRepresenter.represent_list) + +SafeRepresenter.add_representer(dict, + SafeRepresenter.represent_dict) + +SafeRepresenter.add_representer(set, + SafeRepresenter.represent_set) + +SafeRepresenter.add_representer(datetime.date, + SafeRepresenter.represent_date) + +SafeRepresenter.add_representer(datetime.datetime, + SafeRepresenter.represent_datetime) + +SafeRepresenter.add_representer(None, + SafeRepresenter.represent_undefined) + +class Representer(SafeRepresenter): + + def represent_complex(self, data): + if data.imag == 0.0: + data = '%r' % data.real + elif data.real == 0.0: + data = '%rj' % data.imag + elif data.imag > 0: + data = '%r+%rj' % (data.real, data.imag) + else: + data = '%r%rj' % (data.real, data.imag) + return self.represent_scalar('tag:yaml.org,2002:python/complex', data) + + def represent_tuple(self, data): + return self.represent_sequence('tag:yaml.org,2002:python/tuple', data) + + def represent_name(self, data): + name = '%s.%s' % (data.__module__, data.__name__) + return self.represent_scalar('tag:yaml.org,2002:python/name:'+name, '') + + def represent_module(self, data): + return self.represent_scalar( + 'tag:yaml.org,2002:python/module:'+data.__name__, '') + + def represent_object(self, data): + # We use __reduce__ API to save the data. data.__reduce__ returns + # a tuple of length 2-5: + # (function, args, state, listitems, dictitems) + + # For reconstructing, we calls function(*args), then set its state, + # listitems, and dictitems if they are not None. + + # A special case is when function.__name__ == '__newobj__'. In this + # case we create the object with args[0].__new__(*args). + + # Another special case is when __reduce__ returns a string - we don't + # support it. + + # We produce a !!python/object, !!python/object/new or + # !!python/object/apply node. + + cls = type(data) + if cls in copyreg.dispatch_table: + reduce = copyreg.dispatch_table[cls](data) + elif hasattr(data, '__reduce_ex__'): + reduce = data.__reduce_ex__(2) + elif hasattr(data, '__reduce__'): + reduce = data.__reduce__() + else: + raise RepresenterError("cannot represent an object", data) + reduce = (list(reduce)+[None]*5)[:5] + function, args, state, listitems, dictitems = reduce + args = list(args) + if state is None: + state = {} + if listitems is not None: + listitems = list(listitems) + if dictitems is not None: + dictitems = dict(dictitems) + if function.__name__ == '__newobj__': + function = args[0] + args = args[1:] + tag = 'tag:yaml.org,2002:python/object/new:' + newobj = True + else: + tag = 'tag:yaml.org,2002:python/object/apply:' + newobj = False + function_name = '%s.%s' % (function.__module__, function.__name__) + if not args and not listitems and not dictitems \ + and isinstance(state, dict) and newobj: + return self.represent_mapping( + 'tag:yaml.org,2002:python/object:'+function_name, state) + if not listitems and not dictitems \ + and isinstance(state, dict) and not state: + return self.represent_sequence(tag+function_name, args) + value = {} + if args: + value['args'] = args + if state or not isinstance(state, dict): + value['state'] = state + if listitems: + value['listitems'] = listitems + if dictitems: + value['dictitems'] = dictitems + return self.represent_mapping(tag+function_name, value) + + def represent_ordered_dict(self, data): + # Provide uniform representation across different Python versions. + data_type = type(data) + tag = 'tag:yaml.org,2002:python/object/apply:%s.%s' \ + % (data_type.__module__, data_type.__name__) + items = [[key, value] for key, value in data.items()] + return self.represent_sequence(tag, [items]) + +Representer.add_representer(complex, + Representer.represent_complex) + +Representer.add_representer(tuple, + Representer.represent_tuple) + +Representer.add_multi_representer(type, + Representer.represent_name) + +Representer.add_representer(collections.OrderedDict, + Representer.represent_ordered_dict) + +Representer.add_representer(types.FunctionType, + Representer.represent_name) + +Representer.add_representer(types.BuiltinFunctionType, + Representer.represent_name) + +Representer.add_representer(types.ModuleType, + Representer.represent_module) + +Representer.add_multi_representer(object, + Representer.represent_object) + diff --git a/source/yaml/resolver.py b/source/yaml/resolver.py new file mode 100644 index 0000000000000000000000000000000000000000..3522bdaaf6358110b608f4e6503b9d314c82d887 --- /dev/null +++ b/source/yaml/resolver.py @@ -0,0 +1,227 @@ + +__all__ = ['BaseResolver', 'Resolver'] + +from .error import * +from .nodes import * + +import re + +class ResolverError(YAMLError): + pass + +class BaseResolver: + + DEFAULT_SCALAR_TAG = 'tag:yaml.org,2002:str' + DEFAULT_SEQUENCE_TAG = 'tag:yaml.org,2002:seq' + DEFAULT_MAPPING_TAG = 'tag:yaml.org,2002:map' + + yaml_implicit_resolvers = {} + yaml_path_resolvers = {} + + def __init__(self): + self.resolver_exact_paths = [] + self.resolver_prefix_paths = [] + + @classmethod + def add_implicit_resolver(cls, tag, regexp, first): + if not 'yaml_implicit_resolvers' in cls.__dict__: + implicit_resolvers = {} + for key in cls.yaml_implicit_resolvers: + implicit_resolvers[key] = cls.yaml_implicit_resolvers[key][:] + cls.yaml_implicit_resolvers = implicit_resolvers + if first is None: + first = [None] + for ch in first: + cls.yaml_implicit_resolvers.setdefault(ch, []).append((tag, regexp)) + + @classmethod + def add_path_resolver(cls, tag, path, kind=None): + # Note: `add_path_resolver` is experimental. The API could be changed. + # `new_path` is a pattern that is matched against the path from the + # root to the node that is being considered. `node_path` elements are + # tuples `(node_check, index_check)`. `node_check` is a node class: + # `ScalarNode`, `SequenceNode`, `MappingNode` or `None`. `None` + # matches any kind of a node. `index_check` could be `None`, a boolean + # value, a string value, or a number. `None` and `False` match against + # any _value_ of sequence and mapping nodes. `True` matches against + # any _key_ of a mapping node. A string `index_check` matches against + # a mapping value that corresponds to a scalar key which content is + # equal to the `index_check` value. An integer `index_check` matches + # against a sequence value with the index equal to `index_check`. + if not 'yaml_path_resolvers' in cls.__dict__: + cls.yaml_path_resolvers = cls.yaml_path_resolvers.copy() + new_path = [] + for element in path: + if isinstance(element, (list, tuple)): + if len(element) == 2: + node_check, index_check = element + elif len(element) == 1: + node_check = element[0] + index_check = True + else: + raise ResolverError("Invalid path element: %s" % element) + else: + node_check = None + index_check = element + if node_check is str: + node_check = ScalarNode + elif node_check is list: + node_check = SequenceNode + elif node_check is dict: + node_check = MappingNode + elif node_check not in [ScalarNode, SequenceNode, MappingNode] \ + and not isinstance(node_check, str) \ + and node_check is not None: + raise ResolverError("Invalid node checker: %s" % node_check) + if not isinstance(index_check, (str, int)) \ + and index_check is not None: + raise ResolverError("Invalid index checker: %s" % index_check) + new_path.append((node_check, index_check)) + if kind is str: + kind = ScalarNode + elif kind is list: + kind = SequenceNode + elif kind is dict: + kind = MappingNode + elif kind not in [ScalarNode, SequenceNode, MappingNode] \ + and kind is not None: + raise ResolverError("Invalid node kind: %s" % kind) + cls.yaml_path_resolvers[tuple(new_path), kind] = tag + + def descend_resolver(self, current_node, current_index): + if not self.yaml_path_resolvers: + return + exact_paths = {} + prefix_paths = [] + if current_node: + depth = len(self.resolver_prefix_paths) + for path, kind in self.resolver_prefix_paths[-1]: + if self.check_resolver_prefix(depth, path, kind, + current_node, current_index): + if len(path) > depth: + prefix_paths.append((path, kind)) + else: + exact_paths[kind] = self.yaml_path_resolvers[path, kind] + else: + for path, kind in self.yaml_path_resolvers: + if not path: + exact_paths[kind] = self.yaml_path_resolvers[path, kind] + else: + prefix_paths.append((path, kind)) + self.resolver_exact_paths.append(exact_paths) + self.resolver_prefix_paths.append(prefix_paths) + + def ascend_resolver(self): + if not self.yaml_path_resolvers: + return + self.resolver_exact_paths.pop() + self.resolver_prefix_paths.pop() + + def check_resolver_prefix(self, depth, path, kind, + current_node, current_index): + node_check, index_check = path[depth-1] + if isinstance(node_check, str): + if current_node.tag != node_check: + return + elif node_check is not None: + if not isinstance(current_node, node_check): + return + if index_check is True and current_index is not None: + return + if (index_check is False or index_check is None) \ + and current_index is None: + return + if isinstance(index_check, str): + if not (isinstance(current_index, ScalarNode) + and index_check == current_index.value): + return + elif isinstance(index_check, int) and not isinstance(index_check, bool): + if index_check != current_index: + return + return True + + def resolve(self, kind, value, implicit): + if kind is ScalarNode and implicit[0]: + if value == '': + resolvers = self.yaml_implicit_resolvers.get('', []) + else: + resolvers = self.yaml_implicit_resolvers.get(value[0], []) + wildcard_resolvers = self.yaml_implicit_resolvers.get(None, []) + for tag, regexp in resolvers + wildcard_resolvers: + if regexp.match(value): + return tag + implicit = implicit[1] + if self.yaml_path_resolvers: + exact_paths = self.resolver_exact_paths[-1] + if kind in exact_paths: + return exact_paths[kind] + if None in exact_paths: + return exact_paths[None] + if kind is ScalarNode: + return self.DEFAULT_SCALAR_TAG + elif kind is SequenceNode: + return self.DEFAULT_SEQUENCE_TAG + elif kind is MappingNode: + return self.DEFAULT_MAPPING_TAG + +class Resolver(BaseResolver): + pass + +Resolver.add_implicit_resolver( + 'tag:yaml.org,2002:bool', + re.compile(r'''^(?:yes|Yes|YES|no|No|NO + |true|True|TRUE|false|False|FALSE + |on|On|ON|off|Off|OFF)$''', re.X), + list('yYnNtTfFoO')) + +Resolver.add_implicit_resolver( + 'tag:yaml.org,2002:float', + re.compile(r'''^(?:[-+]?(?:[0-9][0-9_]*)\.[0-9_]*(?:[eE][-+][0-9]+)? + |\.[0-9][0-9_]*(?:[eE][-+][0-9]+)? + |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\.[0-9_]* + |[-+]?\.(?:inf|Inf|INF) + |\.(?:nan|NaN|NAN))$''', re.X), + list('-+0123456789.')) + +Resolver.add_implicit_resolver( + 'tag:yaml.org,2002:int', + re.compile(r'''^(?:[-+]?0b[0-1_]+ + |[-+]?0[0-7_]+ + |[-+]?(?:0|[1-9][0-9_]*) + |[-+]?0x[0-9a-fA-F_]+ + |[-+]?[1-9][0-9_]*(?::[0-5]?[0-9])+)$''', re.X), + list('-+0123456789')) + +Resolver.add_implicit_resolver( + 'tag:yaml.org,2002:merge', + re.compile(r'^(?:<<)$'), + ['<']) + +Resolver.add_implicit_resolver( + 'tag:yaml.org,2002:null', + re.compile(r'''^(?: ~ + |null|Null|NULL + | )$''', re.X), + ['~', 'n', 'N', '']) + +Resolver.add_implicit_resolver( + 'tag:yaml.org,2002:timestamp', + re.compile(r'''^(?:[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9] + |[0-9][0-9][0-9][0-9] -[0-9][0-9]? -[0-9][0-9]? + (?:[Tt]|[ \t]+)[0-9][0-9]? + :[0-9][0-9] :[0-9][0-9] (?:\.[0-9]*)? + (?:[ \t]*(?:Z|[-+][0-9][0-9]?(?::[0-9][0-9])?))?)$''', re.X), + list('0123456789')) + +Resolver.add_implicit_resolver( + 'tag:yaml.org,2002:value', + re.compile(r'^(?:=)$'), + ['=']) + +# The following resolver is only for documentation purposes. It cannot work +# because plain scalars cannot start with '!', '&', or '*'. +Resolver.add_implicit_resolver( + 'tag:yaml.org,2002:yaml', + re.compile(r'^(?:!|&|\*)$'), + list('!&*')) + diff --git a/source/yaml/scanner.py b/source/yaml/scanner.py new file mode 100644 index 0000000000000000000000000000000000000000..de925b07f1eaec33c9c305a8a69f9eb7ac5983c5 --- /dev/null +++ b/source/yaml/scanner.py @@ -0,0 +1,1435 @@ + +# Scanner produces tokens of the following types: +# STREAM-START +# STREAM-END +# DIRECTIVE(name, value) +# DOCUMENT-START +# DOCUMENT-END +# BLOCK-SEQUENCE-START +# BLOCK-MAPPING-START +# BLOCK-END +# FLOW-SEQUENCE-START +# FLOW-MAPPING-START +# FLOW-SEQUENCE-END +# FLOW-MAPPING-END +# BLOCK-ENTRY +# FLOW-ENTRY +# KEY +# VALUE +# ALIAS(value) +# ANCHOR(value) +# TAG(value) +# SCALAR(value, plain, style) +# +# Read comments in the Scanner code for more details. +# + +__all__ = ['Scanner', 'ScannerError'] + +from .error import MarkedYAMLError +from .tokens import * + +class ScannerError(MarkedYAMLError): + pass + +class SimpleKey: + # See below simple keys treatment. + + def __init__(self, token_number, required, index, line, column, mark): + self.token_number = token_number + self.required = required + self.index = index + self.line = line + self.column = column + self.mark = mark + +class Scanner: + + def __init__(self): + """Initialize the scanner.""" + # It is assumed that Scanner and Reader will have a common descendant. + # Reader do the dirty work of checking for BOM and converting the + # input data to Unicode. It also adds NUL to the end. + # + # Reader supports the following methods + # self.peek(i=0) # peek the next i-th character + # self.prefix(l=1) # peek the next l characters + # self.forward(l=1) # read the next l characters and move the pointer. + + # Had we reached the end of the stream? + self.done = False + + # The number of unclosed '{' and '['. `flow_level == 0` means block + # context. + self.flow_level = 0 + + # List of processed tokens that are not yet emitted. + self.tokens = [] + + # Add the STREAM-START token. + self.fetch_stream_start() + + # Number of tokens that were emitted through the `get_token` method. + self.tokens_taken = 0 + + # The current indentation level. + self.indent = -1 + + # Past indentation levels. + self.indents = [] + + # Variables related to simple keys treatment. + + # A simple key is a key that is not denoted by the '?' indicator. + # Example of simple keys: + # --- + # block simple key: value + # ? not a simple key: + # : { flow simple key: value } + # We emit the KEY token before all keys, so when we find a potential + # simple key, we try to locate the corresponding ':' indicator. + # Simple keys should be limited to a single line and 1024 characters. + + # Can a simple key start at the current position? A simple key may + # start: + # - at the beginning of the line, not counting indentation spaces + # (in block context), + # - after '{', '[', ',' (in the flow context), + # - after '?', ':', '-' (in the block context). + # In the block context, this flag also signifies if a block collection + # may start at the current position. + self.allow_simple_key = True + + # Keep track of possible simple keys. This is a dictionary. The key + # is `flow_level`; there can be no more that one possible simple key + # for each level. The value is a SimpleKey record: + # (token_number, required, index, line, column, mark) + # A simple key may start with ALIAS, ANCHOR, TAG, SCALAR(flow), + # '[', or '{' tokens. + self.possible_simple_keys = {} + + # Public methods. + + def check_token(self, *choices): + # Check if the next token is one of the given types. + while self.need_more_tokens(): + self.fetch_more_tokens() + if self.tokens: + if not choices: + return True + for choice in choices: + if isinstance(self.tokens[0], choice): + return True + return False + + def peek_token(self): + # Return the next token, but do not delete if from the queue. + # Return None if no more tokens. + while self.need_more_tokens(): + self.fetch_more_tokens() + if self.tokens: + return self.tokens[0] + else: + return None + + def get_token(self): + # Return the next token. + while self.need_more_tokens(): + self.fetch_more_tokens() + if self.tokens: + self.tokens_taken += 1 + return self.tokens.pop(0) + + # Private methods. + + def need_more_tokens(self): + if self.done: + return False + if not self.tokens: + return True + # The current token may be a potential simple key, so we + # need to look further. + self.stale_possible_simple_keys() + if self.next_possible_simple_key() == self.tokens_taken: + return True + + def fetch_more_tokens(self): + + # Eat whitespaces and comments until we reach the next token. + self.scan_to_next_token() + + # Remove obsolete possible simple keys. + self.stale_possible_simple_keys() + + # Compare the current indentation and column. It may add some tokens + # and decrease the current indentation level. + self.unwind_indent(self.column) + + # Peek the next character. + ch = self.peek() + + # Is it the end of stream? + if ch == '\0': + return self.fetch_stream_end() + + # Is it a directive? + if ch == '%' and self.check_directive(): + return self.fetch_directive() + + # Is it the document start? + if ch == '-' and self.check_document_start(): + return self.fetch_document_start() + + # Is it the document end? + if ch == '.' and self.check_document_end(): + return self.fetch_document_end() + + # TODO: support for BOM within a stream. + #if ch == '\uFEFF': + # return self.fetch_bom() <-- issue BOMToken + + # Note: the order of the following checks is NOT significant. + + # Is it the flow sequence start indicator? + if ch == '[': + return self.fetch_flow_sequence_start() + + # Is it the flow mapping start indicator? + if ch == '{': + return self.fetch_flow_mapping_start() + + # Is it the flow sequence end indicator? + if ch == ']': + return self.fetch_flow_sequence_end() + + # Is it the flow mapping end indicator? + if ch == '}': + return self.fetch_flow_mapping_end() + + # Is it the flow entry indicator? + if ch == ',': + return self.fetch_flow_entry() + + # Is it the block entry indicator? + if ch == '-' and self.check_block_entry(): + return self.fetch_block_entry() + + # Is it the key indicator? + if ch == '?' and self.check_key(): + return self.fetch_key() + + # Is it the value indicator? + if ch == ':' and self.check_value(): + return self.fetch_value() + + # Is it an alias? + if ch == '*': + return self.fetch_alias() + + # Is it an anchor? + if ch == '&': + return self.fetch_anchor() + + # Is it a tag? + if ch == '!': + return self.fetch_tag() + + # Is it a literal scalar? + if ch == '|' and not self.flow_level: + return self.fetch_literal() + + # Is it a folded scalar? + if ch == '>' and not self.flow_level: + return self.fetch_folded() + + # Is it a single quoted scalar? + if ch == '\'': + return self.fetch_single() + + # Is it a double quoted scalar? + if ch == '\"': + return self.fetch_double() + + # It must be a plain scalar then. + if self.check_plain(): + return self.fetch_plain() + + # No? It's an error. Let's produce a nice error message. + raise ScannerError("while scanning for the next token", None, + "found character %r that cannot start any token" % ch, + self.get_mark()) + + # Simple keys treatment. + + def next_possible_simple_key(self): + # Return the number of the nearest possible simple key. Actually we + # don't need to loop through the whole dictionary. We may replace it + # with the following code: + # if not self.possible_simple_keys: + # return None + # return self.possible_simple_keys[ + # min(self.possible_simple_keys.keys())].token_number + min_token_number = None + for level in self.possible_simple_keys: + key = self.possible_simple_keys[level] + if min_token_number is None or key.token_number < min_token_number: + min_token_number = key.token_number + return min_token_number + + def stale_possible_simple_keys(self): + # Remove entries that are no longer possible simple keys. According to + # the YAML specification, simple keys + # - should be limited to a single line, + # - should be no longer than 1024 characters. + # Disabling this procedure will allow simple keys of any length and + # height (may cause problems if indentation is broken though). + for level in list(self.possible_simple_keys): + key = self.possible_simple_keys[level] + if key.line != self.line \ + or self.index-key.index > 1024: + if key.required: + raise ScannerError("while scanning a simple key", key.mark, + "could not find expected ':'", self.get_mark()) + del self.possible_simple_keys[level] + + def save_possible_simple_key(self): + # The next token may start a simple key. We check if it's possible + # and save its position. This function is called for + # ALIAS, ANCHOR, TAG, SCALAR(flow), '[', and '{'. + + # Check if a simple key is required at the current position. + required = not self.flow_level and self.indent == self.column + + # The next token might be a simple key. Let's save it's number and + # position. + if self.allow_simple_key: + self.remove_possible_simple_key() + token_number = self.tokens_taken+len(self.tokens) + key = SimpleKey(token_number, required, + self.index, self.line, self.column, self.get_mark()) + self.possible_simple_keys[self.flow_level] = key + + def remove_possible_simple_key(self): + # Remove the saved possible key position at the current flow level. + if self.flow_level in self.possible_simple_keys: + key = self.possible_simple_keys[self.flow_level] + + if key.required: + raise ScannerError("while scanning a simple key", key.mark, + "could not find expected ':'", self.get_mark()) + + del self.possible_simple_keys[self.flow_level] + + # Indentation functions. + + def unwind_indent(self, column): + + ## In flow context, tokens should respect indentation. + ## Actually the condition should be `self.indent >= column` according to + ## the spec. But this condition will prohibit intuitively correct + ## constructions such as + ## key : { + ## } + #if self.flow_level and self.indent > column: + # raise ScannerError(None, None, + # "invalid indentation or unclosed '[' or '{'", + # self.get_mark()) + + # In the flow context, indentation is ignored. We make the scanner less + # restrictive then specification requires. + if self.flow_level: + return + + # In block context, we may need to issue the BLOCK-END tokens. + while self.indent > column: + mark = self.get_mark() + self.indent = self.indents.pop() + self.tokens.append(BlockEndToken(mark, mark)) + + def add_indent(self, column): + # Check if we need to increase indentation. + if self.indent < column: + self.indents.append(self.indent) + self.indent = column + return True + return False + + # Fetchers. + + def fetch_stream_start(self): + # We always add STREAM-START as the first token and STREAM-END as the + # last token. + + # Read the token. + mark = self.get_mark() + + # Add STREAM-START. + self.tokens.append(StreamStartToken(mark, mark, + encoding=self.encoding)) + + + def fetch_stream_end(self): + + # Set the current indentation to -1. + self.unwind_indent(-1) + + # Reset simple keys. + self.remove_possible_simple_key() + self.allow_simple_key = False + self.possible_simple_keys = {} + + # Read the token. + mark = self.get_mark() + + # Add STREAM-END. + self.tokens.append(StreamEndToken(mark, mark)) + + # The steam is finished. + self.done = True + + def fetch_directive(self): + + # Set the current indentation to -1. + self.unwind_indent(-1) + + # Reset simple keys. + self.remove_possible_simple_key() + self.allow_simple_key = False + + # Scan and add DIRECTIVE. + self.tokens.append(self.scan_directive()) + + def fetch_document_start(self): + self.fetch_document_indicator(DocumentStartToken) + + def fetch_document_end(self): + self.fetch_document_indicator(DocumentEndToken) + + def fetch_document_indicator(self, TokenClass): + + # Set the current indentation to -1. + self.unwind_indent(-1) + + # Reset simple keys. Note that there could not be a block collection + # after '---'. + self.remove_possible_simple_key() + self.allow_simple_key = False + + # Add DOCUMENT-START or DOCUMENT-END. + start_mark = self.get_mark() + self.forward(3) + end_mark = self.get_mark() + self.tokens.append(TokenClass(start_mark, end_mark)) + + def fetch_flow_sequence_start(self): + self.fetch_flow_collection_start(FlowSequenceStartToken) + + def fetch_flow_mapping_start(self): + self.fetch_flow_collection_start(FlowMappingStartToken) + + def fetch_flow_collection_start(self, TokenClass): + + # '[' and '{' may start a simple key. + self.save_possible_simple_key() + + # Increase the flow level. + self.flow_level += 1 + + # Simple keys are allowed after '[' and '{'. + self.allow_simple_key = True + + # Add FLOW-SEQUENCE-START or FLOW-MAPPING-START. + start_mark = self.get_mark() + self.forward() + end_mark = self.get_mark() + self.tokens.append(TokenClass(start_mark, end_mark)) + + def fetch_flow_sequence_end(self): + self.fetch_flow_collection_end(FlowSequenceEndToken) + + def fetch_flow_mapping_end(self): + self.fetch_flow_collection_end(FlowMappingEndToken) + + def fetch_flow_collection_end(self, TokenClass): + + # Reset possible simple key on the current level. + self.remove_possible_simple_key() + + # Decrease the flow level. + self.flow_level -= 1 + + # No simple keys after ']' or '}'. + self.allow_simple_key = False + + # Add FLOW-SEQUENCE-END or FLOW-MAPPING-END. + start_mark = self.get_mark() + self.forward() + end_mark = self.get_mark() + self.tokens.append(TokenClass(start_mark, end_mark)) + + def fetch_flow_entry(self): + + # Simple keys are allowed after ','. + self.allow_simple_key = True + + # Reset possible simple key on the current level. + self.remove_possible_simple_key() + + # Add FLOW-ENTRY. + start_mark = self.get_mark() + self.forward() + end_mark = self.get_mark() + self.tokens.append(FlowEntryToken(start_mark, end_mark)) + + def fetch_block_entry(self): + + # Block context needs additional checks. + if not self.flow_level: + + # Are we allowed to start a new entry? + if not self.allow_simple_key: + raise ScannerError(None, None, + "sequence entries are not allowed here", + self.get_mark()) + + # We may need to add BLOCK-SEQUENCE-START. + if self.add_indent(self.column): + mark = self.get_mark() + self.tokens.append(BlockSequenceStartToken(mark, mark)) + + # It's an error for the block entry to occur in the flow context, + # but we let the parser detect this. + else: + pass + + # Simple keys are allowed after '-'. + self.allow_simple_key = True + + # Reset possible simple key on the current level. + self.remove_possible_simple_key() + + # Add BLOCK-ENTRY. + start_mark = self.get_mark() + self.forward() + end_mark = self.get_mark() + self.tokens.append(BlockEntryToken(start_mark, end_mark)) + + def fetch_key(self): + + # Block context needs additional checks. + if not self.flow_level: + + # Are we allowed to start a key (not necessary a simple)? + if not self.allow_simple_key: + raise ScannerError(None, None, + "mapping keys are not allowed here", + self.get_mark()) + + # We may need to add BLOCK-MAPPING-START. + if self.add_indent(self.column): + mark = self.get_mark() + self.tokens.append(BlockMappingStartToken(mark, mark)) + + # Simple keys are allowed after '?' in the block context. + self.allow_simple_key = not self.flow_level + + # Reset possible simple key on the current level. + self.remove_possible_simple_key() + + # Add KEY. + start_mark = self.get_mark() + self.forward() + end_mark = self.get_mark() + self.tokens.append(KeyToken(start_mark, end_mark)) + + def fetch_value(self): + + # Do we determine a simple key? + if self.flow_level in self.possible_simple_keys: + + # Add KEY. + key = self.possible_simple_keys[self.flow_level] + del self.possible_simple_keys[self.flow_level] + self.tokens.insert(key.token_number-self.tokens_taken, + KeyToken(key.mark, key.mark)) + + # If this key starts a new block mapping, we need to add + # BLOCK-MAPPING-START. + if not self.flow_level: + if self.add_indent(key.column): + self.tokens.insert(key.token_number-self.tokens_taken, + BlockMappingStartToken(key.mark, key.mark)) + + # There cannot be two simple keys one after another. + self.allow_simple_key = False + + # It must be a part of a complex key. + else: + + # Block context needs additional checks. + # (Do we really need them? They will be caught by the parser + # anyway.) + if not self.flow_level: + + # We are allowed to start a complex value if and only if + # we can start a simple key. + if not self.allow_simple_key: + raise ScannerError(None, None, + "mapping values are not allowed here", + self.get_mark()) + + # If this value starts a new block mapping, we need to add + # BLOCK-MAPPING-START. It will be detected as an error later by + # the parser. + if not self.flow_level: + if self.add_indent(self.column): + mark = self.get_mark() + self.tokens.append(BlockMappingStartToken(mark, mark)) + + # Simple keys are allowed after ':' in the block context. + self.allow_simple_key = not self.flow_level + + # Reset possible simple key on the current level. + self.remove_possible_simple_key() + + # Add VALUE. + start_mark = self.get_mark() + self.forward() + end_mark = self.get_mark() + self.tokens.append(ValueToken(start_mark, end_mark)) + + def fetch_alias(self): + + # ALIAS could be a simple key. + self.save_possible_simple_key() + + # No simple keys after ALIAS. + self.allow_simple_key = False + + # Scan and add ALIAS. + self.tokens.append(self.scan_anchor(AliasToken)) + + def fetch_anchor(self): + + # ANCHOR could start a simple key. + self.save_possible_simple_key() + + # No simple keys after ANCHOR. + self.allow_simple_key = False + + # Scan and add ANCHOR. + self.tokens.append(self.scan_anchor(AnchorToken)) + + def fetch_tag(self): + + # TAG could start a simple key. + self.save_possible_simple_key() + + # No simple keys after TAG. + self.allow_simple_key = False + + # Scan and add TAG. + self.tokens.append(self.scan_tag()) + + def fetch_literal(self): + self.fetch_block_scalar(style='|') + + def fetch_folded(self): + self.fetch_block_scalar(style='>') + + def fetch_block_scalar(self, style): + + # A simple key may follow a block scalar. + self.allow_simple_key = True + + # Reset possible simple key on the current level. + self.remove_possible_simple_key() + + # Scan and add SCALAR. + self.tokens.append(self.scan_block_scalar(style)) + + def fetch_single(self): + self.fetch_flow_scalar(style='\'') + + def fetch_double(self): + self.fetch_flow_scalar(style='"') + + def fetch_flow_scalar(self, style): + + # A flow scalar could be a simple key. + self.save_possible_simple_key() + + # No simple keys after flow scalars. + self.allow_simple_key = False + + # Scan and add SCALAR. + self.tokens.append(self.scan_flow_scalar(style)) + + def fetch_plain(self): + + # A plain scalar could be a simple key. + self.save_possible_simple_key() + + # No simple keys after plain scalars. But note that `scan_plain` will + # change this flag if the scan is finished at the beginning of the + # line. + self.allow_simple_key = False + + # Scan and add SCALAR. May change `allow_simple_key`. + self.tokens.append(self.scan_plain()) + + # Checkers. + + def check_directive(self): + + # DIRECTIVE: ^ '%' ... + # The '%' indicator is already checked. + if self.column == 0: + return True + + def check_document_start(self): + + # DOCUMENT-START: ^ '---' (' '|'\n') + if self.column == 0: + if self.prefix(3) == '---' \ + and self.peek(3) in '\0 \t\r\n\x85\u2028\u2029': + return True + + def check_document_end(self): + + # DOCUMENT-END: ^ '...' (' '|'\n') + if self.column == 0: + if self.prefix(3) == '...' \ + and self.peek(3) in '\0 \t\r\n\x85\u2028\u2029': + return True + + def check_block_entry(self): + + # BLOCK-ENTRY: '-' (' '|'\n') + return self.peek(1) in '\0 \t\r\n\x85\u2028\u2029' + + def check_key(self): + + # KEY(flow context): '?' + if self.flow_level: + return True + + # KEY(block context): '?' (' '|'\n') + else: + return self.peek(1) in '\0 \t\r\n\x85\u2028\u2029' + + def check_value(self): + + # VALUE(flow context): ':' + if self.flow_level: + return True + + # VALUE(block context): ':' (' '|'\n') + else: + return self.peek(1) in '\0 \t\r\n\x85\u2028\u2029' + + def check_plain(self): + + # A plain scalar may start with any non-space character except: + # '-', '?', ':', ',', '[', ']', '{', '}', + # '#', '&', '*', '!', '|', '>', '\'', '\"', + # '%', '@', '`'. + # + # It may also start with + # '-', '?', ':' + # if it is followed by a non-space character. + # + # Note that we limit the last rule to the block context (except the + # '-' character) because we want the flow context to be space + # independent. + ch = self.peek() + return ch not in '\0 \t\r\n\x85\u2028\u2029-?:,[]{}#&*!|>\'\"%@`' \ + or (self.peek(1) not in '\0 \t\r\n\x85\u2028\u2029' + and (ch == '-' or (not self.flow_level and ch in '?:'))) + + # Scanners. + + def scan_to_next_token(self): + # We ignore spaces, line breaks and comments. + # If we find a line break in the block context, we set the flag + # `allow_simple_key` on. + # The byte order mark is stripped if it's the first character in the + # stream. We do not yet support BOM inside the stream as the + # specification requires. Any such mark will be considered as a part + # of the document. + # + # TODO: We need to make tab handling rules more sane. A good rule is + # Tabs cannot precede tokens + # BLOCK-SEQUENCE-START, BLOCK-MAPPING-START, BLOCK-END, + # KEY(block), VALUE(block), BLOCK-ENTRY + # So the checking code is + # if : + # self.allow_simple_keys = False + # We also need to add the check for `allow_simple_keys == True` to + # `unwind_indent` before issuing BLOCK-END. + # Scanners for block, flow, and plain scalars need to be modified. + + if self.index == 0 and self.peek() == '\uFEFF': + self.forward() + found = False + while not found: + while self.peek() == ' ': + self.forward() + if self.peek() == '#': + while self.peek() not in '\0\r\n\x85\u2028\u2029': + self.forward() + if self.scan_line_break(): + if not self.flow_level: + self.allow_simple_key = True + else: + found = True + + def scan_directive(self): + # See the specification for details. + start_mark = self.get_mark() + self.forward() + name = self.scan_directive_name(start_mark) + value = None + if name == 'YAML': + value = self.scan_yaml_directive_value(start_mark) + end_mark = self.get_mark() + elif name == 'TAG': + value = self.scan_tag_directive_value(start_mark) + end_mark = self.get_mark() + else: + end_mark = self.get_mark() + while self.peek() not in '\0\r\n\x85\u2028\u2029': + self.forward() + self.scan_directive_ignored_line(start_mark) + return DirectiveToken(name, value, start_mark, end_mark) + + def scan_directive_name(self, start_mark): + # See the specification for details. + length = 0 + ch = self.peek(length) + while '0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \ + or ch in '-_': + length += 1 + ch = self.peek(length) + if not length: + raise ScannerError("while scanning a directive", start_mark, + "expected alphabetic or numeric character, but found %r" + % ch, self.get_mark()) + value = self.prefix(length) + self.forward(length) + ch = self.peek() + if ch not in '\0 \r\n\x85\u2028\u2029': + raise ScannerError("while scanning a directive", start_mark, + "expected alphabetic or numeric character, but found %r" + % ch, self.get_mark()) + return value + + def scan_yaml_directive_value(self, start_mark): + # See the specification for details. + while self.peek() == ' ': + self.forward() + major = self.scan_yaml_directive_number(start_mark) + if self.peek() != '.': + raise ScannerError("while scanning a directive", start_mark, + "expected a digit or '.', but found %r" % self.peek(), + self.get_mark()) + self.forward() + minor = self.scan_yaml_directive_number(start_mark) + if self.peek() not in '\0 \r\n\x85\u2028\u2029': + raise ScannerError("while scanning a directive", start_mark, + "expected a digit or ' ', but found %r" % self.peek(), + self.get_mark()) + return (major, minor) + + def scan_yaml_directive_number(self, start_mark): + # See the specification for details. + ch = self.peek() + if not ('0' <= ch <= '9'): + raise ScannerError("while scanning a directive", start_mark, + "expected a digit, but found %r" % ch, self.get_mark()) + length = 0 + while '0' <= self.peek(length) <= '9': + length += 1 + value = int(self.prefix(length)) + self.forward(length) + return value + + def scan_tag_directive_value(self, start_mark): + # See the specification for details. + while self.peek() == ' ': + self.forward() + handle = self.scan_tag_directive_handle(start_mark) + while self.peek() == ' ': + self.forward() + prefix = self.scan_tag_directive_prefix(start_mark) + return (handle, prefix) + + def scan_tag_directive_handle(self, start_mark): + # See the specification for details. + value = self.scan_tag_handle('directive', start_mark) + ch = self.peek() + if ch != ' ': + raise ScannerError("while scanning a directive", start_mark, + "expected ' ', but found %r" % ch, self.get_mark()) + return value + + def scan_tag_directive_prefix(self, start_mark): + # See the specification for details. + value = self.scan_tag_uri('directive', start_mark) + ch = self.peek() + if ch not in '\0 \r\n\x85\u2028\u2029': + raise ScannerError("while scanning a directive", start_mark, + "expected ' ', but found %r" % ch, self.get_mark()) + return value + + def scan_directive_ignored_line(self, start_mark): + # See the specification for details. + while self.peek() == ' ': + self.forward() + if self.peek() == '#': + while self.peek() not in '\0\r\n\x85\u2028\u2029': + self.forward() + ch = self.peek() + if ch not in '\0\r\n\x85\u2028\u2029': + raise ScannerError("while scanning a directive", start_mark, + "expected a comment or a line break, but found %r" + % ch, self.get_mark()) + self.scan_line_break() + + def scan_anchor(self, TokenClass): + # The specification does not restrict characters for anchors and + # aliases. This may lead to problems, for instance, the document: + # [ *alias, value ] + # can be interpreted in two ways, as + # [ "value" ] + # and + # [ *alias , "value" ] + # Therefore we restrict aliases to numbers and ASCII letters. + start_mark = self.get_mark() + indicator = self.peek() + if indicator == '*': + name = 'alias' + else: + name = 'anchor' + self.forward() + length = 0 + ch = self.peek(length) + while '0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \ + or ch in '-_': + length += 1 + ch = self.peek(length) + if not length: + raise ScannerError("while scanning an %s" % name, start_mark, + "expected alphabetic or numeric character, but found %r" + % ch, self.get_mark()) + value = self.prefix(length) + self.forward(length) + ch = self.peek() + if ch not in '\0 \t\r\n\x85\u2028\u2029?:,]}%@`': + raise ScannerError("while scanning an %s" % name, start_mark, + "expected alphabetic or numeric character, but found %r" + % ch, self.get_mark()) + end_mark = self.get_mark() + return TokenClass(value, start_mark, end_mark) + + def scan_tag(self): + # See the specification for details. + start_mark = self.get_mark() + ch = self.peek(1) + if ch == '<': + handle = None + self.forward(2) + suffix = self.scan_tag_uri('tag', start_mark) + if self.peek() != '>': + raise ScannerError("while parsing a tag", start_mark, + "expected '>', but found %r" % self.peek(), + self.get_mark()) + self.forward() + elif ch in '\0 \t\r\n\x85\u2028\u2029': + handle = None + suffix = '!' + self.forward() + else: + length = 1 + use_handle = False + while ch not in '\0 \r\n\x85\u2028\u2029': + if ch == '!': + use_handle = True + break + length += 1 + ch = self.peek(length) + handle = '!' + if use_handle: + handle = self.scan_tag_handle('tag', start_mark) + else: + handle = '!' + self.forward() + suffix = self.scan_tag_uri('tag', start_mark) + ch = self.peek() + if ch not in '\0 \r\n\x85\u2028\u2029': + raise ScannerError("while scanning a tag", start_mark, + "expected ' ', but found %r" % ch, self.get_mark()) + value = (handle, suffix) + end_mark = self.get_mark() + return TagToken(value, start_mark, end_mark) + + def scan_block_scalar(self, style): + # See the specification for details. + + if style == '>': + folded = True + else: + folded = False + + chunks = [] + start_mark = self.get_mark() + + # Scan the header. + self.forward() + chomping, increment = self.scan_block_scalar_indicators(start_mark) + self.scan_block_scalar_ignored_line(start_mark) + + # Determine the indentation level and go to the first non-empty line. + min_indent = self.indent+1 + if min_indent < 1: + min_indent = 1 + if increment is None: + breaks, max_indent, end_mark = self.scan_block_scalar_indentation() + indent = max(min_indent, max_indent) + else: + indent = min_indent+increment-1 + breaks, end_mark = self.scan_block_scalar_breaks(indent) + line_break = '' + + # Scan the inner part of the block scalar. + while self.column == indent and self.peek() != '\0': + chunks.extend(breaks) + leading_non_space = self.peek() not in ' \t' + length = 0 + while self.peek(length) not in '\0\r\n\x85\u2028\u2029': + length += 1 + chunks.append(self.prefix(length)) + self.forward(length) + line_break = self.scan_line_break() + breaks, end_mark = self.scan_block_scalar_breaks(indent) + if self.column == indent and self.peek() != '\0': + + # Unfortunately, folding rules are ambiguous. + # + # This is the folding according to the specification: + + if folded and line_break == '\n' \ + and leading_non_space and self.peek() not in ' \t': + if not breaks: + chunks.append(' ') + else: + chunks.append(line_break) + + # This is Clark Evans's interpretation (also in the spec + # examples): + # + #if folded and line_break == '\n': + # if not breaks: + # if self.peek() not in ' \t': + # chunks.append(' ') + # else: + # chunks.append(line_break) + #else: + # chunks.append(line_break) + else: + break + + # Chomp the tail. + if chomping is not False: + chunks.append(line_break) + if chomping is True: + chunks.extend(breaks) + + # We are done. + return ScalarToken(''.join(chunks), False, start_mark, end_mark, + style) + + def scan_block_scalar_indicators(self, start_mark): + # See the specification for details. + chomping = None + increment = None + ch = self.peek() + if ch in '+-': + if ch == '+': + chomping = True + else: + chomping = False + self.forward() + ch = self.peek() + if ch in '0123456789': + increment = int(ch) + if increment == 0: + raise ScannerError("while scanning a block scalar", start_mark, + "expected indentation indicator in the range 1-9, but found 0", + self.get_mark()) + self.forward() + elif ch in '0123456789': + increment = int(ch) + if increment == 0: + raise ScannerError("while scanning a block scalar", start_mark, + "expected indentation indicator in the range 1-9, but found 0", + self.get_mark()) + self.forward() + ch = self.peek() + if ch in '+-': + if ch == '+': + chomping = True + else: + chomping = False + self.forward() + ch = self.peek() + if ch not in '\0 \r\n\x85\u2028\u2029': + raise ScannerError("while scanning a block scalar", start_mark, + "expected chomping or indentation indicators, but found %r" + % ch, self.get_mark()) + return chomping, increment + + def scan_block_scalar_ignored_line(self, start_mark): + # See the specification for details. + while self.peek() == ' ': + self.forward() + if self.peek() == '#': + while self.peek() not in '\0\r\n\x85\u2028\u2029': + self.forward() + ch = self.peek() + if ch not in '\0\r\n\x85\u2028\u2029': + raise ScannerError("while scanning a block scalar", start_mark, + "expected a comment or a line break, but found %r" % ch, + self.get_mark()) + self.scan_line_break() + + def scan_block_scalar_indentation(self): + # See the specification for details. + chunks = [] + max_indent = 0 + end_mark = self.get_mark() + while self.peek() in ' \r\n\x85\u2028\u2029': + if self.peek() != ' ': + chunks.append(self.scan_line_break()) + end_mark = self.get_mark() + else: + self.forward() + if self.column > max_indent: + max_indent = self.column + return chunks, max_indent, end_mark + + def scan_block_scalar_breaks(self, indent): + # See the specification for details. + chunks = [] + end_mark = self.get_mark() + while self.column < indent and self.peek() == ' ': + self.forward() + while self.peek() in '\r\n\x85\u2028\u2029': + chunks.append(self.scan_line_break()) + end_mark = self.get_mark() + while self.column < indent and self.peek() == ' ': + self.forward() + return chunks, end_mark + + def scan_flow_scalar(self, style): + # See the specification for details. + # Note that we loose indentation rules for quoted scalars. Quoted + # scalars don't need to adhere indentation because " and ' clearly + # mark the beginning and the end of them. Therefore we are less + # restrictive then the specification requires. We only need to check + # that document separators are not included in scalars. + if style == '"': + double = True + else: + double = False + chunks = [] + start_mark = self.get_mark() + quote = self.peek() + self.forward() + chunks.extend(self.scan_flow_scalar_non_spaces(double, start_mark)) + while self.peek() != quote: + chunks.extend(self.scan_flow_scalar_spaces(double, start_mark)) + chunks.extend(self.scan_flow_scalar_non_spaces(double, start_mark)) + self.forward() + end_mark = self.get_mark() + return ScalarToken(''.join(chunks), False, start_mark, end_mark, + style) + + ESCAPE_REPLACEMENTS = { + '0': '\0', + 'a': '\x07', + 'b': '\x08', + 't': '\x09', + '\t': '\x09', + 'n': '\x0A', + 'v': '\x0B', + 'f': '\x0C', + 'r': '\x0D', + 'e': '\x1B', + ' ': '\x20', + '\"': '\"', + '\\': '\\', + '/': '/', + 'N': '\x85', + '_': '\xA0', + 'L': '\u2028', + 'P': '\u2029', + } + + ESCAPE_CODES = { + 'x': 2, + 'u': 4, + 'U': 8, + } + + def scan_flow_scalar_non_spaces(self, double, start_mark): + # See the specification for details. + chunks = [] + while True: + length = 0 + while self.peek(length) not in '\'\"\\\0 \t\r\n\x85\u2028\u2029': + length += 1 + if length: + chunks.append(self.prefix(length)) + self.forward(length) + ch = self.peek() + if not double and ch == '\'' and self.peek(1) == '\'': + chunks.append('\'') + self.forward(2) + elif (double and ch == '\'') or (not double and ch in '\"\\'): + chunks.append(ch) + self.forward() + elif double and ch == '\\': + self.forward() + ch = self.peek() + if ch in self.ESCAPE_REPLACEMENTS: + chunks.append(self.ESCAPE_REPLACEMENTS[ch]) + self.forward() + elif ch in self.ESCAPE_CODES: + length = self.ESCAPE_CODES[ch] + self.forward() + for k in range(length): + if self.peek(k) not in '0123456789ABCDEFabcdef': + raise ScannerError("while scanning a double-quoted scalar", start_mark, + "expected escape sequence of %d hexadecimal numbers, but found %r" % + (length, self.peek(k)), self.get_mark()) + code = int(self.prefix(length), 16) + chunks.append(chr(code)) + self.forward(length) + elif ch in '\r\n\x85\u2028\u2029': + self.scan_line_break() + chunks.extend(self.scan_flow_scalar_breaks(double, start_mark)) + else: + raise ScannerError("while scanning a double-quoted scalar", start_mark, + "found unknown escape character %r" % ch, self.get_mark()) + else: + return chunks + + def scan_flow_scalar_spaces(self, double, start_mark): + # See the specification for details. + chunks = [] + length = 0 + while self.peek(length) in ' \t': + length += 1 + whitespaces = self.prefix(length) + self.forward(length) + ch = self.peek() + if ch == '\0': + raise ScannerError("while scanning a quoted scalar", start_mark, + "found unexpected end of stream", self.get_mark()) + elif ch in '\r\n\x85\u2028\u2029': + line_break = self.scan_line_break() + breaks = self.scan_flow_scalar_breaks(double, start_mark) + if line_break != '\n': + chunks.append(line_break) + elif not breaks: + chunks.append(' ') + chunks.extend(breaks) + else: + chunks.append(whitespaces) + return chunks + + def scan_flow_scalar_breaks(self, double, start_mark): + # See the specification for details. + chunks = [] + while True: + # Instead of checking indentation, we check for document + # separators. + prefix = self.prefix(3) + if (prefix == '---' or prefix == '...') \ + and self.peek(3) in '\0 \t\r\n\x85\u2028\u2029': + raise ScannerError("while scanning a quoted scalar", start_mark, + "found unexpected document separator", self.get_mark()) + while self.peek() in ' \t': + self.forward() + if self.peek() in '\r\n\x85\u2028\u2029': + chunks.append(self.scan_line_break()) + else: + return chunks + + def scan_plain(self): + # See the specification for details. + # We add an additional restriction for the flow context: + # plain scalars in the flow context cannot contain ',' or '?'. + # We also keep track of the `allow_simple_key` flag here. + # Indentation rules are loosed for the flow context. + chunks = [] + start_mark = self.get_mark() + end_mark = start_mark + indent = self.indent+1 + # We allow zero indentation for scalars, but then we need to check for + # document separators at the beginning of the line. + #if indent == 0: + # indent = 1 + spaces = [] + while True: + length = 0 + if self.peek() == '#': + break + while True: + ch = self.peek(length) + if ch in '\0 \t\r\n\x85\u2028\u2029' \ + or (ch == ':' and + self.peek(length+1) in '\0 \t\r\n\x85\u2028\u2029' + + (u',[]{}' if self.flow_level else u''))\ + or (self.flow_level and ch in ',?[]{}'): + break + length += 1 + if length == 0: + break + self.allow_simple_key = False + chunks.extend(spaces) + chunks.append(self.prefix(length)) + self.forward(length) + end_mark = self.get_mark() + spaces = self.scan_plain_spaces(indent, start_mark) + if not spaces or self.peek() == '#' \ + or (not self.flow_level and self.column < indent): + break + return ScalarToken(''.join(chunks), True, start_mark, end_mark) + + def scan_plain_spaces(self, indent, start_mark): + # See the specification for details. + # The specification is really confusing about tabs in plain scalars. + # We just forbid them completely. Do not use tabs in YAML! + chunks = [] + length = 0 + while self.peek(length) in ' ': + length += 1 + whitespaces = self.prefix(length) + self.forward(length) + ch = self.peek() + if ch in '\r\n\x85\u2028\u2029': + line_break = self.scan_line_break() + self.allow_simple_key = True + prefix = self.prefix(3) + if (prefix == '---' or prefix == '...') \ + and self.peek(3) in '\0 \t\r\n\x85\u2028\u2029': + return + breaks = [] + while self.peek() in ' \r\n\x85\u2028\u2029': + if self.peek() == ' ': + self.forward() + else: + breaks.append(self.scan_line_break()) + prefix = self.prefix(3) + if (prefix == '---' or prefix == '...') \ + and self.peek(3) in '\0 \t\r\n\x85\u2028\u2029': + return + if line_break != '\n': + chunks.append(line_break) + elif not breaks: + chunks.append(' ') + chunks.extend(breaks) + elif whitespaces: + chunks.append(whitespaces) + return chunks + + def scan_tag_handle(self, name, start_mark): + # See the specification for details. + # For some strange reasons, the specification does not allow '_' in + # tag handles. I have allowed it anyway. + ch = self.peek() + if ch != '!': + raise ScannerError("while scanning a %s" % name, start_mark, + "expected '!', but found %r" % ch, self.get_mark()) + length = 1 + ch = self.peek(length) + if ch != ' ': + while '0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \ + or ch in '-_': + length += 1 + ch = self.peek(length) + if ch != '!': + self.forward(length) + raise ScannerError("while scanning a %s" % name, start_mark, + "expected '!', but found %r" % ch, self.get_mark()) + length += 1 + value = self.prefix(length) + self.forward(length) + return value + + def scan_tag_uri(self, name, start_mark): + # See the specification for details. + # Note: we do not check if URI is well-formed. + chunks = [] + length = 0 + ch = self.peek(length) + while '0' <= ch <= '9' or 'A' <= ch <= 'Z' or 'a' <= ch <= 'z' \ + or ch in '-;/?:@&=+$,_.!~*\'()[]%': + if ch == '%': + chunks.append(self.prefix(length)) + self.forward(length) + length = 0 + chunks.append(self.scan_uri_escapes(name, start_mark)) + else: + length += 1 + ch = self.peek(length) + if length: + chunks.append(self.prefix(length)) + self.forward(length) + length = 0 + if not chunks: + raise ScannerError("while parsing a %s" % name, start_mark, + "expected URI, but found %r" % ch, self.get_mark()) + return ''.join(chunks) + + def scan_uri_escapes(self, name, start_mark): + # See the specification for details. + codes = [] + mark = self.get_mark() + while self.peek() == '%': + self.forward() + for k in range(2): + if self.peek(k) not in '0123456789ABCDEFabcdef': + raise ScannerError("while scanning a %s" % name, start_mark, + "expected URI escape sequence of 2 hexadecimal numbers, but found %r" + % self.peek(k), self.get_mark()) + codes.append(int(self.prefix(2), 16)) + self.forward(2) + try: + value = bytes(codes).decode('utf-8') + except UnicodeDecodeError as exc: + raise ScannerError("while scanning a %s" % name, start_mark, str(exc), mark) + return value + + def scan_line_break(self): + # Transforms: + # '\r\n' : '\n' + # '\r' : '\n' + # '\n' : '\n' + # '\x85' : '\n' + # '\u2028' : '\u2028' + # '\u2029 : '\u2029' + # default : '' + ch = self.peek() + if ch in '\r\n\x85': + if self.prefix(2) == '\r\n': + self.forward(2) + else: + self.forward() + return '\n' + elif ch in '\u2028\u2029': + self.forward() + return ch + return '' diff --git a/source/yaml/serializer.py b/source/yaml/serializer.py new file mode 100644 index 0000000000000000000000000000000000000000..fe911e67ae7a739abb491fbbc6834b9c37bbda4b --- /dev/null +++ b/source/yaml/serializer.py @@ -0,0 +1,111 @@ + +__all__ = ['Serializer', 'SerializerError'] + +from .error import YAMLError +from .events import * +from .nodes import * + +class SerializerError(YAMLError): + pass + +class Serializer: + + ANCHOR_TEMPLATE = 'id%03d' + + def __init__(self, encoding=None, + explicit_start=None, explicit_end=None, version=None, tags=None): + self.use_encoding = encoding + self.use_explicit_start = explicit_start + self.use_explicit_end = explicit_end + self.use_version = version + self.use_tags = tags + self.serialized_nodes = {} + self.anchors = {} + self.last_anchor_id = 0 + self.closed = None + + def open(self): + if self.closed is None: + self.emit(StreamStartEvent(encoding=self.use_encoding)) + self.closed = False + elif self.closed: + raise SerializerError("serializer is closed") + else: + raise SerializerError("serializer is already opened") + + def close(self): + if self.closed is None: + raise SerializerError("serializer is not opened") + elif not self.closed: + self.emit(StreamEndEvent()) + self.closed = True + + #def __del__(self): + # self.close() + + def serialize(self, node): + if self.closed is None: + raise SerializerError("serializer is not opened") + elif self.closed: + raise SerializerError("serializer is closed") + self.emit(DocumentStartEvent(explicit=self.use_explicit_start, + version=self.use_version, tags=self.use_tags)) + self.anchor_node(node) + self.serialize_node(node, None, None) + self.emit(DocumentEndEvent(explicit=self.use_explicit_end)) + self.serialized_nodes = {} + self.anchors = {} + self.last_anchor_id = 0 + + def anchor_node(self, node): + if node in self.anchors: + if self.anchors[node] is None: + self.anchors[node] = self.generate_anchor(node) + else: + self.anchors[node] = None + if isinstance(node, SequenceNode): + for item in node.value: + self.anchor_node(item) + elif isinstance(node, MappingNode): + for key, value in node.value: + self.anchor_node(key) + self.anchor_node(value) + + def generate_anchor(self, node): + self.last_anchor_id += 1 + return self.ANCHOR_TEMPLATE % self.last_anchor_id + + def serialize_node(self, node, parent, index): + alias = self.anchors[node] + if node in self.serialized_nodes: + self.emit(AliasEvent(alias)) + else: + self.serialized_nodes[node] = True + self.descend_resolver(parent, index) + if isinstance(node, ScalarNode): + detected_tag = self.resolve(ScalarNode, node.value, (True, False)) + default_tag = self.resolve(ScalarNode, node.value, (False, True)) + implicit = (node.tag == detected_tag), (node.tag == default_tag) + self.emit(ScalarEvent(alias, node.tag, implicit, node.value, + style=node.style)) + elif isinstance(node, SequenceNode): + implicit = (node.tag + == self.resolve(SequenceNode, node.value, True)) + self.emit(SequenceStartEvent(alias, node.tag, implicit, + flow_style=node.flow_style)) + index = 0 + for item in node.value: + self.serialize_node(item, node, index) + index += 1 + self.emit(SequenceEndEvent()) + elif isinstance(node, MappingNode): + implicit = (node.tag + == self.resolve(MappingNode, node.value, True)) + self.emit(MappingStartEvent(alias, node.tag, implicit, + flow_style=node.flow_style)) + for key, value in node.value: + self.serialize_node(key, node, None) + self.serialize_node(value, node, key) + self.emit(MappingEndEvent()) + self.ascend_resolver() + diff --git a/source/yaml/tokens.py b/source/yaml/tokens.py new file mode 100644 index 0000000000000000000000000000000000000000..4d0b48a394ac8c019b401516a12f688df361cf90 --- /dev/null +++ b/source/yaml/tokens.py @@ -0,0 +1,104 @@ + +class Token(object): + def __init__(self, start_mark, end_mark): + self.start_mark = start_mark + self.end_mark = end_mark + def __repr__(self): + attributes = [key for key in self.__dict__ + if not key.endswith('_mark')] + attributes.sort() + arguments = ', '.join(['%s=%r' % (key, getattr(self, key)) + for key in attributes]) + return '%s(%s)' % (self.__class__.__name__, arguments) + +#class BOMToken(Token): +# id = '' + +class DirectiveToken(Token): + id = '' + def __init__(self, name, value, start_mark, end_mark): + self.name = name + self.value = value + self.start_mark = start_mark + self.end_mark = end_mark + +class DocumentStartToken(Token): + id = '' + +class DocumentEndToken(Token): + id = '' + +class StreamStartToken(Token): + id = '' + def __init__(self, start_mark=None, end_mark=None, + encoding=None): + self.start_mark = start_mark + self.end_mark = end_mark + self.encoding = encoding + +class StreamEndToken(Token): + id = '' + +class BlockSequenceStartToken(Token): + id = '' + +class BlockMappingStartToken(Token): + id = '' + +class BlockEndToken(Token): + id = '' + +class FlowSequenceStartToken(Token): + id = '[' + +class FlowMappingStartToken(Token): + id = '{' + +class FlowSequenceEndToken(Token): + id = ']' + +class FlowMappingEndToken(Token): + id = '}' + +class KeyToken(Token): + id = '?' + +class ValueToken(Token): + id = ':' + +class BlockEntryToken(Token): + id = '-' + +class FlowEntryToken(Token): + id = ',' + +class AliasToken(Token): + id = '' + def __init__(self, value, start_mark, end_mark): + self.value = value + self.start_mark = start_mark + self.end_mark = end_mark + +class AnchorToken(Token): + id = '' + def __init__(self, value, start_mark, end_mark): + self.value = value + self.start_mark = start_mark + self.end_mark = end_mark + +class TagToken(Token): + id = '' + def __init__(self, value, start_mark, end_mark): + self.value = value + self.start_mark = start_mark + self.end_mark = end_mark + +class ScalarToken(Token): + id = '' + def __init__(self, value, plain, start_mark, end_mark, style=None): + self.value = value + self.plain = plain + self.start_mark = start_mark + self.end_mark = end_mark + self.style = style + diff --git a/source/yarl-1.22.0.dist-info/INSTALLER b/source/yarl-1.22.0.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/source/yarl-1.22.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/source/yarl-1.22.0.dist-info/METADATA b/source/yarl-1.22.0.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..9556b218e19af99a382a55cd6e95e8634748d8e3 --- /dev/null +++ b/source/yarl-1.22.0.dist-info/METADATA @@ -0,0 +1,2478 @@ +Metadata-Version: 2.4 +Name: yarl +Version: 1.22.0 +Summary: Yet another URL library +Home-page: https://github.com/aio-libs/yarl +Author: Andrew Svetlov +Author-email: andrew.svetlov@gmail.com +Maintainer: aiohttp team +Maintainer-email: team@aiohttp.org +License: Apache-2.0 +Project-URL: Chat: Matrix, https://matrix.to/#/#aio-libs:matrix.org +Project-URL: Chat: Matrix Space, https://matrix.to/#/#aio-libs-space:matrix.org +Project-URL: CI: GitHub Workflows, https://github.com/aio-libs/yarl/actions?query=branch:master +Project-URL: Code of Conduct, https://github.com/aio-libs/.github/blob/master/CODE_OF_CONDUCT.md +Project-URL: Coverage: codecov, https://codecov.io/github/aio-libs/yarl +Project-URL: Docs: Changelog, https://yarl.aio-libs.org/en/latest/changes/ +Project-URL: Docs: RTD, https://yarl.aio-libs.org +Project-URL: GitHub: issues, https://github.com/aio-libs/yarl/issues +Project-URL: GitHub: repo, https://github.com/aio-libs/yarl +Keywords: cython,cext,yarl +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Programming Language :: Cython +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Programming Language :: Python :: 3.13 +Classifier: Programming Language :: Python :: 3.14 +Classifier: Topic :: Internet :: WWW/HTTP +Classifier: Topic :: Software Development :: Libraries :: Python Modules +Requires-Python: >=3.9 +Description-Content-Type: text/x-rst +License-File: LICENSE +License-File: NOTICE +Requires-Dist: idna>=2.0 +Requires-Dist: multidict>=4.0 +Requires-Dist: propcache>=0.2.1 +Dynamic: license-file + +yarl +==== + +The module provides handy URL class for URL parsing and changing. + +.. image:: https://github.com/aio-libs/yarl/workflows/CI/badge.svg + :target: https://github.com/aio-libs/yarl/actions?query=workflow%3ACI + :align: right + +.. image:: https://codecov.io/gh/aio-libs/yarl/graph/badge.svg?flag=pytest + :target: https://app.codecov.io/gh/aio-libs/yarl?flags[]=pytest + :alt: Codecov coverage for the pytest-driven measurements + +.. image:: https://img.shields.io/endpoint?url=https://codspeed.io/badge.json + :target: https://codspeed.io/aio-libs/yarl + +.. image:: https://badge.fury.io/py/yarl.svg + :target: https://badge.fury.io/py/yarl + +.. image:: https://readthedocs.org/projects/yarl/badge/?version=latest + :target: https://yarl.aio-libs.org + +.. image:: https://img.shields.io/pypi/pyversions/yarl.svg + :target: https://pypi.python.org/pypi/yarl + +.. image:: https://img.shields.io/matrix/aio-libs:matrix.org?label=Discuss%20on%20Matrix%20at%20%23aio-libs%3Amatrix.org&logo=matrix&server_fqdn=matrix.org&style=flat + :target: https://matrix.to/#/%23aio-libs:matrix.org + :alt: Matrix Room — #aio-libs:matrix.org + +.. image:: https://img.shields.io/matrix/aio-libs-space:matrix.org?label=Discuss%20on%20Matrix%20at%20%23aio-libs-space%3Amatrix.org&logo=matrix&server_fqdn=matrix.org&style=flat + :target: https://matrix.to/#/%23aio-libs-space:matrix.org + :alt: Matrix Space — #aio-libs-space:matrix.org + + +Introduction +------------ + +Url is constructed from ``str``: + +.. code-block:: pycon + + >>> from yarl import URL + >>> url = URL('https://www.python.org/~guido?arg=1#frag') + >>> url + URL('https://www.python.org/~guido?arg=1#frag') + +All url parts: *scheme*, *user*, *password*, *host*, *port*, *path*, +*query* and *fragment* are accessible by properties: + +.. code-block:: pycon + + >>> url.scheme + 'https' + >>> url.host + 'www.python.org' + >>> url.path + '/~guido' + >>> url.query_string + 'arg=1' + >>> url.query + + >>> url.fragment + 'frag' + +All url manipulations produce a new url object: + +.. code-block:: pycon + + >>> url = URL('https://www.python.org') + >>> url / 'foo' / 'bar' + URL('https://www.python.org/foo/bar') + >>> url / 'foo' % {'bar': 'baz'} + URL('https://www.python.org/foo?bar=baz') + +Strings passed to constructor and modification methods are +automatically encoded giving canonical representation as result: + +.. code-block:: pycon + + >>> url = URL('https://www.python.org/шлях') + >>> url + URL('https://www.python.org/%D1%88%D0%BB%D1%8F%D1%85') + +Regular properties are *percent-decoded*, use ``raw_`` versions for +getting *encoded* strings: + +.. code-block:: pycon + + >>> url.path + '/шлях' + + >>> url.raw_path + '/%D1%88%D0%BB%D1%8F%D1%85' + +Human readable representation of URL is available as ``.human_repr()``: + +.. code-block:: pycon + + >>> url.human_repr() + 'https://www.python.org/шлях' + +For full documentation please read https://yarl.aio-libs.org. + + +Installation +------------ + +:: + + $ pip install yarl + +The library is Python 3 only! + +PyPI contains binary wheels for Linux, Windows and MacOS. If you want to install +``yarl`` on another operating system where wheels are not provided, +the tarball will be used to compile the library from +the source code. It requires a C compiler and and Python headers installed. + +To skip the compilation you must explicitly opt-in by using a PEP 517 +configuration setting ``pure-python``, or setting the ``YARL_NO_EXTENSIONS`` +environment variable to a non-empty value, e.g.: + +.. code-block:: console + + $ pip install yarl --config-settings=pure-python=false + +Please note that the pure-Python (uncompiled) version is much slower. However, +PyPy always uses a pure-Python implementation, and, as such, it is unaffected +by this variable. + +Dependencies +------------ + +YARL requires multidict_ and propcache_ libraries. + + +API documentation +------------------ + +The documentation is located at https://yarl.aio-libs.org. + + +Why isn't boolean supported by the URL query API? +------------------------------------------------- + +There is no standard for boolean representation of boolean values. + +Some systems prefer ``true``/``false``, others like ``yes``/``no``, ``on``/``off``, +``Y``/``N``, ``1``/``0``, etc. + +``yarl`` cannot make an unambiguous decision on how to serialize ``bool`` values because +it is specific to how the end-user's application is built and would be different for +different apps. The library doesn't accept booleans in the API; a user should convert +bools into strings using own preferred translation protocol. + + +Comparison with other URL libraries +------------------------------------ + +* furl (https://pypi.python.org/pypi/furl) + + The library has rich functionality but the ``furl`` object is mutable. + + I'm afraid to pass this object into foreign code: who knows if the + code will modify my url in a terrible way while I just want to send URL + with handy helpers for accessing URL properties. + + ``furl`` has other non-obvious tricky things but the main objection + is mutability. + +* URLObject (https://pypi.python.org/pypi/URLObject) + + URLObject is immutable, that's pretty good. + + Every URL change generates a new URL object. + + But the library doesn't do any decode/encode transformations leaving the + end user to cope with these gory details. + + +Source code +----------- + +The project is hosted on GitHub_ + +Please file an issue on the `bug tracker +`_ if you have found a bug +or have some suggestion in order to improve the library. + +Discussion list +--------------- + +*aio-libs* google group: https://groups.google.com/forum/#!forum/aio-libs + +Feel free to post your questions and ideas here. + + +Authors and License +------------------- + +The ``yarl`` package is written by Andrew Svetlov. + +It's *Apache 2* licensed and freely available. + + +.. _GitHub: https://github.com/aio-libs/yarl + +.. _multidict: https://github.com/aio-libs/multidict + +.. _propcache: https://github.com/aio-libs/propcache + +========= +Changelog +========= + +.. + You should *NOT* be adding new change log entries to this file, this + file is managed by towncrier. You *may* edit previous change logs to + fix problems like typo corrections or such. + To add a new change log entry, please see + https://pip.pypa.io/en/latest/development/#adding-a-news-entry + we named the news folder "changes". + + WARNING: Don't drop the next directive! + +.. towncrier release notes start + +1.22.0 +====== + +*(2025-10-05)* + + +Features +-------- + +- Added arm64 Windows wheel builds + -- by `@finnagin `__. + + *Related issues and pull requests on GitHub:* + `#1516 `__. + + +---- + + +1.21.0 +====== + +*(2025-10-05)* + + +Contributor-facing changes +-------------------------- + +- The ``reusable-cibuildwheel.yml`` workflow has been refactored to + be more generic and ``ci-cd.yml`` now holds all the configuration + toggles -- by `@webknjaz `__. + + *Related issues and pull requests on GitHub:* + `#1535 `__. + +- When building wheels, the source distribution is now passed directly + to the ``cibuildwheel`` invocation -- by `@webknjaz `__. + + *Related issues and pull requests on GitHub:* + `#1536 `__. + +- Added CI for Python 3.14 -- by `@kumaraditya303 `__. + + *Related issues and pull requests on GitHub:* + `#1560 `__. + + +---- + + +1.20.1 +====== + +*(2025-06-09)* + + +Bug fixes +--------- + +- Started raising a ``ValueError`` exception raised for corrupted + IPv6 URL values. + + These fixes the issue where exception ``IndexError`` was + leaking from the internal code because of not being handled and + transformed into a user-facing error. The problem was happening + under the following conditions: empty IPv6 URL, brackets in + reverse order. + + -- by `@MaelPic `__. + + *Related issues and pull requests on GitHub:* + `#1512 `__. + + +Packaging updates and notes for downstreams +------------------------------------------- + +- Updated to use Cython 3.1 universally across the build path -- by `@lysnikolaou `__. + + *Related issues and pull requests on GitHub:* + `#1514 `__. + +- Made Cython line tracing opt-in via the ``with-cython-tracing`` build config setting -- by `@bdraco `__. + + Previously, line tracing was enabled by default in ``pyproject.toml``, which caused build issues for some users and made wheels nearly twice as slow. + Now line tracing is only enabled when explicitly requested via ``pip install . --config-setting=with-cython-tracing=true`` or by setting the ``YARL_CYTHON_TRACING`` environment variable. + + *Related issues and pull requests on GitHub:* + `#1521 `__. + + +---- + + +1.20.0 +====== + +*(2025-04-16)* + + +Features +-------- + +- Implemented support for the free-threaded build of CPython 3.13 -- by `@lysnikolaou `__. + + *Related issues and pull requests on GitHub:* + `#1456 `__. + + +Packaging updates and notes for downstreams +------------------------------------------- + +- Started building wheels for the free-threaded build of CPython 3.13 -- by `@lysnikolaou `__. + + *Related issues and pull requests on GitHub:* + `#1456 `__. + + +---- + + +1.19.0 +====== + +*(2025-04-05)* + + +Bug fixes +--------- + +- Fixed entire name being re-encoded when using ``yarl.URL.with_suffix()`` -- by `@NTFSvolume `__. + + *Related issues and pull requests on GitHub:* + `#1468 `__. + + +Features +-------- + +- Started building armv7l wheels for manylinux -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1495 `__. + + +Contributor-facing changes +-------------------------- + +- GitHub Actions CI/CD is now configured to manage caching pip-ecosystem + dependencies using `re-actors/cache-python-deps`_ -- an action by + `@webknjaz `__ that takes into account ABI stability and the exact + version of Python runtime. + + .. _`re-actors/cache-python-deps`: + https://github.com/marketplace/actions/cache-python-deps + + *Related issues and pull requests on GitHub:* + `#1471 `__. + +- Increased minimum `propcache`_ version to 0.2.1 to fix failing tests -- by `@bdraco `__. + + .. _`propcache`: + https://github.com/aio-libs/propcache + + *Related issues and pull requests on GitHub:* + `#1479 `__. + +- Added all hidden folders to pytest's ``norecursedirs`` to prevent it + from trying to collect tests there -- by `@lysnikolaou `__. + + *Related issues and pull requests on GitHub:* + `#1480 `__. + + +Miscellaneous internal changes +------------------------------ + +- Improved accuracy of type annotations -- by `@Dreamsorcerer `__. + + *Related issues and pull requests on GitHub:* + `#1484 `__. + +- Improved performance of parsing query strings -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1493 `__, `#1497 `__. + +- Improved performance of the C unquoter -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1496 `__, `#1498 `__. + + +---- + + +1.18.3 +====== + +*(2024-12-01)* + + +Bug fixes +--------- + +- Fixed uppercase ASCII hosts being rejected by ``URL.build()()`` and ``yarl.URL.with_host()`` -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#954 `__, `#1442 `__. + + +Miscellaneous internal changes +------------------------------ + +- Improved performances of multiple path properties on cache miss -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1443 `__. + + +---- + + +1.18.2 +====== + +*(2024-11-29)* + + +No significant changes. + + +---- + + +1.18.1 +====== + +*(2024-11-29)* + + +Miscellaneous internal changes +------------------------------ + +- Improved cache performance when ``~yarl.URL`` objects are constructed from ``yarl.URL.build()`` with ``encoded=True`` -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1432 `__. + +- Improved cache performance for operations that produce a new ``~yarl.URL`` object -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1434 `__, `#1436 `__. + + +---- + + +1.18.0 +====== + +*(2024-11-21)* + + +Features +-------- + +- Added ``keep_query`` and ``keep_fragment`` flags in the ``yarl.URL.with_path()``, ``yarl.URL.with_name()`` and ``yarl.URL.with_suffix()`` methods, allowing users to optionally retain the query string and fragment in the resulting URL when replacing the path -- by `@paul-nameless `__. + + *Related issues and pull requests on GitHub:* + `#111 `__, `#1421 `__. + + +Contributor-facing changes +-------------------------- + +- Started running downstream ``aiohttp`` tests in CI -- by `@Cycloctane `__. + + *Related issues and pull requests on GitHub:* + `#1415 `__. + + +Miscellaneous internal changes +------------------------------ + +- Improved performance of converting ``~yarl.URL`` to a string -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1422 `__. + + +---- + + +1.17.2 +====== + +*(2024-11-17)* + + +Bug fixes +--------- + +- Stopped implicitly allowing the use of Cython pre-release versions when + building the distribution package -- by `@ajsanchezsanz `__ and + `@markgreene74 `__. + + *Related issues and pull requests on GitHub:* + `#1411 `__, `#1412 `__. + +- Fixed a bug causing ``~yarl.URL.port`` to return the default port when the given port was zero + -- by `@gmacon `__. + + *Related issues and pull requests on GitHub:* + `#1413 `__. + + +Features +-------- + +- Make error messages include details of incorrect type when ``port`` is not int in ``yarl.URL.build()``. + -- by `@Cycloctane `__. + + *Related issues and pull requests on GitHub:* + `#1414 `__. + + +Packaging updates and notes for downstreams +------------------------------------------- + +- Stopped implicitly allowing the use of Cython pre-release versions when + building the distribution package -- by `@ajsanchezsanz `__ and + `@markgreene74 `__. + + *Related issues and pull requests on GitHub:* + `#1411 `__, `#1412 `__. + + +Miscellaneous internal changes +------------------------------ + +- Improved performance of the ``yarl.URL.joinpath()`` method -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1418 `__. + + +---- + + +1.17.1 +====== + +*(2024-10-30)* + + +Miscellaneous internal changes +------------------------------ + +- Improved performance of many ``~yarl.URL`` methods -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1396 `__, `#1397 `__, `#1398 `__. + +- Improved performance of passing a `dict` or `str` to ``yarl.URL.extend_query()`` -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1401 `__. + + +---- + + +1.17.0 +====== + +*(2024-10-28)* + + +Features +-------- + +- Added ``~yarl.URL.host_port_subcomponent`` which returns the ``3986#section-3.2.2`` host and ``3986#section-3.2.3`` port subcomponent -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1375 `__. + + +---- + + +1.16.0 +====== + +*(2024-10-21)* + + +Bug fixes +--------- + +- Fixed blocking I/O to load Python code when creating a new ``~yarl.URL`` with non-ascii characters in the network location part -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1342 `__. + + +Removals and backward incompatible breaking changes +--------------------------------------------------- + +- Migrated to using a single cache for encoding hosts -- by `@bdraco `__. + + Passing ``ip_address_size`` and ``host_validate_size`` to ``yarl.cache_configure()`` is deprecated in favor of the new ``encode_host_size`` parameter and will be removed in a future release. For backwards compatibility, the old parameters affect the ``encode_host`` cache size. + + *Related issues and pull requests on GitHub:* + `#1348 `__, `#1357 `__, `#1363 `__. + + +Miscellaneous internal changes +------------------------------ + +- Improved performance of constructing ``~yarl.URL`` -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1336 `__. + +- Improved performance of calling ``yarl.URL.build()`` and constructing unencoded ``~yarl.URL`` -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1345 `__. + +- Reworked the internal encoding cache to improve performance on cache hit -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1369 `__. + + +---- + + +1.15.5 +====== + +*(2024-10-18)* + + +Miscellaneous internal changes +------------------------------ + +- Improved performance of the ``yarl.URL.joinpath()`` method -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1304 `__. + +- Improved performance of the ``yarl.URL.extend_query()`` method -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1305 `__. + +- Improved performance of the ``yarl.URL.origin()`` method -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1306 `__. + +- Improved performance of the ``yarl.URL.with_path()`` method -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1307 `__. + +- Improved performance of the ``yarl.URL.with_query()`` method -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1308 `__, `#1328 `__. + +- Improved performance of the ``yarl.URL.update_query()`` method -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1309 `__, `#1327 `__. + +- Improved performance of the ``yarl.URL.join()`` method -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1313 `__. + +- Improved performance of ``~yarl.URL`` equality checks -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1315 `__. + +- Improved performance of ``~yarl.URL`` methods that modify the network location -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1316 `__. + +- Improved performance of the ``yarl.URL.with_fragment()`` method -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1317 `__. + +- Improved performance of calculating the hash of ``~yarl.URL`` objects -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1318 `__. + +- Improved performance of the ``yarl.URL.relative()`` method -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1319 `__. + +- Improved performance of the ``yarl.URL.with_name()`` method -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1320 `__. + +- Improved performance of ``~yarl.URL.parent`` -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1321 `__. + +- Improved performance of the ``yarl.URL.with_scheme()`` method -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1322 `__. + + +---- + + +1.15.4 +====== + +*(2024-10-16)* + + +Miscellaneous internal changes +------------------------------ + +- Improved performance of the quoter when all characters are safe -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1288 `__. + +- Improved performance of unquoting strings -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1292 `__, `#1293 `__. + +- Improved performance of calling ``yarl.URL.build()`` -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1297 `__. + + +---- + + +1.15.3 +====== + +*(2024-10-15)* + + +Bug fixes +--------- + +- Fixed ``yarl.URL.build()`` failing to validate paths must start with a ``/`` when passing ``authority`` -- by `@bdraco `__. + + The validation only worked correctly when passing ``host``. + + *Related issues and pull requests on GitHub:* + `#1265 `__. + + +Removals and backward incompatible breaking changes +--------------------------------------------------- + +- Removed support for Python 3.8 as it has reached end of life -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1203 `__. + + +Miscellaneous internal changes +------------------------------ + +- Improved performance of constructing ``~yarl.URL`` when the net location is only the host -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1271 `__. + + +---- + + +1.15.2 +====== + +*(2024-10-13)* + + +Miscellaneous internal changes +------------------------------ + +- Improved performance of converting ``~yarl.URL`` to a string -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1234 `__. + +- Improved performance of ``yarl.URL.joinpath()`` -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1248 `__, `#1250 `__. + +- Improved performance of constructing query strings from ``~multidict.MultiDict`` -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1256 `__. + +- Improved performance of constructing query strings with ``int`` values -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1259 `__. + + +---- + + +1.15.1 +====== + +*(2024-10-12)* + + +Miscellaneous internal changes +------------------------------ + +- Improved performance of calling ``yarl.URL.build()`` -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1222 `__. + +- Improved performance of all ``~yarl.URL`` methods that create new ``~yarl.URL`` objects -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1226 `__. + +- Improved performance of ``~yarl.URL`` methods that modify the network location -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1229 `__. + + +---- + + +1.15.0 +====== + +*(2024-10-11)* + + +Bug fixes +--------- + +- Fixed validation with ``yarl.URL.with_scheme()`` when passed scheme is not lowercase -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1189 `__. + + +Features +-------- + +- Started building ``armv7l`` wheels -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1204 `__. + + +Miscellaneous internal changes +------------------------------ + +- Improved performance of constructing unencoded ``~yarl.URL`` objects -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1188 `__. + +- Added a cache for parsing hosts to reduce overhead of encoding ``~yarl.URL`` -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1190 `__. + +- Improved performance of constructing query strings from ``~collections.abc.Mapping`` -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1193 `__. + +- Improved performance of converting ``~yarl.URL`` objects to strings -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1198 `__. + + +---- + + +1.14.0 +====== + +*(2024-10-08)* + + +Packaging updates and notes for downstreams +------------------------------------------- + +- Switched to using the ``propcache`` package for property caching + -- by `@bdraco `__. + + The ``propcache`` package is derived from the property caching + code in ``yarl`` and has been broken out to avoid maintaining it for multiple + projects. + + *Related issues and pull requests on GitHub:* + `#1169 `__. + + +Contributor-facing changes +-------------------------- + +- Started testing with Hypothesis -- by `@webknjaz `__ and `@bdraco `__. + + Special thanks to `@Zac-HD `__ for helping us get started with this framework. + + *Related issues and pull requests on GitHub:* + `#860 `__. + + +Miscellaneous internal changes +------------------------------ + +- Improved performance of ``yarl.URL.is_default_port()`` when no explicit port is set -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1168 `__. + +- Improved performance of converting ``~yarl.URL`` to a string when no explicit port is set -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1170 `__. + +- Improved performance of the ``yarl.URL.origin()`` method -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1175 `__. + +- Improved performance of encoding hosts -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1176 `__. + + +---- + + +1.13.1 +====== + +*(2024-09-27)* + + +Miscellaneous internal changes +------------------------------ + +- Improved performance of calling ``yarl.URL.build()`` with ``authority`` -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1163 `__. + + +---- + + +1.13.0 +====== + +*(2024-09-26)* + + +Bug fixes +--------- + +- Started rejecting ASCII hostnames with invalid characters. For host strings that + look like authority strings, the exception message includes advice on what to do + instead -- by `@mjpieters `__. + + *Related issues and pull requests on GitHub:* + `#880 `__, `#954 `__. + +- Fixed IPv6 addresses missing brackets when the ``~yarl.URL`` was converted to a string -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1157 `__, `#1158 `__. + + +Features +-------- + +- Added ``~yarl.URL.host_subcomponent`` which returns the ``3986#section-3.2.2`` host subcomponent -- by `@bdraco `__. + + The only current practical difference between ``~yarl.URL.raw_host`` and ``~yarl.URL.host_subcomponent`` is that IPv6 addresses are returned bracketed. + + *Related issues and pull requests on GitHub:* + `#1159 `__. + + +---- + + +1.12.1 +====== + +*(2024-09-23)* + + +No significant changes. + + +---- + + +1.12.0 +====== + +*(2024-09-23)* + + +Features +-------- + +- Added ``~yarl.URL.path_safe`` to be able to fetch the path without ``%2F`` and ``%25`` decoded -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1150 `__. + + +Removals and backward incompatible breaking changes +--------------------------------------------------- + +- Restore decoding ``%2F`` (``/``) in ``URL.path`` -- by `@bdraco `__. + + This change restored the behavior before `#1057 `__. + + *Related issues and pull requests on GitHub:* + `#1151 `__. + + +Miscellaneous internal changes +------------------------------ + +- Improved performance of processing paths -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1143 `__. + + +---- + + +1.11.1 +====== + +*(2024-09-09)* + + +Bug fixes +--------- + +- Allowed scheme replacement for relative URLs if the scheme does not require a host -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#280 `__, `#1138 `__. + +- Allowed empty host for URL schemes other than the special schemes listed in the WHATWG URL spec -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1136 `__. + + +Features +-------- + +- Loosened restriction on integers as query string values to allow classes that implement ``__int__`` -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1139 `__. + + +Miscellaneous internal changes +------------------------------ + +- Improved performance of normalizing paths -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1137 `__. + + +---- + + +1.11.0 +====== + +*(2024-09-08)* + + +Features +-------- + +- Added ``URL.extend_query()()`` method, which can be used to extend parameters without replacing same named keys -- by `@bdraco `__. + + This method was primarily added to replace the inefficient hand rolled method currently used in ``aiohttp``. + + *Related issues and pull requests on GitHub:* + `#1128 `__. + + +Miscellaneous internal changes +------------------------------ + +- Improved performance of the Cython ``cached_property`` implementation -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1122 `__. + +- Simplified computing ports by removing unnecessary code -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1123 `__. + +- Improved performance of encoding non IPv6 hosts -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1125 `__. + +- Improved performance of ``URL.build()()`` when the path, query string, or fragment is an empty string -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1126 `__. + +- Improved performance of the ``URL.update_query()()`` method -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1130 `__. + +- Improved performance of processing query string changes when arguments are ``str`` -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1131 `__. + + +---- + + +1.10.0 +====== + +*(2024-09-06)* + + +Bug fixes +--------- + +- Fixed joining a path when the existing path was empty -- by `@bdraco `__. + + A regression in ``URL.join()()`` was introduced in `#1082 `__. + + *Related issues and pull requests on GitHub:* + `#1118 `__. + + +Features +-------- + +- Added ``URL.without_query_params()()`` method, to drop some parameters from query string -- by `@hongquan `__. + + *Related issues and pull requests on GitHub:* + `#774 `__, `#898 `__, `#1010 `__. + +- The previously protected types ``_SimpleQuery``, ``_QueryVariable``, and ``_Query`` are now available for use externally as ``SimpleQuery``, ``QueryVariable``, and ``Query`` -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1050 `__, `#1113 `__. + + +Contributor-facing changes +-------------------------- + +- Replaced all ``~typing.Optional`` with ``~typing.Union`` -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1095 `__. + + +Miscellaneous internal changes +------------------------------ + +- Significantly improved performance of parsing the network location -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1112 `__. + +- Added internal types to the cache to prevent future refactoring errors -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1117 `__. + + +---- + + +1.9.11 +====== + +*(2024-09-04)* + + +Bug fixes +--------- + +- Fixed a ``TypeError`` with ``MultiDictProxy`` and Python 3.8 -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1084 `__, `#1105 `__, `#1107 `__. + + +Miscellaneous internal changes +------------------------------ + +- Improved performance of encoding hosts -- by `@bdraco `__. + + Previously, the library would unconditionally try to parse a host as an IP Address. The library now avoids trying to parse a host as an IP Address if the string is not in one of the formats described in ``3986#section-3.2.2``. + + *Related issues and pull requests on GitHub:* + `#1104 `__. + + +---- + + +1.9.10 +====== + +*(2024-09-04)* + + +Bug fixes +--------- + +- ``URL.join()()`` has been changed to match + ``3986`` and align with + ``/ operation()`` and ``URL.joinpath()()`` + when joining URLs with empty segments. + Previously ``urllib.parse.urljoin`` was used, + which has known issues with empty segments + (`python/cpython#84774 `_). + + Due to the semantics of ``URL.join()()``, joining an + URL with scheme requires making it relative, prefixing with ``./``. + + .. code-block:: pycon + + >>> URL("https://web.archive.org/web/").join(URL("./https://github.com/aio-libs/yarl")) + URL('https://web.archive.org/web/https://github.com/aio-libs/yarl') + + + Empty segments are honored in the base as well as the joined part. + + .. code-block:: pycon + + >>> URL("https://web.archive.org/web/https://").join(URL("github.com/aio-libs/yarl")) + URL('https://web.archive.org/web/https://github.com/aio-libs/yarl') + + + + -- by `@commonism `__ + + This change initially appeared in 1.9.5 but was reverted in 1.9.6 to resolve a problem with query string handling. + + *Related issues and pull requests on GitHub:* + `#1039 `__, `#1082 `__. + + +Features +-------- + +- Added ``~yarl.URL.absolute`` which is now preferred over ``URL.is_absolute()`` -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1100 `__. + + +---- + + +1.9.9 +===== + +*(2024-09-04)* + + +Bug fixes +--------- + +- Added missing type on ``~yarl.URL.port`` -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1097 `__. + + +---- + + +1.9.8 +===== + +*(2024-09-03)* + + +Features +-------- + +- Covered the ``~yarl.URL`` object with types -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1084 `__. + +- Cache parsing of IP Addresses when encoding hosts -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1086 `__. + + +Contributor-facing changes +-------------------------- + +- Covered the ``~yarl.URL`` object with types -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1084 `__. + + +Miscellaneous internal changes +------------------------------ + +- Improved performance of handling ports -- by `@bdraco `__. + + *Related issues and pull requests on GitHub:* + `#1081 `__. + + +---- + + +1.9.7 +===== + +*(2024-09-01)* + + +Removals and backward incompatible breaking changes +--------------------------------------------------- + +- Removed support ``3986#section-3.2.3`` port normalization when the scheme is not one of ``http``, ``https``, ``wss``, or ``ws`` -- by `@bdraco `__. + + Support for port normalization was recently added in `#1033 `__ and contained code that would do blocking I/O if the scheme was not one of the four listed above. The code has been removed because this library is intended to be safe for usage with ``asyncio``. + + *Related issues and pull requests on GitHub:* + `#1076 `__. + + +Miscellaneous internal changes +------------------------------ + +- Improved performance of property caching -- by `@bdraco `__. + + The ``reify`` implementation from ``aiohttp`` was adapted to replace the internal ``cached_property`` implementation. + + *Related issues and pull requests on GitHub:* + `#1070 `__. + + +---- + + +1.9.6 +===== + +*(2024-08-30)* + + +Bug fixes +--------- + +- Reverted ``3986`` compatible ``URL.join()()`` honoring empty segments which was introduced in `#1039 `__. + + This change introduced a regression handling query string parameters with joined URLs. The change was reverted to maintain compatibility with the previous behavior. + + *Related issues and pull requests on GitHub:* + `#1067 `__. + + +---- + + +1.9.5 +===== + +*(2024-08-30)* + + +Bug fixes +--------- + +- Joining URLs with empty segments has been changed + to match ``3986``. + + Previously empty segments would be removed from path, + breaking use-cases such as + + .. code-block:: python + + URL("https://web.archive.org/web/") / "https://github.com/" + + Now ``/ operation()`` and ``URL.joinpath()()`` + keep empty segments, but do not introduce new empty segments. + e.g. + + .. code-block:: python + + URL("https://example.org/") / "" + + does not introduce an empty segment. + + -- by `@commonism `__ and `@youtux `__ + + *Related issues and pull requests on GitHub:* + `#1026 `__. + +- The default protocol ports of well-known URI schemes are now taken into account + during the normalization of the URL string representation in accordance with + ``3986#section-3.2.3``. + + Specified ports are removed from the ``str`` representation of a ``~yarl.URL`` + if the port matches the scheme's default port -- by `@commonism `__. + + *Related issues and pull requests on GitHub:* + `#1033 `__. + +- ``URL.join()()`` has been changed to match + ``3986`` and align with + ``/ operation()`` and ``URL.joinpath()()`` + when joining URLs with empty segments. + Previously ``urllib.parse.urljoin`` was used, + which has known issues with empty segments + (`python/cpython#84774 `_). + + Due to the semantics of ``URL.join()()``, joining an + URL with scheme requires making it relative, prefixing with ``./``. + + .. code-block:: pycon + + >>> URL("https://web.archive.org/web/").join(URL("./https://github.com/aio-libs/yarl")) + URL('https://web.archive.org/web/https://github.com/aio-libs/yarl') + + + Empty segments are honored in the base as well as the joined part. + + .. code-block:: pycon + + >>> URL("https://web.archive.org/web/https://").join(URL("github.com/aio-libs/yarl")) + URL('https://web.archive.org/web/https://github.com/aio-libs/yarl') + + + + -- by `@commonism `__ + + *Related issues and pull requests on GitHub:* + `#1039 `__. + + +Removals and backward incompatible breaking changes +--------------------------------------------------- + +- Stopped decoding ``%2F`` (``/``) in ``URL.path``, as this could lead to code incorrectly treating it as a path separator + -- by `@Dreamsorcerer `__. + + *Related issues and pull requests on GitHub:* + `#1057 `__. + +- Dropped support for Python 3.7 -- by `@Dreamsorcerer `__. + + *Related issues and pull requests on GitHub:* + `#1016 `__. + + +Improved documentation +---------------------- + +- On the ``Contributing docs`` page, + a link to the ``Towncrier philosophy`` has been fixed. + + *Related issues and pull requests on GitHub:* + `#981 `__. + +- The pre-existing ``/ magic method()`` + has been documented in the API reference -- by `@commonism `__. + + *Related issues and pull requests on GitHub:* + `#1026 `__. + + +Packaging updates and notes for downstreams +------------------------------------------- + +- A flaw in the logic for copying the project directory into a + temporary folder that led to infinite recursion when ``TMPDIR`` + was set to a project subdirectory path. This was happening in Fedora + and its downstream due to the use of `pyproject-rpm-macros + `__. It was + only reproducible with ``pip wheel`` and was not affecting the + ``pyproject-build`` users. + + -- by `@hroncok `__ and `@webknjaz `__ + + *Related issues and pull requests on GitHub:* + `#992 `__, `#1014 `__. + +- Support Python 3.13 and publish non-free-threaded wheels + + *Related issues and pull requests on GitHub:* + `#1054 `__. + + +Contributor-facing changes +-------------------------- + +- The CI/CD setup has been updated to test ``arm64`` wheels + under macOS 14, except for Python 3.7 that is unsupported + in that environment -- by `@webknjaz `__. + + *Related issues and pull requests on GitHub:* + `#1015 `__. + +- Removed unused type ignores and casts -- by `@hauntsaninja `__. + + *Related issues and pull requests on GitHub:* + `#1031 `__. + + +Miscellaneous internal changes +------------------------------ + +- ``port``, ``scheme``, and ``raw_host`` are now ``cached_property`` -- by `@bdraco `__. + + ``aiohttp`` accesses these properties quite often, which cause ``urllib`` to build the ``_hostinfo`` property every time. ``port``, ``scheme``, and ``raw_host`` are now cached properties, which will improve performance. + + *Related issues and pull requests on GitHub:* + `#1044 `__, `#1058 `__. + + +---- + + +1.9.4 (2023-12-06) +================== + +Bug fixes +--------- + +- Started raising ``TypeError`` when a string value is passed into + ``yarl.URL.build()`` as the ``port`` argument -- by `@commonism `__. + + Previously the empty string as port would create malformed URLs when rendered as string representations. (`#883 `__) + + +Packaging updates and notes for downstreams +------------------------------------------- + +- The leading ``--`` has been dropped from the `PEP 517 `__ in-tree build + backend config setting names. ``--pure-python`` is now just ``pure-python`` + -- by `@webknjaz `__. + + The usage now looks as follows: + + .. code-block:: console + + $ python -m build \ + --config-setting=pure-python=true \ + --config-setting=with-cython-tracing=true + + (`#963 `__) + + +Contributor-facing changes +-------------------------- + +- A step-by-step ``Release Guide`` guide has + been added, describing how to release *yarl* -- by `@webknjaz `__. + + This is primarily targeting maintainers. (`#960 `__) +- Coverage collection has been implemented for the Cython modules + -- by `@webknjaz `__. + + It will also be reported to Codecov from any non-release CI jobs. + + To measure coverage in a development environment, *yarl* can be + installed in editable mode: + + .. code-block:: console + + $ python -Im pip install -e . + + Editable install produces C-files required for the Cython coverage + plugin to map the measurements back to the PYX-files. + + `#961 `__ + +- It is now possible to request line tracing in Cython builds using the + ``with-cython-tracing`` `PEP 517 `__ config setting + -- `@webknjaz `__. + + This can be used in CI and development environment to measure coverage + on Cython modules, but is not normally useful to the end-users or + downstream packagers. + + Here's a usage example: + + .. code-block:: console + + $ python -Im pip install . --config-settings=with-cython-tracing=true + + For editable installs, this setting is on by default. Otherwise, it's + off unless requested explicitly. + + The following produces C-files required for the Cython coverage + plugin to map the measurements back to the PYX-files: + + .. code-block:: console + + $ python -Im pip install -e . + + Alternatively, the ``YARL_CYTHON_TRACING=1`` environment variable + can be set to do the same as the `PEP 517 `__ config setting. + + `#962 `__ + + +1.9.3 (2023-11-20) +================== + +Bug fixes +--------- + +- Stopped dropping trailing slashes in ``yarl.URL.joinpath()`` -- by `@gmacon `__. (`#862 `__, `#866 `__) +- Started accepting string subclasses in ``yarl.URL.__truediv__()`` operations (``URL / segment``) -- by `@mjpieters `__. (`#871 `__, `#884 `__) +- Fixed the human representation of URLs with square brackets in usernames and passwords -- by `@mjpieters `__. (`#876 `__, `#882 `__) +- Updated type hints to include ``URL.missing_port()``, ``URL.__bytes__()`` + and the ``encoding`` argument to ``yarl.URL.joinpath()`` + -- by `@mjpieters `__. (`#891 `__) + + +Packaging updates and notes for downstreams +------------------------------------------- + +- Integrated Cython 3 to enable building *yarl* under Python 3.12 -- by `@mjpieters `__. (`#829 `__, `#881 `__) +- Declared modern ``setuptools.build_meta`` as the `PEP 517 `__ build + backend in ``pyproject.toml`` explicitly -- by `@webknjaz `__. (`#886 `__) +- Converted most of the packaging setup into a declarative ``setup.cfg`` + config -- by `@webknjaz `__. (`#890 `__) +- The packaging is replaced from an old-fashioned ``setup.py`` to an + in-tree `PEP 517 `__ build backend -- by `@webknjaz `__. + + Whenever the end-users or downstream packagers need to build ``yarl`` from + source (a Git checkout or an sdist), they may pass a ``config_settings`` + flag ``--pure-python``. If this flag is not set, a C-extension will be built + and included into the distribution. + + Here is how this can be done with ``pip``: + + .. code-block:: console + + $ python -m pip install . --config-settings=--pure-python=false + + This will also work with ``-e | --editable``. + + The same can be achieved via ``pypa/build``: + + .. code-block:: console + + $ python -m build --config-setting=--pure-python=false + + Adding ``-w | --wheel`` can force ``pypa/build`` produce a wheel from source + directly, as opposed to building an ``sdist`` and then building from it. (`#893 `__) + + .. attention:: + + v1.9.3 was the only version using the ``--pure-python`` setting name. + Later versions dropped the ``--`` prefix, making it just ``pure-python``. + +- Declared Python 3.12 supported officially in the distribution package metadata + -- by `@edgarrmondragon `__. (`#942 `__) + + +Contributor-facing changes +-------------------------- + +- A regression test for no-host URLs was added per `#821 `__ + and ``3986`` -- by `@kenballus `__. (`#821 `__, `#822 `__) +- Started testing *yarl* against Python 3.12 in CI -- by `@mjpieters `__. (`#881 `__) +- All Python 3.12 jobs are now marked as required to pass in CI + -- by `@edgarrmondragon `__. (`#942 `__) +- MyST is now integrated in Sphinx -- by `@webknjaz `__. + + This allows the contributors to author new documents in Markdown + when they have difficulties with going straight RST. (`#953 `__) + + +1.9.2 (2023-04-25) +================== + +Bugfixes +-------- + +- Fix regression with ``yarl.URL.__truediv__()`` and absolute URLs with empty paths causing the raw path to lack the leading ``/``. + (`#854 `_) + + +1.9.1 (2023-04-21) +================== + +Bugfixes +-------- + +- Marked tests that fail on older Python patch releases (< 3.7.10, < 3.8.8 and < 3.9.2) as expected to fail due to missing a security fix for CVE-2021-23336. (`#850 `_) + + +1.9.0 (2023-04-19) +================== + +This release was never published to PyPI, due to issues with the build process. + +Features +-------- + +- Added ``URL.joinpath(*elements)``, to create a new URL appending multiple path elements. (`#704 `_) +- Made ``URL.__truediv__()()`` return ``NotImplemented`` if called with an + unsupported type — by `@michaeljpeters `__. + (`#832 `_) + + +Bugfixes +-------- + +- Path normalization for absolute URLs no longer raises a ValueError exception + when ``..`` segments would otherwise go beyond the URL path root. + (`#536 `_) +- Fixed an issue with update_query() not getting rid of the query when argument is None. (`#792 `_) +- Added some input restrictions on with_port() function to prevent invalid boolean inputs or out of valid port inputs; handled incorrect 0 port representation. (`#793 `_) +- Made ``yarl.URL.build()`` raise a ``TypeError`` if the ``host`` argument is ``None`` — by `@paulpapacz `__. (`#808 `_) +- Fixed an issue with ``update_query()`` getting rid of the query when the argument + is empty but not ``None``. (`#845 `_) + + +Misc +---- + +- `#220 `_ + + +1.8.2 (2022-12-03) +================== + +This is the first release that started shipping wheels for Python 3.11. + + +1.8.1 (2022-08-01) +================== + +Misc +---- + +- `#694 `_, `#699 `_, `#700 `_, `#701 `_, `#702 `_, `#703 `_, `#739 `_ + + +1.8.0 (2022-08-01) +================== + +Features +-------- + +- Added ``URL.raw_suffix``, ``URL.suffix``, ``URL.raw_suffixes``, ``URL.suffixes``, ``URL.with_suffix``. (`#613 `_) + + +Improved Documentation +---------------------- + +- Fixed broken internal references to ``yarl.URL.human_repr()``. + (`#665 `_) +- Fixed broken external references to ``multidict:index`` docs. (`#665 `_) + + +Deprecations and Removals +------------------------- + +- Dropped Python 3.6 support. (`#672 `_) + + +Misc +---- + +- `#646 `_, `#699 `_, `#701 `_ + + +1.7.2 (2021-11-01) +================== + +Bugfixes +-------- + +- Changed call in ``with_port()`` to stop reencoding parts of the URL that were already encoded. (`#623 `_) + + +1.7.1 (2021-10-07) +================== + +Bugfixes +-------- + +- Fix 1.7.0 build error + +1.7.0 (2021-10-06) +================== + +Features +-------- + +- Add ``__bytes__()`` magic method so that ``bytes(url)`` will work and use optimal ASCII encoding. + (`#582 `_) +- Started shipping platform-specific arm64 wheels for Apple Silicon. (`#622 `_) +- Started shipping platform-specific wheels with the ``musl`` tag targeting typical Alpine Linux runtimes. (`#622 `_) +- Added support for Python 3.10. (`#622 `_) + + +1.6.3 (2020-11-14) +================== + +Bugfixes +-------- + +- No longer loose characters when decoding incorrect percent-sequences (like ``%e2%82%f8``). All non-decodable percent-sequences are now preserved. + `#517 `_ +- Provide x86 Windows wheels. + `#535 `_ + + +---- + + +1.6.2 (2020-10-12) +================== + + +Bugfixes +-------- + +- Provide generated ``.c`` files in TarBall distribution. + `#530 `_ + +1.6.1 (2020-10-12) +================== + +Features +-------- + +- Provide wheels for ``aarch64``, ``i686``, ``ppc64le``, ``s390x`` architectures on + Linux as well as ``x86_64``. + `#507 `_ +- Provide wheels for Python 3.9. + `#526 `_ + +Bugfixes +-------- + +- ``human_repr()`` now always produces valid representation equivalent to the original URL (if the original URL is valid). + `#511 `_ +- Fixed requoting a single percent followed by a percent-encoded character in the Cython implementation. + `#514 `_ +- Fix ValueError when decoding ``%`` which is not followed by two hexadecimal digits. + `#516 `_ +- Fix decoding ``%`` followed by a space and hexadecimal digit. + `#520 `_ +- Fix annotation of ``with_query()``/``update_query()`` methods for ``key=[val1, val2]`` case. + `#528 `_ + +Removal +------- + +- Drop Python 3.5 support; Python 3.6 is the minimal supported Python version. + + +---- + + +1.6.0 (2020-09-23) +================== + +Features +-------- + +- Allow for int and float subclasses in query, while still denying bool. + `#492 `_ + + +Bugfixes +-------- + +- Do not requote arguments in ``URL.build()``, ``with_xxx()`` and in ``/`` operator. + `#502 `_ +- Keep IPv6 brackets in ``origin()``. + `#504 `_ + + +---- + + +1.5.1 (2020-08-01) +================== + +Bugfixes +-------- + +- Fix including relocated internal ``yarl._quoting_c`` C-extension into published PyPI dists. + `#485 `_ + + +Misc +---- + +- `#484 `_ + + +---- + + +1.5.0 (2020-07-26) +================== + +Features +-------- + +- Convert host to lowercase on URL building. + `#386 `_ +- Allow using ``mod`` operator (``%``) for updating query string (an alias for ``update_query()`` method). + `#435 `_ +- Allow use of sequences such as ``list`` and ``tuple`` in the values + of a mapping such as ``dict`` to represent that a key has many values:: + + url = URL("http://example.com") + assert url.with_query({"a": [1, 2]}) == URL("http://example.com/?a=1&a=2") + + `#443 `_ +- Support ``URL.build()`` with scheme and path (creates a relative URL). + `#464 `_ +- Cache slow IDNA encode/decode calls. + `#476 `_ +- Add ``@final`` / ``Final`` type hints + `#477 `_ +- Support URL authority/raw_authority properties and authority argument of ``URL.build()`` method. + `#478 `_ +- Hide the library implementation details, make the exposed public list very clean. + `#483 `_ + + +Bugfixes +-------- + +- Fix tests with newer Python (3.7.6, 3.8.1 and 3.9.0+). + `#409 `_ +- Fix a bug where query component, passed in a form of mapping or sequence, is unquoted in unexpected way. + `#426 `_ +- Hide ``Query`` and ``QueryVariable`` type aliases in ``__init__.pyi``, now they are prefixed with underscore. + `#431 `_ +- Keep IPv6 brackets after updating port/user/password. + `#451 `_ + + +---- + + +1.4.2 (2019-12-05) +================== + +Features +-------- + +- Workaround for missing ``str.isascii()`` in Python 3.6 + `#389 `_ + + +---- + + +1.4.1 (2019-11-29) +================== + +* Fix regression, make the library work on Python 3.5 and 3.6 again. + +1.4.0 (2019-11-29) +================== + +* Distinguish an empty password in URL from a password not provided at all (#262) + +* Fixed annotations for optional parameters of ``URL.build`` (#309) + +* Use None as default value of ``user`` parameter of ``URL.build`` (#309) + +* Enforce building C Accelerated modules when installing from source tarball, use + ``YARL_NO_EXTENSIONS`` environment variable for falling back to (slower) Pure Python + implementation (#329) + +* Drop Python 3.5 support + +* Fix quoting of plus in path by pure python version (#339) + +* Don't create a new URL if fragment is unchanged (#292) + +* Included in error message the path that produces starting slash forbidden error (#376) + +* Skip slow IDNA encoding for ASCII-only strings (#387) + + +1.3.0 (2018-12-11) +================== + +* Fix annotations for ``query`` parameter (#207) + +* An incoming query sequence can have int variables (the same as for + Mapping type) (#208) + +* Add ``URL.explicit_port`` property (#218) + +* Give a friendlier error when port can't be converted to int (#168) + +* ``bool(URL())`` now returns ``False`` (#272) + +1.2.6 (2018-06-14) +================== + +* Drop Python 3.4 trove classifier (#205) + +1.2.5 (2018-05-23) +================== + +* Fix annotations for ``build`` (#199) + +1.2.4 (2018-05-08) +================== + +* Fix annotations for ``cached_property`` (#195) + +1.2.3 (2018-05-03) +================== + +* Accept ``str`` subclasses in ``URL`` constructor (#190) + +1.2.2 (2018-05-01) +================== + +* Fix build + +1.2.1 (2018-04-30) +================== + +* Pin minimal required Python to 3.5.3 (#189) + +1.2.0 (2018-04-30) +================== + +* Forbid inheritance, replace ``__init__`` with ``__new__`` (#171) + +* Support PEP-561 (provide type hinting marker) (#182) + +1.1.1 (2018-02-17) +================== + +* Fix performance regression: don't encode empty ``netloc`` (#170) + +1.1.0 (2018-01-21) +================== + +* Make pure Python quoter consistent with Cython version (#162) + +1.0.0 (2018-01-15) +================== + +* Use fast path if quoted string does not need requoting (#154) + +* Speed up quoting/unquoting by ``_Quoter`` and ``_Unquoter`` classes (#155) + +* Drop ``yarl.quote`` and ``yarl.unquote`` public functions (#155) + +* Add custom string writer, reuse static buffer if available (#157) + Code is 50-80 times faster than Pure Python version (was 4-5 times faster) + +* Don't recode IP zone (#144) + +* Support ``encoded=True`` in ``yarl.URL.build()`` (#158) + +* Fix updating query with multiple keys (#160) + +0.18.0 (2018-01-10) +=================== + +* Fallback to IDNA 2003 if domain name is not IDNA 2008 compatible (#152) + +0.17.0 (2017-12-30) +=================== + +* Use IDNA 2008 for domain name processing (#149) + +0.16.0 (2017-12-07) +=================== + +* Fix raising ``TypeError`` by ``url.query_string()`` after + ``url.with_query({})`` (empty mapping) (#141) + +0.15.0 (2017-11-23) +=================== + +* Add ``raw_path_qs`` attribute (#137) + +0.14.2 (2017-11-14) +=================== + +* Restore ``strict`` parameter as no-op in ``quote`` / ``unquote`` + +0.14.1 (2017-11-13) +=================== + +* Restore ``strict`` parameter as no-op for sake of compatibility with + aiohttp 2.2 + +0.14.0 (2017-11-11) +=================== + +* Drop strict mode (#123) + +* Fix ``"ValueError: Unallowed PCT %"`` when there's a ``"%"`` in the URL (#124) + +0.13.0 (2017-10-01) +=================== + +* Document ``encoded`` parameter (#102) + +* Support relative URLs like ``'?key=value'`` (#100) + +* Unsafe encoding for QS fixed. Encode ``;`` character in value parameter (#104) + +* Process passwords without user names (#95) + +0.12.0 (2017-06-26) +=================== + +* Properly support paths without leading slash in ``URL.with_path()`` (#90) + +* Enable type annotation checks + +0.11.0 (2017-06-26) +=================== + +* Normalize path (#86) + +* Clear query and fragment parts in ``.with_path()`` (#85) + +0.10.3 (2017-06-13) +=================== + +* Prevent double URL arguments unquoting (#83) + +0.10.2 (2017-05-05) +=================== + +* Unexpected hash behavior (#75) + + +0.10.1 (2017-05-03) +=================== + +* Unexpected compare behavior (#73) + +* Do not quote or unquote + if not a query string. (#74) + + +0.10.0 (2017-03-14) +=================== + +* Added ``URL.build`` class method (#58) + +* Added ``path_qs`` attribute (#42) + + +0.9.8 (2017-02-16) +================== + +* Do not quote ``:`` in path + + +0.9.7 (2017-02-16) +================== + +* Load from pickle without _cache (#56) + +* Percent-encoded pluses in path variables become spaces (#59) + + +0.9.6 (2017-02-15) +================== + +* Revert backward incompatible change (BaseURL) + + +0.9.5 (2017-02-14) +================== + +* Fix BaseURL rich comparison support + + +0.9.4 (2017-02-14) +================== + +* Use BaseURL + + +0.9.3 (2017-02-14) +================== + +* Added BaseURL + + +0.9.2 (2017-02-08) +================== + +* Remove debug print + + +0.9.1 (2017-02-07) +================== + +* Do not lose tail chars (#45) + + +0.9.0 (2017-02-07) +================== + +* Allow to quote ``%`` in non strict mode (#21) + +* Incorrect parsing of query parameters with %3B (;) inside (#34) + +* Fix core dumps (#41) + +* ``tmpbuf`` - compiling error (#43) + +* Added ``URL.update_path()`` method + +* Added ``URL.update_query()`` method (#47) + + +0.8.1 (2016-12-03) +================== + +* Fix broken aiohttp: revert back ``quote`` / ``unquote``. + + +0.8.0 (2016-12-03) +================== + +* Support more verbose error messages in ``.with_query()`` (#24) + +* Don't percent-encode ``@`` and ``:`` in path (#32) + +* Don't expose ``yarl.quote`` and ``yarl.unquote``, these functions are + part of private API + +0.7.1 (2016-11-18) +================== + +* Accept not only ``str`` but all classes inherited from ``str`` also (#25) + +0.7.0 (2016-11-07) +================== + +* Accept ``int`` as value for ``.with_query()`` + +0.6.0 (2016-11-07) +================== + +* Explicitly use UTF8 encoding in ``setup.py`` (#20) +* Properly unquote non-UTF8 strings (#19) + +0.5.3 (2016-11-02) +================== + +* Don't use ``typing.NamedTuple`` fields but indexes on URL construction + +0.5.2 (2016-11-02) +================== + +* Inline ``_encode`` class method + +0.5.1 (2016-11-02) +================== + +* Make URL construction faster by removing extra classmethod calls + +0.5.0 (2016-11-02) +================== + +* Add Cython optimization for quoting/unquoting +* Provide binary wheels + +0.4.3 (2016-09-29) +================== + +* Fix typing stubs + +0.4.2 (2016-09-29) +================== + +* Expose ``quote()`` and ``unquote()`` as public API + +0.4.1 (2016-09-28) +================== + +* Support empty values in query (``'/path?arg'``) + +0.4.0 (2016-09-27) +================== + +* Introduce ``relative()`` (#16) + +0.3.2 (2016-09-27) +================== + +* Typo fixes #15 + +0.3.1 (2016-09-26) +================== + +* Support sequence of pairs as ``with_query()`` parameter + +0.3.0 (2016-09-26) +================== + +* Introduce ``is_default_port()`` + +0.2.1 (2016-09-26) +================== + +* Raise ValueError for URLs like 'http://:8080/' + +0.2.0 (2016-09-18) +================== + +* Avoid doubling slashes when joining paths (#13) + +* Appending path starting from slash is forbidden (#12) + +0.1.4 (2016-09-09) +================== + +* Add ``kwargs`` support for ``with_query()`` (#10) + +0.1.3 (2016-09-07) +================== + +* Document ``with_query()``, ``with_fragment()`` and ``origin()`` + +* Allow ``None`` for ``with_query()`` and ``with_fragment()`` + +0.1.2 (2016-09-07) +================== + +* Fix links, tune docs theme. + +0.1.1 (2016-09-06) +================== + +* Update README, old version used obsolete API + +0.1.0 (2016-09-06) +================== + +* The library was deeply refactored, bytes are gone away but all + accepted strings are encoded if needed. + +0.0.1 (2016-08-30) +================== + +* The first release. diff --git a/source/yarl-1.22.0.dist-info/RECORD b/source/yarl-1.22.0.dist-info/RECORD new file mode 100644 index 0000000000000000000000000000000000000000..7e4c78b4e737749744b3d3d26983e02d5563640d --- /dev/null +++ b/source/yarl-1.22.0.dist-info/RECORD @@ -0,0 +1,26 @@ +yarl-1.22.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +yarl-1.22.0.dist-info/METADATA,sha256=gATDmMXVC53tnEFDwFcOd1skK5FM8kJNbY8GFkCRDLg,75118 +yarl-1.22.0.dist-info/RECORD,, +yarl-1.22.0.dist-info/WHEEL,sha256=DxRnWQz-Kp9-4a4hdDHsSv0KUC3H7sN9Nbef3-8RjXU,190 +yarl-1.22.0.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358 +yarl-1.22.0.dist-info/licenses/NOTICE,sha256=VtasbIEFwKUTBMIdsGDjYa-ajqCvmnXCOcKLXRNpODg,609 +yarl-1.22.0.dist-info/top_level.txt,sha256=vf3SJuQh-k7YtvsUrV_OPOrT9Kqn0COlk7IPYyhtGkQ,5 +yarl/__init__.py,sha256=woYZp7KGli7_1P_hR7ZU9ckEj6ho41smyP-PLfEL-lk,281 +yarl/__pycache__/__init__.cpython-312.pyc,, +yarl/__pycache__/_parse.cpython-312.pyc,, +yarl/__pycache__/_path.cpython-312.pyc,, +yarl/__pycache__/_query.cpython-312.pyc,, +yarl/__pycache__/_quoters.cpython-312.pyc,, +yarl/__pycache__/_quoting.cpython-312.pyc,, +yarl/__pycache__/_quoting_py.cpython-312.pyc,, +yarl/__pycache__/_url.cpython-312.pyc,, +yarl/_parse.py,sha256=gNt8zxVFGr95ufUQpSMiiZ9vDrvg4zq6MEtT3f6_8J0,7185 +yarl/_path.py,sha256=A0FJUylZyzmlT0a3UDOBbK-EzZXCAYuQQBvG9eAC9hs,1291 +yarl/_query.py,sha256=nwGAYewdOU8nt5YZNZxqQ4BGES82Y3Y6LanxqTjnZxw,4068 +yarl/_quoters.py,sha256=z-BzsXfLnJK-bd-HrGaoKGri9L3GpDv6vxFEtmu-uCM,1154 +yarl/_quoting.py,sha256=yKIqFTzFzWLVb08xy1DSxKNjFwo4f-oLlzxTuKwC57M,506 +yarl/_quoting_c.cpython-312-x86_64-linux-gnu.so,sha256=TapJ2HY3skYb6kJkICWISGliO6JvMgDiQtRQcD8JMxA,1170216 +yarl/_quoting_c.pyx,sha256=X40gvQSUB4l7nPKGeiS6pq2JreM36avLhVeBMxd5zmo,14297 +yarl/_quoting_py.py,sha256=7WD7IHhgaJiLZWoIewvB0JRUsbz9McmfZw5TnjlVs9o,6783 +yarl/_url.py,sha256=4K5gCdoQtVi9FmnQdssEqafdlJILKxSap8RNCBC4IGE,55608 +yarl/py.typed,sha256=ay5OMO475PlcZ_Fbun9maHW7Y6MBTk0UXL4ztHx3Iug,14 diff --git a/source/yarl-1.22.0.dist-info/WHEEL b/source/yarl-1.22.0.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..f3e8a970f16adf4526f1722547053522d94bf860 --- /dev/null +++ b/source/yarl-1.22.0.dist-info/WHEEL @@ -0,0 +1,7 @@ +Wheel-Version: 1.0 +Generator: setuptools (80.9.0) +Root-Is-Purelib: false +Tag: cp312-cp312-manylinux_2_17_x86_64 +Tag: cp312-cp312-manylinux2014_x86_64 +Tag: cp312-cp312-manylinux_2_28_x86_64 + diff --git a/source/yarl-1.22.0.dist-info/licenses/LICENSE b/source/yarl-1.22.0.dist-info/licenses/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..d645695673349e3947e8e5ae42332d0ac3164cd7 --- /dev/null +++ b/source/yarl-1.22.0.dist-info/licenses/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/source/yarl-1.22.0.dist-info/licenses/NOTICE b/source/yarl-1.22.0.dist-info/licenses/NOTICE new file mode 100644 index 0000000000000000000000000000000000000000..fa53b2b138df881c4c95239d0e4bede831b36ab5 --- /dev/null +++ b/source/yarl-1.22.0.dist-info/licenses/NOTICE @@ -0,0 +1,13 @@ + Copyright 2016-2021, Andrew Svetlov and aio-libs team + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/source/yarl-1.22.0.dist-info/top_level.txt b/source/yarl-1.22.0.dist-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..e93e8bddefb14a8a753f7ecab6b934fd899cd9e5 --- /dev/null +++ b/source/yarl-1.22.0.dist-info/top_level.txt @@ -0,0 +1 @@ +yarl diff --git a/source/yarl/__init__.py b/source/yarl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e45554be4aa1babca506934bec6c0800dbe8adaa --- /dev/null +++ b/source/yarl/__init__.py @@ -0,0 +1,14 @@ +from ._query import Query, QueryVariable, SimpleQuery +from ._url import URL, cache_clear, cache_configure, cache_info + +__version__ = "1.22.0" + +__all__ = ( + "URL", + "SimpleQuery", + "QueryVariable", + "Query", + "cache_clear", + "cache_configure", + "cache_info", +) diff --git a/source/yarl/_parse.py b/source/yarl/_parse.py new file mode 100644 index 0000000000000000000000000000000000000000..115d772360e61f4322eb72ced557d42a30930518 --- /dev/null +++ b/source/yarl/_parse.py @@ -0,0 +1,203 @@ +"""URL parsing utilities.""" + +import re +import unicodedata +from functools import lru_cache +from typing import Union +from urllib.parse import scheme_chars, uses_netloc + +from ._quoters import QUOTER, UNQUOTER_PLUS + +# Leading and trailing C0 control and space to be stripped per WHATWG spec. +# == "".join([chr(i) for i in range(0, 0x20 + 1)]) +WHATWG_C0_CONTROL_OR_SPACE = ( + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10" + "\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f " +) + +# Unsafe bytes to be removed per WHATWG spec +UNSAFE_URL_BYTES_TO_REMOVE = ["\t", "\r", "\n"] +USES_AUTHORITY = frozenset(uses_netloc) + +SplitURLType = tuple[str, str, str, str, str] + + +def split_url(url: str) -> SplitURLType: + """Split URL into parts.""" + # Adapted from urllib.parse.urlsplit + # Only lstrip url as some applications rely on preserving trailing space. + # (https://url.spec.whatwg.org/#concept-basic-url-parser would strip both) + url = url.lstrip(WHATWG_C0_CONTROL_OR_SPACE) + for b in UNSAFE_URL_BYTES_TO_REMOVE: + if b in url: + url = url.replace(b, "") + + scheme = netloc = query = fragment = "" + i = url.find(":") + if i > 0 and url[0] in scheme_chars: + for c in url[1:i]: + if c not in scheme_chars: + break + else: + scheme, url = url[:i].lower(), url[i + 1 :] + has_hash = "#" in url + has_question_mark = "?" in url + if url[:2] == "//": + delim = len(url) # position of end of domain part of url, default is end + if has_hash and has_question_mark: + delim_chars = "/?#" + elif has_question_mark: + delim_chars = "/?" + elif has_hash: + delim_chars = "/#" + else: + delim_chars = "/" + for c in delim_chars: # look for delimiters; the order is NOT important + wdelim = url.find(c, 2) # find first of this delim + if wdelim >= 0 and wdelim < delim: # if found + delim = wdelim # use earliest delim position + netloc = url[2:delim] + url = url[delim:] + has_left_bracket = "[" in netloc + has_right_bracket = "]" in netloc + if (has_left_bracket and not has_right_bracket) or ( + has_right_bracket and not has_left_bracket + ): + raise ValueError("Invalid IPv6 URL") + if has_left_bracket: + bracketed_host = netloc.partition("[")[2].partition("]")[0] + # Valid bracketed hosts are defined in + # https://www.rfc-editor.org/rfc/rfc3986#page-49 + # https://url.spec.whatwg.org/ + if bracketed_host and bracketed_host[0] == "v": + if not re.match(r"\Av[a-fA-F0-9]+\..+\Z", bracketed_host): + raise ValueError("IPvFuture address is invalid") + elif ":" not in bracketed_host: + raise ValueError("The IPv6 content between brackets is not valid") + if has_hash: + url, _, fragment = url.partition("#") + if has_question_mark: + url, _, query = url.partition("?") + if netloc and not netloc.isascii(): + _check_netloc(netloc) + return scheme, netloc, url, query, fragment + + +def _check_netloc(netloc: str) -> None: + # Adapted from urllib.parse._checknetloc + # looking for characters like \u2100 that expand to 'a/c' + # IDNA uses NFKC equivalence, so normalize for this check + + # ignore characters already included + # but not the surrounding text + n = netloc.replace("@", "").replace(":", "").replace("#", "").replace("?", "") + normalized_netloc = unicodedata.normalize("NFKC", n) + if n == normalized_netloc: + return + # Note that there are no unicode decompositions for the character '@' so + # its currently impossible to have test coverage for this branch, however if the + # one should be added in the future we want to make sure its still checked. + for c in "/?#@:": # pragma: no branch + if c in normalized_netloc: + raise ValueError( + f"netloc '{netloc}' contains invalid " + "characters under NFKC normalization" + ) + + +@lru_cache # match the same size as urlsplit +def split_netloc( + netloc: str, +) -> tuple[Union[str, None], Union[str, None], Union[str, None], Union[int, None]]: + """Split netloc into username, password, host and port.""" + if "@" not in netloc: + username: Union[str, None] = None + password: Union[str, None] = None + hostinfo = netloc + else: + userinfo, _, hostinfo = netloc.rpartition("@") + username, have_password, password = userinfo.partition(":") + if not have_password: + password = None + + if "[" in hostinfo: + _, _, bracketed = hostinfo.partition("[") + hostname, _, port_str = bracketed.partition("]") + _, _, port_str = port_str.partition(":") + else: + hostname, _, port_str = hostinfo.partition(":") + + if not port_str: + return username or None, password, hostname or None, None + + try: + port = int(port_str) + except ValueError: + raise ValueError("Invalid URL: port can't be converted to integer") + if not (0 <= port <= 65535): + raise ValueError("Port out of range 0-65535") + return username or None, password, hostname or None, port + + +def unsplit_result( + scheme: str, netloc: str, url: str, query: str, fragment: str +) -> str: + """Unsplit a URL without any normalization.""" + if netloc or (scheme and scheme in USES_AUTHORITY) or url[:2] == "//": + if url and url[:1] != "/": + url = f"{scheme}://{netloc}/{url}" if scheme else f"{scheme}:{url}" + else: + url = f"{scheme}://{netloc}{url}" if scheme else f"//{netloc}{url}" + elif scheme: + url = f"{scheme}:{url}" + if query: + url = f"{url}?{query}" + return f"{url}#{fragment}" if fragment else url + + +@lru_cache # match the same size as urlsplit +def make_netloc( + user: Union[str, None], + password: Union[str, None], + host: Union[str, None], + port: Union[int, None], + encode: bool = False, +) -> str: + """Make netloc from parts. + + The user and password are encoded if encode is True. + + The host must already be encoded with _encode_host. + """ + if host is None: + return "" + ret = host + if port is not None: + ret = f"{ret}:{port}" + if user is None and password is None: + return ret + if password is not None: + if not user: + user = "" + elif encode: + user = QUOTER(user) + if encode: + password = QUOTER(password) + user = f"{user}:{password}" + elif user and encode: + user = QUOTER(user) + return f"{user}@{ret}" if user else ret + + +def query_to_pairs(query_string: str) -> list[tuple[str, str]]: + """Parse a query given as a string argument. + + Works like urllib.parse.parse_qsl with keep empty values. + """ + pairs: list[tuple[str, str]] = [] + if not query_string: + return pairs + for k_v in query_string.split("&"): + k, _, v = k_v.partition("=") + pairs.append((UNQUOTER_PLUS(k), UNQUOTER_PLUS(v))) + return pairs diff --git a/source/yarl/_path.py b/source/yarl/_path.py new file mode 100644 index 0000000000000000000000000000000000000000..c22f0b4b8cdd9280fd36789e2bc052b1c4938167 --- /dev/null +++ b/source/yarl/_path.py @@ -0,0 +1,41 @@ +"""Utilities for working with paths.""" + +from collections.abc import Sequence +from contextlib import suppress + + +def normalize_path_segments(segments: Sequence[str]) -> list[str]: + """Drop '.' and '..' from a sequence of str segments""" + + resolved_path: list[str] = [] + + for seg in segments: + if seg == "..": + # ignore any .. segments that would otherwise cause an + # IndexError when popped from resolved_path if + # resolving for rfc3986 + with suppress(IndexError): + resolved_path.pop() + elif seg != ".": + resolved_path.append(seg) + + if segments and segments[-1] in (".", ".."): + # do some post-processing here. + # if the last segment was a relative dir, + # then we need to append the trailing '/' + resolved_path.append("") + + return resolved_path + + +def normalize_path(path: str) -> str: + # Drop '.' and '..' from str path + prefix = "" + if path and path[0] == "/": + # preserve the "/" root element of absolute paths, copying it to the + # normalised output as per sections 5.2.4 and 6.2.2.3 of rfc3986. + prefix = "/" + path = path[1:] + + segments = path.split("/") + return prefix + "/".join(normalize_path_segments(segments)) diff --git a/source/yarl/_query.py b/source/yarl/_query.py new file mode 100644 index 0000000000000000000000000000000000000000..d911bcf0b24a6defa12c306d6342ba699e137b3d --- /dev/null +++ b/source/yarl/_query.py @@ -0,0 +1,121 @@ +"""Query string handling.""" + +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import TYPE_CHECKING, Any, SupportsInt, Union, cast + +from multidict import istr + +from ._quoters import QUERY_PART_QUOTER, QUERY_QUOTER + +SimpleQuery = Union[str, SupportsInt, float] +QueryVariable = Union[SimpleQuery, Sequence[SimpleQuery]] +Query = Union[ + None, str, Mapping[str, QueryVariable], Sequence[tuple[str, QueryVariable]] +] + + +def query_var(v: SimpleQuery) -> str: + """Convert a query variable to a string.""" + cls = type(v) + if cls is int: # Fast path for non-subclassed int + return str(v) + if isinstance(v, str): + return v + if isinstance(v, float): + if math.isinf(v): + raise ValueError("float('inf') is not supported") + if math.isnan(v): + raise ValueError("float('nan') is not supported") + return str(float(v)) + if cls is not bool and isinstance(v, SupportsInt): + return str(int(v)) + raise TypeError( + "Invalid variable type: value " + "should be str, int or float, got {!r} " + "of type {}".format(v, cls) + ) + + +def get_str_query_from_sequence_iterable( + items: Iterable[tuple[Union[str, istr], QueryVariable]], +) -> str: + """Return a query string from a sequence of (key, value) pairs. + + value is a single value or a sequence of values for the key + + The sequence of values must be a list or tuple. + """ + quoter = QUERY_PART_QUOTER + pairs = [ + f"{quoter(k)}={quoter(v if type(v) is str else query_var(v))}" + for k, val in items + for v in ( + val if type(val) is not str and isinstance(val, (list, tuple)) else (val,) + ) + ] + return "&".join(pairs) + + +def get_str_query_from_iterable( + items: Iterable[tuple[Union[str, istr], SimpleQuery]], +) -> str: + """Return a query string from an iterable. + + The iterable must contain (key, value) pairs. + + The values are not allowed to be sequences, only single values are + allowed. For sequences, use `_get_str_query_from_sequence_iterable`. + """ + quoter = QUERY_PART_QUOTER + # A listcomp is used since listcomps are inlined on CPython 3.12+ and + # they are a bit faster than a generator expression. + pairs = [ + f"{quoter(k)}={quoter(v if type(v) is str else query_var(v))}" for k, v in items + ] + return "&".join(pairs) + + +def get_str_query(*args: Any, **kwargs: Any) -> Union[str, None]: + """Return a query string from supported args.""" + query: Union[ + str, + Mapping[str, QueryVariable], + Sequence[tuple[Union[str, istr], SimpleQuery]], + None, + ] + if kwargs: + if args: + msg = "Either kwargs or single query parameter must be present" + raise ValueError(msg) + query = kwargs + elif len(args) == 1: + query = args[0] + else: + raise ValueError("Either kwargs or single query parameter must be present") + + if query is None: + return None + if not query: + return "" + if type(query) is dict: + return get_str_query_from_sequence_iterable(query.items()) + if type(query) is str or isinstance(query, str): + return QUERY_QUOTER(query) + if isinstance(query, Mapping): + return get_str_query_from_sequence_iterable(query.items()) + if isinstance(query, (bytes, bytearray, memoryview)): + msg = "Invalid query type: bytes, bytearray and memoryview are forbidden" + raise TypeError(msg) + if isinstance(query, Sequence): + # We don't expect sequence values if we're given a list of pairs + # already; only mappings like builtin `dict` which can't have the + # same key pointing to multiple values are allowed to use + # `_query_seq_pairs`. + if TYPE_CHECKING: + query = cast(Sequence[tuple[Union[str, istr], SimpleQuery]], query) + return get_str_query_from_iterable(query) + raise TypeError( + "Invalid query type: only str, mapping or " + "sequence of (key, value) pairs is allowed" + ) diff --git a/source/yarl/_quoters.py b/source/yarl/_quoters.py new file mode 100644 index 0000000000000000000000000000000000000000..0feb5b141131697a6dc87df19941ffe6714b20ff --- /dev/null +++ b/source/yarl/_quoters.py @@ -0,0 +1,33 @@ +"""Quoting and unquoting utilities for URL parts.""" + +from typing import Union +from urllib.parse import quote + +from ._quoting import _Quoter, _Unquoter + +QUOTER = _Quoter(requote=False) +REQUOTER = _Quoter() +PATH_QUOTER = _Quoter(safe="@:", protected="/+", requote=False) +PATH_REQUOTER = _Quoter(safe="@:", protected="/+") +QUERY_QUOTER = _Quoter(safe="?/:@", protected="=+&;", qs=True, requote=False) +QUERY_REQUOTER = _Quoter(safe="?/:@", protected="=+&;", qs=True) +QUERY_PART_QUOTER = _Quoter(safe="?/:@", qs=True, requote=False) +FRAGMENT_QUOTER = _Quoter(safe="?/:@", requote=False) +FRAGMENT_REQUOTER = _Quoter(safe="?/:@") + +UNQUOTER = _Unquoter() +PATH_UNQUOTER = _Unquoter(unsafe="+") +PATH_SAFE_UNQUOTER = _Unquoter(ignore="/%", unsafe="+") +QS_UNQUOTER = _Unquoter(qs=True) +UNQUOTER_PLUS = _Unquoter(plus=True) # to match urllib.parse.unquote_plus + + +def human_quote(s: Union[str, None], unsafe: str) -> Union[str, None]: + if not s: + return s + for c in "%" + unsafe: + if c in s: + s = s.replace(c, f"%{ord(c):02X}") + if s.isprintable(): + return s + return "".join(c if c.isprintable() else quote(c) for c in s) diff --git a/source/yarl/_quoting.py b/source/yarl/_quoting.py new file mode 100644 index 0000000000000000000000000000000000000000..25d76c885cacaa815bb7e0149aedbe76c20f2228 --- /dev/null +++ b/source/yarl/_quoting.py @@ -0,0 +1,19 @@ +import os +import sys +from typing import TYPE_CHECKING + +__all__ = ("_Quoter", "_Unquoter") + + +NO_EXTENSIONS = bool(os.environ.get("YARL_NO_EXTENSIONS")) # type: bool +if sys.implementation.name != "cpython": + NO_EXTENSIONS = True + + +if TYPE_CHECKING or NO_EXTENSIONS: + from ._quoting_py import _Quoter, _Unquoter +else: + try: + from ._quoting_c import _Quoter, _Unquoter + except ImportError: # pragma: no cover + from ._quoting_py import _Quoter, _Unquoter # type: ignore[assignment] diff --git a/source/yarl/_quoting_c.cpython-312-x86_64-linux-gnu.so b/source/yarl/_quoting_c.cpython-312-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..37efec8a7aea93001da0cb68d562b26a6029664d --- /dev/null +++ b/source/yarl/_quoting_c.cpython-312-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4daa49d87637b2461bea42642025884869623ba26f3200e242d450703f093310 +size 1170216 diff --git a/source/yarl/_quoting_c.pyx b/source/yarl/_quoting_c.pyx new file mode 100644 index 0000000000000000000000000000000000000000..dacf6b088c53e77a43f0435c06a2cbf51dcf4486 --- /dev/null +++ b/source/yarl/_quoting_c.pyx @@ -0,0 +1,451 @@ +from cpython.exc cimport PyErr_NoMemory +from cpython.mem cimport PyMem_Free, PyMem_Malloc, PyMem_Realloc +from cpython.unicode cimport ( + PyUnicode_DATA, + PyUnicode_DecodeASCII, + PyUnicode_DecodeUTF8Stateful, + PyUnicode_GET_LENGTH, + PyUnicode_KIND, + PyUnicode_READ, +) +from libc.stdint cimport uint8_t, uint64_t +from libc.string cimport memcpy, memset + +from string import ascii_letters, digits + + +cdef str GEN_DELIMS = ":/?#[]@" +cdef str SUB_DELIMS_WITHOUT_QS = "!$'()*," +cdef str SUB_DELIMS = SUB_DELIMS_WITHOUT_QS + '+?=;' +cdef str RESERVED = GEN_DELIMS + SUB_DELIMS +cdef str UNRESERVED = ascii_letters + digits + '-._~' +cdef str ALLOWED = UNRESERVED + SUB_DELIMS_WITHOUT_QS +cdef str QS = '+&=;' + +DEF BUF_SIZE = 8 * 1024 # 8KiB + +cdef inline Py_UCS4 _to_hex(uint8_t v) noexcept: + if v < 10: + return (v+0x30) # ord('0') == 0x30 + else: + return (v+0x41-10) # ord('A') == 0x41 + + +cdef inline int _from_hex(Py_UCS4 v) noexcept: + if '0' <= v <= '9': + return (v) - 0x30 # ord('0') == 0x30 + elif 'A' <= v <= 'F': + return (v) - 0x41 + 10 # ord('A') == 0x41 + elif 'a' <= v <= 'f': + return (v) - 0x61 + 10 # ord('a') == 0x61 + else: + return -1 + + +cdef inline int _is_lower_hex(Py_UCS4 v) noexcept: + return 'a' <= v <= 'f' + + +cdef inline long _restore_ch(Py_UCS4 d1, Py_UCS4 d2): + cdef int digit1 = _from_hex(d1) + if digit1 < 0: + return -1 + cdef int digit2 = _from_hex(d2) + if digit2 < 0: + return -1 + return digit1 << 4 | digit2 + + +cdef uint8_t ALLOWED_TABLE[16] +cdef uint8_t ALLOWED_NOTQS_TABLE[16] + + +cdef inline bint bit_at(uint8_t array[], uint64_t ch) noexcept: + return array[ch >> 3] & (1 << (ch & 7)) + + +cdef inline void set_bit(uint8_t array[], uint64_t ch) noexcept: + array[ch >> 3] |= (1 << (ch & 7)) + + +memset(ALLOWED_TABLE, 0, sizeof(ALLOWED_TABLE)) +memset(ALLOWED_NOTQS_TABLE, 0, sizeof(ALLOWED_NOTQS_TABLE)) + +for i in range(128): + if chr(i) in ALLOWED: + set_bit(ALLOWED_TABLE, i) + set_bit(ALLOWED_NOTQS_TABLE, i) + if chr(i) in QS: + set_bit(ALLOWED_NOTQS_TABLE, i) + +# ----------------- writer --------------------------- + +cdef struct Writer: + char *buf + bint heap_allocated_buf + Py_ssize_t size + Py_ssize_t pos + bint changed + + +cdef inline void _init_writer(Writer* writer, char* buf): + writer.buf = buf + writer.heap_allocated_buf = False + writer.size = BUF_SIZE + writer.pos = 0 + writer.changed = 0 + + +cdef inline void _release_writer(Writer* writer): + if writer.heap_allocated_buf: + PyMem_Free(writer.buf) + + +cdef inline int _write_char(Writer* writer, Py_UCS4 ch, bint changed): + cdef char * buf + cdef Py_ssize_t size + + if writer.pos == writer.size: + # reallocate + size = writer.size + BUF_SIZE + if not writer.heap_allocated_buf: + buf = PyMem_Malloc(size) + if buf == NULL: + PyErr_NoMemory() + return -1 + memcpy(buf, writer.buf, writer.size) + writer.heap_allocated_buf = True + else: + buf = PyMem_Realloc(writer.buf, size) + if buf == NULL: + PyErr_NoMemory() + return -1 + writer.buf = buf + writer.size = size + writer.buf[writer.pos] = ch + writer.pos += 1 + writer.changed |= changed + return 0 + + +cdef inline int _write_pct(Writer* writer, uint8_t ch, bint changed): + if _write_char(writer, '%', changed) < 0: + return -1 + if _write_char(writer, _to_hex(ch >> 4), changed) < 0: + return -1 + return _write_char(writer, _to_hex(ch & 0x0f), changed) + + +cdef inline int _write_utf8(Writer* writer, Py_UCS4 symbol): + cdef uint64_t utf = symbol + + if utf < 0x80: + return _write_pct(writer, utf, True) + elif utf < 0x800: + if _write_pct(writer, (0xc0 | (utf >> 6)), True) < 0: + return -1 + return _write_pct(writer, (0x80 | (utf & 0x3f)), True) + elif 0xD800 <= utf <= 0xDFFF: + # surogate pair, ignored + return 0 + elif utf < 0x10000: + if _write_pct(writer, (0xe0 | (utf >> 12)), True) < 0: + return -1 + if _write_pct(writer, (0x80 | ((utf >> 6) & 0x3f)), + True) < 0: + return -1 + return _write_pct(writer, (0x80 | (utf & 0x3f)), True) + elif utf > 0x10FFFF: + # symbol is too large + return 0 + else: + if _write_pct(writer, (0xf0 | (utf >> 18)), True) < 0: + return -1 + if _write_pct(writer, (0x80 | ((utf >> 12) & 0x3f)), + True) < 0: + return -1 + if _write_pct(writer, (0x80 | ((utf >> 6) & 0x3f)), + True) < 0: + return -1 + return _write_pct(writer, (0x80 | (utf & 0x3f)), True) + + +# --------------------- end writer -------------------------- + + +cdef class _Quoter: + cdef bint _qs + cdef bint _requote + + cdef uint8_t _safe_table[16] + cdef uint8_t _protected_table[16] + + def __init__( + self, *, str safe='', str protected='', bint qs=False, bint requote=True, + ): + cdef Py_UCS4 ch + + self._qs = qs + self._requote = requote + + if not self._qs: + memcpy(self._safe_table, + ALLOWED_NOTQS_TABLE, + sizeof(self._safe_table)) + else: + memcpy(self._safe_table, + ALLOWED_TABLE, + sizeof(self._safe_table)) + for ch in safe: + if ord(ch) > 127: + raise ValueError("Only safe symbols with ORD < 128 are allowed") + set_bit(self._safe_table, ch) + + memset(self._protected_table, 0, sizeof(self._protected_table)) + for ch in protected: + if ord(ch) > 127: + raise ValueError("Only safe symbols with ORD < 128 are allowed") + set_bit(self._safe_table, ch) + set_bit(self._protected_table, ch) + + def __call__(self, val): + if val is None: + return None + if type(val) is not str: + if isinstance(val, str): + # derived from str + val = str(val) + else: + raise TypeError("Argument should be str") + return self._do_quote_or_skip(val) + + cdef str _do_quote_or_skip(self, str val): + cdef char[BUF_SIZE] buffer + cdef Py_UCS4 ch + cdef Py_ssize_t length = PyUnicode_GET_LENGTH(val) + cdef Py_ssize_t idx = length + cdef bint must_quote = 0 + cdef Writer writer + cdef int kind = PyUnicode_KIND(val) + cdef const void *data = PyUnicode_DATA(val) + + # If everything in the string is in the safe + # table and all ASCII, we can skip quoting + while idx: + idx -= 1 + ch = PyUnicode_READ(kind, data, idx) + if ch >= 128 or not bit_at(self._safe_table, ch): + must_quote = 1 + break + + if not must_quote: + return val + + _init_writer(&writer, &buffer[0]) + try: + return self._do_quote(val, length, kind, data, &writer) + finally: + _release_writer(&writer) + + cdef str _do_quote( + self, + str val, + Py_ssize_t length, + int kind, + const void *data, + Writer *writer + ): + cdef Py_UCS4 ch + cdef long chl + cdef int changed + cdef Py_ssize_t idx = 0 + + while idx < length: + ch = PyUnicode_READ(kind, data, idx) + idx += 1 + if ch == '%' and self._requote and idx <= length - 2: + chl = _restore_ch( + PyUnicode_READ(kind, data, idx), + PyUnicode_READ(kind, data, idx + 1) + ) + if chl != -1: + ch = chl + idx += 2 + if ch < 128: + if bit_at(self._protected_table, ch): + if _write_pct(writer, ch, True) < 0: + raise + continue + + if bit_at(self._safe_table, ch): + if _write_char(writer, ch, True) < 0: + raise + continue + + changed = (_is_lower_hex(PyUnicode_READ(kind, data, idx - 2)) or + _is_lower_hex(PyUnicode_READ(kind, data, idx - 1))) + if _write_pct(writer, ch, changed) < 0: + raise + continue + else: + ch = '%' + + if self._write(writer, ch) < 0: + raise + + if not writer.changed: + return val + else: + return PyUnicode_DecodeASCII(writer.buf, writer.pos, "strict") + + cdef inline int _write(self, Writer *writer, Py_UCS4 ch): + if self._qs: + if ch == ' ': + return _write_char(writer, '+', True) + + if ch < 128 and bit_at(self._safe_table, ch): + return _write_char(writer, ch, False) + + return _write_utf8(writer, ch) + + +cdef class _Unquoter: + cdef str _ignore + cdef bint _has_ignore + cdef str _unsafe + cdef bytes _unsafe_bytes + cdef Py_ssize_t _unsafe_bytes_len + cdef const unsigned char * _unsafe_bytes_char + cdef bint _qs + cdef bint _plus # to match urllib.parse.unquote_plus + cdef _Quoter _quoter + cdef _Quoter _qs_quoter + + def __init__(self, *, ignore="", unsafe="", qs=False, plus=False): + self._ignore = ignore + self._has_ignore = bool(self._ignore) + self._unsafe = unsafe + # unsafe may only be extended ascii characters (0-255) + self._unsafe_bytes = self._unsafe.encode('ascii') + self._unsafe_bytes_len = len(self._unsafe_bytes) + self._unsafe_bytes_char = self._unsafe_bytes + self._qs = qs + self._plus = plus + self._quoter = _Quoter() + self._qs_quoter = _Quoter(qs=True) + + def __call__(self, val): + if val is None: + return None + if type(val) is not str: + if isinstance(val, str): + # derived from str + val = str(val) + else: + raise TypeError("Argument should be str") + return self._do_unquote(val) + + cdef str _do_unquote(self, str val): + cdef Py_ssize_t length = PyUnicode_GET_LENGTH(val) + if length == 0: + return val + + cdef list ret = [] + cdef char buffer[4] + cdef Py_ssize_t buflen = 0 + cdef Py_ssize_t consumed + cdef str unquoted + cdef Py_UCS4 ch = 0 + cdef long chl = 0 + cdef Py_ssize_t idx = 0 + cdef Py_ssize_t start_pct + cdef int kind = PyUnicode_KIND(val) + cdef const void *data = PyUnicode_DATA(val) + cdef bint changed = 0 + while idx < length: + ch = PyUnicode_READ(kind, data, idx) + idx += 1 + if ch == '%' and idx <= length - 2: + changed = 1 + chl = _restore_ch( + PyUnicode_READ(kind, data, idx), + PyUnicode_READ(kind, data, idx + 1) + ) + if chl != -1: + ch = chl + idx += 2 + assert buflen < 4 + buffer[buflen] = ch + buflen += 1 + try: + unquoted = PyUnicode_DecodeUTF8Stateful(buffer, buflen, + NULL, &consumed) + except UnicodeDecodeError: + start_pct = idx - buflen * 3 + buffer[0] = ch + buflen = 1 + ret.append(val[start_pct : idx - 3]) + try: + unquoted = PyUnicode_DecodeUTF8Stateful(buffer, buflen, + NULL, &consumed) + except UnicodeDecodeError: + buflen = 0 + ret.append(val[idx - 3 : idx]) + continue + if not unquoted: + assert consumed == 0 + continue + assert consumed == buflen + buflen = 0 + if self._qs and unquoted in '+=&;': + ret.append(self._qs_quoter(unquoted)) + elif ( + (self._unsafe_bytes_len and unquoted in self._unsafe) or + (self._has_ignore and unquoted in self._ignore) + ): + ret.append(self._quoter(unquoted)) + else: + ret.append(unquoted) + continue + else: + ch = '%' + + if buflen: + start_pct = idx - 1 - buflen * 3 + ret.append(val[start_pct : idx - 1]) + buflen = 0 + + if ch == '+': + if ( + (not self._qs and not self._plus) or + (self._unsafe_bytes_len and self._is_char_unsafe(ch)) + ): + ret.append('+') + else: + changed = 1 + ret.append(' ') + continue + + if self._unsafe_bytes_len and self._is_char_unsafe(ch): + changed = 1 + ret.append('%') + h = hex(ord(ch)).upper()[2:] + for ch in h: + ret.append(ch) + continue + + ret.append(ch) + + if not changed: + return val + + if buflen: + ret.append(val[length - buflen * 3 : length]) + + return ''.join(ret) + + cdef inline bint _is_char_unsafe(self, Py_UCS4 ch): + for i in range(self._unsafe_bytes_len): + if ch == self._unsafe_bytes_char[i]: + return True + return False diff --git a/source/yarl/_quoting_py.py b/source/yarl/_quoting_py.py new file mode 100644 index 0000000000000000000000000000000000000000..80bf07febf30d2ab3702c08c635ab83c20db877a --- /dev/null +++ b/source/yarl/_quoting_py.py @@ -0,0 +1,213 @@ +import codecs +import re +from string import ascii_letters, ascii_lowercase, digits +from typing import Union, overload + +BASCII_LOWERCASE = ascii_lowercase.encode("ascii") +BPCT_ALLOWED = {f"%{i:02X}".encode("ascii") for i in range(256)} +GEN_DELIMS = ":/?#[]@" +SUB_DELIMS_WITHOUT_QS = "!$'()*," +SUB_DELIMS = SUB_DELIMS_WITHOUT_QS + "+&=;" +RESERVED = GEN_DELIMS + SUB_DELIMS +UNRESERVED = ascii_letters + digits + "-._~" +ALLOWED = UNRESERVED + SUB_DELIMS_WITHOUT_QS + + +_IS_HEX = re.compile(b"[A-Z0-9][A-Z0-9]") +_IS_HEX_STR = re.compile("[A-Fa-f0-9][A-Fa-f0-9]") + +utf8_decoder = codecs.getincrementaldecoder("utf-8") + + +class _Quoter: + def __init__( + self, + *, + safe: str = "", + protected: str = "", + qs: bool = False, + requote: bool = True, + ) -> None: + self._safe = safe + self._protected = protected + self._qs = qs + self._requote = requote + + @overload + def __call__(self, val: str) -> str: ... + @overload + def __call__(self, val: None) -> None: ... + def __call__(self, val: Union[str, None]) -> Union[str, None]: + if val is None: + return None + if not isinstance(val, str): + raise TypeError("Argument should be str") + if not val: + return "" + bval = val.encode("utf8", errors="ignore") + ret = bytearray() + pct = bytearray() + safe = self._safe + safe += ALLOWED + if not self._qs: + safe += "+&=;" + safe += self._protected + bsafe = safe.encode("ascii") + idx = 0 + while idx < len(bval): + ch = bval[idx] + idx += 1 + + if pct: + if ch in BASCII_LOWERCASE: + ch = ch - 32 # convert to uppercase + pct.append(ch) + if len(pct) == 3: # pragma: no branch # peephole optimizer + buf = pct[1:] + if not _IS_HEX.match(buf): + ret.extend(b"%25") + pct.clear() + idx -= 2 + continue + try: + unquoted = chr(int(pct[1:].decode("ascii"), base=16)) + except ValueError: + ret.extend(b"%25") + pct.clear() + idx -= 2 + continue + + if unquoted in self._protected: + ret.extend(pct) + elif unquoted in safe: + ret.append(ord(unquoted)) + else: + ret.extend(pct) + pct.clear() + + # special case, if we have only one char after "%" + elif len(pct) == 2 and idx == len(bval): + ret.extend(b"%25") + pct.clear() + idx -= 1 + + continue + + elif ch == ord("%") and self._requote: + pct.clear() + pct.append(ch) + + # special case if "%" is last char + if idx == len(bval): + ret.extend(b"%25") + + continue + + if self._qs and ch == ord(" "): + ret.append(ord("+")) + continue + if ch in bsafe: + ret.append(ch) + continue + + ret.extend((f"%{ch:02X}").encode("ascii")) + + ret2 = ret.decode("ascii") + if ret2 == val: + return val + return ret2 + + +class _Unquoter: + def __init__( + self, + *, + ignore: str = "", + unsafe: str = "", + qs: bool = False, + plus: bool = False, + ) -> None: + self._ignore = ignore + self._unsafe = unsafe + self._qs = qs + self._plus = plus # to match urllib.parse.unquote_plus + self._quoter = _Quoter() + self._qs_quoter = _Quoter(qs=True) + + @overload + def __call__(self, val: str) -> str: ... + @overload + def __call__(self, val: None) -> None: ... + def __call__(self, val: Union[str, None]) -> Union[str, None]: + if val is None: + return None + if not isinstance(val, str): + raise TypeError("Argument should be str") + if not val: + return "" + decoder = utf8_decoder() + ret = [] + idx = 0 + while idx < len(val): + ch = val[idx] + idx += 1 + if ch == "%" and idx <= len(val) - 2: + pct = val[idx : idx + 2] + if _IS_HEX_STR.fullmatch(pct): + b = bytes([int(pct, base=16)]) + idx += 2 + try: + unquoted = decoder.decode(b) + except UnicodeDecodeError: + start_pct = idx - 3 - len(decoder.buffer) * 3 + ret.append(val[start_pct : idx - 3]) + decoder.reset() + try: + unquoted = decoder.decode(b) + except UnicodeDecodeError: + ret.append(val[idx - 3 : idx]) + continue + if not unquoted: + continue + if self._qs and unquoted in "+=&;": + to_add = self._qs_quoter(unquoted) + if to_add is None: # pragma: no cover + raise RuntimeError("Cannot quote None") + ret.append(to_add) + elif unquoted in self._unsafe or unquoted in self._ignore: + to_add = self._quoter(unquoted) + if to_add is None: # pragma: no cover + raise RuntimeError("Cannot quote None") + ret.append(to_add) + else: + ret.append(unquoted) + continue + + if decoder.buffer: + start_pct = idx - 1 - len(decoder.buffer) * 3 + ret.append(val[start_pct : idx - 1]) + decoder.reset() + + if ch == "+": + if (not self._qs and not self._plus) or ch in self._unsafe: + ret.append("+") + else: + ret.append(" ") + continue + + if ch in self._unsafe: + ret.append("%") + h = hex(ord(ch)).upper()[2:] + for ch in h: + ret.append(ch) + continue + + ret.append(ch) + + if decoder.buffer: + ret.append(val[-len(decoder.buffer) * 3 :]) + + ret2 = "".join(ret) + if ret2 == val: + return val + return ret2 diff --git a/source/yarl/_url.py b/source/yarl/_url.py new file mode 100644 index 0000000000000000000000000000000000000000..527a576ee68f8fbeff941ef67bb1590f940862d0 --- /dev/null +++ b/source/yarl/_url.py @@ -0,0 +1,1622 @@ +import re +import sys +import warnings +from collections.abc import Mapping, Sequence +from enum import Enum +from functools import _CacheInfo, lru_cache +from ipaddress import ip_address +from typing import ( + TYPE_CHECKING, + Any, + NoReturn, + TypedDict, + TypeVar, + Union, + cast, + overload, +) +from urllib.parse import SplitResult, uses_relative + +import idna +from multidict import MultiDict, MultiDictProxy, istr +from propcache.api import under_cached_property as cached_property + +from ._parse import ( + USES_AUTHORITY, + SplitURLType, + make_netloc, + query_to_pairs, + split_netloc, + split_url, + unsplit_result, +) +from ._path import normalize_path, normalize_path_segments +from ._query import ( + Query, + QueryVariable, + SimpleQuery, + get_str_query, + get_str_query_from_iterable, + get_str_query_from_sequence_iterable, +) +from ._quoters import ( + FRAGMENT_QUOTER, + FRAGMENT_REQUOTER, + PATH_QUOTER, + PATH_REQUOTER, + PATH_SAFE_UNQUOTER, + PATH_UNQUOTER, + QS_UNQUOTER, + QUERY_QUOTER, + QUERY_REQUOTER, + QUOTER, + REQUOTER, + UNQUOTER, + human_quote, +) + +DEFAULT_PORTS = {"http": 80, "https": 443, "ws": 80, "wss": 443, "ftp": 21} +USES_RELATIVE = frozenset(uses_relative) + +# Special schemes https://url.spec.whatwg.org/#special-scheme +# are not allowed to have an empty host https://url.spec.whatwg.org/#url-representation +SCHEME_REQUIRES_HOST = frozenset(("http", "https", "ws", "wss", "ftp")) + + +# reg-name: unreserved / pct-encoded / sub-delims +# this pattern matches anything that is *not* in those classes. and is only used +# on lower-cased ASCII values. +NOT_REG_NAME = re.compile( + r""" + # any character not in the unreserved or sub-delims sets, plus % + # (validated with the additional check for pct-encoded sequences below) + [^a-z0-9\-._~!$&'()*+,;=%] + | + # % only allowed if it is part of a pct-encoded + # sequence of 2 hex digits. + %(?![0-9a-f]{2}) + """, + re.VERBOSE, +) + +_T = TypeVar("_T") + +if sys.version_info >= (3, 11): + from typing import Self +else: + Self = Any + + +class UndefinedType(Enum): + """Singleton type for use with not set sentinel values.""" + + _singleton = 0 + + +UNDEFINED = UndefinedType._singleton + + +class CacheInfo(TypedDict): + """Host encoding cache.""" + + idna_encode: _CacheInfo + idna_decode: _CacheInfo + ip_address: _CacheInfo + host_validate: _CacheInfo + encode_host: _CacheInfo + + +class _InternalURLCache(TypedDict, total=False): + _val: SplitURLType + _origin: "URL" + absolute: bool + hash: int + scheme: str + raw_authority: str + authority: str + raw_user: Union[str, None] + user: Union[str, None] + raw_password: Union[str, None] + password: Union[str, None] + raw_host: Union[str, None] + host: Union[str, None] + host_subcomponent: Union[str, None] + host_port_subcomponent: Union[str, None] + port: Union[int, None] + explicit_port: Union[int, None] + raw_path: str + path: str + _parsed_query: list[tuple[str, str]] + query: "MultiDictProxy[str]" + raw_query_string: str + query_string: str + path_qs: str + raw_path_qs: str + raw_fragment: str + fragment: str + raw_parts: tuple[str, ...] + parts: tuple[str, ...] + parent: "URL" + raw_name: str + name: str + raw_suffix: str + suffix: str + raw_suffixes: tuple[str, ...] + suffixes: tuple[str, ...] + + +def rewrite_module(obj: _T) -> _T: + obj.__module__ = "yarl" + return obj + + +@lru_cache +def encode_url(url_str: str) -> "URL": + """Parse unencoded URL.""" + cache: _InternalURLCache = {} + host: Union[str, None] + scheme, netloc, path, query, fragment = split_url(url_str) + if not netloc: # netloc + host = "" + else: + if ":" in netloc or "@" in netloc or "[" in netloc: + # Complex netloc + username, password, host, port = split_netloc(netloc) + else: + username = password = port = None + host = netloc + if host is None: + if scheme in SCHEME_REQUIRES_HOST: + msg = ( + "Invalid URL: host is required for " + f"absolute urls with the {scheme} scheme" + ) + raise ValueError(msg) + else: + host = "" + host = _encode_host(host, validate_host=False) + # Remove brackets as host encoder adds back brackets for IPv6 addresses + cache["raw_host"] = host[1:-1] if "[" in host else host + cache["explicit_port"] = port + if password is None and username is None: + # Fast path for URLs without user, password + netloc = host if port is None else f"{host}:{port}" + cache["raw_user"] = None + cache["raw_password"] = None + else: + raw_user = REQUOTER(username) if username else username + raw_password = REQUOTER(password) if password else password + netloc = make_netloc(raw_user, raw_password, host, port) + cache["raw_user"] = raw_user + cache["raw_password"] = raw_password + + if path: + path = PATH_REQUOTER(path) + if netloc and "." in path: + path = normalize_path(path) + if query: + query = QUERY_REQUOTER(query) + if fragment: + fragment = FRAGMENT_REQUOTER(fragment) + + cache["scheme"] = scheme + cache["raw_path"] = "/" if not path and netloc else path + cache["raw_query_string"] = query + cache["raw_fragment"] = fragment + + self = object.__new__(URL) + self._scheme = scheme + self._netloc = netloc + self._path = path + self._query = query + self._fragment = fragment + self._cache = cache + return self + + +@lru_cache +def pre_encoded_url(url_str: str) -> "URL": + """Parse pre-encoded URL.""" + self = object.__new__(URL) + val = split_url(url_str) + self._scheme, self._netloc, self._path, self._query, self._fragment = val + self._cache = {} + return self + + +@lru_cache +def build_pre_encoded_url( + scheme: str, + authority: str, + user: Union[str, None], + password: Union[str, None], + host: str, + port: Union[int, None], + path: str, + query_string: str, + fragment: str, +) -> "URL": + """Build a pre-encoded URL from parts.""" + self = object.__new__(URL) + self._scheme = scheme + if authority: + self._netloc = authority + elif host: + if port is not None: + port = None if port == DEFAULT_PORTS.get(scheme) else port + if user is None and password is None: + self._netloc = host if port is None else f"{host}:{port}" + else: + self._netloc = make_netloc(user, password, host, port) + else: + self._netloc = "" + self._path = path + self._query = query_string + self._fragment = fragment + self._cache = {} + return self + + +def from_parts_uncached( + scheme: str, netloc: str, path: str, query: str, fragment: str +) -> "URL": + """Create a new URL from parts.""" + self = object.__new__(URL) + self._scheme = scheme + self._netloc = netloc + self._path = path + self._query = query + self._fragment = fragment + self._cache = {} + return self + + +from_parts = lru_cache(from_parts_uncached) + + +@rewrite_module +class URL: + # Don't derive from str + # follow pathlib.Path design + # probably URL will not suffer from pathlib problems: + # it's intended for libraries like aiohttp, + # not to be passed into standard library functions like os.open etc. + + # URL grammar (RFC 3986) + # pct-encoded = "%" HEXDIG HEXDIG + # reserved = gen-delims / sub-delims + # gen-delims = ":" / "/" / "?" / "#" / "[" / "]" / "@" + # sub-delims = "!" / "$" / "&" / "'" / "(" / ")" + # / "*" / "+" / "," / ";" / "=" + # unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~" + # URI = scheme ":" hier-part [ "?" query ] [ "#" fragment ] + # hier-part = "//" authority path-abempty + # / path-absolute + # / path-rootless + # / path-empty + # scheme = ALPHA *( ALPHA / DIGIT / "+" / "-" / "." ) + # authority = [ userinfo "@" ] host [ ":" port ] + # userinfo = *( unreserved / pct-encoded / sub-delims / ":" ) + # host = IP-literal / IPv4address / reg-name + # IP-literal = "[" ( IPv6address / IPvFuture ) "]" + # IPvFuture = "v" 1*HEXDIG "." 1*( unreserved / sub-delims / ":" ) + # IPv6address = 6( h16 ":" ) ls32 + # / "::" 5( h16 ":" ) ls32 + # / [ h16 ] "::" 4( h16 ":" ) ls32 + # / [ *1( h16 ":" ) h16 ] "::" 3( h16 ":" ) ls32 + # / [ *2( h16 ":" ) h16 ] "::" 2( h16 ":" ) ls32 + # / [ *3( h16 ":" ) h16 ] "::" h16 ":" ls32 + # / [ *4( h16 ":" ) h16 ] "::" ls32 + # / [ *5( h16 ":" ) h16 ] "::" h16 + # / [ *6( h16 ":" ) h16 ] "::" + # ls32 = ( h16 ":" h16 ) / IPv4address + # ; least-significant 32 bits of address + # h16 = 1*4HEXDIG + # ; 16 bits of address represented in hexadecimal + # IPv4address = dec-octet "." dec-octet "." dec-octet "." dec-octet + # dec-octet = DIGIT ; 0-9 + # / %x31-39 DIGIT ; 10-99 + # / "1" 2DIGIT ; 100-199 + # / "2" %x30-34 DIGIT ; 200-249 + # / "25" %x30-35 ; 250-255 + # reg-name = *( unreserved / pct-encoded / sub-delims ) + # port = *DIGIT + # path = path-abempty ; begins with "/" or is empty + # / path-absolute ; begins with "/" but not "//" + # / path-noscheme ; begins with a non-colon segment + # / path-rootless ; begins with a segment + # / path-empty ; zero characters + # path-abempty = *( "/" segment ) + # path-absolute = "/" [ segment-nz *( "/" segment ) ] + # path-noscheme = segment-nz-nc *( "/" segment ) + # path-rootless = segment-nz *( "/" segment ) + # path-empty = 0 + # segment = *pchar + # segment-nz = 1*pchar + # segment-nz-nc = 1*( unreserved / pct-encoded / sub-delims / "@" ) + # ; non-zero-length segment without any colon ":" + # pchar = unreserved / pct-encoded / sub-delims / ":" / "@" + # query = *( pchar / "/" / "?" ) + # fragment = *( pchar / "/" / "?" ) + # URI-reference = URI / relative-ref + # relative-ref = relative-part [ "?" query ] [ "#" fragment ] + # relative-part = "//" authority path-abempty + # / path-absolute + # / path-noscheme + # / path-empty + # absolute-URI = scheme ":" hier-part [ "?" query ] + __slots__ = ("_cache", "_scheme", "_netloc", "_path", "_query", "_fragment") + + _cache: _InternalURLCache + _scheme: str + _netloc: str + _path: str + _query: str + _fragment: str + + def __new__( + cls, + val: Union[str, SplitResult, "URL", UndefinedType] = UNDEFINED, + *, + encoded: bool = False, + strict: Union[bool, None] = None, + ) -> "URL": + if strict is not None: # pragma: no cover + warnings.warn("strict parameter is ignored") + if type(val) is str: + return pre_encoded_url(val) if encoded else encode_url(val) + if type(val) is cls: + return val + if type(val) is SplitResult: + if not encoded: + raise ValueError("Cannot apply decoding to SplitResult") + return from_parts(*val) + if isinstance(val, str): + return pre_encoded_url(str(val)) if encoded else encode_url(str(val)) + if val is UNDEFINED: + # Special case for UNDEFINED since it might be unpickling and we do + # not want to cache as the `__set_state__` call would mutate the URL + # object in the `pre_encoded_url` or `encoded_url` caches. + self = object.__new__(URL) + self._scheme = self._netloc = self._path = self._query = self._fragment = "" + self._cache = {} + return self + raise TypeError("Constructor parameter should be str") + + @classmethod + def build( + cls, + *, + scheme: str = "", + authority: str = "", + user: Union[str, None] = None, + password: Union[str, None] = None, + host: str = "", + port: Union[int, None] = None, + path: str = "", + query: Union[Query, None] = None, + query_string: str = "", + fragment: str = "", + encoded: bool = False, + ) -> "URL": + """Creates and returns a new URL""" + + if authority and (user or password or host or port): + raise ValueError( + 'Can\'t mix "authority" with "user", "password", "host" or "port".' + ) + if port is not None and not isinstance(port, int): + raise TypeError(f"The port is required to be int, got {type(port)!r}.") + if port and not host: + raise ValueError('Can\'t build URL with "port" but without "host".') + if query and query_string: + raise ValueError('Only one of "query" or "query_string" should be passed') + if ( + scheme is None # type: ignore[redundant-expr] + or authority is None # type: ignore[redundant-expr] + or host is None # type: ignore[redundant-expr] + or path is None # type: ignore[redundant-expr] + or query_string is None # type: ignore[redundant-expr] + or fragment is None + ): + raise TypeError( + 'NoneType is illegal for "scheme", "authority", "host", "path", ' + '"query_string", and "fragment" args, use empty string instead.' + ) + + if query: + query_string = get_str_query(query) or "" + + if encoded: + return build_pre_encoded_url( + scheme, + authority, + user, + password, + host, + port, + path, + query_string, + fragment, + ) + + self = object.__new__(URL) + self._scheme = scheme + _host: Union[str, None] = None + if authority: + user, password, _host, port = split_netloc(authority) + _host = _encode_host(_host, validate_host=False) if _host else "" + elif host: + _host = _encode_host(host, validate_host=True) + else: + self._netloc = "" + + if _host is not None: + if port is not None: + port = None if port == DEFAULT_PORTS.get(scheme) else port + if user is None and password is None: + self._netloc = _host if port is None else f"{_host}:{port}" + else: + self._netloc = make_netloc(user, password, _host, port, True) + + path = PATH_QUOTER(path) if path else path + if path and self._netloc: + if "." in path: + path = normalize_path(path) + if path[0] != "/": + msg = ( + "Path in a URL with authority should " + "start with a slash ('/') if set" + ) + raise ValueError(msg) + + self._path = path + if not query and query_string: + query_string = QUERY_QUOTER(query_string) + self._query = query_string + self._fragment = FRAGMENT_QUOTER(fragment) if fragment else fragment + self._cache = {} + return self + + def __init_subclass__(cls) -> NoReturn: + raise TypeError(f"Inheriting a class {cls!r} from URL is forbidden") + + def __str__(self) -> str: + if not self._path and self._netloc and (self._query or self._fragment): + path = "/" + else: + path = self._path + if (port := self.explicit_port) is not None and port == DEFAULT_PORTS.get( + self._scheme + ): + # port normalization - using None for default ports to remove from rendering + # https://datatracker.ietf.org/doc/html/rfc3986.html#section-6.2.3 + host = self.host_subcomponent + netloc = make_netloc(self.raw_user, self.raw_password, host, None) + else: + netloc = self._netloc + return unsplit_result(self._scheme, netloc, path, self._query, self._fragment) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}('{str(self)}')" + + def __bytes__(self) -> bytes: + return str(self).encode("ascii") + + def __eq__(self, other: object) -> bool: + if type(other) is not URL: + return NotImplemented + + path1 = "/" if not self._path and self._netloc else self._path + path2 = "/" if not other._path and other._netloc else other._path + return ( + self._scheme == other._scheme + and self._netloc == other._netloc + and path1 == path2 + and self._query == other._query + and self._fragment == other._fragment + ) + + def __hash__(self) -> int: + if (ret := self._cache.get("hash")) is None: + path = "/" if not self._path and self._netloc else self._path + ret = self._cache["hash"] = hash( + (self._scheme, self._netloc, path, self._query, self._fragment) + ) + return ret + + def __le__(self, other: object) -> bool: + if type(other) is not URL: + return NotImplemented + return self._val <= other._val + + def __lt__(self, other: object) -> bool: + if type(other) is not URL: + return NotImplemented + return self._val < other._val + + def __ge__(self, other: object) -> bool: + if type(other) is not URL: + return NotImplemented + return self._val >= other._val + + def __gt__(self, other: object) -> bool: + if type(other) is not URL: + return NotImplemented + return self._val > other._val + + def __truediv__(self, name: str) -> "URL": + if not isinstance(name, str): + return NotImplemented # type: ignore[unreachable] + return self._make_child((str(name),)) + + def __mod__(self, query: Query) -> "URL": + return self.update_query(query) + + def __bool__(self) -> bool: + return bool(self._netloc or self._path or self._query or self._fragment) + + def __getstate__(self) -> tuple[SplitResult]: + return (tuple.__new__(SplitResult, self._val),) + + def __setstate__( + self, state: Union[tuple[SplitURLType], tuple[None, _InternalURLCache]] + ) -> None: + if state[0] is None and isinstance(state[1], dict): + # default style pickle + val = state[1]["_val"] + else: + unused: list[object] + val, *unused = state + self._scheme, self._netloc, self._path, self._query, self._fragment = val + self._cache = {} + + def _cache_netloc(self) -> None: + """Cache the netloc parts of the URL.""" + c = self._cache + split_loc = split_netloc(self._netloc) + c["raw_user"], c["raw_password"], c["raw_host"], c["explicit_port"] = split_loc + + def is_absolute(self) -> bool: + """A check for absolute URLs. + + Return True for absolute ones (having scheme or starting + with //), False otherwise. + + Is is preferred to call the .absolute property instead + as it is cached. + """ + return self.absolute + + def is_default_port(self) -> bool: + """A check for default port. + + Return True if port is default for specified scheme, + e.g. 'http://python.org' or 'http://python.org:80', False + otherwise. + + Return False for relative URLs. + + """ + if (explicit := self.explicit_port) is None: + # If the explicit port is None, then the URL must be + # using the default port unless its a relative URL + # which does not have an implicit port / default port + return self._netloc != "" + return explicit == DEFAULT_PORTS.get(self._scheme) + + def origin(self) -> "URL": + """Return an URL with scheme, host and port parts only. + + user, password, path, query and fragment are removed. + + """ + # TODO: add a keyword-only option for keeping user/pass maybe? + return self._origin + + @cached_property + def _val(self) -> SplitURLType: + return (self._scheme, self._netloc, self._path, self._query, self._fragment) + + @cached_property + def _origin(self) -> "URL": + """Return an URL with scheme, host and port parts only. + + user, password, path, query and fragment are removed. + """ + if not (netloc := self._netloc): + raise ValueError("URL should be absolute") + if not (scheme := self._scheme): + raise ValueError("URL should have scheme") + if "@" in netloc: + encoded_host = self.host_subcomponent + netloc = make_netloc(None, None, encoded_host, self.explicit_port) + elif not self._path and not self._query and not self._fragment: + return self + return from_parts(scheme, netloc, "", "", "") + + def relative(self) -> "URL": + """Return a relative part of the URL. + + scheme, user, password, host and port are removed. + + """ + if not self._netloc: + raise ValueError("URL should be absolute") + return from_parts("", "", self._path, self._query, self._fragment) + + @cached_property + def absolute(self) -> bool: + """A check for absolute URLs. + + Return True for absolute ones (having scheme or starting + with //), False otherwise. + + """ + # `netloc`` is an empty string for relative URLs + # Checking `netloc` is faster than checking `hostname` + # because `hostname` is a property that does some extra work + # to parse the host from the `netloc` + return self._netloc != "" + + @cached_property + def scheme(self) -> str: + """Scheme for absolute URLs. + + Empty string for relative URLs or URLs starting with // + + """ + return self._scheme + + @cached_property + def raw_authority(self) -> str: + """Encoded authority part of URL. + + Empty string for relative URLs. + + """ + return self._netloc + + @cached_property + def authority(self) -> str: + """Decoded authority part of URL. + + Empty string for relative URLs. + + """ + return make_netloc(self.user, self.password, self.host, self.port) + + @cached_property + def raw_user(self) -> Union[str, None]: + """Encoded user part of URL. + + None if user is missing. + + """ + # not .username + self._cache_netloc() + return self._cache["raw_user"] + + @cached_property + def user(self) -> Union[str, None]: + """Decoded user part of URL. + + None if user is missing. + + """ + if (raw_user := self.raw_user) is None: + return None + return UNQUOTER(raw_user) + + @cached_property + def raw_password(self) -> Union[str, None]: + """Encoded password part of URL. + + None if password is missing. + + """ + self._cache_netloc() + return self._cache["raw_password"] + + @cached_property + def password(self) -> Union[str, None]: + """Decoded password part of URL. + + None if password is missing. + + """ + if (raw_password := self.raw_password) is None: + return None + return UNQUOTER(raw_password) + + @cached_property + def raw_host(self) -> Union[str, None]: + """Encoded host part of URL. + + None for relative URLs. + + When working with IPv6 addresses, use the `host_subcomponent` property instead + as it will return the host subcomponent with brackets. + """ + # Use host instead of hostname for sake of shortness + # May add .hostname prop later + self._cache_netloc() + return self._cache["raw_host"] + + @cached_property + def host(self) -> Union[str, None]: + """Decoded host part of URL. + + None for relative URLs. + + """ + if (raw := self.raw_host) is None: + return None + if raw and raw[-1].isdigit() or ":" in raw: + # IP addresses are never IDNA encoded + return raw + return _idna_decode(raw) + + @cached_property + def host_subcomponent(self) -> Union[str, None]: + """Return the host subcomponent part of URL. + + None for relative URLs. + + https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.2 + + `IP-literal = "[" ( IPv6address / IPvFuture ) "]"` + + Examples: + - `http://example.com:8080` -> `example.com` + - `http://example.com:80` -> `example.com` + - `https://127.0.0.1:8443` -> `127.0.0.1` + - `https://[::1]:8443` -> `[::1]` + - `http://[::1]` -> `[::1]` + + """ + if (raw := self.raw_host) is None: + return None + return f"[{raw}]" if ":" in raw else raw + + @cached_property + def host_port_subcomponent(self) -> Union[str, None]: + """Return the host and port subcomponent part of URL. + + Trailing dots are removed from the host part. + + This value is suitable for use in the Host header of an HTTP request. + + None for relative URLs. + + https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.2 + `IP-literal = "[" ( IPv6address / IPvFuture ) "]"` + https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.3 + port = *DIGIT + + Examples: + - `http://example.com:8080` -> `example.com:8080` + - `http://example.com:80` -> `example.com` + - `http://example.com.:80` -> `example.com` + - `https://127.0.0.1:8443` -> `127.0.0.1:8443` + - `https://[::1]:8443` -> `[::1]:8443` + - `http://[::1]` -> `[::1]` + + """ + if (raw := self.raw_host) is None: + return None + if raw[-1] == ".": + # Remove all trailing dots from the netloc as while + # they are valid FQDNs in DNS, TLS validation fails. + # See https://github.com/aio-libs/aiohttp/issues/3636. + # To avoid string manipulation we only call rstrip if + # the last character is a dot. + raw = raw.rstrip(".") + port = self.explicit_port + if port is None or port == DEFAULT_PORTS.get(self._scheme): + return f"[{raw}]" if ":" in raw else raw + return f"[{raw}]:{port}" if ":" in raw else f"{raw}:{port}" + + @cached_property + def port(self) -> Union[int, None]: + """Port part of URL, with scheme-based fallback. + + None for relative URLs or URLs without explicit port and + scheme without default port substitution. + + """ + if (explicit_port := self.explicit_port) is not None: + return explicit_port + return DEFAULT_PORTS.get(self._scheme) + + @cached_property + def explicit_port(self) -> Union[int, None]: + """Port part of URL, without scheme-based fallback. + + None for relative URLs or URLs without explicit port. + + """ + self._cache_netloc() + return self._cache["explicit_port"] + + @cached_property + def raw_path(self) -> str: + """Encoded path of URL. + + / for absolute URLs without path part. + + """ + return self._path if self._path or not self._netloc else "/" + + @cached_property + def path(self) -> str: + """Decoded path of URL. + + / for absolute URLs without path part. + + """ + return PATH_UNQUOTER(self._path) if self._path else "/" if self._netloc else "" + + @cached_property + def path_safe(self) -> str: + """Decoded path of URL. + + / for absolute URLs without path part. + + / (%2F) and % (%25) are not decoded + + """ + if self._path: + return PATH_SAFE_UNQUOTER(self._path) + return "/" if self._netloc else "" + + @cached_property + def _parsed_query(self) -> list[tuple[str, str]]: + """Parse query part of URL.""" + return query_to_pairs(self._query) + + @cached_property + def query(self) -> "MultiDictProxy[str]": + """A MultiDictProxy representing parsed query parameters in decoded + representation. + + Empty value if URL has no query part. + + """ + return MultiDictProxy(MultiDict(self._parsed_query)) + + @cached_property + def raw_query_string(self) -> str: + """Encoded query part of URL. + + Empty string if query is missing. + + """ + return self._query + + @cached_property + def query_string(self) -> str: + """Decoded query part of URL. + + Empty string if query is missing. + + """ + return QS_UNQUOTER(self._query) if self._query else "" + + @cached_property + def path_qs(self) -> str: + """Decoded path of URL with query.""" + return self.path if not (q := self.query_string) else f"{self.path}?{q}" + + @cached_property + def raw_path_qs(self) -> str: + """Encoded path of URL with query.""" + if q := self._query: + return f"{self._path}?{q}" if self._path or not self._netloc else f"/?{q}" + return self._path if self._path or not self._netloc else "/" + + @cached_property + def raw_fragment(self) -> str: + """Encoded fragment part of URL. + + Empty string if fragment is missing. + + """ + return self._fragment + + @cached_property + def fragment(self) -> str: + """Decoded fragment part of URL. + + Empty string if fragment is missing. + + """ + return UNQUOTER(self._fragment) if self._fragment else "" + + @cached_property + def raw_parts(self) -> tuple[str, ...]: + """A tuple containing encoded *path* parts. + + ('/',) for absolute URLs if *path* is missing. + + """ + path = self._path + if self._netloc: + return ("/", *path[1:].split("/")) if path else ("/",) + if path and path[0] == "/": + return ("/", *path[1:].split("/")) + return tuple(path.split("/")) + + @cached_property + def parts(self) -> tuple[str, ...]: + """A tuple containing decoded *path* parts. + + ('/',) for absolute URLs if *path* is missing. + + """ + return tuple(UNQUOTER(part) for part in self.raw_parts) + + @cached_property + def parent(self) -> "URL": + """A new URL with last part of path removed and cleaned up query and + fragment. + + """ + path = self._path + if not path or path == "/": + if self._fragment or self._query: + return from_parts(self._scheme, self._netloc, path, "", "") + return self + parts = path.split("/") + return from_parts(self._scheme, self._netloc, "/".join(parts[:-1]), "", "") + + @cached_property + def raw_name(self) -> str: + """The last part of raw_parts.""" + parts = self.raw_parts + if not self._netloc: + return parts[-1] + parts = parts[1:] + return parts[-1] if parts else "" + + @cached_property + def name(self) -> str: + """The last part of parts.""" + return UNQUOTER(self.raw_name) + + @cached_property + def raw_suffix(self) -> str: + name = self.raw_name + i = name.rfind(".") + return name[i:] if 0 < i < len(name) - 1 else "" + + @cached_property + def suffix(self) -> str: + return UNQUOTER(self.raw_suffix) + + @cached_property + def raw_suffixes(self) -> tuple[str, ...]: + name = self.raw_name + if name.endswith("."): + return () + name = name.lstrip(".") + return tuple("." + suffix for suffix in name.split(".")[1:]) + + @cached_property + def suffixes(self) -> tuple[str, ...]: + return tuple(UNQUOTER(suffix) for suffix in self.raw_suffixes) + + def _make_child(self, paths: "Sequence[str]", encoded: bool = False) -> "URL": + """ + add paths to self._path, accounting for absolute vs relative paths, + keep existing, but do not create new, empty segments + """ + parsed: list[str] = [] + needs_normalize: bool = False + for idx, path in enumerate(reversed(paths)): + # empty segment of last is not removed + last = idx == 0 + if path and path[0] == "/": + raise ValueError( + f"Appending path {path!r} starting from slash is forbidden" + ) + # We need to quote the path if it is not already encoded + # This cannot be done at the end because the existing + # path is already quoted and we do not want to double quote + # the existing path. + path = path if encoded else PATH_QUOTER(path) + needs_normalize |= "." in path + segments = path.split("/") + segments.reverse() + # remove trailing empty segment for all but the last path + parsed += segments[1:] if not last and segments[0] == "" else segments + + if (path := self._path) and (old_segments := path.split("/")): + # If the old path ends with a slash, the last segment is an empty string + # and should be removed before adding the new path segments. + old = old_segments[:-1] if old_segments[-1] == "" else old_segments + old.reverse() + parsed += old + + # If the netloc is present, inject a leading slash when adding a + # path to an absolute URL where there was none before. + if (netloc := self._netloc) and parsed and parsed[-1] != "": + parsed.append("") + + parsed.reverse() + if not netloc or not needs_normalize: + return from_parts(self._scheme, netloc, "/".join(parsed), "", "") + + path = "/".join(normalize_path_segments(parsed)) + # If normalizing the path segments removed the leading slash, add it back. + if path and path[0] != "/": + path = f"/{path}" + return from_parts(self._scheme, netloc, path, "", "") + + def with_scheme(self, scheme: str) -> "URL": + """Return a new URL with scheme replaced.""" + # N.B. doesn't cleanup query/fragment + if not isinstance(scheme, str): + raise TypeError("Invalid scheme type") + lower_scheme = scheme.lower() + netloc = self._netloc + if not netloc and lower_scheme in SCHEME_REQUIRES_HOST: + msg = ( + "scheme replacement is not allowed for " + f"relative URLs for the {lower_scheme} scheme" + ) + raise ValueError(msg) + return from_parts(lower_scheme, netloc, self._path, self._query, self._fragment) + + def with_user(self, user: Union[str, None]) -> "URL": + """Return a new URL with user replaced. + + Autoencode user if needed. + + Clear user/password if user is None. + + """ + # N.B. doesn't cleanup query/fragment + if user is None: + password = None + elif isinstance(user, str): + user = QUOTER(user) + password = self.raw_password + else: + raise TypeError("Invalid user type") + if not (netloc := self._netloc): + raise ValueError("user replacement is not allowed for relative URLs") + encoded_host = self.host_subcomponent or "" + netloc = make_netloc(user, password, encoded_host, self.explicit_port) + return from_parts(self._scheme, netloc, self._path, self._query, self._fragment) + + def with_password(self, password: Union[str, None]) -> "URL": + """Return a new URL with password replaced. + + Autoencode password if needed. + + Clear password if argument is None. + + """ + # N.B. doesn't cleanup query/fragment + if password is None: + pass + elif isinstance(password, str): + password = QUOTER(password) + else: + raise TypeError("Invalid password type") + if not (netloc := self._netloc): + raise ValueError("password replacement is not allowed for relative URLs") + encoded_host = self.host_subcomponent or "" + port = self.explicit_port + netloc = make_netloc(self.raw_user, password, encoded_host, port) + return from_parts(self._scheme, netloc, self._path, self._query, self._fragment) + + def with_host(self, host: str) -> "URL": + """Return a new URL with host replaced. + + Autoencode host if needed. + + Changing host for relative URLs is not allowed, use .join() + instead. + + """ + # N.B. doesn't cleanup query/fragment + if not isinstance(host, str): + raise TypeError("Invalid host type") + if not (netloc := self._netloc): + raise ValueError("host replacement is not allowed for relative URLs") + if not host: + raise ValueError("host removing is not allowed") + encoded_host = _encode_host(host, validate_host=True) if host else "" + port = self.explicit_port + netloc = make_netloc(self.raw_user, self.raw_password, encoded_host, port) + return from_parts(self._scheme, netloc, self._path, self._query, self._fragment) + + def with_port(self, port: Union[int, None]) -> "URL": + """Return a new URL with port replaced. + + Clear port to default if None is passed. + + """ + # N.B. doesn't cleanup query/fragment + if port is not None: + if isinstance(port, bool) or not isinstance(port, int): + raise TypeError(f"port should be int or None, got {type(port)}") + if not (0 <= port <= 65535): + raise ValueError(f"port must be between 0 and 65535, got {port}") + if not (netloc := self._netloc): + raise ValueError("port replacement is not allowed for relative URLs") + encoded_host = self.host_subcomponent or "" + netloc = make_netloc(self.raw_user, self.raw_password, encoded_host, port) + return from_parts(self._scheme, netloc, self._path, self._query, self._fragment) + + def with_path( + self, + path: str, + *, + encoded: bool = False, + keep_query: bool = False, + keep_fragment: bool = False, + ) -> "URL": + """Return a new URL with path replaced.""" + netloc = self._netloc + if not encoded: + path = PATH_QUOTER(path) + if netloc: + path = normalize_path(path) if "." in path else path + if path and path[0] != "/": + path = f"/{path}" + query = self._query if keep_query else "" + fragment = self._fragment if keep_fragment else "" + return from_parts(self._scheme, netloc, path, query, fragment) + + @overload + def with_query(self, query: Query) -> "URL": ... + + @overload + def with_query(self, **kwargs: QueryVariable) -> "URL": ... + + def with_query(self, *args: Any, **kwargs: Any) -> "URL": + """Return a new URL with query part replaced. + + Accepts any Mapping (e.g. dict, multidict.MultiDict instances) + or str, autoencode the argument if needed. + + A sequence of (key, value) pairs is supported as well. + + It also can take an arbitrary number of keyword arguments. + + Clear query if None is passed. + + """ + # N.B. doesn't cleanup query/fragment + query = get_str_query(*args, **kwargs) or "" + return from_parts_uncached( + self._scheme, self._netloc, self._path, query, self._fragment + ) + + @overload + def extend_query(self, query: Query) -> "URL": ... + + @overload + def extend_query(self, **kwargs: QueryVariable) -> "URL": ... + + def extend_query(self, *args: Any, **kwargs: Any) -> "URL": + """Return a new URL with query part combined with the existing. + + This method will not remove existing query parameters. + + Example: + >>> url = URL('http://example.com/?a=1&b=2') + >>> url.extend_query(a=3, c=4) + URL('http://example.com/?a=1&b=2&a=3&c=4') + """ + if not (new_query := get_str_query(*args, **kwargs)): + return self + if query := self._query: + # both strings are already encoded so we can use a simple + # string join + query += new_query if query[-1] == "&" else f"&{new_query}" + else: + query = new_query + return from_parts_uncached( + self._scheme, self._netloc, self._path, query, self._fragment + ) + + @overload + def update_query(self, query: Query) -> "URL": ... + + @overload + def update_query(self, **kwargs: QueryVariable) -> "URL": ... + + def update_query(self, *args: Any, **kwargs: Any) -> "URL": + """Return a new URL with query part updated. + + This method will overwrite existing query parameters. + + Example: + >>> url = URL('http://example.com/?a=1&b=2') + >>> url.update_query(a=3, c=4) + URL('http://example.com/?a=3&b=2&c=4') + """ + in_query: Union[ + str, + Mapping[str, QueryVariable], + Sequence[tuple[Union[str, istr], SimpleQuery]], + None, + ] + if kwargs: + if args: + msg = "Either kwargs or single query parameter must be present" + raise ValueError(msg) + in_query = kwargs + elif len(args) == 1: + in_query = args[0] + else: + raise ValueError("Either kwargs or single query parameter must be present") + + if in_query is None: + query = "" + elif not in_query: + query = self._query + elif isinstance(in_query, Mapping): + qm: MultiDict[QueryVariable] = MultiDict(self._parsed_query) + qm.update(in_query) + query = get_str_query_from_sequence_iterable(qm.items()) + elif isinstance(in_query, str): + qstr: MultiDict[str] = MultiDict(self._parsed_query) + qstr.update(query_to_pairs(in_query)) + query = get_str_query_from_iterable(qstr.items()) + elif isinstance(in_query, (bytes, bytearray, memoryview)): + msg = "Invalid query type: bytes, bytearray and memoryview are forbidden" + raise TypeError(msg) + elif isinstance(in_query, Sequence): + # We don't expect sequence values if we're given a list of pairs + # already; only mappings like builtin `dict` which can't have the + # same key pointing to multiple values are allowed to use + # `_query_seq_pairs`. + if TYPE_CHECKING: + in_query = cast( + Sequence[tuple[Union[str, istr], SimpleQuery]], in_query + ) + qs: MultiDict[SimpleQuery] = MultiDict(self._parsed_query) + qs.update(in_query) + query = get_str_query_from_iterable(qs.items()) + else: + raise TypeError( + "Invalid query type: only str, mapping or " + "sequence of (key, value) pairs is allowed" + ) + return from_parts_uncached( + self._scheme, self._netloc, self._path, query, self._fragment + ) + + def without_query_params(self, *query_params: str) -> "URL": + """Remove some keys from query part and return new URL.""" + params_to_remove = set(query_params) & self.query.keys() + if not params_to_remove: + return self + return self.with_query( + tuple( + (name, value) + for name, value in self.query.items() + if name not in params_to_remove + ) + ) + + def with_fragment(self, fragment: Union[str, None]) -> "URL": + """Return a new URL with fragment replaced. + + Autoencode fragment if needed. + + Clear fragment to default if None is passed. + + """ + # N.B. doesn't cleanup query/fragment + if fragment is None: + raw_fragment = "" + elif not isinstance(fragment, str): + raise TypeError("Invalid fragment type") + else: + raw_fragment = FRAGMENT_QUOTER(fragment) + if self._fragment == raw_fragment: + return self + return from_parts( + self._scheme, self._netloc, self._path, self._query, raw_fragment + ) + + def with_name( + self, + name: str, + *, + keep_query: bool = False, + keep_fragment: bool = False, + ) -> "URL": + """Return a new URL with name (last part of path) replaced. + + Query and fragment parts are cleaned up. + + Name is encoded if needed. + + """ + # N.B. DOES cleanup query/fragment + if not isinstance(name, str): + raise TypeError("Invalid name type") + if "/" in name: + raise ValueError("Slash in name is not allowed") + name = PATH_QUOTER(name) + if name in (".", ".."): + raise ValueError(". and .. values are forbidden") + parts = list(self.raw_parts) + if netloc := self._netloc: + if len(parts) == 1: + parts.append(name) + else: + parts[-1] = name + parts[0] = "" # replace leading '/' + else: + parts[-1] = name + if parts[0] == "/": + parts[0] = "" # replace leading '/' + + query = self._query if keep_query else "" + fragment = self._fragment if keep_fragment else "" + return from_parts(self._scheme, netloc, "/".join(parts), query, fragment) + + def with_suffix( + self, + suffix: str, + *, + keep_query: bool = False, + keep_fragment: bool = False, + ) -> "URL": + """Return a new URL with suffix (file extension of name) replaced. + + Query and fragment parts are cleaned up. + + suffix is encoded if needed. + """ + if not isinstance(suffix, str): + raise TypeError("Invalid suffix type") + if suffix and not suffix[0] == "." or suffix == "." or "/" in suffix: + raise ValueError(f"Invalid suffix {suffix!r}") + name = self.raw_name + if not name: + raise ValueError(f"{self!r} has an empty name") + old_suffix = self.raw_suffix + suffix = PATH_QUOTER(suffix) + name = name + suffix if not old_suffix else name[: -len(old_suffix)] + suffix + if name in (".", ".."): + raise ValueError(". and .. values are forbidden") + parts = list(self.raw_parts) + if netloc := self._netloc: + if len(parts) == 1: + parts.append(name) + else: + parts[-1] = name + parts[0] = "" # replace leading '/' + else: + parts[-1] = name + if parts[0] == "/": + parts[0] = "" # replace leading '/' + + query = self._query if keep_query else "" + fragment = self._fragment if keep_fragment else "" + return from_parts(self._scheme, netloc, "/".join(parts), query, fragment) + + def join(self, url: "URL") -> "URL": + """Join URLs + + Construct a full (“absolute”) URL by combining a “base URL” + (self) with another URL (url). + + Informally, this uses components of the base URL, in + particular the addressing scheme, the network location and + (part of) the path, to provide missing components in the + relative URL. + + """ + if type(url) is not URL: + raise TypeError("url should be URL") + + scheme = url._scheme or self._scheme + if scheme != self._scheme or scheme not in USES_RELATIVE: + return url + + # scheme is in uses_authority as uses_authority is a superset of uses_relative + if (join_netloc := url._netloc) and scheme in USES_AUTHORITY: + return from_parts(scheme, join_netloc, url._path, url._query, url._fragment) + + orig_path = self._path + if join_path := url._path: + if join_path[0] == "/": + path = join_path + elif not orig_path: + path = f"/{join_path}" + elif orig_path[-1] == "/": + path = f"{orig_path}{join_path}" + else: + # … + # and relativizing ".." + # parts[0] is / for absolute urls, + # this join will add a double slash there + path = "/".join([*self.parts[:-1], ""]) + join_path + # which has to be removed + if orig_path[0] == "/": + path = path[1:] + path = normalize_path(path) if "." in path else path + else: + path = orig_path + + return from_parts( + scheme, + self._netloc, + path, + url._query if join_path or url._query else self._query, + url._fragment if join_path or url._fragment else self._fragment, + ) + + def joinpath(self, *other: str, encoded: bool = False) -> "URL": + """Return a new URL with the elements in other appended to the path.""" + return self._make_child(other, encoded=encoded) + + def human_repr(self) -> str: + """Return decoded human readable string for URL representation.""" + user = human_quote(self.user, "#/:?@[]") + password = human_quote(self.password, "#/:?@[]") + if (host := self.host) and ":" in host: + host = f"[{host}]" + path = human_quote(self.path, "#?") + if TYPE_CHECKING: + assert path is not None + query_string = "&".join( + "{}={}".format(human_quote(k, "#&+;="), human_quote(v, "#&+;=")) + for k, v in self.query.items() + ) + fragment = human_quote(self.fragment, "") + if TYPE_CHECKING: + assert fragment is not None + netloc = make_netloc(user, password, host, self.explicit_port) + return unsplit_result(self._scheme, netloc, path, query_string, fragment) + + +_DEFAULT_IDNA_SIZE = 256 +_DEFAULT_ENCODE_SIZE = 512 + + +@lru_cache(_DEFAULT_IDNA_SIZE) +def _idna_decode(raw: str) -> str: + try: + return idna.decode(raw.encode("ascii")) + except UnicodeError: # e.g. '::1' + return raw.encode("ascii").decode("idna") + + +@lru_cache(_DEFAULT_IDNA_SIZE) +def _idna_encode(host: str) -> str: + try: + return idna.encode(host, uts46=True).decode("ascii") + except UnicodeError: + return host.encode("idna").decode("ascii") + + +@lru_cache(_DEFAULT_ENCODE_SIZE) +def _encode_host(host: str, validate_host: bool) -> str: + """Encode host part of URL.""" + # If the host ends with a digit or contains a colon, its likely + # an IP address. + if host and (host[-1].isdigit() or ":" in host): + raw_ip, sep, zone = host.partition("%") + # If it looks like an IP, we check with _ip_compressed_version + # and fall-through if its not an IP address. This is a performance + # optimization to avoid parsing IP addresses as much as possible + # because it is orders of magnitude slower than almost any other + # operation this library does. + # Might be an IP address, check it + # + # IP Addresses can look like: + # https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.2 + # - 127.0.0.1 (last character is a digit) + # - 2001:db8::ff00:42:8329 (contains a colon) + # - 2001:db8::ff00:42:8329%eth0 (contains a colon) + # - [2001:db8::ff00:42:8329] (contains a colon -- brackets should + # have been removed before it gets here) + # Rare IP Address formats are not supported per: + # https://datatracker.ietf.org/doc/html/rfc3986#section-7.4 + # + # IP parsing is slow, so its wrapped in an LRU + try: + ip = ip_address(raw_ip) + except ValueError: + pass + else: + # These checks should not happen in the + # LRU to keep the cache size small + host = ip.compressed + if ip.version == 6: + return f"[{host}%{zone}]" if sep else f"[{host}]" + return f"{host}%{zone}" if sep else host + + # IDNA encoding is slow, skip it for ASCII-only strings + if host.isascii(): + # Check for invalid characters explicitly; _idna_encode() does this + # for non-ascii host names. + host = host.lower() + if validate_host and (invalid := NOT_REG_NAME.search(host)): + value, pos, extra = invalid.group(), invalid.start(), "" + if value == "@" or (value == ":" and "@" in host[pos:]): + # this looks like an authority string + extra = ( + ", if the value includes a username or password, " + "use 'authority' instead of 'host'" + ) + raise ValueError( + f"Host {host!r} cannot contain {value!r} (at position {pos}){extra}" + ) from None + return host + + return _idna_encode(host) + + +@rewrite_module +def cache_clear() -> None: + """Clear all LRU caches.""" + _idna_encode.cache_clear() + _idna_decode.cache_clear() + _encode_host.cache_clear() + + +@rewrite_module +def cache_info() -> CacheInfo: + """Report cache statistics.""" + return { + "idna_encode": _idna_encode.cache_info(), + "idna_decode": _idna_decode.cache_info(), + "ip_address": _encode_host.cache_info(), + "host_validate": _encode_host.cache_info(), + "encode_host": _encode_host.cache_info(), + } + + +@rewrite_module +def cache_configure( + *, + idna_encode_size: Union[int, None] = _DEFAULT_IDNA_SIZE, + idna_decode_size: Union[int, None] = _DEFAULT_IDNA_SIZE, + ip_address_size: Union[int, None, UndefinedType] = UNDEFINED, + host_validate_size: Union[int, None, UndefinedType] = UNDEFINED, + encode_host_size: Union[int, None, UndefinedType] = UNDEFINED, +) -> None: + """Configure LRU cache sizes.""" + global _idna_decode, _idna_encode, _encode_host + # ip_address_size, host_validate_size are no longer + # used, but are kept for backwards compatibility. + if ip_address_size is not UNDEFINED or host_validate_size is not UNDEFINED: + warnings.warn( + "cache_configure() no longer accepts the " + "ip_address_size or host_validate_size arguments, " + "they are used to set the encode_host_size instead " + "and will be removed in the future", + DeprecationWarning, + stacklevel=2, + ) + + if encode_host_size is not None: + for size in (ip_address_size, host_validate_size): + if size is None: + encode_host_size = None + elif encode_host_size is UNDEFINED: + if size is not UNDEFINED: + encode_host_size = size + elif size is not UNDEFINED: + if TYPE_CHECKING: + assert isinstance(size, int) + assert isinstance(encode_host_size, int) + encode_host_size = max(size, encode_host_size) + if encode_host_size is UNDEFINED: + encode_host_size = _DEFAULT_ENCODE_SIZE + + _encode_host = lru_cache(encode_host_size)(_encode_host.__wrapped__) + _idna_decode = lru_cache(idna_decode_size)(_idna_decode.__wrapped__) + _idna_encode = lru_cache(idna_encode_size)(_idna_encode.__wrapped__) diff --git a/source/yarl/py.typed b/source/yarl/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..dcf2c804da5e19d617a03a6c68aa128d1d1f89a0 --- /dev/null +++ b/source/yarl/py.typed @@ -0,0 +1 @@ +# Placeholder diff --git a/source/zmq/__init__.pxd b/source/zmq/__init__.pxd new file mode 100644 index 0000000000000000000000000000000000000000..37b8362e2a07a77d41ff9f6d3b589364a5e00a05 --- /dev/null +++ b/source/zmq/__init__.pxd @@ -0,0 +1 @@ +from zmq.backend.cython cimport Context, Frame, Socket, libzmq diff --git a/source/zmq/__init__.py b/source/zmq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ab0326be0b64dc82b37d383fc6dd5687b6219eb7 --- /dev/null +++ b/source/zmq/__init__.py @@ -0,0 +1,97 @@ +"""Python bindings for 0MQ""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from __future__ import annotations + +import os +import sys +from contextlib import contextmanager + + +@contextmanager +def _libs_on_path(): + """context manager for libs directory on $PATH + + Works around mysterious issue where os.add_dll_directory + does not resolve imports (conda-forge Python >= 3.8) + """ + + if not sys.platform.startswith("win"): + yield + return + + libs_dir = os.path.abspath( + os.path.join( + os.path.dirname(__file__), + os.pardir, + "pyzmq.libs", + ) + ) + if not os.path.exists(libs_dir): + # no bundled libs + yield + return + + path_before = os.environ.get("PATH") + try: + os.environ["PATH"] = os.pathsep.join([path_before or "", libs_dir]) + yield + finally: + if path_before is None: + os.environ.pop("PATH") + else: + os.environ["PATH"] = path_before + + +# zmq top-level imports + +# workaround for Windows +with _libs_on_path(): + from zmq import backend + +from . import constants # noqa +from .constants import * # noqa +from zmq.backend import * # noqa +from zmq import sugar +from zmq.sugar import * # noqa + + +def get_includes(): + """Return a list of directories to include for linking against pyzmq with cython.""" + from os.path import abspath, dirname, exists, join, pardir + + base = dirname(__file__) + parent = abspath(join(base, pardir)) + includes = [parent] + [join(parent, base, subdir) for subdir in ('utils',)] + if exists(join(parent, base, 'include')): + includes.append(join(parent, base, 'include')) + return includes + + +def get_library_dirs(): + """Return a list of directories used to link against pyzmq's bundled libzmq.""" + from os.path import abspath, dirname, join, pardir + + base = dirname(__file__) + parent = abspath(join(base, pardir)) + return [join(parent, base)] + + +COPY_THRESHOLD = 65536 +# zmq.DRAFT_API represents _both_ the current runtime-loaded libzmq +# and pyzmq were built with drafts, +# which is required for pyzmq draft support +DRAFT_API: bool = backend.has('draft') and backend.PYZMQ_DRAFT_API + +__all__ = ( + [ + 'get_includes', + 'COPY_THRESHOLD', + 'DRAFT_API', + ] + + constants.__all__ + + sugar.__all__ + + backend.__all__ +) diff --git a/source/zmq/__init__.pyi b/source/zmq/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..56f4bdcaf0c08c199b0da853bae063d5d392fe5a --- /dev/null +++ b/source/zmq/__init__.pyi @@ -0,0 +1,28 @@ +from typing import List + +from . import backend, sugar + +COPY_THRESHOLD: int +DRAFT_API: bool +__version__: str + +# mypy doesn't like overwriting symbols with * so be explicit +# about what comes from backend, not from sugar +# see tools/backend_imports.py to generate this list +# note: `x as x` is required for re-export +# see https://github.com/python/mypy/issues/2190 +from .backend import IPC_PATH_MAX_LEN as IPC_PATH_MAX_LEN +from .backend import curve_keypair as curve_keypair +from .backend import curve_public as curve_public +from .backend import has as has +from .backend import proxy as proxy +from .backend import proxy_steerable as proxy_steerable +from .backend import strerror as strerror +from .backend import zmq_errno as zmq_errno +from .backend import zmq_poll as zmq_poll +from .constants import * +from .error import * +from .sugar import * + +def get_includes() -> list[str]: ... +def get_library_dirs() -> list[str]: ... diff --git a/source/zmq/_future.py b/source/zmq/_future.py new file mode 100644 index 0000000000000000000000000000000000000000..e598de14f768dc8540684bfa79589c00eeeceed0 --- /dev/null +++ b/source/zmq/_future.py @@ -0,0 +1,737 @@ +"""Future-returning APIs for coroutines.""" + +# Copyright (c) PyZMQ Developers. +# Distributed under the terms of the Modified BSD License. +from __future__ import annotations + +import warnings +from asyncio import Future +from collections import deque +from functools import partial +from itertools import chain +from typing import ( + Any, + Awaitable, + Callable, + NamedTuple, + TypeVar, + cast, +) + +import zmq as _zmq +from zmq import EVENTS, POLLIN, POLLOUT + + +class _FutureEvent(NamedTuple): + future: Future + kind: str + args: tuple + kwargs: dict + msg: Any + timer: Any + + +# These are incomplete classes and need a Mixin for compatibility with an eventloop +# defining the following attributes: +# +# _Future +# _READ +# _WRITE +# _default_loop() + + +class _Async: + """Mixin for common async logic""" + + _current_loop: Any = None + _Future: type[Future] + + def _get_loop(self) -> Any: + """Get event loop + + Notice if event loop has changed, + and register init_io_state on activation of a new event loop + """ + if self._current_loop is None: + self._current_loop = self._default_loop() + self._init_io_state(self._current_loop) + return self._current_loop + current_loop = self._default_loop() + if current_loop is not self._current_loop: + # warn? This means a socket is being used in multiple loops! + self._current_loop = current_loop + self._init_io_state(current_loop) + return current_loop + + def _default_loop(self) -> Any: + raise NotImplementedError("Must be implemented in a subclass") + + def _init_io_state(self, loop=None) -> None: + pass + + +class _AsyncPoller(_Async, _zmq.Poller): + """Poller that returns a Future on poll, instead of blocking.""" + + _socket_class: type[_AsyncSocket] + _READ: int + _WRITE: int + raw_sockets: list[Any] + + def _watch_raw_socket(self, loop: Any, socket: Any, evt: int, f: Callable) -> None: + """Schedule callback for a raw socket""" + raise NotImplementedError() + + def _unwatch_raw_sockets(self, loop: Any, *sockets: Any) -> None: + """Unschedule callback for a raw socket""" + raise NotImplementedError() + + def poll(self, timeout=-1) -> Awaitable[list[tuple[Any, int]]]: # type: ignore + """Return a Future for a poll event""" + future = self._Future() + if timeout == 0: + try: + result = super().poll(0) + except Exception as e: + future.set_exception(e) + else: + future.set_result(result) + return future + + loop = self._get_loop() + + # register Future to be called as soon as any event is available on any socket + watcher = self._Future() + + # watch raw sockets: + raw_sockets: list[Any] = [] + + def wake_raw(*args): + if not watcher.done(): + watcher.set_result(None) + + watcher.add_done_callback( + lambda f: self._unwatch_raw_sockets(loop, *raw_sockets) + ) + + wrapped_sockets: list[_AsyncSocket] = [] + + def _clear_wrapper_io(f): + for s in wrapped_sockets: + s._clear_io_state() + + for socket, mask in self.sockets: + if isinstance(socket, _zmq.Socket): + if not isinstance(socket, self._socket_class): + # it's a blocking zmq.Socket, wrap it in async + socket = self._socket_class.from_socket(socket) + wrapped_sockets.append(socket) + if mask & _zmq.POLLIN: + socket._add_recv_event('poll', future=watcher) + if mask & _zmq.POLLOUT: + socket._add_send_event('poll', future=watcher) + else: + raw_sockets.append(socket) + evt = 0 + if mask & _zmq.POLLIN: + evt |= self._READ + if mask & _zmq.POLLOUT: + evt |= self._WRITE + self._watch_raw_socket(loop, socket, evt, wake_raw) + + def on_poll_ready(f): + if future.done(): + return + if watcher.cancelled(): + try: + future.cancel() + except RuntimeError: + # RuntimeError may be called during teardown + pass + return + if watcher.exception(): + future.set_exception(watcher.exception()) + else: + try: + result = super(_AsyncPoller, self).poll(0) + except Exception as e: + future.set_exception(e) + else: + future.set_result(result) + + watcher.add_done_callback(on_poll_ready) + + if wrapped_sockets: + watcher.add_done_callback(_clear_wrapper_io) + + if timeout is not None and timeout > 0: + # schedule cancel to fire on poll timeout, if any + def trigger_timeout(): + if not watcher.done(): + watcher.set_result(None) + + timeout_handle = loop.call_later(1e-3 * timeout, trigger_timeout) + + def cancel_timeout(f): + if hasattr(timeout_handle, 'cancel'): + timeout_handle.cancel() + else: + loop.remove_timeout(timeout_handle) + + future.add_done_callback(cancel_timeout) + + def cancel_watcher(f): + if not watcher.done(): + watcher.cancel() + + future.add_done_callback(cancel_watcher) + + return future + + +class _NoTimer: + @staticmethod + def cancel(): + pass + + +T = TypeVar("T", bound="_AsyncSocket") + + +class _AsyncSocket(_Async, _zmq.Socket[Future]): + # Warning : these class variables are only here to allow to call super().__setattr__. + # They be overridden at instance initialization and not shared in the whole class + _recv_futures = None + _send_futures = None + _state = 0 + _shadow_sock: _zmq.Socket + _poller_class = _AsyncPoller + _fd = None + + def __init__( + self, + context=None, + socket_type=-1, + io_loop=None, + _from_socket: _zmq.Socket | None = None, + **kwargs, + ) -> None: + if isinstance(context, _zmq.Socket): + context, _from_socket = (None, context) + if _from_socket is not None: + super().__init__(shadow=_from_socket.underlying) # type: ignore + self._shadow_sock = _from_socket + else: + super().__init__(context, socket_type, **kwargs) # type: ignore + self._shadow_sock = _zmq.Socket.shadow(self.underlying) + + if io_loop is not None: + warnings.warn( + f"{self.__class__.__name__}(io_loop) argument is deprecated in pyzmq 22.2." + " The currently active loop will always be used.", + DeprecationWarning, + stacklevel=3, + ) + self._recv_futures = deque() + self._send_futures = deque() + self._state = 0 + self._fd = self._shadow_sock.FD + + @classmethod + def from_socket(cls: type[T], socket: _zmq.Socket, io_loop: Any = None) -> T: + """Create an async socket from an existing Socket""" + return cls(_from_socket=socket, io_loop=io_loop) + + def close(self, linger: int | None = None) -> None: + if not self.closed and self._fd is not None: + event_list: list[_FutureEvent] = list( + chain(self._recv_futures or [], self._send_futures or []) + ) + for event in event_list: + if not event.future.done(): + try: + event.future.cancel() + except RuntimeError: + # RuntimeError may be called during teardown + pass + self._clear_io_state() + super().close(linger=linger) + + close.__doc__ = _zmq.Socket.close.__doc__ + + def get(self, key): + result = super().get(key) + if key == EVENTS: + self._schedule_remaining_events(result) + return result + + get.__doc__ = _zmq.Socket.get.__doc__ + + def recv_multipart( + self, flags: int = 0, copy: bool = True, track: bool = False + ) -> Awaitable[list[bytes] | list[_zmq.Frame]]: + """Receive a complete multipart zmq message. + + Returns a Future whose result will be a multipart message. + """ + return self._add_recv_event( + 'recv_multipart', kwargs=dict(flags=flags, copy=copy, track=track) + ) + + def recv( # type: ignore + self, flags: int = 0, copy: bool = True, track: bool = False + ) -> Awaitable[bytes | _zmq.Frame]: + """Receive a single zmq frame. + + Returns a Future, whose result will be the received frame. + + Recommend using recv_multipart instead. + """ + return self._add_recv_event( + 'recv', kwargs=dict(flags=flags, copy=copy, track=track) + ) + + def recv_into( # type: ignore + self, buf, /, *, nbytes: int = 0, flags: int = 0 + ) -> Awaitable[int]: + """Receive a single zmq frame into a pre-allocated buffer. + + Returns a Future, whose result will be the number of bytes received. + """ + return self._add_recv_event( + 'recv_into', args=(buf,), kwargs=dict(nbytes=nbytes, flags=flags) + ) + + def send_multipart( # type: ignore + self, msg_parts: Any, flags: int = 0, copy: bool = True, track=False, **kwargs + ) -> Awaitable[_zmq.MessageTracker | None]: + """Send a complete multipart zmq message. + + Returns a Future that resolves when sending is complete. + """ + kwargs['flags'] = flags + kwargs['copy'] = copy + kwargs['track'] = track + return self._add_send_event('send_multipart', msg=msg_parts, kwargs=kwargs) + + def send( # type: ignore + self, + data: Any, + flags: int = 0, + copy: bool = True, + track: bool = False, + **kwargs: Any, + ) -> Awaitable[_zmq.MessageTracker | None]: + """Send a single zmq frame. + + Returns a Future that resolves when sending is complete. + + Recommend using send_multipart instead. + """ + kwargs['flags'] = flags + kwargs['copy'] = copy + kwargs['track'] = track + kwargs.update(dict(flags=flags, copy=copy, track=track)) + return self._add_send_event('send', msg=data, kwargs=kwargs) + + def _deserialize(self, recvd, load): + """Deserialize with Futures""" + f = self._Future() + + def _chain(_): + """Chain result through serialization to recvd""" + if f.done(): + # chained future may be cancelled, which means nobody is going to get this result + # if it's an error, that's no big deal (probably zmq.Again), + # but if it's a successful recv, this is a dropped message! + if not recvd.cancelled() and recvd.exception() is None: + warnings.warn( + # is there a useful stacklevel? + # ideally, it would point to where `f.cancel()` was called + f"Future {f} completed while awaiting {recvd}. A message has been dropped!", + RuntimeWarning, + ) + return + if recvd.exception(): + f.set_exception(recvd.exception()) + else: + buf = recvd.result() + try: + loaded = load(buf) + except Exception as e: + f.set_exception(e) + else: + f.set_result(loaded) + + recvd.add_done_callback(_chain) + + def _chain_cancel(_): + """Chain cancellation from f to recvd""" + if recvd.done(): + return + if f.cancelled(): + recvd.cancel() + + f.add_done_callback(_chain_cancel) + + return f + + def poll(self, timeout=None, flags=_zmq.POLLIN) -> Awaitable[int]: # type: ignore + """poll the socket for events + + returns a Future for the poll results. + """ + + if self.closed: + raise _zmq.ZMQError(_zmq.ENOTSUP) + + p = self._poller_class() + p.register(self, flags) + poll_future = cast(Future, p.poll(timeout)) + + future = self._Future() + + def unwrap_result(f): + if future.done(): + return + if poll_future.cancelled(): + try: + future.cancel() + except RuntimeError: + # RuntimeError may be called during teardown + pass + return + if f.exception(): + future.set_exception(poll_future.exception()) + else: + evts = dict(poll_future.result()) + future.set_result(evts.get(self, 0)) + + if poll_future.done(): + # hook up result if already done + unwrap_result(poll_future) + else: + poll_future.add_done_callback(unwrap_result) + + def cancel_poll(future): + """Cancel underlying poll if request has been cancelled""" + if not poll_future.done(): + try: + poll_future.cancel() + except RuntimeError: + # RuntimeError may be called during teardown + pass + + future.add_done_callback(cancel_poll) + + return future + + def _add_timeout(self, future, timeout): + """Add a timeout for a send or recv Future""" + + def future_timeout(): + if future.done(): + # future already resolved, do nothing + return + + # raise EAGAIN + future.set_exception(_zmq.Again()) + + return self._call_later(timeout, future_timeout) + + def _call_later(self, delay, callback): + """Schedule a function to be called later + + Override for different IOLoop implementations + + Tornado and asyncio happen to both have ioloop.call_later + with the same signature. + """ + return self._get_loop().call_later(delay, callback) + + @staticmethod + def _remove_finished_future(future, event_list, event=None): + """Make sure that futures are removed from the event list when they resolve + + Avoids delaying cleanup until the next send/recv event, + which may never come. + """ + # "future" instance is shared between sockets, but each socket has its own event list. + if not event_list: + return + # only unconsumed events (e.g. cancelled calls) + # will be present when this happens + try: + event_list.remove(event) + except ValueError: + # usually this will have been removed by being consumed + return + + def _add_recv_event( + self, + kind: str, + *, + args: tuple | None = None, + kwargs: dict[str, Any] | None = None, + future: Future | None = None, + ) -> Future: + """Add a recv event, returning the corresponding Future""" + f = future or self._Future() + if args is None: + args = () + if kwargs is None: + kwargs = {} + if kind.startswith('recv') and kwargs.get('flags', 0) & _zmq.DONTWAIT: + # short-circuit non-blocking calls + recv = getattr(self._shadow_sock, kind) + try: + r = recv(*args, **kwargs) + except Exception as e: + f.set_exception(e) + else: + f.set_result(r) + return f + + timer = _NoTimer + if hasattr(_zmq, 'RCVTIMEO'): + timeout_ms = self._shadow_sock.rcvtimeo + if timeout_ms >= 0: + timer = self._add_timeout(f, timeout_ms * 1e-3) + + # we add it to the list of futures before we add the timeout as the + # timeout will remove the future from recv_futures to avoid leaks + _future_event = _FutureEvent( + f, kind, args=args, kwargs=kwargs, msg=None, timer=timer + ) + self._recv_futures.append(_future_event) + + if self._shadow_sock.get(EVENTS) & POLLIN: + # recv immediately, if we can + self._handle_recv() + if self._recv_futures and _future_event in self._recv_futures: + # Don't let the Future sit in _recv_events after it's done + # no need to register this if we've already been handled + # (i.e. immediately-resolved recv) + f.add_done_callback( + partial( + self._remove_finished_future, + event_list=self._recv_futures, + event=_future_event, + ) + ) + self._add_io_state(POLLIN) + return f + + def _add_send_event(self, kind, msg=None, kwargs=None, future=None): + """Add a send event, returning the corresponding Future""" + f = future or self._Future() + # attempt send with DONTWAIT if no futures are waiting + # short-circuit for sends that will resolve immediately + # only call if no send Futures are waiting + if kind in ('send', 'send_multipart') and not self._send_futures: + flags = kwargs.get('flags', 0) + nowait_kwargs = kwargs.copy() + nowait_kwargs['flags'] = flags | _zmq.DONTWAIT + + # short-circuit non-blocking calls + send = getattr(self._shadow_sock, kind) + # track if the send resolved or not + # (EAGAIN if DONTWAIT is not set should proceed with) + finish_early = True + try: + r = send(msg, **nowait_kwargs) + except _zmq.Again as e: + if flags & _zmq.DONTWAIT: + f.set_exception(e) + else: + # EAGAIN raised and DONTWAIT not requested, + # proceed with async send + finish_early = False + except Exception as e: + f.set_exception(e) + else: + f.set_result(r) + + if finish_early: + # short-circuit resolved, return finished Future + # schedule wake for recv if there are any receivers waiting + if self._recv_futures: + self._schedule_remaining_events() + return f + + timer = _NoTimer + if hasattr(_zmq, 'SNDTIMEO'): + timeout_ms = self._shadow_sock.get(_zmq.SNDTIMEO) + if timeout_ms >= 0: + timer = self._add_timeout(f, timeout_ms * 1e-3) + + # we add it to the list of futures before we add the timeout as the + # timeout will remove the future from recv_futures to avoid leaks + _future_event = _FutureEvent( + f, kind, args=(), kwargs=kwargs, msg=msg, timer=timer + ) + self._send_futures.append(_future_event) + # Don't let the Future sit in _send_futures after it's done + f.add_done_callback( + partial( + self._remove_finished_future, + event_list=self._send_futures, + event=_future_event, + ) + ) + + self._add_io_state(POLLOUT) + return f + + def _handle_recv(self): + """Handle recv events""" + if not self._shadow_sock.get(EVENTS) & POLLIN: + # event triggered, but state may have been changed between trigger and callback + return + f = None + while self._recv_futures: + f, kind, args, kwargs, _, timer = self._recv_futures.popleft() + # skip any cancelled futures + if f.done(): + f = None + else: + break + + if not self._recv_futures: + self._drop_io_state(POLLIN) + + if f is None: + return + + timer.cancel() + + if kind == 'poll': + # on poll event, just signal ready, nothing else. + f.set_result(None) + return + elif kind == 'recv_multipart': + recv = self._shadow_sock.recv_multipart + elif kind == 'recv': + recv = self._shadow_sock.recv + elif kind == 'recv_into': + recv = self._shadow_sock.recv_into + else: + raise ValueError(f"Unhandled recv event type: {kind!r}") + + kwargs['flags'] |= _zmq.DONTWAIT + try: + result = recv(*args, **kwargs) + except Exception as e: + f.set_exception(e) + else: + f.set_result(result) + + def _handle_send(self): + if not self._shadow_sock.get(EVENTS) & POLLOUT: + # event triggered, but state may have been changed between trigger and callback + return + f = None + while self._send_futures: + f, kind, args, kwargs, msg, timer = self._send_futures.popleft() + # skip any cancelled futures + if f.done(): + f = None + else: + break + + if not self._send_futures: + self._drop_io_state(POLLOUT) + + if f is None: + return + + timer.cancel() + + if kind == 'poll': + # on poll event, just signal ready, nothing else. + f.set_result(None) + return + elif kind == 'send_multipart': + send = self._shadow_sock.send_multipart + elif kind == 'send': + send = self._shadow_sock.send + else: + raise ValueError(f"Unhandled send event type: {kind!r}") + + kwargs['flags'] |= _zmq.DONTWAIT + try: + result = send(msg, **kwargs) + except Exception as e: + f.set_exception(e) + else: + f.set_result(result) + + # event masking from ZMQStream + def _handle_events(self, fd=0, events=0): + """Dispatch IO events to _handle_recv, etc.""" + if self._shadow_sock.closed: + return + + zmq_events = self._shadow_sock.get(EVENTS) + if zmq_events & _zmq.POLLIN: + self._handle_recv() + if zmq_events & _zmq.POLLOUT: + self._handle_send() + self._schedule_remaining_events() + + def _schedule_remaining_events(self, events=None): + """Schedule a call to handle_events next loop iteration + + If there are still events to handle. + """ + # edge-triggered handling + # allow passing events in, in case this is triggered by retrieving events, + # so we don't have to retrieve it twice. + if self._state == 0: + # not watching for anything, nothing to schedule + return + if events is None: + events = self._shadow_sock.get(EVENTS) + if events & self._state: + self._call_later(0, self._handle_events) + + def _add_io_state(self, state): + """Add io_state to poller.""" + if self._state != state: + state = self._state = self._state | state + self._update_handler(self._state) + + def _drop_io_state(self, state): + """Stop poller from watching an io_state.""" + if self._state & state: + self._state = self._state & (~state) + self._update_handler(self._state) + + def _update_handler(self, state): + """Update IOLoop handler with state. + + zmq FD is always read-only. + """ + # ensure loop is registered and init_io has been called + # if there are any events to watch for + if state: + self._get_loop() + self._schedule_remaining_events() + + def _init_io_state(self, loop=None): + """initialize the ioloop event handler""" + if loop is None: + loop = self._get_loop() + loop.add_handler(self._shadow_sock, self._handle_events, self._READ) + self._call_later(0, self._handle_events) + + def _clear_io_state(self): + """unregister the ioloop event handler + + called once during close + """ + fd = self._shadow_sock + if self._shadow_sock.closed: + fd = self._fd + if self._current_loop is not None: + self._current_loop.remove_handler(fd) diff --git a/source/zmq/_future.pyi b/source/zmq/_future.pyi new file mode 100644 index 0000000000000000000000000000000000000000..e22315e82598249391a3bbffbbf2ecc5dd0d12d5 --- /dev/null +++ b/source/zmq/_future.pyi @@ -0,0 +1,95 @@ +"""type annotations for async sockets""" + +from __future__ import annotations + +from asyncio import Future +from pickle import DEFAULT_PROTOCOL +from typing import Any, Awaitable, Literal, Sequence, TypeVar, overload + +import zmq as _zmq + +class _AsyncPoller(_zmq.Poller): + _socket_class: type[_AsyncSocket] + + def poll(self, timeout=-1) -> Awaitable[list[tuple[Any, int]]]: ... # type: ignore + +T = TypeVar("T", bound="_AsyncSocket") + +class _AsyncSocket(_zmq.Socket[Future]): + @classmethod + def from_socket(cls: type[T], socket: _zmq.Socket, io_loop: Any = None) -> T: ... + def send( # type: ignore + self, + data: Any, + flags: int = 0, + copy: bool = True, + track: bool = False, + routing_id: int | None = None, + group: str | None = None, + ) -> Awaitable[_zmq.MessageTracker | None]: ... + @overload # type: ignore + def recv(self, flags: int = 0, *, track: bool = False) -> Awaitable[bytes]: ... + @overload + def recv( + self, flags: int = 0, *, copy: Literal[True], track: bool = False + ) -> Awaitable[bytes]: ... + @overload + def recv( + self, flags: int = 0, *, copy: Literal[False], track: bool = False + ) -> Awaitable[_zmq.Frame]: ... + @overload + def recv( + self, flags: int = 0, copy: bool = True, track: bool = False + ) -> Awaitable[bytes | _zmq.Frame]: ... + def recv_into( # type: ignore + self, buffer: Any, /, *, nbytes: int = 0, flags: int = 0 + ) -> Awaitable[int]: ... + def send_multipart( # type: ignore + self, + msg_parts: Sequence, + flags: int = 0, + copy: bool = True, + track: bool = False, + routing_id: int | None = None, + group: str | None = None, + ) -> Awaitable[_zmq.MessageTracker | None]: ... + @overload # type: ignore + def recv_multipart( + self, flags: int = 0, *, track: bool = False + ) -> Awaitable[list[bytes]]: ... + @overload + def recv_multipart( + self, flags: int = 0, *, copy: Literal[True], track: bool = False + ) -> Awaitable[list[bytes]]: ... + @overload + def recv_multipart( + self, flags: int = 0, *, copy: Literal[False], track: bool = False + ) -> Awaitable[list[_zmq.Frame]]: ... + @overload + def recv_multipart( + self, flags: int = 0, copy: bool = True, track: bool = False + ) -> Awaitable[list[bytes] | list[_zmq.Frame]]: ... + + # serialization wrappers + + def send_string( # type: ignore + self, + u: str, + flags: int = 0, + copy: bool = True, + *, + encoding: str = 'utf-8', + **kwargs, + ) -> Awaitable[_zmq.Frame | None]: ... + def recv_string( # type: ignore + self, flags: int = 0, encoding: str = 'utf-8' + ) -> Awaitable[str]: ... + def send_pyobj( # type: ignore + self, obj: Any, flags: int = 0, protocol: int = DEFAULT_PROTOCOL, **kwargs + ) -> Awaitable[_zmq.Frame | None]: ... + def recv_pyobj(self, flags: int = 0) -> Awaitable[Any]: ... # type: ignore + def send_json( # type: ignore + self, obj: Any, flags: int = 0, **kwargs + ) -> Awaitable[_zmq.Frame | None]: ... + def recv_json(self, flags: int = 0, **kwargs) -> Awaitable[Any]: ... # type: ignore + def poll(self, timeout=-1) -> Awaitable[list[tuple[Any, int]]]: ... # type: ignore diff --git a/source/zmq/_typing.py b/source/zmq/_typing.py new file mode 100644 index 0000000000000000000000000000000000000000..9833a1b9d2bc5d66207e782773c2a09d8878ebcf --- /dev/null +++ b/source/zmq/_typing.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +import sys + +if sys.version_info >= (3, 10): + from typing import TypeAlias +else: + try: + from typing_extensions import TypeAlias + except ImportError: + TypeAlias = type # type: ignore diff --git a/source/zmq/asyncio.py b/source/zmq/asyncio.py new file mode 100644 index 0000000000000000000000000000000000000000..22bbd14d423cc2fb31efa7846681ab4420a2a6fd --- /dev/null +++ b/source/zmq/asyncio.py @@ -0,0 +1,224 @@ +"""AsyncIO support for zmq + +Requires asyncio and Python 3. +""" + +# Copyright (c) PyZMQ Developers. +# Distributed under the terms of the Modified BSD License. +from __future__ import annotations + +import asyncio +import selectors +import sys +import warnings +from asyncio import Future, SelectorEventLoop +from weakref import WeakKeyDictionary + +import zmq as _zmq +from zmq import _future + +# registry of asyncio loop : selector thread +_selectors: WeakKeyDictionary = WeakKeyDictionary() + + +class ProactorSelectorThreadWarning(RuntimeWarning): + """Warning class for notifying about the extra thread spawned by tornado + + We automatically support proactor via tornado's AddThreadSelectorEventLoop""" + + +def _get_selector_windows( + asyncio_loop, +) -> asyncio.AbstractEventLoop: + """Get selector-compatible loop + + Returns an object with ``add_reader`` family of methods, + either the loop itself or a SelectorThread instance. + + Workaround Windows proactor removal of + *reader methods, which we need for zmq sockets. + """ + + if asyncio_loop in _selectors: + return _selectors[asyncio_loop] + + # detect add_reader instead of checking for proactor? + if hasattr(asyncio, "ProactorEventLoop") and isinstance( + asyncio_loop, + asyncio.ProactorEventLoop, # type: ignore + ): + try: + from tornado.platform.asyncio import AddThreadSelectorEventLoop + except ImportError: + raise RuntimeError( + "Proactor event loop does not implement add_reader family of methods required for zmq." + " zmq will work with proactor if tornado >= 6.1 can be found." + " Use `asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())`" + " or install 'tornado>=6.1' to avoid this error." + ) + + warnings.warn( + "Proactor event loop does not implement add_reader family of methods required for zmq." + " Registering an additional selector thread for add_reader support via tornado." + " Use `asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())`" + " to avoid this warning.", + RuntimeWarning, + # stacklevel 5 matches most likely zmq.asyncio.Context().socket() + stacklevel=5, + ) + + selector_loop = _selectors[asyncio_loop] = AddThreadSelectorEventLoop( + asyncio_loop + ) # type: ignore + + # patch loop.close to also close the selector thread + loop_close = asyncio_loop.close + + def _close_selector_and_loop(): + # restore original before calling selector.close, + # which in turn calls eventloop.close! + asyncio_loop.close = loop_close + _selectors.pop(asyncio_loop, None) + selector_loop.close() + + asyncio_loop.close = _close_selector_and_loop # type: ignore # mypy bug - assign a function to method + return selector_loop + else: + return asyncio_loop + + +def _get_selector_noop(loop) -> asyncio.AbstractEventLoop: + """no-op on non-Windows""" + return loop + + +if sys.platform == "win32": + _get_selector = _get_selector_windows +else: + _get_selector = _get_selector_noop + + +class _AsyncIO: + _Future = Future + _WRITE = selectors.EVENT_WRITE + _READ = selectors.EVENT_READ + + def _default_loop(self): + try: + return asyncio.get_running_loop() + except RuntimeError: + warnings.warn( + "No running event loop. zmq.asyncio should be used from within an asyncio loop.", + RuntimeWarning, + stacklevel=4, + ) + # get_event_loop deprecated in 3.10: + return asyncio.get_event_loop() + + +class Poller(_AsyncIO, _future._AsyncPoller): + """Poller returning asyncio.Future for poll results.""" + + def _watch_raw_socket(self, loop, socket, evt, f): + """Schedule callback for a raw socket""" + selector = _get_selector(loop) + if evt & self._READ: + selector.add_reader(socket, lambda *args: f()) + if evt & self._WRITE: + selector.add_writer(socket, lambda *args: f()) + + def _unwatch_raw_sockets(self, loop, *sockets): + """Unschedule callback for a raw socket""" + selector = _get_selector(loop) + for socket in sockets: + selector.remove_reader(socket) + selector.remove_writer(socket) + + +class Socket(_AsyncIO, _future._AsyncSocket): + """Socket returning asyncio Futures for send/recv/poll methods.""" + + _poller_class = Poller + + def _get_selector(self, io_loop=None): + if io_loop is None: + io_loop = self._get_loop() + return _get_selector(io_loop) + + def _init_io_state(self, io_loop=None): + """initialize the ioloop event handler""" + self._get_selector(io_loop).add_reader( + self._fd, lambda: self._handle_events(0, 0) + ) + + def _clear_io_state(self): + """clear any ioloop event handler + + called once at close + """ + loop = self._current_loop + if loop and not loop.is_closed() and self._fd != -1: + self._get_selector(loop).remove_reader(self._fd) + + +Poller._socket_class = Socket + + +class Context(_zmq.Context[Socket]): + """Context for creating asyncio-compatible Sockets""" + + _socket_class = Socket + + # avoid sharing instance with base Context class + _instance = None + + # overload with no changes to satisfy pyright + def __init__( + self: Context, + io_threads: int | _zmq.Context = 1, + shadow: _zmq.Context | int = 0, + ) -> None: + super().__init__(io_threads, shadow) # type: ignore + + +class ZMQEventLoop(SelectorEventLoop): + """DEPRECATED: AsyncIO eventloop using zmq_poll. + + pyzmq sockets should work with any asyncio event loop as of pyzmq 17. + """ + + def __init__(self, selector=None): + _deprecated() + return super().__init__(selector) + + +_loop = None + + +def _deprecated(): + if _deprecated.called: # type: ignore + return + _deprecated.called = True # type: ignore + + warnings.warn( + "ZMQEventLoop and zmq.asyncio.install are deprecated in pyzmq 17. Special eventloop integration is no longer needed.", + DeprecationWarning, + stacklevel=3, + ) + + +_deprecated.called = False # type: ignore + + +def install(): + """DEPRECATED: No longer needed in pyzmq 17""" + _deprecated() + + +__all__ = [ + "Context", + "Socket", + "Poller", + "ZMQEventLoop", + "install", +] diff --git a/source/zmq/auth/__init__.py b/source/zmq/auth/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ebacab71272ab4e1b7999ea8d6a875c08c6b861b --- /dev/null +++ b/source/zmq/auth/__init__.py @@ -0,0 +1,13 @@ +"""Utilities for ZAP authentication. + +To run authentication in a background thread, see :mod:`zmq.auth.thread`. +For integration with the asyncio event loop, see :mod:`zmq.auth.asyncio`. + +Authentication examples are provided in the pyzmq codebase, under +`/examples/security/`. + +.. versionadded:: 14.1 +""" + +from .base import * +from .certs import * diff --git a/source/zmq/auth/asyncio.py b/source/zmq/auth/asyncio.py new file mode 100644 index 0000000000000000000000000000000000000000..8b4915c12784538761dab70cee8e34395a1df066 --- /dev/null +++ b/source/zmq/auth/asyncio.py @@ -0,0 +1,66 @@ +"""ZAP Authenticator integrated with the asyncio IO loop. + +.. versionadded:: 15.2 +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import asyncio +import warnings +from typing import Any, Optional + +import zmq +from zmq.asyncio import Poller + +from .base import Authenticator + + +class AsyncioAuthenticator(Authenticator): + """ZAP authentication for use in the asyncio IO loop""" + + __poller: Optional[Poller] + __task: Any + + def __init__( + self, + context: Optional["zmq.Context"] = None, + loop: Any = None, + encoding: str = 'utf-8', + log: Any = None, + ): + super().__init__(context, encoding, log) + if loop is not None: + warnings.warn( + f"{self.__class__.__name__}(loop) is deprecated and ignored", + DeprecationWarning, + stacklevel=2, + ) + self.__poller = None + self.__task = None + + async def __handle_zap(self) -> None: + while self.__poller is not None: + events = await self.__poller.poll() + if self.zap_socket in dict(events): + msg = self.zap_socket.recv_multipart() + await self.handle_zap_message(msg) + + def start(self) -> None: + """Start ZAP authentication""" + super().start() + self.__poller = Poller() + self.__poller.register(self.zap_socket, zmq.POLLIN) + self.__task = asyncio.ensure_future(self.__handle_zap()) + + def stop(self) -> None: + """Stop ZAP authentication""" + if self.__task: + self.__task.cancel() + if self.__poller: + self.__poller.unregister(self.zap_socket) + self.__poller = None + super().stop() + + +__all__ = ["AsyncioAuthenticator"] diff --git a/source/zmq/auth/base.py b/source/zmq/auth/base.py new file mode 100644 index 0000000000000000000000000000000000000000..c862b60c14f0625de1cf8c0006357760dce88d29 --- /dev/null +++ b/source/zmq/auth/base.py @@ -0,0 +1,445 @@ +"""Base implementation of 0MQ authentication.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import logging +import os +from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union + +import zmq +from zmq.error import _check_version +from zmq.utils import z85 + +from .certs import load_certificates + +CURVE_ALLOW_ANY = '*' +VERSION = b'1.0' + + +class Authenticator: + """Implementation of ZAP authentication for zmq connections. + + This authenticator class does not register with an event loop. As a result, + you will need to manually call `handle_zap_message`:: + + auth = zmq.Authenticator() + auth.allow("127.0.0.1") + auth.start() + while True: + await auth.handle_zap_msg(auth.zap_socket.recv_multipart()) + + Alternatively, you can register `auth.zap_socket` with a poller. + + Since many users will want to run ZAP in a way that does not block the + main thread, other authentication classes (such as :mod:`zmq.auth.thread`) + are provided. + + Note: + + - libzmq provides four levels of security: default NULL (which the Authenticator does + not see), and authenticated NULL, PLAIN, CURVE, and GSSAPI, which the Authenticator can see. + - until you add policies, all incoming NULL connections are allowed. + (classic ZeroMQ behavior), and all PLAIN and CURVE connections are denied. + - GSSAPI requires no configuration. + """ + + context: "zmq.Context" + encoding: str + allow_any: bool + credentials_providers: Dict[str, Any] + zap_socket: "zmq.Socket" + _allowed: Set[str] + _denied: Set[str] + passwords: Dict[str, Dict[str, str]] + certs: Dict[str, Dict[bytes, Any]] + log: Any + + def __init__( + self, + context: Optional["zmq.Context"] = None, + encoding: str = 'utf-8', + log: Any = None, + ): + _check_version((4, 0), "security") + self.context = context or zmq.Context.instance() + self.encoding = encoding + self.allow_any = False + self.credentials_providers = {} + self.zap_socket = None # type: ignore + self._allowed = set() + self._denied = set() + # passwords is a dict keyed by domain and contains values + # of dicts with username:password pairs. + self.passwords = {} + # certs is dict keyed by domain and contains values + # of dicts keyed by the public keys from the specified location. + self.certs = {} + self.log = log or logging.getLogger('zmq.auth') + + def start(self) -> None: + """Create and bind the ZAP socket""" + self.zap_socket = self.context.socket(zmq.REP, socket_class=zmq.Socket) + self.zap_socket.linger = 1 + self.zap_socket.bind("inproc://zeromq.zap.01") + self.log.debug("Starting") + + def stop(self) -> None: + """Close the ZAP socket""" + if self.zap_socket: + self.zap_socket.close() + self.zap_socket = None # type: ignore + + def allow(self, *addresses: str) -> None: + """Allow IP address(es). + + Connections from addresses not explicitly allowed will be rejected. + + - For NULL, all clients from this address will be accepted. + - For real auth setups, they will be allowed to continue with authentication. + + allow is mutually exclusive with deny. + """ + if self._denied: + raise ValueError("Only use allow or deny, not both") + self.log.debug("Allowing %s", ','.join(addresses)) + self._allowed.update(addresses) + + def deny(self, *addresses: str) -> None: + """Deny IP address(es). + + Addresses not explicitly denied will be allowed to continue with authentication. + + deny is mutually exclusive with allow. + """ + if self._allowed: + raise ValueError("Only use a allow or deny, not both") + self.log.debug("Denying %s", ','.join(addresses)) + self._denied.update(addresses) + + def configure_plain( + self, domain: str = '*', passwords: Optional[Dict[str, str]] = None + ) -> None: + """Configure PLAIN authentication for a given domain. + + PLAIN authentication uses a plain-text password file. + To cover all domains, use "*". + You can modify the password file at any time; it is reloaded automatically. + """ + if passwords: + self.passwords[domain] = passwords + self.log.debug("Configure plain: %s", domain) + + def configure_curve( + self, domain: str = '*', location: Union[str, os.PathLike] = "." + ) -> None: + """Configure CURVE authentication for a given domain. + + CURVE authentication uses a directory that holds all public client certificates, + i.e. their public keys. + + To cover all domains, use "*". + + You can add and remove certificates in that directory at any time. configure_curve must be called + every time certificates are added or removed, in order to update the Authenticator's state + + To allow all client keys without checking, specify CURVE_ALLOW_ANY for the location. + """ + # If location is CURVE_ALLOW_ANY then allow all clients. Otherwise + # treat location as a directory that holds the certificates. + self.log.debug("Configure curve: %s[%s]", domain, location) + if location == CURVE_ALLOW_ANY: + self.allow_any = True + else: + self.allow_any = False + try: + self.certs[domain] = load_certificates(location) + except Exception as e: + self.log.error("Failed to load CURVE certs from %s: %s", location, e) + + def configure_curve_callback( + self, domain: str = '*', credentials_provider: Any = None + ) -> None: + """Configure CURVE authentication for a given domain. + + CURVE authentication using a callback function validating + the client public key according to a custom mechanism, e.g. checking the + key against records in a db. credentials_provider is an object of a class which + implements a callback method accepting two parameters (domain and key), e.g.:: + + class CredentialsProvider(object): + + def __init__(self): + ...e.g. db connection + + def callback(self, domain, key): + valid = ...lookup key and/or domain in db + if valid: + logging.info('Authorizing: {0}, {1}'.format(domain, key)) + return True + else: + logging.warning('NOT Authorizing: {0}, {1}'.format(domain, key)) + return False + + To cover all domains, use "*". + """ + + self.allow_any = False + + if credentials_provider is not None: + self.credentials_providers[domain] = credentials_provider + else: + self.log.error("None credentials_provider provided for domain:%s", domain) + + def curve_user_id(self, client_public_key: bytes) -> str: + """Return the User-Id corresponding to a CURVE client's public key + + Default implementation uses the z85-encoding of the public key. + + Override to define a custom mapping of public key : user-id + + This is only called on successful authentication. + + Parameters + ---------- + client_public_key: bytes + The client public key used for the given message + + Returns + ------- + user_id: unicode + The user ID as text + """ + return z85.encode(client_public_key).decode('ascii') + + def configure_gssapi( + self, domain: str = '*', location: Optional[str] = None + ) -> None: + """Configure GSSAPI authentication + + Currently this is a no-op because there is nothing to configure with GSSAPI. + """ + + async def handle_zap_message(self, msg: List[bytes]): + """Perform ZAP authentication""" + if len(msg) < 6: + self.log.error("Invalid ZAP message, not enough frames: %r", msg) + if len(msg) < 2: + self.log.error("Not enough information to reply") + else: + self._send_zap_reply(msg[1], b"400", b"Not enough frames") + return + + version, request_id, domain, address, identity, mechanism = msg[:6] + credentials = msg[6:] + + domain = domain.decode(self.encoding, 'replace') + address = address.decode(self.encoding, 'replace') + + if version != VERSION: + self.log.error("Invalid ZAP version: %r", msg) + self._send_zap_reply(request_id, b"400", b"Invalid version") + return + + self.log.debug( + "version: %r, request_id: %r, domain: %r," + " address: %r, identity: %r, mechanism: %r", + version, + request_id, + domain, + address, + identity, + mechanism, + ) + + # Is address is explicitly allowed or _denied? + allowed = False + denied = False + reason = b"NO ACCESS" + + if self._allowed: + if address in self._allowed: + allowed = True + self.log.debug("PASSED (allowed) address=%s", address) + else: + denied = True + reason = b"Address not allowed" + self.log.debug("DENIED (not allowed) address=%s", address) + + elif self._denied: + if address in self._denied: + denied = True + reason = b"Address denied" + self.log.debug("DENIED (denied) address=%s", address) + else: + allowed = True + self.log.debug("PASSED (not denied) address=%s", address) + + # Perform authentication mechanism-specific checks if necessary + username = "anonymous" + if not denied: + if mechanism == b'NULL' and not allowed: + # For NULL, we allow if the address wasn't denied + self.log.debug("ALLOWED (NULL)") + allowed = True + + elif mechanism == b'PLAIN': + # For PLAIN, even a _alloweded address must authenticate + if len(credentials) != 2: + self.log.error("Invalid PLAIN credentials: %r", credentials) + self._send_zap_reply(request_id, b"400", b"Invalid credentials") + return + username, password = ( + c.decode(self.encoding, 'replace') for c in credentials + ) + allowed, reason = self._authenticate_plain(domain, username, password) + + elif mechanism == b'CURVE': + # For CURVE, even a _alloweded address must authenticate + if len(credentials) != 1: + self.log.error("Invalid CURVE credentials: %r", credentials) + self._send_zap_reply(request_id, b"400", b"Invalid credentials") + return + key = credentials[0] + allowed, reason = await self._authenticate_curve(domain, key) + if allowed: + username = self.curve_user_id(key) + + elif mechanism == b'GSSAPI': + if len(credentials) != 1: + self.log.error("Invalid GSSAPI credentials: %r", credentials) + self._send_zap_reply(request_id, b"400", b"Invalid credentials") + return + # use principal as user-id for now + principal = credentials[0] + username = principal.decode("utf8") + allowed, reason = self._authenticate_gssapi(domain, principal) + + if allowed: + self._send_zap_reply(request_id, b"200", b"OK", username) + else: + self._send_zap_reply(request_id, b"400", reason) + + def _authenticate_plain( + self, domain: str, username: str, password: str + ) -> Tuple[bool, bytes]: + """PLAIN ZAP authentication""" + allowed = False + reason = b"" + if self.passwords: + # If no domain is not specified then use the default domain + if not domain: + domain = '*' + + if domain in self.passwords: + if username in self.passwords[domain]: + if password == self.passwords[domain][username]: + allowed = True + else: + reason = b"Invalid password" + else: + reason = b"Invalid username" + else: + reason = b"Invalid domain" + + if allowed: + self.log.debug( + "ALLOWED (PLAIN) domain=%s username=%s password=%s", + domain, + username, + password, + ) + else: + self.log.debug("DENIED %s", reason) + + else: + reason = b"No passwords defined" + self.log.debug("DENIED (PLAIN) %s", reason) + + return allowed, reason + + async def _authenticate_curve( + self, domain: str, client_key: bytes + ) -> Tuple[bool, bytes]: + """CURVE ZAP authentication""" + allowed = False + reason = b"" + if self.allow_any: + allowed = True + reason = b"OK" + self.log.debug("ALLOWED (CURVE allow any client)") + elif self.credentials_providers != {}: + # If no explicit domain is specified then use the default domain + if not domain: + domain = '*' + + if domain in self.credentials_providers: + z85_client_key = z85.encode(client_key) + # Callback to check if key is Allowed + r = self.credentials_providers[domain].callback(domain, z85_client_key) + if isinstance(r, Awaitable): + r = await r + if r: + allowed = True + reason = b"OK" + else: + reason = b"Unknown key" + + status = "ALLOWED" if allowed else "DENIED" + self.log.debug( + "%s (CURVE auth_callback) domain=%s client_key=%s", + status, + domain, + z85_client_key, + ) + else: + reason = b"Unknown domain" + else: + # If no explicit domain is specified then use the default domain + if not domain: + domain = '*' + + if domain in self.certs: + # The certs dict stores keys in z85 format, convert binary key to z85 bytes + z85_client_key = z85.encode(client_key) + if self.certs[domain].get(z85_client_key): + allowed = True + reason = b"OK" + else: + reason = b"Unknown key" + + status = "ALLOWED" if allowed else "DENIED" + self.log.debug( + "%s (CURVE) domain=%s client_key=%s", + status, + domain, + z85_client_key, + ) + else: + reason = b"Unknown domain" + + return allowed, reason + + def _authenticate_gssapi(self, domain: str, principal: bytes) -> Tuple[bool, bytes]: + """Nothing to do for GSSAPI, which has already been handled by an external service.""" + self.log.debug("ALLOWED (GSSAPI) domain=%s principal=%s", domain, principal) + return True, b'OK' + + def _send_zap_reply( + self, + request_id: bytes, + status_code: bytes, + status_text: bytes, + user_id: str = 'anonymous', + ) -> None: + """Send a ZAP reply to finish the authentication.""" + user_id = user_id if status_code == b'200' else b'' + if isinstance(user_id, str): + user_id = user_id.encode(self.encoding, 'replace') + metadata = b'' # not currently used + self.log.debug("ZAP reply code=%s text=%s", status_code, status_text) + reply = [VERSION, request_id, status_code, status_text, user_id, metadata] + self.zap_socket.send_multipart(reply) + + +__all__ = ['Authenticator', 'CURVE_ALLOW_ANY'] diff --git a/source/zmq/auth/certs.py b/source/zmq/auth/certs.py new file mode 100644 index 0000000000000000000000000000000000000000..d60ae005dc111da80576adcdc378390ba6026527 --- /dev/null +++ b/source/zmq/auth/certs.py @@ -0,0 +1,140 @@ +"""0MQ authentication related functions and classes.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import datetime +import glob +import os +from typing import Dict, Optional, Tuple, Union + +import zmq + +_cert_secret_banner = """# **** Generated on {0} by pyzmq **** +# ZeroMQ CURVE **Secret** Certificate +# DO NOT PROVIDE THIS FILE TO OTHER USERS nor change its permissions. + +""" + + +_cert_public_banner = """# **** Generated on {0} by pyzmq **** +# ZeroMQ CURVE Public Certificate +# Exchange securely, or use a secure mechanism to verify the contents +# of this file after exchange. Store public certificates in your home +# directory, in the .curve subdirectory. + +""" + + +def _write_key_file( + key_filename: Union[str, os.PathLike], + banner: str, + public_key: Union[str, bytes], + secret_key: Optional[Union[str, bytes]] = None, + metadata: Optional[Dict[str, str]] = None, + encoding: str = 'utf-8', +) -> None: + """Create a certificate file""" + if isinstance(public_key, bytes): + public_key = public_key.decode(encoding) + if isinstance(secret_key, bytes): + secret_key = secret_key.decode(encoding) + with open(key_filename, 'w', encoding='utf8') as f: + f.write(banner.format(datetime.datetime.now())) + + f.write('metadata\n') + if metadata: + for k, v in metadata.items(): + if isinstance(k, bytes): + k = k.decode(encoding) + if isinstance(v, bytes): + v = v.decode(encoding) + f.write(f" {k} = {v}\n") + + f.write('curve\n') + f.write(f" public-key = \"{public_key}\"\n") + + if secret_key: + f.write(f" secret-key = \"{secret_key}\"\n") + + +def create_certificates( + key_dir: Union[str, os.PathLike], + name: str, + metadata: Optional[Dict[str, str]] = None, +) -> Tuple[str, str]: + """Create zmq certificates. + + Returns the file paths to the public and secret certificate files. + """ + public_key, secret_key = zmq.curve_keypair() + base_filename = os.path.join(key_dir, name) + secret_key_file = f"{base_filename}.key_secret" + public_key_file = f"{base_filename}.key" + now = datetime.datetime.now() + + _write_key_file(public_key_file, _cert_public_banner.format(now), public_key) + + _write_key_file( + secret_key_file, + _cert_secret_banner.format(now), + public_key, + secret_key=secret_key, + metadata=metadata, + ) + + return public_key_file, secret_key_file + + +def load_certificate( + filename: Union[str, os.PathLike], +) -> Tuple[bytes, Optional[bytes]]: + """Load public and secret key from a zmq certificate. + + Returns (public_key, secret_key) + + If the certificate file only contains the public key, + secret_key will be None. + + If there is no public key found in the file, ValueError will be raised. + """ + public_key = None + secret_key = None + if not os.path.exists(filename): + raise OSError(f"Invalid certificate file: {filename}") + + with open(filename, 'rb') as f: + for line in f: + line = line.strip() + if line.startswith(b'#'): + continue + if line.startswith(b'public-key'): + public_key = line.split(b"=", 1)[1].strip(b' \t\'"') + if line.startswith(b'secret-key'): + secret_key = line.split(b"=", 1)[1].strip(b' \t\'"') + if public_key and secret_key: + break + + if public_key is None: + raise ValueError(f"No public key found in {filename}") + + return public_key, secret_key + + +def load_certificates(directory: Union[str, os.PathLike] = '.') -> Dict[bytes, bool]: + """Load public keys from all certificates in a directory""" + certs = {} + if not os.path.isdir(directory): + raise OSError(f"Invalid certificate directory: {directory}") + # Follow czmq pattern of public keys stored in *.key files. + glob_string = os.path.join(directory, "*.key") + + cert_files = glob.glob(glob_string) + for cert_file in cert_files: + public_key, _ = load_certificate(cert_file) + if public_key: + certs[public_key] = True + return certs + + +__all__ = ['create_certificates', 'load_certificate', 'load_certificates'] diff --git a/source/zmq/auth/ioloop.py b/source/zmq/auth/ioloop.py new file mode 100644 index 0000000000000000000000000000000000000000..f87f068e7ef92004a7bbad93e77e445a607e0093 --- /dev/null +++ b/source/zmq/auth/ioloop.py @@ -0,0 +1,48 @@ +"""ZAP Authenticator integrated with the tornado IOLoop. + +.. versionadded:: 14.1 +.. deprecated:: 25 + Use asyncio.AsyncioAuthenticator instead. + Since tornado runs on asyncio, the asyncio authenticator + offers the same functionality in tornado. +""" + +import warnings + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. +from typing import Any, Optional + +import zmq + +from .asyncio import AsyncioAuthenticator + +warnings.warn( + "zmq.auth.ioloop.IOLoopAuthenticator is deprecated. Use zmq.auth.asyncio.AsyncioAuthenticator", + DeprecationWarning, + stacklevel=2, +) + + +class IOLoopAuthenticator(AsyncioAuthenticator): + """ZAP authentication for use in the tornado IOLoop""" + + def __init__( + self, + context: Optional["zmq.Context"] = None, + encoding: str = 'utf-8', + log: Any = None, + io_loop: Any = None, + ): + loop = None + if io_loop is not None: + warnings.warn( + f"{self.__class__.__name__}(io_loop) is deprecated and ignored", + DeprecationWarning, + stacklevel=2, + ) + loop = io_loop.asyncio_loop + super().__init__(context=context, encoding=encoding, log=log, loop=loop) + + +__all__ = ['IOLoopAuthenticator'] diff --git a/source/zmq/auth/thread.py b/source/zmq/auth/thread.py new file mode 100644 index 0000000000000000000000000000000000000000..a227c4bd5974f95e3b5b844a1ce0de34f9454f10 --- /dev/null +++ b/source/zmq/auth/thread.py @@ -0,0 +1,139 @@ +"""ZAP Authenticator in a Python Thread. + +.. versionadded:: 14.1 +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import asyncio +from threading import Event, Thread +from typing import Any, List, Optional + +import zmq +import zmq.asyncio + +from .base import Authenticator + + +class AuthenticationThread(Thread): + """A Thread for running a zmq Authenticator + + This is run in the background by ThreadAuthenticator + """ + + pipe: zmq.Socket + loop: asyncio.AbstractEventLoop + authenticator: Authenticator + poller: Optional[zmq.asyncio.Poller] = None + + def __init__( + self, + authenticator: Authenticator, + pipe: zmq.Socket, + ) -> None: + super().__init__(daemon=True) + self.authenticator = authenticator + self.log = authenticator.log + self.pipe = pipe + + self.started = Event() + + def run(self) -> None: + """Start the Authentication Agent thread task""" + + loop = asyncio.new_event_loop() + try: + loop.run_until_complete(self._run()) + finally: + if self.pipe: + self.pipe.close() + self.pipe = None # type: ignore + + loop.close() + + async def _run(self): + self.poller = zmq.asyncio.Poller() + self.poller.register(self.pipe, zmq.POLLIN) + self.poller.register(self.authenticator.zap_socket, zmq.POLLIN) + self.started.set() + + while True: + events = dict(await self.poller.poll()) + if self.pipe in events: + msg = self.pipe.recv_multipart() + if self._handle_pipe_message(msg): + return + if self.authenticator.zap_socket in events: + msg = self.authenticator.zap_socket.recv_multipart() + await self.authenticator.handle_zap_message(msg) + + def _handle_pipe_message(self, msg: List[bytes]) -> bool: + command = msg[0] + self.log.debug("auth received API command %r", command) + + if command == b'TERMINATE': + return True + + else: + self.log.error("Invalid auth command from API: %r", command) + self.pipe.send(b'ERROR') + + return False + + +class ThreadAuthenticator(Authenticator): + """Run ZAP authentication in a background thread""" + + pipe: "zmq.Socket" + pipe_endpoint: str = '' + thread: AuthenticationThread + + def __init__( + self, + context: Optional["zmq.Context"] = None, + encoding: str = 'utf-8', + log: Any = None, + ): + super().__init__(context=context, encoding=encoding, log=log) + self.pipe = None # type: ignore + self.pipe_endpoint = f"inproc://{id(self)}.inproc" + self.thread = None # type: ignore + + def start(self) -> None: + """Start the authentication thread""" + # start the Authenticator + super().start() + + # create a socket pair to communicate with auth thread. + self.pipe = self.context.socket(zmq.PAIR, socket_class=zmq.Socket) + self.pipe.linger = 1 + self.pipe.bind(self.pipe_endpoint) + thread_pipe = self.context.socket(zmq.PAIR, socket_class=zmq.Socket) + thread_pipe.linger = 1 + thread_pipe.connect(self.pipe_endpoint) + self.thread = AuthenticationThread(authenticator=self, pipe=thread_pipe) + self.thread.start() + if not self.thread.started.wait(timeout=10): + raise RuntimeError("Authenticator thread failed to start") + + def stop(self) -> None: + """Stop the authentication thread""" + if self.pipe: + self.pipe.send(b'TERMINATE') + if self.is_alive(): + self.thread.join() + self.thread = None # type: ignore + self.pipe.close() + self.pipe = None # type: ignore + super().stop() + + def is_alive(self) -> bool: + """Is the ZAP thread currently running?""" + return bool(self.thread and self.thread.is_alive()) + + def __del__(self) -> None: + self.stop() + + +__all__ = ['ThreadAuthenticator'] diff --git a/source/zmq/backend/__init__.py b/source/zmq/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15f696d9d3ad5275443eb3145b9c4087f1e6609f --- /dev/null +++ b/source/zmq/backend/__init__.py @@ -0,0 +1,34 @@ +"""Import basic exposure of libzmq C API as a backend""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import os +import platform + +from .select import public_api, select_backend + +if 'PYZMQ_BACKEND' in os.environ: + backend = os.environ['PYZMQ_BACKEND'] + if backend in ('cython', 'cffi'): + backend = f'zmq.backend.{backend}' + _ns = select_backend(backend) +else: + # default to cython, fallback to cffi + # (reverse on PyPy) + if platform.python_implementation() == 'PyPy': + first, second = ('zmq.backend.cffi', 'zmq.backend.cython') + else: + first, second = ('zmq.backend.cython', 'zmq.backend.cffi') + + try: + _ns = select_backend(first) + except Exception as original_error: + try: + _ns = select_backend(second) + except ImportError: + raise original_error from None + +globals().update(_ns) + +__all__ = public_api diff --git a/source/zmq/backend/__init__.pyi b/source/zmq/backend/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..2d51e876e72d8d4d59151828ba794d2d89120044 --- /dev/null +++ b/source/zmq/backend/__init__.pyi @@ -0,0 +1,123 @@ +from typing import Any, Callable, List, Optional, Set, Tuple, TypeVar, Union, overload + +from typing_extensions import Literal + +import zmq + +from .select import select_backend + +# avoid collision in Frame.bytes +_bytestr = bytes + +T = TypeVar("T") + +class Frame: + buffer: Any + bytes: bytes + more: bool + tracker: Any + def __init__( + self, + data: Any = None, + track: bool = False, + copy: bool | None = None, + copy_threshold: int | None = None, + ): ... + def copy_fast(self: T) -> T: ... + def get(self, option: int) -> int | _bytestr | str: ... + def set(self, option: int, value: int | _bytestr | str) -> None: ... + +class Socket: + underlying: int + context: zmq.Context + copy_threshold: int + + # specific option types + FD: int + + def __init__( + self, + context: Context | None = None, + socket_type: int = 0, + shadow: int = 0, + copy_threshold: int | None = zmq.COPY_THRESHOLD, + ) -> None: ... + def close(self, linger: int | None = ...) -> None: ... + def get(self, option: int) -> int | bytes | str: ... + def set(self, option: int, value: int | bytes | str) -> None: ... + def connect(self, url: str): ... + def disconnect(self, url: str) -> None: ... + def bind(self, url: str): ... + def unbind(self, url: str) -> None: ... + def send( + self, + data: Any, + flags: int = ..., + copy: bool = ..., + track: bool = ..., + ) -> zmq.MessageTracker | None: ... + @overload + def recv( + self, + flags: int = ..., + *, + copy: Literal[False], + track: bool = ..., + ) -> zmq.Frame: ... + @overload + def recv( + self, + flags: int = ..., + *, + copy: Literal[True], + track: bool = ..., + ) -> bytes: ... + @overload + def recv( + self, + flags: int = ..., + track: bool = False, + ) -> bytes: ... + @overload + def recv( + self, + flags: int | None = ..., + copy: bool = ..., + track: bool | None = False, + ) -> zmq.Frame | bytes: ... + def recv_into(self, buf, /, *, nbytes: int = 0, flags: int = 0) -> int: ... + def monitor(self, addr: str | None, events: int) -> None: ... + # draft methods + def join(self, group: str) -> None: ... + def leave(self, group: str) -> None: ... + +class Context: + underlying: int + def __init__(self, io_threads: int = 1, shadow: int = 0): ... + def get(self, option: int) -> int | bytes | str: ... + def set(self, option: int, value: int | bytes | str) -> None: ... + def socket(self, socket_type: int) -> Socket: ... + def term(self) -> None: ... + +IPC_PATH_MAX_LEN: int +PYZMQ_DRAFT_API: bool + +def has(capability: str) -> bool: ... +def curve_keypair() -> tuple[bytes, bytes]: ... +def curve_public(secret_key: bytes) -> bytes: ... +def strerror(errno: int | None = ...) -> str: ... +def zmq_errno() -> int: ... +def zmq_version() -> str: ... +def zmq_version_info() -> tuple[int, int, int]: ... +def zmq_poll( + sockets: list[Any], timeout: int | None = ... +) -> list[tuple[Socket, int]]: ... +def proxy(frontend: Socket, backend: Socket, capture: Socket | None = None) -> int: ... +def proxy_steerable( + frontend: Socket, + backend: Socket, + capture: Socket | None = ..., + control: Socket | None = ..., +) -> int: ... + +monitored_queue = Callable | None diff --git a/source/zmq/backend/cffi/README.md b/source/zmq/backend/cffi/README.md new file mode 100644 index 0000000000000000000000000000000000000000..00bb32989dcfbc787760075fd6dc4801568ab0a2 --- /dev/null +++ b/source/zmq/backend/cffi/README.md @@ -0,0 +1 @@ +PyZMQ's CFFI support is designed only for (Unix) systems conforming to `have_sys_un_h = True`. diff --git a/source/zmq/backend/cffi/__init__.py b/source/zmq/backend/cffi/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ce929a7a6a6bbe84c6a1cdd873095152837cd44c --- /dev/null +++ b/source/zmq/backend/cffi/__init__.py @@ -0,0 +1,38 @@ +"""CFFI backend (for PyPy)""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +# for clearer error message on missing cffi +import cffi # noqa + +from zmq.backend.cffi import _poll, context, devices, error, message, socket, utils + +from ._cffi import ffi +from ._cffi import lib as C + + +def zmq_version_info(): + """Get libzmq version as tuple of ints""" + major = ffi.new('int*') + minor = ffi.new('int*') + patch = ffi.new('int*') + + C.zmq_version(major, minor, patch) + + return (int(major[0]), int(minor[0]), int(patch[0])) + + +__all__ = ["zmq_version_info"] +for submod in (error, message, context, socket, _poll, devices, utils): + __all__.extend(submod.__all__) + +from ._poll import * +from .context import * +from .devices import * +from .error import * +from .message import * +from .socket import * +from .utils import * + +monitored_queue = None diff --git a/source/zmq/backend/cffi/_cdefs.h b/source/zmq/backend/cffi/_cdefs.h new file mode 100644 index 0000000000000000000000000000000000000000..98f7ce6208af75fa666a2e16bf25892ed2474d62 --- /dev/null +++ b/source/zmq/backend/cffi/_cdefs.h @@ -0,0 +1,98 @@ +void zmq_version(int *major, int *minor, int *patch); + +void* zmq_socket(void *context, int type); +int zmq_close(void *socket); + +int zmq_bind(void *socket, const char *endpoint); +int zmq_connect(void *socket, const char *endpoint); + +int zmq_errno(void); +const char * zmq_strerror(int errnum); + +int zmq_unbind(void *socket, const char *endpoint); +int zmq_disconnect(void *socket, const char *endpoint); +void* zmq_ctx_new(); +int zmq_ctx_destroy(void *context); +int zmq_ctx_get(void *context, int opt); +int zmq_ctx_set(void *context, int opt, int optval); +int zmq_proxy(void *frontend, void *backend, void *capture); +int zmq_proxy_steerable(void *frontend, + void *backend, + void *capture, + void *control); +int zmq_socket_monitor(void *socket, const char *addr, int events); + +int zmq_curve_keypair (char *z85_public_key, char *z85_secret_key); +int zmq_curve_public (char *z85_public_key, char *z85_secret_key); +int zmq_has (const char *capability); + +typedef struct { ...; } zmq_msg_t; +typedef ... zmq_free_fn; + +int zmq_msg_init(zmq_msg_t *msg); +int zmq_msg_init_size(zmq_msg_t *msg, size_t size); +int zmq_msg_init_data(zmq_msg_t *msg, + void *data, + size_t size, + zmq_free_fn *ffn, + void *hint); + +size_t zmq_msg_size(zmq_msg_t *msg); +void *zmq_msg_data(zmq_msg_t *msg); +int zmq_msg_close(zmq_msg_t *msg); + +int zmq_msg_copy(zmq_msg_t *dst, zmq_msg_t *src); +int zmq_msg_send(zmq_msg_t *msg, void *socket, int flags); +int zmq_msg_recv(zmq_msg_t *msg, void *socket, int flags); +int zmq_recv(void *socket, void *buf, int nbytes, int flags); + +int zmq_getsockopt(void *socket, + int option_name, + void *option_value, + size_t *option_len); + +int zmq_setsockopt(void *socket, + int option_name, + const void *option_value, + size_t option_len); + +typedef int... ZMQ_FD_T; + +typedef struct +{ + void *socket; + ZMQ_FD_T fd; + short events; + short revents; +} zmq_pollitem_t; + +int zmq_poll(zmq_pollitem_t *items, int nitems, long timeout); + +// draft poller +void *zmq_poller_new (); +int zmq_poller_destroy (void **poller_p_); +int zmq_poller_add (void *poller_, void *socket_, void *user_data_, short events_); +int zmq_poller_fd (void *poller_, ZMQ_FD_T *fd_); + +// miscellany +void * memcpy(void *restrict s1, const void *restrict s2, size_t n); +void * malloc(size_t sz); +void free(void *p); +int get_ipc_path_max_len(void); + +typedef struct { ...; } mutex_t; + +typedef struct _zhint { + void *sock; + mutex_t *mutex; + size_t id; +} zhint; + +mutex_t* mutex_allocate(); + +int zmq_wrap_msg_init_data(zmq_msg_t *msg, + void *data, + size_t size, + void *hint); + +#define PYZMQ_DRAFT_API ... diff --git a/source/zmq/backend/cffi/_cffi_src.c b/source/zmq/backend/cffi/_cffi_src.c new file mode 100644 index 0000000000000000000000000000000000000000..691be3c782d3d8209dc67e12d241e749e08440cb --- /dev/null +++ b/source/zmq/backend/cffi/_cffi_src.c @@ -0,0 +1,50 @@ +#include +#include + +#include "pyversion_compat.h" +#include "mutex.h" +#include "ipcmaxlen.h" +#include "zmq_compat.h" +#include + +typedef struct _zhint { + void *sock; + mutex_t *mutex; + size_t id; +} zhint; + +void free_python_msg(void *data, void *vhint) { + zmq_msg_t msg; + zhint *hint = (zhint *)vhint; + int rc; + if (hint != NULL) { + zmq_msg_init_size(&msg, sizeof(size_t)); + memcpy(zmq_msg_data(&msg), &hint->id, sizeof(size_t)); + rc = mutex_lock(hint->mutex); + if (rc != 0) { + fprintf(stderr, "pyzmq-gc mutex lock failed rc=%d\n", rc); + } + rc = zmq_msg_send(&msg, hint->sock, 0); + if (rc < 0) { + /* + * gc socket could have been closed, e.g. during process teardown. + * If so, ignore the failure because there's nothing to do. + */ + if (zmq_errno() != ENOTSOCK) { + fprintf(stderr, "pyzmq-gc send failed: %s\n", + zmq_strerror(zmq_errno())); + } + } + rc = mutex_unlock(hint->mutex); + if (rc != 0) { + fprintf(stderr, "pyzmq-gc mutex unlock failed rc=%d\n", rc); + } + zmq_msg_close(&msg); + free(hint); + } +} + +int zmq_wrap_msg_init_data(zmq_msg_t *msg, void *data, size_t size, + void *hint) { + return zmq_msg_init_data(msg, data, size, free_python_msg, hint); +} diff --git a/source/zmq/backend/cffi/_poll.py b/source/zmq/backend/cffi/_poll.py new file mode 100644 index 0000000000000000000000000000000000000000..63e2763b9c57b065b963df9b0a1f86cd9fe30af0 --- /dev/null +++ b/source/zmq/backend/cffi/_poll.py @@ -0,0 +1,92 @@ +"""zmq poll function""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +try: + from time import monotonic +except ImportError: + from time import clock as monotonic + +import warnings + +from zmq.error import InterruptedSystemCall, _check_rc + +from ._cffi import ffi +from ._cffi import lib as C + + +def _make_zmq_pollitem(socket, flags): + zmq_socket = socket._zmq_socket + zmq_pollitem = ffi.new('zmq_pollitem_t*') + zmq_pollitem.socket = zmq_socket + zmq_pollitem.fd = 0 + zmq_pollitem.events = flags + zmq_pollitem.revents = 0 + return zmq_pollitem[0] + + +def _make_zmq_pollitem_fromfd(socket_fd, flags): + zmq_pollitem = ffi.new('zmq_pollitem_t*') + zmq_pollitem.socket = ffi.NULL + zmq_pollitem.fd = socket_fd + zmq_pollitem.events = flags + zmq_pollitem.revents = 0 + return zmq_pollitem[0] + + +def zmq_poll(sockets, timeout): + cffi_pollitem_list = [] + low_level_to_socket_obj = {} + from zmq import Socket + + for item in sockets: + if isinstance(item[0], Socket): + low_level_to_socket_obj[item[0]._zmq_socket] = item + cffi_pollitem_list.append(_make_zmq_pollitem(item[0], item[1])) + else: + if not isinstance(item[0], int): + # not an FD, get it from fileno() + item = (item[0].fileno(), item[1]) + low_level_to_socket_obj[item[0]] = item + cffi_pollitem_list.append(_make_zmq_pollitem_fromfd(item[0], item[1])) + items = ffi.new('zmq_pollitem_t[]', cffi_pollitem_list) + list_length = ffi.cast('int', len(cffi_pollitem_list)) + while True: + c_timeout = ffi.cast('long', timeout) + start = monotonic() + rc = C.zmq_poll(items, list_length, c_timeout) + try: + _check_rc(rc) + except InterruptedSystemCall: + if timeout > 0: + ms_passed = int(1000 * (monotonic() - start)) + if ms_passed < 0: + # don't allow negative ms_passed, + # which can happen on old Python versions without time.monotonic. + warnings.warn( + f"Negative elapsed time for interrupted poll: {ms_passed}." + " Did the clock change?", + RuntimeWarning, + ) + ms_passed = 0 + timeout = max(0, timeout - ms_passed) + continue + else: + break + result = [] + for item in items: + if item.revents > 0: + if item.socket != ffi.NULL: + result.append( + ( + low_level_to_socket_obj[item.socket][0], + item.revents, + ) + ) + else: + result.append((item.fd, item.revents)) + return result + + +__all__ = ['zmq_poll'] diff --git a/source/zmq/backend/cffi/context.py b/source/zmq/backend/cffi/context.py new file mode 100644 index 0000000000000000000000000000000000000000..23a69ecf1d97c1353470ebec03601370b8a5eb73 --- /dev/null +++ b/source/zmq/backend/cffi/context.py @@ -0,0 +1,77 @@ +"""zmq Context class""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from zmq.constants import EINVAL, IO_THREADS +from zmq.error import InterruptedSystemCall, ZMQError, _check_rc + +from ._cffi import ffi +from ._cffi import lib as C + + +class Context: + _zmq_ctx = None + _iothreads = None + _closed = True + _shadow = False + + def __init__(self, io_threads=1, shadow=None): + if shadow: + self._zmq_ctx = ffi.cast("void *", shadow) + self._shadow = True + else: + self._shadow = False + if not io_threads >= 0: + raise ZMQError(EINVAL) + + self._zmq_ctx = C.zmq_ctx_new() + if self._zmq_ctx == ffi.NULL: + raise ZMQError(C.zmq_errno()) + if not shadow: + C.zmq_ctx_set(self._zmq_ctx, IO_THREADS, io_threads) + self._closed = False + + @property + def underlying(self): + """The address of the underlying libzmq context""" + return int(ffi.cast('size_t', self._zmq_ctx)) + + @property + def closed(self): + return self._closed + + def set(self, option, value): + """set a context option + + see zmq_ctx_set + """ + rc = C.zmq_ctx_set(self._zmq_ctx, option, value) + _check_rc(rc) + + def get(self, option): + """get context option + + see zmq_ctx_get + """ + rc = C.zmq_ctx_get(self._zmq_ctx, option) + _check_rc(rc, error_without_errno=False) + return rc + + def term(self): + if self.closed: + return + + rc = C.zmq_ctx_destroy(self._zmq_ctx) + try: + _check_rc(rc) + except InterruptedSystemCall: + # ignore interrupted term + # see PEP 475 notes about close & EINTR for why + pass + + self._zmq_ctx = None + self._closed = True + + +__all__ = ['Context'] diff --git a/source/zmq/backend/cffi/devices.py b/source/zmq/backend/cffi/devices.py new file mode 100644 index 0000000000000000000000000000000000000000..a906be6e8709ce1e16ffedd231d2aff872978ba2 --- /dev/null +++ b/source/zmq/backend/cffi/devices.py @@ -0,0 +1,59 @@ +"""zmq device functions""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from ._cffi import ffi +from ._cffi import lib as C +from .socket import Socket +from .utils import _retry_sys_call + + +def proxy(frontend, backend, capture=None): + if isinstance(capture, Socket): + capture = capture._zmq_socket + else: + capture = ffi.NULL + + _retry_sys_call(C.zmq_proxy, frontend._zmq_socket, backend._zmq_socket, capture) + + +def proxy_steerable(frontend, backend, capture=None, control=None): + """proxy_steerable(frontend, backend, capture, control) + + Start a zeromq proxy with control flow. + + .. versionadded:: libzmq-4.1 + .. versionadded:: 18.0 + + Parameters + ---------- + frontend : Socket + The Socket instance for the incoming traffic. + backend : Socket + The Socket instance for the outbound traffic. + capture : Socket (optional) + The Socket instance for capturing traffic. + control : Socket (optional) + The Socket instance for control flow. + """ + if isinstance(capture, Socket): + capture = capture._zmq_socket + else: + capture = ffi.NULL + + if isinstance(control, Socket): + control = control._zmq_socket + else: + control = ffi.NULL + + _retry_sys_call( + C.zmq_proxy_steerable, + frontend._zmq_socket, + backend._zmq_socket, + capture, + control, + ) + + +__all__ = ['proxy', 'proxy_steerable'] diff --git a/source/zmq/backend/cffi/error.py b/source/zmq/backend/cffi/error.py new file mode 100644 index 0000000000000000000000000000000000000000..5561394ddecc5a47c3767f6921cfa80df5f6a3c1 --- /dev/null +++ b/source/zmq/backend/cffi/error.py @@ -0,0 +1,16 @@ +"""zmq error functions""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from ._cffi import ffi +from ._cffi import lib as C + + +def strerror(errno): + return ffi.string(C.zmq_strerror(errno)).decode() + + +zmq_errno = C.zmq_errno + +__all__ = ['strerror', 'zmq_errno'] diff --git a/source/zmq/backend/cffi/message.py b/source/zmq/backend/cffi/message.py new file mode 100644 index 0000000000000000000000000000000000000000..94bb8c96fc7c2a79cd0b20dda2abca5643599c4b --- /dev/null +++ b/source/zmq/backend/cffi/message.py @@ -0,0 +1,225 @@ +"""Dummy Frame object""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import errno +from threading import Event + +import zmq +import zmq.error +from zmq.constants import ETERM + +from ._cffi import ffi +from ._cffi import lib as C + +zmq_gc = None + +try: + from __pypy__.bufferable import bufferable as maybe_bufferable +except ImportError: + maybe_bufferable = object + + +def _content(obj): + """Return content of obj as bytes""" + if type(obj) is bytes: + return obj + if not isinstance(obj, memoryview): + obj = memoryview(obj) + return obj.tobytes() + + +def _check_rc(rc): + err = C.zmq_errno() + if rc == -1: + if err == errno.EINTR: + raise zmq.error.InterrruptedSystemCall(err) + elif err == errno.EAGAIN: + raise zmq.error.Again(errno) + elif err == ETERM: + raise zmq.error.ContextTerminated(err) + else: + raise zmq.error.ZMQError(err) + return 0 + + +class Frame(maybe_bufferable): + _data = None + tracker = None + closed = False + more = False + _buffer = None + _bytes = None + _failed_init = False + tracker_event = None + zmq_msg = None + + def __init__(self, data=None, track=False, copy=None, copy_threshold=None): + self._failed_init = True + + self.zmq_msg = ffi.cast('zmq_msg_t[1]', C.malloc(ffi.sizeof("zmq_msg_t"))) + + # self.tracker should start finished + # except in the case where we are sharing memory with libzmq + if track: + self.tracker = zmq._FINISHED_TRACKER + + if isinstance(data, str): + raise TypeError( + "Unicode strings are not allowed. Only: bytes, buffer interfaces." + ) + + if data is None: + rc = C.zmq_msg_init(self.zmq_msg) + _check_rc(rc) + self._failed_init = False + return + + self._data = data + if type(data) is bytes: + # avoid unnecessary copy on .bytes access + self._bytes = data + + self._buffer = memoryview(data) + if not self._buffer.contiguous: + raise BufferError("memoryview: underlying buffer is not contiguous") + # from_buffer silently copies if memory is not contiguous + c_data = ffi.from_buffer(self._buffer) + data_len_c = self._buffer.nbytes + + if copy is None: + if copy_threshold and data_len_c < copy_threshold: + copy = True + else: + copy = False + + if copy: + # copy message data instead of sharing memory + rc = C.zmq_msg_init_size(self.zmq_msg, data_len_c) + _check_rc(rc) + ffi.buffer(C.zmq_msg_data(self.zmq_msg), data_len_c)[:] = self._buffer + self._failed_init = False + return + + # Getting here means that we are doing a true zero-copy Frame, + # where libzmq and Python are sharing memory. + # Hook up garbage collection with MessageTracker and zmq_free_fn + + # Event and MessageTracker for monitoring when zmq is done with data: + if track: + evt = Event() + self.tracker_event = evt + self.tracker = zmq.MessageTracker(evt) + # create the hint for zmq_free_fn + # two pointers: the zmq_gc context and a message to be sent to the zmq_gc PULL socket + # allows libzmq to signal to Python when it is done with Python-owned memory. + global zmq_gc + if zmq_gc is None: + from zmq.utils.garbage import gc as zmq_gc + # can't use ffi.new because it will be freed at the wrong time! + hint = ffi.cast("zhint[1]", C.malloc(ffi.sizeof("zhint"))) + hint[0].id = zmq_gc.store(data, self.tracker_event) + if not zmq_gc._push_mutex: + zmq_gc._push_mutex = C.mutex_allocate() + + hint[0].mutex = ffi.cast("mutex_t*", zmq_gc._push_mutex) + hint[0].sock = ffi.cast("void*", zmq_gc._push_socket.underlying) + + # calls zmq_wrap_msg_init_data with the C.free_python_msg callback + rc = C.zmq_wrap_msg_init_data( + self.zmq_msg, + c_data, + data_len_c, + hint, + ) + if rc != 0: + C.free(hint) + C.free(self.zmq_msg) + _check_rc(rc) + self._failed_init = False + + def __del__(self): + if not self.closed and not self._failed_init: + self.close() + + def close(self): + if self.closed or self._failed_init or self.zmq_msg is None: + return + self.closed = True + rc = C.zmq_msg_close(self.zmq_msg) + C.free(self.zmq_msg) + self.zmq_msg = None + if rc != 0: + _check_rc(rc) + + def _buffer_from_zmq_msg(self): + """one-time extract buffer from zmq_msg + + for Frames created by recv + """ + if self._data is None: + self._data = ffi.buffer( + C.zmq_msg_data(self.zmq_msg), C.zmq_msg_size(self.zmq_msg) + ) + if self._buffer is None: + self._buffer = memoryview(self._data) + + @property + def buffer(self): + if self._buffer is None: + self._buffer_from_zmq_msg() + return self._buffer + + @property + def bytes(self): + if self._bytes is None: + self._bytes = self.buffer.tobytes() + return self._bytes + + def __len__(self): + return self.buffer.nbytes + + def __eq__(self, other): + return self.bytes == _content(other) + + @property + def done(self): + return self.tracker.done() + + def __buffer__(self, flags): + return self.buffer + + def __copy__(self): + """Create a shallow copy of the message. + + This does not copy the contents of the Frame, just the pointer. + This will increment the 0MQ ref count of the message, but not + the ref count of the Python object. That is only done once when + the Python is first turned into a 0MQ message. + """ + return self.fast_copy() + + def fast_copy(self): + """Fast shallow copy of the Frame. + + Does not copy underlying data. + """ + new_msg = Frame() + # This does not copy the contents, but just increases the ref-count + # of the zmq_msg by one. + C.zmq_msg_copy(new_msg.zmq_msg, self.zmq_msg) + # Copy the ref to underlying data + new_msg._data = self._data + new_msg._buffer = self._buffer + + # Frame copies share the tracker and tracker_event + new_msg.tracker_event = self.tracker_event + new_msg.tracker = self.tracker + + return new_msg + + +Message = Frame + +__all__ = ['Frame', 'Message'] diff --git a/source/zmq/backend/cffi/socket.py b/source/zmq/backend/cffi/socket.py new file mode 100644 index 0000000000000000000000000000000000000000..6a47dad909a62cd28044f88a1c8eb0c7062f7ea9 --- /dev/null +++ b/source/zmq/backend/cffi/socket.py @@ -0,0 +1,435 @@ +"""zmq Socket class""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import errno as errno_mod +import warnings + +import zmq +from zmq.constants import SocketOption, _OptType +from zmq.error import ZMQError, _check_rc, _check_version + +from ._cffi import ffi +from ._cffi import lib as C +from .message import Frame +from .utils import _retry_sys_call + +nsp = new_sizet_pointer = lambda length: ffi.new('size_t*', length) + + +def new_uint64_pointer(): + return ffi.new('uint64_t*'), nsp(ffi.sizeof('uint64_t')) + + +def new_int64_pointer(): + return ffi.new('int64_t*'), nsp(ffi.sizeof('int64_t')) + + +def new_int_pointer(): + return ffi.new('int*'), nsp(ffi.sizeof('int')) + + +def new_binary_data(length): + return ffi.new(f'char[{length:d}]'), nsp(ffi.sizeof('char') * length) + + +def value_uint64_pointer(val): + return ffi.new('uint64_t*', val), ffi.sizeof('uint64_t') + + +def value_int64_pointer(val): + return ffi.new('int64_t*', val), ffi.sizeof('int64_t') + + +def value_int_pointer(val): + return ffi.new('int*', val), ffi.sizeof('int') + + +def value_binary_data(val, length): + return ffi.new(f'char[{length + 1:d}]', val), ffi.sizeof('char') * length + + +_fd_size = ffi.sizeof('ZMQ_FD_T') +ZMQ_FD_64BIT = _fd_size == 8 + +IPC_PATH_MAX_LEN = C.get_ipc_path_max_len() + + +def new_pointer_from_opt(option, length=0): + opt_type = getattr(option, "_opt_type", _OptType.int) + + if opt_type == _OptType.int64 or (ZMQ_FD_64BIT and opt_type == _OptType.fd): + return new_int64_pointer() + elif opt_type == _OptType.bytes: + return new_binary_data(length) + else: + # default + return new_int_pointer() + + +def value_from_opt_pointer(option, opt_pointer, length=0): + try: + option = SocketOption(option) + except ValueError: + # unrecognized option, + # assume from the future, + # let EINVAL raise + opt_type = _OptType.int + else: + opt_type = option._opt_type + + if opt_type == _OptType.bytes: + return ffi.buffer(opt_pointer, length)[:] + else: + return int(opt_pointer[0]) + + +def initialize_opt_pointer(option, value, length=0): + opt_type = getattr(option, "_opt_type", _OptType.int) + if opt_type == _OptType.int64 or (ZMQ_FD_64BIT and opt_type == _OptType.fd): + return value_int64_pointer(value) + elif opt_type == _OptType.bytes: + return value_binary_data(value, length) + else: + return value_int_pointer(value) + + +class Socket: + context = None + socket_type = None + _zmq_socket = None + _closed = None + _ref = None + _shadow = False + _draft_poller = None + _draft_poller_ptr = None + copy_threshold = 0 + + def __init__(self, context=None, socket_type=None, shadow=0, copy_threshold=None): + if copy_threshold is None: + copy_threshold = zmq.COPY_THRESHOLD + self.copy_threshold = copy_threshold + + self.context = context + self._draft_poller = self._draft_poller_ptr = None + if shadow: + self._zmq_socket = ffi.cast("void *", shadow) + self._shadow = True + else: + self._shadow = False + self._zmq_socket = C.zmq_socket(context._zmq_ctx, socket_type) + if self._zmq_socket == ffi.NULL: + raise ZMQError() + self._closed = False + + @property + def underlying(self): + """The address of the underlying libzmq socket""" + return int(ffi.cast('size_t', self._zmq_socket)) + + def _check_closed_deep(self): + """thorough check of whether the socket has been closed, + even if by another entity (e.g. ctx.destroy). + + Only used by the `closed` property. + + returns True if closed, False otherwise + """ + if self._closed: + return True + try: + self.get(zmq.TYPE) + except ZMQError as e: + if e.errno == zmq.ENOTSOCK: + self._closed = True + return True + elif e.errno == zmq.ETERM: + pass + else: + raise + return False + + @property + def closed(self): + return self._check_closed_deep() + + def close(self, linger=None): + rc = 0 + if not self._closed and hasattr(self, '_zmq_socket'): + if self._draft_poller_ptr is not None: + rc = C.zmq_poller_destroy(self._draft_poller_ptr) + self._draft_poller = self._draft_poller_ptr = None + + if self._zmq_socket is not None: + if linger is not None: + self.set(zmq.LINGER, linger) + rc = C.zmq_close(self._zmq_socket) + self._closed = True + if rc < 0: + _check_rc(rc) + + def bind(self, address): + if isinstance(address, str): + address_b = address.encode('utf8') + else: + address_b = address + if isinstance(address, bytes): + address = address_b.decode('utf8') + rc = C.zmq_bind(self._zmq_socket, address_b) + if rc < 0: + if IPC_PATH_MAX_LEN and C.zmq_errno() == errno_mod.ENAMETOOLONG: + path = address.split('://', 1)[-1] + msg = ( + f'ipc path "{path}" is longer than {IPC_PATH_MAX_LEN} ' + 'characters (sizeof(sockaddr_un.sun_path)).' + ) + raise ZMQError(C.zmq_errno(), msg=msg) + elif C.zmq_errno() == errno_mod.ENOENT: + path = address.split('://', 1)[-1] + msg = f'No such file or directory for ipc path "{path}".' + raise ZMQError(C.zmq_errno(), msg=msg) + else: + _check_rc(rc) + + def unbind(self, address): + if isinstance(address, str): + address = address.encode('utf8') + rc = C.zmq_unbind(self._zmq_socket, address) + _check_rc(rc) + + def connect(self, address): + if isinstance(address, str): + address = address.encode('utf8') + rc = C.zmq_connect(self._zmq_socket, address) + _check_rc(rc) + + def disconnect(self, address): + if isinstance(address, str): + address = address.encode('utf8') + rc = C.zmq_disconnect(self._zmq_socket, address) + _check_rc(rc) + + def set(self, option, value): + length = None + if isinstance(value, str): + raise TypeError("unicode not allowed, use bytes") + + try: + option = SocketOption(option) + except ValueError: + # unrecognized option, + # assume from the future, + # let EINVAL raise + opt_type = _OptType.int + else: + opt_type = option._opt_type + + if isinstance(value, bytes): + if opt_type != _OptType.bytes: + raise TypeError(f"not a bytes sockopt: {option}") + length = len(value) + + c_value_pointer, c_sizet = initialize_opt_pointer(option, value, length) + + _retry_sys_call( + C.zmq_setsockopt, + self._zmq_socket, + option, + ffi.cast('void*', c_value_pointer), + c_sizet, + ) + + def get(self, option): + try: + option = SocketOption(option) + except ValueError: + # unrecognized option, + # assume from the future, + # let EINVAL raise + opt_type = _OptType.int + else: + opt_type = option._opt_type + + if option == zmq.FD and self._draft_poller is not None: + c_value_pointer, _ = new_pointer_from_opt(option) + C.zmq_poller_fd(self._draft_poller, ffi.cast('void*', c_value_pointer)) + return int(c_value_pointer[0]) + + c_value_pointer, c_sizet_pointer = new_pointer_from_opt(option, length=255) + + try: + _retry_sys_call( + C.zmq_getsockopt, + self._zmq_socket, + option, + c_value_pointer, + c_sizet_pointer, + ) + except ZMQError as e: + if ( + option == SocketOption.FD + and e.errno == zmq.Errno.EINVAL + and self.get(SocketOption.THREAD_SAFE) + ): + _check_version((4, 3, 2), "draft socket FD support via zmq_poller_fd") + if not zmq.DRAFT_API: + raise RuntimeError("libzmq must be built with draft support") + warnings.warn(zmq.error.DraftFDWarning(), stacklevel=2) + + # create a poller and retrieve its fd + self._draft_poller_ptr = ffi.new("void*[1]") + self._draft_poller_ptr[0] = self._draft_poller = C.zmq_poller_new() + if self._draft_poller == ffi.NULL: + # failed (why?), raise original error + self._draft_poller_ptr = self._draft_poller = None + raise + # register self with poller + rc = C.zmq_poller_add( + self._draft_poller, + self._zmq_socket, + ffi.NULL, + zmq.POLLIN | zmq.POLLOUT, + ) + _check_rc(rc) + # use poller fd as proxy for ours + rc = C.zmq_poller_fd( + self._draft_poller, ffi.cast('void *', c_value_pointer) + ) + _check_rc(rc) + return int(c_value_pointer[0]) + else: + raise + + sz = c_sizet_pointer[0] + v = value_from_opt_pointer(option, c_value_pointer, sz) + if ( + option != zmq.SocketOption.ROUTING_ID + and opt_type == _OptType.bytes + and v.endswith(b'\0') + ): + v = v[:-1] + return v + + def _send_copy(self, buf, flags): + """Send a copy of a bufferable""" + zmq_msg = ffi.new('zmq_msg_t*') + if not isinstance(buf, bytes): + # cast any bufferable data to bytes via memoryview + buf = memoryview(buf).tobytes() + + c_message = ffi.new('char[]', buf) + rc = C.zmq_msg_init_size(zmq_msg, len(buf)) + _check_rc(rc) + C.memcpy(C.zmq_msg_data(zmq_msg), c_message, len(buf)) + _retry_sys_call(C.zmq_msg_send, zmq_msg, self._zmq_socket, flags) + rc2 = C.zmq_msg_close(zmq_msg) + _check_rc(rc2) + + def _send_frame(self, frame, flags): + """Send a Frame on this socket in a non-copy manner.""" + # Always copy the Frame so the original message isn't garbage collected. + # This doesn't do a real copy, just a reference. + frame_copy = frame.fast_copy() + zmq_msg = frame_copy.zmq_msg + _retry_sys_call(C.zmq_msg_send, zmq_msg, self._zmq_socket, flags) + tracker = frame_copy.tracker + frame_copy.close() + return tracker + + def send(self, data, flags=0, copy=False, track=False): + if isinstance(data, str): + raise TypeError("Message must be in bytes, not a unicode object") + + if copy and not isinstance(data, Frame): + return self._send_copy(data, flags) + else: + close_frame = False + if isinstance(data, Frame): + if track and not data.tracker: + raise ValueError('Not a tracked message') + frame = data + else: + if self.copy_threshold: + buf = memoryview(data) + # always copy messages smaller than copy_threshold + if buf.nbytes < self.copy_threshold: + self._send_copy(buf, flags) + return zmq._FINISHED_TRACKER + frame = Frame(data, track=track, copy_threshold=self.copy_threshold) + close_frame = True + + tracker = self._send_frame(frame, flags) + if close_frame: + frame.close() + return tracker + + def recv(self, flags=0, copy=True, track=False): + if copy: + zmq_msg = ffi.new('zmq_msg_t*') + C.zmq_msg_init(zmq_msg) + else: + frame = zmq.Frame(track=track) + zmq_msg = frame.zmq_msg + + try: + _retry_sys_call(C.zmq_msg_recv, zmq_msg, self._zmq_socket, flags) + except Exception: + if copy: + C.zmq_msg_close(zmq_msg) + raise + + if not copy: + return frame + + _buffer = ffi.buffer(C.zmq_msg_data(zmq_msg), C.zmq_msg_size(zmq_msg)) + _bytes = _buffer[:] + rc = C.zmq_msg_close(zmq_msg) + _check_rc(rc) + return _bytes + + def recv_into(self, buffer, /, *, nbytes: int = 0, flags: int = 0) -> int: + view = memoryview(buffer) + if not view.contiguous: + raise BufferError("Can only recv_into contiguous buffers") + if view.readonly: + raise BufferError("Cannot recv_into readonly buffer") + if nbytes < 0: + raise ValueError(f"{nbytes=} must be non-negative") + view_bytes = view.nbytes + if nbytes == 0: + nbytes = view_bytes + elif nbytes > view_bytes: + raise ValueError(f"{nbytes=} too big for memoryview of {view_bytes}B") + c_buf = ffi.from_buffer(view) + rc: int = _retry_sys_call(C.zmq_recv, self._zmq_socket, c_buf, nbytes, flags) + _check_rc(rc) + return rc + + def monitor(self, addr, events=-1): + """s.monitor(addr, flags) + + Start publishing socket events on inproc. + See libzmq docs for zmq_monitor for details. + + Note: requires libzmq >= 3.2 + + Parameters + ---------- + addr : str + The inproc url used for monitoring. Passing None as + the addr will cause an existing socket monitor to be + deregistered. + events : int [default: zmq.EVENT_ALL] + The zmq event bitmask for which events will be sent to the monitor. + """ + if events < 0: + events = zmq.EVENT_ALL + if addr is None: + addr = ffi.NULL + if isinstance(addr, str): + addr = addr.encode('utf8') + C.zmq_socket_monitor(self._zmq_socket, addr, events) + + +__all__ = ['Socket', 'IPC_PATH_MAX_LEN'] diff --git a/source/zmq/backend/cffi/utils.py b/source/zmq/backend/cffi/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8dcc1c50d5ffa605630fa60aa157ff6e8329cabb --- /dev/null +++ b/source/zmq/backend/cffi/utils.py @@ -0,0 +1,80 @@ +"""miscellaneous zmq_utils wrapping""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from zmq.error import InterruptedSystemCall, _check_rc, _check_version + +from ._cffi import ffi +from ._cffi import lib as C + + +def has(capability): + """Check for zmq capability by name (e.g. 'ipc', 'curve') + + .. versionadded:: libzmq-4.1 + .. versionadded:: 14.1 + """ + _check_version((4, 1), 'zmq.has') + if isinstance(capability, str): + capability = capability.encode('utf8') + return bool(C.zmq_has(capability)) + + +def curve_keypair(): + """generate a Z85 key pair for use with zmq.CURVE security + + Requires libzmq (≥ 4.0) to have been built with CURVE support. + + Returns + ------- + (public, secret) : two bytestrings + The public and private key pair as 40 byte z85-encoded bytestrings. + """ + public = ffi.new('char[64]') + private = ffi.new('char[64]') + rc = C.zmq_curve_keypair(public, private) + _check_rc(rc) + return ffi.buffer(public)[:40], ffi.buffer(private)[:40] + + +def curve_public(private): + """Compute the public key corresponding to a private key for use + with zmq.CURVE security + + Requires libzmq (≥ 4.2) to have been built with CURVE support. + + Parameters + ---------- + private + The private key as a 40 byte z85-encoded bytestring + Returns + ------- + bytestring + The public key as a 40 byte z85-encoded bytestring. + """ + if isinstance(private, str): + private = private.encode('utf8') + _check_version((4, 2), "curve_public") + public = ffi.new('char[64]') + rc = C.zmq_curve_public(public, private) + _check_rc(rc) + return ffi.buffer(public)[:40] + + +def _retry_sys_call(f, *args, **kwargs): + """make a call, retrying if interrupted with EINTR""" + while True: + rc = f(*args) + try: + _check_rc(rc) + except InterruptedSystemCall: + continue + else: + break + return rc + + +PYZMQ_DRAFT_API: bool = bool(C.PYZMQ_DRAFT_API) + +__all__ = ['has', 'curve_keypair', 'curve_public', 'PYZMQ_DRAFT_API'] diff --git a/source/zmq/backend/cython/__init__.pxd b/source/zmq/backend/cython/__init__.pxd new file mode 100644 index 0000000000000000000000000000000000000000..94c7f8a3fb3b246dec15c64768147889f6593579 --- /dev/null +++ b/source/zmq/backend/cython/__init__.pxd @@ -0,0 +1 @@ +from zmq.backend.cython._zmq cimport Context, Frame, Socket diff --git a/source/zmq/backend/cython/__init__.py b/source/zmq/backend/cython/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6c4179a80a3165e15a6f01b602d0281e64254e06 --- /dev/null +++ b/source/zmq/backend/cython/__init__.py @@ -0,0 +1,15 @@ +"""Python bindings for core 0MQ objects.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from . import _zmq + +# mq not in __all__ +from ._zmq import * # noqa +from ._zmq import monitored_queue # noqa + +Message = _zmq.Frame + +__all__ = ["Message"] +__all__.extend(_zmq.__all__) diff --git a/source/zmq/backend/cython/_externs.pxd b/source/zmq/backend/cython/_externs.pxd new file mode 100644 index 0000000000000000000000000000000000000000..dfe0744ddf1da8f8e25bd948d5bcc9517197fbde --- /dev/null +++ b/source/zmq/backend/cython/_externs.pxd @@ -0,0 +1,13 @@ +cdef extern from "mutex.h" nogil: + ctypedef struct mutex_t: + pass + cdef mutex_t* mutex_allocate() + cdef void mutex_dallocate(mutex_t*) + cdef int mutex_lock(mutex_t*) + cdef int mutex_unlock(mutex_t*) + +cdef extern from "getpid_compat.h": + cdef int getpid() + +cdef extern from "ipcmaxlen.h": + cdef int get_ipc_path_max_len() diff --git a/source/zmq/backend/cython/_zmq.abi3.so b/source/zmq/backend/cython/_zmq.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..6d3013d210cdcf6f49323d7ac567ed530971fd39 --- /dev/null +++ b/source/zmq/backend/cython/_zmq.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f1eb00d8890b3ba39b98e5505f7e80ca902974c80cb24801ece23baf5856924d +size 266809 diff --git a/source/zmq/backend/cython/_zmq.pxd b/source/zmq/backend/cython/_zmq.pxd new file mode 100644 index 0000000000000000000000000000000000000000..a7e4f244294c6d954fb8989bb161dd25dcb36a6f --- /dev/null +++ b/source/zmq/backend/cython/_zmq.pxd @@ -0,0 +1,52 @@ +# cython: language_level = 3str +"""zmq Cython backend augmented declarations""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from zmq.backend.cython.libzmq cimport zmq_msg_t + +cdef class Context: + + cdef object __weakref__ # enable weakref + cdef void *handle # The C handle for the underlying zmq object. + cdef bint _shadow # whether the Context is a shadow wrapper of another + cdef int _pid # the pid of the process which created me (for fork safety) + + cdef public bint closed # bool property for a closed context. + cdef inline int _term(self) + +cdef class MessageTracker(object): + cdef set events # Message Event objects to track. + cdef set peers # Other Message or MessageTracker objects. + +cdef class Frame: + + cdef zmq_msg_t zmq_msg + cdef object _data # The actual message data as a Python object. + cdef object _buffer # A Python memoryview of the message contents + cdef object _bytes # A bytes copy of the message. + cdef bint _failed_init # flag to hold failed init + cdef public object tracker_event # Event for use with zmq_free_fn. + cdef public object tracker # MessageTracker object. + cdef public bint more # whether RCVMORE was set + + cdef Frame fast_copy(self) # Create shallow copy of Message object. + +cdef class Socket: + + cdef object __weakref__ # enable weakref + cdef void *handle # The C handle for the underlying zmq object. + cdef bint _shadow # whether the Socket is a shadow wrapper of another + # Hold on to a reference to the context to make sure it is not garbage + # collected until the socket it done with it. + cdef public Context context # The zmq Context object that owns this. + cdef public bint _closed # bool property for a closed socket. + cdef public int copy_threshold # threshold below which pyzmq will always copy messages + cdef int _pid # the pid of the process which created me (for fork safety) + cdef void *_draft_poller # The C handle for the zmq poller for draft socket zmq.FD support + + # cpdef methods for direct-cython access: + cpdef object send(self, data, int flags=*, bint copy=*, bint track=*) + cpdef object recv(self, int flags=*, bint copy=*, bint track=*) + cpdef int recv_into(self, buffer, int nbytes=*, int flags=*) diff --git a/source/zmq/backend/cython/_zmq.py b/source/zmq/backend/cython/_zmq.py new file mode 100644 index 0000000000000000000000000000000000000000..c6f53a54047f761b81e05b82e94a48d1275cdf2d --- /dev/null +++ b/source/zmq/backend/cython/_zmq.py @@ -0,0 +1,2048 @@ +# cython: language_level = 3str +# cython: freethreading_compatible = True +"""Cython backend for pyzmq""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from __future__ import annotations + +try: + import cython + + if not cython.compiled: + raise ImportError() +except ImportError: + from pathlib import Path + + zmq_root = Path(__file__).parents[3] + msg = f""" + Attempting to import zmq Cython backend, which has not been compiled. + + This probably means you are importing zmq from its source tree. + if this is what you want, make sure to do an in-place build first: + + pip install -e '{zmq_root}' + + If it is not, then '{zmq_root}' is probably on your sys.path, + when it shouldn't be. Is that your current working directory? + + If neither of those is true and this file is actually installed, + something seems to have gone wrong with the install! + Please report at https://github.com/zeromq/pyzmq/issues + """ + raise ImportError(msg) + +import warnings +from threading import Event +from time import monotonic +from weakref import ref + +import cython as C +from cython import ( + NULL, + Py_ssize_t, + address, + bint, + cast, + cclass, + cfunc, + char, + declare, + inline, + nogil, + p_char, + p_void, + pointer, + size_t, + sizeof, +) +from cython.cimports.cpython.buffer import ( + Py_buffer, + PyBUF_ANY_CONTIGUOUS, + PyBUF_WRITABLE, + PyBuffer_Release, + PyObject_GetBuffer, +) +from cython.cimports.cpython.bytes import ( + PyBytes_AsString, + PyBytes_FromStringAndSize, + PyBytes_Size, +) +from cython.cimports.cpython.exc import PyErr_CheckSignals +from cython.cimports.libc.errno import EAGAIN, EINTR, ENAMETOOLONG, ENOENT, ENOTSOCK +from cython.cimports.libc.stdint import uint32_t +from cython.cimports.libc.stdio import fprintf +from cython.cimports.libc.stdio import stderr as cstderr +from cython.cimports.libc.stdlib import free, malloc +from cython.cimports.libc.string import memcpy +from cython.cimports.zmq.backend.cython import libzmq +from cython.cimports.zmq.backend.cython._externs import ( + get_ipc_path_max_len, + getpid, + mutex_allocate, + mutex_lock, + mutex_t, + mutex_unlock, +) +from cython.cimports.zmq.backend.cython.libzmq import ( + ZMQ_ENOTSOCK, + ZMQ_ETERM, + ZMQ_EVENT_ALL, + ZMQ_FD, + ZMQ_IDENTITY, + ZMQ_IO_THREADS, + ZMQ_LINGER, + ZMQ_POLLIN, + ZMQ_POLLOUT, + ZMQ_RCVMORE, + ZMQ_ROUTER, + ZMQ_SNDMORE, + ZMQ_THREAD_SAFE, + ZMQ_TYPE, + _zmq_version, + fd_t, + int64_t, + zmq_bind, + zmq_close, + zmq_connect, + zmq_ctx_destroy, + zmq_ctx_get, + zmq_ctx_new, + zmq_ctx_set, + zmq_curve_keypair, + zmq_curve_public, + zmq_disconnect, + zmq_free_fn, + zmq_getsockopt, + zmq_has, + zmq_join, + zmq_leave, + zmq_msg_close, + zmq_msg_copy, + zmq_msg_data, + zmq_msg_get, + zmq_msg_gets, + zmq_msg_group, + zmq_msg_init, + zmq_msg_init_data, + zmq_msg_init_size, + zmq_msg_recv, + zmq_msg_routing_id, + zmq_msg_send, + zmq_msg_set, + zmq_msg_set_group, + zmq_msg_set_routing_id, + zmq_msg_size, + zmq_msg_t, + zmq_poller_add, + zmq_poller_destroy, + zmq_poller_fd, + zmq_poller_new, + zmq_pollitem_t, + zmq_proxy, + zmq_proxy_steerable, + zmq_recv, + zmq_setsockopt, + zmq_socket, + zmq_socket_monitor, + zmq_strerror, + zmq_unbind, +) +from cython.cimports.zmq.backend.cython.libzmq import zmq_errno as _zmq_errno +from cython.cimports.zmq.backend.cython.libzmq import zmq_poll as zmq_poll_c + +import zmq +from zmq.constants import SocketOption, _OptType +from zmq.error import ( + Again, + ContextTerminated, + InterruptedSystemCall, + ZMQError, + _check_version, +) + +IPC_PATH_MAX_LEN: int = get_ipc_path_max_len() + +PYZMQ_DRAFT_API: bool = bool(libzmq.PYZMQ_DRAFT_API) + + +@cfunc +@inline +@C.exceptval(-1) +def _check_rc(rc: C.int, error_without_errno: bint = False) -> C.int: + """internal utility for checking zmq return condition + + and raising the appropriate Exception class + """ + errno: C.int = _zmq_errno() + PyErr_CheckSignals() + if errno == 0 and not error_without_errno: + return 0 + if rc == -1: # if rc < -1, it's a bug in libzmq. Should we warn? + if errno == EINTR: + raise InterruptedSystemCall(errno) + elif errno == EAGAIN: + raise Again(errno) + elif errno == ZMQ_ETERM: + raise ContextTerminated(errno) + else: + raise ZMQError(errno) + return 0 + + +# message Frame class + +_zhint = C.struct( + sock=p_void, + mutex=pointer(mutex_t), + id=size_t, +) + + +@cfunc +@nogil +def free_python_msg(data: p_void, vhint: p_void) -> C.int: + """A pure-C function for DECREF'ing Python-owned message data. + + Sends a message on a PUSH socket + + The hint is a `zhint` struct with two values: + + sock (void *): pointer to the Garbage Collector's PUSH socket + id (size_t): the id to be used to construct a zmq_msg_t that should be sent on a PUSH socket, + signaling the Garbage Collector to remove its reference to the object. + + When the Garbage Collector's PULL socket receives the message, + it deletes its reference to the object, + allowing Python to free the memory. + """ + msg = declare(zmq_msg_t) + msg_ptr: pointer(zmq_msg_t) = address(msg) + hint: pointer(_zhint) = cast(pointer(_zhint), vhint) + rc: C.int + + if hint != NULL: + zmq_msg_init_size(msg_ptr, sizeof(size_t)) + memcpy(zmq_msg_data(msg_ptr), address(hint.id), sizeof(size_t)) + rc = mutex_lock(hint.mutex) + if rc != 0: + fprintf(cstderr, "pyzmq-gc mutex lock failed rc=%d\n", rc) + rc = zmq_msg_send(msg_ptr, hint.sock, 0) + if rc < 0: + # gc socket could have been closed, e.g. during process teardown. + # If so, ignore the failure because there's nothing to do. + if _zmq_errno() != ZMQ_ENOTSOCK: + fprintf( + cstderr, "pyzmq-gc send failed: %s\n", zmq_strerror(_zmq_errno()) + ) + rc = mutex_unlock(hint.mutex) + if rc != 0: + fprintf(cstderr, "pyzmq-gc mutex unlock failed rc=%d\n", rc) + + zmq_msg_close(msg_ptr) + free(hint) + return 0 + + +@cfunc +@inline +def _copy_zmq_msg_bytes(zmq_msg: pointer(zmq_msg_t)) -> bytes: + """Copy the data from a zmq_msg_t""" + data_c: p_char = NULL + data_len_c: Py_ssize_t + data_c = cast(p_char, zmq_msg_data(zmq_msg)) + data_len_c = zmq_msg_size(zmq_msg) + return PyBytes_FromStringAndSize(data_c, data_len_c) + + +@cfunc +@inline +def _asbuffer(obj, data_c: pointer(p_void), writable: bint = False) -> size_t: + """Get a C buffer from a memoryview""" + pybuf = declare(Py_buffer) + flags: C.int = PyBUF_ANY_CONTIGUOUS + if writable: + flags |= PyBUF_WRITABLE + rc: C.int = PyObject_GetBuffer(obj, address(pybuf), flags) + if rc < 0: + raise ValueError("Couldn't create buffer") + data_c[0] = pybuf.buf + data_size: size_t = pybuf.len + PyBuffer_Release(address(pybuf)) + return data_size + + +_gc = None + + +@cclass +class Frame: + def __init__( + self, data=None, track=False, copy=None, copy_threshold=None, **kwargs + ): + rc: C.int + data_c: p_char = NULL + data_len_c: Py_ssize_t = 0 + hint: pointer(_zhint) + if copy_threshold is None: + copy_threshold = zmq.COPY_THRESHOLD + + c_copy_threshold: C.size_t = 0 + if copy_threshold is not None: + c_copy_threshold = copy_threshold + + zmq_msg_ptr: pointer(zmq_msg_t) = address(self.zmq_msg) + # init more as False + self.more = False + + # Save the data object in case the user wants the the data as a str. + self._data = data + self._failed_init = True # bool switch for dealloc + self._buffer = None # buffer view of data + self._bytes = None # bytes copy of data + + self.tracker_event = None + self.tracker = None + # self.tracker should start finished + # except in the case where we are sharing memory with libzmq + if track: + self.tracker = zmq._FINISHED_TRACKER + + if isinstance(data, str): + raise TypeError("Str objects not allowed. Only: bytes, buffer interfaces.") + + if data is None: + rc = zmq_msg_init(zmq_msg_ptr) + _check_rc(rc) + self._failed_init = False + return + + data_len_c = _asbuffer(data, cast(pointer(p_void), address(data_c))) + + # copy unspecified, apply copy_threshold + c_copy: bint = True + if copy is None: + if c_copy_threshold and data_len_c < c_copy_threshold: + c_copy = True + else: + c_copy = False + else: + c_copy = copy + + if c_copy: + # copy message data instead of sharing memory + rc = zmq_msg_init_size(zmq_msg_ptr, data_len_c) + _check_rc(rc) + memcpy(zmq_msg_data(zmq_msg_ptr), data_c, data_len_c) + self._failed_init = False + return + + # Getting here means that we are doing a true zero-copy Frame, + # where libzmq and Python are sharing memory. + # Hook up garbage collection with MessageTracker and zmq_free_fn + + # Event and MessageTracker for monitoring when zmq is done with data: + if track: + evt = Event() + self.tracker_event = evt + self.tracker = zmq.MessageTracker(evt) + # create the hint for zmq_free_fn + # two pointers: the gc context and a message to be sent to the gc PULL socket + # allows libzmq to signal to Python when it is done with Python-owned memory. + global _gc + if _gc is None: + from zmq.utils.garbage import gc as _gc + + hint: pointer(_zhint) = cast(pointer(_zhint), malloc(sizeof(_zhint))) + hint.id = _gc.store(data, self.tracker_event) + if not _gc._push_mutex: + hint.mutex = mutex_allocate() + _gc._push_mutex = cast(size_t, hint.mutex) + else: + hint.mutex = cast(pointer(mutex_t), cast(size_t, _gc._push_mutex)) + hint.sock = cast(p_void, cast(size_t, _gc._push_socket.underlying)) + + rc = zmq_msg_init_data( + zmq_msg_ptr, + cast(p_void, data_c), + data_len_c, + cast(pointer(zmq_free_fn), free_python_msg), + cast(p_void, hint), + ) + if rc != 0: + free(hint) + _check_rc(rc) + self._failed_init = False + + def __dealloc__(self): + if self._failed_init: + return + # decrease the 0MQ ref-count of zmq_msg + with nogil: + rc: C.int = zmq_msg_close(address(self.zmq_msg)) + _check_rc(rc) + + def __copy__(self): + return self.fast_copy() + + def fast_copy(self) -> Frame: + new_msg: Frame = Frame() + # This does not copy the contents, but just increases the ref-count + # of the zmq_msg by one. + zmq_msg_copy(address(new_msg.zmq_msg), address(self.zmq_msg)) + # Copy the ref to data so the copy won't create a copy when str is + # called. + if self._data is not None: + new_msg._data = self._data + if self._buffer is not None: + new_msg._buffer = self._buffer + if self._bytes is not None: + new_msg._bytes = self._bytes + + # Frame copies share the tracker and tracker_event + new_msg.tracker_event = self.tracker_event + new_msg.tracker = self.tracker + + return new_msg + + # buffer interface code adapted from petsc4py by Lisandro Dalcin, a BSD project + + def __getbuffer__(self, buffer: pointer(Py_buffer), flags: C.int): # noqa: F821 + # new-style (memoryview) buffer interface + buffer.buf = zmq_msg_data(address(self.zmq_msg)) + buffer.len = zmq_msg_size(address(self.zmq_msg)) + + buffer.obj = self + buffer.readonly = 0 + buffer.format = "B" + buffer.ndim = 1 + buffer.shape = address(buffer.len) + buffer.strides = NULL + buffer.suboffsets = NULL + buffer.itemsize = 1 + buffer.internal = NULL + + def __len__(self) -> size_t: + """Return the length of the message in bytes.""" + sz: size_t = zmq_msg_size(address(self.zmq_msg)) + return sz + + @property + def buffer(self): + """A memoryview of the message contents.""" + _buffer = self._buffer and self._buffer() + if _buffer is not None: + return _buffer + _buffer = memoryview(self) + self._buffer = ref(_buffer) + return _buffer + + @property + def bytes(self): + """The message content as a Python bytes object. + + The first time this property is accessed, a copy of the message + contents is made. From then on that same copy of the message is + returned. + """ + if self._bytes is None: + self._bytes = _copy_zmq_msg_bytes(address(self.zmq_msg)) + return self._bytes + + def get(self, option): + """ + Get a Frame option or property. + + See the 0MQ API documentation for zmq_msg_get and zmq_msg_gets + for details on specific options. + + .. versionadded:: libzmq-3.2 + .. versionadded:: 13.0 + + .. versionchanged:: 14.3 + add support for zmq_msg_gets (requires libzmq-4.1) + All message properties are strings. + + .. versionchanged:: 17.0 + Added support for `routing_id` and `group`. + Only available if draft API is enabled + with libzmq >= 4.2. + """ + rc: C.int = 0 + property_c: p_char = NULL + + # zmq_msg_get + if isinstance(option, int): + rc = zmq_msg_get(address(self.zmq_msg), option) + _check_rc(rc) + return rc + + if option == 'routing_id': + routing_id: uint32_t = zmq_msg_routing_id(address(self.zmq_msg)) + if routing_id == 0: + _check_rc(-1) + return routing_id + elif option == 'group': + buf = zmq_msg_group(address(self.zmq_msg)) + if buf == NULL: + _check_rc(-1) + return buf.decode('utf8') + + # zmq_msg_gets + _check_version((4, 1), "get string properties") + if isinstance(option, str): + option = option.encode('utf8') + + if not isinstance(option, bytes): + raise TypeError(f"expected str, got: {option!r}") + + property_c = option + + result: p_char = cast(p_char, zmq_msg_gets(address(self.zmq_msg), property_c)) + if result == NULL: + _check_rc(-1) + return result.decode('utf8') + + def set(self, option, value): + """Set a Frame option. + + See the 0MQ API documentation for zmq_msg_set + for details on specific options. + + .. versionadded:: libzmq-3.2 + .. versionadded:: 13.0 + .. versionchanged:: 17.0 + Added support for `routing_id` and `group`. + Only available if draft API is enabled + with libzmq >= 4.2. + """ + rc: C.int + + if option == 'routing_id': + routing_id: uint32_t = value + rc = zmq_msg_set_routing_id(address(self.zmq_msg), routing_id) + _check_rc(rc) + return + elif option == 'group': + if isinstance(value, str): + value = value.encode('utf8') + rc = zmq_msg_set_group(address(self.zmq_msg), value) + _check_rc(rc) + return + + rc = zmq_msg_set(address(self.zmq_msg), option, value) + _check_rc(rc) + + +@cclass +class Context: + """ + Manage the lifecycle of a 0MQ context. + + Parameters + ---------- + io_threads : int + The number of IO threads. + """ + + def __init__(self, io_threads: C.int = 1, shadow: size_t = 0): + self.handle = NULL + self._pid = 0 + self._shadow = False + + if shadow: + self.handle = cast(p_void, shadow) + self._shadow = True + else: + self._shadow = False + self.handle = zmq_ctx_new() + + if self.handle == NULL: + raise ZMQError() + + rc: C.int = 0 + if not self._shadow: + rc = zmq_ctx_set(self.handle, ZMQ_IO_THREADS, io_threads) + _check_rc(rc) + + self.closed = False + self._pid = getpid() + + @property + def underlying(self): + """The address of the underlying libzmq context""" + return cast(size_t, self.handle) + + @cfunc + @inline + def _term(self) -> C.int: + rc: C.int = 0 + if self.handle != NULL and not self.closed and getpid() == self._pid: + with nogil: + rc = zmq_ctx_destroy(self.handle) + self.handle = NULL + return rc + + def term(self): + """ + Close or terminate the context. + + This can be called to close the context by hand. If this is not called, + the context will automatically be closed when it is garbage collected. + """ + rc: C.int = self._term() + try: + _check_rc(rc) + except InterruptedSystemCall: + # ignore interrupted term + # see PEP 475 notes about close & EINTR for why + pass + + self.closed = True + + def set(self, option: C.int, optval): + """ + Set a context option. + + See the 0MQ API documentation for zmq_ctx_set + for details on specific options. + + .. versionadded:: libzmq-3.2 + .. versionadded:: 13.0 + + Parameters + ---------- + option : int + The option to set. Available values will depend on your + version of libzmq. Examples include:: + + zmq.IO_THREADS, zmq.MAX_SOCKETS + + optval : int + The value of the option to set. + """ + optval_int_c: C.int + rc: C.int + + if self.closed: + raise RuntimeError("Context has been destroyed") + + if not isinstance(optval, int): + raise TypeError(f'expected int, got: {optval!r}') + optval_int_c = optval + rc = zmq_ctx_set(self.handle, option, optval_int_c) + _check_rc(rc) + + def get(self, option: C.int): + """ + Get the value of a context option. + + See the 0MQ API documentation for zmq_ctx_get + for details on specific options. + + .. versionadded:: libzmq-3.2 + .. versionadded:: 13.0 + + Parameters + ---------- + option : int + The option to get. Available values will depend on your + version of libzmq. Examples include:: + + zmq.IO_THREADS, zmq.MAX_SOCKETS + + Returns + ------- + optval : int + The value of the option as an integer. + """ + rc: C.int + + if self.closed: + raise RuntimeError("Context has been destroyed") + + rc = zmq_ctx_get(self.handle, option) + _check_rc(rc, error_without_errno=False) + return rc + + +@cfunc +@inline +def _c_addr(addr) -> bytes: + """cast an address input to bytes + + Expects a str, but accepts bytes + and raises informative TypeError otherwise. + """ + if isinstance(addr, str): + addr = addr.encode("utf-8") + try: + c_addr: bytes = addr + except TypeError: + raise TypeError(f"Expected addr to be str, got addr={addr!r}") + return c_addr + + +@cclass +class Socket: + """ + A 0MQ socket. + + These objects will generally be constructed via the socket() method of a Context object. + + Note: 0MQ Sockets are *not* threadsafe. **DO NOT** share them across threads. + + Parameters + ---------- + context : Context + The 0MQ Context this Socket belongs to. + socket_type : int + The socket type, which can be any of the 0MQ socket types: + REQ, REP, PUB, SUB, PAIR, DEALER, ROUTER, PULL, PUSH, XPUB, XSUB. + + See Also + -------- + .Context.socket : method for creating a socket bound to a Context. + """ + + def __init__( + self, + context=None, + socket_type: C.int = -1, + shadow: size_t = 0, + copy_threshold=None, + ): + # pre-init + self.handle = NULL + self._draft_poller = NULL + self._pid = 0 + self._shadow = False + self.context = None + + if copy_threshold is None: + copy_threshold = zmq.COPY_THRESHOLD + self.copy_threshold = copy_threshold + + self.handle = NULL + self.context = context + if shadow: + self._shadow = True + self.handle = cast(p_void, shadow) + else: + if context is None: + raise TypeError("context must be specified") + if socket_type < 0: + raise TypeError("socket_type must be specified") + self._shadow = False + self.handle = zmq_socket(self.context.handle, socket_type) + if self.handle == NULL: + raise ZMQError() + self._closed = False + self._pid = getpid() + + @property + def underlying(self): + """The address of the underlying libzmq socket""" + return cast(size_t, self.handle) + + @property + def closed(self): + """Whether the socket is closed""" + return _check_closed_deep(self) + + def close(self, linger: int | None = None): + """ + Close the socket. + + If linger is specified, LINGER sockopt will be set prior to closing. + + This can be called to close the socket by hand. If this is not + called, the socket will automatically be closed when it is + garbage collected. + """ + rc: C.int = 0 + linger_c: C.int + setlinger: bint = False + + if linger is not None: + linger_c = linger + setlinger = True + + if self.handle != NULL and not self._closed and getpid() == self._pid: + if setlinger: + zmq_setsockopt(self.handle, ZMQ_LINGER, address(linger_c), sizeof(int)) + + # teardown draft poller + if self._draft_poller != NULL: + zmq_poller_destroy(address(self._draft_poller)) + self._draft_poller = NULL + + rc = zmq_close(self.handle) + if rc < 0 and _zmq_errno() != ENOTSOCK: + # ignore ENOTSOCK (closed by Context) + _check_rc(rc) + self._closed = True + self.handle = NULL + + def set(self, option: C.int, optval): + """ + Set socket options. + + See the 0MQ API documentation for details on specific options. + + Parameters + ---------- + option : int + The option to set. Available values will depend on your + version of libzmq. Examples include:: + + zmq.SUBSCRIBE, UNSUBSCRIBE, IDENTITY, HWM, LINGER, FD + + optval : int or bytes + The value of the option to set. + + Notes + ----- + .. warning:: + + All options other than zmq.SUBSCRIBE, zmq.UNSUBSCRIBE and + zmq.LINGER only take effect for subsequent socket bind/connects. + """ + optval_int64_c: int64_t + optval_int_c: C.int + optval_c: p_char + sz: Py_ssize_t + + _check_closed(self) + if isinstance(optval, str): + raise TypeError("unicode not allowed, use setsockopt_string") + + try: + sopt = SocketOption(option) + except ValueError: + # unrecognized option, + # assume from the future, + # let EINVAL raise + opt_type = _OptType.int + else: + opt_type = sopt._opt_type + + if opt_type == _OptType.bytes: + if not isinstance(optval, bytes): + raise TypeError(f'expected bytes, got: {optval!r}') + optval_c = PyBytes_AsString(optval) + sz = PyBytes_Size(optval) + _setsockopt(self.handle, option, optval_c, sz) + elif opt_type == _OptType.int64: + if not isinstance(optval, int): + raise TypeError(f'expected int, got: {optval!r}') + optval_int64_c = optval + _setsockopt(self.handle, option, address(optval_int64_c), sizeof(int64_t)) + else: + # default is to assume int, which is what most new sockopts will be + # this lets pyzmq work with newer libzmq which may add constants + # pyzmq has not yet added, rather than artificially raising. Invalid + # sockopts will still raise just the same, but it will be libzmq doing + # the raising. + if not isinstance(optval, int): + raise TypeError(f'expected int, got: {optval!r}') + optval_int_c = optval + _setsockopt(self.handle, option, address(optval_int_c), sizeof(int)) + + def get(self, option: C.int): + """ + Get the value of a socket option. + + See the 0MQ API documentation for details on specific options. + + .. versionchanged:: 27 + Added experimental support for ZMQ_FD for draft sockets via `zmq_poller_fd`. + Requires libzmq >=4.3.2 built with draft support. + + Parameters + ---------- + option : int + The option to get. Available values will depend on your + version of libzmq. Examples include:: + + zmq.IDENTITY, HWM, LINGER, FD, EVENTS + + Returns + ------- + optval : int or bytes + The value of the option as a bytestring or int. + """ + optval_int64_c = declare(int64_t) + optval_int_c = declare(C.int) + optval_fd_c = declare(fd_t) + identity_str_c = declare(char[255]) + sz: size_t + + _check_closed(self) + + try: + sopt = SocketOption(option) + except ValueError: + # unrecognized option, + # assume from the future, + # let EINVAL raise + opt_type = _OptType.int + else: + opt_type = sopt._opt_type + + if opt_type == _OptType.bytes: + sz = 255 + _getsockopt(self.handle, option, cast(p_void, identity_str_c), address(sz)) + # strip null-terminated strings *except* identity + if ( + option != ZMQ_IDENTITY + and sz > 0 + and (cast(p_char, identity_str_c))[sz - 1] == b'\0' + ): + sz -= 1 + result = PyBytes_FromStringAndSize(cast(p_char, identity_str_c), sz) + elif opt_type == _OptType.int64: + sz = sizeof(int64_t) + _getsockopt( + self.handle, option, cast(p_void, address(optval_int64_c)), address(sz) + ) + result = optval_int64_c + elif option == ZMQ_FD and self._draft_poller != NULL: + # draft sockets use FD of a draft zmq_poller as proxy + rc = zmq_poller_fd(self._draft_poller, address(optval_fd_c)) + _check_rc(rc) + result = optval_fd_c + elif opt_type == _OptType.fd: + sz = sizeof(fd_t) + try: + _getsockopt( + self.handle, option, cast(p_void, address(optval_fd_c)), address(sz) + ) + except ZMQError as e: + # threadsafe sockets don't support ZMQ_FD (yet!) + # fallback on zmq_poller_fd as proxy with the same behavior + # until libzmq fixes this. + # if upstream fixes it, this branch will never be taken + if ( + option == ZMQ_FD + and e.errno == zmq.Errno.EINVAL + and self.get(ZMQ_THREAD_SAFE) + ): + _check_version( + (4, 3, 2), "draft socket FD support via zmq_poller_fd" + ) + if not zmq.DRAFT_API: + raise RuntimeError( + "libzmq and pyzmq must be built with draft support" + ) + warnings.warn(zmq.error.DraftFDWarning(), stacklevel=2) + + # create a poller and retrieve its fd + self._draft_poller = zmq_poller_new() + if self._draft_poller == NULL: + # failed (why?), raise original error + raise + # register self with poller + rc = zmq_poller_add( + self._draft_poller, self.handle, NULL, ZMQ_POLLIN | ZMQ_POLLOUT + ) + _check_rc(rc) + # use poller fd as proxy for ours + rc = zmq_poller_fd(self._draft_poller, address(optval_fd_c)) + _check_rc(rc) + else: + raise + result = optval_fd_c + else: + # default is to assume int, which is what most new sockopts will be + # this lets pyzmq work with newer libzmq which may add constants + # pyzmq has not yet added, rather than artificially raising. Invalid + # sockopts will still raise just the same, but it will be libzmq doing + # the raising. + sz = sizeof(int) + _getsockopt( + self.handle, option, cast(p_void, address(optval_int_c)), address(sz) + ) + result = optval_int_c + + return result + + def bind(self, addr: str | bytes): + """ + Bind the socket to an address. + + This causes the socket to listen on a network port. Sockets on the + other side of this connection will use ``Socket.connect(addr)`` to + connect to this socket. + + Parameters + ---------- + addr : str + The address string. This has the form 'protocol://interface:port', + for example 'tcp://127.0.0.1:5555'. Protocols supported include + tcp, udp, pgm, epgm, inproc and ipc. If the address is unicode, it is + encoded to utf-8 first. + """ + _addr_bytes: bytes = _c_addr(addr) + c_addr: p_char = _addr_bytes + _check_closed(self) + rc: C.int = zmq_bind(self.handle, c_addr) + if rc != 0: + _errno: C.int = _zmq_errno() + _ipc_max: C.int = get_ipc_path_max_len() + if _ipc_max and _errno == ENAMETOOLONG: + path = addr.split('://', 1)[-1] + msg = ( + f'ipc path "{path}" is longer than {_ipc_max} ' + 'characters (sizeof(sockaddr_un.sun_path)). ' + 'zmq.IPC_PATH_MAX_LEN constant can be used ' + 'to check addr length (if it is defined).' + ) + raise ZMQError(msg=msg) + elif _errno == ENOENT: + path = addr.split('://', 1)[-1] + msg = f'No such file or directory for ipc path "{path}".' + raise ZMQError(msg=msg) + while True: + try: + _check_rc(rc) + except InterruptedSystemCall: + rc = zmq_bind(self.handle, c_addr) + continue + else: + break + + def connect(self, addr: str | bytes) -> None: + """ + Connect to a remote 0MQ socket. + + Parameters + ---------- + addr : str + The address string. This has the form 'protocol://interface:port', + for example 'tcp://127.0.0.1:5555'. Protocols supported are + tcp, udp, pgm, inproc and ipc. If the address is unicode, it is + encoded to utf-8 first. + """ + rc: C.int + _addr_bytes: bytes = _c_addr(addr) + c_addr: p_char = _addr_bytes + _check_closed(self) + + while True: + try: + rc = zmq_connect(self.handle, c_addr) + _check_rc(rc) + except InterruptedSystemCall: + # retry syscall + continue + else: + break + + def unbind(self, addr: str | bytes): + """ + Unbind from an address (undoes a call to bind). + + .. versionadded:: libzmq-3.2 + .. versionadded:: 13.0 + + Parameters + ---------- + addr : str + The address string. This has the form 'protocol://interface:port', + for example 'tcp://127.0.0.1:5555'. Protocols supported are + tcp, udp, pgm, inproc and ipc. If the address is unicode, it is + encoded to utf-8 first. + """ + _addr_bytes: bytes = _c_addr(addr) + c_addr: p_char = _addr_bytes + _check_closed(self) + rc: C.int = zmq_unbind(self.handle, c_addr) + if rc != 0: + raise ZMQError() + + def disconnect(self, addr: str | bytes): + """ + Disconnect from a remote 0MQ socket (undoes a call to connect). + + .. versionadded:: libzmq-3.2 + .. versionadded:: 13.0 + + Parameters + ---------- + addr : str + The address string. This has the form 'protocol://interface:port', + for example 'tcp://127.0.0.1:5555'. Protocols supported are + tcp, udp, pgm, inproc and ipc. If the address is unicode, it is + encoded to utf-8 first. + """ + _addr_bytes: bytes = _c_addr(addr) + c_addr: p_char = _addr_bytes + _check_closed(self) + + rc: C.int = zmq_disconnect(self.handle, c_addr) + if rc != 0: + raise ZMQError() + + def monitor(self, addr: str | bytes | None, events: C.int = ZMQ_EVENT_ALL): + """ + Start publishing socket events on inproc. + See libzmq docs for zmq_monitor for details. + + While this function is available from libzmq 3.2, + pyzmq cannot parse monitor messages from libzmq prior to 4.0. + + .. versionadded: libzmq-3.2 + .. versionadded: 14.0 + + Parameters + ---------- + addr : str | None + The inproc url used for monitoring. Passing None as + the addr will cause an existing socket monitor to be + deregistered. + events : int + default: zmq.EVENT_ALL + The zmq event bitmask for which events will be sent to the monitor. + """ + c_addr: p_char = NULL + if addr is not None: + _addr_bytes: bytes = _c_addr(addr) + c_addr: p_char = _addr_bytes + _check_closed(self) + + _check_rc(zmq_socket_monitor(self.handle, c_addr, events)) + + def join(self, group: str | bytes): + """ + Join a RADIO-DISH group + + Only for DISH sockets. + + libzmq and pyzmq must have been built with ZMQ_BUILD_DRAFT_API + + .. versionadded:: 17 + """ + _check_version((4, 2), "RADIO-DISH") + if not zmq.DRAFT_API: + raise RuntimeError("libzmq and pyzmq must be built with draft support") + if isinstance(group, str): + group = group.encode('utf8') + c_group: bytes = group + rc: C.int = zmq_join(self.handle, c_group) + _check_rc(rc) + + def leave(self, group): + """ + Leave a RADIO-DISH group + + Only for DISH sockets. + + libzmq and pyzmq must have been built with ZMQ_BUILD_DRAFT_API + + .. versionadded:: 17 + """ + _check_version((4, 2), "RADIO-DISH") + if not zmq.DRAFT_API: + raise RuntimeError("libzmq and pyzmq must be built with draft support") + rc: C.int = zmq_leave(self.handle, group) + _check_rc(rc) + + def send(self, data, flags=0, copy: bint = True, track: bint = False): + """ + Send a single zmq message frame on this socket. + + This queues the message to be sent by the IO thread at a later time. + + With flags=NOBLOCK, this raises :class:`ZMQError` if the queue is full; + otherwise, this waits until space is available. + See :class:`Poller` for more general non-blocking I/O. + + Parameters + ---------- + data : bytes, Frame, memoryview + The content of the message. This can be any object that provides + the Python buffer API (`memoryview(data)` can be called). + flags : int + 0, NOBLOCK, SNDMORE, or NOBLOCK|SNDMORE. + copy : bool + Should the message be sent in a copying or non-copying manner. + track : bool + Should the message be tracked for notification that ZMQ has + finished with it? (ignored if copy=True) + + Returns + ------- + None : if `copy` or not track + None if message was sent, raises an exception otherwise. + MessageTracker : if track and not copy + a MessageTracker object, whose `done` property will + be False until the send is completed. + + Raises + ------ + TypeError + If a unicode object is passed + ValueError + If `track=True`, but an untracked Frame is passed. + ZMQError + for any of the reasons zmq_msg_send might fail (including + if NOBLOCK is set and the outgoing queue is full). + + """ + _check_closed(self) + + if isinstance(data, str): + raise TypeError("unicode not allowed, use send_string") + + if copy and not isinstance(data, Frame): + return _send_copy(self.handle, data, flags) + else: + if isinstance(data, Frame): + if track and not data.tracker: + raise ValueError('Not a tracked message') + msg = data + else: + if self.copy_threshold: + buf = memoryview(data) + nbytes: size_t = buf.nbytes + copy_threshold: size_t = self.copy_threshold + # always copy messages smaller than copy_threshold + if nbytes < copy_threshold: + _send_copy(self.handle, buf, flags) + return zmq._FINISHED_TRACKER + msg = Frame(data, track=track, copy_threshold=self.copy_threshold) + return _send_frame(self.handle, msg, flags) + + def recv(self, flags=0, copy: bint = True, track: bint = False): + """ + Receive a message. + + With flags=NOBLOCK, this raises :class:`ZMQError` if no messages have + arrived; otherwise, this waits until a message arrives. + See :class:`Poller` for more general non-blocking I/O. + + Parameters + ---------- + flags : int + 0 or NOBLOCK. + copy : bool + Should the message be received in a copying or non-copying manner? + If False a Frame object is returned, if True a string copy of + message is returned. + track : bool + Should the message be tracked for notification that ZMQ has + finished with it? (ignored if copy=True) + + Returns + ------- + msg : bytes or Frame + The received message frame. If `copy` is False, then it will be a Frame, + otherwise it will be bytes. + + Raises + ------ + ZMQError + for any of the reasons zmq_msg_recv might fail (including if + NOBLOCK is set and no new messages have arrived). + """ + _check_closed(self) + + if copy: + return _recv_copy(self.handle, flags) + else: + frame = _recv_frame(self.handle, flags, track) + more: bint = False + sz: size_t = sizeof(bint) + _getsockopt( + self.handle, ZMQ_RCVMORE, cast(p_void, address(more)), address(sz) + ) + frame.more = more + return frame + + def recv_into(self, buffer, /, *, nbytes=0, flags=0) -> C.int: + """ + Receive up to nbytes bytes from the socket, + storing the data into a buffer rather than allocating a new Frame. + + The next message frame can be discarded by receiving into an empty buffer:: + + sock.recv_into(bytearray()) + + .. versionadded:: 26.4 + + Parameters + ---------- + buffer : memoryview + Any object providing the buffer interface (i.e. `memoryview(buffer)` works), + where the memoryview is contiguous and writable. + nbytes: int, default=0 + The maximum number of bytes to receive. + If nbytes is not specified (or 0), receive up to the size available in the given buffer. + If the next frame is larger than this, the frame will be truncated and message content discarded. + flags: int, default=0 + See `socket.recv` + + Returns + ------- + bytes_received: int + Returns the number of bytes received. + This is always the size of the received frame. + If the returned `bytes_received` is larger than `nbytes` (or size of `buffer` if `nbytes=0`), + the message has been truncated and the rest of the frame discarded. + Truncated data cannot be recovered. + + Raises + ------ + ZMQError + for any of the reasons `zmq_recv` might fail. + BufferError + for invalid buffers, such as readonly or not contiguous. + """ + c_flags: C.int = flags + _check_closed(self) + c_nbytes: size_t = nbytes + if c_nbytes < 0: + raise ValueError(f"{nbytes=} must be non-negative") + view = memoryview(buffer) + c_data = declare(pointer(C.void)) + view_bytes: C.size_t = _asbuffer(view, address(c_data), True) + if nbytes == 0: + c_nbytes = view_bytes + elif c_nbytes > view_bytes: + raise ValueError(f"{nbytes=} too big for memoryview of {view_bytes}B") + + # call zmq_recv, with retries + while True: + with nogil: + rc: C.int = zmq_recv(self.handle, c_data, c_nbytes, c_flags) + try: + _check_rc(rc) + except InterruptedSystemCall: + continue + else: + return rc + + +# inline socket methods + + +@inline +@cfunc +def _check_closed(s: Socket): + """raise ENOTSUP if socket is closed + + Does not do a deep check + """ + if s._closed: + raise ZMQError(ENOTSOCK) + + +@inline +@cfunc +def _check_closed_deep(s: Socket) -> bint: + """thorough check of whether the socket has been closed, + even if by another entity (e.g. ctx.destroy). + + Only used by the `closed` property. + + returns True if closed, False otherwise + """ + rc: C.int + errno: C.int + stype = declare(C.int) + sz: size_t = sizeof(int) + + if s._closed: + return True + else: + rc = zmq_getsockopt( + s.handle, ZMQ_TYPE, cast(p_void, address(stype)), address(sz) + ) + if rc < 0: + errno = _zmq_errno() + if errno == ENOTSOCK: + s._closed = True + return True + elif errno == ZMQ_ETERM: + # don't raise ETERM when checking if we're closed + return False + else: + _check_rc(rc) + return False + + +@cfunc +@inline +def _recv_frame(handle: p_void, flags: C.int = 0, track: bint = False) -> Frame: + """Receive a message in a non-copying manner and return a Frame.""" + rc: C.int + msg = zmq.Frame(track=track) + cmsg: Frame = msg + + while True: + with nogil: + rc = zmq_msg_recv(address(cmsg.zmq_msg), handle, flags) + try: + _check_rc(rc) + except InterruptedSystemCall: + continue + else: + break + return msg + + +@cfunc +@inline +def _recv_copy(handle: p_void, flags: C.int = 0): + """Receive a message and return a copy""" + zmq_msg = declare(zmq_msg_t) + zmq_msg_p: pointer(zmq_msg_t) = address(zmq_msg) + rc: C.int = zmq_msg_init(zmq_msg_p) + _check_rc(rc) + while True: + with nogil: + rc = zmq_msg_recv(zmq_msg_p, handle, flags) + try: + _check_rc(rc) + except InterruptedSystemCall: + continue + except Exception: + zmq_msg_close(zmq_msg_p) # ensure msg is closed on failure + raise + else: + break + + msg_bytes = _copy_zmq_msg_bytes(zmq_msg_p) + zmq_msg_close(zmq_msg_p) + return msg_bytes + + +@cfunc +@inline +def _send_frame(handle: p_void, msg: Frame, flags: C.int = 0): + """Send a Frame on this socket in a non-copy manner.""" + rc: C.int + msg_copy: Frame + + # Always copy so the original message isn't garbage collected. + # This doesn't do a real copy, just a reference. + msg_copy = msg.fast_copy() + + while True: + with nogil: + rc = zmq_msg_send(address(msg_copy.zmq_msg), handle, flags) + try: + _check_rc(rc) + except InterruptedSystemCall: + continue + else: + break + + return msg.tracker + + +@cfunc +@inline +def _send_copy(handle: p_void, buf, flags: C.int = 0): + """Send a message on this socket by copying its content.""" + rc: C.int + msg = declare(zmq_msg_t) + c_bytes = declare(p_void) + + # copy to c array: + c_bytes_len = _asbuffer(buf, address(c_bytes)) + + # Copy the msg before sending. This avoids any complications with + # the GIL, etc. + # If zmq_msg_init_* fails we must not call zmq_msg_close (Bus Error) + rc = zmq_msg_init_size(address(msg), c_bytes_len) + _check_rc(rc) + + while True: + with nogil: + memcpy(zmq_msg_data(address(msg)), c_bytes, zmq_msg_size(address(msg))) + rc = zmq_msg_send(address(msg), handle, flags) + try: + _check_rc(rc) + except InterruptedSystemCall: + continue + except Exception: + zmq_msg_close(address(msg)) # close the unused msg + raise # raise original exception + else: + rc = zmq_msg_close(address(msg)) + _check_rc(rc) + break + + +@cfunc +@inline +def _getsockopt(handle: p_void, option: C.int, optval: p_void, sz: pointer(size_t)): + """getsockopt, retrying interrupted calls + + checks rc, raising ZMQError on failure. + """ + rc: C.int = 0 + while True: + rc = zmq_getsockopt(handle, option, optval, sz) + try: + _check_rc(rc) + except InterruptedSystemCall: + continue + else: + break + + +@cfunc +@inline +def _setsockopt(handle: p_void, option: C.int, optval: p_void, sz: size_t): + """setsockopt, retrying interrupted calls + + checks rc, raising ZMQError on failure. + """ + rc: C.int = 0 + while True: + rc = zmq_setsockopt(handle, option, optval, sz) + try: + _check_rc(rc) + except InterruptedSystemCall: + continue + else: + break + + +# General utility functions + + +def zmq_errno() -> C.int: + """Return the integer errno of the most recent zmq error.""" + return _zmq_errno() + + +def strerror(errno: C.int) -> str: + """ + Return the error string given the error number. + """ + str_e: bytes = zmq_strerror(errno) + return str_e.decode("utf8", "replace") + + +def zmq_version_info() -> tuple[int, int, int]: + """Return the version of ZeroMQ itself as a 3-tuple of ints.""" + major: C.int = 0 + minor: C.int = 0 + patch: C.int = 0 + _zmq_version(address(major), address(minor), address(patch)) + return (major, minor, patch) + + +def has(capability: str) -> bool: + """Check for zmq capability by name (e.g. 'ipc', 'curve') + + .. versionadded:: libzmq-4.1 + .. versionadded:: 14.1 + """ + _check_version((4, 1), 'zmq.has') + ccap: bytes = capability.encode('utf8') + return bool(zmq_has(ccap)) + + +def curve_keypair() -> tuple[bytes, bytes]: + """generate a Z85 key pair for use with zmq.CURVE security + + Requires libzmq (≥ 4.0) to have been built with CURVE support. + + .. versionadded:: libzmq-4.0 + .. versionadded:: 14.0 + + Returns + ------- + public: bytes + The public key as 40 byte z85-encoded bytestring. + private: bytes + The private key as 40 byte z85-encoded bytestring. + """ + rc: C.int + public_key = declare(char[64]) + secret_key = declare(char[64]) + _check_version((4, 0), "curve_keypair") + # see huge comment in libzmq/src/random.cpp + # about threadsafety of random initialization + rc = zmq_curve_keypair(public_key, secret_key) + _check_rc(rc) + return public_key, secret_key + + +def curve_public(secret_key) -> bytes: + """Compute the public key corresponding to a secret key for use + with zmq.CURVE security + + Requires libzmq (≥ 4.2) to have been built with CURVE support. + + Parameters + ---------- + private + The private key as a 40 byte z85-encoded bytestring + + Returns + ------- + bytes + The public key as a 40 byte z85-encoded bytestring + """ + if isinstance(secret_key, str): + secret_key = secret_key.encode('utf8') + if not len(secret_key) == 40: + raise ValueError('secret key must be a 40 byte z85 encoded string') + + rc: C.int + public_key = declare(char[64]) + c_secret_key: pointer(char) = secret_key + _check_version((4, 2), "curve_public") + # see huge comment in libzmq/src/random.cpp + # about threadsafety of random initialization + rc = zmq_curve_public(public_key, c_secret_key) + _check_rc(rc) + return public_key[:40] + + +# polling +def zmq_poll(sockets, timeout: C.int = -1): + """zmq_poll(sockets, timeout=-1) + + Poll a set of 0MQ sockets, native file descs. or sockets. + + Parameters + ---------- + sockets : list of tuples of (socket, flags) + Each element of this list is a two-tuple containing a socket + and a flags. The socket may be a 0MQ socket or any object with + a ``fileno()`` method. The flags can be zmq.POLLIN (for detecting + for incoming messages), zmq.POLLOUT (for detecting that send is OK) + or zmq.POLLIN|zmq.POLLOUT for detecting both. + timeout : int + The number of milliseconds to poll for. Negative means no timeout. + """ + rc: C.int + i: C.int + fileno: fd_t + events: C.int + pollitems: pointer(zmq_pollitem_t) = NULL + nsockets: C.int = len(sockets) + + if nsockets == 0: + return [] + + pollitems = cast(pointer(zmq_pollitem_t), malloc(nsockets * sizeof(zmq_pollitem_t))) + if pollitems == NULL: + raise MemoryError("Could not allocate poll items") + + for i in range(nsockets): + s, events = sockets[i] + if isinstance(s, Socket): + pollitems[i].socket = cast(Socket, s).handle + pollitems[i].fd = 0 + pollitems[i].events = events + pollitems[i].revents = 0 + elif isinstance(s, int): + fileno = s + pollitems[i].socket = NULL + pollitems[i].fd = fileno + pollitems[i].events = events + pollitems[i].revents = 0 + elif hasattr(s, 'fileno'): + try: + fileno = int(s.fileno()) + except Exception: + free(pollitems) + raise ValueError('fileno() must return a valid integer fd') + else: + pollitems[i].socket = NULL + pollitems[i].fd = fileno + pollitems[i].events = events + pollitems[i].revents = 0 + else: + free(pollitems) + raise TypeError( + "Socket must be a 0MQ socket, an integer fd or have " + f"a fileno() method: {s!r}" + ) + + ms_passed: C.int = 0 + tic: C.int + try: + while True: + start: C.int = monotonic() + with nogil: + rc = zmq_poll_c(pollitems, nsockets, timeout) + try: + _check_rc(rc) + except InterruptedSystemCall: + if timeout > 0: + tic = monotonic() + ms_passed = int(1000 * (tic - start)) + if ms_passed < 0: + # don't allow negative ms_passed, + # which can happen on old Python versions without time.monotonic. + warnings.warn( + f"Negative elapsed time for interrupted poll: {ms_passed}." + " Did the clock change?", + RuntimeWarning, + ) + # treat this case the same as no time passing, + # since it should be rare and not happen twice in a row. + ms_passed = 0 + timeout = max(0, timeout - ms_passed) + continue + else: + break + except Exception: + free(pollitems) + raise + + results = [] + for i in range(nsockets): + revents = pollitems[i].revents + # for compatibility with select.poll: + # - only return sockets with non-zero status + # - return the fd for plain sockets + if revents > 0: + if pollitems[i].socket != NULL: + s = sockets[i][0] + else: + s = pollitems[i].fd + results.append((s, revents)) + + free(pollitems) + return results + + +def proxy(frontend: Socket, backend: Socket, capture: Socket = None): + """ + Start a zeromq proxy (replacement for device). + + .. versionadded:: libzmq-3.2 + .. versionadded:: 13.0 + + Parameters + ---------- + frontend : Socket + The Socket instance for the incoming traffic. + backend : Socket + The Socket instance for the outbound traffic. + capture : Socket (optional) + The Socket instance for capturing traffic. + """ + rc: C.int = 0 + capture_handle: p_void + if isinstance(capture, Socket): + capture_handle = capture.handle + else: + capture_handle = NULL + while True: + with nogil: + rc = zmq_proxy(frontend.handle, backend.handle, capture_handle) + try: + _check_rc(rc) + except InterruptedSystemCall: + continue + else: + break + return rc + + +def proxy_steerable( + frontend: Socket, + backend: Socket, + capture: Socket = None, + control: Socket = None, +): + """ + Start a zeromq proxy with control flow. + + .. versionadded:: libzmq-4.1 + .. versionadded:: 18.0 + + Parameters + ---------- + frontend : Socket + The Socket instance for the incoming traffic. + backend : Socket + The Socket instance for the outbound traffic. + capture : Socket (optional) + The Socket instance for capturing traffic. + control : Socket (optional) + The Socket instance for control flow. + """ + rc: C.int = 0 + capture_handle: p_void + if isinstance(capture, Socket): + capture_handle = capture.handle + else: + capture_handle = NULL + if isinstance(control, Socket): + control_handle = control.handle + else: + control_handle = NULL + while True: + with nogil: + rc = zmq_proxy_steerable( + frontend.handle, backend.handle, capture_handle, control_handle + ) + try: + _check_rc(rc) + except InterruptedSystemCall: + continue + else: + break + return rc + + +# monitored queue - like proxy (predates libzmq proxy) +# but supports ROUTER-ROUTER devices +@cfunc +@inline +@nogil +def _mq_relay( + in_socket: p_void, + out_socket: p_void, + side_socket: p_void, + msg: zmq_msg_t, + side_msg: zmq_msg_t, + id_msg: zmq_msg_t, + swap_ids: bint, +) -> C.int: + rc: C.int + flags: C.int + flagsz = declare(size_t) + more = declare(int) + flagsz = sizeof(int) + + if swap_ids: # both router, must send second identity first + # recv two ids into msg, id_msg + rc = zmq_msg_recv(address(msg), in_socket, 0) + if rc < 0: + return rc + + rc = zmq_msg_recv(address(id_msg), in_socket, 0) + if rc < 0: + return rc + + # send second id (id_msg) first + # !!!! always send a copy before the original !!!! + rc = zmq_msg_copy(address(side_msg), address(id_msg)) + if rc < 0: + return rc + rc = zmq_msg_send(address(side_msg), out_socket, ZMQ_SNDMORE) + if rc < 0: + return rc + rc = zmq_msg_send(address(id_msg), side_socket, ZMQ_SNDMORE) + if rc < 0: + return rc + # send first id (msg) second + rc = zmq_msg_copy(address(side_msg), address(msg)) + if rc < 0: + return rc + rc = zmq_msg_send(address(side_msg), out_socket, ZMQ_SNDMORE) + if rc < 0: + return rc + rc = zmq_msg_send(address(msg), side_socket, ZMQ_SNDMORE) + if rc < 0: + return rc + while True: + rc = zmq_msg_recv(address(msg), in_socket, 0) + if rc < 0: + return rc + # assert (rc == 0) + rc = zmq_getsockopt(in_socket, ZMQ_RCVMORE, address(more), address(flagsz)) + if rc < 0: + return rc + flags = 0 + if more: + flags |= ZMQ_SNDMORE + + rc = zmq_msg_copy(address(side_msg), address(msg)) + if rc < 0: + return rc + if flags: + rc = zmq_msg_send(address(side_msg), out_socket, flags) + if rc < 0: + return rc + # only SNDMORE for side-socket + rc = zmq_msg_send(address(msg), side_socket, ZMQ_SNDMORE) + if rc < 0: + return rc + else: + rc = zmq_msg_send(address(side_msg), out_socket, 0) + if rc < 0: + return rc + rc = zmq_msg_send(address(msg), side_socket, 0) + if rc < 0: + return rc + break + return rc + + +@cfunc +@inline +@nogil +def _mq_inline( + in_socket: p_void, + out_socket: p_void, + side_socket: p_void, + in_msg_ptr: pointer(zmq_msg_t), + out_msg_ptr: pointer(zmq_msg_t), + swap_ids: bint, +) -> C.int: + """ + inner C function for monitored_queue + """ + + msg: zmq_msg_t = declare(zmq_msg_t) + rc: C.int = zmq_msg_init(address(msg)) + id_msg = declare(zmq_msg_t) + rc = zmq_msg_init(address(id_msg)) + if rc < 0: + return rc + side_msg = declare(zmq_msg_t) + rc = zmq_msg_init(address(side_msg)) + if rc < 0: + return rc + + items = declare(zmq_pollitem_t[2]) + items[0].socket = in_socket + items[0].events = ZMQ_POLLIN + items[0].fd = items[0].revents = 0 + items[1].socket = out_socket + items[1].events = ZMQ_POLLIN + items[1].fd = items[1].revents = 0 + + while True: + # wait for the next message to process + rc = zmq_poll_c(address(items[0]), 2, -1) + if rc < 0: + return rc + if items[0].revents & ZMQ_POLLIN: + # send in_prefix to side socket + rc = zmq_msg_copy(address(side_msg), in_msg_ptr) + if rc < 0: + return rc + rc = zmq_msg_send(address(side_msg), side_socket, ZMQ_SNDMORE) + if rc < 0: + return rc + # relay the rest of the message + rc = _mq_relay( + in_socket, out_socket, side_socket, msg, side_msg, id_msg, swap_ids + ) + if rc < 0: + return rc + if items[1].revents & ZMQ_POLLIN: + # send out_prefix to side socket + rc = zmq_msg_copy(address(side_msg), out_msg_ptr) + if rc < 0: + return rc + rc = zmq_msg_send(address(side_msg), side_socket, ZMQ_SNDMORE) + if rc < 0: + return rc + # relay the rest of the message + rc = _mq_relay( + out_socket, in_socket, side_socket, msg, side_msg, id_msg, swap_ids + ) + if rc < 0: + return rc + return rc + + +def monitored_queue( + in_socket: Socket, + out_socket: Socket, + mon_socket: Socket, + in_prefix: bytes = b'in', + out_prefix: bytes = b'out', +): + """ + Start a monitored queue device. + + A monitored queue is very similar to the zmq.proxy device (monitored queue came first). + + Differences from zmq.proxy: + + - monitored_queue supports both in and out being ROUTER sockets + (via swapping IDENTITY prefixes). + - monitor messages are prefixed, making in and out messages distinguishable. + + Parameters + ---------- + in_socket : zmq.Socket + One of the sockets to the Queue. Its messages will be prefixed with + 'in'. + out_socket : zmq.Socket + One of the sockets to the Queue. Its messages will be prefixed with + 'out'. The only difference between in/out socket is this prefix. + mon_socket : zmq.Socket + This socket sends out every message received by each of the others + with an in/out prefix specifying which one it was. + in_prefix : str + Prefix added to broadcast messages from in_socket. + out_prefix : str + Prefix added to broadcast messages from out_socket. + """ + ins: p_void = in_socket.handle + outs: p_void = out_socket.handle + mons: p_void = mon_socket.handle + in_msg = declare(zmq_msg_t) + out_msg = declare(zmq_msg_t) + swap_ids: bint + msg_c: p_void = NULL + msg_c_len = declare(Py_ssize_t) + rc: C.int + + # force swap_ids if both ROUTERs + swap_ids = in_socket.type == ZMQ_ROUTER and out_socket.type == ZMQ_ROUTER + + # build zmq_msg objects from str prefixes + msg_c_len = _asbuffer(in_prefix, address(msg_c)) + rc = zmq_msg_init_size(address(in_msg), msg_c_len) + _check_rc(rc) + + memcpy(zmq_msg_data(address(in_msg)), msg_c, zmq_msg_size(address(in_msg))) + + msg_c_len = _asbuffer(out_prefix, address(msg_c)) + + rc = zmq_msg_init_size(address(out_msg), msg_c_len) + _check_rc(rc) + + while True: + with nogil: + memcpy( + zmq_msg_data(address(out_msg)), msg_c, zmq_msg_size(address(out_msg)) + ) + rc = _mq_inline( + ins, outs, mons, address(in_msg), address(out_msg), swap_ids + ) + try: + _check_rc(rc) + except InterruptedSystemCall: + continue + else: + break + return rc + + +__all__ = [ + 'IPC_PATH_MAX_LEN', + 'PYZMQ_DRAFT_API', + 'Context', + 'Socket', + 'Frame', + 'has', + 'curve_keypair', + 'curve_public', + 'zmq_version_info', + 'zmq_errno', + 'zmq_poll', + 'strerror', + 'proxy', + 'proxy_steerable', +] diff --git a/source/zmq/backend/cython/constant_enums.pxi b/source/zmq/backend/cython/constant_enums.pxi new file mode 100644 index 0000000000000000000000000000000000000000..811d2a84f56f1b96d98a3f4a57cfbacf50c69a08 --- /dev/null +++ b/source/zmq/backend/cython/constant_enums.pxi @@ -0,0 +1,250 @@ +cdef extern from "zmq.h" nogil: + enum: PYZMQ_DRAFT_API + enum: ZMQ_VERSION + enum: ZMQ_VERSION_MAJOR + enum: ZMQ_VERSION_MINOR + enum: ZMQ_VERSION_PATCH + enum: ZMQ_IO_THREADS + enum: ZMQ_MAX_SOCKETS + enum: ZMQ_SOCKET_LIMIT + enum: ZMQ_THREAD_PRIORITY + enum: ZMQ_THREAD_SCHED_POLICY + enum: ZMQ_MAX_MSGSZ + enum: ZMQ_MSG_T_SIZE + enum: ZMQ_THREAD_AFFINITY_CPU_ADD + enum: ZMQ_THREAD_AFFINITY_CPU_REMOVE + enum: ZMQ_THREAD_NAME_PREFIX + enum: ZMQ_STREAMER + enum: ZMQ_FORWARDER + enum: ZMQ_QUEUE + enum: ZMQ_EAGAIN "EAGAIN" + enum: ZMQ_EFAULT "EFAULT" + enum: ZMQ_EINVAL "EINVAL" + enum: ZMQ_ENOTSUP "ENOTSUP" + enum: ZMQ_EPROTONOSUPPORT "EPROTONOSUPPORT" + enum: ZMQ_ENOBUFS "ENOBUFS" + enum: ZMQ_ENETDOWN "ENETDOWN" + enum: ZMQ_EADDRINUSE "EADDRINUSE" + enum: ZMQ_EADDRNOTAVAIL "EADDRNOTAVAIL" + enum: ZMQ_ECONNREFUSED "ECONNREFUSED" + enum: ZMQ_EINPROGRESS "EINPROGRESS" + enum: ZMQ_ENOTSOCK "ENOTSOCK" + enum: ZMQ_EMSGSIZE "EMSGSIZE" + enum: ZMQ_EAFNOSUPPORT "EAFNOSUPPORT" + enum: ZMQ_ENETUNREACH "ENETUNREACH" + enum: ZMQ_ECONNABORTED "ECONNABORTED" + enum: ZMQ_ECONNRESET "ECONNRESET" + enum: ZMQ_ENOTCONN "ENOTCONN" + enum: ZMQ_ETIMEDOUT "ETIMEDOUT" + enum: ZMQ_EHOSTUNREACH "EHOSTUNREACH" + enum: ZMQ_ENETRESET "ENETRESET" + enum: ZMQ_EFSM "EFSM" + enum: ZMQ_ENOCOMPATPROTO "ENOCOMPATPROTO" + enum: ZMQ_ETERM "ETERM" + enum: ZMQ_EMTHREAD "EMTHREAD" + enum: ZMQ_PROTOCOL_ERROR_WS_UNSPECIFIED + enum: ZMQ_PROTOCOL_ERROR_ZMTP_UNSPECIFIED + enum: ZMQ_PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND + enum: ZMQ_PROTOCOL_ERROR_ZMTP_INVALID_SEQUENCE + enum: ZMQ_PROTOCOL_ERROR_ZMTP_KEY_EXCHANGE + enum: ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_UNSPECIFIED + enum: ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_MESSAGE + enum: ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_HELLO + enum: ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_INITIATE + enum: ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_ERROR + enum: ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_READY + enum: ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_WELCOME + enum: ZMQ_PROTOCOL_ERROR_ZMTP_INVALID_METADATA + enum: ZMQ_PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC + enum: ZMQ_PROTOCOL_ERROR_ZMTP_MECHANISM_MISMATCH + enum: ZMQ_PROTOCOL_ERROR_ZAP_UNSPECIFIED + enum: ZMQ_PROTOCOL_ERROR_ZAP_MALFORMED_REPLY + enum: ZMQ_PROTOCOL_ERROR_ZAP_BAD_REQUEST_ID + enum: ZMQ_PROTOCOL_ERROR_ZAP_BAD_VERSION + enum: ZMQ_PROTOCOL_ERROR_ZAP_INVALID_STATUS_CODE + enum: ZMQ_PROTOCOL_ERROR_ZAP_INVALID_METADATA + enum: ZMQ_EVENT_CONNECTED + enum: ZMQ_EVENT_CONNECT_DELAYED + enum: ZMQ_EVENT_CONNECT_RETRIED + enum: ZMQ_EVENT_LISTENING + enum: ZMQ_EVENT_BIND_FAILED + enum: ZMQ_EVENT_ACCEPTED + enum: ZMQ_EVENT_ACCEPT_FAILED + enum: ZMQ_EVENT_CLOSED + enum: ZMQ_EVENT_CLOSE_FAILED + enum: ZMQ_EVENT_DISCONNECTED + enum: ZMQ_EVENT_MONITOR_STOPPED + enum: ZMQ_EVENT_HANDSHAKE_FAILED_NO_DETAIL + enum: ZMQ_EVENT_HANDSHAKE_SUCCEEDED + enum: ZMQ_EVENT_HANDSHAKE_FAILED_PROTOCOL + enum: ZMQ_EVENT_HANDSHAKE_FAILED_AUTH + enum: ZMQ_EVENT_ALL_V1 + enum: ZMQ_EVENT_ALL + enum: ZMQ_EVENT_PIPES_STATS + enum: ZMQ_EVENT_ALL_V2 + enum: ZMQ_DONTWAIT + enum: ZMQ_SNDMORE + enum: ZMQ_NOBLOCK + enum: ZMQ_MORE + enum: ZMQ_SHARED + enum: ZMQ_SRCFD + enum: ZMQ_NORM_FIXED + enum: ZMQ_NORM_CC + enum: ZMQ_NORM_CCL + enum: ZMQ_NORM_CCE + enum: ZMQ_NORM_CCE_ECNONLY + enum: ZMQ_POLLIN + enum: ZMQ_POLLOUT + enum: ZMQ_POLLERR + enum: ZMQ_POLLPRI + enum: ZMQ_RECONNECT_STOP_CONN_REFUSED + enum: ZMQ_RECONNECT_STOP_HANDSHAKE_FAILED + enum: ZMQ_RECONNECT_STOP_AFTER_DISCONNECT + enum: ZMQ_NOTIFY_CONNECT + enum: ZMQ_NOTIFY_DISCONNECT + enum: ZMQ_NULL + enum: ZMQ_PLAIN + enum: ZMQ_CURVE + enum: ZMQ_GSSAPI + enum: ZMQ_HWM + enum: ZMQ_AFFINITY + enum: ZMQ_ROUTING_ID + enum: ZMQ_SUBSCRIBE + enum: ZMQ_UNSUBSCRIBE + enum: ZMQ_RATE + enum: ZMQ_RECOVERY_IVL + enum: ZMQ_SNDBUF + enum: ZMQ_RCVBUF + enum: ZMQ_RCVMORE + enum: ZMQ_FD + enum: ZMQ_EVENTS + enum: ZMQ_TYPE + enum: ZMQ_LINGER + enum: ZMQ_RECONNECT_IVL + enum: ZMQ_BACKLOG + enum: ZMQ_RECONNECT_IVL_MAX + enum: ZMQ_MAXMSGSIZE + enum: ZMQ_SNDHWM + enum: ZMQ_RCVHWM + enum: ZMQ_MULTICAST_HOPS + enum: ZMQ_RCVTIMEO + enum: ZMQ_SNDTIMEO + enum: ZMQ_LAST_ENDPOINT + enum: ZMQ_ROUTER_MANDATORY + enum: ZMQ_TCP_KEEPALIVE + enum: ZMQ_TCP_KEEPALIVE_CNT + enum: ZMQ_TCP_KEEPALIVE_IDLE + enum: ZMQ_TCP_KEEPALIVE_INTVL + enum: ZMQ_IMMEDIATE + enum: ZMQ_XPUB_VERBOSE + enum: ZMQ_ROUTER_RAW + enum: ZMQ_IPV6 + enum: ZMQ_MECHANISM + enum: ZMQ_PLAIN_SERVER + enum: ZMQ_PLAIN_USERNAME + enum: ZMQ_PLAIN_PASSWORD + enum: ZMQ_CURVE_SERVER + enum: ZMQ_CURVE_PUBLICKEY + enum: ZMQ_CURVE_SECRETKEY + enum: ZMQ_CURVE_SERVERKEY + enum: ZMQ_PROBE_ROUTER + enum: ZMQ_REQ_CORRELATE + enum: ZMQ_REQ_RELAXED + enum: ZMQ_CONFLATE + enum: ZMQ_ZAP_DOMAIN + enum: ZMQ_ROUTER_HANDOVER + enum: ZMQ_TOS + enum: ZMQ_CONNECT_ROUTING_ID + enum: ZMQ_GSSAPI_SERVER + enum: ZMQ_GSSAPI_PRINCIPAL + enum: ZMQ_GSSAPI_SERVICE_PRINCIPAL + enum: ZMQ_GSSAPI_PLAINTEXT + enum: ZMQ_HANDSHAKE_IVL + enum: ZMQ_SOCKS_PROXY + enum: ZMQ_XPUB_NODROP + enum: ZMQ_BLOCKY + enum: ZMQ_XPUB_MANUAL + enum: ZMQ_XPUB_WELCOME_MSG + enum: ZMQ_STREAM_NOTIFY + enum: ZMQ_INVERT_MATCHING + enum: ZMQ_HEARTBEAT_IVL + enum: ZMQ_HEARTBEAT_TTL + enum: ZMQ_HEARTBEAT_TIMEOUT + enum: ZMQ_XPUB_VERBOSER + enum: ZMQ_CONNECT_TIMEOUT + enum: ZMQ_TCP_MAXRT + enum: ZMQ_THREAD_SAFE + enum: ZMQ_MULTICAST_MAXTPDU + enum: ZMQ_VMCI_BUFFER_SIZE + enum: ZMQ_VMCI_BUFFER_MIN_SIZE + enum: ZMQ_VMCI_BUFFER_MAX_SIZE + enum: ZMQ_VMCI_CONNECT_TIMEOUT + enum: ZMQ_USE_FD + enum: ZMQ_GSSAPI_PRINCIPAL_NAMETYPE + enum: ZMQ_GSSAPI_SERVICE_PRINCIPAL_NAMETYPE + enum: ZMQ_BINDTODEVICE + enum: ZMQ_IDENTITY + enum: ZMQ_CONNECT_RID + enum: ZMQ_TCP_ACCEPT_FILTER + enum: ZMQ_IPC_FILTER_PID + enum: ZMQ_IPC_FILTER_UID + enum: ZMQ_IPC_FILTER_GID + enum: ZMQ_IPV4ONLY + enum: ZMQ_DELAY_ATTACH_ON_CONNECT + enum: ZMQ_FAIL_UNROUTABLE + enum: ZMQ_ROUTER_BEHAVIOR + enum: ZMQ_ZAP_ENFORCE_DOMAIN + enum: ZMQ_LOOPBACK_FASTPATH + enum: ZMQ_METADATA + enum: ZMQ_MULTICAST_LOOP + enum: ZMQ_ROUTER_NOTIFY + enum: ZMQ_XPUB_MANUAL_LAST_VALUE + enum: ZMQ_SOCKS_USERNAME + enum: ZMQ_SOCKS_PASSWORD + enum: ZMQ_IN_BATCH_SIZE + enum: ZMQ_OUT_BATCH_SIZE + enum: ZMQ_WSS_KEY_PEM + enum: ZMQ_WSS_CERT_PEM + enum: ZMQ_WSS_TRUST_PEM + enum: ZMQ_WSS_HOSTNAME + enum: ZMQ_WSS_TRUST_SYSTEM + enum: ZMQ_ONLY_FIRST_SUBSCRIBE + enum: ZMQ_RECONNECT_STOP + enum: ZMQ_HELLO_MSG + enum: ZMQ_DISCONNECT_MSG + enum: ZMQ_PRIORITY + enum: ZMQ_BUSY_POLL + enum: ZMQ_HICCUP_MSG + enum: ZMQ_XSUB_VERBOSE_UNSUBSCRIBE + enum: ZMQ_TOPICS_COUNT + enum: ZMQ_NORM_MODE + enum: ZMQ_NORM_UNICAST_NACK + enum: ZMQ_NORM_BUFFER_SIZE + enum: ZMQ_NORM_SEGMENT_SIZE + enum: ZMQ_NORM_BLOCK_SIZE + enum: ZMQ_NORM_NUM_PARITY + enum: ZMQ_NORM_NUM_AUTOPARITY + enum: ZMQ_NORM_PUSH + enum: ZMQ_PAIR + enum: ZMQ_PUB + enum: ZMQ_SUB + enum: ZMQ_REQ + enum: ZMQ_REP + enum: ZMQ_DEALER + enum: ZMQ_ROUTER + enum: ZMQ_PULL + enum: ZMQ_PUSH + enum: ZMQ_XPUB + enum: ZMQ_XSUB + enum: ZMQ_STREAM + enum: ZMQ_XREQ + enum: ZMQ_XREP + enum: ZMQ_SERVER + enum: ZMQ_CLIENT + enum: ZMQ_RADIO + enum: ZMQ_DISH + enum: ZMQ_GATHER + enum: ZMQ_SCATTER + enum: ZMQ_DGRAM + enum: ZMQ_PEER + enum: ZMQ_CHANNEL diff --git a/source/zmq/backend/cython/libzmq.pxd b/source/zmq/backend/cython/libzmq.pxd new file mode 100644 index 0000000000000000000000000000000000000000..2e88f58ebf39ca3044668cafa663b632faf17a2c --- /dev/null +++ b/source/zmq/backend/cython/libzmq.pxd @@ -0,0 +1,128 @@ +"""All the C imports for 0MQ""" + +# +# Copyright (c) 2010 Brian E. Granger & Min Ragan-Kelley +# +# This file is part of pyzmq. +# +# pyzmq is free software; you can redistribute it and/or modify it under +# the terms of the Lesser GNU General Public License as published by +# the Free Software Foundation; either version 3 of the License, or +# (at your option) any later version. +# +# pyzmq is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# Lesser GNU General Public License for more details. +# +# You should have received a copy of the Lesser GNU General Public License +# along with this program. If not, see . +# + +#----------------------------------------------------------------------------- +# Imports +#----------------------------------------------------------------------------- + +#----------------------------------------------------------------------------- +# Import the C header files +#----------------------------------------------------------------------------- + +# common includes, such as zmq compat, pyversion_compat +# make sure we load pyversion compat in every Cython module +cdef extern from "pyversion_compat.h": + pass + +# were it not for Windows, +# we could cimport these from libc.stdint +cdef extern from "zmq_compat.h": + ctypedef signed long long int64_t "pyzmq_int64_t" + ctypedef unsigned int uint32_t "pyzmq_uint32_t" + +include "constant_enums.pxi" + +cdef extern from "zmq.h" nogil: + + void _zmq_version "zmq_version"(int *major, int *minor, int *patch) + + ctypedef int fd_t "ZMQ_FD_T" + + enum: errno + const char *zmq_strerror (int errnum) + int zmq_errno() + + void *zmq_ctx_new () + int zmq_ctx_destroy (void *context) + int zmq_ctx_set (void *context, int option, int optval) + int zmq_ctx_get (void *context, int option) + void *zmq_init (int io_threads) + int zmq_term (void *context) + + # blackbox def for zmq_msg_t + ctypedef void * zmq_msg_t "zmq_msg_t" + + ctypedef void zmq_free_fn(void *data, void *hint) + + int zmq_msg_init (zmq_msg_t *msg) + int zmq_msg_init_size (zmq_msg_t *msg, size_t size) + int zmq_msg_init_data (zmq_msg_t *msg, void *data, + size_t size, zmq_free_fn *ffn, void *hint) + int zmq_msg_send (zmq_msg_t *msg, void *s, int flags) + int zmq_msg_recv (zmq_msg_t *msg, void *s, int flags) + int zmq_msg_close (zmq_msg_t *msg) + int zmq_msg_move (zmq_msg_t *dest, zmq_msg_t *src) + int zmq_msg_copy (zmq_msg_t *dest, zmq_msg_t *src) + void *zmq_msg_data (zmq_msg_t *msg) + size_t zmq_msg_size (zmq_msg_t *msg) + int zmq_msg_more (zmq_msg_t *msg) + int zmq_msg_get (zmq_msg_t *msg, int option) + int zmq_msg_set (zmq_msg_t *msg, int option, int optval) + const char *zmq_msg_gets (zmq_msg_t *msg, const char *property) + int zmq_has (const char *capability) + + void *zmq_socket (void *context, int type) + int zmq_close (void *s) + int zmq_setsockopt (void *s, int option, void *optval, size_t optvallen) + int zmq_getsockopt (void *s, int option, void *optval, size_t *optvallen) + int zmq_bind (void *s, char *addr) + int zmq_connect (void *s, char *addr) + int zmq_unbind (void *s, char *addr) + int zmq_disconnect (void *s, char *addr) + + int zmq_socket_monitor (void *s, char *addr, int flags) + + # send/recv + int zmq_send (void *s, const void *buf, size_t n, int flags) + int zmq_recv (void *s, void *buf, size_t n, int flags) + + ctypedef struct zmq_pollitem_t: + void *socket + fd_t fd + short events + short revents + + int zmq_poll (zmq_pollitem_t *items, int nitems, long timeout) + + int zmq_proxy (void *frontend, void *backend, void *capture) + int zmq_proxy_steerable (void *frontend, + void *backend, + void *capture, + void *control) + + int zmq_curve_keypair (char *z85_public_key, char *z85_secret_key) + int zmq_curve_public (char *z85_public_key, char *z85_secret_key) + + # 4.2 draft + int zmq_join (void *s, const char *group) + int zmq_leave (void *s, const char *group) + + int zmq_msg_set_routing_id(zmq_msg_t *msg, uint32_t routing_id) + uint32_t zmq_msg_routing_id(zmq_msg_t *msg) + int zmq_msg_set_group(zmq_msg_t *msg, const char *group) + const char *zmq_msg_group(zmq_msg_t *msg) + + void *zmq_poller_new () + int zmq_poller_destroy (void **poller_p_) + int zmq_poller_add (void *poller_, void *socket_, void *user_data_, short events_) + int zmq_poller_modify (void *poller_, void *socket_, short events_) + int zmq_poller_remove (void *poller_, void *socket_) + int zmq_poller_fd (void *poller_, fd_t *fd_) diff --git a/source/zmq/backend/select.py b/source/zmq/backend/select.py new file mode 100644 index 0000000000000000000000000000000000000000..17cbea8285337fa370983f18bb4ad90db686aa37 --- /dev/null +++ b/source/zmq/backend/select.py @@ -0,0 +1,41 @@ +"""Import basic exposure of libzmq C API as a backend""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from importlib import import_module +from typing import Dict + +public_api = [ + 'Context', + 'Socket', + 'Frame', + 'Message', + 'proxy', + 'proxy_steerable', + 'zmq_poll', + 'strerror', + 'zmq_errno', + 'has', + 'curve_keypair', + 'curve_public', + 'zmq_version_info', + 'IPC_PATH_MAX_LEN', + 'PYZMQ_DRAFT_API', +] + + +def select_backend(name: str) -> Dict: + """Select the pyzmq backend""" + try: + mod = import_module(name) + except ImportError: + raise + except Exception as e: + raise ImportError(f"Importing {name} failed with {e}") from e + ns = { + # private API + 'monitored_queue': mod.monitored_queue, + } + ns.update({key: getattr(mod, key) for key in public_api}) + return ns diff --git a/source/zmq/constants.py b/source/zmq/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..cddf433a15d9102b9162425f8d8a54aa125f783e --- /dev/null +++ b/source/zmq/constants.py @@ -0,0 +1,974 @@ +"""zmq constants as enums""" + +from __future__ import annotations + +import errno +import sys +from enum import Enum, IntEnum, IntFlag + +_HAUSNUMERO = 156384712 + + +class Errno(IntEnum): + """libzmq error codes + + .. versionadded:: 23 + """ + + EAGAIN = errno.EAGAIN + EFAULT = errno.EFAULT + EINVAL = errno.EINVAL + + if sys.platform.startswith("win"): + # Windows: libzmq uses errno.h + # while Python errno prefers WSA* variants + # many of these were introduced to errno.h in vs2010 + # ref: https://github.com/python/cpython/blob/3.9/Modules/errnomodule.c#L10-L37 + # source: https://docs.microsoft.com/en-us/cpp/c-runtime-library/errno-constants + ENOTSUP = 129 + EPROTONOSUPPORT = 135 + ENOBUFS = 119 + ENETDOWN = 116 + EADDRINUSE = 100 + EADDRNOTAVAIL = 101 + ECONNREFUSED = 107 + EINPROGRESS = 112 + ENOTSOCK = 128 + EMSGSIZE = 115 + EAFNOSUPPORT = 102 + ENETUNREACH = 118 + ECONNABORTED = 106 + ECONNRESET = 108 + ENOTCONN = 126 + ETIMEDOUT = 138 + EHOSTUNREACH = 110 + ENETRESET = 117 + + else: + ENOTSUP = getattr(errno, "ENOTSUP", _HAUSNUMERO + 1) + EPROTONOSUPPORT = getattr(errno, "EPROTONOSUPPORT", _HAUSNUMERO + 2) + ENOBUFS = getattr(errno, "ENOBUFS", _HAUSNUMERO + 3) + ENETDOWN = getattr(errno, "ENETDOWN", _HAUSNUMERO + 4) + EADDRINUSE = getattr(errno, "EADDRINUSE", _HAUSNUMERO + 5) + EADDRNOTAVAIL = getattr(errno, "EADDRNOTAVAIL", _HAUSNUMERO + 6) + ECONNREFUSED = getattr(errno, "ECONNREFUSED", _HAUSNUMERO + 7) + EINPROGRESS = getattr(errno, "EINPROGRESS", _HAUSNUMERO + 8) + ENOTSOCK = getattr(errno, "ENOTSOCK", _HAUSNUMERO + 9) + EMSGSIZE = getattr(errno, "EMSGSIZE", _HAUSNUMERO + 10) + EAFNOSUPPORT = getattr(errno, "EAFNOSUPPORT", _HAUSNUMERO + 11) + ENETUNREACH = getattr(errno, "ENETUNREACH", _HAUSNUMERO + 12) + ECONNABORTED = getattr(errno, "ECONNABORTED", _HAUSNUMERO + 13) + ECONNRESET = getattr(errno, "ECONNRESET", _HAUSNUMERO + 14) + ENOTCONN = getattr(errno, "ENOTCONN", _HAUSNUMERO + 15) + ETIMEDOUT = getattr(errno, "ETIMEDOUT", _HAUSNUMERO + 16) + EHOSTUNREACH = getattr(errno, "EHOSTUNREACH", _HAUSNUMERO + 17) + ENETRESET = getattr(errno, "ENETRESET", _HAUSNUMERO + 18) + + # Native 0MQ error codes + EFSM = _HAUSNUMERO + 51 + ENOCOMPATPROTO = _HAUSNUMERO + 52 + ETERM = _HAUSNUMERO + 53 + EMTHREAD = _HAUSNUMERO + 54 + + +class ContextOption(IntEnum): + """Options for Context.get/set + + .. versionadded:: 23 + """ + + IO_THREADS = 1 + MAX_SOCKETS = 2 + SOCKET_LIMIT = 3 + THREAD_PRIORITY = 3 + THREAD_SCHED_POLICY = 4 + MAX_MSGSZ = 5 + MSG_T_SIZE = 6 + THREAD_AFFINITY_CPU_ADD = 7 + THREAD_AFFINITY_CPU_REMOVE = 8 + THREAD_NAME_PREFIX = 9 + + +class SocketType(IntEnum): + """zmq socket types + + .. versionadded:: 23 + """ + + PAIR = 0 + PUB = 1 + SUB = 2 + REQ = 3 + REP = 4 + DEALER = 5 + ROUTER = 6 + PULL = 7 + PUSH = 8 + XPUB = 9 + XSUB = 10 + STREAM = 11 + + # deprecated aliases + XREQ = DEALER + XREP = ROUTER + + # DRAFT socket types + SERVER = 12 + CLIENT = 13 + RADIO = 14 + DISH = 15 + GATHER = 16 + SCATTER = 17 + DGRAM = 18 + PEER = 19 + CHANNEL = 20 + + +class _OptType(Enum): + int = 'int' + int64 = 'int64' + bytes = 'bytes' + fd = 'fd' + + +class SocketOption(IntEnum): + """Options for Socket.get/set + + .. versionadded:: 23 + """ + + _opt_type: _OptType + + def __new__(cls, value: int, opt_type: _OptType = _OptType.int): + """Attach option type as `._opt_type`""" + obj = int.__new__(cls, value) + obj._value_ = value + obj._opt_type = opt_type + return obj + + HWM = 1 + AFFINITY = 4, _OptType.int64 + ROUTING_ID = 5, _OptType.bytes + SUBSCRIBE = 6, _OptType.bytes + UNSUBSCRIBE = 7, _OptType.bytes + RATE = 8 + RECOVERY_IVL = 9 + SNDBUF = 11 + RCVBUF = 12 + RCVMORE = 13 + FD = 14, _OptType.fd + EVENTS = 15 + TYPE = 16 + LINGER = 17 + RECONNECT_IVL = 18 + BACKLOG = 19 + RECONNECT_IVL_MAX = 21 + MAXMSGSIZE = 22, _OptType.int64 + SNDHWM = 23 + RCVHWM = 24 + MULTICAST_HOPS = 25 + RCVTIMEO = 27 + SNDTIMEO = 28 + LAST_ENDPOINT = 32, _OptType.bytes + ROUTER_MANDATORY = 33 + TCP_KEEPALIVE = 34 + TCP_KEEPALIVE_CNT = 35 + TCP_KEEPALIVE_IDLE = 36 + TCP_KEEPALIVE_INTVL = 37 + IMMEDIATE = 39 + XPUB_VERBOSE = 40 + ROUTER_RAW = 41 + IPV6 = 42 + MECHANISM = 43 + PLAIN_SERVER = 44 + PLAIN_USERNAME = 45, _OptType.bytes + PLAIN_PASSWORD = 46, _OptType.bytes + CURVE_SERVER = 47 + CURVE_PUBLICKEY = 48, _OptType.bytes + CURVE_SECRETKEY = 49, _OptType.bytes + CURVE_SERVERKEY = 50, _OptType.bytes + PROBE_ROUTER = 51 + REQ_CORRELATE = 52 + REQ_RELAXED = 53 + CONFLATE = 54 + ZAP_DOMAIN = 55, _OptType.bytes + ROUTER_HANDOVER = 56 + TOS = 57 + CONNECT_ROUTING_ID = 61, _OptType.bytes + GSSAPI_SERVER = 62 + GSSAPI_PRINCIPAL = 63, _OptType.bytes + GSSAPI_SERVICE_PRINCIPAL = 64, _OptType.bytes + GSSAPI_PLAINTEXT = 65 + HANDSHAKE_IVL = 66 + SOCKS_PROXY = 68, _OptType.bytes + XPUB_NODROP = 69 + BLOCKY = 70 + XPUB_MANUAL = 71 + XPUB_WELCOME_MSG = 72, _OptType.bytes + STREAM_NOTIFY = 73 + INVERT_MATCHING = 74 + HEARTBEAT_IVL = 75 + HEARTBEAT_TTL = 76 + HEARTBEAT_TIMEOUT = 77 + XPUB_VERBOSER = 78 + CONNECT_TIMEOUT = 79 + TCP_MAXRT = 80 + THREAD_SAFE = 81 + MULTICAST_MAXTPDU = 84 + VMCI_BUFFER_SIZE = 85, _OptType.int64 + VMCI_BUFFER_MIN_SIZE = 86, _OptType.int64 + VMCI_BUFFER_MAX_SIZE = 87, _OptType.int64 + VMCI_CONNECT_TIMEOUT = 88 + USE_FD = 89 + GSSAPI_PRINCIPAL_NAMETYPE = 90 + GSSAPI_SERVICE_PRINCIPAL_NAMETYPE = 91 + BINDTODEVICE = 92, _OptType.bytes + + # Deprecated options and aliases + # must not use name-assignment, must have the same value + IDENTITY = ROUTING_ID + CONNECT_RID = CONNECT_ROUTING_ID + TCP_ACCEPT_FILTER = 38, _OptType.bytes + IPC_FILTER_PID = 58 + IPC_FILTER_UID = 59 + IPC_FILTER_GID = 60 + IPV4ONLY = 31 + DELAY_ATTACH_ON_CONNECT = IMMEDIATE + FAIL_UNROUTABLE = ROUTER_MANDATORY + ROUTER_BEHAVIOR = ROUTER_MANDATORY + + # Draft socket options + ZAP_ENFORCE_DOMAIN = 93 + LOOPBACK_FASTPATH = 94 + METADATA = 95, _OptType.bytes + MULTICAST_LOOP = 96 + ROUTER_NOTIFY = 97 + XPUB_MANUAL_LAST_VALUE = 98 + SOCKS_USERNAME = 99, _OptType.bytes + SOCKS_PASSWORD = 100, _OptType.bytes + IN_BATCH_SIZE = 101 + OUT_BATCH_SIZE = 102 + WSS_KEY_PEM = 103, _OptType.bytes + WSS_CERT_PEM = 104, _OptType.bytes + WSS_TRUST_PEM = 105, _OptType.bytes + WSS_HOSTNAME = 106, _OptType.bytes + WSS_TRUST_SYSTEM = 107 + ONLY_FIRST_SUBSCRIBE = 108 + RECONNECT_STOP = 109 + HELLO_MSG = 110, _OptType.bytes + DISCONNECT_MSG = 111, _OptType.bytes + PRIORITY = 112 + # 4.3.5 + BUSY_POLL = 113 + HICCUP_MSG = 114, _OptType.bytes + XSUB_VERBOSE_UNSUBSCRIBE = 115 + TOPICS_COUNT = 116 + NORM_MODE = 117 + NORM_UNICAST_NACK = 118 + NORM_BUFFER_SIZE = 119 + NORM_SEGMENT_SIZE = 120 + NORM_BLOCK_SIZE = 121 + NORM_NUM_PARITY = 122 + NORM_NUM_AUTOPARITY = 123 + NORM_PUSH = 124 + + +class MessageOption(IntEnum): + """Options on zmq.Frame objects + + .. versionadded:: 23 + """ + + MORE = 1 + SHARED = 3 + # Deprecated message options + SRCFD = 2 + + +class Flag(IntFlag): + """Send/recv flags + + .. versionadded:: 23 + """ + + DONTWAIT = 1 + SNDMORE = 2 + NOBLOCK = DONTWAIT + + +class RouterNotify(IntEnum): + """Values for zmq.ROUTER_NOTIFY socket option + + .. versionadded:: 26 + .. versionadded:: libzmq-4.3.0 (draft) + """ + + @staticmethod + def _global_name(name): + return f"NOTIFY_{name}" + + CONNECT = 1 + DISCONNECT = 2 + + +class NormMode(IntEnum): + """Values for zmq.NORM_MODE socket option + + .. versionadded:: 26 + .. versionadded:: libzmq-4.3.5 (draft) + """ + + @staticmethod + def _global_name(name): + return f"NORM_{name}" + + FIXED = 0 + CC = 1 + CCL = 2 + CCE = 3 + CCE_ECNONLY = 4 + + +class SecurityMechanism(IntEnum): + """Security mechanisms (as returned by ``socket.get(zmq.MECHANISM)``) + + .. versionadded:: 23 + """ + + NULL = 0 + PLAIN = 1 + CURVE = 2 + GSSAPI = 3 + + +class ReconnectStop(IntEnum): + """Select behavior for socket.reconnect_stop + + .. versionadded:: 25 + """ + + @staticmethod + def _global_name(name): + return f"RECONNECT_STOP_{name}" + + CONN_REFUSED = 0x1 + HANDSHAKE_FAILED = 0x2 + AFTER_DISCONNECT = 0x4 + + +class Event(IntFlag): + """Socket monitoring events + + .. versionadded:: 23 + """ + + @staticmethod + def _global_name(name): + if name.startswith("PROTOCOL_ERROR_"): + return name + else: + # add EVENT_ prefix + return "EVENT_" + name + + PROTOCOL_ERROR_WS_UNSPECIFIED = 0x30000000 + PROTOCOL_ERROR_ZMTP_UNSPECIFIED = 0x10000000 + PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND = 0x10000001 + PROTOCOL_ERROR_ZMTP_INVALID_SEQUENCE = 0x10000002 + PROTOCOL_ERROR_ZMTP_KEY_EXCHANGE = 0x10000003 + PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_UNSPECIFIED = 0x10000011 + PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_MESSAGE = 0x10000012 + PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_HELLO = 0x10000013 + PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_INITIATE = 0x10000014 + PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_ERROR = 0x10000015 + PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_READY = 0x10000016 + PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_WELCOME = 0x10000017 + PROTOCOL_ERROR_ZMTP_INVALID_METADATA = 0x10000018 + + PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC = 0x11000001 + PROTOCOL_ERROR_ZMTP_MECHANISM_MISMATCH = 0x11000002 + PROTOCOL_ERROR_ZAP_UNSPECIFIED = 0x20000000 + PROTOCOL_ERROR_ZAP_MALFORMED_REPLY = 0x20000001 + PROTOCOL_ERROR_ZAP_BAD_REQUEST_ID = 0x20000002 + PROTOCOL_ERROR_ZAP_BAD_VERSION = 0x20000003 + PROTOCOL_ERROR_ZAP_INVALID_STATUS_CODE = 0x20000004 + PROTOCOL_ERROR_ZAP_INVALID_METADATA = 0x20000005 + + # define event types _after_ overlapping protocol error masks + CONNECTED = 0x0001 + CONNECT_DELAYED = 0x0002 + CONNECT_RETRIED = 0x0004 + LISTENING = 0x0008 + BIND_FAILED = 0x0010 + ACCEPTED = 0x0020 + ACCEPT_FAILED = 0x0040 + CLOSED = 0x0080 + CLOSE_FAILED = 0x0100 + DISCONNECTED = 0x0200 + MONITOR_STOPPED = 0x0400 + + HANDSHAKE_FAILED_NO_DETAIL = 0x0800 + HANDSHAKE_SUCCEEDED = 0x1000 + HANDSHAKE_FAILED_PROTOCOL = 0x2000 + HANDSHAKE_FAILED_AUTH = 0x4000 + + ALL_V1 = 0xFFFF + ALL = ALL_V1 + + # DRAFT Socket monitoring events + PIPES_STATS = 0x10000 + ALL_V2 = ALL_V1 | PIPES_STATS + + +class PollEvent(IntFlag): + """Which events to poll for in poll methods + + .. versionadded: 23 + """ + + POLLIN = 1 + POLLOUT = 2 + POLLERR = 4 + POLLPRI = 8 + + +class DeviceType(IntEnum): + """Device type constants for zmq.device + + .. versionadded: 23 + """ + + STREAMER = 1 + FORWARDER = 2 + QUEUE = 3 + + +# AUTOGENERATED_BELOW_HERE + + +IO_THREADS: int = ContextOption.IO_THREADS +MAX_SOCKETS: int = ContextOption.MAX_SOCKETS +SOCKET_LIMIT: int = ContextOption.SOCKET_LIMIT +THREAD_PRIORITY: int = ContextOption.THREAD_PRIORITY +THREAD_SCHED_POLICY: int = ContextOption.THREAD_SCHED_POLICY +MAX_MSGSZ: int = ContextOption.MAX_MSGSZ +MSG_T_SIZE: int = ContextOption.MSG_T_SIZE +THREAD_AFFINITY_CPU_ADD: int = ContextOption.THREAD_AFFINITY_CPU_ADD +THREAD_AFFINITY_CPU_REMOVE: int = ContextOption.THREAD_AFFINITY_CPU_REMOVE +THREAD_NAME_PREFIX: int = ContextOption.THREAD_NAME_PREFIX +STREAMER: int = DeviceType.STREAMER +FORWARDER: int = DeviceType.FORWARDER +QUEUE: int = DeviceType.QUEUE +EAGAIN: int = Errno.EAGAIN +EFAULT: int = Errno.EFAULT +EINVAL: int = Errno.EINVAL +ENOTSUP: int = Errno.ENOTSUP +EPROTONOSUPPORT: int = Errno.EPROTONOSUPPORT +ENOBUFS: int = Errno.ENOBUFS +ENETDOWN: int = Errno.ENETDOWN +EADDRINUSE: int = Errno.EADDRINUSE +EADDRNOTAVAIL: int = Errno.EADDRNOTAVAIL +ECONNREFUSED: int = Errno.ECONNREFUSED +EINPROGRESS: int = Errno.EINPROGRESS +ENOTSOCK: int = Errno.ENOTSOCK +EMSGSIZE: int = Errno.EMSGSIZE +EAFNOSUPPORT: int = Errno.EAFNOSUPPORT +ENETUNREACH: int = Errno.ENETUNREACH +ECONNABORTED: int = Errno.ECONNABORTED +ECONNRESET: int = Errno.ECONNRESET +ENOTCONN: int = Errno.ENOTCONN +ETIMEDOUT: int = Errno.ETIMEDOUT +EHOSTUNREACH: int = Errno.EHOSTUNREACH +ENETRESET: int = Errno.ENETRESET +EFSM: int = Errno.EFSM +ENOCOMPATPROTO: int = Errno.ENOCOMPATPROTO +ETERM: int = Errno.ETERM +EMTHREAD: int = Errno.EMTHREAD +PROTOCOL_ERROR_WS_UNSPECIFIED: int = Event.PROTOCOL_ERROR_WS_UNSPECIFIED +PROTOCOL_ERROR_ZMTP_UNSPECIFIED: int = Event.PROTOCOL_ERROR_ZMTP_UNSPECIFIED +PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND: int = ( + Event.PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND +) +PROTOCOL_ERROR_ZMTP_INVALID_SEQUENCE: int = Event.PROTOCOL_ERROR_ZMTP_INVALID_SEQUENCE +PROTOCOL_ERROR_ZMTP_KEY_EXCHANGE: int = Event.PROTOCOL_ERROR_ZMTP_KEY_EXCHANGE +PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_UNSPECIFIED: int = ( + Event.PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_UNSPECIFIED +) +PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_MESSAGE: int = ( + Event.PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_MESSAGE +) +PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_HELLO: int = ( + Event.PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_HELLO +) +PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_INITIATE: int = ( + Event.PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_INITIATE +) +PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_ERROR: int = ( + Event.PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_ERROR +) +PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_READY: int = ( + Event.PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_READY +) +PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_WELCOME: int = ( + Event.PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_WELCOME +) +PROTOCOL_ERROR_ZMTP_INVALID_METADATA: int = Event.PROTOCOL_ERROR_ZMTP_INVALID_METADATA +PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC: int = Event.PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC +PROTOCOL_ERROR_ZMTP_MECHANISM_MISMATCH: int = ( + Event.PROTOCOL_ERROR_ZMTP_MECHANISM_MISMATCH +) +PROTOCOL_ERROR_ZAP_UNSPECIFIED: int = Event.PROTOCOL_ERROR_ZAP_UNSPECIFIED +PROTOCOL_ERROR_ZAP_MALFORMED_REPLY: int = Event.PROTOCOL_ERROR_ZAP_MALFORMED_REPLY +PROTOCOL_ERROR_ZAP_BAD_REQUEST_ID: int = Event.PROTOCOL_ERROR_ZAP_BAD_REQUEST_ID +PROTOCOL_ERROR_ZAP_BAD_VERSION: int = Event.PROTOCOL_ERROR_ZAP_BAD_VERSION +PROTOCOL_ERROR_ZAP_INVALID_STATUS_CODE: int = ( + Event.PROTOCOL_ERROR_ZAP_INVALID_STATUS_CODE +) +PROTOCOL_ERROR_ZAP_INVALID_METADATA: int = Event.PROTOCOL_ERROR_ZAP_INVALID_METADATA +EVENT_CONNECTED: int = Event.CONNECTED +EVENT_CONNECT_DELAYED: int = Event.CONNECT_DELAYED +EVENT_CONNECT_RETRIED: int = Event.CONNECT_RETRIED +EVENT_LISTENING: int = Event.LISTENING +EVENT_BIND_FAILED: int = Event.BIND_FAILED +EVENT_ACCEPTED: int = Event.ACCEPTED +EVENT_ACCEPT_FAILED: int = Event.ACCEPT_FAILED +EVENT_CLOSED: int = Event.CLOSED +EVENT_CLOSE_FAILED: int = Event.CLOSE_FAILED +EVENT_DISCONNECTED: int = Event.DISCONNECTED +EVENT_MONITOR_STOPPED: int = Event.MONITOR_STOPPED +EVENT_HANDSHAKE_FAILED_NO_DETAIL: int = Event.HANDSHAKE_FAILED_NO_DETAIL +EVENT_HANDSHAKE_SUCCEEDED: int = Event.HANDSHAKE_SUCCEEDED +EVENT_HANDSHAKE_FAILED_PROTOCOL: int = Event.HANDSHAKE_FAILED_PROTOCOL +EVENT_HANDSHAKE_FAILED_AUTH: int = Event.HANDSHAKE_FAILED_AUTH +EVENT_ALL_V1: int = Event.ALL_V1 +EVENT_ALL: int = Event.ALL +EVENT_PIPES_STATS: int = Event.PIPES_STATS +EVENT_ALL_V2: int = Event.ALL_V2 +DONTWAIT: int = Flag.DONTWAIT +SNDMORE: int = Flag.SNDMORE +NOBLOCK: int = Flag.NOBLOCK +MORE: int = MessageOption.MORE +SHARED: int = MessageOption.SHARED +SRCFD: int = MessageOption.SRCFD +NORM_FIXED: int = NormMode.FIXED +NORM_CC: int = NormMode.CC +NORM_CCL: int = NormMode.CCL +NORM_CCE: int = NormMode.CCE +NORM_CCE_ECNONLY: int = NormMode.CCE_ECNONLY +POLLIN: int = PollEvent.POLLIN +POLLOUT: int = PollEvent.POLLOUT +POLLERR: int = PollEvent.POLLERR +POLLPRI: int = PollEvent.POLLPRI +RECONNECT_STOP_CONN_REFUSED: int = ReconnectStop.CONN_REFUSED +RECONNECT_STOP_HANDSHAKE_FAILED: int = ReconnectStop.HANDSHAKE_FAILED +RECONNECT_STOP_AFTER_DISCONNECT: int = ReconnectStop.AFTER_DISCONNECT +NOTIFY_CONNECT: int = RouterNotify.CONNECT +NOTIFY_DISCONNECT: int = RouterNotify.DISCONNECT +NULL: int = SecurityMechanism.NULL +PLAIN: int = SecurityMechanism.PLAIN +CURVE: int = SecurityMechanism.CURVE +GSSAPI: int = SecurityMechanism.GSSAPI +HWM: int = SocketOption.HWM +AFFINITY: int = SocketOption.AFFINITY +ROUTING_ID: int = SocketOption.ROUTING_ID +SUBSCRIBE: int = SocketOption.SUBSCRIBE +UNSUBSCRIBE: int = SocketOption.UNSUBSCRIBE +RATE: int = SocketOption.RATE +RECOVERY_IVL: int = SocketOption.RECOVERY_IVL +SNDBUF: int = SocketOption.SNDBUF +RCVBUF: int = SocketOption.RCVBUF +RCVMORE: int = SocketOption.RCVMORE +FD: int = SocketOption.FD +EVENTS: int = SocketOption.EVENTS +TYPE: int = SocketOption.TYPE +LINGER: int = SocketOption.LINGER +RECONNECT_IVL: int = SocketOption.RECONNECT_IVL +BACKLOG: int = SocketOption.BACKLOG +RECONNECT_IVL_MAX: int = SocketOption.RECONNECT_IVL_MAX +MAXMSGSIZE: int = SocketOption.MAXMSGSIZE +SNDHWM: int = SocketOption.SNDHWM +RCVHWM: int = SocketOption.RCVHWM +MULTICAST_HOPS: int = SocketOption.MULTICAST_HOPS +RCVTIMEO: int = SocketOption.RCVTIMEO +SNDTIMEO: int = SocketOption.SNDTIMEO +LAST_ENDPOINT: int = SocketOption.LAST_ENDPOINT +ROUTER_MANDATORY: int = SocketOption.ROUTER_MANDATORY +TCP_KEEPALIVE: int = SocketOption.TCP_KEEPALIVE +TCP_KEEPALIVE_CNT: int = SocketOption.TCP_KEEPALIVE_CNT +TCP_KEEPALIVE_IDLE: int = SocketOption.TCP_KEEPALIVE_IDLE +TCP_KEEPALIVE_INTVL: int = SocketOption.TCP_KEEPALIVE_INTVL +IMMEDIATE: int = SocketOption.IMMEDIATE +XPUB_VERBOSE: int = SocketOption.XPUB_VERBOSE +ROUTER_RAW: int = SocketOption.ROUTER_RAW +IPV6: int = SocketOption.IPV6 +MECHANISM: int = SocketOption.MECHANISM +PLAIN_SERVER: int = SocketOption.PLAIN_SERVER +PLAIN_USERNAME: int = SocketOption.PLAIN_USERNAME +PLAIN_PASSWORD: int = SocketOption.PLAIN_PASSWORD +CURVE_SERVER: int = SocketOption.CURVE_SERVER +CURVE_PUBLICKEY: int = SocketOption.CURVE_PUBLICKEY +CURVE_SECRETKEY: int = SocketOption.CURVE_SECRETKEY +CURVE_SERVERKEY: int = SocketOption.CURVE_SERVERKEY +PROBE_ROUTER: int = SocketOption.PROBE_ROUTER +REQ_CORRELATE: int = SocketOption.REQ_CORRELATE +REQ_RELAXED: int = SocketOption.REQ_RELAXED +CONFLATE: int = SocketOption.CONFLATE +ZAP_DOMAIN: int = SocketOption.ZAP_DOMAIN +ROUTER_HANDOVER: int = SocketOption.ROUTER_HANDOVER +TOS: int = SocketOption.TOS +CONNECT_ROUTING_ID: int = SocketOption.CONNECT_ROUTING_ID +GSSAPI_SERVER: int = SocketOption.GSSAPI_SERVER +GSSAPI_PRINCIPAL: int = SocketOption.GSSAPI_PRINCIPAL +GSSAPI_SERVICE_PRINCIPAL: int = SocketOption.GSSAPI_SERVICE_PRINCIPAL +GSSAPI_PLAINTEXT: int = SocketOption.GSSAPI_PLAINTEXT +HANDSHAKE_IVL: int = SocketOption.HANDSHAKE_IVL +SOCKS_PROXY: int = SocketOption.SOCKS_PROXY +XPUB_NODROP: int = SocketOption.XPUB_NODROP +BLOCKY: int = SocketOption.BLOCKY +XPUB_MANUAL: int = SocketOption.XPUB_MANUAL +XPUB_WELCOME_MSG: int = SocketOption.XPUB_WELCOME_MSG +STREAM_NOTIFY: int = SocketOption.STREAM_NOTIFY +INVERT_MATCHING: int = SocketOption.INVERT_MATCHING +HEARTBEAT_IVL: int = SocketOption.HEARTBEAT_IVL +HEARTBEAT_TTL: int = SocketOption.HEARTBEAT_TTL +HEARTBEAT_TIMEOUT: int = SocketOption.HEARTBEAT_TIMEOUT +XPUB_VERBOSER: int = SocketOption.XPUB_VERBOSER +CONNECT_TIMEOUT: int = SocketOption.CONNECT_TIMEOUT +TCP_MAXRT: int = SocketOption.TCP_MAXRT +THREAD_SAFE: int = SocketOption.THREAD_SAFE +MULTICAST_MAXTPDU: int = SocketOption.MULTICAST_MAXTPDU +VMCI_BUFFER_SIZE: int = SocketOption.VMCI_BUFFER_SIZE +VMCI_BUFFER_MIN_SIZE: int = SocketOption.VMCI_BUFFER_MIN_SIZE +VMCI_BUFFER_MAX_SIZE: int = SocketOption.VMCI_BUFFER_MAX_SIZE +VMCI_CONNECT_TIMEOUT: int = SocketOption.VMCI_CONNECT_TIMEOUT +USE_FD: int = SocketOption.USE_FD +GSSAPI_PRINCIPAL_NAMETYPE: int = SocketOption.GSSAPI_PRINCIPAL_NAMETYPE +GSSAPI_SERVICE_PRINCIPAL_NAMETYPE: int = SocketOption.GSSAPI_SERVICE_PRINCIPAL_NAMETYPE +BINDTODEVICE: int = SocketOption.BINDTODEVICE +IDENTITY: int = SocketOption.IDENTITY +CONNECT_RID: int = SocketOption.CONNECT_RID +TCP_ACCEPT_FILTER: int = SocketOption.TCP_ACCEPT_FILTER +IPC_FILTER_PID: int = SocketOption.IPC_FILTER_PID +IPC_FILTER_UID: int = SocketOption.IPC_FILTER_UID +IPC_FILTER_GID: int = SocketOption.IPC_FILTER_GID +IPV4ONLY: int = SocketOption.IPV4ONLY +DELAY_ATTACH_ON_CONNECT: int = SocketOption.DELAY_ATTACH_ON_CONNECT +FAIL_UNROUTABLE: int = SocketOption.FAIL_UNROUTABLE +ROUTER_BEHAVIOR: int = SocketOption.ROUTER_BEHAVIOR +ZAP_ENFORCE_DOMAIN: int = SocketOption.ZAP_ENFORCE_DOMAIN +LOOPBACK_FASTPATH: int = SocketOption.LOOPBACK_FASTPATH +METADATA: int = SocketOption.METADATA +MULTICAST_LOOP: int = SocketOption.MULTICAST_LOOP +ROUTER_NOTIFY: int = SocketOption.ROUTER_NOTIFY +XPUB_MANUAL_LAST_VALUE: int = SocketOption.XPUB_MANUAL_LAST_VALUE +SOCKS_USERNAME: int = SocketOption.SOCKS_USERNAME +SOCKS_PASSWORD: int = SocketOption.SOCKS_PASSWORD +IN_BATCH_SIZE: int = SocketOption.IN_BATCH_SIZE +OUT_BATCH_SIZE: int = SocketOption.OUT_BATCH_SIZE +WSS_KEY_PEM: int = SocketOption.WSS_KEY_PEM +WSS_CERT_PEM: int = SocketOption.WSS_CERT_PEM +WSS_TRUST_PEM: int = SocketOption.WSS_TRUST_PEM +WSS_HOSTNAME: int = SocketOption.WSS_HOSTNAME +WSS_TRUST_SYSTEM: int = SocketOption.WSS_TRUST_SYSTEM +ONLY_FIRST_SUBSCRIBE: int = SocketOption.ONLY_FIRST_SUBSCRIBE +RECONNECT_STOP: int = SocketOption.RECONNECT_STOP +HELLO_MSG: int = SocketOption.HELLO_MSG +DISCONNECT_MSG: int = SocketOption.DISCONNECT_MSG +PRIORITY: int = SocketOption.PRIORITY +BUSY_POLL: int = SocketOption.BUSY_POLL +HICCUP_MSG: int = SocketOption.HICCUP_MSG +XSUB_VERBOSE_UNSUBSCRIBE: int = SocketOption.XSUB_VERBOSE_UNSUBSCRIBE +TOPICS_COUNT: int = SocketOption.TOPICS_COUNT +NORM_MODE: int = SocketOption.NORM_MODE +NORM_UNICAST_NACK: int = SocketOption.NORM_UNICAST_NACK +NORM_BUFFER_SIZE: int = SocketOption.NORM_BUFFER_SIZE +NORM_SEGMENT_SIZE: int = SocketOption.NORM_SEGMENT_SIZE +NORM_BLOCK_SIZE: int = SocketOption.NORM_BLOCK_SIZE +NORM_NUM_PARITY: int = SocketOption.NORM_NUM_PARITY +NORM_NUM_AUTOPARITY: int = SocketOption.NORM_NUM_AUTOPARITY +NORM_PUSH: int = SocketOption.NORM_PUSH +PAIR: int = SocketType.PAIR +PUB: int = SocketType.PUB +SUB: int = SocketType.SUB +REQ: int = SocketType.REQ +REP: int = SocketType.REP +DEALER: int = SocketType.DEALER +ROUTER: int = SocketType.ROUTER +PULL: int = SocketType.PULL +PUSH: int = SocketType.PUSH +XPUB: int = SocketType.XPUB +XSUB: int = SocketType.XSUB +STREAM: int = SocketType.STREAM +XREQ: int = SocketType.XREQ +XREP: int = SocketType.XREP +SERVER: int = SocketType.SERVER +CLIENT: int = SocketType.CLIENT +RADIO: int = SocketType.RADIO +DISH: int = SocketType.DISH +GATHER: int = SocketType.GATHER +SCATTER: int = SocketType.SCATTER +DGRAM: int = SocketType.DGRAM +PEER: int = SocketType.PEER +CHANNEL: int = SocketType.CHANNEL + +__all__: list[str] = [ + "ContextOption", + "IO_THREADS", + "MAX_SOCKETS", + "SOCKET_LIMIT", + "THREAD_PRIORITY", + "THREAD_SCHED_POLICY", + "MAX_MSGSZ", + "MSG_T_SIZE", + "THREAD_AFFINITY_CPU_ADD", + "THREAD_AFFINITY_CPU_REMOVE", + "THREAD_NAME_PREFIX", + "DeviceType", + "STREAMER", + "FORWARDER", + "QUEUE", + "Enum", + "Errno", + "EAGAIN", + "EFAULT", + "EINVAL", + "ENOTSUP", + "EPROTONOSUPPORT", + "ENOBUFS", + "ENETDOWN", + "EADDRINUSE", + "EADDRNOTAVAIL", + "ECONNREFUSED", + "EINPROGRESS", + "ENOTSOCK", + "EMSGSIZE", + "EAFNOSUPPORT", + "ENETUNREACH", + "ECONNABORTED", + "ECONNRESET", + "ENOTCONN", + "ETIMEDOUT", + "EHOSTUNREACH", + "ENETRESET", + "EFSM", + "ENOCOMPATPROTO", + "ETERM", + "EMTHREAD", + "Event", + "PROTOCOL_ERROR_WS_UNSPECIFIED", + "PROTOCOL_ERROR_ZMTP_UNSPECIFIED", + "PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND", + "PROTOCOL_ERROR_ZMTP_INVALID_SEQUENCE", + "PROTOCOL_ERROR_ZMTP_KEY_EXCHANGE", + "PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_UNSPECIFIED", + "PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_MESSAGE", + "PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_HELLO", + "PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_INITIATE", + "PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_ERROR", + "PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_READY", + "PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_WELCOME", + "PROTOCOL_ERROR_ZMTP_INVALID_METADATA", + "PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC", + "PROTOCOL_ERROR_ZMTP_MECHANISM_MISMATCH", + "PROTOCOL_ERROR_ZAP_UNSPECIFIED", + "PROTOCOL_ERROR_ZAP_MALFORMED_REPLY", + "PROTOCOL_ERROR_ZAP_BAD_REQUEST_ID", + "PROTOCOL_ERROR_ZAP_BAD_VERSION", + "PROTOCOL_ERROR_ZAP_INVALID_STATUS_CODE", + "PROTOCOL_ERROR_ZAP_INVALID_METADATA", + "EVENT_CONNECTED", + "EVENT_CONNECT_DELAYED", + "EVENT_CONNECT_RETRIED", + "EVENT_LISTENING", + "EVENT_BIND_FAILED", + "EVENT_ACCEPTED", + "EVENT_ACCEPT_FAILED", + "EVENT_CLOSED", + "EVENT_CLOSE_FAILED", + "EVENT_DISCONNECTED", + "EVENT_MONITOR_STOPPED", + "EVENT_HANDSHAKE_FAILED_NO_DETAIL", + "EVENT_HANDSHAKE_SUCCEEDED", + "EVENT_HANDSHAKE_FAILED_PROTOCOL", + "EVENT_HANDSHAKE_FAILED_AUTH", + "EVENT_ALL_V1", + "EVENT_ALL", + "EVENT_PIPES_STATS", + "EVENT_ALL_V2", + "Flag", + "DONTWAIT", + "SNDMORE", + "NOBLOCK", + "IntEnum", + "IntFlag", + "MessageOption", + "MORE", + "SHARED", + "SRCFD", + "NormMode", + "NORM_FIXED", + "NORM_CC", + "NORM_CCL", + "NORM_CCE", + "NORM_CCE_ECNONLY", + "PollEvent", + "POLLIN", + "POLLOUT", + "POLLERR", + "POLLPRI", + "ReconnectStop", + "RECONNECT_STOP_CONN_REFUSED", + "RECONNECT_STOP_HANDSHAKE_FAILED", + "RECONNECT_STOP_AFTER_DISCONNECT", + "RouterNotify", + "NOTIFY_CONNECT", + "NOTIFY_DISCONNECT", + "SecurityMechanism", + "NULL", + "PLAIN", + "CURVE", + "GSSAPI", + "SocketOption", + "HWM", + "AFFINITY", + "ROUTING_ID", + "SUBSCRIBE", + "UNSUBSCRIBE", + "RATE", + "RECOVERY_IVL", + "SNDBUF", + "RCVBUF", + "RCVMORE", + "FD", + "EVENTS", + "TYPE", + "LINGER", + "RECONNECT_IVL", + "BACKLOG", + "RECONNECT_IVL_MAX", + "MAXMSGSIZE", + "SNDHWM", + "RCVHWM", + "MULTICAST_HOPS", + "RCVTIMEO", + "SNDTIMEO", + "LAST_ENDPOINT", + "ROUTER_MANDATORY", + "TCP_KEEPALIVE", + "TCP_KEEPALIVE_CNT", + "TCP_KEEPALIVE_IDLE", + "TCP_KEEPALIVE_INTVL", + "IMMEDIATE", + "XPUB_VERBOSE", + "ROUTER_RAW", + "IPV6", + "MECHANISM", + "PLAIN_SERVER", + "PLAIN_USERNAME", + "PLAIN_PASSWORD", + "CURVE_SERVER", + "CURVE_PUBLICKEY", + "CURVE_SECRETKEY", + "CURVE_SERVERKEY", + "PROBE_ROUTER", + "REQ_CORRELATE", + "REQ_RELAXED", + "CONFLATE", + "ZAP_DOMAIN", + "ROUTER_HANDOVER", + "TOS", + "CONNECT_ROUTING_ID", + "GSSAPI_SERVER", + "GSSAPI_PRINCIPAL", + "GSSAPI_SERVICE_PRINCIPAL", + "GSSAPI_PLAINTEXT", + "HANDSHAKE_IVL", + "SOCKS_PROXY", + "XPUB_NODROP", + "BLOCKY", + "XPUB_MANUAL", + "XPUB_WELCOME_MSG", + "STREAM_NOTIFY", + "INVERT_MATCHING", + "HEARTBEAT_IVL", + "HEARTBEAT_TTL", + "HEARTBEAT_TIMEOUT", + "XPUB_VERBOSER", + "CONNECT_TIMEOUT", + "TCP_MAXRT", + "THREAD_SAFE", + "MULTICAST_MAXTPDU", + "VMCI_BUFFER_SIZE", + "VMCI_BUFFER_MIN_SIZE", + "VMCI_BUFFER_MAX_SIZE", + "VMCI_CONNECT_TIMEOUT", + "USE_FD", + "GSSAPI_PRINCIPAL_NAMETYPE", + "GSSAPI_SERVICE_PRINCIPAL_NAMETYPE", + "BINDTODEVICE", + "IDENTITY", + "CONNECT_RID", + "TCP_ACCEPT_FILTER", + "IPC_FILTER_PID", + "IPC_FILTER_UID", + "IPC_FILTER_GID", + "IPV4ONLY", + "DELAY_ATTACH_ON_CONNECT", + "FAIL_UNROUTABLE", + "ROUTER_BEHAVIOR", + "ZAP_ENFORCE_DOMAIN", + "LOOPBACK_FASTPATH", + "METADATA", + "MULTICAST_LOOP", + "ROUTER_NOTIFY", + "XPUB_MANUAL_LAST_VALUE", + "SOCKS_USERNAME", + "SOCKS_PASSWORD", + "IN_BATCH_SIZE", + "OUT_BATCH_SIZE", + "WSS_KEY_PEM", + "WSS_CERT_PEM", + "WSS_TRUST_PEM", + "WSS_HOSTNAME", + "WSS_TRUST_SYSTEM", + "ONLY_FIRST_SUBSCRIBE", + "RECONNECT_STOP", + "HELLO_MSG", + "DISCONNECT_MSG", + "PRIORITY", + "BUSY_POLL", + "HICCUP_MSG", + "XSUB_VERBOSE_UNSUBSCRIBE", + "TOPICS_COUNT", + "NORM_MODE", + "NORM_UNICAST_NACK", + "NORM_BUFFER_SIZE", + "NORM_SEGMENT_SIZE", + "NORM_BLOCK_SIZE", + "NORM_NUM_PARITY", + "NORM_NUM_AUTOPARITY", + "NORM_PUSH", + "SocketType", + "PAIR", + "PUB", + "SUB", + "REQ", + "REP", + "DEALER", + "ROUTER", + "PULL", + "PUSH", + "XPUB", + "XSUB", + "STREAM", + "XREQ", + "XREP", + "SERVER", + "CLIENT", + "RADIO", + "DISH", + "GATHER", + "SCATTER", + "DGRAM", + "PEER", + "CHANNEL", +] diff --git a/source/zmq/decorators.py b/source/zmq/decorators.py new file mode 100644 index 0000000000000000000000000000000000000000..7cd80ebc764fb0b83146821ea85af3ce4aad8196 --- /dev/null +++ b/source/zmq/decorators.py @@ -0,0 +1,190 @@ +"""Decorators for running functions with context/sockets. + +.. versionadded:: 15.3 + +Like using Contexts and Sockets as context managers, but with decorator syntax. +Context and sockets are closed at the end of the function. + +For example:: + + from zmq.decorators import context, socket + + @context() + @socket(zmq.PUSH) + def work(ctx, push): + ... +""" + +from __future__ import annotations + +# Copyright (c) PyZMQ Developers. +# Distributed under the terms of the Modified BSD License. + +__all__ = ( + 'context', + 'socket', +) + +from functools import wraps + +import zmq + + +class _Decorator: + '''The mini decorator factory''' + + def __init__(self, target=None): + self._target = target + + def __call__(self, *dec_args, **dec_kwargs): + """ + The main logic of decorator + + Here is how those arguments works:: + + @out_decorator(*dec_args, *dec_kwargs) + def func(*wrap_args, **wrap_kwargs): + ... + + And in the ``wrapper``, we simply create ``self.target`` instance via + ``with``:: + + target = self.get_target(*args, **kwargs) + with target(*dec_args, **dec_kwargs) as obj: + ... + + """ + kw_name, dec_args, dec_kwargs = self.process_decorator_args( + *dec_args, **dec_kwargs + ) + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + target = self.get_target(*args, **kwargs) + + with target(*dec_args, **dec_kwargs) as obj: + # insert our object into args + if kw_name and kw_name not in kwargs: + kwargs[kw_name] = obj + elif kw_name and kw_name in kwargs: + raise TypeError( + f"{func.__name__}() got multiple values for" + f" argument '{kw_name}'" + ) + else: + args = args + (obj,) + + return func(*args, **kwargs) + + return wrapper + + return decorator + + def get_target(self, *args, **kwargs): + """Return the target function + + Allows modifying args/kwargs to be passed. + """ + return self._target + + def process_decorator_args(self, *args, **kwargs): + """Process args passed to the decorator. + + args not consumed by the decorator will be passed to the target factory + (Context/Socket constructor). + """ + kw_name = None + + if isinstance(kwargs.get('name'), str): + kw_name = kwargs.pop('name') + elif len(args) >= 1 and isinstance(args[0], str): + kw_name = args[0] + args = args[1:] + + return kw_name, args, kwargs + + +class _ContextDecorator(_Decorator): + """Decorator subclass for Contexts""" + + def __init__(self): + super().__init__(zmq.Context) + + +class _SocketDecorator(_Decorator): + """Decorator subclass for sockets + + Gets the context from other args. + """ + + def process_decorator_args(self, *args, **kwargs): + """Also grab context_name out of kwargs""" + kw_name, args, kwargs = super().process_decorator_args(*args, **kwargs) + self.context_name = kwargs.pop('context_name', 'context') + return kw_name, args, kwargs + + def get_target(self, *args, **kwargs): + """Get context, based on call-time args""" + context = self._get_context(*args, **kwargs) + return context.socket + + def _get_context(self, *args, **kwargs): + """ + Find the ``zmq.Context`` from ``args`` and ``kwargs`` at call time. + + First, if there is an keyword argument named ``context`` and it is a + ``zmq.Context`` instance , we will take it. + + Second, we check all the ``args``, take the first ``zmq.Context`` + instance. + + Finally, we will provide default Context -- ``zmq.Context.instance`` + + :return: a ``zmq.Context`` instance + """ + if self.context_name in kwargs: + ctx = kwargs[self.context_name] + + if isinstance(ctx, zmq.Context): + return ctx + + for arg in args: + if isinstance(arg, zmq.Context): + return arg + # not specified by any decorator + return zmq.Context.instance() + + +def context(*args, **kwargs): + """Decorator for adding a Context to a function. + + Usage:: + + @context() + def foo(ctx): + ... + + .. versionadded:: 15.3 + + :param str name: the keyword argument passed to decorated function + """ + return _ContextDecorator()(*args, **kwargs) + + +def socket(*args, **kwargs): + """Decorator for adding a socket to a function. + + Usage:: + + @socket(zmq.PUSH) + def foo(push): + ... + + .. versionadded:: 15.3 + + :param str name: the keyword argument passed to decorated function + :param str context_name: the keyword only argument to identify context + object + """ + return _SocketDecorator()(*args, **kwargs) diff --git a/source/zmq/devices/__init__.py b/source/zmq/devices/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a2ccf2aa471104cccb82fe55e91a14ce81e1b25a --- /dev/null +++ b/source/zmq/devices/__init__.py @@ -0,0 +1,30 @@ +"""0MQ Device classes for running in background threads or processes.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from __future__ import annotations + +from zmq import DeviceType, proxy +from zmq.devices import ( + basedevice, + monitoredqueue, + monitoredqueuedevice, + proxydevice, + proxysteerabledevice, +) +from zmq.devices.basedevice import * +from zmq.devices.monitoredqueue import * +from zmq.devices.monitoredqueuedevice import * +from zmq.devices.proxydevice import * +from zmq.devices.proxysteerabledevice import * + +__all__ = [] +for submod in ( + basedevice, + proxydevice, + proxysteerabledevice, + monitoredqueue, + monitoredqueuedevice, +): + __all__.extend(submod.__all__) # type: ignore diff --git a/source/zmq/devices/basedevice.py b/source/zmq/devices/basedevice.py new file mode 100644 index 0000000000000000000000000000000000000000..5039fd70d03e9713beb6954016f4946e11052e75 --- /dev/null +++ b/source/zmq/devices/basedevice.py @@ -0,0 +1,310 @@ +"""Classes for running 0MQ Devices in the background.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import time +from multiprocessing import Process +from threading import Thread +from typing import Any, Callable, List, Optional, Tuple + +import zmq +from zmq import ENOTSOCK, ETERM, PUSH, QUEUE, Context, ZMQBindError, ZMQError, proxy + + +class Device: + """A 0MQ Device to be run in the background. + + You do not pass Socket instances to this, but rather Socket types:: + + Device(device_type, in_socket_type, out_socket_type) + + For instance:: + + dev = Device(zmq.QUEUE, zmq.DEALER, zmq.ROUTER) + + Similar to zmq.device, but socket types instead of sockets themselves are + passed, and the sockets are created in the work thread, to avoid issues + with thread safety. As a result, additional bind_{in|out} and + connect_{in|out} methods and setsockopt_{in|out} allow users to specify + connections for the sockets. + + Parameters + ---------- + device_type : int + The 0MQ Device type + {in|out}_type : int + zmq socket types, to be passed later to context.socket(). e.g. + zmq.PUB, zmq.SUB, zmq.REQ. If out_type is < 0, then in_socket is used + for both in_socket and out_socket. + + Methods + ------- + bind_{in_out}(iface) + passthrough for ``{in|out}_socket.bind(iface)``, to be called in the thread + connect_{in_out}(iface) + passthrough for ``{in|out}_socket.connect(iface)``, to be called in the + thread + setsockopt_{in_out}(opt,value) + passthrough for ``{in|out}_socket.setsockopt(opt, value)``, to be called in + the thread + + Attributes + ---------- + daemon : bool + sets whether the thread should be run as a daemon + Default is true, because if it is false, the thread will not + exit unless it is killed + context_factory : callable + This is a class attribute. + Function for creating the Context. This will be Context.instance + in ThreadDevices, and Context in ProcessDevices. The only reason + it is not instance() in ProcessDevices is that there may be a stale + Context instance already initialized, and the forked environment + should *never* try to use it. + """ + + context_factory: Callable[[], zmq.Context] = Context.instance + """Callable that returns a context. Typically either Context.instance or Context, + depending on whether the device should share the global instance or not. + """ + + daemon: bool + device_type: int + in_type: int + out_type: int + + _in_binds: List[str] + _in_connects: List[str] + _in_sockopts: List[Tuple[int, Any]] + _out_binds: List[str] + _out_connects: List[str] + _out_sockopts: List[Tuple[int, Any]] + _random_addrs: List[str] + _sockets: List[zmq.Socket] + + def __init__( + self, + device_type: int = QUEUE, + in_type: Optional[int] = None, + out_type: Optional[int] = None, + ) -> None: + self.device_type = device_type + if in_type is None: + raise TypeError("in_type must be specified") + if out_type is None: + raise TypeError("out_type must be specified") + self.in_type = in_type + self.out_type = out_type + self._in_binds = [] + self._in_connects = [] + self._in_sockopts = [] + self._out_binds = [] + self._out_connects = [] + self._out_sockopts = [] + self._random_addrs = [] + self.daemon = True + self.done = False + self._sockets = [] + + def bind_in(self, addr: str) -> None: + """Enqueue ZMQ address for binding on in_socket. + + See zmq.Socket.bind for details. + """ + self._in_binds.append(addr) + + def bind_in_to_random_port(self, addr: str, *args, **kwargs) -> int: + """Enqueue a random port on the given interface for binding on + in_socket. + + See zmq.Socket.bind_to_random_port for details. + + .. versionadded:: 18.0 + """ + port = self._reserve_random_port(addr, *args, **kwargs) + + self.bind_in(f'{addr}:{port}') + + return port + + def connect_in(self, addr: str) -> None: + """Enqueue ZMQ address for connecting on in_socket. + + See zmq.Socket.connect for details. + """ + self._in_connects.append(addr) + + def setsockopt_in(self, opt: int, value: Any) -> None: + """Enqueue setsockopt(opt, value) for in_socket + + See zmq.Socket.setsockopt for details. + """ + self._in_sockopts.append((opt, value)) + + def bind_out(self, addr: str) -> None: + """Enqueue ZMQ address for binding on out_socket. + + See zmq.Socket.bind for details. + """ + self._out_binds.append(addr) + + def bind_out_to_random_port(self, addr: str, *args, **kwargs) -> int: + """Enqueue a random port on the given interface for binding on + out_socket. + + See zmq.Socket.bind_to_random_port for details. + + .. versionadded:: 18.0 + """ + port = self._reserve_random_port(addr, *args, **kwargs) + + self.bind_out(f'{addr}:{port}') + + return port + + def connect_out(self, addr: str): + """Enqueue ZMQ address for connecting on out_socket. + + See zmq.Socket.connect for details. + """ + self._out_connects.append(addr) + + def setsockopt_out(self, opt: int, value: Any): + """Enqueue setsockopt(opt, value) for out_socket + + See zmq.Socket.setsockopt for details. + """ + self._out_sockopts.append((opt, value)) + + def _reserve_random_port(self, addr: str, *args, **kwargs) -> int: + with Context() as ctx: + with ctx.socket(PUSH) as binder: + for i in range(5): + port = binder.bind_to_random_port(addr, *args, **kwargs) + + new_addr = f'{addr}:{port}' + + if new_addr in self._random_addrs: + continue + else: + break + else: + raise ZMQBindError("Could not reserve random port.") + + self._random_addrs.append(new_addr) + + return port + + def _setup_sockets(self) -> Tuple[zmq.Socket, zmq.Socket]: + ctx: zmq.Context[zmq.Socket] = self.context_factory() # type: ignore + self._context = ctx + + # create the sockets + ins = ctx.socket(self.in_type) + self._sockets.append(ins) + if self.out_type < 0: + outs = ins + else: + outs = ctx.socket(self.out_type) + self._sockets.append(outs) + + # set sockopts (must be done first, in case of zmq.IDENTITY) + for opt, value in self._in_sockopts: + ins.setsockopt(opt, value) + for opt, value in self._out_sockopts: + outs.setsockopt(opt, value) + + for iface in self._in_binds: + ins.bind(iface) + for iface in self._out_binds: + outs.bind(iface) + + for iface in self._in_connects: + ins.connect(iface) + for iface in self._out_connects: + outs.connect(iface) + + return ins, outs + + def run_device(self) -> None: + """The runner method. + + Do not call me directly, instead call ``self.start()``, just like a Thread. + """ + ins, outs = self._setup_sockets() + proxy(ins, outs) + + def _close_sockets(self): + """Cleanup sockets we created""" + for s in self._sockets: + if s and not s.closed: + s.close() + + def run(self) -> None: + """wrap run_device in try/catch ETERM""" + try: + self.run_device() + except ZMQError as e: + if e.errno in {ETERM, ENOTSOCK}: + # silence TERM, ENOTSOCK errors, because this should be a clean shutdown + pass + else: + raise + finally: + self.done = True + self._close_sockets() + + def start(self) -> None: + """Start the device. Override me in subclass for other launchers.""" + return self.run() + + def join(self, timeout: Optional[float] = None) -> None: + """wait for me to finish, like Thread.join. + + Reimplemented appropriately by subclasses.""" + tic = time.monotonic() + toc = tic + while not self.done and not (timeout is not None and toc - tic > timeout): + time.sleep(0.001) + toc = time.monotonic() + + +class BackgroundDevice(Device): + """Base class for launching Devices in background processes and threads.""" + + launcher: Any = None + _launch_class: Any = None + + def start(self) -> None: + self.launcher = self._launch_class(target=self.run) + self.launcher.daemon = self.daemon + return self.launcher.start() + + def join(self, timeout: Optional[float] = None) -> None: + return self.launcher.join(timeout=timeout) + + +class ThreadDevice(BackgroundDevice): + """A Device that will be run in a background Thread. + + See Device for details. + """ + + _launch_class = Thread + + +class ProcessDevice(BackgroundDevice): + """A Device that will be run in a background Process. + + See Device for details. + """ + + _launch_class = Process + context_factory = Context + """Callable that returns a context. Typically either Context.instance or Context, + depending on whether the device should share the global instance or not. + """ + + +__all__ = ['Device', 'ThreadDevice', 'ProcessDevice'] diff --git a/source/zmq/devices/monitoredqueue.py b/source/zmq/devices/monitoredqueue.py new file mode 100644 index 0000000000000000000000000000000000000000..f590457a8696979889fe3b5f4b7604f6149d69b7 --- /dev/null +++ b/source/zmq/devices/monitoredqueue.py @@ -0,0 +1,51 @@ +"""pure Python monitored_queue function + +For use when Cython extension is unavailable (PyPy). + +Authors +------- +* MinRK +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from typing import Callable + +import zmq +from zmq.backend import monitored_queue as _backend_mq + + +def _relay(ins, outs, sides, prefix, swap_ids): + msg = ins.recv_multipart() + if swap_ids: + msg[:2] = msg[:2][::-1] + outs.send_multipart(msg) + sides.send_multipart([prefix] + msg) + + +def _monitored_queue( + in_socket, out_socket, mon_socket, in_prefix=b'in', out_prefix=b'out' +): + swap_ids = in_socket.type == zmq.ROUTER and out_socket.type == zmq.ROUTER + + poller = zmq.Poller() + poller.register(in_socket, zmq.POLLIN) + poller.register(out_socket, zmq.POLLIN) + while True: + events = dict(poller.poll()) + if in_socket in events: + _relay(in_socket, out_socket, mon_socket, in_prefix, swap_ids) + if out_socket in events: + _relay(out_socket, in_socket, mon_socket, out_prefix, swap_ids) + + +monitored_queue: Callable +if _backend_mq is not None: + monitored_queue = _backend_mq # type: ignore +else: + # backend has no monitored_queue + monitored_queue = _monitored_queue + + +__all__ = ['monitored_queue'] diff --git a/source/zmq/devices/monitoredqueuedevice.py b/source/zmq/devices/monitoredqueuedevice.py new file mode 100644 index 0000000000000000000000000000000000000000..7bcc5629964e3c1fb19cf042c4ec4e16444ac8aa --- /dev/null +++ b/source/zmq/devices/monitoredqueuedevice.py @@ -0,0 +1,60 @@ +"""MonitoredQueue classes and functions.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from zmq import PUB +from zmq.devices.monitoredqueue import monitored_queue +from zmq.devices.proxydevice import ProcessProxy, Proxy, ProxyBase, ThreadProxy + + +class MonitoredQueueBase(ProxyBase): + """Base class for overriding methods.""" + + _in_prefix = b'' + _out_prefix = b'' + + def __init__( + self, in_type, out_type, mon_type=PUB, in_prefix=b'in', out_prefix=b'out' + ): + ProxyBase.__init__(self, in_type=in_type, out_type=out_type, mon_type=mon_type) + + self._in_prefix = in_prefix + self._out_prefix = out_prefix + + def run_device(self): + ins, outs, mons = self._setup_sockets() + monitored_queue(ins, outs, mons, self._in_prefix, self._out_prefix) + + +class MonitoredQueue(MonitoredQueueBase, Proxy): + """Class for running monitored_queue in the background. + + See zmq.devices.Device for most of the spec. MonitoredQueue differs from Proxy, + only in that it adds a ``prefix`` to messages sent on the monitor socket, + with a different prefix for each direction. + + MQ also supports ROUTER on both sides, which zmq.proxy does not. + + If a message arrives on `in_sock`, it will be prefixed with `in_prefix` on the monitor socket. + If it arrives on out_sock, it will be prefixed with `out_prefix`. + + A PUB socket is the most logical choice for the mon_socket, but it is not required. + """ + + +class ThreadMonitoredQueue(MonitoredQueueBase, ThreadProxy): + """Run zmq.monitored_queue in a background thread. + + See MonitoredQueue and Proxy for details. + """ + + +class ProcessMonitoredQueue(MonitoredQueueBase, ProcessProxy): + """Run zmq.monitored_queue in a separate process. + + See MonitoredQueue and Proxy for details. + """ + + +__all__ = ['MonitoredQueue', 'ThreadMonitoredQueue', 'ProcessMonitoredQueue'] diff --git a/source/zmq/devices/proxydevice.py b/source/zmq/devices/proxydevice.py new file mode 100644 index 0000000000000000000000000000000000000000..f2af06793c27bbf4c9a9c33a377e2acd6ded5c09 --- /dev/null +++ b/source/zmq/devices/proxydevice.py @@ -0,0 +1,104 @@ +"""Proxy classes and functions.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import zmq +from zmq.devices.basedevice import Device, ProcessDevice, ThreadDevice + + +class ProxyBase: + """Base class for overriding methods.""" + + def __init__(self, in_type, out_type, mon_type=zmq.PUB): + Device.__init__(self, in_type=in_type, out_type=out_type) + self.mon_type = mon_type + self._mon_binds = [] + self._mon_connects = [] + self._mon_sockopts = [] + + def bind_mon(self, addr): + """Enqueue ZMQ address for binding on mon_socket. + + See zmq.Socket.bind for details. + """ + self._mon_binds.append(addr) + + def bind_mon_to_random_port(self, addr, *args, **kwargs): + """Enqueue a random port on the given interface for binding on + mon_socket. + + See zmq.Socket.bind_to_random_port for details. + + .. versionadded:: 18.0 + """ + port = self._reserve_random_port(addr, *args, **kwargs) + + self.bind_mon(f'{addr}:{port}') + + return port + + def connect_mon(self, addr): + """Enqueue ZMQ address for connecting on mon_socket. + + See zmq.Socket.connect for details. + """ + self._mon_connects.append(addr) + + def setsockopt_mon(self, opt, value): + """Enqueue setsockopt(opt, value) for mon_socket + + See zmq.Socket.setsockopt for details. + """ + self._mon_sockopts.append((opt, value)) + + def _setup_sockets(self): + ins, outs = Device._setup_sockets(self) + ctx = self._context + mons = ctx.socket(self.mon_type) + self._sockets.append(mons) + + # set sockopts (must be done first, in case of zmq.IDENTITY) + for opt, value in self._mon_sockopts: + mons.setsockopt(opt, value) + + for iface in self._mon_binds: + mons.bind(iface) + + for iface in self._mon_connects: + mons.connect(iface) + + return ins, outs, mons + + def run_device(self): + ins, outs, mons = self._setup_sockets() + zmq.proxy(ins, outs, mons) + + +class Proxy(ProxyBase, Device): + """Threadsafe Proxy object. + + See zmq.devices.Device for most of the spec. This subclass adds a + _mon version of each _{in|out} method, for configuring the + monitor socket. + + A Proxy is a 3-socket ZMQ Device that functions just like a + QUEUE, except each message is also sent out on the monitor socket. + + A PUB socket is the most logical choice for the mon_socket, but it is not required. + """ + + +class ThreadProxy(ProxyBase, ThreadDevice): + """Proxy in a Thread. See Proxy for more.""" + + +class ProcessProxy(ProxyBase, ProcessDevice): + """Proxy in a Process. See Proxy for more.""" + + +__all__ = [ + 'Proxy', + 'ThreadProxy', + 'ProcessProxy', +] diff --git a/source/zmq/devices/proxysteerabledevice.py b/source/zmq/devices/proxysteerabledevice.py new file mode 100644 index 0000000000000000000000000000000000000000..256a1e0498c907791da79935c7ed0f35faf90ce0 --- /dev/null +++ b/source/zmq/devices/proxysteerabledevice.py @@ -0,0 +1,106 @@ +"""Classes for running a steerable ZMQ proxy""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import zmq +from zmq.devices.proxydevice import ProcessProxy, Proxy, ThreadProxy + + +class ProxySteerableBase: + """Base class for overriding methods.""" + + def __init__(self, in_type, out_type, mon_type=zmq.PUB, ctrl_type=None): + super().__init__(in_type=in_type, out_type=out_type, mon_type=mon_type) + self.ctrl_type = ctrl_type + self._ctrl_binds = [] + self._ctrl_connects = [] + self._ctrl_sockopts = [] + + def bind_ctrl(self, addr): + """Enqueue ZMQ address for binding on ctrl_socket. + + See zmq.Socket.bind for details. + """ + self._ctrl_binds.append(addr) + + def bind_ctrl_to_random_port(self, addr, *args, **kwargs): + """Enqueue a random port on the given interface for binding on + ctrl_socket. + + See zmq.Socket.bind_to_random_port for details. + """ + port = self._reserve_random_port(addr, *args, **kwargs) + + self.bind_ctrl(f'{addr}:{port}') + + return port + + def connect_ctrl(self, addr): + """Enqueue ZMQ address for connecting on ctrl_socket. + + See zmq.Socket.connect for details. + """ + self._ctrl_connects.append(addr) + + def setsockopt_ctrl(self, opt, value): + """Enqueue setsockopt(opt, value) for ctrl_socket + + See zmq.Socket.setsockopt for details. + """ + self._ctrl_sockopts.append((opt, value)) + + def _setup_sockets(self): + ins, outs, mons = super()._setup_sockets() + ctx = self._context + ctrls = ctx.socket(self.ctrl_type) + self._sockets.append(ctrls) + + for opt, value in self._ctrl_sockopts: + ctrls.setsockopt(opt, value) + + for iface in self._ctrl_binds: + ctrls.bind(iface) + + for iface in self._ctrl_connects: + ctrls.connect(iface) + + return ins, outs, mons, ctrls + + def run_device(self): + ins, outs, mons, ctrls = self._setup_sockets() + zmq.proxy_steerable(ins, outs, mons, ctrls) + + +class ProxySteerable(ProxySteerableBase, Proxy): + """Class for running a steerable proxy in the background. + + See zmq.devices.Proxy for most of the spec. If the control socket is not + NULL, the proxy supports control flow, provided by the socket. + + If PAUSE is received on this socket, the proxy suspends its activities. If + RESUME is received, it goes on. If TERMINATE is received, it terminates + smoothly. If the control socket is NULL, the proxy behave exactly as if + zmq.devices.Proxy had been used. + + This subclass adds a _ctrl version of each _{in|out} + method, for configuring the control socket. + + .. versionadded:: libzmq-4.1 + .. versionadded:: 18.0 + """ + + +class ThreadProxySteerable(ProxySteerableBase, ThreadProxy): + """ProxySteerable in a Thread. See ProxySteerable for details.""" + + +class ProcessProxySteerable(ProxySteerableBase, ProcessProxy): + """ProxySteerable in a Process. See ProxySteerable for details.""" + + +__all__ = [ + 'ProxySteerable', + 'ThreadProxySteerable', + 'ProcessProxySteerable', +] diff --git a/source/zmq/error.py b/source/zmq/error.py new file mode 100644 index 0000000000000000000000000000000000000000..8a07a51fe1fc55aaeb0e1908d5f2251b0d32e9ac --- /dev/null +++ b/source/zmq/error.py @@ -0,0 +1,229 @@ +"""0MQ Error classes and functions.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. +from __future__ import annotations + +from errno import EINTR + + +class DraftFDWarning(RuntimeWarning): + """Warning for using experimental FD on draft sockets. + + .. versionadded:: 27 + """ + + def __init__(self, msg=""): + if not msg: + msg = ( + "pyzmq's back-fill socket.FD support on thread-safe sockets is experimental, and may be removed." + " This warning will go away automatically if/when libzmq implements socket.FD on thread-safe sockets." + " You can suppress this warning with `warnings.simplefilter('ignore', zmq.error.DraftFDWarning)" + ) + super().__init__(msg) + + +class ZMQBaseError(Exception): + """Base exception class for 0MQ errors in Python.""" + + +class ZMQError(ZMQBaseError): + """Wrap an errno style error. + + Parameters + ---------- + errno : int + The ZMQ errno or None. If None, then ``zmq_errno()`` is called and + used. + msg : str + Description of the error or None. + """ + + errno: int | None = None + strerror: str + + def __init__(self, errno: int | None = None, msg: str | None = None): + """Wrap an errno style error. + + Parameters + ---------- + errno : int + The ZMQ errno or None. If None, then ``zmq_errno()`` is called and + used. + msg : string + Description of the error or None. + """ + from zmq.backend import strerror, zmq_errno + + if errno is None: + errno = zmq_errno() + if isinstance(errno, int): + self.errno = errno + if msg is None: + self.strerror = strerror(errno) + else: + self.strerror = msg + else: + if msg is None: + self.strerror = str(errno) + else: + self.strerror = msg + # flush signals, because there could be a SIGINT + # waiting to pounce, resulting in uncaught exceptions. + # Doing this here means getting SIGINT during a blocking + # libzmq call will raise a *catchable* KeyboardInterrupt + # PyErr_CheckSignals() + + def __str__(self) -> str: + return self.strerror + + def __repr__(self) -> str: + return f"{self.__class__.__name__}('{str(self)}')" + + +class ZMQBindError(ZMQBaseError): + """An error for ``Socket.bind_to_random_port()``. + + See Also + -------- + .Socket.bind_to_random_port + """ + + +class NotDone(ZMQBaseError): + """Raised when timeout is reached while waiting for 0MQ to finish with a Message + + See Also + -------- + .MessageTracker.wait : object for tracking when ZeroMQ is done + """ + + +class ContextTerminated(ZMQError): + """Wrapper for zmq.ETERM + + .. versionadded:: 13.0 + """ + + def __init__(self, errno="ignored", msg="ignored"): + from zmq import ETERM + + super().__init__(ETERM) + + +class Again(ZMQError): + """Wrapper for zmq.EAGAIN + + .. versionadded:: 13.0 + """ + + def __init__(self, errno="ignored", msg="ignored"): + from zmq import EAGAIN + + super().__init__(EAGAIN) + + +class InterruptedSystemCall(ZMQError, InterruptedError): + """Wrapper for EINTR + + This exception should be caught internally in pyzmq + to retry system calls, and not propagate to the user. + + .. versionadded:: 14.7 + """ + + errno = EINTR + strerror: str + + def __init__(self, errno="ignored", msg="ignored"): + super().__init__(EINTR) + + def __str__(self): + s = super().__str__() + return s + ": This call should have been retried. Please report this to pyzmq." + + +def _check_rc(rc, errno=None, error_without_errno=True): + """internal utility for checking zmq return condition + + and raising the appropriate Exception class + """ + if rc == -1: + if errno is None: + from zmq.backend import zmq_errno + + errno = zmq_errno() + if errno == 0 and not error_without_errno: + return + from zmq import EAGAIN, ETERM + + if errno == EINTR: + raise InterruptedSystemCall(errno) + elif errno == EAGAIN: + raise Again(errno) + elif errno == ETERM: + raise ContextTerminated(errno) + else: + raise ZMQError(errno) + + +_zmq_version_info = None +_zmq_version = None + + +class ZMQVersionError(NotImplementedError): + """Raised when a feature is not provided by the linked version of libzmq. + + .. versionadded:: 14.2 + """ + + min_version = None + + def __init__(self, min_version: str, msg: str = "Feature"): + global _zmq_version + if _zmq_version is None: + from zmq import zmq_version + + _zmq_version = zmq_version() + self.msg = msg + self.min_version = min_version + self.version = _zmq_version + + def __repr__(self): + return f"ZMQVersionError('{str(self)}')" + + def __str__(self): + return f"{self.msg} requires libzmq >= {self.min_version}, have {self.version}" + + +def _check_version( + min_version_info: tuple[int] | tuple[int, int] | tuple[int, int, int], + msg: str = "Feature", +): + """Check for libzmq + + raises ZMQVersionError if current zmq version is not at least min_version + + min_version_info is a tuple of integers, and will be compared against zmq.zmq_version_info(). + """ + global _zmq_version_info + if _zmq_version_info is None: + from zmq import zmq_version_info + + _zmq_version_info = zmq_version_info() + if _zmq_version_info < min_version_info: + min_version = ".".join(str(v) for v in min_version_info) + raise ZMQVersionError(min_version, msg) + + +__all__ = [ + "DraftFDWarning", + "ZMQBaseError", + "ZMQBindError", + "ZMQError", + "NotDone", + "ContextTerminated", + "InterruptedSystemCall", + "Again", + "ZMQVersionError", +] diff --git a/source/zmq/eventloop/__init__.py b/source/zmq/eventloop/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a99e143ccf3035db55cb5a34caa13e98ba793e9a --- /dev/null +++ b/source/zmq/eventloop/__init__.py @@ -0,0 +1,5 @@ +"""Tornado eventloop integration for pyzmq""" + +from tornado.ioloop import IOLoop + +__all__ = ['IOLoop'] diff --git a/source/zmq/eventloop/_deprecated.py b/source/zmq/eventloop/_deprecated.py new file mode 100644 index 0000000000000000000000000000000000000000..06bce13c199bcdc58dc611f661dc348b71e272cf --- /dev/null +++ b/source/zmq/eventloop/_deprecated.py @@ -0,0 +1,212 @@ +"""tornado IOLoop API with zmq compatibility + +If you have tornado ≥ 3.0, this is a subclass of tornado's IOLoop, +otherwise we ship a minimal subset of tornado in zmq.eventloop.minitornado. + +The minimal shipped version of tornado's IOLoop does not include +support for concurrent futures - this will only be available if you +have tornado ≥ 3.0. +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import time +import warnings +from typing import Tuple + +from zmq import ETERM, POLLERR, POLLIN, POLLOUT, Poller, ZMQError + +tornado_version: Tuple = () +try: + import tornado + + tornado_version = tornado.version_info +except (ImportError, AttributeError): + pass + +from .minitornado.ioloop import PeriodicCallback, PollIOLoop +from .minitornado.log import gen_log + + +class DelayedCallback(PeriodicCallback): + """Schedules the given callback to be called once. + + The callback is called once, after callback_time milliseconds. + + `start` must be called after the DelayedCallback is created. + + The timeout is calculated from when `start` is called. + """ + + def __init__(self, callback, callback_time, io_loop=None): + # PeriodicCallback require callback_time to be positive + warnings.warn( + """DelayedCallback is deprecated. + Use loop.add_timeout instead.""", + DeprecationWarning, + ) + callback_time = max(callback_time, 1e-3) + super().__init__(callback, callback_time, io_loop) + + def start(self): + """Starts the timer.""" + self._running = True + self._firstrun = True + self._next_timeout = time.time() + self.callback_time / 1000.0 + self.io_loop.add_timeout(self._next_timeout, self._run) + + def _run(self): + if not self._running: + return + self._running = False + try: + self.callback() + except Exception: + gen_log.error("Error in delayed callback", exc_info=True) + + +class ZMQPoller: + """A poller that can be used in the tornado IOLoop. + + This simply wraps a regular zmq.Poller, scaling the timeout + by 1000, so that it is in seconds rather than milliseconds. + """ + + def __init__(self): + self._poller = Poller() + + @staticmethod + def _map_events(events): + """translate IOLoop.READ/WRITE/ERROR event masks into zmq.POLLIN/OUT/ERR""" + z_events = 0 + if events & IOLoop.READ: + z_events |= POLLIN + if events & IOLoop.WRITE: + z_events |= POLLOUT + if events & IOLoop.ERROR: + z_events |= POLLERR + return z_events + + @staticmethod + def _remap_events(z_events): + """translate zmq.POLLIN/OUT/ERR event masks into IOLoop.READ/WRITE/ERROR""" + events = 0 + if z_events & POLLIN: + events |= IOLoop.READ + if z_events & POLLOUT: + events |= IOLoop.WRITE + if z_events & POLLERR: + events |= IOLoop.ERROR + return events + + def register(self, fd, events): + return self._poller.register(fd, self._map_events(events)) + + def modify(self, fd, events): + return self._poller.modify(fd, self._map_events(events)) + + def unregister(self, fd): + return self._poller.unregister(fd) + + def poll(self, timeout): + """poll in seconds rather than milliseconds. + + Event masks will be IOLoop.READ/WRITE/ERROR + """ + z_events = self._poller.poll(1000 * timeout) + return [(fd, self._remap_events(evt)) for (fd, evt) in z_events] + + def close(self): + pass + + +class ZMQIOLoop(PollIOLoop): + """ZMQ subclass of tornado's IOLoop + + Minor modifications, so that .current/.instance return self + """ + + _zmq_impl = ZMQPoller + + def initialize(self, impl=None, **kwargs): + impl = self._zmq_impl() if impl is None else impl + super().initialize(impl=impl, **kwargs) + + @classmethod + def instance(cls, *args, **kwargs): + """Returns a global `IOLoop` instance. + + Most applications have a single, global `IOLoop` running on the + main thread. Use this method to get this instance from + another thread. To get the current thread's `IOLoop`, use `current()`. + """ + # install ZMQIOLoop as the active IOLoop implementation + # when using tornado 3 + if tornado_version >= (3,): + PollIOLoop.configure(cls) + loop = PollIOLoop.instance(*args, **kwargs) + if not isinstance(loop, cls): + warnings.warn( + f"IOLoop.current expected instance of {cls!r}, got {loop!r}", + RuntimeWarning, + stacklevel=2, + ) + return loop + + @classmethod + def current(cls, *args, **kwargs): + """Returns the current thread’s IOLoop.""" + # install ZMQIOLoop as the active IOLoop implementation + # when using tornado 3 + if tornado_version >= (3,): + PollIOLoop.configure(cls) + loop = PollIOLoop.current(*args, **kwargs) + if not isinstance(loop, cls): + warnings.warn( + f"IOLoop.current expected instance of {cls!r}, got {loop!r}", + RuntimeWarning, + stacklevel=2, + ) + return loop + + def start(self): + try: + super().start() + except ZMQError as e: + if e.errno == ETERM: + # quietly return on ETERM + pass + else: + raise + + +# public API name +IOLoop = ZMQIOLoop + + +def install(): + """set the tornado IOLoop instance with the pyzmq IOLoop. + + After calling this function, tornado's IOLoop.instance() and pyzmq's + IOLoop.instance() will return the same object. + + An assertion error will be raised if tornado's IOLoop has been initialized + prior to calling this function. + """ + from tornado import ioloop + + # check if tornado's IOLoop is already initialized to something other + # than the pyzmq IOLoop instance: + assert ( + not ioloop.IOLoop.initialized() + ) or ioloop.IOLoop.instance() is IOLoop.instance(), ( + "tornado IOLoop already initialized" + ) + + if tornado_version >= (3,): + # tornado 3 has an official API for registering new defaults, yay! + ioloop.IOLoop.configure(ZMQIOLoop) + else: + # we have to set the global instance explicitly + ioloop.IOLoop._instance = IOLoop.instance() diff --git a/source/zmq/eventloop/future.py b/source/zmq/eventloop/future.py new file mode 100644 index 0000000000000000000000000000000000000000..0f34f0ef93a8947e622b74e8ebedf79003727e0a --- /dev/null +++ b/source/zmq/eventloop/future.py @@ -0,0 +1,104 @@ +"""Future-returning APIs for tornado coroutines. + +.. seealso:: + + :mod:`zmq.asyncio` + +""" + +# Copyright (c) PyZMQ Developers. +# Distributed under the terms of the Modified BSD License. +from __future__ import annotations + +import asyncio +import warnings +from typing import Any + +from tornado.concurrent import Future +from tornado.ioloop import IOLoop + +import zmq as _zmq +from zmq._future import _AsyncPoller, _AsyncSocket + + +class CancelledError(Exception): + pass + + +class _TornadoFuture(Future): + """Subclass Tornado Future, reinstating cancellation.""" + + def cancel(self): + if self.done(): + return False + self.set_exception(CancelledError()) + return True + + def cancelled(self): + return self.done() and isinstance(self.exception(), CancelledError) + + +class _CancellableTornadoTimeout: + def __init__(self, loop, timeout): + self.loop = loop + self.timeout = timeout + + def cancel(self): + self.loop.remove_timeout(self.timeout) + + +# mixin for tornado/asyncio compatibility + + +class _AsyncTornado: + _Future: type[asyncio.Future] = _TornadoFuture + _READ = IOLoop.READ + _WRITE = IOLoop.WRITE + + def _default_loop(self): + return IOLoop.current() + + def _call_later(self, delay, callback): + io_loop = self._get_loop() + timeout = io_loop.call_later(delay, callback) + return _CancellableTornadoTimeout(io_loop, timeout) + + +class Poller(_AsyncTornado, _AsyncPoller): + def _watch_raw_socket(self, loop, socket, evt, f): + """Schedule callback for a raw socket""" + loop.add_handler(socket, lambda *args: f(), evt) + + def _unwatch_raw_sockets(self, loop, *sockets): + """Unschedule callback for a raw socket""" + for socket in sockets: + loop.remove_handler(socket) + + +class Socket(_AsyncTornado, _AsyncSocket): + _poller_class = Poller + + +Poller._socket_class = Socket + + +class Context(_zmq.Context[Socket]): + # avoid sharing instance with base Context class + _instance = None + + io_loop = None + + @staticmethod + def _socket_class(self, socket_type): + return Socket(self, socket_type) + + def __init__(self: Context, *args: Any, **kwargs: Any) -> None: + io_loop = kwargs.pop('io_loop', None) + if io_loop is not None: + warnings.warn( + f"{self.__class__.__name__}(io_loop) argument is deprecated in pyzmq 22.2." + " The currently active loop will always be used.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(*args, **kwargs) # type: ignore diff --git a/source/zmq/eventloop/ioloop.py b/source/zmq/eventloop/ioloop.py new file mode 100644 index 0000000000000000000000000000000000000000..dccb92a14eba7190c8571c0ee880ab5542e31bb7 --- /dev/null +++ b/source/zmq/eventloop/ioloop.py @@ -0,0 +1,37 @@ +"""tornado IOLoop API with zmq compatibility + +This module is deprecated in pyzmq 17. +To use zmq with tornado, +eventloop integration is no longer required +and tornado itself should be used. +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import warnings + + +def _deprecated(): + warnings.warn( + "zmq.eventloop.ioloop is deprecated in pyzmq 17." + " pyzmq now works with default tornado and asyncio eventloops.", + DeprecationWarning, + stacklevel=3, + ) + + +_deprecated() + +from tornado.ioloop import * # noqa +from tornado.ioloop import IOLoop + +ZMQIOLoop = IOLoop + + +def install(): + """DEPRECATED + + pyzmq 17 no longer needs any special integration for tornado. + """ + _deprecated() diff --git a/source/zmq/eventloop/zmqstream.py b/source/zmq/eventloop/zmqstream.py new file mode 100644 index 0000000000000000000000000000000000000000..e0b5d297b19f973d19271625a2508ec6f24ee858 --- /dev/null +++ b/source/zmq/eventloop/zmqstream.py @@ -0,0 +1,688 @@ +# Derived from iostream.py from tornado 1.0, Copyright 2009 Facebook +# Used under Apache License Version 2.0 +# +# Modifications are Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. +"""A utility class for event-based messaging on a zmq socket using tornado. + +.. seealso:: + + - :mod:`zmq.asyncio` + - :mod:`zmq.eventloop.future` +""" + +from __future__ import annotations + +import asyncio +import pickle +import warnings +from queue import Queue +from typing import Any, Awaitable, Callable, Literal, Sequence, cast, overload + +from tornado.ioloop import IOLoop +from tornado.log import gen_log + +import zmq +import zmq._future +from zmq import POLLIN, POLLOUT +from zmq.utils import jsonapi + + +class ZMQStream: + """A utility class to register callbacks when a zmq socket sends and receives + + For use with tornado IOLoop. + + There are three main methods + + Methods: + + * **on_recv(callback, copy=True):** + register a callback to be run every time the socket has something to receive + * **on_send(callback):** + register a callback to be run every time you call send + * **send_multipart(self, msg, flags=0, copy=False, callback=None):** + perform a send that will trigger the callback + if callback is passed, on_send is also called. + + There are also send_multipart(), send_json(), send_pyobj() + + Three other methods for deactivating the callbacks: + + * **stop_on_recv():** + turn off the recv callback + * **stop_on_send():** + turn off the send callback + + which simply call ``on_(None)``. + + The entire socket interface, excluding direct recv methods, is also + provided, primarily through direct-linking the methods. + e.g. + + >>> stream.bind is stream.socket.bind + True + + + .. versionadded:: 25 + + send/recv callbacks can be coroutines. + + .. versionchanged:: 25 + + ZMQStreams only support base zmq.Socket classes (this has always been true, but not enforced). + If ZMQStreams are created with e.g. async Socket subclasses, + a RuntimeWarning will be shown, + and the socket cast back to the default zmq.Socket + before connecting events. + + Previously, using async sockets (or any zmq.Socket subclass) would result in undefined behavior for the + arguments passed to callback functions. + Now, the callback functions reliably get the return value of the base `zmq.Socket` send/recv_multipart methods + (the list of message frames). + """ + + socket: zmq.Socket + io_loop: IOLoop + poller: zmq.Poller + _send_queue: Queue + _recv_callback: Callable | None + _send_callback: Callable | None + _close_callback: Callable | None + _state: int = 0 + _flushed: bool = False + _recv_copy: bool = False + _fd: int + + def __init__(self, socket: zmq.Socket, io_loop: IOLoop | None = None): + if isinstance(socket, zmq._future._AsyncSocket): + warnings.warn( + f"""ZMQStream only supports the base zmq.Socket class. + + Use zmq.Socket(shadow=other_socket) + or `ctx.socket(zmq.{socket._type_name}, socket_class=zmq.Socket)` + to create a base zmq.Socket object, + no matter what other kind of socket your Context creates. + """, + RuntimeWarning, + stacklevel=2, + ) + # shadow back to base zmq.Socket, + # otherwise callbacks like `on_recv` will get the wrong types. + socket = zmq.Socket(shadow=socket) + self.socket = socket + + # IOLoop.current() is deprecated if called outside the event loop + # that means + self.io_loop = io_loop or IOLoop.current() + self.poller = zmq.Poller() + self._fd = cast(int, self.socket.FD) + + self._send_queue = Queue() + self._recv_callback = None + self._send_callback = None + self._close_callback = None + self._recv_copy = False + self._flushed = False + + self._state = 0 + self._init_io_state() + + # shortcircuit some socket methods + self.bind = self.socket.bind + self.bind_to_random_port = self.socket.bind_to_random_port + self.connect = self.socket.connect + self.setsockopt = self.socket.setsockopt + self.getsockopt = self.socket.getsockopt + self.setsockopt_string = self.socket.setsockopt_string + self.getsockopt_string = self.socket.getsockopt_string + self.setsockopt_unicode = self.socket.setsockopt_unicode + self.getsockopt_unicode = self.socket.getsockopt_unicode + + def stop_on_recv(self): + """Disable callback and automatic receiving.""" + return self.on_recv(None) + + def stop_on_send(self): + """Disable callback on sending.""" + return self.on_send(None) + + def stop_on_err(self): + """DEPRECATED, does nothing""" + gen_log.warn("on_err does nothing, and will be removed") + + def on_err(self, callback: Callable): + """DEPRECATED, does nothing""" + gen_log.warn("on_err does nothing, and will be removed") + + @overload + def on_recv( + self, + callback: Callable[[list[bytes]], Any], + ) -> None: ... + + @overload + def on_recv( + self, + callback: Callable[[list[bytes]], Any], + copy: Literal[True], + ) -> None: ... + + @overload + def on_recv( + self, + callback: Callable[[list[zmq.Frame]], Any], + copy: Literal[False], + ) -> None: ... + + @overload + def on_recv( + self, + callback: Callable[[list[zmq.Frame]], Any] | Callable[[list[bytes]], Any], + copy: bool = ..., + ): ... + + def on_recv( + self, + callback: Callable[[list[zmq.Frame]], Any] | Callable[[list[bytes]], Any], + copy: bool = True, + ) -> None: + """Register a callback for when a message is ready to recv. + + There can be only one callback registered at a time, so each + call to `on_recv` replaces previously registered callbacks. + + on_recv(None) disables recv event polling. + + Use on_recv_stream(callback) instead, to register a callback that will receive + both this ZMQStream and the message, instead of just the message. + + Parameters + ---------- + + callback : callable + callback must take exactly one argument, which will be a + list, as returned by socket.recv_multipart() + if callback is None, recv callbacks are disabled. + copy : bool + copy is passed directly to recv, so if copy is False, + callback will receive Message objects. If copy is True, + then callback will receive bytes/str objects. + + Returns : None + """ + + self._check_closed() + assert callback is None or callable(callback) + self._recv_callback = callback + self._recv_copy = copy + if callback is None: + self._drop_io_state(zmq.POLLIN) + else: + self._add_io_state(zmq.POLLIN) + + @overload + def on_recv_stream( + self, + callback: Callable[[ZMQStream, list[bytes]], Any], + ) -> None: ... + + @overload + def on_recv_stream( + self, + callback: Callable[[ZMQStream, list[bytes]], Any], + copy: Literal[True], + ) -> None: ... + + @overload + def on_recv_stream( + self, + callback: Callable[[ZMQStream, list[zmq.Frame]], Any], + copy: Literal[False], + ) -> None: ... + + @overload + def on_recv_stream( + self, + callback: ( + Callable[[ZMQStream, list[zmq.Frame]], Any] + | Callable[[ZMQStream, list[bytes]], Any] + ), + copy: bool = ..., + ): ... + + def on_recv_stream( + self, + callback: ( + Callable[[ZMQStream, list[zmq.Frame]], Any] + | Callable[[ZMQStream, list[bytes]], Any] + ), + copy: bool = True, + ): + """Same as on_recv, but callback will get this stream as first argument + + callback must take exactly two arguments, as it will be called as:: + + callback(stream, msg) + + Useful when a single callback should be used with multiple streams. + """ + if callback is None: + self.stop_on_recv() + else: + + def stream_callback(msg): + return callback(self, msg) + + self.on_recv(stream_callback, copy=copy) + + def on_send( + self, callback: Callable[[Sequence[Any], zmq.MessageTracker | None], Any] + ): + """Register a callback to be called on each send + + There will be two arguments:: + + callback(msg, status) + + * `msg` will be the list of sendable objects that was just sent + * `status` will be the return result of socket.send_multipart(msg) - + MessageTracker or None. + + Non-copying sends return a MessageTracker object whose + `done` attribute will be True when the send is complete. + This allows users to track when an object is safe to write to + again. + + The second argument will always be None if copy=True + on the send. + + Use on_send_stream(callback) to register a callback that will be passed + this ZMQStream as the first argument, in addition to the other two. + + on_send(None) disables recv event polling. + + Parameters + ---------- + + callback : callable + callback must take exactly two arguments, which will be + the message being sent (always a list), + and the return result of socket.send_multipart(msg) - + MessageTracker or None. + + if callback is None, send callbacks are disabled. + """ + + self._check_closed() + assert callback is None or callable(callback) + self._send_callback = callback + + def on_send_stream( + self, + callback: Callable[[ZMQStream, Sequence[Any], zmq.MessageTracker | None], Any], + ): + """Same as on_send, but callback will get this stream as first argument + + Callback will be passed three arguments:: + + callback(stream, msg, status) + + Useful when a single callback should be used with multiple streams. + """ + if callback is None: + self.stop_on_send() + else: + self.on_send(lambda msg, status: callback(self, msg, status)) + + def send(self, msg, flags=0, copy=True, track=False, callback=None, **kwargs): + """Send a message, optionally also register a new callback for sends. + See zmq.socket.send for details. + """ + return self.send_multipart( + [msg], flags=flags, copy=copy, track=track, callback=callback, **kwargs + ) + + def send_multipart( + self, + msg: Sequence[Any], + flags: int = 0, + copy: bool = True, + track: bool = False, + callback: Callable | None = None, + **kwargs: Any, + ) -> None: + """Send a multipart message, optionally also register a new callback for sends. + See zmq.socket.send_multipart for details. + """ + kwargs.update(dict(flags=flags, copy=copy, track=track)) + self._send_queue.put((msg, kwargs)) + callback = callback or self._send_callback + if callback is not None: + self.on_send(callback) + else: + # noop callback + self.on_send(lambda *args: None) + self._add_io_state(zmq.POLLOUT) + + def send_string( + self, + u: str, + flags: int = 0, + encoding: str = 'utf-8', + callback: Callable | None = None, + **kwargs: Any, + ): + """Send a unicode message with an encoding. + See zmq.socket.send_unicode for details. + """ + if not isinstance(u, str): + raise TypeError("unicode/str objects only") + return self.send(u.encode(encoding), flags=flags, callback=callback, **kwargs) + + send_unicode = send_string + + def send_json( + self, + obj: Any, + flags: int = 0, + callback: Callable | None = None, + **kwargs: Any, + ): + """Send json-serialized version of an object. + See zmq.socket.send_json for details. + """ + msg = jsonapi.dumps(obj) + return self.send(msg, flags=flags, callback=callback, **kwargs) + + def send_pyobj( + self, + obj: Any, + flags: int = 0, + protocol: int = -1, + callback: Callable | None = None, + **kwargs: Any, + ): + """Send a Python object as a message using pickle to serialize. + + See zmq.socket.send_json for details. + """ + msg = pickle.dumps(obj, protocol) + return self.send(msg, flags, callback=callback, **kwargs) + + def _finish_flush(self): + """callback for unsetting _flushed flag.""" + self._flushed = False + + def flush(self, flag: int = zmq.POLLIN | zmq.POLLOUT, limit: int | None = None): + """Flush pending messages. + + This method safely handles all pending incoming and/or outgoing messages, + bypassing the inner loop, passing them to the registered callbacks. + + A limit can be specified, to prevent blocking under high load. + + flush will return the first time ANY of these conditions are met: + * No more events matching the flag are pending. + * the total number of events handled reaches the limit. + + Note that if ``flag|POLLIN != 0``, recv events will be flushed even if no callback + is registered, unlike normal IOLoop operation. This allows flush to be + used to remove *and ignore* incoming messages. + + Parameters + ---------- + flag : int + default=POLLIN|POLLOUT + 0MQ poll flags. + If flag|POLLIN, recv events will be flushed. + If flag|POLLOUT, send events will be flushed. + Both flags can be set at once, which is the default. + limit : None or int, optional + The maximum number of messages to send or receive. + Both send and recv count against this limit. + + Returns + ------- + int : + count of events handled (both send and recv) + """ + self._check_closed() + # unset self._flushed, so callbacks will execute, in case flush has + # already been called this iteration + already_flushed = self._flushed + self._flushed = False + # initialize counters + count = 0 + + def update_flag(): + """Update the poll flag, to prevent registering POLLOUT events + if we don't have pending sends.""" + return flag & zmq.POLLIN | (self.sending() and flag & zmq.POLLOUT) + + flag = update_flag() + if not flag: + # nothing to do + return 0 + self.poller.register(self.socket, flag) + events = self.poller.poll(0) + while events and (not limit or count < limit): + s, event = events[0] + if event & POLLIN: # receiving + self._handle_recv() + count += 1 + if self.socket is None: + # break if socket was closed during callback + break + if event & POLLOUT and self.sending(): + self._handle_send() + count += 1 + if self.socket is None: + # break if socket was closed during callback + break + + flag = update_flag() + if flag: + self.poller.register(self.socket, flag) + events = self.poller.poll(0) + else: + events = [] + if count: # only bypass loop if we actually flushed something + # skip send/recv callbacks this iteration + self._flushed = True + # reregister them at the end of the loop + if not already_flushed: # don't need to do it again + self.io_loop.add_callback(self._finish_flush) + elif already_flushed: + self._flushed = True + + # update ioloop poll state, which may have changed + self._rebuild_io_state() + return count + + def set_close_callback(self, callback: Callable | None): + """Call the given callback when the stream is closed.""" + self._close_callback = callback + + def close(self, linger: int | None = None) -> None: + """Close this stream.""" + if self.socket is not None: + if self.socket.closed: + # fallback on raw fd for closed sockets + # hopefully this happened promptly after close, + # otherwise somebody else may have the FD + warnings.warn( + f"Unregistering FD {self._fd} after closing socket. " + "This could result in unregistering handlers for the wrong socket. " + "Please use stream.close() instead of closing the socket directly.", + stacklevel=2, + ) + self.io_loop.remove_handler(self._fd) + else: + self.io_loop.remove_handler(self.socket) + self.socket.close(linger) + self.socket = None # type: ignore + if self._close_callback: + self._run_callback(self._close_callback) + + def receiving(self) -> bool: + """Returns True if we are currently receiving from the stream.""" + return self._recv_callback is not None + + def sending(self) -> bool: + """Returns True if we are currently sending to the stream.""" + return not self._send_queue.empty() + + def closed(self) -> bool: + if self.socket is None: + return True + if self.socket.closed: + # underlying socket has been closed, but not by us! + # trigger our cleanup + self.close() + return True + return False + + def _run_callback(self, callback, *args, **kwargs): + """Wrap running callbacks in try/except to allow us to + close our socket.""" + try: + f = callback(*args, **kwargs) + if isinstance(f, Awaitable): + f = asyncio.ensure_future(f) + else: + f = None + except Exception: + gen_log.error("Uncaught exception in ZMQStream callback", exc_info=True) + # Re-raise the exception so that IOLoop.handle_callback_exception + # can see it and log the error + raise + + if f is not None: + # handle async callbacks + def _log_error(f): + try: + f.result() + except Exception: + gen_log.error( + "Uncaught exception in ZMQStream callback", exc_info=True + ) + + f.add_done_callback(_log_error) + + def _handle_events(self, fd, events): + """This method is the actual handler for IOLoop, that gets called whenever + an event on my socket is posted. It dispatches to _handle_recv, etc.""" + if not self.socket: + gen_log.warning("Got events for closed stream %s", self) + return + try: + zmq_events = self.socket.EVENTS + except zmq.ContextTerminated: + gen_log.warning("Got events for stream %s after terminating context", self) + # trigger close check, this will unregister callbacks + self.closed() + return + except zmq.ZMQError as e: + # run close check + # shadow sockets may have been closed elsewhere, + # which should show up as ENOTSOCK here + if self.closed(): + gen_log.warning( + "Got events for stream %s attached to closed socket: %s", self, e + ) + else: + gen_log.error("Error getting events for %s: %s", self, e) + return + try: + # dispatch events: + if zmq_events & zmq.POLLIN and self.receiving(): + self._handle_recv() + if not self.socket: + return + if zmq_events & zmq.POLLOUT and self.sending(): + self._handle_send() + if not self.socket: + return + + # rebuild the poll state + self._rebuild_io_state() + except Exception: + gen_log.error("Uncaught exception in zmqstream callback", exc_info=True) + raise + + def _handle_recv(self): + """Handle a recv event.""" + if self._flushed: + return + try: + msg = self.socket.recv_multipart(zmq.NOBLOCK, copy=self._recv_copy) + except zmq.ZMQError as e: + if e.errno == zmq.EAGAIN: + # state changed since poll event + pass + else: + raise + else: + if self._recv_callback: + callback = self._recv_callback + self._run_callback(callback, msg) + + def _handle_send(self): + """Handle a send event.""" + if self._flushed: + return + if not self.sending(): + gen_log.error("Shouldn't have handled a send event") + return + + msg, kwargs = self._send_queue.get() + try: + status = self.socket.send_multipart(msg, **kwargs) + except zmq.ZMQError as e: + gen_log.error("SEND Error: %s", e) + status = e + if self._send_callback: + callback = self._send_callback + self._run_callback(callback, msg, status) + + def _check_closed(self): + if not self.socket: + raise OSError("Stream is closed") + + def _rebuild_io_state(self): + """rebuild io state based on self.sending() and receiving()""" + if self.socket is None: + return + state = 0 + if self.receiving(): + state |= zmq.POLLIN + if self.sending(): + state |= zmq.POLLOUT + + self._state = state + self._update_handler(state) + + def _add_io_state(self, state): + """Add io_state to poller.""" + self._state = self._state | state + self._update_handler(self._state) + + def _drop_io_state(self, state): + """Stop poller from watching an io_state.""" + self._state = self._state & (~state) + self._update_handler(self._state) + + def _update_handler(self, state): + """Update IOLoop handler with state.""" + if self.socket is None: + return + + if state & self.socket.events: + # events still exist that haven't been processed + # explicitly schedule handling to avoid missing events due to edge-triggered FDs + self.io_loop.add_callback(lambda: self._handle_events(self.socket, 0)) + + def _init_io_state(self): + """initialize the ioloop event handler""" + self.io_loop.add_handler(self.socket, self._handle_events, self.io_loop.READ) diff --git a/source/zmq/green/__init__.py b/source/zmq/green/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8543864b00d60f236c05463f62494a483c739871 --- /dev/null +++ b/source/zmq/green/__init__.py @@ -0,0 +1,48 @@ +# ----------------------------------------------------------------------------- +# Copyright (C) 2011-2012 Travis Cline +# +# This file is part of pyzmq +# It is adapted from upstream project zeromq_gevent under the New BSD License +# +# Distributed under the terms of the New BSD License. The full license is in +# the file LICENSE.BSD, distributed as part of this software. +# ----------------------------------------------------------------------------- + +"""zmq.green - gevent compatibility with zeromq. + +Usage +----- + +Instead of importing zmq directly, do so in the following manner: + +.. + + import zmq.green as zmq + + +Any calls that would have blocked the current thread will now only block the +current green thread. + +This compatibility is accomplished by ensuring the nonblocking flag is set +before any blocking operation and the ØMQ file descriptor is polled internally +to trigger needed events. +""" + +from __future__ import annotations + +from typing import List + +import zmq as _zmq +from zmq import * +from zmq.green.core import _Context, _Socket +from zmq.green.poll import _Poller + +Context = _Context # type: ignore +Socket = _Socket # type: ignore +Poller = _Poller # type: ignore + +from zmq.green.device import device # type: ignore + +__all__: list[str] = [] +# adding `__all__` to __init__.pyi gets mypy all confused +__all__.extend(_zmq.__all__) # type: ignore diff --git a/source/zmq/green/core.py b/source/zmq/green/core.py new file mode 100644 index 0000000000000000000000000000000000000000..ddfccc205a6500eaf3e81df77bd35f4af274fac5 --- /dev/null +++ b/source/zmq/green/core.py @@ -0,0 +1,334 @@ +# ----------------------------------------------------------------------------- +# Copyright (C) 2011-2012 Travis Cline +# +# This file is part of pyzmq +# It is adapted from upstream project zeromq_gevent under the New BSD License +# +# Distributed under the terms of the New BSD License. The full license is in +# the file LICENSE.BSD, distributed as part of this software. +# ----------------------------------------------------------------------------- + +"""This module wraps the :class:`Socket` and :class:`Context` found in :mod:`pyzmq ` to be non blocking""" + +from __future__ import annotations + +import sys +import time +import warnings + +import gevent +from gevent.event import AsyncResult +from gevent.hub import get_hub + +import zmq +from zmq import Context as _original_Context +from zmq import Socket as _original_Socket + +from .poll import _Poller + +if hasattr(zmq, 'RCVTIMEO'): + TIMEOS: tuple = (zmq.RCVTIMEO, zmq.SNDTIMEO) +else: + TIMEOS = () + + +def _stop(evt): + """simple wrapper for stopping an Event, allowing for method rename in gevent 1.0""" + try: + evt.stop() + except AttributeError: + # gevent<1.0 compat + evt.cancel() + + +class _Socket(_original_Socket): + """Green version of :class:`zmq.Socket` + + The following methods are overridden: + + * send + * recv + + To ensure that the ``zmq.NOBLOCK`` flag is set and that sending or receiving + is deferred to the hub if a ``zmq.EAGAIN`` (retry) error is raised. + + The `__state_changed` method is triggered when the zmq.FD for the socket is + marked as readable and triggers the necessary read and write events (which + are waited for in the recv and send methods). + + Some double underscore prefixes are used to minimize pollution of + :class:`zmq.Socket`'s namespace. + """ + + __in_send_multipart = False + __in_recv_multipart = False + __writable = None + __readable = None + _state_event = None + _gevent_bug_timeout = 11.6 # timeout for not trusting gevent + _debug_gevent = False # turn on if you think gevent is missing events + _poller_class = _Poller + _repr_cls = "zmq.green.Socket" + + def __init__(self, *a, **kw): + super().__init__(*a, **kw) + self.__in_send_multipart = False + self.__in_recv_multipart = False + self.__setup_events() + + def __del__(self): + self.close() + + def close(self, linger=None): + super().close(linger) + self.__cleanup_events() + + def __cleanup_events(self): + # close the _state_event event, keeps the number of active file descriptors down + if getattr(self, '_state_event', None): + _stop(self._state_event) + self._state_event = None + # if the socket has entered a close state resume any waiting greenlets + self.__writable.set() + self.__readable.set() + + def __setup_events(self): + self.__readable = AsyncResult() + self.__writable = AsyncResult() + self.__readable.set() + self.__writable.set() + + try: + self._state_event = get_hub().loop.io( + self.getsockopt(zmq.FD), 1 + ) # read state watcher + self._state_event.start(self.__state_changed) + except AttributeError: + # for gevent<1.0 compatibility + from gevent.core import read_event + + self._state_event = read_event( + self.getsockopt(zmq.FD), self.__state_changed, persist=True + ) + + def __state_changed(self, event=None, _evtype=None): + if self.closed: + self.__cleanup_events() + return + try: + # avoid triggering __state_changed from inside __state_changed + events = super().getsockopt(zmq.EVENTS) + except zmq.ZMQError as exc: + self.__writable.set_exception(exc) + self.__readable.set_exception(exc) + else: + if events & zmq.POLLOUT: + self.__writable.set() + if events & zmq.POLLIN: + self.__readable.set() + + def _wait_write(self): + assert self.__writable.ready(), "Only one greenlet can be waiting on this event" + self.__writable = AsyncResult() + # timeout is because libzmq cannot be trusted to properly signal a new send event: + # this is effectively a maximum poll interval of 1s + tic = time.time() + dt = self._gevent_bug_timeout + if dt: + timeout = gevent.Timeout(seconds=dt) + else: + timeout = None + try: + if timeout: + timeout.start() + self.__writable.get(block=True) + except gevent.Timeout as t: + if t is not timeout: + raise + toc = time.time() + # gevent bug: get can raise timeout even on clean return + # don't display zmq bug warning for gevent bug (this is getting ridiculous) + if ( + self._debug_gevent + and timeout + and toc - tic > dt + and self.getsockopt(zmq.EVENTS) & zmq.POLLOUT + ): + print( + f"BUG: gevent may have missed a libzmq send event on {self.FD}!", + file=sys.stderr, + ) + finally: + if timeout: + timeout.close() + self.__writable.set() + + def _wait_read(self): + assert self.__readable.ready(), "Only one greenlet can be waiting on this event" + self.__readable = AsyncResult() + # timeout is because libzmq cannot always be trusted to play nice with libevent. + # I can only confirm that this actually happens for send, but lets be symmetrical + # with our dirty hacks. + # this is effectively a maximum poll interval of 1s + tic = time.time() + dt = self._gevent_bug_timeout + if dt: + timeout = gevent.Timeout(seconds=dt) + else: + timeout = None + try: + if timeout: + timeout.start() + self.__readable.get(block=True) + except gevent.Timeout as t: + if t is not timeout: + raise + toc = time.time() + # gevent bug: get can raise timeout even on clean return + # don't display zmq bug warning for gevent bug (this is getting ridiculous) + if ( + self._debug_gevent + and timeout + and toc - tic > dt + and self.getsockopt(zmq.EVENTS) & zmq.POLLIN + ): + print( + f"BUG: gevent may have missed a libzmq recv event on {self.FD}!", + file=sys.stderr, + ) + finally: + if timeout: + timeout.close() + self.__readable.set() + + def send(self, data, flags=0, copy=True, track=False, **kwargs): + """send, which will only block current greenlet + + state_changed always fires exactly once (success or fail) at the + end of this method. + """ + + # if we're given the NOBLOCK flag act as normal and let the EAGAIN get raised + if flags & zmq.NOBLOCK: + try: + msg = super().send(data, flags, copy, track, **kwargs) + finally: + if not self.__in_send_multipart: + self.__state_changed() + return msg + # ensure the zmq.NOBLOCK flag is part of flags + flags |= zmq.NOBLOCK + while True: # Attempt to complete this operation indefinitely, blocking the current greenlet + try: + # attempt the actual call + msg = super().send(data, flags, copy, track) + except zmq.ZMQError as e: + # if the raised ZMQError is not EAGAIN, reraise + if e.errno != zmq.EAGAIN: + if not self.__in_send_multipart: + self.__state_changed() + raise + else: + if not self.__in_send_multipart: + self.__state_changed() + return msg + # defer to the event loop until we're notified the socket is writable + self._wait_write() + + def recv(self, flags=0, copy=True, track=False): + """recv, which will only block current greenlet + + state_changed always fires exactly once (success or fail) at the + end of this method. + """ + if flags & zmq.NOBLOCK: + try: + msg = super().recv(flags, copy, track) + finally: + if not self.__in_recv_multipart: + self.__state_changed() + return msg + + flags |= zmq.NOBLOCK + while True: + try: + msg = super().recv(flags, copy, track) + except zmq.ZMQError as e: + if e.errno != zmq.EAGAIN: + if not self.__in_recv_multipart: + self.__state_changed() + raise + else: + if not self.__in_recv_multipart: + self.__state_changed() + return msg + self._wait_read() + + def recv_into(self, buffer, /, *, nbytes=0, flags=0): + """recv_into, which will only block current greenlet""" + if flags & zmq.DONTWAIT: + return super().recv_into(buffer, nbytes=nbytes, flags=flags) + flags |= zmq.DONTWAIT + while True: + try: + recvd = super().recv_into(buffer, nbytes=nbytes, flags=flags) + except zmq.ZMQError as e: + if e.errno != zmq.EAGAIN: + self.__state_changed() + raise + else: + self.__state_changed() + return recvd + self._wait_read() + + def send_multipart(self, *args, **kwargs): + """wrap send_multipart to prevent state_changed on each partial send""" + self.__in_send_multipart = True + try: + msg = super().send_multipart(*args, **kwargs) + finally: + self.__in_send_multipart = False + self.__state_changed() + return msg + + def recv_multipart(self, *args, **kwargs): + """wrap recv_multipart to prevent state_changed on each partial recv""" + self.__in_recv_multipart = True + try: + msg = super().recv_multipart(*args, **kwargs) + finally: + self.__in_recv_multipart = False + self.__state_changed() + return msg + + def get(self, opt): + """trigger state_changed on getsockopt(EVENTS)""" + if opt in TIMEOS: + warnings.warn( + "TIMEO socket options have no effect in zmq.green", UserWarning + ) + optval = super().get(opt) + if opt == zmq.EVENTS: + self.__state_changed() + return optval + + def set(self, opt, val): + """set socket option""" + if opt in TIMEOS: + warnings.warn( + "TIMEO socket options have no effect in zmq.green", UserWarning + ) + return super().set(opt, val) + + +class _Context(_original_Context[_Socket]): + """Replacement for :class:`zmq.Context` + + Ensures that the greened Socket above is used in calls to `socket`. + """ + + _socket_class = _Socket + _repr_cls = "zmq.green.Context" + + # avoid sharing instance with base Context class + _instance = None diff --git a/source/zmq/green/device.py b/source/zmq/green/device.py new file mode 100644 index 0000000000000000000000000000000000000000..f9beae39eff8f023997aa075afd2a1baf80ac196 --- /dev/null +++ b/source/zmq/green/device.py @@ -0,0 +1,34 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. +from __future__ import annotations + +import zmq +from zmq.green import Poller + + +def device(device_type, isocket, osocket): + """Start a zeromq device (gevent-compatible). + + Unlike the true zmq.device, this does not release the GIL. + + Parameters + ---------- + device_type : (QUEUE, FORWARDER, STREAMER) + The type of device to start (ignored). + isocket : Socket + The Socket instance for the incoming traffic. + osocket : Socket + The Socket instance for the outbound traffic. + """ + p = Poller() + if osocket == -1: + osocket = isocket + p.register(isocket, zmq.POLLIN) + p.register(osocket, zmq.POLLIN) + + while True: + events = dict(p.poll()) + if isocket in events: + osocket.send_multipart(isocket.recv_multipart()) + if osocket in events: + isocket.send_multipart(osocket.recv_multipart()) diff --git a/source/zmq/green/eventloop/__init__.py b/source/zmq/green/eventloop/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b0ef0272c4e93a6f7a9955e9ac16f9a6ddf148e2 --- /dev/null +++ b/source/zmq/green/eventloop/__init__.py @@ -0,0 +1,3 @@ +from zmq.green.eventloop.ioloop import IOLoop + +__all__ = ['IOLoop'] diff --git a/source/zmq/green/eventloop/ioloop.py b/source/zmq/green/eventloop/ioloop.py new file mode 100644 index 0000000000000000000000000000000000000000..50e9151469570311c2fbc6af9fbb8d04b2d594f4 --- /dev/null +++ b/source/zmq/green/eventloop/ioloop.py @@ -0,0 +1 @@ +from zmq.eventloop.ioloop import * # noqa diff --git a/source/zmq/green/eventloop/zmqstream.py b/source/zmq/green/eventloop/zmqstream.py new file mode 100644 index 0000000000000000000000000000000000000000..c06c2ab50e1b7e4ef1501e78a7ac85f2fa43d6c3 --- /dev/null +++ b/source/zmq/green/eventloop/zmqstream.py @@ -0,0 +1,11 @@ +from zmq.eventloop import zmqstream +from zmq.green.eventloop.ioloop import IOLoop + + +class ZMQStream(zmqstream.ZMQStream): + def __init__(self, socket, io_loop=None): + io_loop = io_loop or IOLoop.instance() + super().__init__(socket, io_loop=io_loop) + + +__all__ = ["ZMQStream"] diff --git a/source/zmq/green/poll.py b/source/zmq/green/poll.py new file mode 100644 index 0000000000000000000000000000000000000000..7aa1bedd461aac26c26978b29e6e92ec0242dbf0 --- /dev/null +++ b/source/zmq/green/poll.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import gevent +from gevent import select + +import zmq +from zmq import Poller as _original_Poller + + +class _Poller(_original_Poller): + """Replacement for :class:`zmq.Poller` + + Ensures that the greened Poller below is used in calls to + :meth:`zmq.Poller.poll`. + """ + + _gevent_bug_timeout = 1.33 # minimum poll interval, for working around gevent bug + + def _get_descriptors(self): + """Returns three elements tuple with socket descriptors ready + for gevent.select.select + """ + rlist = [] + wlist = [] + xlist = [] + + for socket, flags in self.sockets: + if isinstance(socket, zmq.Socket): + rlist.append(socket.getsockopt(zmq.FD)) + continue + elif isinstance(socket, int): + fd = socket + elif hasattr(socket, 'fileno'): + try: + fd = int(socket.fileno()) + except Exception: + raise ValueError('fileno() must return an valid integer fd') + else: + raise TypeError( + 'Socket must be a 0MQ socket, an integer fd ' + f'or have a fileno() method: {socket!r}' + ) + + if flags & zmq.POLLIN: + rlist.append(fd) + if flags & zmq.POLLOUT: + wlist.append(fd) + if flags & zmq.POLLERR: + xlist.append(fd) + + return (rlist, wlist, xlist) + + def poll(self, timeout=-1): + """Overridden method to ensure that the green version of + Poller is used. + + Behaves the same as :meth:`zmq.core.Poller.poll` + """ + + if timeout is None: + timeout = -1 + + if timeout < 0: + timeout = -1 + + rlist = None + wlist = None + xlist = None + + if timeout > 0: + tout = gevent.Timeout.start_new(timeout / 1000.0) + else: + tout = None + + try: + # Loop until timeout or events available + rlist, wlist, xlist = self._get_descriptors() + while True: + events = super().poll(0) + if events or timeout == 0: + return events + + # wait for activity on sockets in a green way + # set a minimum poll frequency, + # because gevent < 1.0 cannot be trusted to catch edge-triggered FD events + _bug_timeout = gevent.Timeout.start_new(self._gevent_bug_timeout) + try: + select.select(rlist, wlist, xlist) + except gevent.Timeout as t: + if t is not _bug_timeout: + raise + finally: + _bug_timeout.cancel() + + except gevent.Timeout as t: + if t is not tout: + raise + return [] + finally: + if timeout > 0: + tout.cancel() diff --git a/source/zmq/log/__init__.py b/source/zmq/log/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/source/zmq/log/__main__.py b/source/zmq/log/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..98c6b97c47381d413fa861702a926e1d57d39a8a --- /dev/null +++ b/source/zmq/log/__main__.py @@ -0,0 +1,135 @@ +"""pyzmq log watcher. + +Easily view log messages published by the PUBHandler in zmq.log.handlers + +Designed to be run as an executable module - try this to see options: + python -m zmq.log -h + +Subscribes to the '' (empty string) topic by default which means it will work +out-of-the-box with a PUBHandler object instantiated with default settings. +If you change the root topic with PUBHandler.setRootTopic() you must pass +the value to this script with the --topic argument. + +Note that the default formats for the PUBHandler object selectively include +the log level in the message. This creates redundancy in this script as it +always prints the topic of the message, which includes the log level. +Consider overriding the default formats with PUBHandler.setFormat() to +avoid this issue. + +""" + +# encoding: utf-8 + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import argparse +from datetime import datetime +from typing import Dict + +import zmq + +parser = argparse.ArgumentParser('ZMQ Log Watcher') +parser.add_argument('zmq_pub_url', type=str, help='URL to a ZMQ publisher socket.') +parser.add_argument( + '-t', + '--topic', + type=str, + default='', + help='Only receive messages that start with this topic.', +) +parser.add_argument( + '--timestamp', action='store_true', help='Append local time to the log messages.' +) +parser.add_argument( + '--separator', + type=str, + default=' | ', + help='String to print between topic and message.', +) +parser.add_argument( + '--dateformat', + type=str, + default='%Y-%d-%m %H:%M', + help='Set alternative date format for use with --timestamp.', +) +parser.add_argument( + '--align', + action='store_true', + default=False, + help='Try to align messages by the width of their topics.', +) +parser.add_argument( + '--color', + action='store_true', + default=False, + help='Color the output based on the error level. Requires the colorama module.', +) +args = parser.parse_args() + + +if args.color: + import colorama + + colorama.init() + colors = { + 'DEBUG': colorama.Fore.LIGHTCYAN_EX, + 'INFO': colorama.Fore.LIGHTWHITE_EX, + 'WARNING': colorama.Fore.YELLOW, + 'ERROR': colorama.Fore.LIGHTRED_EX, + 'CRITICAL': colorama.Fore.LIGHTRED_EX, + '__RESET__': colorama.Fore.RESET, + } +else: + colors = {} + + +ctx = zmq.Context() +sub = ctx.socket(zmq.SUB) +sub.subscribe(args.topic.encode("utf8")) +sub.connect(args.zmq_pub_url) + +topic_widths: Dict[int, int] = {} + +while True: + try: + if sub.poll(10, zmq.POLLIN): + topic, msg = sub.recv_multipart() + topics = topic.decode('utf8').strip().split('.') + + if args.align: + topics.extend(' ' for extra in range(len(topics), len(topic_widths))) + aligned_parts = [] + for key, part in enumerate(topics): + topic_widths[key] = max(len(part), topic_widths.get(key, 0)) + fmt = ''.join(('{:<', str(topic_widths[key]), '}')) + aligned_parts.append(fmt.format(part)) + + if len(topics) == 1: + level = topics[0] + else: + level = topics[1] + + fields = { + 'msg': msg.decode('utf8').strip(), + 'ts': ( + datetime.now().strftime(args.dateformat) + ' ' + if args.timestamp + else '' + ), + 'aligned': ( + '.'.join(aligned_parts) + if args.align + else topic.decode('utf8').strip() + ), + 'color': colors.get(level, ''), + 'color_rst': colors.get('__RESET__', ''), + 'sep': args.separator, + } + print('{ts}{color}{aligned}{sep}{msg}{color_rst}'.format(**fields)) + except KeyboardInterrupt: + break + +sub.disconnect(args.zmq_pub_url) +if args.color: + print(colorama.Fore.RESET) diff --git a/source/zmq/log/handlers.py b/source/zmq/log/handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..8138d2eed96c676261688f74ba9fe91dbc29d9c0 --- /dev/null +++ b/source/zmq/log/handlers.py @@ -0,0 +1,232 @@ +"""pyzmq logging handlers. + +This mainly defines the PUBHandler object for publishing logging messages over +a zmq.PUB socket. + +The PUBHandler can be used with the regular logging module, as in:: + + >>> import logging + >>> handler = PUBHandler('tcp://127.0.0.1:12345') + >>> handler.root_topic = 'foo' + >>> logger = logging.getLogger('foobar') + >>> logger.setLevel(logging.DEBUG) + >>> logger.addHandler(handler) + +Or using ``dictConfig``, as in:: + + >>> from logging.config import dictConfig + >>> socket = Context.instance().socket(PUB) + >>> socket.connect('tcp://127.0.0.1:12345') + >>> dictConfig({ + >>> 'version': 1, + >>> 'handlers': { + >>> 'zmq': { + >>> 'class': 'zmq.log.handlers.PUBHandler', + >>> 'level': logging.DEBUG, + >>> 'root_topic': 'foo', + >>> 'interface_or_socket': socket + >>> } + >>> }, + >>> 'root': { + >>> 'level': 'DEBUG', + >>> 'handlers': ['zmq'], + >>> } + >>> }) + + +After this point, all messages logged by ``logger`` will be published on the +PUB socket. + +Code adapted from StarCluster: + + https://github.com/jtriley/StarCluster/blob/StarCluster-0.91/starcluster/logger.py +""" + +from __future__ import annotations + +import logging +from copy import copy + +import zmq + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +TOPIC_DELIM = "::" # delimiter for splitting topics on the receiving end. + + +class PUBHandler(logging.Handler): + """A basic logging handler that emits log messages through a PUB socket. + + Takes a PUB socket already bound to interfaces or an interface to bind to. + + Example:: + + sock = context.socket(zmq.PUB) + sock.bind('inproc://log') + handler = PUBHandler(sock) + + Or:: + + handler = PUBHandler('inproc://loc') + + These are equivalent. + + Log messages handled by this handler are broadcast with ZMQ topics + ``this.root_topic`` comes first, followed by the log level + (DEBUG,INFO,etc.), followed by any additional subtopics specified in the + message by: log.debug("subtopic.subsub::the real message") + """ + + ctx: zmq.Context + socket: zmq.Socket + + def __init__( + self, + interface_or_socket: str | zmq.Socket, + context: zmq.Context | None = None, + root_topic: str = '', + ) -> None: + logging.Handler.__init__(self) + self.root_topic = root_topic + self.formatters = { + logging.DEBUG: logging.Formatter( + "%(levelname)s %(filename)s:%(lineno)d - %(message)s\n" + ), + logging.INFO: logging.Formatter("%(message)s\n"), + logging.WARN: logging.Formatter( + "%(levelname)s %(filename)s:%(lineno)d - %(message)s\n" + ), + logging.ERROR: logging.Formatter( + "%(levelname)s %(filename)s:%(lineno)d - %(message)s - %(exc_info)s\n" + ), + logging.CRITICAL: logging.Formatter( + "%(levelname)s %(filename)s:%(lineno)d - %(message)s\n" + ), + } + if isinstance(interface_or_socket, zmq.Socket): + self.socket = interface_or_socket + self.ctx = self.socket.context + else: + self.ctx = context or zmq.Context() + self.socket = self.ctx.socket(zmq.PUB) + self.socket.bind(interface_or_socket) + + @property + def root_topic(self) -> str: + return self._root_topic + + @root_topic.setter + def root_topic(self, value: str): + self.setRootTopic(value) + + def setRootTopic(self, root_topic: str): + """Set the root topic for this handler. + + This value is prepended to all messages published by this handler, and it + defaults to the empty string ''. When you subscribe to this socket, you must + set your subscription to an empty string, or to at least the first letter of + the binary representation of this string to ensure you receive any messages + from this handler. + + If you use the default empty string root topic, messages will begin with + the binary representation of the log level string (INFO, WARN, etc.). + Note that ZMQ SUB sockets can have multiple subscriptions. + """ + if isinstance(root_topic, bytes): + root_topic = root_topic.decode("utf8") + self._root_topic = root_topic + + def setFormatter(self, fmt, level=logging.NOTSET): + """Set the Formatter for this handler. + + If no level is provided, the same format is used for all levels. This + will overwrite all selective formatters set in the object constructor. + """ + if level == logging.NOTSET: + for fmt_level in self.formatters.keys(): + self.formatters[fmt_level] = fmt + else: + self.formatters[level] = fmt + + def format(self, record): + """Format a record.""" + return self.formatters[record.levelno].format(record) + + def emit(self, record): + """Emit a log message on my socket.""" + + # LogRecord.getMessage explicitly allows msg to be anything _castable_ to a str + try: + topic, msg = str(record.msg).split(TOPIC_DELIM, 1) + except ValueError: + topic = "" + else: + # copy to avoid mutating LogRecord in-place + record = copy(record) + record.msg = msg + + try: + bmsg = self.format(record).encode("utf8") + except Exception: + self.handleError(record) + return + + topic_list = [] + + if self.root_topic: + topic_list.append(self.root_topic) + + topic_list.append(record.levelname) + + if topic: + topic_list.append(topic) + + btopic = '.'.join(topic_list).encode("utf8", "replace") + + self.socket.send_multipart([btopic, bmsg]) + + +class TopicLogger(logging.Logger): + """A simple wrapper that takes an additional argument to log methods. + + All the regular methods exist, but instead of one msg argument, two + arguments: topic, msg are passed. + + That is:: + + logger.debug('msg') + + Would become:: + + logger.debug('topic.sub', 'msg') + """ + + def log(self, level, topic, msg, *args, **kwargs): + """Log 'msg % args' with level and topic. + + To pass exception information, use the keyword argument exc_info + with a True value:: + + logger.log(level, "zmq.fun", "We have a %s", + "mysterious problem", exc_info=1) + """ + logging.Logger.log(self, level, f'{topic}{TOPIC_DELIM}{msg}', *args, **kwargs) + + +# Generate the methods of TopicLogger, since they are just adding a +# topic prefix to a message. +for name in "debug warn warning error critical fatal".split(): + try: + meth = getattr(logging.Logger, name) + except AttributeError: + # some methods are missing, e.g. Logger.warn was removed from Python 3.13 + continue + setattr( + TopicLogger, + name, + lambda self, level, topic, msg, *args, **kwargs: meth( + self, level, topic + TOPIC_DELIM + msg, *args, **kwargs + ), + ) diff --git a/source/zmq/py.typed b/source/zmq/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/source/zmq/ssh/__init__.py b/source/zmq/ssh/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..57f09568223c48babf9c9d7745218f7e33290eaa --- /dev/null +++ b/source/zmq/ssh/__init__.py @@ -0,0 +1 @@ +from zmq.ssh.tunnel import * diff --git a/source/zmq/ssh/forward.py b/source/zmq/ssh/forward.py new file mode 100644 index 0000000000000000000000000000000000000000..074a98f23c43930a47f2aa6b59259a849df2d5e9 --- /dev/null +++ b/source/zmq/ssh/forward.py @@ -0,0 +1,95 @@ +# +# This file is adapted from a paramiko demo, and thus licensed under LGPL 2.1. +# Original Copyright (C) 2003-2007 Robey Pointer +# Edits Copyright (C) 2010 The IPython Team +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, see . + +""" +Sample script showing how to do local port forwarding over paramiko. + +This script connects to the requested SSH server and sets up local port +forwarding (the openssh -L option) from a local port through a tunneled +connection to a destination reachable from the SSH server machine. +""" + +import logging +import select +import socketserver + +logger = logging.getLogger('ssh') + + +class ForwardServer(socketserver.ThreadingTCPServer): + daemon_threads = True + allow_reuse_address = True + + +class Handler(socketserver.BaseRequestHandler): + def handle(self): + try: + chan = self.ssh_transport.open_channel( + 'direct-tcpip', + (self.chain_host, self.chain_port), + self.request.getpeername(), + ) + except Exception as e: + logger.debug( + 'Incoming request to %s:%d failed: %r', + self.chain_host, + self.chain_port, + e, + ) + return + if chan is None: + logger.debug( + 'Incoming request to %s:%d was rejected by the SSH server.', + self.chain_host, + self.chain_port, + ) + return + + logger.debug( + f'Connected! Tunnel open {self.request.getpeername()!r} -> {chan.getpeername()!r} -> {(self.chain_host, self.chain_port)!r}' + ) + while True: + r, w, x = select.select([self.request, chan], [], []) + if self.request in r: + data = self.request.recv(1024) + if len(data) == 0: + break + chan.send(data) + if chan in r: + data = chan.recv(1024) + if len(data) == 0: + break + self.request.send(data) + chan.close() + self.request.close() + logger.debug('Tunnel closed ') + + +def forward_tunnel(local_port, remote_host, remote_port, transport): + # this is a little convoluted, but lets me configure things for the Handler + # object. (SocketServer doesn't give Handlers any way to access the outer + # server normally.) + class SubHander(Handler): + chain_host = remote_host + chain_port = remote_port + ssh_transport = transport + + ForwardServer(('127.0.0.1', local_port), SubHander).serve_forever() + + +__all__ = ['forward_tunnel'] diff --git a/source/zmq/ssh/tunnel.py b/source/zmq/ssh/tunnel.py new file mode 100644 index 0000000000000000000000000000000000000000..0e9c88e8a1acd88ad3aca4495d9632591cb7d336 --- /dev/null +++ b/source/zmq/ssh/tunnel.py @@ -0,0 +1,430 @@ +"""Basic ssh tunnel utilities, and convenience functions for tunneling +zeromq connections. +""" + +# Copyright (C) 2010-2011 IPython Development Team +# Copyright (C) 2011- PyZMQ Developers +# +# Redistributed from IPython under the terms of the BSD License. + +import atexit +import os +import re +import signal +import socket +import sys +import warnings +from getpass import getpass, getuser +from multiprocessing import Process + +try: + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + import paramiko + + SSHException = paramiko.ssh_exception.SSHException +except ImportError: + paramiko = None # type: ignore + + class SSHException(Exception): # type: ignore + pass + +else: + from .forward import forward_tunnel + +try: + import pexpect +except ImportError: + pexpect = None + + +class MaxRetryExceeded(Exception): + pass + + +def select_random_ports(n): + """Select and return n random ports that are available.""" + ports = [] + sockets = [] + for i in range(n): + sock = socket.socket() + sock.bind(('', 0)) + ports.append(sock.getsockname()[1]) + sockets.append(sock) + for sock in sockets: + sock.close() + return ports + + +# ----------------------------------------------------------------------------- +# Check for passwordless login +# ----------------------------------------------------------------------------- +_password_pat = re.compile(rb'pass(word|phrase)', re.IGNORECASE) + + +def try_passwordless_ssh(server, keyfile, paramiko=None): + """Attempt to make an ssh connection without a password. + This is mainly used for requiring password input only once + when many tunnels may be connected to the same server. + + If paramiko is None, the default for the platform is chosen. + """ + if paramiko is None: + paramiko = sys.platform == 'win32' + if not paramiko: + f = _try_passwordless_openssh + else: + f = _try_passwordless_paramiko + return f(server, keyfile) + + +def _try_passwordless_openssh(server, keyfile): + """Try passwordless login with shell ssh command.""" + if pexpect is None: + raise ImportError("pexpect unavailable, use paramiko") + cmd = 'ssh -f ' + server + if keyfile: + cmd += ' -i ' + keyfile + cmd += ' exit' + + # pop SSH_ASKPASS from env + env = os.environ.copy() + env.pop('SSH_ASKPASS', None) + + ssh_newkey = 'Are you sure you want to continue connecting' + p = pexpect.spawn(cmd, env=env) + + MAX_RETRY = 10 + + for _ in range(MAX_RETRY): + try: + i = p.expect([ssh_newkey, _password_pat], timeout=0.1) + if i == 0: + raise SSHException( + 'The authenticity of the host can\'t be established.' + ) + except pexpect.TIMEOUT: + continue + except pexpect.EOF: + return True + else: + return False + + raise MaxRetryExceeded(f"Failed after {MAX_RETRY} attempts") + + +def _try_passwordless_paramiko(server, keyfile): + """Try passwordless login with paramiko.""" + if paramiko is None: + msg = "Paramiko unavailable, " + if sys.platform == 'win32': + msg += "Paramiko is required for ssh tunneled connections on Windows." + else: + msg += "use OpenSSH." + raise ImportError(msg) + username, server, port = _split_server(server) + client = paramiko.SSHClient() + known_hosts = os.path.expanduser("~/.ssh/known_hosts") + try: + client.load_host_keys(known_hosts) + except FileNotFoundError: + pass + + policy_name = os.environ.get("PYZMQ_PARAMIKO_HOST_KEY_POLICY", None) + if policy_name: + policy = getattr(paramiko, f"{policy_name}Policy") + client.set_missing_host_key_policy(policy()) + try: + client.connect( + server, port, username=username, key_filename=keyfile, look_for_keys=True + ) + except paramiko.AuthenticationException: + return False + else: + client.close() + return True + + +def tunnel_connection( + socket, addr, server, keyfile=None, password=None, paramiko=None, timeout=60 +): + """Connect a socket to an address via an ssh tunnel. + + This is a wrapper for socket.connect(addr), when addr is not accessible + from the local machine. It simply creates an ssh tunnel using the remaining args, + and calls socket.connect('tcp://localhost:lport') where lport is the randomly + selected local port of the tunnel. + + """ + new_url, tunnel = open_tunnel( + addr, + server, + keyfile=keyfile, + password=password, + paramiko=paramiko, + timeout=timeout, + ) + socket.connect(new_url) + return tunnel + + +def open_tunnel(addr, server, keyfile=None, password=None, paramiko=None, timeout=60): + """Open a tunneled connection from a 0MQ url. + + For use inside tunnel_connection. + + Returns + ------- + + (url, tunnel) : (str, object) + The 0MQ url that has been forwarded, and the tunnel object + """ + + lport = select_random_ports(1)[0] + transport, addr = addr.split('://') + ip, rport = addr.split(':') + rport = int(rport) + if paramiko is None: + paramiko = sys.platform == 'win32' + if paramiko: + tunnelf = paramiko_tunnel + else: + tunnelf = openssh_tunnel + + tunnel = tunnelf( + lport, + rport, + server, + remoteip=ip, + keyfile=keyfile, + password=password, + timeout=timeout, + ) + return f'tcp://127.0.0.1:{lport}', tunnel + + +def openssh_tunnel( + lport, rport, server, remoteip='127.0.0.1', keyfile=None, password=None, timeout=60 +): + """Create an ssh tunnel using command-line ssh that connects port lport + on this machine to localhost:rport on server. The tunnel + will automatically close when not in use, remaining open + for a minimum of timeout seconds for an initial connection. + + This creates a tunnel redirecting `localhost:lport` to `remoteip:rport`, + as seen from `server`. + + keyfile and password may be specified, but ssh config is checked for defaults. + + Parameters + ---------- + + lport : int + local port for connecting to the tunnel from this machine. + rport : int + port on the remote machine to connect to. + server : str + The ssh server to connect to. The full ssh server string will be parsed. + user@server:port + remoteip : str [Default: 127.0.0.1] + The remote ip, specifying the destination of the tunnel. + Default is localhost, which means that the tunnel would redirect + localhost:lport on this machine to localhost:rport on the *server*. + + keyfile : str; path to private key file + This specifies a key to be used in ssh login, default None. + Regular default ssh keys will be used without specifying this argument. + password : str; + Your ssh password to the ssh server. Note that if this is left None, + you will be prompted for it if passwordless key based login is unavailable. + timeout : int [default: 60] + The time (in seconds) after which no activity will result in the tunnel + closing. This prevents orphaned tunnels from running forever. + """ + if pexpect is None: + raise ImportError("pexpect unavailable, use paramiko_tunnel") + ssh = "ssh " + if keyfile: + ssh += "-i " + keyfile + + if ':' in server: + server, port = server.split(':') + ssh += f" -p {port}" + + cmd = f"{ssh} -O check {server}" + (output, exitstatus) = pexpect.run(cmd, withexitstatus=True) + if not exitstatus: + pid = int(output[output.find(b"(pid=") + 5 : output.find(b")")]) + cmd = f"{ssh} -O forward -L 127.0.0.1:{lport}:{remoteip}:{rport} {server}" + (output, exitstatus) = pexpect.run(cmd, withexitstatus=True) + if not exitstatus: + atexit.register(_stop_tunnel, cmd.replace("-O forward", "-O cancel", 1)) + return pid + cmd = f"{ssh} -f -S none -L 127.0.0.1:{lport}:{remoteip}:{rport} {server} sleep {timeout}" + + # pop SSH_ASKPASS from env + env = os.environ.copy() + env.pop('SSH_ASKPASS', None) + + ssh_newkey = 'Are you sure you want to continue connecting' + tunnel = pexpect.spawn(cmd, env=env) + failed = False + MAX_RETRY = 10 + for _ in range(MAX_RETRY): + try: + i = tunnel.expect([ssh_newkey, _password_pat], timeout=0.1) + if i == 0: + raise SSHException( + 'The authenticity of the host can\'t be established.' + ) + except pexpect.TIMEOUT: + continue + except pexpect.EOF: + if tunnel.exitstatus: + print(tunnel.exitstatus) + print(tunnel.before) + print(tunnel.after) + raise RuntimeError(f"tunnel '{cmd}' failed to start") + else: + return tunnel.pid + else: + if failed: + print("Password rejected, try again") + password = None + if password is None: + password = getpass(f"{server}'s password: ") + tunnel.sendline(password) + failed = True + raise MaxRetryExceeded(f"Failed after {MAX_RETRY} attempts") + + +def _stop_tunnel(cmd): + pexpect.run(cmd) + + +def _split_server(server): + if '@' in server: + username, server = server.split('@', 1) + else: + username = getuser() + if ':' in server: + server, port = server.split(':') + port = int(port) + else: + port = 22 + return username, server, port + + +def paramiko_tunnel( + lport, rport, server, remoteip='127.0.0.1', keyfile=None, password=None, timeout=60 +): + """launch a tunner with paramiko in a subprocess. This should only be used + when shell ssh is unavailable (e.g. Windows). + + This creates a tunnel redirecting `localhost:lport` to `remoteip:rport`, + as seen from `server`. + + If you are familiar with ssh tunnels, this creates the tunnel: + + ssh server -L localhost:lport:remoteip:rport + + keyfile and password may be specified, but ssh config is checked for defaults. + + + Parameters + ---------- + + lport : int + local port for connecting to the tunnel from this machine. + rport : int + port on the remote machine to connect to. + server : str + The ssh server to connect to. The full ssh server string will be parsed. + user@server:port + remoteip : str [Default: 127.0.0.1] + The remote ip, specifying the destination of the tunnel. + Default is localhost, which means that the tunnel would redirect + localhost:lport on this machine to localhost:rport on the *server*. + + keyfile : str; path to private key file + This specifies a key to be used in ssh login, default None. + Regular default ssh keys will be used without specifying this argument. + password : str; + Your ssh password to the ssh server. Note that if this is left None, + you will be prompted for it if passwordless key based login is unavailable. + timeout : int [default: 60] + The time (in seconds) after which no activity will result in the tunnel + closing. This prevents orphaned tunnels from running forever. + + """ + if paramiko is None: + raise ImportError("Paramiko not available") + + if password is None: + if not _try_passwordless_paramiko(server, keyfile): + password = getpass(f"{server}'s password: ") + + p = Process( + target=_paramiko_tunnel, + args=(lport, rport, server, remoteip), + kwargs=dict(keyfile=keyfile, password=password), + ) + p.daemon = True + p.start() + return p + + +def _paramiko_tunnel(lport, rport, server, remoteip, keyfile=None, password=None): + """Function for actually starting a paramiko tunnel, to be passed + to multiprocessing.Process(target=this), and not called directly. + """ + username, server, port = _split_server(server) + client = paramiko.SSHClient() + client.load_system_host_keys() + client.set_missing_host_key_policy(paramiko.WarningPolicy()) + + try: + client.connect( + server, + port, + username=username, + key_filename=keyfile, + look_for_keys=True, + password=password, + ) + # except paramiko.AuthenticationException: + # if password is None: + # password = getpass("%s@%s's password: "%(username, server)) + # client.connect(server, port, username=username, password=password) + # else: + # raise + except Exception as e: + print(f'*** Failed to connect to {server}:{port}: {e!r}') + sys.exit(1) + + # Don't let SIGINT kill the tunnel subprocess + signal.signal(signal.SIGINT, signal.SIG_IGN) + + try: + forward_tunnel(lport, remoteip, rport, client.get_transport()) + except KeyboardInterrupt: + print('SIGINT: Port forwarding stopped cleanly') + sys.exit(0) + except Exception as e: + print(f"Port forwarding stopped uncleanly: {e}") + sys.exit(255) + + +if sys.platform == 'win32': + ssh_tunnel = paramiko_tunnel +else: + ssh_tunnel = openssh_tunnel + + +__all__ = [ + 'tunnel_connection', + 'ssh_tunnel', + 'openssh_tunnel', + 'paramiko_tunnel', + 'try_passwordless_ssh', +] diff --git a/source/zmq/sugar/__init__.py b/source/zmq/sugar/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..88e755682c1096272830c01e37ef71d822ec1b05 --- /dev/null +++ b/source/zmq/sugar/__init__.py @@ -0,0 +1,39 @@ +"""pure-Python sugar wrappers for core 0MQ objects.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from __future__ import annotations + +from zmq import error +from zmq.backend import proxy +from zmq.constants import DeviceType +from zmq.sugar import context, frame, poll, socket, tracker, version + + +def device(device_type: DeviceType, frontend: socket.Socket, backend: socket.Socket): + """Deprecated alias for zmq.proxy + + .. deprecated:: libzmq-3.2 + .. deprecated:: 13.0 + """ + + return proxy(frontend, backend) + + +__all__ = ["device"] +for submod in (context, error, frame, poll, socket, tracker, version): + __all__.extend(submod.__all__) + +from zmq.error import * # noqa +from zmq.sugar.context import * # noqa +from zmq.sugar.frame import * # noqa +from zmq.sugar.poll import * # noqa +from zmq.sugar.socket import * # noqa + +# deprecated: +from zmq.sugar.stopwatch import Stopwatch # noqa +from zmq.sugar.tracker import * # noqa +from zmq.sugar.version import * # noqa + +__all__.append('Stopwatch') diff --git a/source/zmq/sugar/__init__.pyi b/source/zmq/sugar/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..732f605017a063e9b66403921623a468f1ea0abe --- /dev/null +++ b/source/zmq/sugar/__init__.pyi @@ -0,0 +1,10 @@ +from zmq.error import * + +from . import constants as constants +from .constants import * +from .context import * +from .frame import * +from .poll import * +from .socket import * +from .tracker import * +from .version import * diff --git a/source/zmq/sugar/attrsettr.py b/source/zmq/sugar/attrsettr.py new file mode 100644 index 0000000000000000000000000000000000000000..844fce606e53def45930d2e3a19873b4e8b4aeac --- /dev/null +++ b/source/zmq/sugar/attrsettr.py @@ -0,0 +1,79 @@ +"""Mixin for mapping set/getattr to self.set/get""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. +from __future__ import annotations + +import errno +from typing import TypeVar, Union + +from .. import constants + +T = TypeVar("T") +OptValT = Union[str, bytes, int] + + +class AttributeSetter: + def __setattr__(self, key: str, value: OptValT) -> None: + """set zmq options by attribute""" + + if key in self.__dict__: + object.__setattr__(self, key, value) + return + # regular setattr only allowed for class-defined attributes + for cls in self.__class__.mro(): + if key in cls.__dict__ or key in getattr(cls, "__annotations__", {}): + object.__setattr__(self, key, value) + return + + upper_key = key.upper() + try: + opt = getattr(constants, upper_key) + except AttributeError: + raise AttributeError( + f"{self.__class__.__name__} has no such option: {upper_key}" + ) + else: + self._set_attr_opt(upper_key, opt, value) + + def _set_attr_opt(self, name: str, opt: int, value: OptValT) -> None: + """override if setattr should do something other than call self.set""" + self.set(opt, value) + + def __getattr__(self, key: str) -> OptValT: + """get zmq options by attribute""" + upper_key = key.upper() + try: + opt = getattr(constants, upper_key) + except AttributeError: + raise AttributeError( + f"{self.__class__.__name__} has no such option: {upper_key}" + ) from None + else: + from zmq import ZMQError + + try: + return self._get_attr_opt(upper_key, opt) + except ZMQError as e: + # EINVAL will be raised on access for write-only attributes. + # Turn that into an AttributeError + # necessary for mocking + if e.errno in {errno.EINVAL, errno.EFAULT}: + raise AttributeError(f"{key} attribute is write-only") + else: + raise + + def _get_attr_opt(self, name, opt) -> OptValT: + """override if getattr should do something other than call self.get""" + return self.get(opt) + + def get(self, opt: int) -> OptValT: + """Override in subclass""" + raise NotImplementedError("override in subclass") + + def set(self, opt: int, val: OptValT) -> None: + """Override in subclass""" + raise NotImplementedError("override in subclass") + + +__all__ = ['AttributeSetter'] diff --git a/source/zmq/sugar/context.py b/source/zmq/sugar/context.py new file mode 100644 index 0000000000000000000000000000000000000000..21bc2cdda9d9319b03a6a0b45d791a6c5e2c3e8a --- /dev/null +++ b/source/zmq/sugar/context.py @@ -0,0 +1,420 @@ +"""Python bindings for 0MQ.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from __future__ import annotations + +import atexit +import os +from threading import Lock +from typing import Any, Callable, Generic, TypeVar, overload +from warnings import warn +from weakref import WeakSet + +import zmq +from zmq._typing import TypeAlias +from zmq.backend import Context as ContextBase +from zmq.constants import ContextOption, Errno, SocketOption +from zmq.error import ZMQError +from zmq.utils.interop import cast_int_addr + +from .attrsettr import AttributeSetter, OptValT +from .socket import Socket, SyncSocket + +# notice when exiting, to avoid triggering term on exit +_exiting = False + + +def _notice_atexit() -> None: + global _exiting + _exiting = True + + +atexit.register(_notice_atexit) + +_ContextType = TypeVar('_ContextType', bound='Context') +_SocketType = TypeVar('_SocketType', bound='Socket', covariant=True) + + +class Context(ContextBase, AttributeSetter, Generic[_SocketType]): + """Create a zmq Context + + A zmq Context creates sockets via its ``ctx.socket`` method. + + .. versionchanged:: 24 + + When using a Context as a context manager (``with zmq.Context()``), + or deleting a context without closing it first, + ``ctx.destroy()`` is called, + closing any leftover sockets, + instead of `ctx.term()` which requires sockets to be closed first. + + This prevents hangs caused by `ctx.term()` if sockets are left open, + but means that unclean destruction of contexts + (with sockets left open) is not safe + if sockets are managed in other threads. + + .. versionadded:: 25 + + Contexts can now be shadowed by passing another Context. + This helps in creating an async copy of a sync context or vice versa:: + + ctx = zmq.Context(async_ctx) + + Which previously had to be:: + + ctx = zmq.Context.shadow(async_ctx.underlying) + """ + + sockopts: dict[int, Any] + _instance: Any = None + _instance_lock = Lock() + _instance_pid: int | None = None + _shadow = False + _shadow_obj = None + _warn_destroy_close = False + _sockets: WeakSet + # mypy doesn't like a default value here + _socket_class: type[_SocketType] = Socket # type: ignore + + @overload + def __init__(self: SyncContext, io_threads: int = 1): ... + + @overload + def __init__(self: SyncContext, io_threads: Context, /): ... + + @overload + def __init__(self: SyncContext, *, shadow: Context | int): ... + + def __init__( + self: SyncContext, + io_threads: int | Context = 1, + shadow: Context | int = 0, + ) -> None: + if isinstance(io_threads, Context): + # allow positional shadow `zmq.Context(zmq.asyncio.Context())` + # this s + shadow = io_threads + io_threads = 1 + + shadow_address: int = 0 + if shadow: + self._shadow = True + # hold a reference to the shadow object + self._shadow_obj = shadow + if not isinstance(shadow, int): + try: + shadow = shadow.underlying + except AttributeError: + pass + shadow_address = cast_int_addr(shadow) + else: + self._shadow = False + super().__init__(io_threads=io_threads, shadow=shadow_address) + self.sockopts = {} + self._sockets = WeakSet() + + def __del__(self) -> None: + """Deleting a Context without closing it destroys it and all sockets. + + .. versionchanged:: 24 + Switch from threadsafe `term()` which hangs in the event of open sockets + to less safe `destroy()` which + warns about any leftover sockets and closes them. + """ + + # Calling locals() here conceals issue #1167 on Windows CPython 3.5.4. + locals() + + if not self._shadow and not _exiting and not self.closed: + self._warn_destroy_close = True + if warn is not None and getattr(self, "_sockets", None) is not None: + # warn can be None during process teardown + warn( + f"Unclosed context {self}", + ResourceWarning, + stacklevel=2, + source=self, + ) + self.destroy() + + _repr_cls = "zmq.Context" + + def __repr__(self) -> str: + cls = self.__class__ + # look up _repr_cls on exact class, not inherited + _repr_cls = cls.__dict__.get("_repr_cls", None) + if _repr_cls is None: + _repr_cls = f"{cls.__module__}.{cls.__name__}" + + closed = ' closed' if self.closed else '' + if getattr(self, "_sockets", None): + n_sockets = len(self._sockets) + s = 's' if n_sockets > 1 else '' + sockets = f"{n_sockets} socket{s}" + else: + sockets = "" + return f"<{_repr_cls}({sockets}) at {hex(id(self))}{closed}>" + + def __enter__(self: _ContextType) -> _ContextType: + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + # warn about any leftover sockets before closing them + self._warn_destroy_close = True + self.destroy() + + def __copy__(self: _ContextType, memo: Any = None) -> _ContextType: + """Copying a Context creates a shadow copy""" + return self.__class__.shadow(self.underlying) + + __deepcopy__ = __copy__ + + @classmethod + def shadow(cls: type[_ContextType], address: int | zmq.Context) -> _ContextType: + """Shadow an existing libzmq context + + address is a zmq.Context or an integer (or FFI pointer) + representing the address of the libzmq context. + + .. versionadded:: 14.1 + + .. versionadded:: 25 + Support for shadowing `zmq.Context` objects, + instead of just integer addresses. + """ + return cls(shadow=address) + + @classmethod + def shadow_pyczmq(cls: type[_ContextType], ctx: Any) -> _ContextType: + """Shadow an existing pyczmq context + + ctx is the FFI `zctx_t *` pointer + + .. versionadded:: 14.1 + """ + from pyczmq import zctx # type: ignore + + from zmq.utils.interop import cast_int_addr + + underlying = zctx.underlying(ctx) + address = cast_int_addr(underlying) + return cls(shadow=address) + + # static method copied from tornado IOLoop.instance + @classmethod + def instance(cls: type[_ContextType], io_threads: int = 1) -> _ContextType: + """Returns a global Context instance. + + Most single-process applications have a single, global Context. + Use this method instead of passing around Context instances + throughout your code. + + A common pattern for classes that depend on Contexts is to use + a default argument to enable programs with multiple Contexts + but not require the argument for simpler applications:: + + class MyClass(object): + def __init__(self, context=None): + self.context = context or Context.instance() + + .. versionchanged:: 18.1 + + When called in a subprocess after forking, + a new global instance is created instead of inheriting + a Context that won't work from the parent process. + """ + if ( + cls._instance is None + or cls._instance_pid != os.getpid() + or cls._instance.closed + ): + with cls._instance_lock: + if ( + cls._instance is None + or cls._instance_pid != os.getpid() + or cls._instance.closed + ): + cls._instance = cls(io_threads=io_threads) + cls._instance_pid = os.getpid() + return cls._instance + + def term(self) -> None: + """Close or terminate the context. + + Context termination is performed in the following steps: + + - Any blocking operations currently in progress on sockets open within context shall + raise :class:`zmq.ContextTerminated`. + With the exception of socket.close(), any further operations on sockets open within this context + shall raise :class:`zmq.ContextTerminated`. + - After interrupting all blocking calls, term shall block until the following conditions are satisfied: + - All sockets open within context have been closed. + - For each socket within context, all messages sent on the socket have either been + physically transferred to a network peer, + or the socket's linger period set with the zmq.LINGER socket option has expired. + + For further details regarding socket linger behaviour refer to libzmq documentation for ZMQ_LINGER. + + This can be called to close the context by hand. If this is not called, + the context will automatically be closed when it is garbage collected, + in which case you may see a ResourceWarning about the unclosed context. + """ + super().term() + + # ------------------------------------------------------------------------- + # Hooks for ctxopt completion + # ------------------------------------------------------------------------- + + def __dir__(self) -> list[str]: + keys = dir(self.__class__) + keys.extend(ContextOption.__members__) + return keys + + # ------------------------------------------------------------------------- + # Creating Sockets + # ------------------------------------------------------------------------- + + def _add_socket(self, socket: Any) -> None: + """Add a weakref to a socket for Context.destroy / reference counting""" + self._sockets.add(socket) + + def _rm_socket(self, socket: Any) -> None: + """Remove a socket for Context.destroy / reference counting""" + # allow _sockets to be None in case of process teardown + if getattr(self, "_sockets", None) is not None: + self._sockets.discard(socket) + + def destroy(self, linger: int | None = None) -> None: + """Close all sockets associated with this context and then terminate + the context. + + .. warning:: + + destroy involves calling :meth:`Socket.close`, which is **NOT** threadsafe. + If there are active sockets in other threads, this must not be called. + + Parameters + ---------- + + linger : int, optional + If specified, set LINGER on sockets prior to closing them. + """ + if self.closed: + return + + sockets: list[_SocketType] = list(getattr(self, "_sockets", None) or []) + for s in sockets: + if s and not s.closed: + if self._warn_destroy_close and warn is not None: + # warn can be None during process teardown + warn( + f"Destroying context with unclosed socket {s}", + ResourceWarning, + stacklevel=3, + source=s, + ) + if linger is not None: + s.setsockopt(SocketOption.LINGER, linger) + s.close() + + self.term() + + def socket( + self: _ContextType, + socket_type: int, + socket_class: Callable[[_ContextType, int], _SocketType] | None = None, + **kwargs: Any, + ) -> _SocketType: + """Create a Socket associated with this Context. + + Parameters + ---------- + socket_type : int + The socket type, which can be any of the 0MQ socket types: + REQ, REP, PUB, SUB, PAIR, DEALER, ROUTER, PULL, PUSH, etc. + + socket_class: zmq.Socket + The socket class to instantiate, if different from the default for this Context. + e.g. for creating an asyncio socket attached to a default Context or vice versa. + + .. versionadded:: 25 + + kwargs: + will be passed to the __init__ method of the socket class. + """ + if self.closed: + raise ZMQError(Errno.ENOTSUP) + if socket_class is None: + socket_class = self._socket_class + s: _SocketType = ( + socket_class( # set PYTHONTRACEMALLOC=2 to get the calling frame + self, socket_type, **kwargs + ) + ) + for opt, value in self.sockopts.items(): + try: + s.setsockopt(opt, value) + except ZMQError: + # ignore ZMQErrors, which are likely for socket options + # that do not apply to a particular socket type, e.g. + # SUBSCRIBE for non-SUB sockets. + pass + self._add_socket(s) + return s + + def setsockopt(self, opt: int, value: Any) -> None: + """set default socket options for new sockets created by this Context + + .. versionadded:: 13.0 + """ + self.sockopts[opt] = value + + def getsockopt(self, opt: int) -> OptValT: + """get default socket options for new sockets created by this Context + + .. versionadded:: 13.0 + """ + return self.sockopts[opt] + + def _set_attr_opt(self, name: str, opt: int, value: OptValT) -> None: + """set default sockopts as attributes""" + if name in ContextOption.__members__: + return self.set(opt, value) + elif name in SocketOption.__members__: + self.sockopts[opt] = value + else: + raise AttributeError(f"No such context or socket option: {name}") + + def _get_attr_opt(self, name: str, opt: int) -> OptValT: + """get default sockopts as attributes""" + if name in ContextOption.__members__: + return self.get(opt) + else: + if opt not in self.sockopts: + raise AttributeError(name) + else: + return self.sockopts[opt] + + def __delattr__(self, key: str) -> None: + """delete default sockopts as attributes""" + if key in self.__dict__: + self.__dict__.pop(key) + return + key = key.upper() + try: + opt = getattr(SocketOption, key) + except AttributeError: + raise AttributeError(f"No such socket option: {key!r}") + else: + if opt not in self.sockopts: + raise AttributeError(key) + else: + del self.sockopts[opt] + + +SyncContext: TypeAlias = Context[SyncSocket] + + +__all__ = ['Context', 'SyncContext'] diff --git a/source/zmq/sugar/frame.py b/source/zmq/sugar/frame.py new file mode 100644 index 0000000000000000000000000000000000000000..3239d357e1c67bc0bc41b3e01b479f3c893544f2 --- /dev/null +++ b/source/zmq/sugar/frame.py @@ -0,0 +1,134 @@ +"""0MQ Frame pure Python methods.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import zmq +from zmq.backend import Frame as FrameBase + +from .attrsettr import AttributeSetter + + +def _draft(v, feature): + zmq.error._check_version(v, feature) + if not zmq.DRAFT_API: + raise RuntimeError( + f"libzmq and pyzmq must be built with draft support for {feature}" + ) + + +class Frame(FrameBase, AttributeSetter): + """ + A zmq message Frame class for non-copying send/recvs and access to message properties. + + A ``zmq.Frame`` wraps an underlying ``zmq_msg_t``. + + Message *properties* can be accessed by treating a Frame like a dictionary (``frame["User-Id"]``). + + .. versionadded:: 14.4, libzmq 4 + + Frames created by ``recv(copy=False)`` can be used to access message properties and attributes, + such as the CURVE User-Id. + + For example:: + + frames = socket.recv_multipart(copy=False) + user_id = frames[0]["User-Id"] + + This class is used if you want to do non-copying send and recvs. + When you pass a chunk of bytes to this class, e.g. ``Frame(buf)``, the + ref-count of `buf` is increased by two: once because the Frame saves `buf` as + an instance attribute and another because a ZMQ message is created that + points to the buffer of `buf`. This second ref-count increase makes sure + that `buf` lives until all messages that use it have been sent. + Once 0MQ sends all the messages and it doesn't need the buffer of ``buf``, + 0MQ will call ``Py_DECREF(s)``. + + Parameters + ---------- + + data : object, optional + any object that provides the buffer interface will be used to + construct the 0MQ message data. + track : bool + whether a MessageTracker_ should be created to track this object. + Tracking a message has a cost at creation, because it creates a threadsafe + Event object. + copy : bool + default: use copy_threshold + Whether to create a copy of the data to pass to libzmq + or share the memory with libzmq. + If unspecified, copy_threshold is used. + copy_threshold: int + default: :const:`zmq.COPY_THRESHOLD` + If copy is unspecified, messages smaller than this many bytes + will be copied and messages larger than this will be shared with libzmq. + """ + + def __getitem__(self, key): + # map Frame['User-Id'] to Frame.get('User-Id') + return self.get(key) + + def __repr__(self): + """Return the str form of the message.""" + nbytes = len(self) + msg_suffix = "" + if nbytes > 16: + msg_bytes = bytes(memoryview(self.buffer)[:12]) + if nbytes >= 1e9: + unit = "GB" + n = nbytes // 1e9 + elif nbytes >= 2**20: + unit = "MB" + n = nbytes // 1e6 + elif nbytes >= 1e3: + unit = "kB" + n = nbytes // 1e3 + else: + unit = "B" + n = nbytes + msg_suffix = f'...{n:.0f}{unit}' + else: + msg_bytes = self.bytes + + _module = self.__class__.__module__ + if _module == "zmq.sugar.frame": + _module = "zmq" + return f"<{_module}.{self.__class__.__name__}({msg_bytes!r}{msg_suffix})>" + + @property + def group(self): + """The RADIO-DISH group of the message. + + Requires libzmq >= 4.2 and pyzmq built with draft APIs enabled. + + .. versionadded:: 17 + """ + _draft((4, 2), "RADIO-DISH") + return self.get('group') + + @group.setter + def group(self, group): + _draft((4, 2), "RADIO-DISH") + self.set('group', group) + + @property + def routing_id(self): + """The CLIENT-SERVER routing id of the message. + + Requires libzmq >= 4.2 and pyzmq built with draft APIs enabled. + + .. versionadded:: 17 + """ + _draft((4, 2), "CLIENT-SERVER") + return self.get('routing_id') + + @routing_id.setter + def routing_id(self, routing_id): + _draft((4, 2), "CLIENT-SERVER") + self.set('routing_id', routing_id) + + +# keep deprecated alias +Message = Frame +__all__ = ['Frame', 'Message'] diff --git a/source/zmq/sugar/poll.py b/source/zmq/sugar/poll.py new file mode 100644 index 0000000000000000000000000000000000000000..27baad46e75626d68a0691b6c89ba0fbd7c353e2 --- /dev/null +++ b/source/zmq/sugar/poll.py @@ -0,0 +1,172 @@ +"""0MQ polling related functions and classes.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from __future__ import annotations + +from typing import Any + +from zmq.backend import zmq_poll +from zmq.constants import POLLERR, POLLIN, POLLOUT + +# ----------------------------------------------------------------------------- +# Polling related methods +# ----------------------------------------------------------------------------- + + +class Poller: + """A stateful poll interface that mirrors Python's built-in poll.""" + + sockets: list[tuple[Any, int]] + _map: dict + + def __init__(self) -> None: + self.sockets = [] + self._map = {} + + def __contains__(self, socket: Any) -> bool: + return socket in self._map + + def register(self, socket: Any, flags: int = POLLIN | POLLOUT): + """p.register(socket, flags=POLLIN|POLLOUT) + + Register a 0MQ socket or native fd for I/O monitoring. + + register(s,0) is equivalent to unregister(s). + + Parameters + ---------- + socket : zmq.Socket or native socket + A zmq.Socket or any Python object having a ``fileno()`` + method that returns a valid file descriptor. + flags : int + The events to watch for. Can be POLLIN, POLLOUT or POLLIN|POLLOUT. + If `flags=0`, socket will be unregistered. + """ + if flags: + if socket in self._map: + idx = self._map[socket] + self.sockets[idx] = (socket, flags) + else: + idx = len(self.sockets) + self.sockets.append((socket, flags)) + self._map[socket] = idx + elif socket in self._map: + # uregister sockets registered with no events + self.unregister(socket) + else: + # ignore new sockets with no events + pass + + def modify(self, socket, flags=POLLIN | POLLOUT): + """Modify the flags for an already registered 0MQ socket or native fd.""" + self.register(socket, flags) + + def unregister(self, socket: Any): + """Remove a 0MQ socket or native fd for I/O monitoring. + + Parameters + ---------- + socket : Socket + The socket instance to stop polling. + """ + idx = self._map.pop(socket) + self.sockets.pop(idx) + # shift indices after deletion + for socket, flags in self.sockets[idx:]: + self._map[socket] -= 1 + + def poll(self, timeout: int | None = None) -> list[tuple[Any, int]]: + """Poll the registered 0MQ or native fds for I/O. + + If there are currently events ready to be processed, this function will return immediately. + Otherwise, this function will return as soon the first event is available or after timeout + milliseconds have elapsed. + + Parameters + ---------- + timeout : int + The timeout in milliseconds. If None, no `timeout` (infinite). This + is in milliseconds to be compatible with ``select.poll()``. + + Returns + ------- + events : list + The list of events that are ready to be processed. + This is a list of tuples of the form ``(socket, event_mask)``, where the 0MQ Socket + or integer fd is the first element, and the poll event mask (POLLIN, POLLOUT) is the second. + It is common to call ``events = dict(poller.poll())``, + which turns the list of tuples into a mapping of ``socket : event_mask``. + """ + if timeout is None or timeout < 0: + timeout = -1 + elif isinstance(timeout, float): + timeout = int(timeout) + return zmq_poll(self.sockets, timeout=timeout) + + +def select( + rlist: list, wlist: list, xlist: list, timeout: float | None = None +) -> tuple[list, list, list]: + """select(rlist, wlist, xlist, timeout=None) -> (rlist, wlist, xlist) + + Return the result of poll as a lists of sockets ready for r/w/exception. + + This has the same interface as Python's built-in ``select.select()`` function. + + Parameters + ---------- + timeout : float, optional + The timeout in seconds. If None, no timeout (infinite). This is in seconds to be + compatible with ``select.select()``. + rlist : list + sockets/FDs to be polled for read events + wlist : list + sockets/FDs to be polled for write events + xlist : list + sockets/FDs to be polled for error events + + Returns + ------- + rlist: list + list of sockets or FDs that are readable + wlist: list + list of sockets or FDs that are writable + xlist: list + list of sockets or FDs that had error events (rare) + """ + if timeout is None: + timeout = -1 + # Convert from sec -> ms for zmq_poll. + # zmq_poll accepts 3.x style timeout in ms + timeout = int(timeout * 1000.0) + if timeout < 0: + timeout = -1 + sockets = [] + for s in set(rlist + wlist + xlist): + flags = 0 + if s in rlist: + flags |= POLLIN + if s in wlist: + flags |= POLLOUT + if s in xlist: + flags |= POLLERR + sockets.append((s, flags)) + return_sockets = zmq_poll(sockets, timeout) + rlist, wlist, xlist = [], [], [] + for s, flags in return_sockets: + if flags & POLLIN: + rlist.append(s) + if flags & POLLOUT: + wlist.append(s) + if flags & POLLERR: + xlist.append(s) + return rlist, wlist, xlist + + +# ----------------------------------------------------------------------------- +# Symbols to export +# ----------------------------------------------------------------------------- + +__all__ = ['Poller', 'select'] diff --git a/source/zmq/sugar/socket.py b/source/zmq/sugar/socket.py new file mode 100644 index 0000000000000000000000000000000000000000..168fcaa116174f679a1d3a317b783690e1ac4967 --- /dev/null +++ b/source/zmq/sugar/socket.py @@ -0,0 +1,1125 @@ +"""0MQ Socket pure Python methods.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from __future__ import annotations + +import errno +import pickle +import random +import sys +from typing import ( + Any, + Callable, + Generic, + List, + Literal, + Sequence, + TypeVar, + Union, + cast, + overload, +) +from warnings import warn + +import zmq +from zmq._typing import TypeAlias +from zmq.backend import Socket as SocketBase +from zmq.error import ZMQBindError, ZMQError +from zmq.utils import jsonapi +from zmq.utils.interop import cast_int_addr + +from ..constants import SocketOption, SocketType, _OptType +from .attrsettr import AttributeSetter +from .poll import Poller + +try: + DEFAULT_PROTOCOL = pickle.DEFAULT_PROTOCOL +except AttributeError: + DEFAULT_PROTOCOL = pickle.HIGHEST_PROTOCOL + +_SocketType = TypeVar("_SocketType", bound="Socket") + +_JSONType: TypeAlias = "int | str | bool | list[_JSONType] | dict[str, _JSONType]" + + +class _SocketContext(Generic[_SocketType]): + """Context Manager for socket bind/unbind""" + + socket: _SocketType + kind: str + addr: str + + def __repr__(self): + return f"" + + def __init__( + self: _SocketContext[_SocketType], socket: _SocketType, kind: str, addr: str + ): + assert kind in {"bind", "connect"} + self.socket = socket + self.kind = kind + self.addr = addr + + def __enter__(self: _SocketContext[_SocketType]) -> _SocketType: + return self.socket + + def __exit__(self, *args): + if self.socket.closed: + return + if self.kind == "bind": + self.socket.unbind(self.addr) + elif self.kind == "connect": + self.socket.disconnect(self.addr) + + +SocketReturnType = TypeVar("SocketReturnType") + + +class Socket(SocketBase, AttributeSetter, Generic[SocketReturnType]): + """The ZMQ socket object + + To create a Socket, first create a Context:: + + ctx = zmq.Context.instance() + + then call ``ctx.socket(socket_type)``:: + + s = ctx.socket(zmq.ROUTER) + + .. versionadded:: 25 + + Sockets can now be shadowed by passing another Socket. + This helps in creating an async copy of a sync socket or vice versa:: + + s = zmq.Socket(async_socket) + + Which previously had to be:: + + s = zmq.Socket.shadow(async_socket.underlying) + """ + + _shadow = False + _shadow_obj = None + _monitor_socket = None + _type_name = 'UNKNOWN' + + @overload + def __init__( + self: Socket[bytes], + ctx_or_socket: zmq.Context, + socket_type: int, + *, + copy_threshold: int | None = None, + ): ... + + @overload + def __init__( + self: Socket[bytes], + *, + shadow: Socket | int, + copy_threshold: int | None = None, + ): ... + + @overload + def __init__( + self: Socket[bytes], + ctx_or_socket: Socket, + ): ... + + def __init__( + self: Socket[bytes], + ctx_or_socket: zmq.Context | Socket | None = None, + socket_type: int = 0, + *, + shadow: Socket | int = 0, + copy_threshold: int | None = None, + ): + shadow_context: zmq.Context | None = None + if isinstance(ctx_or_socket, zmq.Socket): + # positional Socket(other_socket) + shadow = ctx_or_socket + ctx_or_socket = None + + shadow_address: int = 0 + + if shadow: + self._shadow = True + # hold a reference to the shadow object + self._shadow_obj = shadow + if not isinstance(shadow, int): + if isinstance(shadow, zmq.Socket): + shadow_context = shadow.context + try: + shadow = cast(int, shadow.underlying) + except AttributeError: + pass + shadow_address = cast_int_addr(shadow) + else: + self._shadow = False + + super().__init__( + ctx_or_socket, + socket_type, + shadow=shadow_address, + copy_threshold=copy_threshold, + ) + if self._shadow_obj and shadow_context: + # keep self.context reference if shadowing a Socket object + self.context = shadow_context + + try: + socket_type = cast(int, self.get(zmq.TYPE)) + except Exception: + pass + else: + try: + self.__dict__["type"] = stype = SocketType(socket_type) + except ValueError: + self._type_name = str(socket_type) + else: + self._type_name = stype.name + + def __del__(self): + if not self._shadow and not self.closed: + if warn is not None: + # warn can be None during process teardown + warn( + f"Unclosed socket {self}", + ResourceWarning, + stacklevel=2, + source=self, + ) + self.close() + + _repr_cls = "zmq.Socket" + + def __repr__(self): + cls = self.__class__ + # look up _repr_cls on exact class, not inherited + _repr_cls = cls.__dict__.get("_repr_cls", None) + if _repr_cls is None: + _repr_cls = f"{cls.__module__}.{cls.__name__}" + + closed = ' closed' if self._closed else '' + + return f"<{_repr_cls}(zmq.{self._type_name}) at {hex(id(self))}{closed}>" + + # socket as context manager: + def __enter__(self: _SocketType) -> _SocketType: + """Sockets are context managers + + .. versionadded:: 14.4 + """ + return self + + def __exit__(self, *args, **kwargs): + self.close() + + # ------------------------------------------------------------------------- + # Socket creation + # ------------------------------------------------------------------------- + + def __copy__(self: _SocketType, memo=None) -> _SocketType: + """Copying a Socket creates a shadow copy""" + return self.__class__.shadow(self.underlying) + + __deepcopy__ = __copy__ + + @classmethod + def shadow(cls: type[_SocketType], address: int | zmq.Socket) -> _SocketType: + """Shadow an existing libzmq socket + + address is a zmq.Socket or an integer (or FFI pointer) + representing the address of the libzmq socket. + + .. versionadded:: 14.1 + + .. versionadded:: 25 + Support for shadowing `zmq.Socket` objects, + instead of just integer addresses. + """ + return cls(shadow=address) + + def close(self, linger=None) -> None: + """ + Close the socket. + + If linger is specified, LINGER sockopt will be set prior to closing. + + Note: closing a zmq Socket may not close the underlying sockets + if there are undelivered messages. + Only after all messages are delivered or discarded by reaching the socket's LINGER timeout + (default: forever) + will the underlying sockets be closed. + + This can be called to close the socket by hand. If this is not + called, the socket will automatically be closed when it is + garbage collected, + in which case you may see a ResourceWarning about the unclosed socket. + """ + if self.context: + self.context._rm_socket(self) + super().close(linger=linger) + + # ------------------------------------------------------------------------- + # Connect/Bind context managers + # ------------------------------------------------------------------------- + + def _connect_cm(self: _SocketType, addr: str) -> _SocketContext[_SocketType]: + """Context manager to disconnect on exit + + .. versionadded:: 20.0 + """ + return _SocketContext(self, 'connect', addr) + + def _bind_cm(self: _SocketType, addr: str) -> _SocketContext[_SocketType]: + """Context manager to unbind on exit + + .. versionadded:: 20.0 + """ + try: + # retrieve last_endpoint + # to support binding on random ports via + # `socket.bind('tcp://127.0.0.1:0')` + addr = cast(bytes, self.get(zmq.LAST_ENDPOINT)).decode("utf8") + except (AttributeError, ZMQError, UnicodeDecodeError): + pass + return _SocketContext(self, 'bind', addr) + + def bind(self: _SocketType, addr: str) -> _SocketContext[_SocketType]: + """s.bind(addr) + + Bind the socket to an address. + + This causes the socket to listen on a network port. Sockets on the + other side of this connection will use ``Socket.connect(addr)`` to + connect to this socket. + + Returns a context manager which will call unbind on exit. + + .. versionadded:: 20.0 + Can be used as a context manager. + + .. versionadded:: 26.0 + binding to port 0 can be used as a context manager + for binding to a random port. + The URL can be retrieved as `socket.last_endpoint`. + + Parameters + ---------- + addr : str + The address string. This has the form 'protocol://interface:port', + for example 'tcp://127.0.0.1:5555'. Protocols supported include + tcp, udp, pgm, epgm, inproc and ipc. If the address is unicode, it is + encoded to utf-8 first. + + """ + try: + super().bind(addr) + except ZMQError as e: + e.strerror += f" (addr={addr!r})" + raise + return self._bind_cm(addr) + + def connect(self: _SocketType, addr: str) -> _SocketContext[_SocketType]: + """s.connect(addr) + + Connect to a remote 0MQ socket. + + Returns a context manager which will call disconnect on exit. + + .. versionadded:: 20.0 + Can be used as a context manager. + + Parameters + ---------- + addr : str + The address string. This has the form 'protocol://interface:port', + for example 'tcp://127.0.0.1:5555'. Protocols supported are + tcp, udp, pgm, inproc and ipc. If the address is unicode, it is + encoded to utf-8 first. + + """ + try: + super().connect(addr) + except ZMQError as e: + e.strerror += f" (addr={addr!r})" + raise + return self._connect_cm(addr) + + # ------------------------------------------------------------------------- + # Deprecated aliases + # ------------------------------------------------------------------------- + + @property + def socket_type(self) -> int: + warn("Socket.socket_type is deprecated, use Socket.type", DeprecationWarning) + return cast(int, self.type) + + # ------------------------------------------------------------------------- + # Hooks for sockopt completion + # ------------------------------------------------------------------------- + + def __dir__(self): + keys = dir(self.__class__) + keys.extend(SocketOption.__members__) + return keys + + # ------------------------------------------------------------------------- + # Getting/Setting options + # ------------------------------------------------------------------------- + setsockopt = SocketBase.set + getsockopt = SocketBase.get + + def __setattr__(self, key, value): + """Override to allow setting zmq.[UN]SUBSCRIBE even though we have a subscribe method""" + if key in self.__dict__: + object.__setattr__(self, key, value) + return + _key = key.lower() + if _key in ('subscribe', 'unsubscribe'): + if isinstance(value, str): + value = value.encode('utf8') + if _key == 'subscribe': + self.set(zmq.SUBSCRIBE, value) + else: + self.set(zmq.UNSUBSCRIBE, value) + return + super().__setattr__(key, value) + + def fileno(self) -> int: + """Return edge-triggered file descriptor for this socket. + + This is a read-only edge-triggered file descriptor for both read and write events on this socket. + It is important that all available events be consumed when an event is detected, + otherwise the read event will not trigger again. + + .. versionadded:: 17.0 + """ + return self.FD + + def subscribe(self, topic: str | bytes) -> None: + """Subscribe to a topic + + Only for SUB sockets. + + .. versionadded:: 15.3 + """ + if isinstance(topic, str): + topic = topic.encode('utf8') + self.set(zmq.SUBSCRIBE, topic) + + def unsubscribe(self, topic: str | bytes) -> None: + """Unsubscribe from a topic + + Only for SUB sockets. + + .. versionadded:: 15.3 + """ + if isinstance(topic, str): + topic = topic.encode('utf8') + self.set(zmq.UNSUBSCRIBE, topic) + + def set_string(self, option: int, optval: str, encoding='utf-8') -> None: + """Set socket options with a unicode object. + + This is simply a wrapper for setsockopt to protect from encoding ambiguity. + + See the 0MQ documentation for details on specific options. + + Parameters + ---------- + option : int + The name of the option to set. Can be any of: SUBSCRIBE, + UNSUBSCRIBE, IDENTITY + optval : str + The value of the option to set. + encoding : str + The encoding to be used, default is utf8 + """ + if not isinstance(optval, str): + raise TypeError(f"strings only, not {type(optval)}: {optval!r}") + return self.set(option, optval.encode(encoding)) + + setsockopt_unicode = setsockopt_string = set_string + + def get_string(self, option: int, encoding='utf-8') -> str: + """Get the value of a socket option. + + See the 0MQ documentation for details on specific options. + + Parameters + ---------- + option : int + The option to retrieve. + + Returns + ------- + optval : str + The value of the option as a unicode string. + """ + if SocketOption(option)._opt_type != _OptType.bytes: + raise TypeError(f"option {option} will not return a string to be decoded") + return cast(bytes, self.get(option)).decode(encoding) + + getsockopt_unicode = getsockopt_string = get_string + + def bind_to_random_port( + self: _SocketType, + addr: str, + min_port: int = 49152, + max_port: int = 65536, + max_tries: int = 100, + ) -> int: + """Bind this socket to a random port in a range. + + If the port range is unspecified, the system will choose the port. + + Parameters + ---------- + addr : str + The address string without the port to pass to ``Socket.bind()``. + min_port : int, optional + The minimum port in the range of ports to try (inclusive). + max_port : int, optional + The maximum port in the range of ports to try (exclusive). + max_tries : int, optional + The maximum number of bind attempts to make. + + Returns + ------- + port : int + The port the socket was bound to. + + Raises + ------ + ZMQBindError + if `max_tries` reached before successful bind + """ + if min_port == 49152 and max_port == 65536: + # if LAST_ENDPOINT is supported, and min_port / max_port weren't specified, + # we can bind to port 0 and let the OS do the work + self.bind(f"{addr}:*") + url = cast(bytes, self.last_endpoint).decode('ascii', 'replace') + _, port_s = url.rsplit(':', 1) + return int(port_s) + + for i in range(max_tries): + try: + port = random.randrange(min_port, max_port) + self.bind(f'{addr}:{port}') + except ZMQError as exception: + en = exception.errno + if en == zmq.EADDRINUSE: + continue + elif sys.platform == 'win32' and en == errno.EACCES: + continue + else: + raise + else: + return port + raise ZMQBindError("Could not bind socket to random port.") + + def get_hwm(self) -> int: + """Get the High Water Mark. + + On libzmq ≥ 3, this gets SNDHWM if available, otherwise RCVHWM + """ + # return sndhwm, fallback on rcvhwm + try: + return cast(int, self.get(zmq.SNDHWM)) + except zmq.ZMQError: + pass + + return cast(int, self.get(zmq.RCVHWM)) + + def set_hwm(self, value: int) -> None: + """Set the High Water Mark. + + On libzmq ≥ 3, this sets both SNDHWM and RCVHWM + + + .. warning:: + + New values only take effect for subsequent socket + bind/connects. + """ + raised = None + try: + self.sndhwm = value + except Exception as e: + raised = e + try: + self.rcvhwm = value + except Exception as e: + raised = e + + if raised: + raise raised + + hwm = property( + get_hwm, + set_hwm, + None, + """Property for High Water Mark. + + Setting hwm sets both SNDHWM and RCVHWM as appropriate. + It gets SNDHWM if available, otherwise RCVHWM. + """, + ) + + # ------------------------------------------------------------------------- + # Sending and receiving messages + # ------------------------------------------------------------------------- + + @overload + def send( + self, + data: Any, + flags: int = ..., + copy: bool = ..., + *, + track: Literal[True], + routing_id: int | None = ..., + group: str | None = ..., + ) -> zmq.MessageTracker: ... + + @overload + def send( + self, + data: Any, + flags: int = ..., + copy: bool = ..., + *, + track: Literal[False], + routing_id: int | None = ..., + group: str | None = ..., + ) -> None: ... + + @overload + def send( + self, + data: Any, + flags: int = ..., + *, + copy: bool = ..., + routing_id: int | None = ..., + group: str | None = ..., + ) -> None: ... + + @overload + def send( + self, + data: Any, + flags: int = ..., + copy: bool = ..., + track: bool = ..., + routing_id: int | None = ..., + group: str | None = ..., + ) -> zmq.MessageTracker | None: ... + + def send( + self, + data: Any, + flags: int = 0, + copy: bool = True, + track: bool = False, + routing_id: int | None = None, + group: str | None = None, + ) -> zmq.MessageTracker | None: + """Send a single zmq message frame on this socket. + + This queues the message to be sent by the IO thread at a later time. + + With flags=NOBLOCK, this raises :class:`ZMQError` if the queue is full; + otherwise, this waits until space is available. + See :class:`Poller` for more general non-blocking I/O. + + Parameters + ---------- + data : bytes, Frame, memoryview + The content of the message. This can be any object that provides + the Python buffer API (i.e. `memoryview(data)` can be called). + flags : int + 0, NOBLOCK, SNDMORE, or NOBLOCK|SNDMORE. + copy : bool + Should the message be sent in a copying or non-copying manner. + track : bool + Should the message be tracked for notification that ZMQ has + finished with it? (ignored if copy=True) + routing_id : int + For use with SERVER sockets + group : str + For use with RADIO sockets + + Returns + ------- + None : if `copy` or not track + None if message was sent, raises an exception otherwise. + MessageTracker : if track and not copy + a MessageTracker object, whose `done` property will + be False until the send is completed. + + Raises + ------ + TypeError + If a unicode object is passed + ValueError + If `track=True`, but an untracked Frame is passed. + ZMQError + If the send does not succeed for any reason (including + if NOBLOCK is set and the outgoing queue is full). + + + .. versionchanged:: 17.0 + + DRAFT support for routing_id and group arguments. + """ + if routing_id is not None: + if not isinstance(data, zmq.Frame): + data = zmq.Frame( + data, + track=track, + copy=copy or None, + copy_threshold=self.copy_threshold, + ) + data.routing_id = routing_id + if group is not None: + if not isinstance(data, zmq.Frame): + data = zmq.Frame( + data, + track=track, + copy=copy or None, + copy_threshold=self.copy_threshold, + ) + data.group = group + return super().send(data, flags=flags, copy=copy, track=track) + + def send_multipart( + self, + msg_parts: Sequence, + flags: int = 0, + copy: bool = True, + track: bool = False, + **kwargs, + ): + """Send a sequence of buffers as a multipart message. + + The zmq.SNDMORE flag is added to all msg parts before the last. + + Parameters + ---------- + msg_parts : iterable + A sequence of objects to send as a multipart message. Each element + can be any sendable object (Frame, bytes, buffer-providers) + flags : int, optional + Any valid flags for :func:`Socket.send`. + SNDMORE is added automatically for frames before the last. + copy : bool, optional + Should the frame(s) be sent in a copying or non-copying manner. + If copy=False, frames smaller than self.copy_threshold bytes + will be copied anyway. + track : bool, optional + Should the frame(s) be tracked for notification that ZMQ has + finished with it (ignored if copy=True). + + Returns + ------- + None : if copy or not track + MessageTracker : if track and not copy + a MessageTracker object, whose `done` property will + be False until the last send is completed. + """ + # typecheck parts before sending: + for i, msg in enumerate(msg_parts): + if isinstance(msg, (zmq.Frame, bytes, memoryview)): + continue + try: + memoryview(msg) + except Exception: + rmsg = repr(msg) + if len(rmsg) > 32: + rmsg = rmsg[:32] + '...' + raise TypeError( + f"Frame {i} ({rmsg}) does not support the buffer interface." + ) + for msg in msg_parts[:-1]: + self.send(msg, zmq.SNDMORE | flags, copy=copy, track=track) + # Send the last part without the extra SNDMORE flag. + return self.send(msg_parts[-1], flags, copy=copy, track=track) + + @overload + def recv_multipart( + self, flags: int = ..., *, copy: Literal[True], track: bool = ... + ) -> list[bytes]: ... + + @overload + def recv_multipart( + self, flags: int = ..., *, copy: Literal[False], track: bool = ... + ) -> list[zmq.Frame]: ... + + @overload + def recv_multipart(self, flags: int = ..., *, track: bool = ...) -> list[bytes]: ... + + @overload + def recv_multipart( + self, flags: int = 0, copy: bool = True, track: bool = False + ) -> list[zmq.Frame] | list[bytes]: ... + + def recv_multipart( + self, flags: int = 0, copy: bool = True, track: bool = False + ) -> list[zmq.Frame] | list[bytes]: + """Receive a multipart message as a list of bytes or Frame objects + + Parameters + ---------- + flags : int, optional + Any valid flags for :func:`Socket.recv`. + copy : bool, optional + Should the message frame(s) be received in a copying or non-copying manner? + If False a Frame object is returned for each part, if True a copy of + the bytes is made for each frame. + track : bool, optional + Should the message frame(s) be tracked for notification that ZMQ has + finished with it? (ignored if copy=True) + + Returns + ------- + msg_parts : list + A list of frames in the multipart message; either Frames or bytes, + depending on `copy`. + + Raises + ------ + ZMQError + for any of the reasons :func:`~Socket.recv` might fail + """ + parts = [self.recv(flags, copy=copy, track=track)] + # have first part already, only loop while more to receive + while self.getsockopt(zmq.RCVMORE): + part = self.recv(flags, copy=copy, track=track) + parts.append(part) + # cast List[Union] to Union[List] + # how do we get mypy to recognize that return type is invariant on `copy`? + return cast(Union[List[zmq.Frame], List[bytes]], parts) + + def _deserialize( + self, + recvd: bytes, + load: Callable[[bytes], Any], + ) -> Any: + """Deserialize a received message + + Override in subclass (e.g. Futures) if recvd is not the raw bytes. + + The default implementation expects bytes and returns the deserialized message immediately. + + Parameters + ---------- + + load: callable + Callable that deserializes bytes + recvd: + The object returned by self.recv + + """ + return load(recvd) + + def send_serialized(self, msg, serialize, flags=0, copy=True, **kwargs): + """Send a message with a custom serialization function. + + .. versionadded:: 17 + + Parameters + ---------- + msg : The message to be sent. Can be any object serializable by `serialize`. + serialize : callable + The serialization function to use. + serialize(msg) should return an iterable of sendable message frames + (e.g. bytes objects), which will be passed to send_multipart. + flags : int, optional + Any valid flags for :func:`Socket.send`. + copy : bool, optional + Whether to copy the frames. + + """ + frames = serialize(msg) + return self.send_multipart(frames, flags=flags, copy=copy, **kwargs) + + def recv_serialized(self, deserialize, flags=0, copy=True): + """Receive a message with a custom deserialization function. + + .. versionadded:: 17 + + Parameters + ---------- + deserialize : callable + The deserialization function to use. + deserialize will be called with one argument: the list of frames + returned by recv_multipart() and can return any object. + flags : int, optional + Any valid flags for :func:`Socket.recv`. + copy : bool, optional + Whether to recv bytes or Frame objects. + + Returns + ------- + obj : object + The object returned by the deserialization function. + + Raises + ------ + ZMQError + for any of the reasons :func:`~Socket.recv` might fail + """ + frames = self.recv_multipart(flags=flags, copy=copy) + return self._deserialize(frames, deserialize) + + def send_string( + self, + u: str, + flags: int = 0, + copy: bool = True, + encoding: str = 'utf-8', + **kwargs, + ) -> zmq.Frame | None: + """Send a Python unicode string as a message with an encoding. + + 0MQ communicates with raw bytes, so you must encode/decode + text (str) around 0MQ. + + Parameters + ---------- + u : str + The unicode string to send. + flags : int, optional + Any valid flags for :func:`Socket.send`. + encoding : str + The encoding to be used + """ + if not isinstance(u, str): + raise TypeError("str objects only") + return self.send(u.encode(encoding), flags=flags, copy=copy, **kwargs) + + send_unicode = send_string + + def recv_string(self, flags: int = 0, encoding: str = 'utf-8') -> str: + """Receive a unicode string, as sent by send_string. + + Parameters + ---------- + flags : int + Any valid flags for :func:`Socket.recv`. + encoding : str + The encoding to be used + + Returns + ------- + s : str + The Python unicode string that arrives as encoded bytes. + + Raises + ------ + ZMQError + for any of the reasons :func:`Socket.recv` might fail + """ + msg = self.recv(flags=flags) + return self._deserialize(msg, lambda buf: buf.decode(encoding)) + + recv_unicode = recv_string + + def send_pyobj( + self, obj: Any, flags: int = 0, protocol: int = DEFAULT_PROTOCOL, **kwargs + ) -> zmq.Frame | None: + """ + Send a Python object as a message using pickle to serialize. + + .. warning:: + + Never deserialize an untrusted message with pickle, + which can involve arbitrary code execution. + Make sure to authenticate the sources of messages + before unpickling them, e.g. with transport-level security + (e.g. CURVE, ZAP, or IPC permissions) + or signed messages. + + Parameters + ---------- + obj : Python object + The Python object to send. + flags : int + Any valid flags for :func:`Socket.send`. + protocol : int + The pickle protocol number to use. The default is pickle.DEFAULT_PROTOCOL + where defined, and pickle.HIGHEST_PROTOCOL elsewhere. + """ + msg = pickle.dumps(obj, protocol) + return self.send(msg, flags=flags, **kwargs) + + def recv_pyobj(self, flags: int = 0) -> Any: + """ + Receive a Python object as a message using UNSAFE pickle to serialize. + + .. warning:: + + Never deserialize an untrusted message with pickle, + which can involve arbitrary code execution. + Make sure to authenticate the sources of messages + before unpickling them, e.g. with transport-level security + (such as CURVE or IPC permissions) + or authenticating messages themselves before deserializing. + + Parameters + ---------- + flags : int + Any valid flags for :func:`Socket.recv`. + + Returns + ------- + obj : Python object + The Python object that arrives as a message. + + Raises + ------ + ZMQError + for any of the reasons :func:`~Socket.recv` might fail + """ + msg = self.recv(flags) + return self._deserialize(msg, pickle.loads) + + def send_json(self, obj: Any, flags: int = 0, **kwargs) -> None: + """Send a Python object as a message using json to serialize. + + Keyword arguments are passed on to json.dumps + + Parameters + ---------- + obj : Python object + The Python object to send + flags : int + Any valid flags for :func:`Socket.send` + """ + send_kwargs = {} + for key in ('routing_id', 'group'): + if key in kwargs: + send_kwargs[key] = kwargs.pop(key) + msg = jsonapi.dumps(obj, **kwargs) + return self.send(msg, flags=flags, **send_kwargs) + + def recv_json(self, flags: int = 0, **kwargs) -> _JSONType: + """Receive a Python object as a message using json to serialize. + + Keyword arguments are passed on to json.loads + + Parameters + ---------- + flags : int + Any valid flags for :func:`Socket.recv`. + + Returns + ------- + obj : Python object + The Python object that arrives as a message. + + Raises + ------ + ZMQError + for any of the reasons :func:`~Socket.recv` might fail + """ + msg = self.recv(flags) + return self._deserialize(msg, lambda buf: jsonapi.loads(buf, **kwargs)) + + _poller_class = Poller + + def poll(self, timeout: int | None = None, flags: int = zmq.POLLIN) -> int: + """Poll the socket for events. + + See :class:`Poller` to wait for multiple sockets at once. + + Parameters + ---------- + timeout : int + The timeout (in milliseconds) to wait for an event. If unspecified + (or specified None), will wait forever for an event. + flags : int + default: POLLIN. + POLLIN, POLLOUT, or POLLIN|POLLOUT. The event flags to poll for. + + Returns + ------- + event_mask : int + The poll event mask (POLLIN, POLLOUT), + 0 if the timeout was reached without an event. + """ + + if self.closed: + raise ZMQError(zmq.ENOTSUP) + + p = self._poller_class() + p.register(self, flags) + evts = dict(p.poll(timeout)) + # return 0 if no events, otherwise return event bitfield + return evts.get(self, 0) + + def get_monitor_socket( + self: _SocketType, events: int | None = None, addr: str | None = None + ) -> _SocketType: + """Return a connected PAIR socket ready to receive the event notifications. + + .. versionadded:: libzmq-4.0 + .. versionadded:: 14.0 + + Parameters + ---------- + events : int + default: `zmq.EVENT_ALL` + The bitmask defining which events are wanted. + addr : str + The optional endpoint for the monitoring sockets. + + Returns + ------- + socket : zmq.Socket + The PAIR socket, connected and ready to receive messages. + """ + # safe-guard, method only available on libzmq >= 4 + if zmq.zmq_version_info() < (4,): + raise NotImplementedError( + f"get_monitor_socket requires libzmq >= 4, have {zmq.zmq_version()}" + ) + + # if already monitoring, return existing socket + if self._monitor_socket: + if self._monitor_socket.closed: + self._monitor_socket = None + else: + return self._monitor_socket + + if addr is None: + # create endpoint name from internal fd + addr = f"inproc://monitor.s-{self.FD}" + if events is None: + # use all events + events = zmq.EVENT_ALL + # attach monitoring socket + self.monitor(addr, events) + # create new PAIR socket and connect it + self._monitor_socket = self.context.socket(zmq.PAIR) + self._monitor_socket.connect(addr) + return self._monitor_socket + + def disable_monitor(self) -> None: + """Shutdown the PAIR socket (created using get_monitor_socket) + that is serving socket events. + + .. versionadded:: 14.4 + """ + self._monitor_socket = None + self.monitor(None, 0) + + +SyncSocket: TypeAlias = Socket[bytes] + +__all__ = ['Socket', 'SyncSocket'] diff --git a/source/zmq/sugar/stopwatch.py b/source/zmq/sugar/stopwatch.py new file mode 100644 index 0000000000000000000000000000000000000000..2001e670a92761a5c3a10e40d93f39e2c21a3a10 --- /dev/null +++ b/source/zmq/sugar/stopwatch.py @@ -0,0 +1,36 @@ +"""Deprecated Stopwatch implementation""" + +# Copyright (c) PyZMQ Development Team. +# Distributed under the terms of the Modified BSD License. + + +class Stopwatch: + """Deprecated zmq.Stopwatch implementation + + You can use Python's builtin timers (time.monotonic, etc.). + """ + + def __init__(self): + import warnings + + warnings.warn( + "zmq.Stopwatch is deprecated. Use stdlib time.monotonic and friends instead", + DeprecationWarning, + stacklevel=2, + ) + self._start = 0 + import time + + try: + self._monotonic = time.monotonic + except AttributeError: + self._monotonic = time.time + + def start(self): + """Start the counter""" + self._start = self._monotonic() + + def stop(self): + """Return time since start in microseconds""" + stop = self._monotonic() + return int(1e6 * (stop - self._start)) diff --git a/source/zmq/sugar/tracker.py b/source/zmq/sugar/tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..973fdbd6e2740951969e0f071dd07837cad1d632 --- /dev/null +++ b/source/zmq/sugar/tracker.py @@ -0,0 +1,116 @@ +"""Tracker for zero-copy messages with 0MQ.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from __future__ import annotations + +import time +from threading import Event + +from zmq.backend import Frame +from zmq.error import NotDone + + +class MessageTracker: + """A class for tracking if 0MQ is done using one or more messages. + + When you send a 0MQ message, it is not sent immediately. The 0MQ IO thread + sends the message at some later time. Often you want to know when 0MQ has + actually sent the message though. This is complicated by the fact that + a single 0MQ message can be sent multiple times using different sockets. + This class allows you to track all of the 0MQ usages of a message. + + Parameters + ---------- + towatch : Event, MessageTracker, zmq.Frame + This objects to track. This class can track the low-level + Events used by the Message class, other MessageTrackers or + actual Messages. + """ + + events: set[Event] + peers: set[MessageTracker] + + def __init__(self, *towatch: tuple[MessageTracker | Event | Frame]): + """Create a message tracker to track a set of messages. + + Parameters + ---------- + *towatch : tuple of Event, MessageTracker, Message instances. + This list of objects to track. This class can track the low-level + Events used by the Message class, other MessageTrackers or + actual Messages. + """ + self.events = set() + self.peers = set() + for obj in towatch: + if isinstance(obj, Event): + self.events.add(obj) + elif isinstance(obj, MessageTracker): + self.peers.add(obj) + elif isinstance(obj, Frame): + if not obj.tracker: + raise ValueError("Not a tracked message") + self.peers.add(obj.tracker) + else: + raise TypeError(f"Require Events or Message Frames, not {type(obj)}") + + @property + def done(self): + """Is 0MQ completely done with the message(s) being tracked?""" + for evt in self.events: + if not evt.is_set(): + return False + for pm in self.peers: + if not pm.done: + return False + return True + + def wait(self, timeout: float | int = -1): + """Wait for 0MQ to be done with the message or until `timeout`. + + Parameters + ---------- + timeout : float + default: -1, which means wait forever. + Maximum time in (s) to wait before raising NotDone. + + Returns + ------- + None + if done before `timeout` + + Raises + ------ + NotDone + if `timeout` reached before I am done. + """ + tic = time.time() + remaining: float + if timeout is False or timeout < 0: + remaining = 3600 * 24 * 7 # a week + else: + remaining = timeout + for evt in self.events: + if remaining < 0: + raise NotDone + evt.wait(timeout=remaining) + if not evt.is_set(): + raise NotDone + toc = time.time() + remaining -= toc - tic + tic = toc + + for peer in self.peers: + if remaining < 0: + raise NotDone + peer.wait(timeout=remaining) + toc = time.time() + remaining -= toc - tic + tic = toc + + +_FINISHED_TRACKER = MessageTracker() + +__all__ = ['MessageTracker', '_FINISHED_TRACKER'] diff --git a/source/zmq/sugar/version.py b/source/zmq/sugar/version.py new file mode 100644 index 0000000000000000000000000000000000000000..5659d7bc43561ca28e384fe76d54159b36df5871 --- /dev/null +++ b/source/zmq/sugar/version.py @@ -0,0 +1,67 @@ +"""PyZMQ and 0MQ version functions.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. +from __future__ import annotations + +import re +from typing import Match, cast + +from zmq.backend import zmq_version_info + +__version__: str = "27.1.0" +_version_pat = re.compile(r"(\d+)\.(\d+)\.(\d+)(.*)") +_match = cast(Match, _version_pat.match(__version__)) +_version_groups = _match.groups() + +VERSION_MAJOR = int(_version_groups[0]) +VERSION_MINOR = int(_version_groups[1]) +VERSION_PATCH = int(_version_groups[2]) +VERSION_EXTRA = _version_groups[3].lstrip(".") + +version_info: tuple[int, int, int] | tuple[int, int, int, float] = ( + VERSION_MAJOR, + VERSION_MINOR, + VERSION_PATCH, +) + +if VERSION_EXTRA: + version_info = ( + VERSION_MAJOR, + VERSION_MINOR, + VERSION_PATCH, + float('inf'), + ) + +__revision__: str = '' + + +def pyzmq_version() -> str: + """return the version of pyzmq as a string""" + if __revision__: + return '+'.join([__version__, __revision__[:6]]) + else: + return __version__ + + +def pyzmq_version_info() -> tuple[int, int, int] | tuple[int, int, int, float]: + """return the pyzmq version as a tuple of at least three numbers + + If pyzmq is a development version, `inf` will be appended after the third integer. + """ + return version_info + + +def zmq_version() -> str: + """return the version of libzmq as a string""" + return "{}.{}.{}".format(*zmq_version_info()) + + +__all__ = [ + 'zmq_version', + 'zmq_version_info', + 'pyzmq_version', + 'pyzmq_version_info', + '__version__', + '__revision__', +] diff --git a/source/zmq/tests/__init__.py b/source/zmq/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5be08c6a8a2c01f8bcce399252a7794d646a670d --- /dev/null +++ b/source/zmq/tests/__init__.py @@ -0,0 +1,262 @@ +# Copyright (c) PyZMQ Developers. +# Distributed under the terms of the Modified BSD License. + +import os +import platform +import signal +import sys +import time +import warnings +from functools import partial +from threading import Thread +from typing import List +from unittest import SkipTest, TestCase + +from pytest import mark + +import zmq +from zmq.utils import jsonapi + +try: + import gevent + + from zmq import green as gzmq + + have_gevent = True +except ImportError: + have_gevent = False + + +PYPY = platform.python_implementation() == 'PyPy' + +# ----------------------------------------------------------------------------- +# skip decorators (directly from unittest) +# ----------------------------------------------------------------------------- +warnings.warn( + "zmq.tests is deprecated in pyzmq 25, we recommend managing your own contexts and sockets.", + DeprecationWarning, +) + + +def _id(x): + return x + + +skip_pypy = mark.skipif(PYPY, reason="Doesn't work on PyPy") +require_zmq_4 = mark.skipif(zmq.zmq_version_info() < (4,), reason="requires zmq >= 4") + +# ----------------------------------------------------------------------------- +# Base test class +# ----------------------------------------------------------------------------- + + +def term_context(ctx, timeout): + """Terminate a context with a timeout""" + t = Thread(target=ctx.term) + t.daemon = True + t.start() + t.join(timeout=timeout) + if t.is_alive(): + # reset Context.instance, so the failure to term doesn't corrupt subsequent tests + zmq.sugar.context.Context._instance = None + raise RuntimeError( + "context could not terminate, open sockets likely remain in test" + ) + + +class BaseZMQTestCase(TestCase): + green = False + teardown_timeout = 10 + test_timeout_seconds = int(os.environ.get("ZMQ_TEST_TIMEOUT") or 60) + sockets: List[zmq.Socket] + + @property + def _is_pyzmq_test(self): + return self.__class__.__module__.split(".", 1)[0] == __name__.split(".", 1)[0] + + @property + def _should_test_timeout(self): + return ( + self._is_pyzmq_test + and hasattr(signal, 'SIGALRM') + and self.test_timeout_seconds + ) + + @property + def Context(self): + if self.green: + return gzmq.Context + else: + return zmq.Context + + def socket(self, socket_type): + s = self.context.socket(socket_type) + self.sockets.append(s) + return s + + def _alarm_timeout(self, timeout, *args): + raise TimeoutError(f"Test did not complete in {timeout} seconds") + + def setUp(self): + super().setUp() + if self.green and not have_gevent: + raise SkipTest("requires gevent") + + self.context = self.Context.instance() + self.sockets = [] + if self._should_test_timeout: + # use SIGALRM to avoid test hangs + signal.signal( + signal.SIGALRM, partial(self._alarm_timeout, self.test_timeout_seconds) + ) + signal.alarm(self.test_timeout_seconds) + + def tearDown(self): + if self._should_test_timeout: + # cancel the timeout alarm, if there was one + signal.alarm(0) + contexts = {self.context} + while self.sockets: + sock = self.sockets.pop() + contexts.add(sock.context) # in case additional contexts are created + sock.close(0) + for ctx in contexts: + try: + term_context(ctx, self.teardown_timeout) + except Exception: + # reset Context.instance, so the failure to term doesn't corrupt subsequent tests + zmq.sugar.context.Context._instance = None + raise + + super().tearDown() + + def create_bound_pair( + self, type1=zmq.PAIR, type2=zmq.PAIR, interface='tcp://127.0.0.1' + ): + """Create a bound socket pair using a random port.""" + s1 = self.context.socket(type1) + s1.setsockopt(zmq.LINGER, 0) + port = s1.bind_to_random_port(interface) + s2 = self.context.socket(type2) + s2.setsockopt(zmq.LINGER, 0) + s2.connect(f'{interface}:{port}') + self.sockets.extend([s1, s2]) + return s1, s2 + + def ping_pong(self, s1, s2, msg): + s1.send(msg) + msg2 = s2.recv() + s2.send(msg2) + msg3 = s1.recv() + return msg3 + + def ping_pong_json(self, s1, s2, o): + if jsonapi.jsonmod is None: + raise SkipTest("No json library") + s1.send_json(o) + o2 = s2.recv_json() + s2.send_json(o2) + o3 = s1.recv_json() + return o3 + + def ping_pong_pyobj(self, s1, s2, o): + s1.send_pyobj(o) + o2 = s2.recv_pyobj() + s2.send_pyobj(o2) + o3 = s1.recv_pyobj() + return o3 + + def assertRaisesErrno(self, errno, func, *args, **kwargs): + try: + func(*args, **kwargs) + except zmq.ZMQError as e: + self.assertEqual( + e.errno, + errno, + f"wrong error raised, expected '{zmq.ZMQError(errno)}' \ +got '{zmq.ZMQError(e.errno)}'", + ) + else: + self.fail("Function did not raise any error") + + def _select_recv(self, multipart, socket, **kwargs): + """call recv[_multipart] in a way that raises if there is nothing to receive""" + # zmq 3.1 has a bug, where poll can return false positives, + # so we wait a little bit just in case + # See LIBZMQ-280 on JIRA + time.sleep(0.1) + + r, w, x = zmq.select([socket], [], [], timeout=kwargs.pop('timeout', 5)) + assert len(r) > 0, "Should have received a message" + kwargs['flags'] = zmq.DONTWAIT | kwargs.get('flags', 0) + + recv = socket.recv_multipart if multipart else socket.recv + return recv(**kwargs) + + def recv(self, socket, **kwargs): + """call recv in a way that raises if there is nothing to receive""" + return self._select_recv(False, socket, **kwargs) + + def recv_multipart(self, socket, **kwargs): + """call recv_multipart in a way that raises if there is nothing to receive""" + return self._select_recv(True, socket, **kwargs) + + +class PollZMQTestCase(BaseZMQTestCase): + pass + + +class GreenTest: + """Mixin for making green versions of test classes""" + + green = True + teardown_timeout = 10 + + def assertRaisesErrno(self, errno, func, *args, **kwargs): + if errno == zmq.EAGAIN: + raise SkipTest("Skipping because we're green.") + try: + func(*args, **kwargs) + except zmq.ZMQError: + e = sys.exc_info()[1] + self.assertEqual( + e.errno, + errno, + f"wrong error raised, expected '{zmq.ZMQError(errno)}' \ +got '{zmq.ZMQError(e.errno)}'", + ) + else: + self.fail("Function did not raise any error") + + def tearDown(self): + if self._should_test_timeout: + # cancel the timeout alarm, if there was one + signal.alarm(0) + contexts = {self.context} + while self.sockets: + sock = self.sockets.pop() + contexts.add(sock.context) # in case additional contexts are created + sock.close() + try: + gevent.joinall( + [gevent.spawn(ctx.term) for ctx in contexts], + timeout=self.teardown_timeout, + raise_error=True, + ) + except gevent.Timeout: + raise RuntimeError( + "context could not terminate, open sockets likely remain in test" + ) + + def skip_green(self): + raise SkipTest("Skipping because we are green") + + +def skip_green(f): + def skipping_test(self, *args, **kwargs): + if self.green: + raise SkipTest("Skipping because we are green") + else: + return f(self, *args, **kwargs) + + return skipping_test diff --git a/source/zmq/utils/__init__.py b/source/zmq/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/source/zmq/utils/garbage.py b/source/zmq/utils/garbage.py new file mode 100644 index 0000000000000000000000000000000000000000..2c700313d917f2dd44a3d62b4a2a79bb0a1f9de0 --- /dev/null +++ b/source/zmq/utils/garbage.py @@ -0,0 +1,213 @@ +"""Garbage collection thread for representing zmq refcount of Python objects +used in zero-copy sends. +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import atexit +import struct +import warnings +from collections import namedtuple +from os import getpid +from threading import Event, Lock, Thread + +import zmq + +gcref = namedtuple('gcref', ['obj', 'event']) + + +class GarbageCollectorThread(Thread): + """Thread in which garbage collection actually happens.""" + + def __init__(self, gc): + super().__init__() + self.gc = gc + self.daemon = True + self.pid = getpid() + self.ready = Event() + + def run(self): + # detect fork at beginning of the thread + if getpid is None or getpid() != self.pid: + self.ready.set() + return + try: + s = self.gc.context.socket(zmq.PULL) + s.linger = 0 + s.bind(self.gc.url) + finally: + self.ready.set() + + while True: + # detect fork + if getpid is None or getpid() != self.pid: + return + msg = s.recv() + if msg == b'DIE': + break + fmt = 'L' if len(msg) == 4 else 'Q' + key = struct.unpack(fmt, msg)[0] + tup = self.gc.refs.pop(key, None) + if tup and tup.event: + tup.event.set() + del tup + s.close() + + +class GarbageCollector: + """PyZMQ Garbage Collector + + Used for representing the reference held by libzmq during zero-copy sends. + This object holds a dictionary, keyed by Python id, + of the Python objects whose memory are currently in use by zeromq. + + When zeromq is done with the memory, it sends a message on an inproc PUSH socket + containing the packed size_t (32 or 64-bit unsigned int), + which is the key in the dict. + When the PULL socket in the gc thread receives that message, + the reference is popped from the dict, + and any tracker events that should be signaled fire. + """ + + refs = None + _context = None + _lock = None + url = "inproc://pyzmq.gc.01" + + def __init__(self, context=None): + super().__init__() + self.refs = {} + self.pid = None + self.thread = None + self._context = context + self._lock = Lock() + self._stay_down = False + self._push = None + self._push_mutex = None + atexit.register(self._atexit) + + @property + def context(self): + if self._context is None: + if Thread.__module__.startswith('gevent'): + # gevent has monkey-patched Thread, use green Context + from zmq import green + + self._context = green.Context() + else: + self._context = zmq.Context() + return self._context + + @context.setter + def context(self, ctx): + if self.is_alive(): + if self.refs: + warnings.warn( + "Replacing gc context while gc is running", RuntimeWarning + ) + self.stop() + self._context = ctx + + def _atexit(self): + """atexit callback + + sets _stay_down flag so that gc doesn't try to start up again in other atexit handlers + """ + self._stay_down = True + self.stop() + + def stop(self): + """stop the garbage-collection thread""" + if not self.is_alive(): + return + self._stop() + + def _clear(self): + """Clear state + + called after stop or when setting up a new subprocess + """ + self._push = None + self._push_mutex = None + self.thread = None + self.refs.clear() + self.context = None + + def _stop(self): + push = self.context.socket(zmq.PUSH) + push.connect(self.url) + push.send(b'DIE') + push.close() + if self._push: + self._push.close() + self.thread.join() + self.context.term() + self._clear() + + @property + def _push_socket(self): + """The PUSH socket for use in the zmq message destructor callback.""" + if getattr(self, "_stay_down", False): + raise RuntimeError("zmq gc socket requested during shutdown") + if not self.is_alive() or self._push is None: + self._push = self.context.socket(zmq.PUSH) + self._push.connect(self.url) + return self._push + + def start(self): + """Start a new garbage collection thread. + + Creates a new zmq Context used for garbage collection. + Under most circumstances, this will only be called once per process. + """ + if self.thread is not None and self.pid != getpid(): + # It's re-starting, must free earlier thread's context + # since a fork probably broke it + self._clear() + self.pid = getpid() + self.refs = {} + self.thread = GarbageCollectorThread(self) + self.thread.start() + self.thread.ready.wait() + + def is_alive(self): + """Is the garbage collection thread currently running? + + Includes checks for process shutdown or fork. + """ + if ( + getpid is None + or getpid() != self.pid + or self.thread is None + or not self.thread.is_alive() + ): + return False + return True + + def store(self, obj, event=None): + """store an object and (optionally) event for zero-copy""" + if not self.is_alive(): + if self._stay_down: + return 0 + # safely start the gc thread + # use lock and double check, + # so we don't start multiple threads + with self._lock: + if not self.is_alive(): + self.start() + tup = gcref(obj, event) + theid = id(tup) + self.refs[theid] = tup + return theid + + def __del__(self): + if not self.is_alive(): + return + try: + self.stop() + except Exception as e: + raise (e) + + +gc = GarbageCollector() diff --git a/source/zmq/utils/getpid_compat.h b/source/zmq/utils/getpid_compat.h new file mode 100644 index 0000000000000000000000000000000000000000..3fe1bec4c24ac472fb9ab8940da43adc1de60257 --- /dev/null +++ b/source/zmq/utils/getpid_compat.h @@ -0,0 +1,7 @@ +#pragma once +#ifdef _WIN32 + #include + #define getpid _getpid +#else + #include +#endif diff --git a/source/zmq/utils/interop.py b/source/zmq/utils/interop.py new file mode 100644 index 0000000000000000000000000000000000000000..ab4ffd9a813c7c0e490428123b08b9e38c85d52b --- /dev/null +++ b/source/zmq/utils/interop.py @@ -0,0 +1,29 @@ +"""Utils for interoperability with other libraries. + +Just CFFI pointer casting for now. +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from typing import Any + + +def cast_int_addr(n: Any) -> int: + """Cast an address to a Python int + + This could be a Python integer or a CFFI pointer + """ + if isinstance(n, int): + return n + try: + import cffi # type: ignore + except ImportError: + pass + else: + # from pyzmq, this is an FFI void * + ffi = cffi.FFI() + if isinstance(n, ffi.CData): + return int(ffi.cast("size_t", n)) + + raise ValueError(f"Cannot cast {n!r} to int") diff --git a/source/zmq/utils/ipcmaxlen.h b/source/zmq/utils/ipcmaxlen.h new file mode 100644 index 0000000000000000000000000000000000000000..7af9a261be4be890aab7df2b570d59e09ff2a3e2 --- /dev/null +++ b/source/zmq/utils/ipcmaxlen.h @@ -0,0 +1,27 @@ +/* + +Platform-independant detection of IPC path max length + +Copyright (c) 2012 Godefroid Chapelle + +Distributed under the terms of the New BSD License. The full license is in +the file LICENSE.BSD, distributed as part of this software. + */ + +#pragma once + +#if defined(HAVE_SYS_UN_H) +#if defined _MSC_VER +#include +#else +#include +#endif +int get_ipc_path_max_len(void) { + struct sockaddr_un *dummy; + return sizeof(dummy->sun_path) - 1; +} +#else +int get_ipc_path_max_len(void) { + return 0; +} +#endif diff --git a/source/zmq/utils/jsonapi.py b/source/zmq/utils/jsonapi.py new file mode 100644 index 0000000000000000000000000000000000000000..6a6ee0785229f449ccc9efc3e778909699866de0 --- /dev/null +++ b/source/zmq/utils/jsonapi.py @@ -0,0 +1,38 @@ +"""JSON serialize to/from utf8 bytes + +.. versionchanged:: 22.2 + Remove optional imports of different JSON implementations. + Now that we require recent Python, unconditionally use the standard library. + Custom JSON libraries can be used via custom serialization functions. +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. +from __future__ import annotations + +import json +from typing import Any + +# backward-compatibility, unused +jsonmod = json + + +def dumps(o: Any, **kwargs) -> bytes: + """Serialize object to JSON bytes (utf-8). + + Keyword arguments are passed along to :py:func:`json.dumps`. + """ + return json.dumps(o, **kwargs).encode("utf8") + + +def loads(s: bytes | str, **kwargs) -> dict | list | str | int | float: + """Load object from JSON bytes (utf-8). + + Keyword arguments are passed along to :py:func:`json.loads`. + """ + if isinstance(s, bytes): + s = s.decode("utf8") + return json.loads(s, **kwargs) + + +__all__ = ['dumps', 'loads'] diff --git a/source/zmq/utils/monitor.py b/source/zmq/utils/monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..872e24dc2953ba17877d39c1628d3f38d53484d6 --- /dev/null +++ b/source/zmq/utils/monitor.py @@ -0,0 +1,128 @@ +"""Module holding utility and convenience functions for zmq event monitoring.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from __future__ import annotations + +import struct +from typing import Awaitable, TypedDict, overload + +import zmq +import zmq.asyncio +from zmq.error import _check_version + + +class _MonitorMessage(TypedDict): + event: int + value: int + endpoint: bytes + + +def parse_monitor_message(msg: list[bytes]) -> _MonitorMessage: + """decode zmq_monitor event messages. + + Parameters + ---------- + msg : list(bytes) + zmq multipart message that has arrived on a monitor PAIR socket. + + First frame is:: + + 16 bit event id + 32 bit event value + no padding + + Second frame is the endpoint as a bytestring + + Returns + ------- + event : dict + event description as dict with the keys `event`, `value`, and `endpoint`. + """ + if len(msg) != 2 or len(msg[0]) != 6: + raise RuntimeError(f"Invalid event message format: {msg}") + event_id, value = struct.unpack("=hi", msg[0]) + event: _MonitorMessage = { + 'event': zmq.Event(event_id), + 'value': zmq.Event(value), + 'endpoint': msg[1], + } + return event + + +async def _parse_monitor_msg_async( + awaitable_msg: Awaitable[list[bytes]], +) -> _MonitorMessage: + """Like parse_monitor_msg, but awaitable + + Given awaitable message, return awaitable for the parsed monitor message. + """ + + msg = await awaitable_msg + # 4.0-style event API + return parse_monitor_message(msg) + + +@overload +def recv_monitor_message( + socket: zmq.asyncio.Socket, + flags: int = 0, +) -> Awaitable[_MonitorMessage]: ... + + +@overload +def recv_monitor_message( + socket: zmq.Socket[bytes], + flags: int = 0, +) -> _MonitorMessage: ... + + +def recv_monitor_message( + socket: zmq.Socket, + flags: int = 0, +) -> _MonitorMessage | Awaitable[_MonitorMessage]: + """Receive and decode the given raw message from the monitoring socket and return a dict. + + Requires libzmq ≥ 4.0 + + The returned dict will have the following entries: + event : int + the event id as described in `libzmq.zmq_socket_monitor` + value : int + the event value associated with the event, see `libzmq.zmq_socket_monitor` + endpoint : str + the affected endpoint + + .. versionchanged:: 23.1 + Support for async sockets added. + When called with a async socket, + returns an awaitable for the monitor message. + + Parameters + ---------- + socket : zmq.Socket + The PAIR socket (created by other.get_monitor_socket()) on which to recv the message + flags : int + standard zmq recv flags + + Returns + ------- + event : dict + event description as dict with the keys `event`, `value`, and `endpoint`. + """ + + _check_version((4, 0), 'libzmq event API') + # will always return a list + msg = socket.recv_multipart(flags) + + # transparently handle asyncio socket, + # returns a Future instead of a dict + if isinstance(msg, Awaitable): + return _parse_monitor_msg_async(msg) + + # 4.0-style event API + return parse_monitor_message(msg) + + +__all__ = ['parse_monitor_message', 'recv_monitor_message'] diff --git a/source/zmq/utils/mutex.h b/source/zmq/utils/mutex.h new file mode 100644 index 0000000000000000000000000000000000000000..b6275ea28b0799607b93a69679a86f51a95fbb28 --- /dev/null +++ b/source/zmq/utils/mutex.h @@ -0,0 +1,84 @@ +/* +* simplified from mutex.c from Foundation Library, in the Public Domain +* https://github.com/rampantpixels/foundation_lib/blob/master/foundation/mutex.c +* +* This file is Copyright (C) PyZMQ Developers +* Distributed under the terms of the Modified BSD License. +* +*/ + +#pragma once + +#include + +#if defined(_WIN32) +# include +#else +# include +#endif + +typedef struct { +#if defined(_WIN32) + CRITICAL_SECTION csection; +#else + pthread_mutex_t mutex; +#endif +} mutex_t; + + +static void +_mutex_initialize(mutex_t* mutex) { +#if defined(_WIN32) + InitializeCriticalSectionAndSpinCount(&mutex->csection, 4000); +#else + pthread_mutexattr_t attr; + pthread_mutexattr_init(&attr); + pthread_mutexattr_settype(&attr, PTHREAD_MUTEX_RECURSIVE); + pthread_mutex_init(&mutex->mutex, &attr); + pthread_mutexattr_destroy(&attr); +#endif +} + +static void +_mutex_finalize(mutex_t* mutex) { +#if defined(_WIN32) + DeleteCriticalSection(&mutex->csection); +#else + pthread_mutex_destroy(&mutex->mutex); +#endif +} + +mutex_t* +mutex_allocate(void) { + mutex_t* mutex = (mutex_t*)malloc(sizeof(mutex_t)); + _mutex_initialize(mutex); + return mutex; +} + +void +mutex_deallocate(mutex_t* mutex) { + if (!mutex) + return; + _mutex_finalize(mutex); + free(mutex); +} + +int +mutex_lock(mutex_t* mutex) { +#if defined(_WIN32) + EnterCriticalSection(&mutex->csection); + return 0; +#else + return pthread_mutex_lock(&mutex->mutex); +#endif +} + +int +mutex_unlock(mutex_t* mutex) { +#if defined(_WIN32) + LeaveCriticalSection(&mutex->csection); + return 0; +#else + return pthread_mutex_unlock(&mutex->mutex); +#endif +} diff --git a/source/zmq/utils/pyversion_compat.h b/source/zmq/utils/pyversion_compat.h new file mode 100644 index 0000000000000000000000000000000000000000..fb19dcf09d6840ac12e995a308bb96838813797a --- /dev/null +++ b/source/zmq/utils/pyversion_compat.h @@ -0,0 +1,12 @@ +#include "Python.h" + +// default to Python's own target Windows version(s) +// override by setting WINVER, _WIN32_WINNT, (maybe also NTDDI_VERSION?) macros +#ifdef Py_WINVER +#ifndef WINVER +#define WINVER Py_WINVER +#endif +#ifndef _WIN32_WINNT +#define _WIN32_WINNT Py_WINVER +#endif +#endif diff --git a/source/zmq/utils/strtypes.py b/source/zmq/utils/strtypes.py new file mode 100644 index 0000000000000000000000000000000000000000..3d90a04809ca4f29e1457099b63c03302e837a1e --- /dev/null +++ b/source/zmq/utils/strtypes.py @@ -0,0 +1,62 @@ +"""Declare basic string types unambiguously for various Python versions. + +Authors +------- +* MinRK +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import warnings + +bytes = bytes +unicode = str +basestring = (str,) + + +def cast_bytes(s, encoding='utf8', errors='strict'): + """cast unicode or bytes to bytes""" + warnings.warn( + "zmq.utils.strtypes is deprecated in pyzmq 23.", + DeprecationWarning, + stacklevel=2, + ) + if isinstance(s, bytes): + return s + elif isinstance(s, str): + return s.encode(encoding, errors) + else: + raise TypeError(f"Expected unicode or bytes, got {s!r}") + + +def cast_unicode(s, encoding='utf8', errors='strict'): + """cast bytes or unicode to unicode""" + warnings.warn( + "zmq.utils.strtypes is deprecated in pyzmq 23.", + DeprecationWarning, + stacklevel=2, + ) + if isinstance(s, bytes): + return s.decode(encoding, errors) + elif isinstance(s, str): + return s + else: + raise TypeError(f"Expected unicode or bytes, got {s!r}") + + +# give short 'b' alias for cast_bytes, so that we can use fake b'stuff' +# to simulate b'stuff' +b = asbytes = cast_bytes +u = cast_unicode + +__all__ = [ + 'asbytes', + 'bytes', + 'unicode', + 'basestring', + 'b', + 'u', + 'cast_bytes', + 'cast_unicode', +] diff --git a/source/zmq/utils/win32.py b/source/zmq/utils/win32.py new file mode 100644 index 0000000000000000000000000000000000000000..019d429715af914d4aeb508bda342f764a74e9d0 --- /dev/null +++ b/source/zmq/utils/win32.py @@ -0,0 +1,130 @@ +"""Win32 compatibility utilities.""" + +# ----------------------------------------------------------------------------- +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. +# ----------------------------------------------------------------------------- +from __future__ import annotations + +import os +from typing import Any, Callable + + +class allow_interrupt: + """Utility for fixing CTRL-C events on Windows. + + On Windows, the Python interpreter intercepts CTRL-C events in order to + translate them into ``KeyboardInterrupt`` exceptions. It (presumably) + does this by setting a flag in its "console control handler" and + checking it later at a convenient location in the interpreter. + + However, when the Python interpreter is blocked waiting for the ZMQ + poll operation to complete, it must wait for ZMQ's ``select()`` + operation to complete before translating the CTRL-C event into the + ``KeyboardInterrupt`` exception. + + The only way to fix this seems to be to add our own "console control + handler" and perform some application-defined operation that will + unblock the ZMQ polling operation in order to force ZMQ to pass control + back to the Python interpreter. + + This context manager performs all that Windows-y stuff, providing you + with a hook that is called when a CTRL-C event is intercepted. This + hook allows you to unblock your ZMQ poll operation immediately, which + will then result in the expected ``KeyboardInterrupt`` exception. + + Without this context manager, your ZMQ-based application will not + respond normally to CTRL-C events on Windows. If a CTRL-C event occurs + while blocked on ZMQ socket polling, the translation to a + ``KeyboardInterrupt`` exception will be delayed until the I/O completes + and control returns to the Python interpreter (this may never happen if + you use an infinite timeout). + + A no-op implementation is provided on non-Win32 systems to avoid the + application from having to conditionally use it. + + Example usage: + + .. sourcecode:: python + + def stop_my_application(): + # ... + + with allow_interrupt(stop_my_application): + # main polling loop. + + In a typical ZMQ application, you would use the "self pipe trick" to + send message to a ``PAIR`` socket in order to interrupt your blocking + socket polling operation. + + In a Tornado event loop, you can use the ``IOLoop.stop`` method to + unblock your I/O loop. + """ + + def __init__(self, action: Callable[[], Any] | None = None) -> None: + """Translate ``action`` into a CTRL-C handler. + + ``action`` is a callable that takes no arguments and returns no + value (returned value is ignored). It must *NEVER* raise an + exception. + + If unspecified, a no-op will be used. + """ + if os.name != "nt": + return + self._init_action(action) + + def _init_action(self, action): + from ctypes import WINFUNCTYPE, windll + from ctypes.wintypes import BOOL, DWORD + + kernel32 = windll.LoadLibrary('kernel32') + + # + PHANDLER_ROUTINE = WINFUNCTYPE(BOOL, DWORD) + SetConsoleCtrlHandler = self._SetConsoleCtrlHandler = ( + kernel32.SetConsoleCtrlHandler + ) + SetConsoleCtrlHandler.argtypes = (PHANDLER_ROUTINE, BOOL) + SetConsoleCtrlHandler.restype = BOOL + + if action is None: + + def action(): + return None + + self.action = action + + @PHANDLER_ROUTINE + def handle(event): + if event == 0: # CTRL_C_EVENT + action() + # Typical C implementations would return 1 to indicate that + # the event was processed and other control handlers in the + # stack should not be executed. However, that would + # prevent the Python interpreter's handler from translating + # CTRL-C to a `KeyboardInterrupt` exception, so we pretend + # that we didn't handle it. + return 0 + + self.handle = handle + + def __enter__(self): + """Install the custom CTRL-C handler.""" + if os.name != "nt": + return + result = self._SetConsoleCtrlHandler(self.handle, 1) + if result == 0: + # Have standard library automatically call `GetLastError()` and + # `FormatMessage()` into a nice exception object :-) + raise OSError() + + def __exit__(self, *args): + """Remove the custom CTRL-C handler.""" + if os.name != "nt": + return + result = self._SetConsoleCtrlHandler(self.handle, 0) + if result == 0: + # Have standard library automatically call `GetLastError()` and + # `FormatMessage()` into a nice exception object :-) + raise OSError() diff --git a/source/zmq/utils/z85.py b/source/zmq/utils/z85.py new file mode 100644 index 0000000000000000000000000000000000000000..016bdfe12f303f1542fe00fc92ab1a72a20a20b9 --- /dev/null +++ b/source/zmq/utils/z85.py @@ -0,0 +1,58 @@ +"""Python implementation of Z85 85-bit encoding + +Z85 encoding is a plaintext encoding for a bytestring interpreted as 32bit integers. +Since the chunks are 32bit, a bytestring must be a multiple of 4 bytes. +See ZMQ RFC 32 for details. + + +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. +from __future__ import annotations + +import struct + +# Z85CHARS is the base 85 symbol table +Z85CHARS = b"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ.-:+=^!/*?&<>()[]{}@%$#" +# Z85MAP maps integers in [0,84] to the appropriate character in Z85CHARS +Z85MAP = {c: idx for idx, c in enumerate(Z85CHARS)} + +_85s = [85**i for i in range(5)][::-1] + + +def encode(rawbytes): + """encode raw bytes into Z85""" + # Accepts only byte arrays bounded to 4 bytes + if len(rawbytes) % 4: + raise ValueError(f"length must be multiple of 4, not {len(rawbytes)}") + + nvalues = len(rawbytes) // 4 + values = struct.unpack(f'>{nvalues:d}I', rawbytes) + encoded = [] + for v in values: + for offset in _85s: + encoded.append(Z85CHARS[(v // offset) % 85]) + + return bytes(encoded) + + +def decode(z85bytes): + """decode Z85 bytes to raw bytes, accepts ASCII string""" + if isinstance(z85bytes, str): + try: + z85bytes = z85bytes.encode('ascii') + except UnicodeEncodeError: + raise ValueError('string argument should contain only ASCII characters') + + if len(z85bytes) % 5: + raise ValueError(f"Z85 length must be multiple of 5, not {len(z85bytes)}") + + nvalues = len(z85bytes) // 5 + values = [] + for i in range(0, len(z85bytes), 5): + value = 0 + for j, offset in enumerate(_85s): + value += Z85MAP[z85bytes[i + j]] * offset + values.append(value) + return struct.pack(f'>{nvalues:d}I', *values) diff --git a/source/zmq/utils/zmq_compat.h b/source/zmq/utils/zmq_compat.h new file mode 100644 index 0000000000000000000000000000000000000000..55fb58ec877a08a2126cbfebd5f39174476c8290 --- /dev/null +++ b/source/zmq/utils/zmq_compat.h @@ -0,0 +1,100 @@ +//----------------------------------------------------------------------------- +// Copyright (c) 2010 Brian Granger, Min Ragan-Kelley +// +// Distributed under the terms of the New BSD License. The full license is in +// the file LICENSE.BSD, distributed as part of this software. +//----------------------------------------------------------------------------- + +#pragma once + +#if defined(_MSC_VER) +#define pyzmq_int64_t __int64 +#define pyzmq_uint32_t unsigned __int32 +#else +#include +#define pyzmq_int64_t int64_t +#define pyzmq_uint32_t uint32_t +#endif + + +#include "zmq.h" + +#define _missing (-1) + +#if (ZMQ_VERSION >= 40303) + // libzmq >= 4.3.3 defines zmq_fd_t for us + #define ZMQ_FD_T zmq_fd_t +#else + #ifdef _WIN32 + #if defined(_MSC_VER) && _MSC_VER <= 1400 + #define ZMQ_FD_T UINT_PTR + #else + #define ZMQ_FD_T SOCKET + #endif + #else + #define ZMQ_FD_T int + #endif +#endif + +#if (ZMQ_VERSION >= 40200) + // Nothing to remove +#else + #define zmq_curve_public(z85_public_key, z85_secret_key) _missing +#endif + +// use unambiguous aliases for zmq_send/recv functions + +#if ZMQ_VERSION_MAJOR >= 4 +// nothing to remove + #if ZMQ_VERSION_MAJOR == 4 && ZMQ_VERSION_MINOR == 0 + // zmq 4.1 deprecates zmq_utils.h + // we only get zmq_curve_keypair from it + #include "zmq_utils.h" + #endif +#else + #define zmq_curve_keypair(z85_public_key, z85_secret_key) _missing +#endif + +// libzmq 4.2 draft API +#ifdef ZMQ_BUILD_DRAFT_API + #define PYZMQ_DRAFT_API 1 + #if ZMQ_VERSION >= 40200 + #define PYZMQ_DRAFT_42 + #endif + #if ZMQ_VERSION >= 40302 + #define PYZMQ_DRAFT_432 + #endif +#else + #define PYZMQ_DRAFT_API 0 +#endif + +#ifndef PYZMQ_DRAFT_42 + #define zmq_join(s, group) _missing + #define zmq_leave(s, group) _missing + #define zmq_msg_set_routing_id(msg, routing_id) _missing + #define zmq_msg_routing_id(msg) 0 + #define zmq_msg_set_group(msg, group) _missing + #define zmq_msg_group(msg) NULL + #define zmq_poller_new() NULL + #define zmq_poller_destroy(poller_p) _missing + #define zmq_poller_add(poller, socket, userdata, events) _missing + #define zmq_poller_modify(poller, socket, events) _missing + #define zmq_poller_remove(poller, socket) _missing +#endif +#ifndef PYZMQ_DRAFT_432 + #define zmq_poller_fd(poller, fd) _missing +#endif + +#if ZMQ_VERSION >= 40100 +// nothing to remove +#else + #define zmq_msg_gets(msg, prop) _missing + #define zmq_has(capability) _missing + #define zmq_proxy_steerable(in, out, mon, ctrl) _missing +#endif + +// 3.x deprecations - these symbols haven't been removed, +// but let's protect against their planned removal +#define zmq_device(device_type, isocket, osocket) _missing +#define zmq_init(io_threads) ((void*)NULL) +#define zmq_term zmq_ctx_destroy