Harmony18090 commited on
Commit
c389653
·
verified ·
1 Parent(s): cabdcdf

Add source batch 11/11

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. source/watchfiles-1.1.1.dist-info/INSTALLER +1 -0
  3. source/watchfiles-1.1.1.dist-info/METADATA +148 -0
  4. source/watchfiles-1.1.1.dist-info/RECORD +24 -0
  5. source/watchfiles-1.1.1.dist-info/WHEEL +4 -0
  6. source/watchfiles-1.1.1.dist-info/entry_points.txt +2 -0
  7. source/watchfiles-1.1.1.dist-info/licenses/LICENSE +21 -0
  8. source/websockets-16.0.dist-info/INSTALLER +1 -0
  9. source/websockets-16.0.dist-info/METADATA +179 -0
  10. source/websockets-16.0.dist-info/RECORD +108 -0
  11. source/websockets-16.0.dist-info/WHEEL +7 -0
  12. source/websockets-16.0.dist-info/entry_points.txt +2 -0
  13. source/websockets-16.0.dist-info/licenses/LICENSE +24 -0
  14. source/websockets-16.0.dist-info/top_level.txt +1 -0
  15. source/websockets/__init__.py +236 -0
  16. source/websockets/__main__.py +5 -0
  17. source/websockets/asyncio/__init__.py +0 -0
  18. source/websockets/asyncio/async_timeout.py +282 -0
  19. source/websockets/asyncio/client.py +804 -0
  20. source/websockets/asyncio/compatibility.py +30 -0
  21. source/websockets/asyncio/connection.py +1247 -0
  22. source/websockets/asyncio/messages.py +316 -0
  23. source/websockets/asyncio/router.py +219 -0
  24. source/websockets/asyncio/server.py +997 -0
  25. source/websockets/auth.py +18 -0
  26. source/websockets/cli.py +178 -0
  27. source/websockets/client.py +391 -0
  28. source/websockets/connection.py +12 -0
  29. source/websockets/datastructures.py +183 -0
  30. source/websockets/exceptions.py +473 -0
  31. source/websockets/extensions/__init__.py +4 -0
  32. source/websockets/extensions/base.py +123 -0
  33. source/websockets/extensions/permessage_deflate.py +699 -0
  34. source/websockets/frames.py +431 -0
  35. source/websockets/headers.py +586 -0
  36. source/websockets/http.py +20 -0
  37. source/websockets/http11.py +438 -0
  38. source/websockets/imports.py +100 -0
  39. source/websockets/legacy/__init__.py +11 -0
  40. source/websockets/legacy/auth.py +190 -0
  41. source/websockets/legacy/client.py +703 -0
  42. source/websockets/legacy/exceptions.py +71 -0
  43. source/websockets/legacy/framing.py +224 -0
  44. source/websockets/legacy/handshake.py +158 -0
  45. source/websockets/legacy/http.py +201 -0
  46. source/websockets/legacy/protocol.py +1635 -0
  47. source/websockets/legacy/server.py +1191 -0
  48. source/websockets/protocol.py +768 -0
  49. source/websockets/proxy.py +150 -0
  50. source/websockets/py.typed +0 -0
.gitattributes CHANGED
@@ -267,3 +267,6 @@ source/tvm_ffi/lib/libtvm_ffi.so filter=lfs diff=lfs merge=lfs -text
267
  source/tvm_ffi/lib/libtvm_ffi_testing.so filter=lfs diff=lfs merge=lfs -text
268
  source/uvloop/loop.cpython-312-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
269
  source/watchfiles/_rust_notify.cpython-312-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
 
 
 
 
267
  source/tvm_ffi/lib/libtvm_ffi_testing.so filter=lfs diff=lfs merge=lfs -text
268
  source/uvloop/loop.cpython-312-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
269
  source/watchfiles/_rust_notify.cpython-312-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
270
+ source/yaml/_yaml.cpython-312-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
271
+ source/yarl/_quoting_c.cpython-312-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
272
+ source/zmq/backend/cython/_zmq.abi3.so filter=lfs diff=lfs merge=lfs -text
source/watchfiles-1.1.1.dist-info/INSTALLER ADDED
@@ -0,0 +1 @@
 
 
1
+ pip
source/watchfiles-1.1.1.dist-info/METADATA ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: watchfiles
3
+ Version: 1.1.1
4
+ Classifier: Development Status :: 5 - Production/Stable
5
+ Classifier: Environment :: Console
6
+ Classifier: Programming Language :: Python
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: Programming Language :: Python :: 3 :: Only
9
+ Classifier: Programming Language :: Python :: 3.9
10
+ Classifier: Programming Language :: Python :: 3.10
11
+ Classifier: Programming Language :: Python :: 3.11
12
+ Classifier: Programming Language :: Python :: 3.12
13
+ Classifier: Programming Language :: Python :: 3.13
14
+ Classifier: Programming Language :: Python :: 3.14
15
+ Classifier: Intended Audience :: Developers
16
+ Classifier: Intended Audience :: Information Technology
17
+ Classifier: Intended Audience :: System Administrators
18
+ Classifier: License :: OSI Approved :: MIT License
19
+ Classifier: Operating System :: POSIX :: Linux
20
+ Classifier: Operating System :: Microsoft :: Windows
21
+ Classifier: Operating System :: MacOS
22
+ Classifier: Environment :: MacOS X
23
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
24
+ Classifier: Topic :: System :: Filesystems
25
+ Classifier: Framework :: AnyIO
26
+ Requires-Dist: anyio>=3.0.0
27
+ License-File: LICENSE
28
+ Summary: Simple, modern and high performance file watching and code reload in python.
29
+ Home-Page: https://github.com/samuelcolvin/watchfiles
30
+ Author-email: Samuel Colvin <s@muelcolvin.com>
31
+ License: MIT
32
+ Requires-Python: >=3.9
33
+ Description-Content-Type: text/markdown; charset=UTF-8; variant=GFM
34
+ Project-URL: Homepage, https://github.com/samuelcolvin/watchfiles
35
+ Project-URL: Documentation, https://watchfiles.helpmanual.io
36
+ Project-URL: Funding, https://github.com/sponsors/samuelcolvin
37
+ Project-URL: Source, https://github.com/samuelcolvin/watchfiles
38
+ Project-URL: Changelog, https://github.com/samuelcolvin/watchfiles/releases
39
+
40
+ # watchfiles
41
+
42
+ [![CI](https://github.com/samuelcolvin/watchfiles/actions/workflows/ci.yml/badge.svg)](https://github.com/samuelcolvin/watchfiles/actions/workflows/ci.yml?query=branch%3Amain)
43
+ [![Coverage](https://codecov.io/gh/samuelcolvin/watchfiles/branch/main/graph/badge.svg)](https://codecov.io/gh/samuelcolvin/watchfiles)
44
+ [![pypi](https://img.shields.io/pypi/v/watchfiles.svg)](https://pypi.python.org/pypi/watchfiles)
45
+ [![CondaForge](https://img.shields.io/conda/v/conda-forge/watchfiles.svg)](https://anaconda.org/conda-forge/watchfiles)
46
+ [![license](https://img.shields.io/github/license/samuelcolvin/watchfiles.svg)](https://github.com/samuelcolvin/watchfiles/blob/main/LICENSE)
47
+
48
+ Simple, modern and high performance file watching and code reload in python.
49
+
50
+ ---
51
+
52
+ **Documentation**: [watchfiles.helpmanual.io](https://watchfiles.helpmanual.io)
53
+
54
+ **Source Code**: [github.com/samuelcolvin/watchfiles](https://github.com/samuelcolvin/watchfiles)
55
+
56
+ ---
57
+
58
+ Underlying file system notifications are handled by the [Notify](https://github.com/notify-rs/notify) rust library.
59
+
60
+ This package was previously named "watchgod",
61
+ see [the migration guide](https://watchfiles.helpmanual.io/migrating/) for more information.
62
+
63
+ ## Installation
64
+
65
+ **watchfiles** requires Python 3.9 - 3.14.
66
+
67
+ ```bash
68
+ pip install watchfiles
69
+ ```
70
+
71
+ Binaries are available for most architectures on Linux, MacOS and Windows ([learn more](https://watchfiles.helpmanual.io/#installation)).
72
+
73
+ Otherwise, you can install from source which requires Rust stable to be installed.
74
+
75
+ ## Usage
76
+
77
+ Here are some examples of what **watchfiles** can do:
78
+
79
+ ### `watch` Usage
80
+
81
+ ```py
82
+ from watchfiles import watch
83
+
84
+ for changes in watch('./path/to/dir'):
85
+ print(changes)
86
+ ```
87
+ See [`watch` docs](https://watchfiles.helpmanual.io/api/watch/#watchfiles.watch) for more details.
88
+
89
+ ### `awatch` Usage
90
+
91
+ ```py
92
+ import asyncio
93
+ from watchfiles import awatch
94
+
95
+ async def main():
96
+ async for changes in awatch('/path/to/dir'):
97
+ print(changes)
98
+
99
+ asyncio.run(main())
100
+ ```
101
+ See [`awatch` docs](https://watchfiles.helpmanual.io/api/watch/#watchfiles.awatch) for more details.
102
+
103
+ ### `run_process` Usage
104
+
105
+ ```py
106
+ from watchfiles import run_process
107
+
108
+ def foobar(a, b, c):
109
+ ...
110
+
111
+ if __name__ == '__main__':
112
+ run_process('./path/to/dir', target=foobar, args=(1, 2, 3))
113
+ ```
114
+ See [`run_process` docs](https://watchfiles.helpmanual.io/api/run_process/#watchfiles.run_process) for more details.
115
+
116
+ ### `arun_process` Usage
117
+
118
+ ```py
119
+ import asyncio
120
+ from watchfiles import arun_process
121
+
122
+ def foobar(a, b, c):
123
+ ...
124
+
125
+ async def main():
126
+ await arun_process('./path/to/dir', target=foobar, args=(1, 2, 3))
127
+
128
+ if __name__ == '__main__':
129
+ asyncio.run(main())
130
+ ```
131
+ See [`arun_process` docs](https://watchfiles.helpmanual.io/api/run_process/#watchfiles.arun_process) for more details.
132
+
133
+ ## CLI
134
+
135
+ **watchfiles** also comes with a CLI for running and reloading code. To run `some command` when files in `src` change:
136
+
137
+ ```
138
+ watchfiles "some command" src
139
+ ```
140
+
141
+ For more information, see [the CLI docs](https://watchfiles.helpmanual.io/cli/).
142
+
143
+ Or run
144
+
145
+ ```bash
146
+ watchfiles --help
147
+ ```
148
+
source/watchfiles-1.1.1.dist-info/RECORD ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ../../bin/watchfiles,sha256=UmgepAyVu9Gw-Yp6nEG9ks2cXHYq9nd7hBmVixDPM7s,211
2
+ watchfiles-1.1.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
3
+ watchfiles-1.1.1.dist-info/METADATA,sha256=h34wYtQyezaYEn9GQWg8z4d9JCViFLT4vmKd_ip6WF8,4874
4
+ watchfiles-1.1.1.dist-info/RECORD,,
5
+ watchfiles-1.1.1.dist-info/WHEEL,sha256=AUS7tHOBvWg1bDsPcHg1j3P_rKxqebEdeR--lIGHkyI,129
6
+ watchfiles-1.1.1.dist-info/entry_points.txt,sha256=s1Dpa2d_KKBy-jKREWW60Z3GoRZ3JpCEo_9iYDt6hOQ,48
7
+ watchfiles-1.1.1.dist-info/licenses/LICENSE,sha256=T9eDVbZ84md-3p-29jolDzd7t-IgiBNqX0aZrbS8Bp8,1091
8
+ watchfiles/__init__.py,sha256=IRlM9KOSedMzF1fvLr7yEHPVS-UFERNThlB-tmWI8yU,364
9
+ watchfiles/__main__.py,sha256=JgErYkiskih8Y6oRwowALtR-rwQhAAdqOYWjQraRIPI,59
10
+ watchfiles/__pycache__/__init__.cpython-312.pyc,,
11
+ watchfiles/__pycache__/__main__.cpython-312.pyc,,
12
+ watchfiles/__pycache__/cli.cpython-312.pyc,,
13
+ watchfiles/__pycache__/filters.cpython-312.pyc,,
14
+ watchfiles/__pycache__/main.cpython-312.pyc,,
15
+ watchfiles/__pycache__/run.cpython-312.pyc,,
16
+ watchfiles/__pycache__/version.cpython-312.pyc,,
17
+ watchfiles/_rust_notify.cpython-312-x86_64-linux-gnu.so,sha256=sJsIMMJW0QyNqKUFF2eg4YaVUywMgJiAjdubGmyjAo0,1124288
18
+ watchfiles/_rust_notify.pyi,sha256=q5FQkXgBJEFPt9RCf7my4wP5RM1FwSVpqf221csyebg,4753
19
+ watchfiles/cli.py,sha256=DHMI0LfT7hOrWai_Y4RP_vvTvVdtcDaioixXLiv2pG4,7707
20
+ watchfiles/filters.py,sha256=U0zXGOeg9dMHkT51-56BKpRrWIu95lPq0HDR_ZB4oDE,5139
21
+ watchfiles/main.py,sha256=-pbJBFBA34VEXMt8VGcaPTQHAjsGhPf7Psu1gP_HnKk,15235
22
+ watchfiles/py.typed,sha256=MS4Na3to9VTGPy_8wBQM_6mNKaX4qIpi5-w7_LZB-8I,69
23
+ watchfiles/run.py,sha256=TLXb2y_xYx-t3xyszVQWHoGyG7RCb107Q0NoIcSWmjQ,15348
24
+ watchfiles/version.py,sha256=NRWUnkZ32DamsNKV20EetagIGTLDMMUnqDWVGFFA2WQ,85
source/watchfiles-1.1.1.dist-info/WHEEL ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Wheel-Version: 1.0
2
+ Generator: maturin (1.9.6)
3
+ Root-Is-Purelib: false
4
+ Tag: cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64
source/watchfiles-1.1.1.dist-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ watchfiles=watchfiles.cli:cli
source/watchfiles-1.1.1.dist-info/licenses/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The MIT License (MIT)
2
+
3
+ Copyright (c) 2017 to present Samuel Colvin
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
source/websockets-16.0.dist-info/INSTALLER ADDED
@@ -0,0 +1 @@
 
 
1
+ pip
source/websockets-16.0.dist-info/METADATA ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: websockets
3
+ Version: 16.0
4
+ Summary: An implementation of the WebSocket Protocol (RFC 6455 & 7692)
5
+ Author-email: Aymeric Augustin <aymeric.augustin@m4x.org>
6
+ License-Expression: BSD-3-Clause
7
+ Project-URL: Homepage, https://github.com/python-websockets/websockets
8
+ Project-URL: Changelog, https://websockets.readthedocs.io/en/stable/project/changelog.html
9
+ Project-URL: Documentation, https://websockets.readthedocs.io/
10
+ Project-URL: Funding, https://tidelift.com/subscription/pkg/pypi-websockets?utm_source=pypi-websockets&utm_medium=referral&utm_campaign=readme
11
+ Project-URL: Tracker, https://github.com/python-websockets/websockets/issues
12
+ Keywords: WebSocket
13
+ Classifier: Development Status :: 5 - Production/Stable
14
+ Classifier: Environment :: Web Environment
15
+ Classifier: Intended Audience :: Developers
16
+ Classifier: Operating System :: OS Independent
17
+ Classifier: Programming Language :: Python
18
+ Classifier: Programming Language :: Python :: 3
19
+ Classifier: Programming Language :: Python :: 3.10
20
+ Classifier: Programming Language :: Python :: 3.11
21
+ Classifier: Programming Language :: Python :: 3.12
22
+ Classifier: Programming Language :: Python :: 3.13
23
+ Classifier: Programming Language :: Python :: 3.14
24
+ Requires-Python: >=3.10
25
+ Description-Content-Type: text/x-rst
26
+ License-File: LICENSE
27
+ Dynamic: description
28
+ Dynamic: description-content-type
29
+ Dynamic: license-file
30
+
31
+ .. image:: logo/horizontal.svg
32
+ :width: 480px
33
+ :alt: websockets
34
+
35
+ |licence| |version| |pyversions| |tests| |docs| |openssf|
36
+
37
+ .. |licence| image:: https://img.shields.io/pypi/l/websockets.svg
38
+ :target: https://pypi.python.org/pypi/websockets
39
+
40
+ .. |version| image:: https://img.shields.io/pypi/v/websockets.svg
41
+ :target: https://pypi.python.org/pypi/websockets
42
+
43
+ .. |pyversions| image:: https://img.shields.io/pypi/pyversions/websockets.svg
44
+ :target: https://pypi.python.org/pypi/websockets
45
+
46
+ .. |tests| image:: https://img.shields.io/github/checks-status/python-websockets/websockets/main?label=tests
47
+ :target: https://github.com/python-websockets/websockets/actions/workflows/tests.yml
48
+
49
+ .. |docs| image:: https://img.shields.io/readthedocs/websockets.svg
50
+ :target: https://websockets.readthedocs.io/
51
+
52
+ .. |openssf| image:: https://bestpractices.coreinfrastructure.org/projects/6475/badge
53
+ :target: https://bestpractices.coreinfrastructure.org/projects/6475
54
+
55
+ What is ``websockets``?
56
+ -----------------------
57
+
58
+ websockets is a library for building WebSocket_ servers and clients in Python
59
+ with a focus on correctness, simplicity, robustness, and performance.
60
+
61
+ .. _WebSocket: https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API
62
+
63
+ Built on top of ``asyncio``, Python's standard asynchronous I/O framework, the
64
+ default implementation provides an elegant coroutine-based API.
65
+
66
+ An implementation on top of ``threading`` and a Sans-I/O implementation are also
67
+ available.
68
+
69
+ `Documentation is available on Read the Docs. <https://websockets.readthedocs.io/>`_
70
+
71
+ .. copy-pasted because GitHub doesn't support the include directive
72
+
73
+ Here's an echo server with the ``asyncio`` API:
74
+
75
+ .. code:: python
76
+
77
+ #!/usr/bin/env python
78
+
79
+ import asyncio
80
+ from websockets.asyncio.server import serve
81
+
82
+ async def echo(websocket):
83
+ async for message in websocket:
84
+ await websocket.send(message)
85
+
86
+ async def main():
87
+ async with serve(echo, "localhost", 8765) as server:
88
+ await server.serve_forever()
89
+
90
+ asyncio.run(main())
91
+
92
+ Here's how a client sends and receives messages with the ``threading`` API:
93
+
94
+ .. code:: python
95
+
96
+ #!/usr/bin/env python
97
+
98
+ from websockets.sync.client import connect
99
+
100
+ def hello():
101
+ with connect("ws://localhost:8765") as websocket:
102
+ websocket.send("Hello world!")
103
+ message = websocket.recv()
104
+ print(f"Received: {message}")
105
+
106
+ hello()
107
+
108
+
109
+ Does that look good?
110
+
111
+ `Get started with the tutorial! <https://websockets.readthedocs.io/en/stable/intro/index.html>`_
112
+
113
+ Why should I use ``websockets``?
114
+ --------------------------------
115
+
116
+ The development of ``websockets`` is shaped by four principles:
117
+
118
+ 1. **Correctness**: ``websockets`` is heavily tested for compliance with
119
+ :rfc:`6455`. Continuous integration fails under 100% branch coverage.
120
+
121
+ 2. **Simplicity**: all you need to understand is ``msg = await ws.recv()`` and
122
+ ``await ws.send(msg)``. ``websockets`` takes care of managing connections
123
+ so you can focus on your application.
124
+
125
+ 3. **Robustness**: ``websockets`` is built for production. For example, it was
126
+ the only library to `handle backpressure correctly`_ before the issue
127
+ became widely known in the Python community.
128
+
129
+ 4. **Performance**: memory usage is optimized and configurable. A C extension
130
+ accelerates expensive operations. It's pre-compiled for Linux, macOS and
131
+ Windows and packaged in the wheel format for each system and Python version.
132
+
133
+ Documentation is a first class concern in the project. Head over to `Read the
134
+ Docs`_ and see for yourself.
135
+
136
+ .. _Read the Docs: https://websockets.readthedocs.io/
137
+ .. _handle backpressure correctly: https://vorpus.org/blog/some-thoughts-on-asynchronous-api-design-in-a-post-asyncawait-world/#websocket-servers
138
+
139
+ Why shouldn't I use ``websockets``?
140
+ -----------------------------------
141
+
142
+ * If you prefer callbacks over coroutines: ``websockets`` was created to
143
+ provide the best coroutine-based API to manage WebSocket connections in
144
+ Python. Pick another library for a callback-based API.
145
+
146
+ * If you're looking for a mixed HTTP / WebSocket library: ``websockets`` aims
147
+ at being an excellent implementation of :rfc:`6455`: The WebSocket Protocol
148
+ and :rfc:`7692`: Compression Extensions for WebSocket. Its support for HTTP
149
+ is minimal — just enough for an HTTP health check.
150
+
151
+ If you want to do both in the same server, look at HTTP + WebSocket servers
152
+ that build on top of ``websockets`` to support WebSocket connections, like
153
+ uvicorn_ or Sanic_.
154
+
155
+ .. _uvicorn: https://www.uvicorn.org/
156
+ .. _Sanic: https://sanic.dev/en/
157
+
158
+ What else?
159
+ ----------
160
+
161
+ Bug reports, patches and suggestions are welcome!
162
+
163
+ To report a security vulnerability, please use the `Tidelift security
164
+ contact`_. Tidelift will coordinate the fix and disclosure.
165
+
166
+ .. _Tidelift security contact: https://tidelift.com/security
167
+
168
+ For anything else, please open an issue_ or send a `pull request`_.
169
+
170
+ .. _issue: https://github.com/python-websockets/websockets/issues/new
171
+ .. _pull request: https://github.com/python-websockets/websockets/compare/
172
+
173
+ Participants must uphold the `Contributor Covenant code of conduct`_.
174
+
175
+ .. _Contributor Covenant code of conduct: https://github.com/python-websockets/websockets/blob/main/CODE_OF_CONDUCT.md
176
+
177
+ ``websockets`` is released under the `BSD license`_.
178
+
179
+ .. _BSD license: https://github.com/python-websockets/websockets/blob/main/LICENSE
source/websockets-16.0.dist-info/RECORD ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ../../bin/websockets,sha256=jIwwGFqaK2AvxQf01v-BJ3EMdDyCxBvpgWufffO9SyU,213
2
+ websockets-16.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
3
+ websockets-16.0.dist-info/METADATA,sha256=JcDvWo8DVSw5uoDAFbk9N8fJXuRJvnrcLXVBFyBjwN8,6799
4
+ websockets-16.0.dist-info/RECORD,,
5
+ websockets-16.0.dist-info/WHEEL,sha256=mX4U4odf6w47aVjwZUmTYd1MF9BbrhVLKlaWSvZwHEk,186
6
+ websockets-16.0.dist-info/entry_points.txt,sha256=Dnhn4dm5EsI4ZMAsHldGF6CwBXZrGXnR7cnK2-XR7zY,51
7
+ websockets-16.0.dist-info/licenses/LICENSE,sha256=PWoMBQ2L7FL6utUC5F-yW9ArytvXDeo01Ee2oP9Obag,1514
8
+ websockets-16.0.dist-info/top_level.txt,sha256=CMpdKklxKsvZgCgyltxUWOHibZXZ1uYIVpca9xsQ8Hk,11
9
+ websockets/__init__.py,sha256=AC2Hq92uSc_WOo9_xvITpGshJ7Dy0Md5m2_ywsdSt_Y,7058
10
+ websockets/__main__.py,sha256=wu5N2wk8mvBgyvr2ghmQf4prezAe0_i-p123VVreyYc,62
11
+ websockets/__pycache__/__init__.cpython-312.pyc,,
12
+ websockets/__pycache__/__main__.cpython-312.pyc,,
13
+ websockets/__pycache__/auth.cpython-312.pyc,,
14
+ websockets/__pycache__/cli.cpython-312.pyc,,
15
+ websockets/__pycache__/client.cpython-312.pyc,,
16
+ websockets/__pycache__/connection.cpython-312.pyc,,
17
+ websockets/__pycache__/datastructures.cpython-312.pyc,,
18
+ websockets/__pycache__/exceptions.cpython-312.pyc,,
19
+ websockets/__pycache__/frames.cpython-312.pyc,,
20
+ websockets/__pycache__/headers.cpython-312.pyc,,
21
+ websockets/__pycache__/http.cpython-312.pyc,,
22
+ websockets/__pycache__/http11.cpython-312.pyc,,
23
+ websockets/__pycache__/imports.cpython-312.pyc,,
24
+ websockets/__pycache__/protocol.cpython-312.pyc,,
25
+ websockets/__pycache__/proxy.cpython-312.pyc,,
26
+ websockets/__pycache__/server.cpython-312.pyc,,
27
+ websockets/__pycache__/streams.cpython-312.pyc,,
28
+ websockets/__pycache__/typing.cpython-312.pyc,,
29
+ websockets/__pycache__/uri.cpython-312.pyc,,
30
+ websockets/__pycache__/utils.cpython-312.pyc,,
31
+ websockets/__pycache__/version.cpython-312.pyc,,
32
+ websockets/asyncio/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
33
+ websockets/asyncio/__pycache__/__init__.cpython-312.pyc,,
34
+ websockets/asyncio/__pycache__/async_timeout.cpython-312.pyc,,
35
+ websockets/asyncio/__pycache__/client.cpython-312.pyc,,
36
+ websockets/asyncio/__pycache__/compatibility.cpython-312.pyc,,
37
+ websockets/asyncio/__pycache__/connection.cpython-312.pyc,,
38
+ websockets/asyncio/__pycache__/messages.cpython-312.pyc,,
39
+ websockets/asyncio/__pycache__/router.cpython-312.pyc,,
40
+ websockets/asyncio/__pycache__/server.cpython-312.pyc,,
41
+ websockets/asyncio/async_timeout.py,sha256=N-6Mubyiaoh66PAXGvCzhgxCM-7V2XiRnH32Xi6J6TE,8971
42
+ websockets/asyncio/client.py,sha256=e4xlgtzb3v29M2vN-UDiyoUtThg--d5GqKg3lt2pDdE,30850
43
+ websockets/asyncio/compatibility.py,sha256=gkenDDhzNbm6_iXV5Edvbvp6uHZYdrTvGNjt8P_JtyQ,786
44
+ websockets/asyncio/connection.py,sha256=87RdVURijJk8V-ShWAWfTEyhW5Z1YUXKV8ezUzxt5L0,49099
45
+ websockets/asyncio/messages.py,sha256=u2M5WKY9xPyw8G3nKoXfdO5K41hrTnf4MdizVHzgdM4,11129
46
+ websockets/asyncio/router.py,sha256=S-69vszK-SqUCcZbXXPOnux-eH2fTHYC2JNh7tOtmmA,7520
47
+ websockets/asyncio/server.py,sha256=wQ9oBc0WBOIzbXKDYJ8UhXRTeoXrSfLu6CWCrUl-vck,37941
48
+ websockets/auth.py,sha256=U_Jwmn59ZRQ6EecpOvMizQCG_ZbAvgUf1ik7haZRC3c,568
49
+ websockets/cli.py,sha256=YnegH59z93JxSVIGiXiWhR3ktgI6k1_pf_BRLanxKrQ,5336
50
+ websockets/client.py,sha256=fljI5k5oQ-Sfm53MCoyTlr2jFtOOIuO13H9bbtpBPes,13789
51
+ websockets/connection.py,sha256=OLiMVkNd25_86sB8Q7CrCwBoXy9nA0OCgdgLRA8WUR8,323
52
+ websockets/datastructures.py,sha256=Uq2CpjmXak9_pPWcOqh36rzJMo8eCi2lVPTFWDvK5sA,5518
53
+ websockets/exceptions.py,sha256=bgaMdqQGGZosAEULeCB30XW2YnwomWa3c8YOrEfeOoY,12859
54
+ websockets/extensions/__init__.py,sha256=QkZsxaJVllVSp1uhdD5uPGibdbx_091GrVVfS5LXcpw,98
55
+ websockets/extensions/__pycache__/__init__.cpython-312.pyc,,
56
+ websockets/extensions/__pycache__/base.cpython-312.pyc,,
57
+ websockets/extensions/__pycache__/permessage_deflate.cpython-312.pyc,,
58
+ websockets/extensions/base.py,sha256=JNfyk543C7VuPH0QOobiqKoGrzjJILje6sz5ILvOPl4,2903
59
+ websockets/extensions/permessage_deflate.py,sha256=AkuhkAKFo5lqJQMXnckbSs9b2KBBrOFsE1DHIcbLL3k,25770
60
+ websockets/frames.py,sha256=5IK4GZpl8ukr0bZ_UA_jjjztK09yYQAl9m5NVmGLiK0,12889
61
+ websockets/headers.py,sha256=yQnPljVZwV1_V-pOSRKNLG_u827wFC1h72cciojcQ8M,16046
62
+ websockets/http.py,sha256=T1tNLmbkFCneXQ6qepBmsVVDXyP9i500IVzTJTeBMR4,659
63
+ websockets/http11.py,sha256=T8ai5BcBGkV0n9It63oDeNpmtQMyg8Cpav5rf_yT0r4,15619
64
+ websockets/imports.py,sha256=T_B9TUmHoceKMQ-PNphdQQAH2XdxAxwSQNeQEgqILkE,2795
65
+ websockets/legacy/__init__.py,sha256=wQ5zRIENGUS_5eKNAX9CRE7x1TwKapKimrQFFWN9Sxs,276
66
+ websockets/legacy/__pycache__/__init__.cpython-312.pyc,,
67
+ websockets/legacy/__pycache__/auth.cpython-312.pyc,,
68
+ websockets/legacy/__pycache__/client.cpython-312.pyc,,
69
+ websockets/legacy/__pycache__/exceptions.cpython-312.pyc,,
70
+ websockets/legacy/__pycache__/framing.cpython-312.pyc,,
71
+ websockets/legacy/__pycache__/handshake.cpython-312.pyc,,
72
+ websockets/legacy/__pycache__/http.cpython-312.pyc,,
73
+ websockets/legacy/__pycache__/protocol.cpython-312.pyc,,
74
+ websockets/legacy/__pycache__/server.cpython-312.pyc,,
75
+ websockets/legacy/auth.py,sha256=DcQcCSeVeP93JcH8vFWE0HIJL-X-f23LZ0DsJpav1So,6531
76
+ websockets/legacy/client.py,sha256=fV2mbiU9rciXhJfAEKVSm0GztJDUbDpRQ-K5EMbkuQ0,26815
77
+ websockets/legacy/exceptions.py,sha256=ViEjpoT09fzx_Zqf0aNGDVtRDNjXaOw0gdCta3LkjFc,1924
78
+ websockets/legacy/framing.py,sha256=r9P1wiXv_1XuAVQw8SOPkuE9d4eZ0r_JowAkz9-WV4w,6366
79
+ websockets/legacy/handshake.py,sha256=2Nzr5AN2xvDC5EdNP-kB3lOcrAaUNlYuj_-hr_jv7pM,5285
80
+ websockets/legacy/http.py,sha256=cOCQmDWhIKQmm8UWGXPW7CDZg03wjogCsb0LP9oetNQ,7061
81
+ websockets/legacy/protocol.py,sha256=ajtVXDb-lEm9BN0NF3iEaTI_b1q5fBCKTB9wvUoGOxY,63632
82
+ websockets/legacy/server.py,sha256=7mwY-yD0ljNF93oPYumTWD7OIVbCWtaEOw1FFJBhIAM,45251
83
+ websockets/protocol.py,sha256=vTqjPIg2HmO-bSxsczuEmWMxPTxPXU1hmVUjqnahV44,27247
84
+ websockets/proxy.py,sha256=oFrbEYtasYWv-WDcniObD9nBR5Q5qkHpyCVLngx7WMQ,4969
85
+ websockets/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
86
+ websockets/server.py,sha256=E4SWBA8WZRmAOpsUm-oCqacBGZre9e0iDmDIrfpV21Q,21790
87
+ websockets/speedups.c,sha256=u_dncR4M38EX6He_fzb1TY6D3Hke67ZpoHLLhZZ0hvQ,5920
88
+ websockets/speedups.cpython-312-x86_64-linux-gnu.so,sha256=F8FiVerlQi_Z0YSsuY_ASEHvWcddXkyyRa3ylkV80B0,38048
89
+ websockets/speedups.pyi,sha256=unjvBNg-uW4c7z-9OW4WiSzZk_QH2bLEcjYAMuoSgBI,102
90
+ websockets/streams.py,sha256=pXqga7ttjuF6lChWYiWLSfUlt3FCaQpEX1ae_jvcCeQ,4071
91
+ websockets/sync/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
92
+ websockets/sync/__pycache__/__init__.cpython-312.pyc,,
93
+ websockets/sync/__pycache__/client.cpython-312.pyc,,
94
+ websockets/sync/__pycache__/connection.cpython-312.pyc,,
95
+ websockets/sync/__pycache__/messages.cpython-312.pyc,,
96
+ websockets/sync/__pycache__/router.cpython-312.pyc,,
97
+ websockets/sync/__pycache__/server.cpython-312.pyc,,
98
+ websockets/sync/__pycache__/utils.cpython-312.pyc,,
99
+ websockets/sync/client.py,sha256=_2Erytw1f3f9O_u2jLtS1oNV4HsHUi_h3lGvT9ZEaDQ,22108
100
+ websockets/sync/connection.py,sha256=1pJYEMRHLWIN7538vJcIeFVnvSXVrD0n1xrfX7wDNSc,41868
101
+ websockets/sync/messages.py,sha256=yZV1zhY07ZD0vRF5b1yDa7ug0rbA5UDOCCCQmWwAcds,12858
102
+ websockets/sync/router.py,sha256=BqKSAKNZYtRWiOxol9qYeyfgyXRrMNJ6FrTTZLNcXMg,7172
103
+ websockets/sync/server.py,sha256=s07HNK_2s1kLN62Uqc77uvND0z7C0YTXGePsCiBtXaE,27655
104
+ websockets/sync/utils.py,sha256=TtW-ncYFvJmiSW2gO86ngE2BVsnnBdL-4H88kWNDYbg,1107
105
+ websockets/typing.py,sha256=A6xh4m65pRzKAbuOs0kFuGhL4DWIIko-ppS4wvJVc0Q,1946
106
+ websockets/uri.py,sha256=2fFMw-AbKJ5HVHNCuw1Rx1MnkCkNWRpogxWhhM30EU4,3125
107
+ websockets/utils.py,sha256=AwhS4UmlbKv7meAaR7WNbUqD5JFoStOP1bAyo9sRMus,1197
108
+ websockets/version.py,sha256=IhaztWxysdY-pd-0nOubnnPduvySSvdoBwrQdJKtZ2g,3202
source/websockets-16.0.dist-info/WHEEL ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: false
4
+ Tag: cp312-cp312-manylinux_2_5_x86_64
5
+ Tag: cp312-cp312-manylinux1_x86_64
6
+ Tag: cp312-cp312-manylinux_2_28_x86_64
7
+
source/websockets-16.0.dist-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ websockets = websockets.cli:main
source/websockets-16.0.dist-info/licenses/LICENSE ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) Aymeric Augustin and contributors
2
+
3
+ Redistribution and use in source and binary forms, with or without
4
+ modification, are permitted provided that the following conditions are met:
5
+
6
+ * Redistributions of source code must retain the above copyright notice,
7
+ this list of conditions and the following disclaimer.
8
+ * Redistributions in binary form must reproduce the above copyright notice,
9
+ this list of conditions and the following disclaimer in the documentation
10
+ and/or other materials provided with the distribution.
11
+ * Neither the name of the copyright holder nor the names of its contributors
12
+ may be used to endorse or promote products derived from this software
13
+ without specific prior written permission.
14
+
15
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
19
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
source/websockets-16.0.dist-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ websockets
source/websockets/__init__.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ # Importing the typing module would conflict with websockets.typing.
4
+ from typing import TYPE_CHECKING
5
+
6
+ from .imports import lazy_import
7
+ from .version import version as __version__ # noqa: F401
8
+
9
+
10
+ __all__ = [
11
+ # .asyncio.client
12
+ "connect",
13
+ "unix_connect",
14
+ "ClientConnection",
15
+ # .asyncio.router
16
+ "route",
17
+ "unix_route",
18
+ "Router",
19
+ # .asyncio.server
20
+ "basic_auth",
21
+ "broadcast",
22
+ "serve",
23
+ "unix_serve",
24
+ "ServerConnection",
25
+ "Server",
26
+ # .client
27
+ "ClientProtocol",
28
+ # .datastructures
29
+ "Headers",
30
+ "HeadersLike",
31
+ "MultipleValuesError",
32
+ # .exceptions
33
+ "ConcurrencyError",
34
+ "ConnectionClosed",
35
+ "ConnectionClosedError",
36
+ "ConnectionClosedOK",
37
+ "DuplicateParameter",
38
+ "InvalidHandshake",
39
+ "InvalidHeader",
40
+ "InvalidHeaderFormat",
41
+ "InvalidHeaderValue",
42
+ "InvalidMessage",
43
+ "InvalidOrigin",
44
+ "InvalidParameterName",
45
+ "InvalidParameterValue",
46
+ "InvalidProxy",
47
+ "InvalidProxyMessage",
48
+ "InvalidProxyStatus",
49
+ "InvalidState",
50
+ "InvalidStatus",
51
+ "InvalidUpgrade",
52
+ "InvalidURI",
53
+ "NegotiationError",
54
+ "PayloadTooBig",
55
+ "ProtocolError",
56
+ "ProxyError",
57
+ "SecurityError",
58
+ "WebSocketException",
59
+ # .frames
60
+ "Close",
61
+ "CloseCode",
62
+ "Frame",
63
+ "Opcode",
64
+ # .http11
65
+ "Request",
66
+ "Response",
67
+ # .protocol
68
+ "Protocol",
69
+ "Side",
70
+ "State",
71
+ # .server
72
+ "ServerProtocol",
73
+ # .typing
74
+ "Data",
75
+ "ExtensionName",
76
+ "ExtensionParameter",
77
+ "LoggerLike",
78
+ "StatusLike",
79
+ "Origin",
80
+ "Subprotocol",
81
+ ]
82
+
83
+ # When type checking, import non-deprecated aliases eagerly. Else, import on demand.
84
+ if TYPE_CHECKING:
85
+ from .asyncio.client import ClientConnection, connect, unix_connect
86
+ from .asyncio.router import Router, route, unix_route
87
+ from .asyncio.server import (
88
+ Server,
89
+ ServerConnection,
90
+ basic_auth,
91
+ broadcast,
92
+ serve,
93
+ unix_serve,
94
+ )
95
+ from .client import ClientProtocol
96
+ from .datastructures import Headers, HeadersLike, MultipleValuesError
97
+ from .exceptions import (
98
+ ConcurrencyError,
99
+ ConnectionClosed,
100
+ ConnectionClosedError,
101
+ ConnectionClosedOK,
102
+ DuplicateParameter,
103
+ InvalidHandshake,
104
+ InvalidHeader,
105
+ InvalidHeaderFormat,
106
+ InvalidHeaderValue,
107
+ InvalidMessage,
108
+ InvalidOrigin,
109
+ InvalidParameterName,
110
+ InvalidParameterValue,
111
+ InvalidProxy,
112
+ InvalidProxyMessage,
113
+ InvalidProxyStatus,
114
+ InvalidState,
115
+ InvalidStatus,
116
+ InvalidUpgrade,
117
+ InvalidURI,
118
+ NegotiationError,
119
+ PayloadTooBig,
120
+ ProtocolError,
121
+ ProxyError,
122
+ SecurityError,
123
+ WebSocketException,
124
+ )
125
+ from .frames import Close, CloseCode, Frame, Opcode
126
+ from .http11 import Request, Response
127
+ from .protocol import Protocol, Side, State
128
+ from .server import ServerProtocol
129
+ from .typing import (
130
+ Data,
131
+ ExtensionName,
132
+ ExtensionParameter,
133
+ LoggerLike,
134
+ Origin,
135
+ StatusLike,
136
+ Subprotocol,
137
+ )
138
+ else:
139
+ lazy_import(
140
+ globals(),
141
+ aliases={
142
+ # .asyncio.client
143
+ "connect": ".asyncio.client",
144
+ "unix_connect": ".asyncio.client",
145
+ "ClientConnection": ".asyncio.client",
146
+ # .asyncio.router
147
+ "route": ".asyncio.router",
148
+ "unix_route": ".asyncio.router",
149
+ "Router": ".asyncio.router",
150
+ # .asyncio.server
151
+ "basic_auth": ".asyncio.server",
152
+ "broadcast": ".asyncio.server",
153
+ "serve": ".asyncio.server",
154
+ "unix_serve": ".asyncio.server",
155
+ "ServerConnection": ".asyncio.server",
156
+ "Server": ".asyncio.server",
157
+ # .client
158
+ "ClientProtocol": ".client",
159
+ # .datastructures
160
+ "Headers": ".datastructures",
161
+ "HeadersLike": ".datastructures",
162
+ "MultipleValuesError": ".datastructures",
163
+ # .exceptions
164
+ "ConcurrencyError": ".exceptions",
165
+ "ConnectionClosed": ".exceptions",
166
+ "ConnectionClosedError": ".exceptions",
167
+ "ConnectionClosedOK": ".exceptions",
168
+ "DuplicateParameter": ".exceptions",
169
+ "InvalidHandshake": ".exceptions",
170
+ "InvalidHeader": ".exceptions",
171
+ "InvalidHeaderFormat": ".exceptions",
172
+ "InvalidHeaderValue": ".exceptions",
173
+ "InvalidMessage": ".exceptions",
174
+ "InvalidOrigin": ".exceptions",
175
+ "InvalidParameterName": ".exceptions",
176
+ "InvalidParameterValue": ".exceptions",
177
+ "InvalidProxy": ".exceptions",
178
+ "InvalidProxyMessage": ".exceptions",
179
+ "InvalidProxyStatus": ".exceptions",
180
+ "InvalidState": ".exceptions",
181
+ "InvalidStatus": ".exceptions",
182
+ "InvalidUpgrade": ".exceptions",
183
+ "InvalidURI": ".exceptions",
184
+ "NegotiationError": ".exceptions",
185
+ "PayloadTooBig": ".exceptions",
186
+ "ProtocolError": ".exceptions",
187
+ "ProxyError": ".exceptions",
188
+ "SecurityError": ".exceptions",
189
+ "WebSocketException": ".exceptions",
190
+ # .frames
191
+ "Close": ".frames",
192
+ "CloseCode": ".frames",
193
+ "Frame": ".frames",
194
+ "Opcode": ".frames",
195
+ # .http11
196
+ "Request": ".http11",
197
+ "Response": ".http11",
198
+ # .protocol
199
+ "Protocol": ".protocol",
200
+ "Side": ".protocol",
201
+ "State": ".protocol",
202
+ # .server
203
+ "ServerProtocol": ".server",
204
+ # .typing
205
+ "Data": ".typing",
206
+ "ExtensionName": ".typing",
207
+ "ExtensionParameter": ".typing",
208
+ "LoggerLike": ".typing",
209
+ "Origin": ".typing",
210
+ "StatusLike": ".typing",
211
+ "Subprotocol": ".typing",
212
+ },
213
+ deprecated_aliases={
214
+ # deprecated in 9.0 - 2021-09-01
215
+ "framing": ".legacy",
216
+ "handshake": ".legacy",
217
+ "parse_uri": ".uri",
218
+ "WebSocketURI": ".uri",
219
+ # deprecated in 14.0 - 2024-11-09
220
+ # .legacy.auth
221
+ "BasicAuthWebSocketServerProtocol": ".legacy.auth",
222
+ "basic_auth_protocol_factory": ".legacy.auth",
223
+ # .legacy.client
224
+ "WebSocketClientProtocol": ".legacy.client",
225
+ # .legacy.exceptions
226
+ "AbortHandshake": ".legacy.exceptions",
227
+ "InvalidStatusCode": ".legacy.exceptions",
228
+ "RedirectHandshake": ".legacy.exceptions",
229
+ "WebSocketProtocolError": ".legacy.exceptions",
230
+ # .legacy.protocol
231
+ "WebSocketCommonProtocol": ".legacy.protocol",
232
+ # .legacy.server
233
+ "WebSocketServer": ".legacy.server",
234
+ "WebSocketServerProtocol": ".legacy.server",
235
+ },
236
+ )
source/websockets/__main__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .cli import main
2
+
3
+
4
+ if __name__ == "__main__":
5
+ main()
source/websockets/asyncio/__init__.py ADDED
File without changes
source/websockets/asyncio/async_timeout.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py
2
+ # Licensed under the Apache License (Apache-2.0)
3
+
4
+ import asyncio
5
+ import enum
6
+ import sys
7
+ import warnings
8
+ from types import TracebackType
9
+ from typing import Optional, Type
10
+
11
+
12
+ if sys.version_info >= (3, 11):
13
+ from typing import final
14
+ else:
15
+ # From https://github.com/python/typing_extensions/blob/main/src/typing_extensions.py
16
+ # Licensed under the Python Software Foundation License (PSF-2.0)
17
+
18
+ # @final exists in 3.8+, but we backport it for all versions
19
+ # before 3.11 to keep support for the __final__ attribute.
20
+ # See https://bugs.python.org/issue46342
21
+ def final(f):
22
+ """This decorator can be used to indicate to type checkers that
23
+ the decorated method cannot be overridden, and decorated class
24
+ cannot be subclassed. For example:
25
+
26
+ class Base:
27
+ @final
28
+ def done(self) -> None:
29
+ ...
30
+ class Sub(Base):
31
+ def done(self) -> None: # Error reported by type checker
32
+ ...
33
+ @final
34
+ class Leaf:
35
+ ...
36
+ class Other(Leaf): # Error reported by type checker
37
+ ...
38
+
39
+ There is no runtime checking of these properties. The decorator
40
+ sets the ``__final__`` attribute to ``True`` on the decorated object
41
+ to allow runtime introspection.
42
+ """
43
+ try:
44
+ f.__final__ = True
45
+ except (AttributeError, TypeError):
46
+ # Skip the attribute silently if it is not writable.
47
+ # AttributeError happens if the object has __slots__ or a
48
+ # read-only property, TypeError if it's a builtin class.
49
+ pass
50
+ return f
51
+
52
+ # End https://github.com/python/typing_extensions/blob/main/src/typing_extensions.py
53
+
54
+
55
+ if sys.version_info >= (3, 11):
56
+
57
+ def _uncancel_task(task: "asyncio.Task[object]") -> None:
58
+ task.uncancel()
59
+
60
+ else:
61
+
62
+ def _uncancel_task(task: "asyncio.Task[object]") -> None:
63
+ pass
64
+
65
+
66
+ __version__ = "4.0.3"
67
+
68
+
69
+ __all__ = ("timeout", "timeout_at", "Timeout")
70
+
71
+
72
+ def timeout(delay: Optional[float]) -> "Timeout":
73
+ """timeout context manager.
74
+
75
+ Useful in cases when you want to apply timeout logic around block
76
+ of code or in cases when asyncio.wait_for is not suitable. For example:
77
+
78
+ >>> async with timeout(0.001):
79
+ ... async with aiohttp.get('https://github.com') as r:
80
+ ... await r.text()
81
+
82
+
83
+ delay - value in seconds or None to disable timeout logic
84
+ """
85
+ loop = asyncio.get_running_loop()
86
+ if delay is not None:
87
+ deadline = loop.time() + delay # type: Optional[float]
88
+ else:
89
+ deadline = None
90
+ return Timeout(deadline, loop)
91
+
92
+
93
+ def timeout_at(deadline: Optional[float]) -> "Timeout":
94
+ """Schedule the timeout at absolute time.
95
+
96
+ deadline argument points on the time in the same clock system
97
+ as loop.time().
98
+
99
+ Please note: it is not POSIX time but a time with
100
+ undefined starting base, e.g. the time of the system power on.
101
+
102
+ >>> async with timeout_at(loop.time() + 10):
103
+ ... async with aiohttp.get('https://github.com') as r:
104
+ ... await r.text()
105
+
106
+
107
+ """
108
+ loop = asyncio.get_running_loop()
109
+ return Timeout(deadline, loop)
110
+
111
+
112
+ class _State(enum.Enum):
113
+ INIT = "INIT"
114
+ ENTER = "ENTER"
115
+ TIMEOUT = "TIMEOUT"
116
+ EXIT = "EXIT"
117
+
118
+
119
+ @final
120
+ class Timeout:
121
+ # Internal class, please don't instantiate it directly
122
+ # Use timeout() and timeout_at() public factories instead.
123
+ #
124
+ # Implementation note: `async with timeout()` is preferred
125
+ # over `with timeout()`.
126
+ # While technically the Timeout class implementation
127
+ # doesn't need to be async at all,
128
+ # the `async with` statement explicitly points that
129
+ # the context manager should be used from async function context.
130
+ #
131
+ # This design allows to avoid many silly misusages.
132
+ #
133
+ # TimeoutError is raised immediately when scheduled
134
+ # if the deadline is passed.
135
+ # The purpose is to time out as soon as possible
136
+ # without waiting for the next await expression.
137
+
138
+ __slots__ = ("_deadline", "_loop", "_state", "_timeout_handler", "_task")
139
+
140
+ def __init__(
141
+ self, deadline: Optional[float], loop: asyncio.AbstractEventLoop
142
+ ) -> None:
143
+ self._loop = loop
144
+ self._state = _State.INIT
145
+
146
+ self._task: Optional["asyncio.Task[object]"] = None
147
+ self._timeout_handler = None # type: Optional[asyncio.Handle]
148
+ if deadline is None:
149
+ self._deadline = None # type: Optional[float]
150
+ else:
151
+ self.update(deadline)
152
+
153
+ def __enter__(self) -> "Timeout":
154
+ warnings.warn(
155
+ "with timeout() is deprecated, use async with timeout() instead",
156
+ DeprecationWarning,
157
+ stacklevel=2,
158
+ )
159
+ self._do_enter()
160
+ return self
161
+
162
+ def __exit__(
163
+ self,
164
+ exc_type: Optional[Type[BaseException]],
165
+ exc_val: Optional[BaseException],
166
+ exc_tb: Optional[TracebackType],
167
+ ) -> Optional[bool]:
168
+ self._do_exit(exc_type)
169
+ return None
170
+
171
+ async def __aenter__(self) -> "Timeout":
172
+ self._do_enter()
173
+ return self
174
+
175
+ async def __aexit__(
176
+ self,
177
+ exc_type: Optional[Type[BaseException]],
178
+ exc_val: Optional[BaseException],
179
+ exc_tb: Optional[TracebackType],
180
+ ) -> Optional[bool]:
181
+ self._do_exit(exc_type)
182
+ return None
183
+
184
+ @property
185
+ def expired(self) -> bool:
186
+ """Is timeout expired during execution?"""
187
+ return self._state == _State.TIMEOUT
188
+
189
+ @property
190
+ def deadline(self) -> Optional[float]:
191
+ return self._deadline
192
+
193
+ def reject(self) -> None:
194
+ """Reject scheduled timeout if any."""
195
+ # cancel is maybe better name but
196
+ # task.cancel() raises CancelledError in asyncio world.
197
+ if self._state not in (_State.INIT, _State.ENTER):
198
+ raise RuntimeError(f"invalid state {self._state.value}")
199
+ self._reject()
200
+
201
+ def _reject(self) -> None:
202
+ self._task = None
203
+ if self._timeout_handler is not None:
204
+ self._timeout_handler.cancel()
205
+ self._timeout_handler = None
206
+
207
+ def shift(self, delay: float) -> None:
208
+ """Advance timeout on delay seconds.
209
+
210
+ The delay can be negative.
211
+
212
+ Raise RuntimeError if shift is called when deadline is not scheduled
213
+ """
214
+ deadline = self._deadline
215
+ if deadline is None:
216
+ raise RuntimeError("cannot shift timeout if deadline is not scheduled")
217
+ self.update(deadline + delay)
218
+
219
+ def update(self, deadline: float) -> None:
220
+ """Set deadline to absolute value.
221
+
222
+ deadline argument points on the time in the same clock system
223
+ as loop.time().
224
+
225
+ If new deadline is in the past the timeout is raised immediately.
226
+
227
+ Please note: it is not POSIX time but a time with
228
+ undefined starting base, e.g. the time of the system power on.
229
+ """
230
+ if self._state == _State.EXIT:
231
+ raise RuntimeError("cannot reschedule after exit from context manager")
232
+ if self._state == _State.TIMEOUT:
233
+ raise RuntimeError("cannot reschedule expired timeout")
234
+ if self._timeout_handler is not None:
235
+ self._timeout_handler.cancel()
236
+ self._deadline = deadline
237
+ if self._state != _State.INIT:
238
+ self._reschedule()
239
+
240
+ def _reschedule(self) -> None:
241
+ assert self._state == _State.ENTER
242
+ deadline = self._deadline
243
+ if deadline is None:
244
+ return
245
+
246
+ now = self._loop.time()
247
+ if self._timeout_handler is not None:
248
+ self._timeout_handler.cancel()
249
+
250
+ self._task = asyncio.current_task()
251
+ if deadline <= now:
252
+ self._timeout_handler = self._loop.call_soon(self._on_timeout)
253
+ else:
254
+ self._timeout_handler = self._loop.call_at(deadline, self._on_timeout)
255
+
256
+ def _do_enter(self) -> None:
257
+ if self._state != _State.INIT:
258
+ raise RuntimeError(f"invalid state {self._state.value}")
259
+ self._state = _State.ENTER
260
+ self._reschedule()
261
+
262
+ def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None:
263
+ if exc_type is asyncio.CancelledError and self._state == _State.TIMEOUT:
264
+ assert self._task is not None
265
+ _uncancel_task(self._task)
266
+ self._timeout_handler = None
267
+ self._task = None
268
+ raise asyncio.TimeoutError
269
+ # timeout has not expired
270
+ self._state = _State.EXIT
271
+ self._reject()
272
+ return None
273
+
274
+ def _on_timeout(self) -> None:
275
+ assert self._task is not None
276
+ self._task.cancel()
277
+ self._state = _State.TIMEOUT
278
+ # drop the reference early
279
+ self._timeout_handler = None
280
+
281
+
282
+ # End https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py
source/websockets/asyncio/client.py ADDED
@@ -0,0 +1,804 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ import os
6
+ import socket
7
+ import ssl as ssl_module
8
+ import traceback
9
+ import urllib.parse
10
+ from collections.abc import AsyncIterator, Generator, Sequence
11
+ from types import TracebackType
12
+ from typing import Any, Callable, Literal, cast
13
+
14
+ from ..client import ClientProtocol, backoff
15
+ from ..datastructures import HeadersLike
16
+ from ..exceptions import (
17
+ InvalidMessage,
18
+ InvalidProxyMessage,
19
+ InvalidProxyStatus,
20
+ InvalidStatus,
21
+ ProxyError,
22
+ SecurityError,
23
+ )
24
+ from ..extensions.base import ClientExtensionFactory
25
+ from ..extensions.permessage_deflate import enable_client_permessage_deflate
26
+ from ..headers import validate_subprotocols
27
+ from ..http11 import USER_AGENT, Response
28
+ from ..protocol import CONNECTING, Event
29
+ from ..proxy import Proxy, get_proxy, parse_proxy, prepare_connect_request
30
+ from ..streams import StreamReader
31
+ from ..typing import LoggerLike, Origin, Subprotocol
32
+ from ..uri import WebSocketURI, parse_uri
33
+ from .compatibility import TimeoutError, asyncio_timeout
34
+ from .connection import Connection
35
+
36
+
37
+ __all__ = ["connect", "unix_connect", "ClientConnection"]
38
+
39
+ MAX_REDIRECTS = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10"))
40
+
41
+
42
+ class ClientConnection(Connection):
43
+ """
44
+ :mod:`asyncio` implementation of a WebSocket client connection.
45
+
46
+ :class:`ClientConnection` provides :meth:`recv` and :meth:`send` coroutines
47
+ for receiving and sending messages.
48
+
49
+ It supports asynchronous iteration to receive messages::
50
+
51
+ async for message in websocket:
52
+ await process(message)
53
+
54
+ The iterator exits normally when the connection is closed with code
55
+ 1000 (OK) or 1001 (going away) or without a close code. It raises a
56
+ :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is
57
+ closed with any other code.
58
+
59
+ The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``,
60
+ and ``write_limit`` arguments have the same meaning as in :func:`connect`.
61
+
62
+ Args:
63
+ protocol: Sans-I/O connection.
64
+
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ protocol: ClientProtocol,
70
+ *,
71
+ ping_interval: float | None = 20,
72
+ ping_timeout: float | None = 20,
73
+ close_timeout: float | None = 10,
74
+ max_queue: int | None | tuple[int | None, int | None] = 16,
75
+ write_limit: int | tuple[int, int | None] = 2**15,
76
+ ) -> None:
77
+ self.protocol: ClientProtocol
78
+ super().__init__(
79
+ protocol,
80
+ ping_interval=ping_interval,
81
+ ping_timeout=ping_timeout,
82
+ close_timeout=close_timeout,
83
+ max_queue=max_queue,
84
+ write_limit=write_limit,
85
+ )
86
+ self.response_rcvd: asyncio.Future[None] = self.loop.create_future()
87
+
88
+ async def handshake(
89
+ self,
90
+ additional_headers: HeadersLike | None = None,
91
+ user_agent_header: str | None = USER_AGENT,
92
+ ) -> None:
93
+ """
94
+ Perform the opening handshake.
95
+
96
+ """
97
+ async with self.send_context(expected_state=CONNECTING):
98
+ self.request = self.protocol.connect()
99
+ if additional_headers is not None:
100
+ self.request.headers.update(additional_headers)
101
+ if user_agent_header is not None:
102
+ self.request.headers.setdefault("User-Agent", user_agent_header)
103
+ self.protocol.send_request(self.request)
104
+
105
+ await asyncio.wait(
106
+ [self.response_rcvd, self.connection_lost_waiter],
107
+ return_when=asyncio.FIRST_COMPLETED,
108
+ )
109
+
110
+ # self.protocol.handshake_exc is set when the connection is lost before
111
+ # receiving a response, when the response cannot be parsed, or when the
112
+ # response fails the handshake.
113
+
114
+ if self.protocol.handshake_exc is not None:
115
+ raise self.protocol.handshake_exc
116
+
117
+ def process_event(self, event: Event) -> None:
118
+ """
119
+ Process one incoming event.
120
+
121
+ """
122
+ # First event - handshake response.
123
+ if self.response is None:
124
+ assert isinstance(event, Response)
125
+ self.response = event
126
+ self.response_rcvd.set_result(None)
127
+ # Later events - frames.
128
+ else:
129
+ super().process_event(event)
130
+
131
+
132
+ def process_exception(exc: Exception) -> Exception | None:
133
+ """
134
+ Determine whether a connection error is retryable or fatal.
135
+
136
+ When reconnecting automatically with ``async for ... in connect(...)``, if a
137
+ connection attempt fails, :func:`process_exception` is called to determine
138
+ whether to retry connecting or to raise the exception.
139
+
140
+ This function defines the default behavior, which is to retry on:
141
+
142
+ * :exc:`EOFError`, :exc:`OSError`, :exc:`asyncio.TimeoutError`: network
143
+ errors;
144
+ * :exc:`~websockets.exceptions.InvalidStatus` when the status code is 500,
145
+ 502, 503, or 504: server or proxy errors.
146
+
147
+ All other exceptions are considered fatal.
148
+
149
+ You can change this behavior with the ``process_exception`` argument of
150
+ :func:`connect`.
151
+
152
+ Return :obj:`None` if the exception is retryable i.e. when the error could
153
+ be transient and trying to reconnect with the same parameters could succeed.
154
+ The exception will be logged at the ``INFO`` level.
155
+
156
+ Return an exception, either ``exc`` or a new exception, if the exception is
157
+ fatal i.e. when trying to reconnect will most likely produce the same error.
158
+ That exception will be raised, breaking out of the retry loop.
159
+
160
+ """
161
+ # This catches python-socks' ProxyConnectionError and ProxyTimeoutError.
162
+ # Remove asyncio.TimeoutError when dropping Python < 3.11.
163
+ if isinstance(exc, (OSError, TimeoutError, asyncio.TimeoutError)):
164
+ return None
165
+ if isinstance(exc, InvalidMessage) and isinstance(exc.__cause__, EOFError):
166
+ return None
167
+ if isinstance(exc, InvalidStatus) and exc.response.status_code in [
168
+ 500, # Internal Server Error
169
+ 502, # Bad Gateway
170
+ 503, # Service Unavailable
171
+ 504, # Gateway Timeout
172
+ ]:
173
+ return None
174
+ return exc
175
+
176
+
177
+ # This is spelled in lower case because it's exposed as a callable in the API.
178
+ class connect:
179
+ """
180
+ Connect to the WebSocket server at ``uri``.
181
+
182
+ This coroutine returns a :class:`ClientConnection` instance, which you can
183
+ use to send and receive messages.
184
+
185
+ :func:`connect` may be used as an asynchronous context manager::
186
+
187
+ from websockets.asyncio.client import connect
188
+
189
+ async with connect(...) as websocket:
190
+ ...
191
+
192
+ The connection is closed automatically when exiting the context.
193
+
194
+ :func:`connect` can be used as an infinite asynchronous iterator to
195
+ reconnect automatically on errors::
196
+
197
+ async for websocket in connect(...):
198
+ try:
199
+ ...
200
+ except websockets.exceptions.ConnectionClosed:
201
+ continue
202
+
203
+ If the connection fails with a transient error, it is retried with
204
+ exponential backoff. If it fails with a fatal error, the exception is
205
+ raised, breaking out of the loop.
206
+
207
+ The connection is closed automatically after each iteration of the loop.
208
+
209
+ Args:
210
+ uri: URI of the WebSocket server.
211
+ origin: Value of the ``Origin`` header, for servers that require it.
212
+ extensions: List of supported extensions, in order in which they
213
+ should be negotiated and run.
214
+ subprotocols: List of supported subprotocols, in order of decreasing
215
+ preference.
216
+ compression: The "permessage-deflate" extension is enabled by default.
217
+ Set ``compression`` to :obj:`None` to disable it. See the
218
+ :doc:`compression guide <../../topics/compression>` for details.
219
+ additional_headers (HeadersLike | None): Arbitrary HTTP headers to add
220
+ to the handshake request.
221
+ user_agent_header: Value of the ``User-Agent`` request header.
222
+ It defaults to ``"Python/x.y.z websockets/X.Y"``.
223
+ Setting it to :obj:`None` removes the header.
224
+ proxy: If a proxy is configured, it is used by default. Set ``proxy``
225
+ to :obj:`None` to disable the proxy or to the address of a proxy
226
+ to override the system configuration. See the :doc:`proxy docs
227
+ <../../topics/proxies>` for details.
228
+ process_exception: When reconnecting automatically, tell whether an
229
+ error is transient or fatal. The default behavior is defined by
230
+ :func:`process_exception`. Refer to its documentation for details.
231
+ open_timeout: Timeout for opening the connection in seconds.
232
+ :obj:`None` disables the timeout.
233
+ ping_interval: Interval between keepalive pings in seconds.
234
+ :obj:`None` disables keepalive.
235
+ ping_timeout: Timeout for keepalive pings in seconds.
236
+ :obj:`None` disables timeouts.
237
+ close_timeout: Timeout for closing the connection in seconds.
238
+ :obj:`None` disables the timeout.
239
+ max_size: Maximum size of incoming messages in bytes.
240
+ :obj:`None` disables the limit. You may pass a ``(max_message_size,
241
+ max_fragment_size)`` tuple to set different limits for messages and
242
+ fragments when you expect long messages sent in short fragments.
243
+ max_queue: High-water mark of the buffer where frames are received.
244
+ It defaults to 16 frames. The low-water mark defaults to ``max_queue
245
+ // 4``. You may pass a ``(high, low)`` tuple to set the high-water
246
+ and low-water marks. If you want to disable flow control entirely,
247
+ you may set it to ``None``, although that's a bad idea.
248
+ write_limit: High-water mark of write buffer in bytes. It is passed to
249
+ :meth:`~asyncio.WriteTransport.set_write_buffer_limits`. It defaults
250
+ to 32 KiB. You may pass a ``(high, low)`` tuple to set the
251
+ high-water and low-water marks.
252
+ logger: Logger for this client.
253
+ It defaults to ``logging.getLogger("websockets.client")``.
254
+ See the :doc:`logging guide <../../topics/logging>` for details.
255
+ create_connection: Factory for the :class:`ClientConnection` managing
256
+ the connection. Set it to a wrapper or a subclass to customize
257
+ connection handling.
258
+
259
+ Any other keyword arguments are passed to the event loop's
260
+ :meth:`~asyncio.loop.create_connection` method.
261
+
262
+ For example:
263
+
264
+ * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enforce TLS settings.
265
+ When connecting to a ``wss://`` URI, if ``ssl`` isn't provided, a TLS
266
+ context is created with :func:`~ssl.create_default_context`.
267
+
268
+ * You can set ``server_hostname`` to override the host name from ``uri`` in
269
+ the TLS handshake.
270
+
271
+ * You can set ``host`` and ``port`` to connect to a different host and port
272
+ from those found in ``uri``. This only changes the destination of the TCP
273
+ connection. The host name from ``uri`` is still used in the TLS handshake
274
+ for secure connections and in the ``Host`` header.
275
+
276
+ * You can set ``sock`` to provide a preexisting TCP socket. You may call
277
+ :func:`socket.create_connection` (not to be confused with the event loop's
278
+ :meth:`~asyncio.loop.create_connection` method) to create a suitable
279
+ client socket and customize it.
280
+
281
+ When using a proxy:
282
+
283
+ * Prefix keyword arguments with ``proxy_`` for configuring TLS between the
284
+ client and an HTTPS proxy: ``proxy_ssl``, ``proxy_server_hostname``,
285
+ ``proxy_ssl_handshake_timeout``, and ``proxy_ssl_shutdown_timeout``.
286
+ * Use the standard keyword arguments for configuring TLS between the proxy
287
+ and the WebSocket server: ``ssl``, ``server_hostname``,
288
+ ``ssl_handshake_timeout``, and ``ssl_shutdown_timeout``.
289
+ * Other keyword arguments are used only for connecting to the proxy.
290
+
291
+ Raises:
292
+ InvalidURI: If ``uri`` isn't a valid WebSocket URI.
293
+ InvalidProxy: If ``proxy`` isn't a valid proxy.
294
+ OSError: If the TCP connection fails.
295
+ InvalidHandshake: If the opening handshake fails.
296
+ TimeoutError: If the opening handshake times out.
297
+
298
+ """
299
+
300
+ def __init__(
301
+ self,
302
+ uri: str,
303
+ *,
304
+ # WebSocket
305
+ origin: Origin | None = None,
306
+ extensions: Sequence[ClientExtensionFactory] | None = None,
307
+ subprotocols: Sequence[Subprotocol] | None = None,
308
+ compression: str | None = "deflate",
309
+ # HTTP
310
+ additional_headers: HeadersLike | None = None,
311
+ user_agent_header: str | None = USER_AGENT,
312
+ proxy: str | Literal[True] | None = True,
313
+ process_exception: Callable[[Exception], Exception | None] = process_exception,
314
+ # Timeouts
315
+ open_timeout: float | None = 10,
316
+ ping_interval: float | None = 20,
317
+ ping_timeout: float | None = 20,
318
+ close_timeout: float | None = 10,
319
+ # Limits
320
+ max_size: int | None | tuple[int | None, int | None] = 2**20,
321
+ max_queue: int | None | tuple[int | None, int | None] = 16,
322
+ write_limit: int | tuple[int, int | None] = 2**15,
323
+ # Logging
324
+ logger: LoggerLike | None = None,
325
+ # Escape hatch for advanced customization
326
+ create_connection: type[ClientConnection] | None = None,
327
+ # Other keyword arguments are passed to loop.create_connection
328
+ **kwargs: Any,
329
+ ) -> None:
330
+ self.uri = uri
331
+
332
+ if subprotocols is not None:
333
+ validate_subprotocols(subprotocols)
334
+
335
+ if compression == "deflate":
336
+ extensions = enable_client_permessage_deflate(extensions)
337
+ elif compression is not None:
338
+ raise ValueError(f"unsupported compression: {compression}")
339
+
340
+ if logger is None:
341
+ logger = logging.getLogger("websockets.client")
342
+
343
+ if create_connection is None:
344
+ create_connection = ClientConnection
345
+
346
+ def protocol_factory(uri: WebSocketURI) -> ClientConnection:
347
+ # This is a protocol in the Sans-I/O implementation of websockets.
348
+ protocol = ClientProtocol(
349
+ uri,
350
+ origin=origin,
351
+ extensions=extensions,
352
+ subprotocols=subprotocols,
353
+ max_size=max_size,
354
+ logger=logger,
355
+ )
356
+ # This is a connection in websockets and a protocol in asyncio.
357
+ connection = create_connection(
358
+ protocol,
359
+ ping_interval=ping_interval,
360
+ ping_timeout=ping_timeout,
361
+ close_timeout=close_timeout,
362
+ max_queue=max_queue,
363
+ write_limit=write_limit,
364
+ )
365
+ return connection
366
+
367
+ self.proxy = proxy
368
+ self.protocol_factory = protocol_factory
369
+ self.additional_headers = additional_headers
370
+ self.user_agent_header = user_agent_header
371
+ self.process_exception = process_exception
372
+ self.open_timeout = open_timeout
373
+ self.logger = logger
374
+ self.connection_kwargs = kwargs
375
+
376
+ async def create_connection(self) -> ClientConnection:
377
+ """Create TCP or Unix connection."""
378
+ loop = asyncio.get_running_loop()
379
+ kwargs = self.connection_kwargs.copy()
380
+
381
+ ws_uri = parse_uri(self.uri)
382
+
383
+ proxy = self.proxy
384
+ if kwargs.get("unix", False):
385
+ proxy = None
386
+ if kwargs.get("sock") is not None:
387
+ proxy = None
388
+ if proxy is True:
389
+ proxy = get_proxy(ws_uri)
390
+
391
+ def factory() -> ClientConnection:
392
+ return self.protocol_factory(ws_uri)
393
+
394
+ if ws_uri.secure:
395
+ kwargs.setdefault("ssl", True)
396
+ kwargs.setdefault("server_hostname", ws_uri.host)
397
+ if kwargs.get("ssl") is None:
398
+ raise ValueError("ssl=None is incompatible with a wss:// URI")
399
+ else:
400
+ if kwargs.get("ssl") is not None:
401
+ raise ValueError("ssl argument is incompatible with a ws:// URI")
402
+
403
+ if kwargs.pop("unix", False):
404
+ _, connection = await loop.create_unix_connection(factory, **kwargs)
405
+ elif proxy is not None:
406
+ proxy_parsed = parse_proxy(proxy)
407
+ if proxy_parsed.scheme[:5] == "socks":
408
+ # Connect to the server through the proxy.
409
+ sock = await connect_socks_proxy(
410
+ proxy_parsed,
411
+ ws_uri,
412
+ local_addr=kwargs.pop("local_addr", None),
413
+ )
414
+ # Initialize WebSocket connection via the proxy.
415
+ _, connection = await loop.create_connection(
416
+ factory,
417
+ sock=sock,
418
+ **kwargs,
419
+ )
420
+ elif proxy_parsed.scheme[:4] == "http":
421
+ # Split keyword arguments between the proxy and the server.
422
+ all_kwargs, proxy_kwargs, kwargs = kwargs, {}, {}
423
+ for key, value in all_kwargs.items():
424
+ if key.startswith("ssl") or key == "server_hostname":
425
+ kwargs[key] = value
426
+ elif key.startswith("proxy_"):
427
+ proxy_kwargs[key[6:]] = value
428
+ else:
429
+ proxy_kwargs[key] = value
430
+ # Validate the proxy_ssl argument.
431
+ if proxy_parsed.scheme == "https":
432
+ proxy_kwargs.setdefault("ssl", True)
433
+ if proxy_kwargs.get("ssl") is None:
434
+ raise ValueError(
435
+ "proxy_ssl=None is incompatible with an https:// proxy"
436
+ )
437
+ else:
438
+ if proxy_kwargs.get("ssl") is not None:
439
+ raise ValueError(
440
+ "proxy_ssl argument is incompatible with an http:// proxy"
441
+ )
442
+ # Connect to the server through the proxy.
443
+ transport = await connect_http_proxy(
444
+ proxy_parsed,
445
+ ws_uri,
446
+ user_agent_header=self.user_agent_header,
447
+ **proxy_kwargs,
448
+ )
449
+ # Initialize WebSocket connection via the proxy.
450
+ connection = factory()
451
+ transport.set_protocol(connection)
452
+ ssl = kwargs.pop("ssl", None)
453
+ if ssl is True:
454
+ ssl = ssl_module.create_default_context()
455
+ if ssl is not None:
456
+ new_transport = await loop.start_tls(
457
+ transport, connection, ssl, **kwargs
458
+ )
459
+ assert new_transport is not None # help mypy
460
+ transport = new_transport
461
+ connection.connection_made(transport)
462
+ else:
463
+ raise AssertionError("unsupported proxy")
464
+ else:
465
+ # Connect to the server directly.
466
+ if kwargs.get("sock") is None:
467
+ kwargs.setdefault("host", ws_uri.host)
468
+ kwargs.setdefault("port", ws_uri.port)
469
+ # Initialize WebSocket connection.
470
+ _, connection = await loop.create_connection(factory, **kwargs)
471
+ return connection
472
+
473
+ def process_redirect(self, exc: Exception) -> Exception | str:
474
+ """
475
+ Determine whether a connection error is a redirect that can be followed.
476
+
477
+ Return the new URI if it's a valid redirect. Else, return an exception.
478
+
479
+ """
480
+ if not (
481
+ isinstance(exc, InvalidStatus)
482
+ and exc.response.status_code
483
+ in [
484
+ 300, # Multiple Choices
485
+ 301, # Moved Permanently
486
+ 302, # Found
487
+ 303, # See Other
488
+ 307, # Temporary Redirect
489
+ 308, # Permanent Redirect
490
+ ]
491
+ and "Location" in exc.response.headers
492
+ ):
493
+ return exc
494
+
495
+ old_ws_uri = parse_uri(self.uri)
496
+ new_uri = urllib.parse.urljoin(self.uri, exc.response.headers["Location"])
497
+ new_ws_uri = parse_uri(new_uri)
498
+
499
+ # If connect() received a socket, it is closed and cannot be reused.
500
+ if self.connection_kwargs.get("sock") is not None:
501
+ return ValueError(
502
+ f"cannot follow redirect to {new_uri} with a preexisting socket"
503
+ )
504
+
505
+ # TLS downgrade is forbidden.
506
+ if old_ws_uri.secure and not new_ws_uri.secure:
507
+ return SecurityError(f"cannot follow redirect to non-secure URI {new_uri}")
508
+
509
+ # Apply restrictions to cross-origin redirects.
510
+ if (
511
+ old_ws_uri.secure != new_ws_uri.secure
512
+ or old_ws_uri.host != new_ws_uri.host
513
+ or old_ws_uri.port != new_ws_uri.port
514
+ ):
515
+ # Cross-origin redirects on Unix sockets don't quite make sense.
516
+ if self.connection_kwargs.get("unix", False):
517
+ return ValueError(
518
+ f"cannot follow cross-origin redirect to {new_uri} "
519
+ f"with a Unix socket"
520
+ )
521
+
522
+ # Cross-origin redirects when host and port are overridden are ill-defined.
523
+ if (
524
+ self.connection_kwargs.get("host") is not None
525
+ or self.connection_kwargs.get("port") is not None
526
+ ):
527
+ return ValueError(
528
+ f"cannot follow cross-origin redirect to {new_uri} "
529
+ f"with an explicit host or port"
530
+ )
531
+
532
+ return new_uri
533
+
534
+ # ... = await connect(...)
535
+
536
+ def __await__(self) -> Generator[Any, None, ClientConnection]:
537
+ # Create a suitable iterator by calling __await__ on a coroutine.
538
+ return self.__await_impl__().__await__()
539
+
540
+ async def __await_impl__(self) -> ClientConnection:
541
+ try:
542
+ async with asyncio_timeout(self.open_timeout):
543
+ for _ in range(MAX_REDIRECTS):
544
+ self.connection = await self.create_connection()
545
+ try:
546
+ await self.connection.handshake(
547
+ self.additional_headers,
548
+ self.user_agent_header,
549
+ )
550
+ except asyncio.CancelledError:
551
+ self.connection.transport.abort()
552
+ raise
553
+ except Exception as exc:
554
+ # Always close the connection even though keep-alive is
555
+ # the default in HTTP/1.1 because create_connection ties
556
+ # opening the network connection with initializing the
557
+ # protocol. In the current design of connect(), there is
558
+ # no easy way to reuse the network connection that works
559
+ # in every case nor to reinitialize the protocol.
560
+ self.connection.transport.abort()
561
+
562
+ uri_or_exc = self.process_redirect(exc)
563
+ # Response is a valid redirect; follow it.
564
+ if isinstance(uri_or_exc, str):
565
+ self.uri = uri_or_exc
566
+ continue
567
+ # Response isn't a valid redirect; raise the exception.
568
+ if uri_or_exc is exc:
569
+ raise
570
+ else:
571
+ raise uri_or_exc from exc
572
+
573
+ else:
574
+ self.connection.start_keepalive()
575
+ return self.connection
576
+ else:
577
+ raise SecurityError(f"more than {MAX_REDIRECTS} redirects")
578
+
579
+ except TimeoutError as exc:
580
+ # Re-raise exception with an informative error message.
581
+ raise TimeoutError("timed out during opening handshake") from exc
582
+
583
+ # ... = yield from connect(...) - remove when dropping Python < 3.11
584
+
585
+ __iter__ = __await__
586
+
587
+ # async with connect(...) as ...: ...
588
+
589
+ async def __aenter__(self) -> ClientConnection:
590
+ return await self
591
+
592
+ async def __aexit__(
593
+ self,
594
+ exc_type: type[BaseException] | None,
595
+ exc_value: BaseException | None,
596
+ traceback: TracebackType | None,
597
+ ) -> None:
598
+ await self.connection.close()
599
+
600
+ # async for ... in connect(...):
601
+
602
+ async def __aiter__(self) -> AsyncIterator[ClientConnection]:
603
+ delays: Generator[float] | None = None
604
+ while True:
605
+ try:
606
+ async with self as protocol:
607
+ yield protocol
608
+ except Exception as exc:
609
+ # Determine whether the exception is retryable or fatal.
610
+ # The API of process_exception is "return an exception or None";
611
+ # "raise an exception" is also supported because it's a frequent
612
+ # mistake. It isn't documented in order to keep the API simple.
613
+ try:
614
+ new_exc = self.process_exception(exc)
615
+ except Exception as raised_exc:
616
+ new_exc = raised_exc
617
+
618
+ # The connection failed with a fatal error.
619
+ # Raise the exception and exit the loop.
620
+ if new_exc is exc:
621
+ raise
622
+ if new_exc is not None:
623
+ raise new_exc from exc
624
+
625
+ # The connection failed with a retryable error.
626
+ # Start or continue backoff and reconnect.
627
+ if delays is None:
628
+ delays = backoff()
629
+ delay = next(delays)
630
+ self.logger.info(
631
+ "connect failed; reconnecting in %.1f seconds: %s",
632
+ delay,
633
+ traceback.format_exception_only(exc)[0].strip(),
634
+ )
635
+ await asyncio.sleep(delay)
636
+ continue
637
+
638
+ else:
639
+ # The connection succeeded. Reset backoff.
640
+ delays = None
641
+
642
+
643
+ def unix_connect(
644
+ path: str | None = None,
645
+ uri: str | None = None,
646
+ **kwargs: Any,
647
+ ) -> connect:
648
+ """
649
+ Connect to a WebSocket server listening on a Unix socket.
650
+
651
+ This function accepts the same keyword arguments as :func:`connect`.
652
+
653
+ It's only available on Unix.
654
+
655
+ It's mainly useful for debugging servers listening on Unix sockets.
656
+
657
+ Args:
658
+ path: File system path to the Unix socket.
659
+ uri: URI of the WebSocket server. ``uri`` defaults to
660
+ ``ws://localhost/`` or, when a ``ssl`` argument is provided, to
661
+ ``wss://localhost/``.
662
+
663
+ """
664
+ if uri is None:
665
+ if kwargs.get("ssl") is None:
666
+ uri = "ws://localhost/"
667
+ else:
668
+ uri = "wss://localhost/"
669
+ return connect(uri=uri, unix=True, path=path, **kwargs)
670
+
671
+
672
+ try:
673
+ from python_socks import ProxyType
674
+ from python_socks.async_.asyncio import Proxy as SocksProxy
675
+
676
+ except ImportError:
677
+
678
+ async def connect_socks_proxy(
679
+ proxy: Proxy,
680
+ ws_uri: WebSocketURI,
681
+ **kwargs: Any,
682
+ ) -> socket.socket:
683
+ raise ImportError("connecting through a SOCKS proxy requires python-socks")
684
+
685
+ else:
686
+ SOCKS_PROXY_TYPES = {
687
+ "socks5h": ProxyType.SOCKS5,
688
+ "socks5": ProxyType.SOCKS5,
689
+ "socks4a": ProxyType.SOCKS4,
690
+ "socks4": ProxyType.SOCKS4,
691
+ }
692
+
693
+ SOCKS_PROXY_RDNS = {
694
+ "socks5h": True,
695
+ "socks5": False,
696
+ "socks4a": True,
697
+ "socks4": False,
698
+ }
699
+
700
+ async def connect_socks_proxy(
701
+ proxy: Proxy,
702
+ ws_uri: WebSocketURI,
703
+ **kwargs: Any,
704
+ ) -> socket.socket:
705
+ """Connect via a SOCKS proxy and return the socket."""
706
+ socks_proxy = SocksProxy(
707
+ SOCKS_PROXY_TYPES[proxy.scheme],
708
+ proxy.host,
709
+ proxy.port,
710
+ proxy.username,
711
+ proxy.password,
712
+ SOCKS_PROXY_RDNS[proxy.scheme],
713
+ )
714
+ # connect() is documented to raise OSError.
715
+ # socks_proxy.connect() doesn't raise TimeoutError; it gets canceled.
716
+ # Wrap other exceptions in ProxyError, a subclass of InvalidHandshake.
717
+ try:
718
+ return await socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs)
719
+ except OSError:
720
+ raise
721
+ except Exception as exc:
722
+ raise ProxyError("failed to connect to SOCKS proxy") from exc
723
+
724
+
725
+ class HTTPProxyConnection(asyncio.Protocol):
726
+ def __init__(
727
+ self,
728
+ ws_uri: WebSocketURI,
729
+ proxy: Proxy,
730
+ user_agent_header: str | None = None,
731
+ ):
732
+ self.ws_uri = ws_uri
733
+ self.proxy = proxy
734
+ self.user_agent_header = user_agent_header
735
+
736
+ self.reader = StreamReader()
737
+ self.parser = Response.parse(
738
+ self.reader.read_line,
739
+ self.reader.read_exact,
740
+ self.reader.read_to_eof,
741
+ proxy=True,
742
+ )
743
+
744
+ loop = asyncio.get_running_loop()
745
+ self.response: asyncio.Future[Response] = loop.create_future()
746
+
747
+ def run_parser(self) -> None:
748
+ try:
749
+ next(self.parser)
750
+ except StopIteration as exc:
751
+ response = exc.value
752
+ if 200 <= response.status_code < 300:
753
+ self.response.set_result(response)
754
+ else:
755
+ self.response.set_exception(InvalidProxyStatus(response))
756
+ except Exception as exc:
757
+ proxy_exc = InvalidProxyMessage(
758
+ "did not receive a valid HTTP response from proxy"
759
+ )
760
+ proxy_exc.__cause__ = exc
761
+ self.response.set_exception(proxy_exc)
762
+
763
+ def connection_made(self, transport: asyncio.BaseTransport) -> None:
764
+ transport = cast(asyncio.Transport, transport)
765
+ self.transport = transport
766
+ self.transport.write(
767
+ prepare_connect_request(self.proxy, self.ws_uri, self.user_agent_header)
768
+ )
769
+
770
+ def data_received(self, data: bytes) -> None:
771
+ self.reader.feed_data(data)
772
+ self.run_parser()
773
+
774
+ def eof_received(self) -> None:
775
+ self.reader.feed_eof()
776
+ self.run_parser()
777
+
778
+ def connection_lost(self, exc: Exception | None) -> None:
779
+ self.reader.feed_eof()
780
+ if exc is not None:
781
+ self.response.set_exception(exc)
782
+
783
+
784
+ async def connect_http_proxy(
785
+ proxy: Proxy,
786
+ ws_uri: WebSocketURI,
787
+ user_agent_header: str | None = None,
788
+ **kwargs: Any,
789
+ ) -> asyncio.Transport:
790
+ transport, protocol = await asyncio.get_running_loop().create_connection(
791
+ lambda: HTTPProxyConnection(ws_uri, proxy, user_agent_header),
792
+ proxy.host,
793
+ proxy.port,
794
+ **kwargs,
795
+ )
796
+
797
+ try:
798
+ # This raises exceptions if the connection to the proxy fails.
799
+ await protocol.response
800
+ except Exception:
801
+ transport.close()
802
+ raise
803
+
804
+ return transport
source/websockets/asyncio/compatibility.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+
5
+
6
+ __all__ = ["TimeoutError", "aiter", "anext", "asyncio_timeout", "asyncio_timeout_at"]
7
+
8
+
9
+ if sys.version_info[:2] >= (3, 11):
10
+ TimeoutError = TimeoutError
11
+ aiter = aiter
12
+ anext = anext
13
+ from asyncio import (
14
+ timeout as asyncio_timeout, # noqa: F401
15
+ timeout_at as asyncio_timeout_at, # noqa: F401
16
+ )
17
+
18
+ else: # Python < 3.11
19
+ from asyncio import TimeoutError
20
+
21
+ def aiter(async_iterable):
22
+ return type(async_iterable).__aiter__(async_iterable)
23
+
24
+ async def anext(async_iterator):
25
+ return await type(async_iterator).__anext__(async_iterator)
26
+
27
+ from .async_timeout import (
28
+ timeout as asyncio_timeout, # noqa: F401
29
+ timeout_at as asyncio_timeout_at, # noqa: F401
30
+ )
source/websockets/asyncio/connection.py ADDED
@@ -0,0 +1,1247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import collections
5
+ import contextlib
6
+ import logging
7
+ import random
8
+ import struct
9
+ import sys
10
+ import traceback
11
+ import uuid
12
+ from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterable, Mapping
13
+ from types import TracebackType
14
+ from typing import Any, Literal, cast, overload
15
+
16
+ from ..exceptions import (
17
+ ConcurrencyError,
18
+ ConnectionClosed,
19
+ ConnectionClosedOK,
20
+ ProtocolError,
21
+ )
22
+ from ..frames import DATA_OPCODES, CloseCode, Frame, Opcode
23
+ from ..http11 import Request, Response
24
+ from ..protocol import CLOSED, OPEN, Event, Protocol, State
25
+ from ..typing import BytesLike, Data, DataLike, LoggerLike, Subprotocol
26
+ from .compatibility import (
27
+ TimeoutError,
28
+ aiter,
29
+ anext,
30
+ asyncio_timeout,
31
+ asyncio_timeout_at,
32
+ )
33
+ from .messages import Assembler
34
+
35
+
36
+ __all__ = ["Connection"]
37
+
38
+
39
+ class Connection(asyncio.Protocol):
40
+ """
41
+ :mod:`asyncio` implementation of a WebSocket connection.
42
+
43
+ :class:`Connection` provides APIs shared between WebSocket servers and
44
+ clients.
45
+
46
+ You shouldn't use it directly. Instead, use
47
+ :class:`~websockets.asyncio.client.ClientConnection` or
48
+ :class:`~websockets.asyncio.server.ServerConnection`.
49
+
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ protocol: Protocol,
55
+ *,
56
+ ping_interval: float | None = 20,
57
+ ping_timeout: float | None = 20,
58
+ close_timeout: float | None = 10,
59
+ max_queue: int | None | tuple[int | None, int | None] = 16,
60
+ write_limit: int | tuple[int, int | None] = 2**15,
61
+ ) -> None:
62
+ self.protocol = protocol
63
+ self.ping_interval = ping_interval
64
+ self.ping_timeout = ping_timeout
65
+ self.close_timeout = close_timeout
66
+ if isinstance(max_queue, int) or max_queue is None:
67
+ self.max_queue_high, self.max_queue_low = max_queue, None
68
+ else:
69
+ self.max_queue_high, self.max_queue_low = max_queue
70
+ if isinstance(write_limit, int):
71
+ self.write_limit_high, self.write_limit_low = write_limit, None
72
+ else:
73
+ self.write_limit_high, self.write_limit_low = write_limit
74
+
75
+ # Inject reference to this instance in the protocol's logger.
76
+ self.protocol.logger = logging.LoggerAdapter(
77
+ self.protocol.logger,
78
+ {"websocket": self},
79
+ )
80
+
81
+ # Copy attributes from the protocol for convenience.
82
+ self.id: uuid.UUID = self.protocol.id
83
+ """Unique identifier of the connection. Useful in logs."""
84
+ self.logger: LoggerLike = self.protocol.logger
85
+ """Logger for this connection."""
86
+ self.debug = self.protocol.debug
87
+
88
+ # HTTP handshake request and response.
89
+ self.request: Request | None = None
90
+ """Opening handshake request."""
91
+ self.response: Response | None = None
92
+ """Opening handshake response."""
93
+
94
+ # Event loop running this connection.
95
+ self.loop = asyncio.get_running_loop()
96
+
97
+ # Assembler turning frames into messages and serializing reads.
98
+ self.recv_messages: Assembler # initialized in connection_made
99
+
100
+ # Deadline for the closing handshake.
101
+ self.close_deadline: float | None = None
102
+
103
+ # Whether we are busy sending a fragmented message.
104
+ self.send_in_progress: asyncio.Future[None] | None = None
105
+
106
+ # Mapping of ping IDs to pong waiters, in chronological order.
107
+ self.pending_pings: dict[bytes, tuple[asyncio.Future[float], float]] = {}
108
+
109
+ self.latency: float = 0.0
110
+ """
111
+ Latency of the connection, in seconds.
112
+
113
+ Latency is defined as the round-trip time of the connection. It is
114
+ measured by sending a Ping frame and waiting for a matching Pong frame.
115
+ Before the first measurement, :attr:`latency` is ``0.0``.
116
+
117
+ By default, websockets enables a :ref:`keepalive <keepalive>` mechanism
118
+ that sends Ping frames automatically at regular intervals. You can also
119
+ send Ping frames and measure latency with :meth:`ping`.
120
+ """
121
+
122
+ # Task that sends keepalive pings. None when ping_interval is None.
123
+ self.keepalive_task: asyncio.Task[None] | None = None
124
+
125
+ # Exception raised while reading from the connection, to be chained to
126
+ # ConnectionClosed in order to show why the TCP connection dropped.
127
+ self.recv_exc: BaseException | None = None
128
+
129
+ # Completed when the TCP connection is closed and the WebSocket
130
+ # connection state becomes CLOSED.
131
+ self.connection_lost_waiter: asyncio.Future[None] = self.loop.create_future()
132
+
133
+ # Adapted from asyncio.FlowControlMixin.
134
+ self.paused: bool = False
135
+ self.drain_waiters: collections.deque[asyncio.Future[None]] = (
136
+ collections.deque()
137
+ )
138
+
139
+ # Public attributes
140
+
141
+ @property
142
+ def local_address(self) -> Any:
143
+ """
144
+ Local address of the connection.
145
+
146
+ For IPv4 connections, this is a ``(host, port)`` tuple.
147
+
148
+ The format of the address depends on the address family.
149
+ See :meth:`~socket.socket.getsockname`.
150
+
151
+ """
152
+ return self.transport.get_extra_info("sockname")
153
+
154
+ @property
155
+ def remote_address(self) -> Any:
156
+ """
157
+ Remote address of the connection.
158
+
159
+ For IPv4 connections, this is a ``(host, port)`` tuple.
160
+
161
+ The format of the address depends on the address family.
162
+ See :meth:`~socket.socket.getpeername`.
163
+
164
+ """
165
+ return self.transport.get_extra_info("peername")
166
+
167
+ @property
168
+ def state(self) -> State:
169
+ """
170
+ State of the WebSocket connection, defined in :rfc:`6455`.
171
+
172
+ This attribute is provided for completeness. Typical applications
173
+ shouldn't check its value. Instead, they should call :meth:`~recv` or
174
+ :meth:`send` and handle :exc:`~websockets.exceptions.ConnectionClosed`
175
+ exceptions.
176
+
177
+ """
178
+ return self.protocol.state
179
+
180
+ @property
181
+ def subprotocol(self) -> Subprotocol | None:
182
+ """
183
+ Subprotocol negotiated during the opening handshake.
184
+
185
+ :obj:`None` if no subprotocol was negotiated.
186
+
187
+ """
188
+ return self.protocol.subprotocol
189
+
190
+ @property
191
+ def close_code(self) -> int | None:
192
+ """
193
+ State of the WebSocket connection, defined in :rfc:`6455`.
194
+
195
+ This attribute is provided for completeness. Typical applications
196
+ shouldn't check its value. Instead, they should inspect attributes
197
+ of :exc:`~websockets.exceptions.ConnectionClosed` exceptions.
198
+
199
+ """
200
+ return self.protocol.close_code
201
+
202
+ @property
203
+ def close_reason(self) -> str | None:
204
+ """
205
+ State of the WebSocket connection, defined in :rfc:`6455`.
206
+
207
+ This attribute is provided for completeness. Typical applications
208
+ shouldn't check its value. Instead, they should inspect attributes
209
+ of :exc:`~websockets.exceptions.ConnectionClosed` exceptions.
210
+
211
+ """
212
+ return self.protocol.close_reason
213
+
214
+ # Public methods
215
+
216
+ async def __aenter__(self) -> Connection:
217
+ return self
218
+
219
+ async def __aexit__(
220
+ self,
221
+ exc_type: type[BaseException] | None,
222
+ exc_value: BaseException | None,
223
+ traceback: TracebackType | None,
224
+ ) -> None:
225
+ if exc_type is None:
226
+ await self.close()
227
+ else:
228
+ await self.close(CloseCode.INTERNAL_ERROR)
229
+
230
+ async def __aiter__(self) -> AsyncIterator[Data]:
231
+ """
232
+ Iterate on incoming messages.
233
+
234
+ The iterator calls :meth:`recv` and yields messages asynchronously in an
235
+ infinite loop.
236
+
237
+ It exits when the connection is closed normally. It raises a
238
+ :exc:`~websockets.exceptions.ConnectionClosedError` exception after a
239
+ protocol error or a network failure.
240
+
241
+ """
242
+ try:
243
+ while True:
244
+ yield await self.recv()
245
+ except ConnectionClosedOK:
246
+ return
247
+
248
+ @overload
249
+ async def recv(self, decode: Literal[True]) -> str: ...
250
+
251
+ @overload
252
+ async def recv(self, decode: Literal[False]) -> bytes: ...
253
+
254
+ @overload
255
+ async def recv(self, decode: bool | None = None) -> Data: ...
256
+
257
+ async def recv(self, decode: bool | None = None) -> Data:
258
+ """
259
+ Receive the next message.
260
+
261
+ When the connection is closed, :meth:`recv` raises
262
+ :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises
263
+ :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal closure
264
+ and :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol
265
+ error or a network failure. This is how you detect the end of the
266
+ message stream.
267
+
268
+ Canceling :meth:`recv` is safe. There's no risk of losing data. The next
269
+ invocation of :meth:`recv` will return the next message.
270
+
271
+ This makes it possible to enforce a timeout by wrapping :meth:`recv` in
272
+ :func:`~asyncio.timeout` or :func:`~asyncio.wait_for`.
273
+
274
+ When the message is fragmented, :meth:`recv` waits until all fragments
275
+ are received, reassembles them, and returns the whole message.
276
+
277
+ Args:
278
+ decode: Set this flag to override the default behavior of returning
279
+ :class:`str` or :class:`bytes`. See below for details.
280
+
281
+ Returns:
282
+ A string (:class:`str`) for a Text_ frame or a bytestring
283
+ (:class:`bytes`) for a Binary_ frame.
284
+
285
+ .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
286
+ .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
287
+
288
+ You may override this behavior with the ``decode`` argument:
289
+
290
+ * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and
291
+ return a bytestring (:class:`bytes`). This improves performance
292
+ when decoding isn't needed, for example if the message contains
293
+ JSON and you're using a JSON library that expects a bytestring.
294
+ * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames and
295
+ return strings (:class:`str`). This may be useful for servers that
296
+ send binary frames instead of text frames.
297
+
298
+ Raises:
299
+ ConnectionClosed: When the connection is closed.
300
+ ConcurrencyError: If two coroutines call :meth:`recv` or
301
+ :meth:`recv_streaming` concurrently.
302
+
303
+ """
304
+ try:
305
+ return await self.recv_messages.get(decode)
306
+ except EOFError:
307
+ pass
308
+ # fallthrough
309
+ except ConcurrencyError:
310
+ raise ConcurrencyError(
311
+ "cannot call recv while another coroutine "
312
+ "is already running recv or recv_streaming"
313
+ ) from None
314
+ except UnicodeDecodeError as exc:
315
+ async with self.send_context():
316
+ self.protocol.fail(
317
+ CloseCode.INVALID_DATA,
318
+ f"{exc.reason} at position {exc.start}",
319
+ )
320
+ # fallthrough
321
+
322
+ # Wait for the protocol state to be CLOSED before accessing close_exc.
323
+ await asyncio.shield(self.connection_lost_waiter)
324
+ raise self.protocol.close_exc from self.recv_exc
325
+
326
+ @overload
327
+ def recv_streaming(self, decode: Literal[True]) -> AsyncIterator[str]: ...
328
+
329
+ @overload
330
+ def recv_streaming(self, decode: Literal[False]) -> AsyncIterator[bytes]: ...
331
+
332
+ @overload
333
+ def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: ...
334
+
335
+ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]:
336
+ """
337
+ Receive the next message frame by frame.
338
+
339
+ This method is designed for receiving fragmented messages. It returns an
340
+ asynchronous iterator that yields each fragment as it is received. This
341
+ iterator must be fully consumed. Else, future calls to :meth:`recv` or
342
+ :meth:`recv_streaming` will raise
343
+ :exc:`~websockets.exceptions.ConcurrencyError`, making the connection
344
+ unusable.
345
+
346
+ :meth:`recv_streaming` raises the same exceptions as :meth:`recv`.
347
+
348
+ Canceling :meth:`recv_streaming` before receiving the first frame is
349
+ safe. Canceling it after receiving one or more frames leaves the
350
+ iterator in a partially consumed state, making the connection unusable.
351
+ Instead, you should close the connection with :meth:`close`.
352
+
353
+ Args:
354
+ decode: Set this flag to override the default behavior of returning
355
+ :class:`str` or :class:`bytes`. See below for details.
356
+
357
+ Returns:
358
+ An iterator of strings (:class:`str`) for a Text_ frame or
359
+ bytestrings (:class:`bytes`) for a Binary_ frame.
360
+
361
+ .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
362
+ .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
363
+
364
+ You may override this behavior with the ``decode`` argument:
365
+
366
+ * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and
367
+ yield bytestrings (:class:`bytes`). This improves performance
368
+ when decoding isn't needed.
369
+ * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames and
370
+ yield strings (:class:`str`). This may be useful for servers that
371
+ send binary frames instead of text frames.
372
+
373
+ Raises:
374
+ ConnectionClosed: When the connection is closed.
375
+ ConcurrencyError: If two coroutines call :meth:`recv` or
376
+ :meth:`recv_streaming` concurrently.
377
+
378
+ """
379
+ try:
380
+ async for frame in self.recv_messages.get_iter(decode):
381
+ yield frame
382
+ return
383
+ except EOFError:
384
+ pass
385
+ # fallthrough
386
+ except ConcurrencyError:
387
+ raise ConcurrencyError(
388
+ "cannot call recv_streaming while another coroutine "
389
+ "is already running recv or recv_streaming"
390
+ ) from None
391
+ except UnicodeDecodeError as exc:
392
+ async with self.send_context():
393
+ self.protocol.fail(
394
+ CloseCode.INVALID_DATA,
395
+ f"{exc.reason} at position {exc.start}",
396
+ )
397
+ # fallthrough
398
+
399
+ # Wait for the protocol state to be CLOSED before accessing close_exc.
400
+ await asyncio.shield(self.connection_lost_waiter)
401
+ raise self.protocol.close_exc from self.recv_exc
402
+
403
+ async def send(
404
+ self,
405
+ message: DataLike | Iterable[DataLike] | AsyncIterable[DataLike],
406
+ text: bool | None = None,
407
+ ) -> None:
408
+ """
409
+ Send a message.
410
+
411
+ A string (:class:`str`) is sent as a Text_ frame. A bytestring or
412
+ bytes-like object (:class:`bytes`, :class:`bytearray`, or
413
+ :class:`memoryview`) is sent as a Binary_ frame.
414
+
415
+ .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
416
+ .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
417
+
418
+ You may override this behavior with the ``text`` argument:
419
+
420
+ * Set ``text=True`` to send an UTF-8 bytestring or bytes-like object
421
+ (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) in a
422
+ Text_ frame. This improves performance when the message is already
423
+ UTF-8 encoded, for example if the message contains JSON and you're
424
+ using a JSON library that produces a bytestring.
425
+ * Set ``text=False`` to send a string (:class:`str`) in a Binary_
426
+ frame. This may be useful for servers that expect binary frames
427
+ instead of text frames.
428
+
429
+ :meth:`send` also accepts an iterable or asynchronous iterable of
430
+ strings, bytestrings, or bytes-like objects to enable fragmentation_.
431
+ Each item is treated as a message fragment and sent in its own frame.
432
+ All items must be of the same type, or else :meth:`send` will raise a
433
+ :exc:`TypeError` and the connection will be closed.
434
+
435
+ .. _fragmentation: https://datatracker.ietf.org/doc/html/rfc6455#section-5.4
436
+
437
+ :meth:`send` rejects dict-like objects because this is often an error.
438
+ (If you really want to send the keys of a dict-like object as fragments,
439
+ call its :meth:`~dict.keys` method and pass the result to :meth:`send`.)
440
+
441
+ Canceling :meth:`send` is discouraged. Instead, you should close the
442
+ connection with :meth:`close`. Indeed, there are only two situations
443
+ where :meth:`send` may yield control to the event loop and then get
444
+ canceled; in both cases, :meth:`close` has the same effect and the
445
+ effect is more obvious:
446
+
447
+ 1. The write buffer is full. If you don't want to wait until enough
448
+ data is sent, your only alternative is to close the connection.
449
+ :meth:`close` will likely time out then abort the TCP connection.
450
+ 2. ``message`` is an asynchronous iterator that yields control.
451
+ Stopping in the middle of a fragmented message will cause a
452
+ protocol error and the connection will be closed.
453
+
454
+ When the connection is closed, :meth:`send` raises
455
+ :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it
456
+ raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal
457
+ connection closure and
458
+ :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol
459
+ error or a network failure.
460
+
461
+ Args:
462
+ message: Message to send.
463
+
464
+ Raises:
465
+ ConnectionClosed: When the connection is closed.
466
+ TypeError: If ``message`` doesn't have a supported type.
467
+
468
+ """
469
+ # While sending a fragmented message, prevent sending other messages
470
+ # until all fragments are sent.
471
+ while self.send_in_progress is not None:
472
+ await asyncio.shield(self.send_in_progress)
473
+
474
+ # Unfragmented message -- this case must be handled first because
475
+ # strings and bytes-like objects are iterable.
476
+
477
+ if isinstance(message, str):
478
+ async with self.send_context():
479
+ if text is False:
480
+ self.protocol.send_binary(message.encode())
481
+ else:
482
+ self.protocol.send_text(message.encode())
483
+
484
+ elif isinstance(message, BytesLike):
485
+ async with self.send_context():
486
+ if text is True:
487
+ self.protocol.send_text(message)
488
+ else:
489
+ self.protocol.send_binary(message)
490
+
491
+ # Catch a common mistake -- passing a dict to send().
492
+
493
+ elif isinstance(message, Mapping):
494
+ raise TypeError("data is a dict-like object")
495
+
496
+ # Fragmented message -- regular iterator.
497
+
498
+ elif isinstance(message, Iterable):
499
+ chunks = iter(message)
500
+ try:
501
+ chunk = next(chunks)
502
+ except StopIteration:
503
+ return
504
+
505
+ assert self.send_in_progress is None
506
+ self.send_in_progress = self.loop.create_future()
507
+ try:
508
+ # First fragment.
509
+ if isinstance(chunk, str):
510
+ async with self.send_context():
511
+ if text is False:
512
+ self.protocol.send_binary(chunk.encode(), fin=False)
513
+ else:
514
+ self.protocol.send_text(chunk.encode(), fin=False)
515
+ encode = True
516
+ elif isinstance(chunk, BytesLike):
517
+ async with self.send_context():
518
+ if text is True:
519
+ self.protocol.send_text(chunk, fin=False)
520
+ else:
521
+ self.protocol.send_binary(chunk, fin=False)
522
+ encode = False
523
+ else:
524
+ raise TypeError("iterable must contain bytes or str")
525
+
526
+ # Other fragments
527
+ for chunk in chunks:
528
+ if isinstance(chunk, str) and encode:
529
+ async with self.send_context():
530
+ self.protocol.send_continuation(chunk.encode(), fin=False)
531
+ elif isinstance(chunk, BytesLike) and not encode:
532
+ async with self.send_context():
533
+ self.protocol.send_continuation(chunk, fin=False)
534
+ else:
535
+ raise TypeError("iterable must contain uniform types")
536
+
537
+ # Final fragment.
538
+ async with self.send_context():
539
+ self.protocol.send_continuation(b"", fin=True)
540
+
541
+ except Exception:
542
+ # We're half-way through a fragmented message and we can't
543
+ # complete it. This makes the connection unusable.
544
+ async with self.send_context():
545
+ self.protocol.fail(
546
+ CloseCode.INTERNAL_ERROR,
547
+ "error in fragmented message",
548
+ )
549
+ raise
550
+
551
+ finally:
552
+ self.send_in_progress.set_result(None)
553
+ self.send_in_progress = None
554
+
555
+ # Fragmented message -- async iterator.
556
+
557
+ elif isinstance(message, AsyncIterable):
558
+ achunks = aiter(message)
559
+ try:
560
+ chunk = await anext(achunks)
561
+ except StopAsyncIteration:
562
+ return
563
+
564
+ assert self.send_in_progress is None
565
+ self.send_in_progress = self.loop.create_future()
566
+ try:
567
+ # First fragment.
568
+ if isinstance(chunk, str):
569
+ if text is False:
570
+ async with self.send_context():
571
+ self.protocol.send_binary(chunk.encode(), fin=False)
572
+ else:
573
+ async with self.send_context():
574
+ self.protocol.send_text(chunk.encode(), fin=False)
575
+ encode = True
576
+ elif isinstance(chunk, BytesLike):
577
+ if text is True:
578
+ async with self.send_context():
579
+ self.protocol.send_text(chunk, fin=False)
580
+ else:
581
+ async with self.send_context():
582
+ self.protocol.send_binary(chunk, fin=False)
583
+ encode = False
584
+ else:
585
+ raise TypeError("async iterable must contain bytes or str")
586
+
587
+ # Other fragments
588
+ async for chunk in achunks:
589
+ if isinstance(chunk, str) and encode:
590
+ async with self.send_context():
591
+ self.protocol.send_continuation(chunk.encode(), fin=False)
592
+ elif isinstance(chunk, BytesLike) and not encode:
593
+ async with self.send_context():
594
+ self.protocol.send_continuation(chunk, fin=False)
595
+ else:
596
+ raise TypeError("async iterable must contain uniform types")
597
+
598
+ # Final fragment.
599
+ async with self.send_context():
600
+ self.protocol.send_continuation(b"", fin=True)
601
+
602
+ except Exception:
603
+ # We're half-way through a fragmented message and we can't
604
+ # complete it. This makes the connection unusable.
605
+ async with self.send_context():
606
+ self.protocol.fail(
607
+ CloseCode.INTERNAL_ERROR,
608
+ "error in fragmented message",
609
+ )
610
+ raise
611
+
612
+ finally:
613
+ self.send_in_progress.set_result(None)
614
+ self.send_in_progress = None
615
+
616
+ else:
617
+ raise TypeError("data must be str, bytes, iterable, or async iterable")
618
+
619
+ async def close(
620
+ self,
621
+ code: CloseCode | int = CloseCode.NORMAL_CLOSURE,
622
+ reason: str = "",
623
+ ) -> None:
624
+ """
625
+ Perform the closing handshake.
626
+
627
+ :meth:`close` waits for the other end to complete the handshake and
628
+ for the TCP connection to terminate.
629
+
630
+ :meth:`close` is idempotent: it doesn't do anything once the
631
+ connection is closed.
632
+
633
+ Args:
634
+ code: WebSocket close code.
635
+ reason: WebSocket close reason.
636
+
637
+ """
638
+ try:
639
+ # The context manager takes care of waiting for the TCP connection
640
+ # to terminate after calling a method that sends a close frame.
641
+ async with self.send_context():
642
+ if self.send_in_progress is not None:
643
+ self.protocol.fail(
644
+ CloseCode.INTERNAL_ERROR,
645
+ "close during fragmented message",
646
+ )
647
+ else:
648
+ self.protocol.send_close(code, reason)
649
+ except ConnectionClosed:
650
+ # Ignore ConnectionClosed exceptions raised from send_context().
651
+ # They mean that the connection is closed, which was the goal.
652
+ pass
653
+
654
+ async def wait_closed(self) -> None:
655
+ """
656
+ Wait until the connection is closed.
657
+
658
+ :meth:`wait_closed` waits for the closing handshake to complete and for
659
+ the TCP connection to terminate.
660
+
661
+ """
662
+ await asyncio.shield(self.connection_lost_waiter)
663
+
664
+ async def ping(self, data: DataLike | None = None) -> Awaitable[float]:
665
+ """
666
+ Send a Ping_.
667
+
668
+ .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2
669
+
670
+ A ping may serve as a keepalive or as a check that the remote endpoint
671
+ received all messages up to this point
672
+
673
+ Args:
674
+ data: Payload of the ping. A :class:`str` will be encoded to UTF-8.
675
+ If ``data`` is :obj:`None`, the payload is four random bytes.
676
+
677
+ Returns:
678
+ A future that will be completed when the corresponding pong is
679
+ received. You can ignore it if you don't intend to wait. The result
680
+ of the future is the latency of the connection in seconds.
681
+
682
+ ::
683
+
684
+ pong_received = await ws.ping()
685
+ # only if you want to wait for the corresponding pong
686
+ latency = await pong_received
687
+
688
+ Raises:
689
+ ConnectionClosed: When the connection is closed.
690
+ ConcurrencyError: If another ping was sent with the same data and
691
+ the corresponding pong wasn't received yet.
692
+
693
+ """
694
+ if isinstance(data, BytesLike):
695
+ data = bytes(data)
696
+ elif isinstance(data, str):
697
+ data = data.encode()
698
+ elif data is not None:
699
+ raise TypeError("data must be str or bytes-like")
700
+
701
+ async with self.send_context():
702
+ # Protect against duplicates if a payload is explicitly set.
703
+ if data in self.pending_pings:
704
+ raise ConcurrencyError("already waiting for a pong with the same data")
705
+
706
+ # Generate a unique random payload otherwise.
707
+ while data is None or data in self.pending_pings:
708
+ data = struct.pack("!I", random.getrandbits(32))
709
+
710
+ pong_received = self.loop.create_future()
711
+ ping_timestamp = self.loop.time()
712
+ # The event loop's default clock is time.monotonic(). Its resolution
713
+ # is a bit low on Windows (~16ms). This is improved in Python 3.13.
714
+ self.pending_pings[data] = (pong_received, ping_timestamp)
715
+ self.protocol.send_ping(data)
716
+ return pong_received
717
+
718
+ async def pong(self, data: DataLike = b"") -> None:
719
+ """
720
+ Send a Pong_.
721
+
722
+ .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3
723
+
724
+ An unsolicited pong may serve as a unidirectional heartbeat.
725
+
726
+ Args:
727
+ data: Payload of the pong. A :class:`str` will be encoded to UTF-8.
728
+
729
+ Raises:
730
+ ConnectionClosed: When the connection is closed.
731
+
732
+ """
733
+ if isinstance(data, BytesLike):
734
+ data = bytes(data)
735
+ elif isinstance(data, str):
736
+ data = data.encode()
737
+ else:
738
+ raise TypeError("data must be str or bytes-like")
739
+
740
+ async with self.send_context():
741
+ self.protocol.send_pong(data)
742
+
743
+ # Private methods
744
+
745
+ def process_event(self, event: Event) -> None:
746
+ """
747
+ Process one incoming event.
748
+
749
+ This method is overridden in subclasses to handle the handshake.
750
+
751
+ """
752
+ assert isinstance(event, Frame)
753
+ if event.opcode in DATA_OPCODES:
754
+ self.recv_messages.put(event)
755
+
756
+ if event.opcode is Opcode.PONG:
757
+ self.acknowledge_pings(bytes(event.data))
758
+
759
+ def acknowledge_pings(self, data: bytes) -> None:
760
+ """
761
+ Acknowledge pings when receiving a pong.
762
+
763
+ """
764
+ # Ignore unsolicited pong.
765
+ if data not in self.pending_pings:
766
+ return
767
+
768
+ pong_timestamp = self.loop.time()
769
+
770
+ # Sending a pong for only the most recent ping is legal.
771
+ # Acknowledge all previous pings too in that case.
772
+ ping_id = None
773
+ ping_ids = []
774
+ for ping_id, (pong_received, ping_timestamp) in self.pending_pings.items():
775
+ ping_ids.append(ping_id)
776
+ latency = pong_timestamp - ping_timestamp
777
+ if not pong_received.done():
778
+ pong_received.set_result(latency)
779
+ if ping_id == data:
780
+ self.latency = latency
781
+ break
782
+ else:
783
+ raise AssertionError("solicited pong not found in pings")
784
+
785
+ # Remove acknowledged pings from self.pending_pings.
786
+ for ping_id in ping_ids:
787
+ del self.pending_pings[ping_id]
788
+
789
+ def terminate_pending_pings(self) -> None:
790
+ """
791
+ Raise ConnectionClosed in pending pings when the connection is closed.
792
+
793
+ """
794
+ assert self.protocol.state is CLOSED
795
+ exc = self.protocol.close_exc
796
+
797
+ for pong_received, _ping_timestamp in self.pending_pings.values():
798
+ if not pong_received.done():
799
+ pong_received.set_exception(exc)
800
+ # If the exception is never retrieved, it will be logged when ping
801
+ # is garbage-collected. This is confusing for users.
802
+ # Given that ping is done (with an exception), canceling it does
803
+ # nothing, but it prevents logging the exception.
804
+ pong_received.cancel()
805
+
806
+ self.pending_pings.clear()
807
+
808
+ async def keepalive(self) -> None:
809
+ """
810
+ Send a Ping frame and wait for a Pong frame at regular intervals.
811
+
812
+ """
813
+ assert self.ping_interval is not None
814
+ latency = 0.0
815
+ try:
816
+ while True:
817
+ # If self.ping_timeout > latency > self.ping_interval,
818
+ # pings will be sent immediately after receiving pongs.
819
+ # The period will be longer than self.ping_interval.
820
+ await asyncio.sleep(self.ping_interval - latency)
821
+
822
+ # This cannot raise ConnectionClosed when the connection is
823
+ # closing because ping(), via send_context(), waits for the
824
+ # connection to be closed before raising ConnectionClosed.
825
+ # However, connection_lost() cancels keepalive_task before
826
+ # it gets a chance to resume excuting.
827
+ pong_received = await self.ping()
828
+ if self.debug:
829
+ self.logger.debug("% sent keepalive ping")
830
+
831
+ if self.ping_timeout is not None:
832
+ try:
833
+ async with asyncio_timeout(self.ping_timeout):
834
+ # connection_lost cancels keepalive immediately
835
+ # after setting a ConnectionClosed exception on
836
+ # pong_received. A CancelledError is raised here,
837
+ # not a ConnectionClosed exception.
838
+ latency = await pong_received
839
+ if self.debug:
840
+ self.logger.debug("% received keepalive pong")
841
+ except asyncio.TimeoutError:
842
+ if self.debug:
843
+ self.logger.debug("- timed out waiting for keepalive pong")
844
+ async with self.send_context():
845
+ self.protocol.fail(
846
+ CloseCode.INTERNAL_ERROR,
847
+ "keepalive ping timeout",
848
+ )
849
+ raise AssertionError(
850
+ "send_context() should wait for connection_lost(), "
851
+ "which cancels keepalive()"
852
+ )
853
+ except Exception:
854
+ self.logger.error("keepalive ping failed", exc_info=True)
855
+
856
+ def start_keepalive(self) -> None:
857
+ """
858
+ Run :meth:`keepalive` in a task, unless keepalive is disabled.
859
+
860
+ """
861
+ if self.ping_interval is not None:
862
+ self.keepalive_task = self.loop.create_task(self.keepalive())
863
+
864
+ @contextlib.asynccontextmanager
865
+ async def send_context(
866
+ self,
867
+ *,
868
+ expected_state: State = OPEN, # CONNECTING during the opening handshake
869
+ ) -> AsyncIterator[None]:
870
+ """
871
+ Create a context for writing to the connection from user code.
872
+
873
+ On entry, :meth:`send_context` checks that the connection is open; on
874
+ exit, it writes outgoing data to the socket::
875
+
876
+ async with self.send_context():
877
+ self.protocol.send_text(message.encode())
878
+
879
+ When the connection isn't open on entry, when the connection is expected
880
+ to close on exit, or when an unexpected error happens, terminating the
881
+ connection, :meth:`send_context` waits until the connection is closed
882
+ then raises :exc:`~websockets.exceptions.ConnectionClosed`.
883
+
884
+ """
885
+ # Should we wait until the connection is closed?
886
+ wait_for_close = False
887
+ # Should we close the transport and raise ConnectionClosed?
888
+ raise_close_exc = False
889
+ # What exception should we chain ConnectionClosed to?
890
+ original_exc: BaseException | None = None
891
+
892
+ if self.protocol.state is expected_state:
893
+ # Let the caller interact with the protocol.
894
+ try:
895
+ yield
896
+ except (ProtocolError, ConcurrencyError):
897
+ # The protocol state wasn't changed. Exit immediately.
898
+ raise
899
+ except Exception as exc:
900
+ self.logger.error("unexpected internal error", exc_info=True)
901
+ # This branch should never run. It's a safety net in case of
902
+ # bugs. Since we don't know what happened, we will close the
903
+ # connection and raise the exception to the caller.
904
+ wait_for_close = False
905
+ raise_close_exc = True
906
+ original_exc = exc
907
+ else:
908
+ # Check if the connection is expected to close soon.
909
+ if self.protocol.close_expected():
910
+ wait_for_close = True
911
+ # Set the close deadline based on the close timeout.
912
+ # Since we tested earlier that protocol.state is OPEN
913
+ # (or CONNECTING), self.close_deadline is still None.
914
+ assert self.close_deadline is None
915
+ if self.close_timeout is not None:
916
+ self.close_deadline = self.loop.time() + self.close_timeout
917
+ # Write outgoing data to the socket with flow control.
918
+ try:
919
+ self.send_data()
920
+ await self.drain()
921
+ except Exception as exc:
922
+ if self.debug:
923
+ self.logger.debug(
924
+ "! error while sending data",
925
+ exc_info=True,
926
+ )
927
+ # While the only expected exception here is OSError,
928
+ # other exceptions would be treated identically.
929
+ wait_for_close = False
930
+ raise_close_exc = True
931
+ original_exc = exc
932
+
933
+ else: # self.protocol.state is not expected_state
934
+ # Minor layering violation: we assume that the connection
935
+ # will be closing soon if it isn't in the expected state.
936
+ wait_for_close = True
937
+ # Calculate close_deadline if it wasn't set yet.
938
+ if self.close_deadline is None:
939
+ if self.close_timeout is not None:
940
+ self.close_deadline = self.loop.time() + self.close_timeout
941
+ raise_close_exc = True
942
+
943
+ # If the connection is expected to close soon and the close timeout
944
+ # elapses, close the socket to terminate the connection.
945
+ if wait_for_close:
946
+ try:
947
+ async with asyncio_timeout_at(self.close_deadline):
948
+ await asyncio.shield(self.connection_lost_waiter)
949
+ except TimeoutError:
950
+ # There's no risk of overwriting another error because
951
+ # original_exc is never set when wait_for_close is True.
952
+ assert original_exc is None
953
+ original_exc = TimeoutError("timed out while closing connection")
954
+ # Set recv_exc before closing the transport in order to get
955
+ # proper exception reporting.
956
+ raise_close_exc = True
957
+ self.set_recv_exc(original_exc)
958
+
959
+ # If an error occurred, close the transport to terminate the connection and
960
+ # raise an exception.
961
+ if raise_close_exc:
962
+ self.transport.abort()
963
+ # Wait for the protocol state to be CLOSED before accessing close_exc.
964
+ await asyncio.shield(self.connection_lost_waiter)
965
+ raise self.protocol.close_exc from original_exc
966
+
967
+ def send_data(self) -> None:
968
+ """
969
+ Send outgoing data.
970
+
971
+ """
972
+ for data in self.protocol.data_to_send():
973
+ if data:
974
+ self.transport.write(data)
975
+ else:
976
+ # Half-close the TCP connection when possible i.e. no TLS.
977
+ if self.transport.can_write_eof():
978
+ if self.debug:
979
+ self.logger.debug("x half-closing TCP connection")
980
+ # write_eof() doesn't document which exceptions it raises.
981
+ # OSError is plausible. uvloop can raise RuntimeError here.
982
+ try:
983
+ self.transport.write_eof()
984
+ except Exception: # pragma: no cover
985
+ pass
986
+ # Else, close the TCP connection.
987
+ else: # pragma: no cover
988
+ if self.debug:
989
+ self.logger.debug("x closing TCP connection")
990
+ self.transport.close()
991
+
992
+ def set_recv_exc(self, exc: BaseException | None) -> None:
993
+ """
994
+ Set recv_exc, if not set yet.
995
+
996
+ This method must be called only from connection callbacks.
997
+
998
+ """
999
+ if self.recv_exc is None:
1000
+ self.recv_exc = exc
1001
+
1002
+ # asyncio.Protocol methods
1003
+
1004
+ # Connection callbacks
1005
+
1006
+ def connection_made(self, transport: asyncio.BaseTransport) -> None:
1007
+ transport = cast(asyncio.Transport, transport)
1008
+ self.recv_messages = Assembler(
1009
+ self.max_queue_high,
1010
+ self.max_queue_low,
1011
+ pause=transport.pause_reading,
1012
+ resume=transport.resume_reading,
1013
+ )
1014
+ transport.set_write_buffer_limits(
1015
+ self.write_limit_high,
1016
+ self.write_limit_low,
1017
+ )
1018
+ self.transport = transport
1019
+
1020
+ def connection_lost(self, exc: Exception | None) -> None:
1021
+ # Calling protocol.receive_eof() is safe because it's idempotent.
1022
+ # This guarantees that the protocol state becomes CLOSED.
1023
+ self.protocol.receive_eof()
1024
+ assert self.protocol.state is CLOSED
1025
+
1026
+ self.set_recv_exc(exc)
1027
+
1028
+ # Abort recv() and pending pings with a ConnectionClosed exception.
1029
+ self.recv_messages.close()
1030
+ self.terminate_pending_pings()
1031
+
1032
+ if self.keepalive_task is not None:
1033
+ self.keepalive_task.cancel()
1034
+
1035
+ # If self.connection_lost_waiter isn't pending, that's a bug, because:
1036
+ # - it's set only here in connection_lost() which is called only once;
1037
+ # - it must never be canceled.
1038
+ self.connection_lost_waiter.set_result(None)
1039
+
1040
+ # Adapted from asyncio.streams.FlowControlMixin
1041
+ if self.paused: # pragma: no cover
1042
+ self.paused = False
1043
+ for waiter in self.drain_waiters:
1044
+ if not waiter.done():
1045
+ if exc is None:
1046
+ waiter.set_result(None)
1047
+ else:
1048
+ waiter.set_exception(exc)
1049
+
1050
+ # Flow control callbacks
1051
+
1052
+ def pause_writing(self) -> None: # pragma: no cover
1053
+ # Adapted from asyncio.streams.FlowControlMixin
1054
+ assert not self.paused
1055
+ self.paused = True
1056
+
1057
+ def resume_writing(self) -> None: # pragma: no cover
1058
+ # Adapted from asyncio.streams.FlowControlMixin
1059
+ assert self.paused
1060
+ self.paused = False
1061
+ for waiter in self.drain_waiters:
1062
+ if not waiter.done():
1063
+ waiter.set_result(None)
1064
+
1065
+ async def drain(self) -> None: # pragma: no cover
1066
+ # We don't check if the connection is closed because we call drain()
1067
+ # immediately after write() and write() would fail in that case.
1068
+
1069
+ # Adapted from asyncio.streams.StreamWriter
1070
+ # Yield to the event loop so that connection_lost() may be called.
1071
+ if self.transport.is_closing():
1072
+ await asyncio.sleep(0)
1073
+
1074
+ # Adapted from asyncio.streams.FlowControlMixin
1075
+ if self.paused:
1076
+ waiter = self.loop.create_future()
1077
+ self.drain_waiters.append(waiter)
1078
+ try:
1079
+ await waiter
1080
+ finally:
1081
+ self.drain_waiters.remove(waiter)
1082
+
1083
+ # Streaming protocol callbacks
1084
+
1085
+ def data_received(self, data: bytes) -> None:
1086
+ # Feed incoming data to the protocol.
1087
+ self.protocol.receive_data(data)
1088
+
1089
+ # This isn't expected to raise an exception.
1090
+ events = self.protocol.events_received()
1091
+
1092
+ # Write outgoing data to the transport.
1093
+ try:
1094
+ self.send_data()
1095
+ except Exception as exc:
1096
+ if self.debug:
1097
+ self.logger.debug("! error while sending data", exc_info=True)
1098
+ self.set_recv_exc(exc)
1099
+
1100
+ # If needed, set the close deadline based on the close timeout.
1101
+ if self.protocol.close_expected():
1102
+ if self.close_deadline is None:
1103
+ if self.close_timeout is not None:
1104
+ self.close_deadline = self.loop.time() + self.close_timeout
1105
+
1106
+ # If self.send_data raised an exception, then events are lost.
1107
+ # Given that automatic responses write small amounts of data,
1108
+ # this should be uncommon, so we don't handle the edge case.
1109
+
1110
+ for event in events:
1111
+ # This isn't expected to raise an exception.
1112
+ self.process_event(event)
1113
+
1114
+ def eof_received(self) -> None:
1115
+ # Feed the end of the data stream to the protocol.
1116
+ self.protocol.receive_eof()
1117
+
1118
+ # This isn't expected to raise an exception.
1119
+ events = self.protocol.events_received()
1120
+
1121
+ # There is no error handling because send_data() can only write
1122
+ # the end of the data stream and it handles errors by itself.
1123
+ self.send_data()
1124
+
1125
+ # This code path is triggered when receiving an HTTP response
1126
+ # without a Content-Length header. This is the only case where
1127
+ # reading until EOF generates an event; all other events have
1128
+ # a known length. Ignore for coverage measurement because tests
1129
+ # are in test_client.py rather than test_connection.py.
1130
+ for event in events: # pragma: no cover
1131
+ # This isn't expected to raise an exception.
1132
+ self.process_event(event)
1133
+
1134
+ # The WebSocket protocol has its own closing handshake: endpoints close
1135
+ # the TCP or TLS connection after sending and receiving a close frame.
1136
+ # As a consequence, they never need to write after receiving EOF, so
1137
+ # there's no reason to keep the transport open by returning True.
1138
+ # Besides, that doesn't work on TLS connections.
1139
+
1140
+
1141
+ # broadcast() is defined in the connection module even though it's primarily
1142
+ # used by servers and documented in the server module because it works with
1143
+ # client connections too and because it's easier to test together with the
1144
+ # Connection class.
1145
+
1146
+
1147
+ def broadcast(
1148
+ connections: Iterable[Connection],
1149
+ message: DataLike,
1150
+ raise_exceptions: bool = False,
1151
+ ) -> None:
1152
+ """
1153
+ Broadcast a message to several WebSocket connections.
1154
+
1155
+ A string (:class:`str`) is sent as a Text_ frame. A bytestring or bytes-like
1156
+ object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent
1157
+ as a Binary_ frame.
1158
+
1159
+ .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
1160
+ .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
1161
+
1162
+ :func:`broadcast` pushes the message synchronously to all connections even
1163
+ if their write buffers are overflowing. There's no backpressure.
1164
+
1165
+ If you broadcast messages faster than a connection can handle them, messages
1166
+ will pile up in its write buffer until the connection times out. Keep
1167
+ ``ping_interval`` and ``ping_timeout`` low to prevent excessive memory usage
1168
+ from slow connections.
1169
+
1170
+ Unlike :meth:`~websockets.asyncio.connection.Connection.send`,
1171
+ :func:`broadcast` doesn't support sending fragmented messages. Indeed,
1172
+ fragmentation is useful for sending large messages without buffering them in
1173
+ memory, while :func:`broadcast` buffers one copy per connection as fast as
1174
+ possible.
1175
+
1176
+ :func:`broadcast` skips connections that aren't open in order to avoid
1177
+ errors on connections where the closing handshake is in progress.
1178
+
1179
+ :func:`broadcast` ignores failures to write the message on some connections.
1180
+ It continues writing to other connections. On Python 3.11 and above, you may
1181
+ set ``raise_exceptions`` to :obj:`True` to record failures and raise all
1182
+ exceptions in a :pep:`654` :exc:`ExceptionGroup`.
1183
+
1184
+ While :func:`broadcast` makes more sense for servers, it works identically
1185
+ with clients, if you have a use case for opening connections to many servers
1186
+ and broadcasting a message to them.
1187
+
1188
+ Args:
1189
+ websockets: WebSocket connections to which the message will be sent.
1190
+ message: Message to send.
1191
+ raise_exceptions: Whether to raise an exception in case of failures.
1192
+
1193
+ Raises:
1194
+ TypeError: If ``message`` doesn't have a supported type.
1195
+
1196
+ """
1197
+ if isinstance(message, str):
1198
+ send_method = "send_text"
1199
+ message = message.encode()
1200
+ elif isinstance(message, BytesLike):
1201
+ send_method = "send_binary"
1202
+ else:
1203
+ raise TypeError("data must be str or bytes")
1204
+
1205
+ if raise_exceptions:
1206
+ if sys.version_info[:2] < (3, 11): # pragma: no cover
1207
+ raise ValueError("raise_exceptions requires at least Python 3.11")
1208
+ exceptions: list[Exception] = []
1209
+
1210
+ for connection in connections:
1211
+ exception: Exception
1212
+
1213
+ if connection.protocol.state is not OPEN:
1214
+ continue
1215
+
1216
+ if connection.send_in_progress is not None:
1217
+ if raise_exceptions:
1218
+ exception = ConcurrencyError("sending a fragmented message")
1219
+ exceptions.append(exception)
1220
+ else:
1221
+ connection.logger.warning(
1222
+ "skipped broadcast: sending a fragmented message",
1223
+ )
1224
+ continue
1225
+
1226
+ try:
1227
+ # Call connection.protocol.send_text or send_binary.
1228
+ # Either way, message is already converted to bytes.
1229
+ getattr(connection.protocol, send_method)(message)
1230
+ connection.send_data()
1231
+ except Exception as write_exception:
1232
+ if raise_exceptions:
1233
+ exception = RuntimeError("failed to write message")
1234
+ exception.__cause__ = write_exception
1235
+ exceptions.append(exception)
1236
+ else:
1237
+ connection.logger.warning(
1238
+ "skipped broadcast: failed to write message: %s",
1239
+ traceback.format_exception_only(write_exception)[0].strip(),
1240
+ )
1241
+
1242
+ if raise_exceptions and exceptions:
1243
+ raise ExceptionGroup("skipped broadcast", exceptions)
1244
+
1245
+
1246
+ # Pretend that broadcast is actually defined in the server module.
1247
+ broadcast.__module__ = "websockets.asyncio.server"
source/websockets/asyncio/messages.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import codecs
5
+ import collections
6
+ from collections.abc import AsyncIterator, Iterable
7
+ from typing import Any, Callable, Generic, Literal, TypeVar, overload
8
+
9
+ from ..exceptions import ConcurrencyError
10
+ from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
11
+ from ..typing import Data
12
+
13
+
14
+ __all__ = ["Assembler"]
15
+
16
+ UTF8Decoder = codecs.getincrementaldecoder("utf-8")
17
+
18
+ T = TypeVar("T")
19
+
20
+
21
+ class SimpleQueue(Generic[T]):
22
+ """
23
+ Simplified version of :class:`asyncio.Queue`.
24
+
25
+ Provides only the subset of functionality needed by :class:`Assembler`.
26
+
27
+ """
28
+
29
+ def __init__(self) -> None:
30
+ self.loop = asyncio.get_running_loop()
31
+ self.get_waiter: asyncio.Future[None] | None = None
32
+ self.queue: collections.deque[T] = collections.deque()
33
+
34
+ def __len__(self) -> int:
35
+ return len(self.queue)
36
+
37
+ def put(self, item: T) -> None:
38
+ """Put an item into the queue."""
39
+ self.queue.append(item)
40
+ if self.get_waiter is not None and not self.get_waiter.done():
41
+ self.get_waiter.set_result(None)
42
+
43
+ async def get(self, block: bool = True) -> T:
44
+ """Remove and return an item from the queue, waiting if necessary."""
45
+ if not self.queue:
46
+ if not block:
47
+ raise EOFError("stream of frames ended")
48
+ assert self.get_waiter is None, "cannot call get() concurrently"
49
+ self.get_waiter = self.loop.create_future()
50
+ try:
51
+ await self.get_waiter
52
+ finally:
53
+ self.get_waiter.cancel()
54
+ self.get_waiter = None
55
+ return self.queue.popleft()
56
+
57
+ def reset(self, items: Iterable[T]) -> None:
58
+ """Put back items into an empty, idle queue."""
59
+ assert self.get_waiter is None, "cannot reset() while get() is running"
60
+ assert not self.queue, "cannot reset() while queue isn't empty"
61
+ self.queue.extend(items)
62
+
63
+ def abort(self) -> None:
64
+ """Close the queue, raising EOFError in get() if necessary."""
65
+ if self.get_waiter is not None and not self.get_waiter.done():
66
+ self.get_waiter.set_exception(EOFError("stream of frames ended"))
67
+
68
+
69
+ class Assembler:
70
+ """
71
+ Assemble messages from frames.
72
+
73
+ :class:`Assembler` expects only data frames. The stream of frames must
74
+ respect the protocol; if it doesn't, the behavior is undefined.
75
+
76
+ Args:
77
+ pause: Called when the buffer of frames goes above the high water mark;
78
+ should pause reading from the network.
79
+ resume: Called when the buffer of frames goes below the low water mark;
80
+ should resume reading from the network.
81
+
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ high: int | None = None,
87
+ low: int | None = None,
88
+ pause: Callable[[], Any] = lambda: None,
89
+ resume: Callable[[], Any] = lambda: None,
90
+ ) -> None:
91
+ # Queue of incoming frames.
92
+ self.frames: SimpleQueue[Frame] = SimpleQueue()
93
+
94
+ # We cannot put a hard limit on the size of the queue because a single
95
+ # call to Protocol.data_received() could produce thousands of frames,
96
+ # which must be buffered. Instead, we pause reading when the buffer goes
97
+ # above the high limit and we resume when it goes under the low limit.
98
+ if high is not None and low is None:
99
+ low = high // 4
100
+ if high is None and low is not None:
101
+ high = low * 4
102
+ if high is not None and low is not None:
103
+ if low < 0:
104
+ raise ValueError("low must be positive or equal to zero")
105
+ if high < low:
106
+ raise ValueError("high must be greater than or equal to low")
107
+ self.high, self.low = high, low
108
+ self.pause = pause
109
+ self.resume = resume
110
+ self.paused = False
111
+
112
+ # This flag prevents concurrent calls to get() by user code.
113
+ self.get_in_progress = False
114
+
115
+ # This flag marks the end of the connection.
116
+ self.closed = False
117
+
118
+ @overload
119
+ async def get(self, decode: Literal[True]) -> str: ...
120
+
121
+ @overload
122
+ async def get(self, decode: Literal[False]) -> bytes: ...
123
+
124
+ @overload
125
+ async def get(self, decode: bool | None = None) -> Data: ...
126
+
127
+ async def get(self, decode: bool | None = None) -> Data:
128
+ """
129
+ Read the next message.
130
+
131
+ :meth:`get` returns a single :class:`str` or :class:`bytes`.
132
+
133
+ If the message is fragmented, :meth:`get` waits until the last frame is
134
+ received, then it reassembles the message and returns it. To receive
135
+ messages frame by frame, use :meth:`get_iter` instead.
136
+
137
+ Args:
138
+ decode: :obj:`False` disables UTF-8 decoding of text frames and
139
+ returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of
140
+ binary frames and returns :class:`str`.
141
+
142
+ Raises:
143
+ EOFError: If the stream of frames has ended.
144
+ UnicodeDecodeError: If a text frame contains invalid UTF-8.
145
+ ConcurrencyError: If two coroutines run :meth:`get` or
146
+ :meth:`get_iter` concurrently.
147
+
148
+ """
149
+ if self.get_in_progress:
150
+ raise ConcurrencyError("get() or get_iter() is already running")
151
+ self.get_in_progress = True
152
+
153
+ # Locking with get_in_progress prevents concurrent execution
154
+ # until get() fetches a complete message or is canceled.
155
+
156
+ try:
157
+ # Fetch the first frame.
158
+ frame = await self.frames.get(not self.closed)
159
+ self.maybe_resume()
160
+ assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
161
+ if decode is None:
162
+ decode = frame.opcode is OP_TEXT
163
+ frames = [frame]
164
+
165
+ # Fetch subsequent frames for fragmented messages.
166
+ while not frame.fin:
167
+ try:
168
+ frame = await self.frames.get(not self.closed)
169
+ except asyncio.CancelledError:
170
+ # Put frames already received back into the queue
171
+ # so that future calls to get() can return them.
172
+ self.frames.reset(frames)
173
+ raise
174
+ self.maybe_resume()
175
+ assert frame.opcode is OP_CONT
176
+ frames.append(frame)
177
+
178
+ finally:
179
+ self.get_in_progress = False
180
+
181
+ # This converts frame.data to bytes when it's a bytearray.
182
+ data = b"".join(frame.data for frame in frames)
183
+ if decode:
184
+ return data.decode()
185
+ else:
186
+ return data
187
+
188
+ @overload
189
+ def get_iter(self, decode: Literal[True]) -> AsyncIterator[str]: ...
190
+
191
+ @overload
192
+ def get_iter(self, decode: Literal[False]) -> AsyncIterator[bytes]: ...
193
+
194
+ @overload
195
+ def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: ...
196
+
197
+ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
198
+ """
199
+ Stream the next message.
200
+
201
+ Iterating the return value of :meth:`get_iter` asynchronously yields a
202
+ :class:`str` or :class:`bytes` for each frame in the message.
203
+
204
+ The iterator must be fully consumed before calling :meth:`get_iter` or
205
+ :meth:`get` again. Else, :exc:`ConcurrencyError` is raised.
206
+
207
+ This method only makes sense for fragmented messages. If messages aren't
208
+ fragmented, use :meth:`get` instead.
209
+
210
+ Args:
211
+ decode: :obj:`False` disables UTF-8 decoding of text frames and
212
+ returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of
213
+ binary frames and returns :class:`str`.
214
+
215
+ Raises:
216
+ EOFError: If the stream of frames has ended.
217
+ UnicodeDecodeError: If a text frame contains invalid UTF-8.
218
+ ConcurrencyError: If two coroutines run :meth:`get` or
219
+ :meth:`get_iter` concurrently.
220
+
221
+ """
222
+ if self.get_in_progress:
223
+ raise ConcurrencyError("get() or get_iter() is already running")
224
+ self.get_in_progress = True
225
+
226
+ # Locking with get_in_progress prevents concurrent execution
227
+ # until get_iter() fetches a complete message or is canceled.
228
+
229
+ # If get_iter() raises an exception e.g. in decoder.decode(),
230
+ # get_in_progress remains set and the connection becomes unusable.
231
+
232
+ # Yield the first frame.
233
+ try:
234
+ frame = await self.frames.get(not self.closed)
235
+ except asyncio.CancelledError:
236
+ self.get_in_progress = False
237
+ raise
238
+ self.maybe_resume()
239
+ assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
240
+ if decode is None:
241
+ decode = frame.opcode is OP_TEXT
242
+ if decode:
243
+ decoder = UTF8Decoder()
244
+ yield decoder.decode(frame.data, frame.fin)
245
+ else:
246
+ # Convert to bytes when frame.data is a bytearray.
247
+ yield bytes(frame.data)
248
+
249
+ # Yield subsequent frames for fragmented messages.
250
+ while not frame.fin:
251
+ # We cannot handle asyncio.CancelledError because we don't buffer
252
+ # previous fragments — we're streaming them. Canceling get_iter()
253
+ # here will leave the assembler in a stuck state. Future calls to
254
+ # get() or get_iter() will raise ConcurrencyError.
255
+ frame = await self.frames.get(not self.closed)
256
+ self.maybe_resume()
257
+ assert frame.opcode is OP_CONT
258
+ if decode:
259
+ yield decoder.decode(frame.data, frame.fin)
260
+ else:
261
+ # Convert to bytes when frame.data is a bytearray.
262
+ yield bytes(frame.data)
263
+
264
+ self.get_in_progress = False
265
+
266
+ def put(self, frame: Frame) -> None:
267
+ """
268
+ Add ``frame`` to the next message.
269
+
270
+ Raises:
271
+ EOFError: If the stream of frames has ended.
272
+
273
+ """
274
+ if self.closed:
275
+ raise EOFError("stream of frames ended")
276
+
277
+ self.frames.put(frame)
278
+ self.maybe_pause()
279
+
280
+ def maybe_pause(self) -> None:
281
+ """Pause the writer if queue is above the high water mark."""
282
+ # Skip if flow control is disabled.
283
+ if self.high is None:
284
+ return
285
+
286
+ # Check for "> high" to support high = 0.
287
+ if len(self.frames) > self.high and not self.paused:
288
+ self.paused = True
289
+ self.pause()
290
+
291
+ def maybe_resume(self) -> None:
292
+ """Resume the writer if queue is below the low water mark."""
293
+ # Skip if flow control is disabled.
294
+ if self.low is None:
295
+ return
296
+
297
+ # Check for "<= low" to support low = 0.
298
+ if len(self.frames) <= self.low and self.paused:
299
+ self.paused = False
300
+ self.resume()
301
+
302
+ def close(self) -> None:
303
+ """
304
+ End the stream of frames.
305
+
306
+ Calling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`,
307
+ or :meth:`put` is safe. They will raise :exc:`EOFError`.
308
+
309
+ """
310
+ if self.closed:
311
+ return
312
+
313
+ self.closed = True
314
+
315
+ # Unblock get() or get_iter().
316
+ self.frames.abort()
source/websockets/asyncio/router.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import http
4
+ import ssl as ssl_module
5
+ import urllib.parse
6
+ from typing import Any, Awaitable, Callable, Literal
7
+
8
+ from ..http11 import Request, Response
9
+ from .server import Server, ServerConnection, serve
10
+
11
+
12
+ __all__ = ["route", "unix_route", "Router"]
13
+
14
+
15
+ try:
16
+ from werkzeug.exceptions import NotFound
17
+ from werkzeug.routing import Map, RequestRedirect
18
+
19
+ except ImportError:
20
+
21
+ def route(
22
+ url_map: Map,
23
+ *args: Any,
24
+ server_name: str | None = None,
25
+ ssl: ssl_module.SSLContext | Literal[True] | None = None,
26
+ create_router: type[Router] | None = None,
27
+ **kwargs: Any,
28
+ ) -> Awaitable[Server]:
29
+ raise ImportError("route() requires werkzeug")
30
+
31
+ def unix_route(
32
+ url_map: Map,
33
+ path: str | None = None,
34
+ **kwargs: Any,
35
+ ) -> Awaitable[Server]:
36
+ raise ImportError("unix_route() requires werkzeug")
37
+
38
+ else:
39
+
40
+ def route(
41
+ url_map: Map,
42
+ *args: Any,
43
+ server_name: str | None = None,
44
+ ssl: ssl_module.SSLContext | Literal[True] | None = None,
45
+ create_router: type[Router] | None = None,
46
+ **kwargs: Any,
47
+ ) -> Awaitable[Server]:
48
+ """
49
+ Create a WebSocket server dispatching connections to different handlers.
50
+
51
+ This feature requires the third-party library `werkzeug`_:
52
+
53
+ .. code-block:: console
54
+
55
+ $ pip install werkzeug
56
+
57
+ .. _werkzeug: https://werkzeug.palletsprojects.com/
58
+
59
+ :func:`route` accepts the same arguments as
60
+ :func:`~websockets.sync.server.serve`, except as described below.
61
+
62
+ The first argument is a :class:`werkzeug.routing.Map` that maps URL patterns
63
+ to connection handlers. In addition to the connection, handlers receive
64
+ parameters captured in the URL as keyword arguments.
65
+
66
+ Here's an example::
67
+
68
+
69
+ from websockets.asyncio.router import route
70
+ from werkzeug.routing import Map, Rule
71
+
72
+ async def channel_handler(websocket, channel_id):
73
+ ...
74
+
75
+ url_map = Map([
76
+ Rule("/channel/<uuid:channel_id>", endpoint=channel_handler),
77
+ ...
78
+ ])
79
+
80
+ # set this future to exit the server
81
+ stop = asyncio.get_running_loop().create_future()
82
+
83
+ async with route(url_map, ...) as server:
84
+ await stop
85
+
86
+
87
+ Refer to the documentation of :mod:`werkzeug.routing` for details.
88
+
89
+ If you define redirects with ``Rule(..., redirect_to=...)`` in the URL map,
90
+ when the server runs behind a reverse proxy that modifies the ``Host``
91
+ header or terminates TLS, you need additional configuration:
92
+
93
+ * Set ``server_name`` to the name of the server as seen by clients. When
94
+ not provided, websockets uses the value of the ``Host`` header.
95
+
96
+ * Set ``ssl=True`` to generate ``wss://`` URIs without enabling TLS.
97
+ Under the hood, this bind the URL map with a ``url_scheme`` of
98
+ ``wss://`` instead of ``ws://``.
99
+
100
+ There is no need to specify ``websocket=True`` in each rule. It is added
101
+ automatically.
102
+
103
+ Args:
104
+ url_map: Mapping of URL patterns to connection handlers.
105
+ server_name: Name of the server as seen by clients. If :obj:`None`,
106
+ websockets uses the value of the ``Host`` header.
107
+ ssl: Configuration for enabling TLS on the connection. Set it to
108
+ :obj:`True` if a reverse proxy terminates TLS connections.
109
+ create_router: Factory for the :class:`Router` dispatching requests to
110
+ handlers. Set it to a wrapper or a subclass to customize routing.
111
+
112
+ """
113
+ url_scheme = "ws" if ssl is None else "wss"
114
+ if ssl is not True and ssl is not None:
115
+ kwargs["ssl"] = ssl
116
+
117
+ if create_router is None:
118
+ create_router = Router
119
+
120
+ router = create_router(url_map, server_name, url_scheme)
121
+
122
+ _process_request: (
123
+ Callable[
124
+ [ServerConnection, Request],
125
+ Awaitable[Response | None] | Response | None,
126
+ ]
127
+ | None
128
+ ) = kwargs.pop("process_request", None)
129
+ if _process_request is None:
130
+ process_request: Callable[
131
+ [ServerConnection, Request],
132
+ Awaitable[Response | None] | Response | None,
133
+ ] = router.route_request
134
+ else:
135
+
136
+ async def process_request(
137
+ connection: ServerConnection, request: Request
138
+ ) -> Response | None:
139
+ response = _process_request(connection, request)
140
+ if isinstance(response, Awaitable):
141
+ response = await response
142
+ if response is not None:
143
+ return response
144
+ return router.route_request(connection, request)
145
+
146
+ return serve(router.handler, *args, process_request=process_request, **kwargs)
147
+
148
+ def unix_route(
149
+ url_map: Map,
150
+ path: str | None = None,
151
+ **kwargs: Any,
152
+ ) -> Awaitable[Server]:
153
+ """
154
+ Create a WebSocket Unix server dispatching connections to different handlers.
155
+
156
+ :func:`unix_route` combines the behaviors of :func:`route` and
157
+ :func:`~websockets.asyncio.server.unix_serve`.
158
+
159
+ Args:
160
+ url_map: Mapping of URL patterns to connection handlers.
161
+ path: File system path to the Unix socket.
162
+
163
+ """
164
+ return route(url_map, unix=True, path=path, **kwargs)
165
+
166
+
167
+ class Router:
168
+ """WebSocket router supporting :func:`route`."""
169
+
170
+ def __init__(
171
+ self,
172
+ url_map: Map,
173
+ server_name: str | None = None,
174
+ url_scheme: str = "ws",
175
+ ) -> None:
176
+ self.url_map = url_map
177
+ self.server_name = server_name
178
+ self.url_scheme = url_scheme
179
+ for rule in self.url_map.iter_rules():
180
+ rule.websocket = True
181
+
182
+ def get_server_name(self, connection: ServerConnection, request: Request) -> str:
183
+ if self.server_name is None:
184
+ return request.headers["Host"]
185
+ else:
186
+ return self.server_name
187
+
188
+ def redirect(self, connection: ServerConnection, url: str) -> Response:
189
+ response = connection.respond(http.HTTPStatus.FOUND, f"Found at {url}")
190
+ response.headers["Location"] = url
191
+ return response
192
+
193
+ def not_found(self, connection: ServerConnection) -> Response:
194
+ return connection.respond(http.HTTPStatus.NOT_FOUND, "Not Found")
195
+
196
+ def route_request(
197
+ self, connection: ServerConnection, request: Request
198
+ ) -> Response | None:
199
+ """Route incoming request."""
200
+ url_map_adapter = self.url_map.bind(
201
+ server_name=self.get_server_name(connection, request),
202
+ url_scheme=self.url_scheme,
203
+ )
204
+ try:
205
+ parsed = urllib.parse.urlparse(request.path)
206
+ handler, kwargs = url_map_adapter.match(
207
+ path_info=parsed.path,
208
+ query_args=parsed.query,
209
+ )
210
+ except RequestRedirect as redirect:
211
+ return self.redirect(connection, redirect.new_url)
212
+ except NotFound:
213
+ return self.not_found(connection)
214
+ connection.handler, connection.handler_kwargs = handler, kwargs
215
+ return None
216
+
217
+ async def handler(self, connection: ServerConnection) -> None:
218
+ """Handle a connection."""
219
+ return await connection.handler(connection, **connection.handler_kwargs)
source/websockets/asyncio/server.py ADDED
@@ -0,0 +1,997 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import hmac
5
+ import http
6
+ import logging
7
+ import re
8
+ import socket
9
+ import sys
10
+ from collections.abc import Awaitable, Generator, Iterable, Sequence
11
+ from types import TracebackType
12
+ from typing import Any, Callable, Mapping, cast
13
+
14
+ from ..exceptions import InvalidHeader
15
+ from ..extensions.base import ServerExtensionFactory
16
+ from ..extensions.permessage_deflate import enable_server_permessage_deflate
17
+ from ..frames import CloseCode
18
+ from ..headers import (
19
+ build_www_authenticate_basic,
20
+ parse_authorization_basic,
21
+ validate_subprotocols,
22
+ )
23
+ from ..http11 import SERVER, Request, Response
24
+ from ..protocol import CONNECTING, OPEN, Event
25
+ from ..server import ServerProtocol
26
+ from ..typing import LoggerLike, Origin, StatusLike, Subprotocol
27
+ from .compatibility import asyncio_timeout
28
+ from .connection import Connection, broadcast
29
+
30
+
31
+ __all__ = [
32
+ "broadcast",
33
+ "serve",
34
+ "unix_serve",
35
+ "ServerConnection",
36
+ "Server",
37
+ "basic_auth",
38
+ ]
39
+
40
+
41
+ class ServerConnection(Connection):
42
+ """
43
+ :mod:`asyncio` implementation of a WebSocket server connection.
44
+
45
+ :class:`ServerConnection` provides :meth:`recv` and :meth:`send` methods for
46
+ receiving and sending messages.
47
+
48
+ It supports asynchronous iteration to receive messages::
49
+
50
+ async for message in websocket:
51
+ await process(message)
52
+
53
+ The iterator exits normally when the connection is closed with code
54
+ 1000 (OK) or 1001 (going away) or without a close code. It raises a
55
+ :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is
56
+ closed with any other code.
57
+
58
+ The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``,
59
+ and ``write_limit`` arguments have the same meaning as in :func:`serve`.
60
+
61
+ Args:
62
+ protocol: Sans-I/O connection.
63
+ server: Server that manages this connection.
64
+
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ protocol: ServerProtocol,
70
+ server: Server,
71
+ *,
72
+ ping_interval: float | None = 20,
73
+ ping_timeout: float | None = 20,
74
+ close_timeout: float | None = 10,
75
+ max_queue: int | None | tuple[int | None, int | None] = 16,
76
+ write_limit: int | tuple[int, int | None] = 2**15,
77
+ ) -> None:
78
+ self.protocol: ServerProtocol
79
+ super().__init__(
80
+ protocol,
81
+ ping_interval=ping_interval,
82
+ ping_timeout=ping_timeout,
83
+ close_timeout=close_timeout,
84
+ max_queue=max_queue,
85
+ write_limit=write_limit,
86
+ )
87
+ self.server = server
88
+ self.request_rcvd: asyncio.Future[None] = self.loop.create_future()
89
+ self.username: str # see basic_auth()
90
+ self.handler: Callable[[ServerConnection], Awaitable[None]] # see route()
91
+ self.handler_kwargs: Mapping[str, Any] # see route()
92
+
93
+ def respond(self, status: StatusLike, text: str) -> Response:
94
+ """
95
+ Create a plain text HTTP response.
96
+
97
+ ``process_request`` and ``process_response`` may call this method to
98
+ return an HTTP response instead of performing the WebSocket opening
99
+ handshake.
100
+
101
+ You can modify the response before returning it, for example by changing
102
+ HTTP headers.
103
+
104
+ Args:
105
+ status: HTTP status code.
106
+ text: HTTP response body; it will be encoded to UTF-8.
107
+
108
+ Returns:
109
+ HTTP response to send to the client.
110
+
111
+ """
112
+ return self.protocol.reject(status, text)
113
+
114
+ async def handshake(
115
+ self,
116
+ process_request: (
117
+ Callable[
118
+ [ServerConnection, Request],
119
+ Awaitable[Response | None] | Response | None,
120
+ ]
121
+ | None
122
+ ) = None,
123
+ process_response: (
124
+ Callable[
125
+ [ServerConnection, Request, Response],
126
+ Awaitable[Response | None] | Response | None,
127
+ ]
128
+ | None
129
+ ) = None,
130
+ server_header: str | None = SERVER,
131
+ ) -> None:
132
+ """
133
+ Perform the opening handshake.
134
+
135
+ """
136
+ await asyncio.wait(
137
+ [self.request_rcvd, self.connection_lost_waiter],
138
+ return_when=asyncio.FIRST_COMPLETED,
139
+ )
140
+
141
+ if self.request is not None:
142
+ async with self.send_context(expected_state=CONNECTING):
143
+ response = None
144
+
145
+ if process_request is not None:
146
+ try:
147
+ response = process_request(self, self.request)
148
+ if isinstance(response, Awaitable):
149
+ response = await response
150
+ except Exception as exc:
151
+ self.protocol.handshake_exc = exc
152
+ response = self.protocol.reject(
153
+ http.HTTPStatus.INTERNAL_SERVER_ERROR,
154
+ (
155
+ "Failed to open a WebSocket connection.\n"
156
+ "See server log for more information.\n"
157
+ ),
158
+ )
159
+
160
+ if response is None:
161
+ if self.server.is_serving():
162
+ self.response = self.protocol.accept(self.request)
163
+ else:
164
+ self.response = self.protocol.reject(
165
+ http.HTTPStatus.SERVICE_UNAVAILABLE,
166
+ "Server is shutting down.\n",
167
+ )
168
+ else:
169
+ assert isinstance(response, Response) # help mypy
170
+ self.response = response
171
+
172
+ if server_header:
173
+ self.response.headers["Server"] = server_header
174
+
175
+ response = None
176
+
177
+ if process_response is not None:
178
+ try:
179
+ response = process_response(self, self.request, self.response)
180
+ if isinstance(response, Awaitable):
181
+ response = await response
182
+ except Exception as exc:
183
+ self.protocol.handshake_exc = exc
184
+ response = self.protocol.reject(
185
+ http.HTTPStatus.INTERNAL_SERVER_ERROR,
186
+ (
187
+ "Failed to open a WebSocket connection.\n"
188
+ "See server log for more information.\n"
189
+ ),
190
+ )
191
+
192
+ if response is not None:
193
+ assert isinstance(response, Response) # help mypy
194
+ self.response = response
195
+
196
+ self.protocol.send_response(self.response)
197
+
198
+ # self.protocol.handshake_exc is set when the connection is lost before
199
+ # receiving a request, when the request cannot be parsed, or when the
200
+ # handshake fails, including when process_request or process_response
201
+ # raises an exception.
202
+
203
+ # It isn't set when process_request or process_response sends an HTTP
204
+ # response that rejects the handshake.
205
+
206
+ if self.protocol.handshake_exc is not None:
207
+ raise self.protocol.handshake_exc
208
+
209
+ def process_event(self, event: Event) -> None:
210
+ """
211
+ Process one incoming event.
212
+
213
+ """
214
+ # First event - handshake request.
215
+ if self.request is None:
216
+ assert isinstance(event, Request)
217
+ self.request = event
218
+ self.request_rcvd.set_result(None)
219
+ # Later events - frames.
220
+ else:
221
+ super().process_event(event)
222
+
223
+ def connection_made(self, transport: asyncio.BaseTransport) -> None:
224
+ super().connection_made(transport)
225
+ self.server.start_connection_handler(self)
226
+
227
+
228
+ class Server:
229
+ """
230
+ WebSocket server returned by :func:`serve`.
231
+
232
+ This class mirrors the API of :class:`asyncio.Server`.
233
+
234
+ It keeps track of WebSocket connections in order to close them properly
235
+ when shutting down.
236
+
237
+ Args:
238
+ handler: Connection handler. It receives the WebSocket connection,
239
+ which is a :class:`ServerConnection`, in argument.
240
+ process_request: Intercept the request during the opening handshake.
241
+ Return an HTTP response to force the response. Return :obj:`None` to
242
+ continue normally. When you force an HTTP 101 Continue response, the
243
+ handshake is successful. Else, the connection is aborted.
244
+ ``process_request`` may be a function or a coroutine.
245
+ process_response: Intercept the response during the opening handshake.
246
+ Modify the response or return a new HTTP response to force the
247
+ response. Return :obj:`None` to continue normally. When you force an
248
+ HTTP 101 Continue response, the handshake is successful. Else, the
249
+ connection is aborted. ``process_response`` may be a function or a
250
+ coroutine.
251
+ server_header: Value of the ``Server`` response header.
252
+ It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to
253
+ :obj:`None` removes the header.
254
+ open_timeout: Timeout for opening connections in seconds.
255
+ :obj:`None` disables the timeout.
256
+ logger: Logger for this server.
257
+ It defaults to ``logging.getLogger("websockets.server")``.
258
+ See the :doc:`logging guide <../../topics/logging>` for details.
259
+
260
+ """
261
+
262
+ def __init__(
263
+ self,
264
+ handler: Callable[[ServerConnection], Awaitable[None]],
265
+ *,
266
+ process_request: (
267
+ Callable[
268
+ [ServerConnection, Request],
269
+ Awaitable[Response | None] | Response | None,
270
+ ]
271
+ | None
272
+ ) = None,
273
+ process_response: (
274
+ Callable[
275
+ [ServerConnection, Request, Response],
276
+ Awaitable[Response | None] | Response | None,
277
+ ]
278
+ | None
279
+ ) = None,
280
+ server_header: str | None = SERVER,
281
+ open_timeout: float | None = 10,
282
+ logger: LoggerLike | None = None,
283
+ ) -> None:
284
+ self.loop = asyncio.get_running_loop()
285
+ self.handler = handler
286
+ self.process_request = process_request
287
+ self.process_response = process_response
288
+ self.server_header = server_header
289
+ self.open_timeout = open_timeout
290
+ if logger is None:
291
+ logger = logging.getLogger("websockets.server")
292
+ self.logger = logger
293
+
294
+ # Keep track of active connections.
295
+ self.handlers: dict[ServerConnection, asyncio.Task[None]] = {}
296
+
297
+ # Task responsible for closing the server and terminating connections.
298
+ self.close_task: asyncio.Task[None] | None = None
299
+
300
+ # Completed when the server is closed and connections are terminated.
301
+ self.closed_waiter: asyncio.Future[None] = self.loop.create_future()
302
+
303
+ @property
304
+ def connections(self) -> set[ServerConnection]:
305
+ """
306
+ Set of active connections.
307
+
308
+ This property contains all connections that completed the opening
309
+ handshake successfully and didn't start the closing handshake yet.
310
+ It can be useful in combination with :func:`~broadcast`.
311
+
312
+ """
313
+ return {connection for connection in self.handlers if connection.state is OPEN}
314
+
315
+ def wrap(self, server: asyncio.Server) -> None:
316
+ """
317
+ Attach to a given :class:`asyncio.Server`.
318
+
319
+ Since :meth:`~asyncio.loop.create_server` doesn't support injecting a
320
+ custom ``Server`` class, the easiest solution that doesn't rely on
321
+ private :mod:`asyncio` APIs is to:
322
+
323
+ - instantiate a :class:`Server`
324
+ - give the protocol factory a reference to that instance
325
+ - call :meth:`~asyncio.loop.create_server` with the factory
326
+ - attach the resulting :class:`asyncio.Server` with this method
327
+
328
+ """
329
+ self.server = server
330
+ for sock in server.sockets:
331
+ if sock.family == socket.AF_INET:
332
+ name = "%s:%d" % sock.getsockname()
333
+ elif sock.family == socket.AF_INET6:
334
+ name = "[%s]:%d" % sock.getsockname()[:2]
335
+ elif sock.family == socket.AF_UNIX:
336
+ name = sock.getsockname()
337
+ # In the unlikely event that someone runs websockets over a
338
+ # protocol other than IP or Unix sockets, avoid crashing.
339
+ else: # pragma: no cover
340
+ name = str(sock.getsockname())
341
+ self.logger.info("server listening on %s", name)
342
+
343
+ async def conn_handler(self, connection: ServerConnection) -> None:
344
+ """
345
+ Handle the lifecycle of a WebSocket connection.
346
+
347
+ Since this method doesn't have a caller that can handle exceptions,
348
+ it attempts to log relevant ones.
349
+
350
+ It guarantees that the TCP connection is closed before exiting.
351
+
352
+ """
353
+ try:
354
+ async with asyncio_timeout(self.open_timeout):
355
+ try:
356
+ await connection.handshake(
357
+ self.process_request,
358
+ self.process_response,
359
+ self.server_header,
360
+ )
361
+ except asyncio.CancelledError:
362
+ connection.transport.abort()
363
+ raise
364
+ except Exception:
365
+ connection.logger.error("opening handshake failed", exc_info=True)
366
+ connection.transport.abort()
367
+ return
368
+
369
+ if connection.protocol.state is not OPEN:
370
+ # process_request or process_response rejected the handshake.
371
+ connection.transport.abort()
372
+ return
373
+
374
+ try:
375
+ connection.start_keepalive()
376
+ await self.handler(connection)
377
+ except Exception:
378
+ connection.logger.error("connection handler failed", exc_info=True)
379
+ await connection.close(CloseCode.INTERNAL_ERROR)
380
+ else:
381
+ await connection.close()
382
+
383
+ except TimeoutError:
384
+ # When the opening handshake times out, there's nothing to log.
385
+ pass
386
+
387
+ except Exception: # pragma: no cover
388
+ # Don't leak connections on unexpected errors.
389
+ connection.transport.abort()
390
+
391
+ finally:
392
+ # Registration is tied to the lifecycle of conn_handler() because
393
+ # the server waits for connection handlers to terminate, even if
394
+ # all connections are already closed.
395
+ del self.handlers[connection]
396
+
397
+ def start_connection_handler(self, connection: ServerConnection) -> None:
398
+ """
399
+ Register a connection with this server.
400
+
401
+ """
402
+ # The connection must be registered in self.handlers immediately.
403
+ # If it was registered in conn_handler(), a race condition could
404
+ # happen when closing the server after scheduling conn_handler()
405
+ # but before it starts executing.
406
+ self.handlers[connection] = self.loop.create_task(self.conn_handler(connection))
407
+
408
+ def close(
409
+ self,
410
+ close_connections: bool = True,
411
+ code: CloseCode | int = CloseCode.GOING_AWAY,
412
+ reason: str = "",
413
+ ) -> None:
414
+ """
415
+ Close the server.
416
+
417
+ * Close the underlying :class:`asyncio.Server`.
418
+ * When ``close_connections`` is :obj:`True`, which is the default, close
419
+ existing connections. Specifically:
420
+
421
+ * Reject opening WebSocket connections with an HTTP 503 (service
422
+ unavailable) error. This happens when the server accepted the TCP
423
+ connection but didn't complete the opening handshake before closing.
424
+ * Close open WebSocket connections with code 1001 (going away).
425
+ ``code`` and ``reason`` can be customized, for example to use code
426
+ 1012 (service restart).
427
+
428
+ * Wait until all connection handlers terminate.
429
+
430
+ :meth:`close` is idempotent.
431
+
432
+ """
433
+ if self.close_task is None:
434
+ self.close_task = self.get_loop().create_task(
435
+ self._close(close_connections, code, reason)
436
+ )
437
+
438
+ async def _close(
439
+ self,
440
+ close_connections: bool = True,
441
+ code: CloseCode | int = CloseCode.GOING_AWAY,
442
+ reason: str = "",
443
+ ) -> None:
444
+ """
445
+ Implementation of :meth:`close`.
446
+
447
+ This calls :meth:`~asyncio.Server.close` on the underlying
448
+ :class:`asyncio.Server` object to stop accepting new connections and
449
+ then closes open connections.
450
+
451
+ """
452
+ self.logger.info("server closing")
453
+
454
+ # Stop accepting new connections.
455
+ self.server.close()
456
+
457
+ # Wait until all accepted connections reach connection_made() and call
458
+ # register(). See https://github.com/python/cpython/issues/79033 for
459
+ # details. This workaround can be removed when dropping Python < 3.11.
460
+ await asyncio.sleep(0)
461
+
462
+ # After server.close(), handshake() closes OPENING connections with an
463
+ # HTTP 503 error.
464
+
465
+ if close_connections:
466
+ # Close OPEN connections with code 1001 by default.
467
+ close_tasks = [
468
+ asyncio.create_task(connection.close(code, reason))
469
+ for connection in self.handlers
470
+ if connection.protocol.state is not CONNECTING
471
+ ]
472
+ # asyncio.wait doesn't accept an empty first argument.
473
+ if close_tasks:
474
+ await asyncio.wait(close_tasks)
475
+
476
+ # Wait until all TCP connections are closed.
477
+ await self.server.wait_closed()
478
+
479
+ # Wait until all connection handlers terminate.
480
+ # asyncio.wait doesn't accept an empty first argument.
481
+ if self.handlers:
482
+ await asyncio.wait(self.handlers.values())
483
+
484
+ # Tell wait_closed() to return.
485
+ self.closed_waiter.set_result(None)
486
+
487
+ self.logger.info("server closed")
488
+
489
+ async def wait_closed(self) -> None:
490
+ """
491
+ Wait until the server is closed.
492
+
493
+ When :meth:`wait_closed` returns, all TCP connections are closed and
494
+ all connection handlers have returned.
495
+
496
+ To ensure a fast shutdown, a connection handler should always be
497
+ awaiting at least one of:
498
+
499
+ * :meth:`~ServerConnection.recv`: when the connection is closed,
500
+ it raises :exc:`~websockets.exceptions.ConnectionClosedOK`;
501
+ * :meth:`~ServerConnection.wait_closed`: when the connection is
502
+ closed, it returns.
503
+
504
+ Then the connection handler is immediately notified of the shutdown;
505
+ it can clean up and exit.
506
+
507
+ """
508
+ await asyncio.shield(self.closed_waiter)
509
+
510
+ def get_loop(self) -> asyncio.AbstractEventLoop:
511
+ """
512
+ See :meth:`asyncio.Server.get_loop`.
513
+
514
+ """
515
+ return self.server.get_loop()
516
+
517
+ def is_serving(self) -> bool: # pragma: no cover
518
+ """
519
+ See :meth:`asyncio.Server.is_serving`.
520
+
521
+ """
522
+ return self.server.is_serving()
523
+
524
+ async def start_serving(self) -> None: # pragma: no cover
525
+ """
526
+ See :meth:`asyncio.Server.start_serving`.
527
+
528
+ Typical use::
529
+
530
+ server = await serve(..., start_serving=False)
531
+ # perform additional setup here...
532
+ # ... then start the server
533
+ await server.start_serving()
534
+
535
+ """
536
+ await self.server.start_serving()
537
+
538
+ async def serve_forever(self) -> None: # pragma: no cover
539
+ """
540
+ See :meth:`asyncio.Server.serve_forever`.
541
+
542
+ Typical use::
543
+
544
+ server = await serve(...)
545
+ # this coroutine doesn't return
546
+ # canceling it stops the server
547
+ await server.serve_forever()
548
+
549
+ This is an alternative to using :func:`serve` as an asynchronous context
550
+ manager. Shutdown is triggered by canceling :meth:`serve_forever`
551
+ instead of exiting a :func:`serve` context.
552
+
553
+ """
554
+ await self.server.serve_forever()
555
+
556
+ @property
557
+ def sockets(self) -> tuple[socket.socket, ...]:
558
+ """
559
+ See :attr:`asyncio.Server.sockets`.
560
+
561
+ """
562
+ return self.server.sockets
563
+
564
+ async def __aenter__(self) -> Server: # pragma: no cover
565
+ return self
566
+
567
+ async def __aexit__(
568
+ self,
569
+ exc_type: type[BaseException] | None,
570
+ exc_value: BaseException | None,
571
+ traceback: TracebackType | None,
572
+ ) -> None: # pragma: no cover
573
+ self.close()
574
+ await self.wait_closed()
575
+
576
+
577
+ # This is spelled in lower case because it's exposed as a callable in the API.
578
+ class serve:
579
+ """
580
+ Create a WebSocket server listening on ``host`` and ``port``.
581
+
582
+ Whenever a client connects, the server creates a :class:`ServerConnection`,
583
+ performs the opening handshake, and delegates to the ``handler`` coroutine.
584
+
585
+ The handler receives the :class:`ServerConnection` instance, which you can
586
+ use to send and receive messages.
587
+
588
+ Once the handler completes, either normally or with an exception, the server
589
+ performs the closing handshake and closes the connection.
590
+
591
+ This coroutine returns a :class:`Server` whose API mirrors
592
+ :class:`asyncio.Server`. Treat it as an asynchronous context manager to
593
+ ensure that the server will be closed::
594
+
595
+ from websockets.asyncio.server import serve
596
+
597
+ def handler(websocket):
598
+ ...
599
+
600
+ # set this future to exit the server
601
+ stop = asyncio.get_running_loop().create_future()
602
+
603
+ async with serve(handler, host, port):
604
+ await stop
605
+
606
+ Alternatively, call :meth:`~Server.serve_forever` to serve requests and
607
+ cancel it to stop the server::
608
+
609
+ server = await serve(handler, host, port)
610
+ await server.serve_forever()
611
+
612
+ Args:
613
+ handler: Connection handler. It receives the WebSocket connection,
614
+ which is a :class:`ServerConnection`, in argument.
615
+ host: Network interfaces the server binds to.
616
+ See :meth:`~asyncio.loop.create_server` for details.
617
+ port: TCP port the server listens on.
618
+ See :meth:`~asyncio.loop.create_server` for details.
619
+ origins: Acceptable values of the ``Origin`` header, for defending
620
+ against Cross-Site WebSocket Hijacking attacks. Values can be
621
+ :class:`str` to test for an exact match or regular expressions
622
+ compiled by :func:`re.compile` to test against a pattern. Include
623
+ :obj:`None` in the list if the lack of an origin is acceptable.
624
+ extensions: List of supported extensions, in order in which they
625
+ should be negotiated and run.
626
+ subprotocols: List of supported subprotocols, in order of decreasing
627
+ preference.
628
+ select_subprotocol: Callback for selecting a subprotocol among
629
+ those supported by the client and the server. It receives a
630
+ :class:`ServerConnection` (not a
631
+ :class:`~websockets.server.ServerProtocol`!) instance and a list of
632
+ subprotocols offered by the client. Other than the first argument,
633
+ it has the same behavior as the
634
+ :meth:`ServerProtocol.select_subprotocol
635
+ <websockets.server.ServerProtocol.select_subprotocol>` method.
636
+ compression: The "permessage-deflate" extension is enabled by default.
637
+ Set ``compression`` to :obj:`None` to disable it. See the
638
+ :doc:`compression guide <../../topics/compression>` for details.
639
+ process_request: Intercept the request during the opening handshake.
640
+ Return an HTTP response to force the response or :obj:`None` to
641
+ continue normally. When you force an HTTP 101 Continue response, the
642
+ handshake is successful. Else, the connection is aborted.
643
+ ``process_request`` may be a function or a coroutine.
644
+ process_response: Intercept the response during the opening handshake.
645
+ Return an HTTP response to force the response or :obj:`None` to
646
+ continue normally. When you force an HTTP 101 Continue response, the
647
+ handshake is successful. Else, the connection is aborted.
648
+ ``process_response`` may be a function or a coroutine.
649
+ server_header: Value of the ``Server`` response header.
650
+ It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to
651
+ :obj:`None` removes the header.
652
+ open_timeout: Timeout for opening connections in seconds.
653
+ :obj:`None` disables the timeout.
654
+ ping_interval: Interval between keepalive pings in seconds.
655
+ :obj:`None` disables keepalive.
656
+ ping_timeout: Timeout for keepalive pings in seconds.
657
+ :obj:`None` disables timeouts.
658
+ close_timeout: Timeout for closing connections in seconds.
659
+ :obj:`None` disables the timeout.
660
+ max_size: Maximum size of incoming messages in bytes.
661
+ :obj:`None` disables the limit. You may pass a ``(max_message_size,
662
+ max_fragment_size)`` tuple to set different limits for messages and
663
+ fragments when you expect long messages sent in short fragments.
664
+ max_queue: High-water mark of the buffer where frames are received.
665
+ It defaults to 16 frames. The low-water mark defaults to ``max_queue
666
+ // 4``. You may pass a ``(high, low)`` tuple to set the high-water
667
+ and low-water marks. If you want to disable flow control entirely,
668
+ you may set it to ``None``, although that's a bad idea.
669
+ write_limit: High-water mark of write buffer in bytes. It is passed to
670
+ :meth:`~asyncio.WriteTransport.set_write_buffer_limits`. It defaults
671
+ to 32 KiB. You may pass a ``(high, low)`` tuple to set the
672
+ high-water and low-water marks.
673
+ logger: Logger for this server.
674
+ It defaults to ``logging.getLogger("websockets.server")``. See the
675
+ :doc:`logging guide <../../topics/logging>` for details.
676
+ create_connection: Factory for the :class:`ServerConnection` managing
677
+ the connection. Set it to a wrapper or a subclass to customize
678
+ connection handling.
679
+
680
+ Any other keyword arguments are passed to the event loop's
681
+ :meth:`~asyncio.loop.create_server` method.
682
+
683
+ For example:
684
+
685
+ * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enable TLS.
686
+
687
+ * You can set ``sock`` to provide a preexisting TCP socket. You may call
688
+ :func:`socket.create_server` (not to be confused with the event loop's
689
+ :meth:`~asyncio.loop.create_server` method) to create a suitable server
690
+ socket and customize it.
691
+
692
+ * You can set ``start_serving`` to ``False`` to start accepting connections
693
+ only after you call :meth:`~Server.start_serving()` or
694
+ :meth:`~Server.serve_forever()`.
695
+
696
+ """
697
+
698
+ def __init__(
699
+ self,
700
+ handler: Callable[[ServerConnection], Awaitable[None]],
701
+ host: str | None = None,
702
+ port: int | None = None,
703
+ *,
704
+ # WebSocket
705
+ origins: Sequence[Origin | re.Pattern[str] | None] | None = None,
706
+ extensions: Sequence[ServerExtensionFactory] | None = None,
707
+ subprotocols: Sequence[Subprotocol] | None = None,
708
+ select_subprotocol: (
709
+ Callable[
710
+ [ServerConnection, Sequence[Subprotocol]],
711
+ Subprotocol | None,
712
+ ]
713
+ | None
714
+ ) = None,
715
+ compression: str | None = "deflate",
716
+ # HTTP
717
+ process_request: (
718
+ Callable[
719
+ [ServerConnection, Request],
720
+ Awaitable[Response | None] | Response | None,
721
+ ]
722
+ | None
723
+ ) = None,
724
+ process_response: (
725
+ Callable[
726
+ [ServerConnection, Request, Response],
727
+ Awaitable[Response | None] | Response | None,
728
+ ]
729
+ | None
730
+ ) = None,
731
+ server_header: str | None = SERVER,
732
+ # Timeouts
733
+ open_timeout: float | None = 10,
734
+ ping_interval: float | None = 20,
735
+ ping_timeout: float | None = 20,
736
+ close_timeout: float | None = 10,
737
+ # Limits
738
+ max_size: int | None | tuple[int | None, int | None] = 2**20,
739
+ max_queue: int | None | tuple[int | None, int | None] = 16,
740
+ write_limit: int | tuple[int, int | None] = 2**15,
741
+ # Logging
742
+ logger: LoggerLike | None = None,
743
+ # Escape hatch for advanced customization
744
+ create_connection: type[ServerConnection] | None = None,
745
+ # Other keyword arguments are passed to loop.create_server
746
+ **kwargs: Any,
747
+ ) -> None:
748
+ if subprotocols is not None:
749
+ validate_subprotocols(subprotocols)
750
+
751
+ if compression == "deflate":
752
+ extensions = enable_server_permessage_deflate(extensions)
753
+ elif compression is not None:
754
+ raise ValueError(f"unsupported compression: {compression}")
755
+
756
+ if create_connection is None:
757
+ create_connection = ServerConnection
758
+
759
+ self.server = Server(
760
+ handler,
761
+ process_request=process_request,
762
+ process_response=process_response,
763
+ server_header=server_header,
764
+ open_timeout=open_timeout,
765
+ logger=logger,
766
+ )
767
+
768
+ if kwargs.get("ssl") is not None:
769
+ kwargs.setdefault("ssl_handshake_timeout", open_timeout)
770
+ if sys.version_info[:2] >= (3, 11): # pragma: no branch
771
+ kwargs.setdefault("ssl_shutdown_timeout", close_timeout)
772
+
773
+ def factory() -> ServerConnection:
774
+ """
775
+ Create an asyncio protocol for managing a WebSocket connection.
776
+
777
+ """
778
+ # Create a closure to give select_subprotocol access to connection.
779
+ protocol_select_subprotocol: (
780
+ Callable[
781
+ [ServerProtocol, Sequence[Subprotocol]],
782
+ Subprotocol | None,
783
+ ]
784
+ | None
785
+ ) = None
786
+ if select_subprotocol is not None:
787
+
788
+ def protocol_select_subprotocol(
789
+ protocol: ServerProtocol,
790
+ subprotocols: Sequence[Subprotocol],
791
+ ) -> Subprotocol | None:
792
+ # mypy doesn't know that select_subprotocol is immutable.
793
+ assert select_subprotocol is not None
794
+ # Ensure this function is only used in the intended context.
795
+ assert protocol is connection.protocol
796
+ return select_subprotocol(connection, subprotocols)
797
+
798
+ # This is a protocol in the Sans-I/O implementation of websockets.
799
+ protocol = ServerProtocol(
800
+ origins=origins,
801
+ extensions=extensions,
802
+ subprotocols=subprotocols,
803
+ select_subprotocol=protocol_select_subprotocol,
804
+ max_size=max_size,
805
+ logger=logger,
806
+ )
807
+ # This is a connection in websockets and a protocol in asyncio.
808
+ connection = create_connection(
809
+ protocol,
810
+ self.server,
811
+ ping_interval=ping_interval,
812
+ ping_timeout=ping_timeout,
813
+ close_timeout=close_timeout,
814
+ max_queue=max_queue,
815
+ write_limit=write_limit,
816
+ )
817
+ return connection
818
+
819
+ loop = asyncio.get_running_loop()
820
+ if kwargs.pop("unix", False):
821
+ self.create_server = loop.create_unix_server(factory, **kwargs)
822
+ else:
823
+ # mypy cannot tell that kwargs must provide sock when port is None.
824
+ self.create_server = loop.create_server(factory, host, port, **kwargs) # type: ignore[arg-type]
825
+
826
+ # async with serve(...) as ...: ...
827
+
828
+ async def __aenter__(self) -> Server:
829
+ return await self
830
+
831
+ async def __aexit__(
832
+ self,
833
+ exc_type: type[BaseException] | None,
834
+ exc_value: BaseException | None,
835
+ traceback: TracebackType | None,
836
+ ) -> None:
837
+ self.server.close()
838
+ await self.server.wait_closed()
839
+
840
+ # ... = await serve(...)
841
+
842
+ def __await__(self) -> Generator[Any, None, Server]:
843
+ # Create a suitable iterator by calling __await__ on a coroutine.
844
+ return self.__await_impl__().__await__()
845
+
846
+ async def __await_impl__(self) -> Server:
847
+ server = await self.create_server
848
+ self.server.wrap(server)
849
+ return self.server
850
+
851
+ # ... = yield from serve(...) - remove when dropping Python < 3.11
852
+
853
+ __iter__ = __await__
854
+
855
+
856
+ def unix_serve(
857
+ handler: Callable[[ServerConnection], Awaitable[None]],
858
+ path: str | None = None,
859
+ **kwargs: Any,
860
+ ) -> Awaitable[Server]:
861
+ """
862
+ Create a WebSocket server listening on a Unix socket.
863
+
864
+ This function is identical to :func:`serve`, except the ``host`` and
865
+ ``port`` arguments are replaced by ``path``. It's only available on Unix.
866
+
867
+ It's useful for deploying a server behind a reverse proxy such as nginx.
868
+
869
+ Args:
870
+ handler: Connection handler. It receives the WebSocket connection,
871
+ which is a :class:`ServerConnection`, in argument.
872
+ path: File system path to the Unix socket.
873
+
874
+ """
875
+ return serve(handler, unix=True, path=path, **kwargs)
876
+
877
+
878
+ def is_credentials(credentials: Any) -> bool:
879
+ try:
880
+ username, password = credentials
881
+ except (TypeError, ValueError):
882
+ return False
883
+ else:
884
+ return isinstance(username, str) and isinstance(password, str)
885
+
886
+
887
+ def basic_auth(
888
+ realm: str = "",
889
+ credentials: tuple[str, str] | Iterable[tuple[str, str]] | None = None,
890
+ check_credentials: Callable[[str, str], Awaitable[bool] | bool] | None = None,
891
+ ) -> Callable[[ServerConnection, Request], Awaitable[Response | None]]:
892
+ """
893
+ Factory for ``process_request`` to enforce HTTP Basic Authentication.
894
+
895
+ :func:`basic_auth` is designed to integrate with :func:`serve` as follows::
896
+
897
+ from websockets.asyncio.server import basic_auth, serve
898
+
899
+ async with serve(
900
+ ...,
901
+ process_request=basic_auth(
902
+ realm="my dev server",
903
+ credentials=("hello", "iloveyou"),
904
+ ),
905
+ ):
906
+
907
+ If authentication succeeds, the connection's ``username`` attribute is set.
908
+ If it fails, the server responds with an HTTP 401 Unauthorized status.
909
+
910
+ One of ``credentials`` or ``check_credentials`` must be provided; not both.
911
+
912
+ Args:
913
+ realm: Scope of protection. It should contain only ASCII characters
914
+ because the encoding of non-ASCII characters is undefined. Refer to
915
+ section 2.2 of :rfc:`7235` for details.
916
+ credentials: Hard coded authorized credentials. It can be a
917
+ ``(username, password)`` pair or a list of such pairs.
918
+ check_credentials: Function or coroutine that verifies credentials.
919
+ It receives ``username`` and ``password`` arguments and returns
920
+ whether they're valid.
921
+ Raises:
922
+ TypeError: If ``credentials`` or ``check_credentials`` is wrong.
923
+ ValueError: If ``credentials`` and ``check_credentials`` are both
924
+ provided or both not provided.
925
+
926
+ """
927
+ if (credentials is None) == (check_credentials is None):
928
+ raise ValueError("provide either credentials or check_credentials")
929
+
930
+ if credentials is not None:
931
+ if is_credentials(credentials):
932
+ credentials_list = [cast(tuple[str, str], credentials)]
933
+ elif isinstance(credentials, Iterable):
934
+ credentials_list = list(cast(Iterable[tuple[str, str]], credentials))
935
+ if not all(is_credentials(item) for item in credentials_list):
936
+ raise TypeError(f"invalid credentials argument: {credentials}")
937
+ else:
938
+ raise TypeError(f"invalid credentials argument: {credentials}")
939
+
940
+ credentials_dict = dict(credentials_list)
941
+
942
+ def check_credentials(username: str, password: str) -> bool:
943
+ try:
944
+ expected_password = credentials_dict[username]
945
+ except KeyError:
946
+ return False
947
+ return hmac.compare_digest(expected_password, password)
948
+
949
+ assert check_credentials is not None # help mypy
950
+
951
+ async def process_request(
952
+ connection: ServerConnection,
953
+ request: Request,
954
+ ) -> Response | None:
955
+ """
956
+ Perform HTTP Basic Authentication.
957
+
958
+ If it succeeds, set the connection's ``username`` attribute and return
959
+ :obj:`None`. If it fails, return an HTTP 401 Unauthorized responss.
960
+
961
+ """
962
+ try:
963
+ authorization = request.headers["Authorization"]
964
+ except KeyError:
965
+ response = connection.respond(
966
+ http.HTTPStatus.UNAUTHORIZED,
967
+ "Missing credentials\n",
968
+ )
969
+ response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
970
+ return response
971
+
972
+ try:
973
+ username, password = parse_authorization_basic(authorization)
974
+ except InvalidHeader:
975
+ response = connection.respond(
976
+ http.HTTPStatus.UNAUTHORIZED,
977
+ "Unsupported credentials\n",
978
+ )
979
+ response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
980
+ return response
981
+
982
+ valid_credentials = check_credentials(username, password)
983
+ if isinstance(valid_credentials, Awaitable):
984
+ valid_credentials = await valid_credentials
985
+
986
+ if not valid_credentials:
987
+ response = connection.respond(
988
+ http.HTTPStatus.UNAUTHORIZED,
989
+ "Invalid credentials\n",
990
+ )
991
+ response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
992
+ return response
993
+
994
+ connection.username = username
995
+ return None
996
+
997
+ return process_request
source/websockets/auth.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+
5
+
6
+ with warnings.catch_warnings():
7
+ # Suppress redundant DeprecationWarning raised by websockets.legacy.
8
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
9
+ from .legacy.auth import *
10
+ from .legacy.auth import __all__ # noqa: F401
11
+
12
+
13
+ warnings.warn( # deprecated in 14.0 - 2024-11-09
14
+ "websockets.auth, an alias for websockets.legacy.auth, is deprecated; "
15
+ "see https://websockets.readthedocs.io/en/stable/howto/upgrade.html "
16
+ "for upgrade instructions",
17
+ DeprecationWarning,
18
+ )
source/websockets/cli.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import asyncio
5
+ import os
6
+ import sys
7
+ from typing import Generator
8
+
9
+ from .asyncio.client import ClientConnection, connect
10
+ from .asyncio.messages import SimpleQueue
11
+ from .exceptions import ConnectionClosed
12
+ from .frames import Close
13
+ from .streams import StreamReader
14
+ from .version import version as websockets_version
15
+
16
+
17
+ __all__ = ["main"]
18
+
19
+
20
+ def print_during_input(string: str) -> None:
21
+ sys.stdout.write(
22
+ # Save cursor position
23
+ "\N{ESC}7"
24
+ # Add a new line
25
+ "\N{LINE FEED}"
26
+ # Move cursor up
27
+ "\N{ESC}[A"
28
+ # Insert blank line, scroll last line down
29
+ "\N{ESC}[L"
30
+ # Print string in the inserted blank line
31
+ f"{string}\N{LINE FEED}"
32
+ # Restore cursor position
33
+ "\N{ESC}8"
34
+ # Move cursor down
35
+ "\N{ESC}[B"
36
+ )
37
+ sys.stdout.flush()
38
+
39
+
40
+ def print_over_input(string: str) -> None:
41
+ sys.stdout.write(
42
+ # Move cursor to beginning of line
43
+ "\N{CARRIAGE RETURN}"
44
+ # Delete current line
45
+ "\N{ESC}[K"
46
+ # Print string
47
+ f"{string}\N{LINE FEED}"
48
+ )
49
+ sys.stdout.flush()
50
+
51
+
52
+ class ReadLines(asyncio.Protocol):
53
+ def __init__(self) -> None:
54
+ self.reader = StreamReader()
55
+ self.messages: SimpleQueue[str] = SimpleQueue()
56
+
57
+ def parse(self) -> Generator[None, None, None]:
58
+ while True:
59
+ sys.stdout.write("> ")
60
+ sys.stdout.flush()
61
+ line = yield from self.reader.read_line(sys.maxsize)
62
+ self.messages.put(line.decode().rstrip("\r\n"))
63
+
64
+ def connection_made(self, transport: asyncio.BaseTransport) -> None:
65
+ self.parser = self.parse()
66
+ next(self.parser)
67
+
68
+ def data_received(self, data: bytes) -> None:
69
+ self.reader.feed_data(data)
70
+ next(self.parser)
71
+
72
+ def eof_received(self) -> None:
73
+ self.reader.feed_eof()
74
+ # next(self.parser) isn't useful and would raise EOFError.
75
+
76
+ def connection_lost(self, exc: Exception | None) -> None:
77
+ self.reader.discard()
78
+ self.messages.abort()
79
+
80
+
81
+ async def print_incoming_messages(websocket: ClientConnection) -> None:
82
+ async for message in websocket:
83
+ if isinstance(message, str):
84
+ print_during_input("< " + message)
85
+ else:
86
+ print_during_input("< (binary) " + message.hex())
87
+
88
+
89
+ async def send_outgoing_messages(
90
+ websocket: ClientConnection,
91
+ messages: SimpleQueue[str],
92
+ ) -> None:
93
+ while True:
94
+ try:
95
+ message = await messages.get()
96
+ except EOFError:
97
+ break
98
+ try:
99
+ await websocket.send(message)
100
+ except ConnectionClosed: # pragma: no cover
101
+ break
102
+
103
+
104
+ async def interactive_client(uri: str) -> None:
105
+ try:
106
+ websocket = await connect(uri)
107
+ except Exception as exc:
108
+ print(f"Failed to connect to {uri}: {exc}.")
109
+ sys.exit(1)
110
+ else:
111
+ print(f"Connected to {uri}.")
112
+
113
+ loop = asyncio.get_running_loop()
114
+ transport, protocol = await loop.connect_read_pipe(ReadLines, sys.stdin)
115
+ incoming = asyncio.create_task(
116
+ print_incoming_messages(websocket),
117
+ )
118
+ outgoing = asyncio.create_task(
119
+ send_outgoing_messages(websocket, protocol.messages),
120
+ )
121
+ try:
122
+ await asyncio.wait(
123
+ [incoming, outgoing],
124
+ # Clean up and exit when the server closes the connection
125
+ # or the user enters EOT (^D), whichever happens first.
126
+ return_when=asyncio.FIRST_COMPLETED,
127
+ )
128
+ # asyncio.run() cancels the main task when the user triggers SIGINT (^C).
129
+ # https://docs.python.org/3/library/asyncio-runner.html#handling-keyboard-interruption
130
+ # Clean up and exit without re-raising CancelledError to prevent Python
131
+ # from raising KeyboardInterrupt and displaying a stack track.
132
+ except asyncio.CancelledError: # pragma: no cover
133
+ pass
134
+ finally:
135
+ incoming.cancel()
136
+ outgoing.cancel()
137
+ transport.close()
138
+
139
+ await websocket.close()
140
+ assert websocket.close_code is not None and websocket.close_reason is not None
141
+ close_status = Close(websocket.close_code, websocket.close_reason)
142
+ print_over_input(f"Connection closed: {close_status}.")
143
+
144
+
145
+ def main(argv: list[str] | None = None) -> None:
146
+ parser = argparse.ArgumentParser(
147
+ prog="websockets",
148
+ description="Interactive WebSocket client.",
149
+ add_help=False,
150
+ )
151
+ group = parser.add_mutually_exclusive_group()
152
+ group.add_argument("--version", action="store_true")
153
+ group.add_argument("uri", metavar="<uri>", nargs="?")
154
+ args = parser.parse_args(argv)
155
+
156
+ if args.version:
157
+ print(f"websockets {websockets_version}")
158
+ return
159
+
160
+ if args.uri is None:
161
+ parser.print_usage()
162
+ sys.exit(2)
163
+
164
+ # Enable VT100 to support ANSI escape codes in Command Prompt on Windows.
165
+ # See https://github.com/python/cpython/issues/74261 for why this works.
166
+ if sys.platform == "win32":
167
+ os.system("")
168
+
169
+ try:
170
+ import readline # noqa: F401
171
+ except ImportError: # readline isn't available on all platforms
172
+ pass
173
+
174
+ # Remove the try/except block when dropping Python < 3.11.
175
+ try:
176
+ asyncio.run(interactive_client(args.uri))
177
+ except KeyboardInterrupt: # pragma: no cover
178
+ pass
source/websockets/client.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import random
5
+ import warnings
6
+ from collections.abc import Generator, Sequence
7
+ from typing import Any
8
+
9
+ from .datastructures import Headers, MultipleValuesError
10
+ from .exceptions import (
11
+ InvalidHandshake,
12
+ InvalidHeader,
13
+ InvalidHeaderValue,
14
+ InvalidMessage,
15
+ InvalidStatus,
16
+ InvalidUpgrade,
17
+ NegotiationError,
18
+ )
19
+ from .extensions import ClientExtensionFactory, Extension
20
+ from .headers import (
21
+ build_authorization_basic,
22
+ build_extension,
23
+ build_host,
24
+ build_subprotocol,
25
+ parse_connection,
26
+ parse_extension,
27
+ parse_subprotocol,
28
+ parse_upgrade,
29
+ )
30
+ from .http11 import Request, Response
31
+ from .imports import lazy_import
32
+ from .protocol import CLIENT, CONNECTING, OPEN, Protocol, State
33
+ from .typing import (
34
+ ConnectionOption,
35
+ ExtensionHeader,
36
+ LoggerLike,
37
+ Origin,
38
+ Subprotocol,
39
+ UpgradeProtocol,
40
+ )
41
+ from .uri import WebSocketURI
42
+ from .utils import accept_key, generate_key
43
+
44
+
45
+ __all__ = ["ClientProtocol"]
46
+
47
+
48
+ class ClientProtocol(Protocol):
49
+ """
50
+ Sans-I/O implementation of a WebSocket client connection.
51
+
52
+ Args:
53
+ uri: URI of the WebSocket server, parsed
54
+ with :func:`~websockets.uri.parse_uri`.
55
+ origin: Value of the ``Origin`` header. This is useful when connecting
56
+ to a server that validates the ``Origin`` header to defend against
57
+ Cross-Site WebSocket Hijacking attacks.
58
+ extensions: List of supported extensions, in order in which they
59
+ should be tried.
60
+ subprotocols: List of supported subprotocols, in order of decreasing
61
+ preference.
62
+ state: Initial state of the WebSocket connection.
63
+ max_size: Maximum size of incoming messages in bytes.
64
+ :obj:`None` disables the limit. You may pass a ``(max_message_size,
65
+ max_fragment_size)`` tuple to set different limits for messages and
66
+ fragments when you expect long messages sent in short fragments.
67
+ logger: Logger for this connection;
68
+ defaults to ``logging.getLogger("websockets.client")``;
69
+ see the :doc:`logging guide <../../topics/logging>` for details.
70
+
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ uri: WebSocketURI,
76
+ *,
77
+ origin: Origin | None = None,
78
+ extensions: Sequence[ClientExtensionFactory] | None = None,
79
+ subprotocols: Sequence[Subprotocol] | None = None,
80
+ state: State = CONNECTING,
81
+ max_size: int | None | tuple[int | None, int | None] = 2**20,
82
+ logger: LoggerLike | None = None,
83
+ ) -> None:
84
+ super().__init__(
85
+ side=CLIENT,
86
+ state=state,
87
+ max_size=max_size,
88
+ logger=logger,
89
+ )
90
+ self.uri = uri
91
+ self.origin = origin
92
+ self.available_extensions = extensions
93
+ self.available_subprotocols = subprotocols
94
+ self.key = generate_key()
95
+
96
+ def connect(self) -> Request:
97
+ """
98
+ Create a handshake request to open a connection.
99
+
100
+ You must send the handshake request with :meth:`send_request`.
101
+
102
+ You can modify it before sending it, for example to add HTTP headers.
103
+
104
+ Returns:
105
+ WebSocket handshake request event to send to the server.
106
+
107
+ """
108
+ headers = Headers()
109
+ headers["Host"] = build_host(self.uri.host, self.uri.port, self.uri.secure)
110
+ if self.uri.user_info:
111
+ headers["Authorization"] = build_authorization_basic(*self.uri.user_info)
112
+ if self.origin is not None:
113
+ headers["Origin"] = self.origin
114
+ headers["Upgrade"] = "websocket"
115
+ headers["Connection"] = "Upgrade"
116
+ headers["Sec-WebSocket-Key"] = self.key
117
+ headers["Sec-WebSocket-Version"] = "13"
118
+ if self.available_extensions is not None:
119
+ headers["Sec-WebSocket-Extensions"] = build_extension(
120
+ [
121
+ (extension_factory.name, extension_factory.get_request_params())
122
+ for extension_factory in self.available_extensions
123
+ ]
124
+ )
125
+ if self.available_subprotocols is not None:
126
+ headers["Sec-WebSocket-Protocol"] = build_subprotocol(
127
+ self.available_subprotocols
128
+ )
129
+ return Request(self.uri.resource_name, headers)
130
+
131
+ def process_response(self, response: Response) -> None:
132
+ """
133
+ Check a handshake response.
134
+
135
+ Args:
136
+ request: WebSocket handshake response received from the server.
137
+
138
+ Raises:
139
+ InvalidHandshake: If the handshake response is invalid.
140
+
141
+ """
142
+
143
+ if response.status_code != 101:
144
+ raise InvalidStatus(response)
145
+
146
+ headers = response.headers
147
+
148
+ connection: list[ConnectionOption] = sum(
149
+ [parse_connection(value) for value in headers.get_all("Connection")], []
150
+ )
151
+ if not any(value.lower() == "upgrade" for value in connection):
152
+ raise InvalidUpgrade(
153
+ "Connection", ", ".join(connection) if connection else None
154
+ )
155
+
156
+ upgrade: list[UpgradeProtocol] = sum(
157
+ [parse_upgrade(value) for value in headers.get_all("Upgrade")], []
158
+ )
159
+ # For compatibility with non-strict implementations, ignore case when
160
+ # checking the Upgrade header. It's supposed to be 'WebSocket'.
161
+ if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
162
+ raise InvalidUpgrade("Upgrade", ", ".join(upgrade) if upgrade else None)
163
+
164
+ try:
165
+ s_w_accept = headers["Sec-WebSocket-Accept"]
166
+ except KeyError:
167
+ raise InvalidHeader("Sec-WebSocket-Accept") from None
168
+ except MultipleValuesError:
169
+ raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from None
170
+ if s_w_accept != accept_key(self.key):
171
+ raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept)
172
+
173
+ self.extensions = self.process_extensions(headers)
174
+ self.subprotocol = self.process_subprotocol(headers)
175
+
176
+ def process_extensions(self, headers: Headers) -> list[Extension]:
177
+ """
178
+ Handle the Sec-WebSocket-Extensions HTTP response header.
179
+
180
+ Check that each extension is supported, as well as its parameters.
181
+
182
+ :rfc:`6455` leaves the rules up to the specification of each
183
+ extension.
184
+
185
+ To provide this level of flexibility, for each extension accepted by
186
+ the server, we check for a match with each extension available in the
187
+ client configuration. If no match is found, an exception is raised.
188
+
189
+ If several variants of the same extension are accepted by the server,
190
+ it may be configured several times, which won't make sense in general.
191
+ Extensions must implement their own requirements. For this purpose,
192
+ the list of previously accepted extensions is provided.
193
+
194
+ Other requirements, for example related to mandatory extensions or the
195
+ order of extensions, may be implemented by overriding this method.
196
+
197
+ Args:
198
+ headers: WebSocket handshake response headers.
199
+
200
+ Returns:
201
+ List of accepted extensions.
202
+
203
+ Raises:
204
+ InvalidHandshake: To abort the handshake.
205
+
206
+ """
207
+ accepted_extensions: list[Extension] = []
208
+
209
+ extensions = headers.get_all("Sec-WebSocket-Extensions")
210
+
211
+ if extensions:
212
+ if self.available_extensions is None:
213
+ raise NegotiationError("no extensions supported")
214
+
215
+ parsed_extensions: list[ExtensionHeader] = sum(
216
+ [parse_extension(header_value) for header_value in extensions], []
217
+ )
218
+
219
+ for name, response_params in parsed_extensions:
220
+ for extension_factory in self.available_extensions:
221
+ # Skip non-matching extensions based on their name.
222
+ if extension_factory.name != name:
223
+ continue
224
+
225
+ # Skip non-matching extensions based on their params.
226
+ try:
227
+ extension = extension_factory.process_response_params(
228
+ response_params, accepted_extensions
229
+ )
230
+ except NegotiationError:
231
+ continue
232
+
233
+ # Add matching extension to the final list.
234
+ accepted_extensions.append(extension)
235
+
236
+ # Break out of the loop once we have a match.
237
+ break
238
+
239
+ # If we didn't break from the loop, no extension in our list
240
+ # matched what the server sent. Fail the connection.
241
+ else:
242
+ raise NegotiationError(
243
+ f"Unsupported extension: "
244
+ f"name = {name}, params = {response_params}"
245
+ )
246
+
247
+ return accepted_extensions
248
+
249
+ def process_subprotocol(self, headers: Headers) -> Subprotocol | None:
250
+ """
251
+ Handle the Sec-WebSocket-Protocol HTTP response header.
252
+
253
+ If provided, check that it contains exactly one supported subprotocol.
254
+
255
+ Args:
256
+ headers: WebSocket handshake response headers.
257
+
258
+ Returns:
259
+ Subprotocol, if one was selected.
260
+
261
+ """
262
+ subprotocol: Subprotocol | None = None
263
+
264
+ subprotocols = headers.get_all("Sec-WebSocket-Protocol")
265
+
266
+ if subprotocols:
267
+ if self.available_subprotocols is None:
268
+ raise NegotiationError("no subprotocols supported")
269
+
270
+ parsed_subprotocols: Sequence[Subprotocol] = sum(
271
+ [parse_subprotocol(header_value) for header_value in subprotocols], []
272
+ )
273
+ if len(parsed_subprotocols) > 1:
274
+ raise InvalidHeader(
275
+ "Sec-WebSocket-Protocol",
276
+ f"multiple values: {', '.join(parsed_subprotocols)}",
277
+ )
278
+
279
+ subprotocol = parsed_subprotocols[0]
280
+ if subprotocol not in self.available_subprotocols:
281
+ raise NegotiationError(f"unsupported subprotocol: {subprotocol}")
282
+
283
+ return subprotocol
284
+
285
+ def send_request(self, request: Request) -> None:
286
+ """
287
+ Send a handshake request to the server.
288
+
289
+ Args:
290
+ request: WebSocket handshake request event.
291
+
292
+ """
293
+ if self.debug:
294
+ self.logger.debug("> GET %s HTTP/1.1", request.path)
295
+ for key, value in request.headers.raw_items():
296
+ self.logger.debug("> %s: %s", key, value)
297
+
298
+ self.writes.append(request.serialize())
299
+
300
+ def parse(self) -> Generator[None]:
301
+ if self.state is CONNECTING:
302
+ try:
303
+ response = yield from Response.parse(
304
+ self.reader.read_line,
305
+ self.reader.read_exact,
306
+ self.reader.read_to_eof,
307
+ )
308
+ except Exception as exc:
309
+ self.handshake_exc = InvalidMessage(
310
+ "did not receive a valid HTTP response"
311
+ )
312
+ self.handshake_exc.__cause__ = exc
313
+ self.send_eof()
314
+ self.parser = self.discard()
315
+ next(self.parser) # start coroutine
316
+ yield
317
+
318
+ if self.debug:
319
+ code, phrase = response.status_code, response.reason_phrase
320
+ self.logger.debug("< HTTP/1.1 %d %s", code, phrase)
321
+ for key, value in response.headers.raw_items():
322
+ self.logger.debug("< %s: %s", key, value)
323
+ if response.body:
324
+ self.logger.debug("< [body] (%d bytes)", len(response.body))
325
+
326
+ try:
327
+ self.process_response(response)
328
+ except InvalidHandshake as exc:
329
+ response._exception = exc
330
+ self.events.append(response)
331
+ self.handshake_exc = exc
332
+ self.send_eof()
333
+ self.parser = self.discard()
334
+ next(self.parser) # start coroutine
335
+ yield
336
+
337
+ assert self.state is CONNECTING
338
+ self.state = OPEN
339
+ self.events.append(response)
340
+
341
+ yield from super().parse()
342
+
343
+
344
+ class ClientConnection(ClientProtocol):
345
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
346
+ warnings.warn( # deprecated in 11.0 - 2023-04-02
347
+ "ClientConnection was renamed to ClientProtocol",
348
+ DeprecationWarning,
349
+ )
350
+ super().__init__(*args, **kwargs)
351
+
352
+
353
+ BACKOFF_INITIAL_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_INITIAL_DELAY", "5"))
354
+ BACKOFF_MIN_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_MIN_DELAY", "3.1"))
355
+ BACKOFF_MAX_DELAY = float(os.environ.get("WEBSOCKETS_BACKOFF_MAX_DELAY", "90.0"))
356
+ BACKOFF_FACTOR = float(os.environ.get("WEBSOCKETS_BACKOFF_FACTOR", "1.618"))
357
+
358
+
359
+ def backoff(
360
+ initial_delay: float = BACKOFF_INITIAL_DELAY,
361
+ min_delay: float = BACKOFF_MIN_DELAY,
362
+ max_delay: float = BACKOFF_MAX_DELAY,
363
+ factor: float = BACKOFF_FACTOR,
364
+ ) -> Generator[float]:
365
+ """
366
+ Generate a series of backoff delays between reconnection attempts.
367
+
368
+ Yields:
369
+ How many seconds to wait before retrying to connect.
370
+
371
+ """
372
+ # Add a random initial delay between 0 and 5 seconds.
373
+ # See 7.2.3. Recovering from Abnormal Closure in RFC 6455.
374
+ yield random.random() * initial_delay
375
+ delay = min_delay
376
+ while delay < max_delay:
377
+ yield delay
378
+ delay *= factor
379
+ while True:
380
+ yield max_delay
381
+
382
+
383
+ lazy_import(
384
+ globals(),
385
+ deprecated_aliases={
386
+ # deprecated in 14.0 - 2024-11-09
387
+ "WebSocketClientProtocol": ".legacy.client",
388
+ "connect": ".legacy.client",
389
+ "unix_connect": ".legacy.client",
390
+ },
391
+ )
source/websockets/connection.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+
5
+ from .protocol import SEND_EOF, Protocol as Connection, Side, State # noqa: F401
6
+
7
+
8
+ warnings.warn( # deprecated in 11.0 - 2023-04-02
9
+ "websockets.connection was renamed to websockets.protocol "
10
+ "and Connection was renamed to Protocol",
11
+ DeprecationWarning,
12
+ )
source/websockets/datastructures.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Iterable, Iterator, Mapping, MutableMapping
4
+ from typing import Any, Protocol
5
+
6
+
7
+ __all__ = [
8
+ "Headers",
9
+ "HeadersLike",
10
+ "MultipleValuesError",
11
+ ]
12
+
13
+
14
+ class MultipleValuesError(LookupError):
15
+ """
16
+ Exception raised when :class:`Headers` has multiple values for a key.
17
+
18
+ """
19
+
20
+ def __str__(self) -> str:
21
+ # Implement the same logic as KeyError_str in Objects/exceptions.c.
22
+ if len(self.args) == 1:
23
+ return repr(self.args[0])
24
+ return super().__str__()
25
+
26
+
27
+ class Headers(MutableMapping[str, str]):
28
+ """
29
+ Efficient data structure for manipulating HTTP headers.
30
+
31
+ A :class:`list` of ``(name, values)`` is inefficient for lookups.
32
+
33
+ A :class:`dict` doesn't suffice because header names are case-insensitive
34
+ and multiple occurrences of headers with the same name are possible.
35
+
36
+ :class:`Headers` stores HTTP headers in a hybrid data structure to provide
37
+ efficient insertions and lookups while preserving the original data.
38
+
39
+ In order to account for multiple values with minimal hassle,
40
+ :class:`Headers` follows this logic:
41
+
42
+ - When getting a header with ``headers[name]``:
43
+ - if there's no value, :exc:`KeyError` is raised;
44
+ - if there's exactly one value, it's returned;
45
+ - if there's more than one value, :exc:`MultipleValuesError` is raised.
46
+
47
+ - When setting a header with ``headers[name] = value``, the value is
48
+ appended to the list of values for that header.
49
+
50
+ - When deleting a header with ``del headers[name]``, all values for that
51
+ header are removed (this is slow).
52
+
53
+ Other methods for manipulating headers are consistent with this logic.
54
+
55
+ As long as no header occurs multiple times, :class:`Headers` behaves like
56
+ :class:`dict`, except keys are lower-cased to provide case-insensitivity.
57
+
58
+ Two methods support manipulating multiple values explicitly:
59
+
60
+ - :meth:`get_all` returns a list of all values for a header;
61
+ - :meth:`raw_items` returns an iterator of ``(name, values)`` pairs.
62
+
63
+ """
64
+
65
+ __slots__ = ["_dict", "_list"]
66
+
67
+ # Like dict, Headers accepts an optional "mapping or iterable" argument.
68
+ def __init__(self, *args: HeadersLike, **kwargs: str) -> None:
69
+ self._dict: dict[str, list[str]] = {}
70
+ self._list: list[tuple[str, str]] = []
71
+ self.update(*args, **kwargs)
72
+
73
+ def __str__(self) -> str:
74
+ return "".join(f"{key}: {value}\r\n" for key, value in self._list) + "\r\n"
75
+
76
+ def __repr__(self) -> str:
77
+ return f"{self.__class__.__name__}({self._list!r})"
78
+
79
+ def copy(self) -> Headers:
80
+ copy = self.__class__()
81
+ copy._dict = self._dict.copy()
82
+ copy._list = self._list.copy()
83
+ return copy
84
+
85
+ def serialize(self) -> bytes:
86
+ # Since headers only contain ASCII characters, we can keep this simple.
87
+ return str(self).encode()
88
+
89
+ # Collection methods
90
+
91
+ def __contains__(self, key: object) -> bool:
92
+ return isinstance(key, str) and key.lower() in self._dict
93
+
94
+ def __iter__(self) -> Iterator[str]:
95
+ return iter(self._dict)
96
+
97
+ def __len__(self) -> int:
98
+ return len(self._dict)
99
+
100
+ # MutableMapping methods
101
+
102
+ def __getitem__(self, key: str) -> str:
103
+ value = self._dict[key.lower()]
104
+ if len(value) == 1:
105
+ return value[0]
106
+ else:
107
+ raise MultipleValuesError(key)
108
+
109
+ def __setitem__(self, key: str, value: str) -> None:
110
+ self._dict.setdefault(key.lower(), []).append(value)
111
+ self._list.append((key, value))
112
+
113
+ def __delitem__(self, key: str) -> None:
114
+ key_lower = key.lower()
115
+ self._dict.__delitem__(key_lower)
116
+ # This is inefficient. Fortunately deleting HTTP headers is uncommon.
117
+ self._list = [(k, v) for k, v in self._list if k.lower() != key_lower]
118
+
119
+ def __eq__(self, other: Any) -> bool:
120
+ if not isinstance(other, Headers):
121
+ return NotImplemented
122
+ return self._dict == other._dict
123
+
124
+ def clear(self) -> None:
125
+ """
126
+ Remove all headers.
127
+
128
+ """
129
+ self._dict = {}
130
+ self._list = []
131
+
132
+ def update(self, *args: HeadersLike, **kwargs: str) -> None:
133
+ """
134
+ Update from a :class:`Headers` instance and/or keyword arguments.
135
+
136
+ """
137
+ args = tuple(
138
+ arg.raw_items() if isinstance(arg, Headers) else arg for arg in args
139
+ )
140
+ super().update(*args, **kwargs)
141
+
142
+ # Methods for handling multiple values
143
+
144
+ def get_all(self, key: str) -> list[str]:
145
+ """
146
+ Return the (possibly empty) list of all values for a header.
147
+
148
+ Args:
149
+ key: Header name.
150
+
151
+ """
152
+ return self._dict.get(key.lower(), [])
153
+
154
+ def raw_items(self) -> Iterator[tuple[str, str]]:
155
+ """
156
+ Return an iterator of all values as ``(name, value)`` pairs.
157
+
158
+ """
159
+ return iter(self._list)
160
+
161
+
162
+ # copy of _typeshed.SupportsKeysAndGetItem.
163
+ class SupportsKeysAndGetItem(Protocol): # pragma: no cover
164
+ """
165
+ Dict-like types with ``keys() -> str`` and ``__getitem__(key: str) -> str`` methods.
166
+
167
+ """
168
+
169
+ def keys(self) -> Iterable[str]: ...
170
+
171
+ def __getitem__(self, key: str) -> str: ...
172
+
173
+
174
+ HeadersLike = (
175
+ Headers | Mapping[str, str] | Iterable[tuple[str, str]] | SupportsKeysAndGetItem
176
+ )
177
+ """
178
+ Types accepted where :class:`Headers` is expected.
179
+
180
+ In addition to :class:`Headers` itself, this includes dict-like types where both
181
+ keys and values are :class:`str`.
182
+
183
+ """
source/websockets/exceptions.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ :mod:`websockets.exceptions` defines the following hierarchy of exceptions.
3
+
4
+ * :exc:`WebSocketException`
5
+ * :exc:`ConnectionClosed`
6
+ * :exc:`ConnectionClosedOK`
7
+ * :exc:`ConnectionClosedError`
8
+ * :exc:`InvalidURI`
9
+ * :exc:`InvalidProxy`
10
+ * :exc:`InvalidHandshake`
11
+ * :exc:`SecurityError`
12
+ * :exc:`ProxyError`
13
+ * :exc:`InvalidProxyMessage`
14
+ * :exc:`InvalidProxyStatus`
15
+ * :exc:`InvalidMessage`
16
+ * :exc:`InvalidStatus`
17
+ * :exc:`InvalidStatusCode` (legacy)
18
+ * :exc:`InvalidHeader`
19
+ * :exc:`InvalidHeaderFormat`
20
+ * :exc:`InvalidHeaderValue`
21
+ * :exc:`InvalidOrigin`
22
+ * :exc:`InvalidUpgrade`
23
+ * :exc:`NegotiationError`
24
+ * :exc:`DuplicateParameter`
25
+ * :exc:`InvalidParameterName`
26
+ * :exc:`InvalidParameterValue`
27
+ * :exc:`AbortHandshake` (legacy)
28
+ * :exc:`RedirectHandshake` (legacy)
29
+ * :exc:`ProtocolError` (Sans-I/O)
30
+ * :exc:`PayloadTooBig` (Sans-I/O)
31
+ * :exc:`InvalidState` (Sans-I/O)
32
+ * :exc:`ConcurrencyError`
33
+
34
+ """
35
+
36
+ from __future__ import annotations
37
+
38
+ import warnings
39
+
40
+ from .imports import lazy_import
41
+
42
+
43
+ __all__ = [
44
+ "WebSocketException",
45
+ "ConnectionClosed",
46
+ "ConnectionClosedOK",
47
+ "ConnectionClosedError",
48
+ "InvalidURI",
49
+ "InvalidProxy",
50
+ "InvalidHandshake",
51
+ "SecurityError",
52
+ "ProxyError",
53
+ "InvalidProxyMessage",
54
+ "InvalidProxyStatus",
55
+ "InvalidMessage",
56
+ "InvalidStatus",
57
+ "InvalidHeader",
58
+ "InvalidHeaderFormat",
59
+ "InvalidHeaderValue",
60
+ "InvalidOrigin",
61
+ "InvalidUpgrade",
62
+ "NegotiationError",
63
+ "DuplicateParameter",
64
+ "InvalidParameterName",
65
+ "InvalidParameterValue",
66
+ "ProtocolError",
67
+ "PayloadTooBig",
68
+ "InvalidState",
69
+ "ConcurrencyError",
70
+ ]
71
+
72
+
73
+ class WebSocketException(Exception):
74
+ """
75
+ Base class for all exceptions defined by websockets.
76
+
77
+ """
78
+
79
+
80
+ class ConnectionClosed(WebSocketException):
81
+ """
82
+ Raised when trying to interact with a closed connection.
83
+
84
+ Attributes:
85
+ rcvd: If a close frame was received, its code and reason are available
86
+ in ``rcvd.code`` and ``rcvd.reason``.
87
+ sent: If a close frame was sent, its code and reason are available
88
+ in ``sent.code`` and ``sent.reason``.
89
+ rcvd_then_sent: If close frames were received and sent, this attribute
90
+ tells in which order this happened, from the perspective of this
91
+ side of the connection.
92
+
93
+ """
94
+
95
+ def __init__(
96
+ self,
97
+ rcvd: frames.Close | None,
98
+ sent: frames.Close | None,
99
+ rcvd_then_sent: bool | None = None,
100
+ ) -> None:
101
+ self.rcvd = rcvd
102
+ self.sent = sent
103
+ self.rcvd_then_sent = rcvd_then_sent
104
+ assert (self.rcvd_then_sent is None) == (self.rcvd is None or self.sent is None)
105
+
106
+ def __str__(self) -> str:
107
+ if self.rcvd is None:
108
+ if self.sent is None:
109
+ return "no close frame received or sent"
110
+ else:
111
+ return f"sent {self.sent}; no close frame received"
112
+ else:
113
+ if self.sent is None:
114
+ return f"received {self.rcvd}; no close frame sent"
115
+ else:
116
+ if self.rcvd_then_sent:
117
+ return f"received {self.rcvd}; then sent {self.sent}"
118
+ else:
119
+ return f"sent {self.sent}; then received {self.rcvd}"
120
+
121
+ # code and reason attributes are provided for backwards-compatibility
122
+
123
+ @property
124
+ def code(self) -> int:
125
+ warnings.warn( # deprecated in 13.1 - 2024-09-21
126
+ "ConnectionClosed.code is deprecated; "
127
+ "use Protocol.close_code or ConnectionClosed.rcvd.code",
128
+ DeprecationWarning,
129
+ )
130
+ if self.rcvd is None:
131
+ return frames.CloseCode.ABNORMAL_CLOSURE
132
+ return self.rcvd.code
133
+
134
+ @property
135
+ def reason(self) -> str:
136
+ warnings.warn( # deprecated in 13.1 - 2024-09-21
137
+ "ConnectionClosed.reason is deprecated; "
138
+ "use Protocol.close_reason or ConnectionClosed.rcvd.reason",
139
+ DeprecationWarning,
140
+ )
141
+ if self.rcvd is None:
142
+ return ""
143
+ return self.rcvd.reason
144
+
145
+
146
+ class ConnectionClosedOK(ConnectionClosed):
147
+ """
148
+ Like :exc:`ConnectionClosed`, when the connection terminated properly.
149
+
150
+ A close code with code 1000 (OK) or 1001 (going away) or without a code was
151
+ received and sent.
152
+
153
+ """
154
+
155
+
156
+ class ConnectionClosedError(ConnectionClosed):
157
+ """
158
+ Like :exc:`ConnectionClosed`, when the connection terminated with an error.
159
+
160
+ A close frame with a code other than 1000 (OK) or 1001 (going away) was
161
+ received or sent, or the closing handshake didn't complete properly.
162
+
163
+ """
164
+
165
+
166
+ class InvalidURI(WebSocketException):
167
+ """
168
+ Raised when connecting to a URI that isn't a valid WebSocket URI.
169
+
170
+ """
171
+
172
+ def __init__(self, uri: str, msg: str) -> None:
173
+ self.uri = uri
174
+ self.msg = msg
175
+
176
+ def __str__(self) -> str:
177
+ return f"{self.uri} isn't a valid URI: {self.msg}"
178
+
179
+
180
+ class InvalidProxy(WebSocketException):
181
+ """
182
+ Raised when connecting via a proxy that isn't valid.
183
+
184
+ """
185
+
186
+ def __init__(self, proxy: str, msg: str) -> None:
187
+ self.proxy = proxy
188
+ self.msg = msg
189
+
190
+ def __str__(self) -> str:
191
+ return f"{self.proxy} isn't a valid proxy: {self.msg}"
192
+
193
+
194
+ class InvalidHandshake(WebSocketException):
195
+ """
196
+ Base class for exceptions raised when the opening handshake fails.
197
+
198
+ """
199
+
200
+
201
+ class SecurityError(InvalidHandshake):
202
+ """
203
+ Raised when a handshake request or response breaks a security rule.
204
+
205
+ Security limits can be configured with :doc:`environment variables
206
+ <../reference/variables>`.
207
+
208
+ """
209
+
210
+
211
+ class ProxyError(InvalidHandshake):
212
+ """
213
+ Raised when failing to connect to a proxy.
214
+
215
+ """
216
+
217
+
218
+ class InvalidProxyMessage(ProxyError):
219
+ """
220
+ Raised when an HTTP proxy response is malformed.
221
+
222
+ """
223
+
224
+
225
+ class InvalidProxyStatus(ProxyError):
226
+ """
227
+ Raised when an HTTP proxy rejects the connection.
228
+
229
+ """
230
+
231
+ def __init__(self, response: http11.Response) -> None:
232
+ self.response = response
233
+
234
+ def __str__(self) -> str:
235
+ return f"proxy rejected connection: HTTP {self.response.status_code:d}"
236
+
237
+
238
+ class InvalidMessage(InvalidHandshake):
239
+ """
240
+ Raised when a handshake request or response is malformed.
241
+
242
+ """
243
+
244
+
245
+ class InvalidStatus(InvalidHandshake):
246
+ """
247
+ Raised when a handshake response rejects the WebSocket upgrade.
248
+
249
+ """
250
+
251
+ def __init__(self, response: http11.Response) -> None:
252
+ self.response = response
253
+
254
+ def __str__(self) -> str:
255
+ return (
256
+ f"server rejected WebSocket connection: HTTP {self.response.status_code:d}"
257
+ )
258
+
259
+
260
+ class InvalidHeader(InvalidHandshake):
261
+ """
262
+ Raised when an HTTP header doesn't have a valid format or value.
263
+
264
+ """
265
+
266
+ def __init__(self, name: str, value: str | None = None) -> None:
267
+ self.name = name
268
+ self.value = value
269
+
270
+ def __str__(self) -> str:
271
+ if self.value is None:
272
+ return f"missing {self.name} header"
273
+ elif self.value == "":
274
+ return f"empty {self.name} header"
275
+ else:
276
+ return f"invalid {self.name} header: {self.value}"
277
+
278
+
279
+ class InvalidHeaderFormat(InvalidHeader):
280
+ """
281
+ Raised when an HTTP header cannot be parsed.
282
+
283
+ The format of the header doesn't match the grammar for that header.
284
+
285
+ """
286
+
287
+ def __init__(self, name: str, error: str, header: str, pos: int) -> None:
288
+ super().__init__(name, f"{error} at {pos} in {header}")
289
+
290
+
291
+ class InvalidHeaderValue(InvalidHeader):
292
+ """
293
+ Raised when an HTTP header has a wrong value.
294
+
295
+ The format of the header is correct but the value isn't acceptable.
296
+
297
+ """
298
+
299
+
300
+ class InvalidOrigin(InvalidHeader):
301
+ """
302
+ Raised when the Origin header in a request isn't allowed.
303
+
304
+ """
305
+
306
+ def __init__(self, origin: str | None) -> None:
307
+ super().__init__("Origin", origin)
308
+
309
+
310
+ class InvalidUpgrade(InvalidHeader):
311
+ """
312
+ Raised when the Upgrade or Connection header isn't correct.
313
+
314
+ """
315
+
316
+
317
+ class NegotiationError(InvalidHandshake):
318
+ """
319
+ Raised when negotiating an extension or a subprotocol fails.
320
+
321
+ """
322
+
323
+
324
+ class DuplicateParameter(NegotiationError):
325
+ """
326
+ Raised when a parameter name is repeated in an extension header.
327
+
328
+ """
329
+
330
+ def __init__(self, name: str) -> None:
331
+ self.name = name
332
+
333
+ def __str__(self) -> str:
334
+ return f"duplicate parameter: {self.name}"
335
+
336
+
337
+ class InvalidParameterName(NegotiationError):
338
+ """
339
+ Raised when a parameter name in an extension header is invalid.
340
+
341
+ """
342
+
343
+ def __init__(self, name: str) -> None:
344
+ self.name = name
345
+
346
+ def __str__(self) -> str:
347
+ return f"invalid parameter name: {self.name}"
348
+
349
+
350
+ class InvalidParameterValue(NegotiationError):
351
+ """
352
+ Raised when a parameter value in an extension header is invalid.
353
+
354
+ """
355
+
356
+ def __init__(self, name: str, value: str | None) -> None:
357
+ self.name = name
358
+ self.value = value
359
+
360
+ def __str__(self) -> str:
361
+ if self.value is None:
362
+ return f"missing value for parameter {self.name}"
363
+ elif self.value == "":
364
+ return f"empty value for parameter {self.name}"
365
+ else:
366
+ return f"invalid value for parameter {self.name}: {self.value}"
367
+
368
+
369
+ class ProtocolError(WebSocketException):
370
+ """
371
+ Raised when receiving or sending a frame that breaks the protocol.
372
+
373
+ The Sans-I/O implementation raises this exception when:
374
+
375
+ * receiving or sending a frame that contains invalid data;
376
+ * receiving or sending an invalid sequence of frames.
377
+
378
+ """
379
+
380
+
381
+ class PayloadTooBig(WebSocketException):
382
+ """
383
+ Raised when parsing a frame with a payload that exceeds the maximum size.
384
+
385
+ The Sans-I/O layer uses this exception internally. It doesn't bubble up to
386
+ the I/O layer.
387
+
388
+ The :meth:`~websockets.extensions.Extension.decode` method of extensions
389
+ must raise :exc:`PayloadTooBig` if decoding a frame would exceed the limit.
390
+
391
+ """
392
+
393
+ def __init__(
394
+ self,
395
+ size_or_message: int | None | str,
396
+ max_size: int | None = None,
397
+ current_size: int | None = None,
398
+ ) -> None:
399
+ if isinstance(size_or_message, str):
400
+ assert max_size is None
401
+ assert current_size is None
402
+ warnings.warn( # deprecated in 14.0 - 2024-11-09
403
+ "PayloadTooBig(message) is deprecated; "
404
+ "change to PayloadTooBig(size, max_size)",
405
+ DeprecationWarning,
406
+ )
407
+ self.message: str | None = size_or_message
408
+ else:
409
+ self.message = None
410
+ self.size: int | None = size_or_message
411
+ assert max_size is not None
412
+ self.max_size: int = max_size
413
+ self.current_size: int | None = None
414
+ self.set_current_size(current_size)
415
+
416
+ def __str__(self) -> str:
417
+ if self.message is not None:
418
+ return self.message
419
+ else:
420
+ message = "frame "
421
+ if self.size is not None:
422
+ message += f"with {self.size} bytes "
423
+ if self.current_size is not None:
424
+ message += f"after reading {self.current_size} bytes "
425
+ message += f"exceeds limit of {self.max_size} bytes"
426
+ return message
427
+
428
+ def set_current_size(self, current_size: int | None) -> None:
429
+ assert self.current_size is None
430
+ if current_size is not None:
431
+ self.max_size += current_size
432
+ self.current_size = current_size
433
+
434
+
435
+ class InvalidState(WebSocketException, AssertionError):
436
+ """
437
+ Raised when sending a frame is forbidden in the current state.
438
+
439
+ Specifically, the Sans-I/O layer raises this exception when:
440
+
441
+ * sending a data frame to a connection in a state other
442
+ :attr:`~websockets.protocol.State.OPEN`;
443
+ * sending a control frame to a connection in a state other than
444
+ :attr:`~websockets.protocol.State.OPEN` or
445
+ :attr:`~websockets.protocol.State.CLOSING`.
446
+
447
+ """
448
+
449
+
450
+ class ConcurrencyError(WebSocketException, RuntimeError):
451
+ """
452
+ Raised when receiving or sending messages concurrently.
453
+
454
+ WebSocket is a connection-oriented protocol. Reads must be serialized; so
455
+ must be writes. However, reading and writing concurrently is possible.
456
+
457
+ """
458
+
459
+
460
+ # At the bottom to break import cycles created by type annotations.
461
+ from . import frames, http11 # noqa: E402
462
+
463
+
464
+ lazy_import(
465
+ globals(),
466
+ deprecated_aliases={
467
+ # deprecated in 14.0 - 2024-11-09
468
+ "AbortHandshake": ".legacy.exceptions",
469
+ "InvalidStatusCode": ".legacy.exceptions",
470
+ "RedirectHandshake": ".legacy.exceptions",
471
+ "WebSocketProtocolError": ".legacy.exceptions",
472
+ },
473
+ )
source/websockets/extensions/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .base import *
2
+
3
+
4
+ __all__ = ["Extension", "ClientExtensionFactory", "ServerExtensionFactory"]
source/websockets/extensions/base.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Sequence
4
+
5
+ from ..frames import Frame
6
+ from ..typing import ExtensionName, ExtensionParameter
7
+
8
+
9
+ __all__ = ["Extension", "ClientExtensionFactory", "ServerExtensionFactory"]
10
+
11
+
12
+ class Extension:
13
+ """
14
+ Base class for extensions.
15
+
16
+ """
17
+
18
+ name: ExtensionName
19
+ """Extension identifier."""
20
+
21
+ def decode(self, frame: Frame, *, max_size: int | None = None) -> Frame:
22
+ """
23
+ Decode an incoming frame.
24
+
25
+ Args:
26
+ frame: Incoming frame.
27
+ max_size: Maximum payload size in bytes.
28
+
29
+ Returns:
30
+ Decoded frame.
31
+
32
+ Raises:
33
+ PayloadTooBig: If decoding the payload exceeds ``max_size``.
34
+
35
+ """
36
+ raise NotImplementedError
37
+
38
+ def encode(self, frame: Frame) -> Frame:
39
+ """
40
+ Encode an outgoing frame.
41
+
42
+ Args:
43
+ frame: Outgoing frame.
44
+
45
+ Returns:
46
+ Encoded frame.
47
+
48
+ """
49
+ raise NotImplementedError
50
+
51
+
52
+ class ClientExtensionFactory:
53
+ """
54
+ Base class for client-side extension factories.
55
+
56
+ """
57
+
58
+ name: ExtensionName
59
+ """Extension identifier."""
60
+
61
+ def get_request_params(self) -> Sequence[ExtensionParameter]:
62
+ """
63
+ Build parameters to send to the server for this extension.
64
+
65
+ Returns:
66
+ Parameters to send to the server.
67
+
68
+ """
69
+ raise NotImplementedError
70
+
71
+ def process_response_params(
72
+ self,
73
+ params: Sequence[ExtensionParameter],
74
+ accepted_extensions: Sequence[Extension],
75
+ ) -> Extension:
76
+ """
77
+ Process parameters received from the server.
78
+
79
+ Args:
80
+ params: Parameters received from the server for this extension.
81
+ accepted_extensions: List of previously accepted extensions.
82
+
83
+ Returns:
84
+ An extension instance.
85
+
86
+ Raises:
87
+ NegotiationError: If parameters aren't acceptable.
88
+
89
+ """
90
+ raise NotImplementedError
91
+
92
+
93
+ class ServerExtensionFactory:
94
+ """
95
+ Base class for server-side extension factories.
96
+
97
+ """
98
+
99
+ name: ExtensionName
100
+ """Extension identifier."""
101
+
102
+ def process_request_params(
103
+ self,
104
+ params: Sequence[ExtensionParameter],
105
+ accepted_extensions: Sequence[Extension],
106
+ ) -> tuple[list[ExtensionParameter], Extension]:
107
+ """
108
+ Process parameters received from the client.
109
+
110
+ Args:
111
+ params: Parameters received from the client for this extension.
112
+ accepted_extensions: List of previously accepted extensions.
113
+
114
+ Returns:
115
+ To accept the offer, parameters to send to the client for this
116
+ extension and an extension instance.
117
+
118
+ Raises:
119
+ NegotiationError: To reject the offer, if parameters received from
120
+ the client aren't acceptable.
121
+
122
+ """
123
+ raise NotImplementedError
source/websockets/extensions/permessage_deflate.py ADDED
@@ -0,0 +1,699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import zlib
4
+ from collections.abc import Sequence
5
+ from typing import Any, Literal
6
+
7
+ from .. import frames
8
+ from ..exceptions import (
9
+ DuplicateParameter,
10
+ InvalidParameterName,
11
+ InvalidParameterValue,
12
+ NegotiationError,
13
+ PayloadTooBig,
14
+ ProtocolError,
15
+ )
16
+ from ..typing import BytesLike, ExtensionName, ExtensionParameter
17
+ from .base import ClientExtensionFactory, Extension, ServerExtensionFactory
18
+
19
+
20
+ __all__ = [
21
+ "PerMessageDeflate",
22
+ "ClientPerMessageDeflateFactory",
23
+ "enable_client_permessage_deflate",
24
+ "ServerPerMessageDeflateFactory",
25
+ "enable_server_permessage_deflate",
26
+ ]
27
+
28
+ _EMPTY_UNCOMPRESSED_BLOCK = b"\x00\x00\xff\xff"
29
+
30
+ _MAX_WINDOW_BITS_VALUES = [str(bits) for bits in range(8, 16)]
31
+
32
+
33
+ class PerMessageDeflate(Extension):
34
+ """
35
+ Per-Message Deflate extension.
36
+
37
+ """
38
+
39
+ name = ExtensionName("permessage-deflate")
40
+
41
+ def __init__(
42
+ self,
43
+ remote_no_context_takeover: bool,
44
+ local_no_context_takeover: bool,
45
+ remote_max_window_bits: int,
46
+ local_max_window_bits: int,
47
+ compress_settings: dict[Any, Any] | None = None,
48
+ ) -> None:
49
+ """
50
+ Configure the Per-Message Deflate extension.
51
+
52
+ """
53
+ if compress_settings is None:
54
+ compress_settings = {}
55
+
56
+ assert remote_no_context_takeover in [False, True]
57
+ assert local_no_context_takeover in [False, True]
58
+ assert 8 <= remote_max_window_bits <= 15
59
+ assert 8 <= local_max_window_bits <= 15
60
+ assert "wbits" not in compress_settings
61
+
62
+ self.remote_no_context_takeover = remote_no_context_takeover
63
+ self.local_no_context_takeover = local_no_context_takeover
64
+ self.remote_max_window_bits = remote_max_window_bits
65
+ self.local_max_window_bits = local_max_window_bits
66
+ self.compress_settings = compress_settings
67
+
68
+ if not self.remote_no_context_takeover:
69
+ self.decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits)
70
+
71
+ if not self.local_no_context_takeover:
72
+ self.encoder = zlib.compressobj(
73
+ wbits=-self.local_max_window_bits,
74
+ **self.compress_settings,
75
+ )
76
+
77
+ # To handle continuation frames properly, we must keep track of
78
+ # whether that initial frame was encoded.
79
+ self.decode_cont_data = False
80
+ # There's no need for self.encode_cont_data because we always encode
81
+ # outgoing frames, so it would always be True.
82
+
83
+ def __repr__(self) -> str:
84
+ return (
85
+ f"PerMessageDeflate("
86
+ f"remote_no_context_takeover={self.remote_no_context_takeover}, "
87
+ f"local_no_context_takeover={self.local_no_context_takeover}, "
88
+ f"remote_max_window_bits={self.remote_max_window_bits}, "
89
+ f"local_max_window_bits={self.local_max_window_bits})"
90
+ )
91
+
92
+ def decode(
93
+ self,
94
+ frame: frames.Frame,
95
+ *,
96
+ max_size: int | None = None,
97
+ ) -> frames.Frame:
98
+ """
99
+ Decode an incoming frame.
100
+
101
+ """
102
+ # Skip control frames.
103
+ if frame.opcode in frames.CTRL_OPCODES:
104
+ return frame
105
+
106
+ # Handle continuation data frames:
107
+ # - skip if the message isn't encoded
108
+ # - reset "decode continuation data" flag if it's a final frame
109
+ if frame.opcode is frames.OP_CONT:
110
+ if not self.decode_cont_data:
111
+ return frame
112
+ if frame.fin:
113
+ self.decode_cont_data = False
114
+
115
+ # Handle text and binary data frames:
116
+ # - skip if the message isn't encoded
117
+ # - unset the rsv1 flag on the first frame of a compressed message
118
+ # - set "decode continuation data" flag if it's a non-final frame
119
+ else:
120
+ if not frame.rsv1:
121
+ return frame
122
+ if not frame.fin:
123
+ self.decode_cont_data = True
124
+
125
+ # Re-initialize per-message decoder.
126
+ if self.remote_no_context_takeover:
127
+ self.decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits)
128
+
129
+ # Uncompress data. Protect against zip bombs by preventing zlib from
130
+ # decompressing more than max_length bytes (except when the limit is
131
+ # disabled with max_size = None).
132
+ data: BytesLike
133
+ if frame.fin and len(frame.data) < 2044:
134
+ # Profiling shows that appending four bytes, which makes a copy, is
135
+ # faster than calling decompress() again when data is less than 2kB.
136
+ data = bytes(frame.data) + _EMPTY_UNCOMPRESSED_BLOCK
137
+ else:
138
+ data = frame.data
139
+ max_length = 0 if max_size is None else max_size
140
+ try:
141
+ data = self.decoder.decompress(data, max_length)
142
+ if self.decoder.unconsumed_tail:
143
+ assert max_size is not None # help mypy
144
+ raise PayloadTooBig(None, max_size)
145
+ if frame.fin and len(frame.data) >= 2044:
146
+ # This cannot generate additional data.
147
+ self.decoder.decompress(_EMPTY_UNCOMPRESSED_BLOCK)
148
+ except zlib.error as exc:
149
+ raise ProtocolError("decompression failed") from exc
150
+
151
+ # Allow garbage collection of the decoder if it won't be reused.
152
+ if frame.fin and self.remote_no_context_takeover:
153
+ del self.decoder
154
+
155
+ return frames.Frame(
156
+ frame.opcode,
157
+ data,
158
+ frame.fin,
159
+ # Unset the rsv1 flag on the first frame of a compressed message.
160
+ False,
161
+ frame.rsv2,
162
+ frame.rsv3,
163
+ )
164
+
165
+ def encode(self, frame: frames.Frame) -> frames.Frame:
166
+ """
167
+ Encode an outgoing frame.
168
+
169
+ """
170
+ # Skip control frames.
171
+ if frame.opcode in frames.CTRL_OPCODES:
172
+ return frame
173
+
174
+ # Since we always encode messages, there's no "encode continuation
175
+ # data" flag similar to "decode continuation data" at this time.
176
+
177
+ if frame.opcode is not frames.OP_CONT:
178
+ # Re-initialize per-message decoder.
179
+ if self.local_no_context_takeover:
180
+ self.encoder = zlib.compressobj(
181
+ wbits=-self.local_max_window_bits,
182
+ **self.compress_settings,
183
+ )
184
+
185
+ # Compress data.
186
+ data: BytesLike
187
+ data = self.encoder.compress(frame.data) + self.encoder.flush(zlib.Z_SYNC_FLUSH)
188
+ if frame.fin:
189
+ # Sync flush generates between 5 or 6 bytes, ending with the bytes
190
+ # 0x00 0x00 0xff 0xff, which must be removed.
191
+ assert data[-4:] == _EMPTY_UNCOMPRESSED_BLOCK
192
+ # Making a copy is faster than memoryview(a)[:-4] until 2kB.
193
+ if len(data) < 2048:
194
+ data = data[:-4]
195
+ else:
196
+ data = memoryview(data)[:-4]
197
+
198
+ # Allow garbage collection of the encoder if it won't be reused.
199
+ if frame.fin and self.local_no_context_takeover:
200
+ del self.encoder
201
+
202
+ return frames.Frame(
203
+ frame.opcode,
204
+ data,
205
+ frame.fin,
206
+ # Set the rsv1 flag on the first frame of a compressed message.
207
+ frame.opcode is not frames.OP_CONT,
208
+ frame.rsv2,
209
+ frame.rsv3,
210
+ )
211
+
212
+
213
+ def _build_parameters(
214
+ server_no_context_takeover: bool,
215
+ client_no_context_takeover: bool,
216
+ server_max_window_bits: int | None,
217
+ client_max_window_bits: int | Literal[True] | None,
218
+ ) -> list[ExtensionParameter]:
219
+ """
220
+ Build a list of ``(name, value)`` pairs for some compression parameters.
221
+
222
+ """
223
+ params: list[ExtensionParameter] = []
224
+ if server_no_context_takeover:
225
+ params.append(("server_no_context_takeover", None))
226
+ if client_no_context_takeover:
227
+ params.append(("client_no_context_takeover", None))
228
+ if server_max_window_bits:
229
+ params.append(("server_max_window_bits", str(server_max_window_bits)))
230
+ if client_max_window_bits is True: # only in handshake requests
231
+ params.append(("client_max_window_bits", None))
232
+ elif client_max_window_bits:
233
+ params.append(("client_max_window_bits", str(client_max_window_bits)))
234
+ return params
235
+
236
+
237
+ def _extract_parameters(
238
+ params: Sequence[ExtensionParameter], *, is_server: bool
239
+ ) -> tuple[bool, bool, int | None, int | Literal[True] | None]:
240
+ """
241
+ Extract compression parameters from a list of ``(name, value)`` pairs.
242
+
243
+ If ``is_server`` is :obj:`True`, ``client_max_window_bits`` may be
244
+ provided without a value. This is only allowed in handshake requests.
245
+
246
+ """
247
+ server_no_context_takeover: bool = False
248
+ client_no_context_takeover: bool = False
249
+ server_max_window_bits: int | None = None
250
+ client_max_window_bits: int | Literal[True] | None = None
251
+
252
+ for name, value in params:
253
+ if name == "server_no_context_takeover":
254
+ if server_no_context_takeover:
255
+ raise DuplicateParameter(name)
256
+ if value is None:
257
+ server_no_context_takeover = True
258
+ else:
259
+ raise InvalidParameterValue(name, value)
260
+
261
+ elif name == "client_no_context_takeover":
262
+ if client_no_context_takeover:
263
+ raise DuplicateParameter(name)
264
+ if value is None:
265
+ client_no_context_takeover = True
266
+ else:
267
+ raise InvalidParameterValue(name, value)
268
+
269
+ elif name == "server_max_window_bits":
270
+ if server_max_window_bits is not None:
271
+ raise DuplicateParameter(name)
272
+ if value in _MAX_WINDOW_BITS_VALUES:
273
+ server_max_window_bits = int(value)
274
+ else:
275
+ raise InvalidParameterValue(name, value)
276
+
277
+ elif name == "client_max_window_bits":
278
+ if client_max_window_bits is not None:
279
+ raise DuplicateParameter(name)
280
+ if is_server and value is None: # only in handshake requests
281
+ client_max_window_bits = True
282
+ elif value in _MAX_WINDOW_BITS_VALUES:
283
+ client_max_window_bits = int(value)
284
+ else:
285
+ raise InvalidParameterValue(name, value)
286
+
287
+ else:
288
+ raise InvalidParameterName(name)
289
+
290
+ return (
291
+ server_no_context_takeover,
292
+ client_no_context_takeover,
293
+ server_max_window_bits,
294
+ client_max_window_bits,
295
+ )
296
+
297
+
298
+ class ClientPerMessageDeflateFactory(ClientExtensionFactory):
299
+ """
300
+ Client-side extension factory for the Per-Message Deflate extension.
301
+
302
+ Parameters behave as described in `section 7.1 of RFC 7692`_.
303
+
304
+ .. _section 7.1 of RFC 7692: https://datatracker.ietf.org/doc/html/rfc7692#section-7.1
305
+
306
+ Set them to :obj:`True` to include them in the negotiation offer without a
307
+ value or to an integer value to include them with this value.
308
+
309
+ Args:
310
+ server_no_context_takeover: Prevent server from using context takeover.
311
+ client_no_context_takeover: Prevent client from using context takeover.
312
+ server_max_window_bits: Maximum size of the server's LZ77 sliding window
313
+ in bits, between 8 and 15.
314
+ client_max_window_bits: Maximum size of the client's LZ77 sliding window
315
+ in bits, between 8 and 15, or :obj:`True` to indicate support without
316
+ setting a limit.
317
+ compress_settings: Additional keyword arguments for :func:`zlib.compressobj`,
318
+ excluding ``wbits``.
319
+
320
+ """
321
+
322
+ name = ExtensionName("permessage-deflate")
323
+
324
+ def __init__(
325
+ self,
326
+ server_no_context_takeover: bool = False,
327
+ client_no_context_takeover: bool = False,
328
+ server_max_window_bits: int | None = None,
329
+ client_max_window_bits: int | Literal[True] | None = True,
330
+ compress_settings: dict[str, Any] | None = None,
331
+ ) -> None:
332
+ """
333
+ Configure the Per-Message Deflate extension factory.
334
+
335
+ """
336
+ if not (server_max_window_bits is None or 8 <= server_max_window_bits <= 15):
337
+ raise ValueError("server_max_window_bits must be between 8 and 15")
338
+ if not (
339
+ client_max_window_bits is None
340
+ or client_max_window_bits is True
341
+ or 8 <= client_max_window_bits <= 15
342
+ ):
343
+ raise ValueError("client_max_window_bits must be between 8 and 15")
344
+ if compress_settings is not None and "wbits" in compress_settings:
345
+ raise ValueError(
346
+ "compress_settings must not include wbits, "
347
+ "set client_max_window_bits instead"
348
+ )
349
+
350
+ self.server_no_context_takeover = server_no_context_takeover
351
+ self.client_no_context_takeover = client_no_context_takeover
352
+ self.server_max_window_bits = server_max_window_bits
353
+ self.client_max_window_bits = client_max_window_bits
354
+ self.compress_settings = compress_settings
355
+
356
+ def get_request_params(self) -> Sequence[ExtensionParameter]:
357
+ """
358
+ Build request parameters.
359
+
360
+ """
361
+ return _build_parameters(
362
+ self.server_no_context_takeover,
363
+ self.client_no_context_takeover,
364
+ self.server_max_window_bits,
365
+ self.client_max_window_bits,
366
+ )
367
+
368
+ def process_response_params(
369
+ self,
370
+ params: Sequence[ExtensionParameter],
371
+ accepted_extensions: Sequence[Extension],
372
+ ) -> PerMessageDeflate:
373
+ """
374
+ Process response parameters.
375
+
376
+ Return an extension instance.
377
+
378
+ """
379
+ if any(other.name == self.name for other in accepted_extensions):
380
+ raise NegotiationError(f"received duplicate {self.name}")
381
+
382
+ # Request parameters are available in instance variables.
383
+
384
+ # Load response parameters in local variables.
385
+ (
386
+ server_no_context_takeover,
387
+ client_no_context_takeover,
388
+ server_max_window_bits,
389
+ client_max_window_bits,
390
+ ) = _extract_parameters(params, is_server=False)
391
+
392
+ # After comparing the request and the response, the final
393
+ # configuration must be available in the local variables.
394
+
395
+ # server_no_context_takeover
396
+ #
397
+ # Req. Resp. Result
398
+ # ------ ------ --------------------------------------------------
399
+ # False False False
400
+ # False True True
401
+ # True False Error!
402
+ # True True True
403
+
404
+ if self.server_no_context_takeover:
405
+ if not server_no_context_takeover:
406
+ raise NegotiationError("expected server_no_context_takeover")
407
+
408
+ # client_no_context_takeover
409
+ #
410
+ # Req. Resp. Result
411
+ # ------ ------ --------------------------------------------------
412
+ # False False False
413
+ # False True True
414
+ # True False True - must change value
415
+ # True True True
416
+
417
+ if self.client_no_context_takeover:
418
+ if not client_no_context_takeover:
419
+ client_no_context_takeover = True
420
+
421
+ # server_max_window_bits
422
+
423
+ # Req. Resp. Result
424
+ # ------ ------ --------------------------------------------------
425
+ # None None None
426
+ # None 8≤M≤15 M
427
+ # 8≤N≤15 None Error!
428
+ # 8≤N≤15 8≤M≤N M
429
+ # 8≤N≤15 N<M≤15 Error!
430
+
431
+ if self.server_max_window_bits is None:
432
+ pass
433
+
434
+ else:
435
+ if server_max_window_bits is None:
436
+ raise NegotiationError("expected server_max_window_bits")
437
+ elif server_max_window_bits > self.server_max_window_bits:
438
+ raise NegotiationError("unsupported server_max_window_bits")
439
+
440
+ # client_max_window_bits
441
+
442
+ # Req. Resp. Result
443
+ # ------ ------ --------------------------------------------------
444
+ # None None None
445
+ # None 8≤M≤15 Error!
446
+ # True None None
447
+ # True 8≤M≤15 M
448
+ # 8≤N≤15 None N - must change value
449
+ # 8≤N≤15 8≤M≤N M
450
+ # 8≤N≤15 N<M≤15 Error!
451
+
452
+ if self.client_max_window_bits is None:
453
+ if client_max_window_bits is not None:
454
+ raise NegotiationError("unexpected client_max_window_bits")
455
+
456
+ elif self.client_max_window_bits is True:
457
+ pass
458
+
459
+ else:
460
+ if client_max_window_bits is None:
461
+ client_max_window_bits = self.client_max_window_bits
462
+ elif client_max_window_bits > self.client_max_window_bits:
463
+ raise NegotiationError("unsupported client_max_window_bits")
464
+
465
+ return PerMessageDeflate(
466
+ server_no_context_takeover, # remote_no_context_takeover
467
+ client_no_context_takeover, # local_no_context_takeover
468
+ server_max_window_bits or 15, # remote_max_window_bits
469
+ client_max_window_bits or 15, # local_max_window_bits
470
+ self.compress_settings,
471
+ )
472
+
473
+
474
+ def enable_client_permessage_deflate(
475
+ extensions: Sequence[ClientExtensionFactory] | None,
476
+ ) -> Sequence[ClientExtensionFactory]:
477
+ """
478
+ Enable Per-Message Deflate with default settings in client extensions.
479
+
480
+ If the extension is already present, perhaps with non-default settings,
481
+ the configuration isn't changed.
482
+
483
+ """
484
+ if extensions is None:
485
+ extensions = []
486
+ if not any(
487
+ extension_factory.name == ClientPerMessageDeflateFactory.name
488
+ for extension_factory in extensions
489
+ ):
490
+ extensions = list(extensions) + [
491
+ ClientPerMessageDeflateFactory(
492
+ compress_settings={"memLevel": 5},
493
+ )
494
+ ]
495
+ return extensions
496
+
497
+
498
+ class ServerPerMessageDeflateFactory(ServerExtensionFactory):
499
+ """
500
+ Server-side extension factory for the Per-Message Deflate extension.
501
+
502
+ Parameters behave as described in `section 7.1 of RFC 7692`_.
503
+
504
+ .. _section 7.1 of RFC 7692: https://datatracker.ietf.org/doc/html/rfc7692#section-7.1
505
+
506
+ Set them to :obj:`True` to include them in the negotiation offer without a
507
+ value or to an integer value to include them with this value.
508
+
509
+ Args:
510
+ server_no_context_takeover: Prevent server from using context takeover.
511
+ client_no_context_takeover: Prevent client from using context takeover.
512
+ server_max_window_bits: Maximum size of the server's LZ77 sliding window
513
+ in bits, between 8 and 15.
514
+ client_max_window_bits: Maximum size of the client's LZ77 sliding window
515
+ in bits, between 8 and 15.
516
+ compress_settings: Additional keyword arguments for :func:`zlib.compressobj`,
517
+ excluding ``wbits``.
518
+ require_client_max_window_bits: Do not enable compression at all if
519
+ client doesn't advertise support for ``client_max_window_bits``;
520
+ the default behavior is to enable compression without enforcing
521
+ ``client_max_window_bits``.
522
+
523
+ """
524
+
525
+ name = ExtensionName("permessage-deflate")
526
+
527
+ def __init__(
528
+ self,
529
+ server_no_context_takeover: bool = False,
530
+ client_no_context_takeover: bool = False,
531
+ server_max_window_bits: int | None = None,
532
+ client_max_window_bits: int | None = None,
533
+ compress_settings: dict[str, Any] | None = None,
534
+ require_client_max_window_bits: bool = False,
535
+ ) -> None:
536
+ """
537
+ Configure the Per-Message Deflate extension factory.
538
+
539
+ """
540
+ if not (server_max_window_bits is None or 8 <= server_max_window_bits <= 15):
541
+ raise ValueError("server_max_window_bits must be between 8 and 15")
542
+ if not (client_max_window_bits is None or 8 <= client_max_window_bits <= 15):
543
+ raise ValueError("client_max_window_bits must be between 8 and 15")
544
+ if compress_settings is not None and "wbits" in compress_settings:
545
+ raise ValueError(
546
+ "compress_settings must not include wbits, "
547
+ "set server_max_window_bits instead"
548
+ )
549
+ if client_max_window_bits is None and require_client_max_window_bits:
550
+ raise ValueError(
551
+ "require_client_max_window_bits is enabled, "
552
+ "but client_max_window_bits isn't configured"
553
+ )
554
+
555
+ self.server_no_context_takeover = server_no_context_takeover
556
+ self.client_no_context_takeover = client_no_context_takeover
557
+ self.server_max_window_bits = server_max_window_bits
558
+ self.client_max_window_bits = client_max_window_bits
559
+ self.compress_settings = compress_settings
560
+ self.require_client_max_window_bits = require_client_max_window_bits
561
+
562
+ def process_request_params(
563
+ self,
564
+ params: Sequence[ExtensionParameter],
565
+ accepted_extensions: Sequence[Extension],
566
+ ) -> tuple[list[ExtensionParameter], PerMessageDeflate]:
567
+ """
568
+ Process request parameters.
569
+
570
+ Return response params and an extension instance.
571
+
572
+ """
573
+ if any(other.name == self.name for other in accepted_extensions):
574
+ raise NegotiationError(f"skipped duplicate {self.name}")
575
+
576
+ # Load request parameters in local variables.
577
+ (
578
+ server_no_context_takeover,
579
+ client_no_context_takeover,
580
+ server_max_window_bits,
581
+ client_max_window_bits,
582
+ ) = _extract_parameters(params, is_server=True)
583
+
584
+ # Configuration parameters are available in instance variables.
585
+
586
+ # After comparing the request and the configuration, the response must
587
+ # be available in the local variables.
588
+
589
+ # server_no_context_takeover
590
+ #
591
+ # Config Req. Resp.
592
+ # ------ ------ --------------------------------------------------
593
+ # False False False
594
+ # False True True
595
+ # True False True - must change value to True
596
+ # True True True
597
+
598
+ if self.server_no_context_takeover:
599
+ if not server_no_context_takeover:
600
+ server_no_context_takeover = True
601
+
602
+ # client_no_context_takeover
603
+ #
604
+ # Config Req. Resp.
605
+ # ------ ------ --------------------------------------------------
606
+ # False False False
607
+ # False True True (or False)
608
+ # True False True - must change value to True
609
+ # True True True (or False)
610
+
611
+ if self.client_no_context_takeover:
612
+ if not client_no_context_takeover:
613
+ client_no_context_takeover = True
614
+
615
+ # server_max_window_bits
616
+
617
+ # Config Req. Resp.
618
+ # ------ ------ --------------------------------------------------
619
+ # None None None
620
+ # None 8≤M≤15 M
621
+ # 8≤N≤15 None N - must change value
622
+ # 8≤N≤15 8≤M≤N M
623
+ # 8≤N≤15 N<M≤15 N - must change value
624
+
625
+ if self.server_max_window_bits is None:
626
+ pass
627
+
628
+ else:
629
+ if server_max_window_bits is None:
630
+ server_max_window_bits = self.server_max_window_bits
631
+ elif server_max_window_bits > self.server_max_window_bits:
632
+ server_max_window_bits = self.server_max_window_bits
633
+
634
+ # client_max_window_bits
635
+
636
+ # Config Req. Resp.
637
+ # ------ ------ --------------------------------------------------
638
+ # None None None
639
+ # None True None - must change value
640
+ # None 8≤M≤15 M (or None)
641
+ # 8≤N≤15 None None or Error!
642
+ # 8≤N≤15 True N - must change value
643
+ # 8≤N≤15 8≤M≤N M (or None)
644
+ # 8≤N≤15 N<M≤15 N
645
+
646
+ if self.client_max_window_bits is None:
647
+ if client_max_window_bits is True:
648
+ client_max_window_bits = self.client_max_window_bits
649
+
650
+ else:
651
+ if client_max_window_bits is None:
652
+ if self.require_client_max_window_bits:
653
+ raise NegotiationError("required client_max_window_bits")
654
+ elif client_max_window_bits is True:
655
+ client_max_window_bits = self.client_max_window_bits
656
+ elif self.client_max_window_bits < client_max_window_bits:
657
+ client_max_window_bits = self.client_max_window_bits
658
+
659
+ return (
660
+ _build_parameters(
661
+ server_no_context_takeover,
662
+ client_no_context_takeover,
663
+ server_max_window_bits,
664
+ client_max_window_bits,
665
+ ),
666
+ PerMessageDeflate(
667
+ client_no_context_takeover, # remote_no_context_takeover
668
+ server_no_context_takeover, # local_no_context_takeover
669
+ client_max_window_bits or 15, # remote_max_window_bits
670
+ server_max_window_bits or 15, # local_max_window_bits
671
+ self.compress_settings,
672
+ ),
673
+ )
674
+
675
+
676
+ def enable_server_permessage_deflate(
677
+ extensions: Sequence[ServerExtensionFactory] | None,
678
+ ) -> Sequence[ServerExtensionFactory]:
679
+ """
680
+ Enable Per-Message Deflate with default settings in server extensions.
681
+
682
+ If the extension is already present, perhaps with non-default settings,
683
+ the configuration isn't changed.
684
+
685
+ """
686
+ if extensions is None:
687
+ extensions = []
688
+ if not any(
689
+ ext_factory.name == ServerPerMessageDeflateFactory.name
690
+ for ext_factory in extensions
691
+ ):
692
+ extensions = list(extensions) + [
693
+ ServerPerMessageDeflateFactory(
694
+ server_max_window_bits=12,
695
+ client_max_window_bits=12,
696
+ compress_settings={"memLevel": 5},
697
+ )
698
+ ]
699
+ return extensions
source/websockets/frames.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import enum
5
+ import io
6
+ import os
7
+ import secrets
8
+ import struct
9
+ from collections.abc import Generator, Sequence
10
+ from typing import Callable
11
+
12
+ from .exceptions import PayloadTooBig, ProtocolError
13
+ from .typing import BytesLike
14
+
15
+
16
+ try:
17
+ from .speedups import apply_mask
18
+ except ImportError:
19
+ from .utils import apply_mask
20
+
21
+
22
+ __all__ = [
23
+ "Opcode",
24
+ "OP_CONT",
25
+ "OP_TEXT",
26
+ "OP_BINARY",
27
+ "OP_CLOSE",
28
+ "OP_PING",
29
+ "OP_PONG",
30
+ "DATA_OPCODES",
31
+ "CTRL_OPCODES",
32
+ "CloseCode",
33
+ "Frame",
34
+ "Close",
35
+ ]
36
+
37
+
38
+ class Opcode(enum.IntEnum):
39
+ """Opcode values for WebSocket frames."""
40
+
41
+ CONT, TEXT, BINARY = 0x00, 0x01, 0x02
42
+ CLOSE, PING, PONG = 0x08, 0x09, 0x0A
43
+
44
+
45
+ OP_CONT = Opcode.CONT
46
+ OP_TEXT = Opcode.TEXT
47
+ OP_BINARY = Opcode.BINARY
48
+ OP_CLOSE = Opcode.CLOSE
49
+ OP_PING = Opcode.PING
50
+ OP_PONG = Opcode.PONG
51
+
52
+ DATA_OPCODES = OP_CONT, OP_TEXT, OP_BINARY
53
+ CTRL_OPCODES = OP_CLOSE, OP_PING, OP_PONG
54
+
55
+
56
+ class CloseCode(enum.IntEnum):
57
+ """Close code values for WebSocket close frames."""
58
+
59
+ NORMAL_CLOSURE = 1000
60
+ GOING_AWAY = 1001
61
+ PROTOCOL_ERROR = 1002
62
+ UNSUPPORTED_DATA = 1003
63
+ # 1004 is reserved
64
+ NO_STATUS_RCVD = 1005
65
+ ABNORMAL_CLOSURE = 1006
66
+ INVALID_DATA = 1007
67
+ POLICY_VIOLATION = 1008
68
+ MESSAGE_TOO_BIG = 1009
69
+ MANDATORY_EXTENSION = 1010
70
+ INTERNAL_ERROR = 1011
71
+ SERVICE_RESTART = 1012
72
+ TRY_AGAIN_LATER = 1013
73
+ BAD_GATEWAY = 1014
74
+ TLS_HANDSHAKE = 1015
75
+
76
+
77
+ # See https://www.iana.org/assignments/websocket/websocket.xhtml
78
+ CLOSE_CODE_EXPLANATIONS: dict[int, str] = {
79
+ CloseCode.NORMAL_CLOSURE: "OK",
80
+ CloseCode.GOING_AWAY: "going away",
81
+ CloseCode.PROTOCOL_ERROR: "protocol error",
82
+ CloseCode.UNSUPPORTED_DATA: "unsupported data",
83
+ CloseCode.NO_STATUS_RCVD: "no status received [internal]",
84
+ CloseCode.ABNORMAL_CLOSURE: "abnormal closure [internal]",
85
+ CloseCode.INVALID_DATA: "invalid frame payload data",
86
+ CloseCode.POLICY_VIOLATION: "policy violation",
87
+ CloseCode.MESSAGE_TOO_BIG: "message too big",
88
+ CloseCode.MANDATORY_EXTENSION: "mandatory extension",
89
+ CloseCode.INTERNAL_ERROR: "internal error",
90
+ CloseCode.SERVICE_RESTART: "service restart",
91
+ CloseCode.TRY_AGAIN_LATER: "try again later",
92
+ CloseCode.BAD_GATEWAY: "bad gateway",
93
+ CloseCode.TLS_HANDSHAKE: "TLS handshake failure [internal]",
94
+ }
95
+
96
+
97
+ # Close code that are allowed in a close frame.
98
+ # Using a set optimizes `code in EXTERNAL_CLOSE_CODES`.
99
+ EXTERNAL_CLOSE_CODES = {
100
+ CloseCode.NORMAL_CLOSURE,
101
+ CloseCode.GOING_AWAY,
102
+ CloseCode.PROTOCOL_ERROR,
103
+ CloseCode.UNSUPPORTED_DATA,
104
+ CloseCode.INVALID_DATA,
105
+ CloseCode.POLICY_VIOLATION,
106
+ CloseCode.MESSAGE_TOO_BIG,
107
+ CloseCode.MANDATORY_EXTENSION,
108
+ CloseCode.INTERNAL_ERROR,
109
+ CloseCode.SERVICE_RESTART,
110
+ CloseCode.TRY_AGAIN_LATER,
111
+ CloseCode.BAD_GATEWAY,
112
+ }
113
+
114
+
115
+ OK_CLOSE_CODES = {
116
+ CloseCode.NORMAL_CLOSURE,
117
+ CloseCode.GOING_AWAY,
118
+ CloseCode.NO_STATUS_RCVD,
119
+ }
120
+
121
+
122
+ @dataclasses.dataclass
123
+ class Frame:
124
+ """
125
+ WebSocket frame.
126
+
127
+ Attributes:
128
+ opcode: Opcode.
129
+ data: Payload data.
130
+ fin: FIN bit.
131
+ rsv1: RSV1 bit.
132
+ rsv2: RSV2 bit.
133
+ rsv3: RSV3 bit.
134
+
135
+ Only these fields are needed. The MASK bit, payload length and masking-key
136
+ are handled on the fly when parsing and serializing frames.
137
+
138
+ """
139
+
140
+ opcode: Opcode
141
+ data: BytesLike
142
+ fin: bool = True
143
+ rsv1: bool = False
144
+ rsv2: bool = False
145
+ rsv3: bool = False
146
+
147
+ # Configure if you want to see more in logs. Should be a multiple of 3.
148
+ MAX_LOG_SIZE = int(os.environ.get("WEBSOCKETS_MAX_LOG_SIZE", "75"))
149
+
150
+ def __str__(self) -> str:
151
+ """
152
+ Return a human-readable representation of a frame.
153
+
154
+ """
155
+ coding = None
156
+ length = f"{len(self.data)} byte{'' if len(self.data) == 1 else 's'}"
157
+ non_final = "" if self.fin else "continued"
158
+
159
+ if self.opcode is OP_TEXT:
160
+ # Decoding only the beginning and the end is needlessly hard.
161
+ # Decode the entire payload then elide later if necessary.
162
+ data = repr(bytes(self.data).decode())
163
+ elif self.opcode is OP_BINARY:
164
+ # We'll show at most the first 16 bytes and the last 8 bytes.
165
+ # Encode just what we need, plus two dummy bytes to elide later.
166
+ binary = self.data
167
+ if len(binary) > self.MAX_LOG_SIZE // 3:
168
+ cut = (self.MAX_LOG_SIZE // 3 - 1) // 3 # by default cut = 8
169
+ binary = b"".join([binary[: 2 * cut], b"\x00\x00", binary[-cut:]])
170
+ data = " ".join(f"{byte:02x}" for byte in binary)
171
+ elif self.opcode is OP_CLOSE:
172
+ data = str(Close.parse(self.data))
173
+ elif self.data:
174
+ # We don't know if a Continuation frame contains text or binary.
175
+ # Ping and Pong frames could contain UTF-8.
176
+ # Attempt to decode as UTF-8 and display it as text; fallback to
177
+ # binary. If self.data is a memoryview, it has no decode() method,
178
+ # which raises AttributeError.
179
+ try:
180
+ data = repr(bytes(self.data).decode())
181
+ coding = "text"
182
+ except (UnicodeDecodeError, AttributeError):
183
+ binary = self.data
184
+ if len(binary) > self.MAX_LOG_SIZE // 3:
185
+ cut = (self.MAX_LOG_SIZE // 3 - 1) // 3 # by default cut = 8
186
+ binary = b"".join([binary[: 2 * cut], b"\x00\x00", binary[-cut:]])
187
+ data = " ".join(f"{byte:02x}" for byte in binary)
188
+ coding = "binary"
189
+ else:
190
+ data = "''"
191
+
192
+ if len(data) > self.MAX_LOG_SIZE:
193
+ cut = self.MAX_LOG_SIZE // 3 - 1 # by default cut = 24
194
+ data = data[: 2 * cut] + "..." + data[-cut:]
195
+
196
+ metadata = ", ".join(filter(None, [coding, length, non_final]))
197
+
198
+ return f"{self.opcode.name} {data} [{metadata}]"
199
+
200
+ @classmethod
201
+ def parse(
202
+ cls,
203
+ read_exact: Callable[[int], Generator[None, None, bytes | bytearray]],
204
+ *,
205
+ mask: bool,
206
+ max_size: int | None = None,
207
+ extensions: Sequence[extensions.Extension] | None = None,
208
+ ) -> Generator[None, None, Frame]:
209
+ """
210
+ Parse a WebSocket frame.
211
+
212
+ This is a generator-based coroutine.
213
+
214
+ Args:
215
+ read_exact: Generator-based coroutine that reads the requested
216
+ bytes or raises an exception if there isn't enough data.
217
+ mask: Whether the frame should be masked i.e. whether the read
218
+ happens on the server side.
219
+ max_size: Maximum payload size in bytes.
220
+ extensions: List of extensions, applied in reverse order.
221
+
222
+ Raises:
223
+ EOFError: If the connection is closed without a full WebSocket frame.
224
+ PayloadTooBig: If the frame's payload size exceeds ``max_size``.
225
+ ProtocolError: If the frame contains incorrect values.
226
+
227
+ """
228
+ # Read the header.
229
+ data = yield from read_exact(2)
230
+ head1, head2 = struct.unpack("!BB", data)
231
+
232
+ # While not Pythonic, this is marginally faster than calling bool().
233
+ fin = True if head1 & 0b10000000 else False
234
+ rsv1 = True if head1 & 0b01000000 else False
235
+ rsv2 = True if head1 & 0b00100000 else False
236
+ rsv3 = True if head1 & 0b00010000 else False
237
+
238
+ try:
239
+ opcode = Opcode(head1 & 0b00001111)
240
+ except ValueError as exc:
241
+ raise ProtocolError("invalid opcode") from exc
242
+
243
+ if (True if head2 & 0b10000000 else False) != mask:
244
+ raise ProtocolError("incorrect masking")
245
+
246
+ length = head2 & 0b01111111
247
+ if length == 126:
248
+ data = yield from read_exact(2)
249
+ (length,) = struct.unpack("!H", data)
250
+ elif length == 127:
251
+ data = yield from read_exact(8)
252
+ (length,) = struct.unpack("!Q", data)
253
+ if max_size is not None and length > max_size:
254
+ raise PayloadTooBig(length, max_size)
255
+ if mask:
256
+ mask_bytes = yield from read_exact(4)
257
+
258
+ # Read the data.
259
+ data = yield from read_exact(length)
260
+ if mask:
261
+ data = apply_mask(data, mask_bytes)
262
+
263
+ frame = cls(opcode, data, fin, rsv1, rsv2, rsv3)
264
+
265
+ if extensions is None:
266
+ extensions = []
267
+ for extension in reversed(extensions):
268
+ frame = extension.decode(frame, max_size=max_size)
269
+
270
+ frame.check()
271
+
272
+ return frame
273
+
274
+ def serialize(
275
+ self,
276
+ *,
277
+ mask: bool,
278
+ extensions: Sequence[extensions.Extension] | None = None,
279
+ ) -> bytes:
280
+ """
281
+ Serialize a WebSocket frame.
282
+
283
+ Args:
284
+ mask: Whether the frame should be masked i.e. whether the write
285
+ happens on the client side.
286
+ extensions: List of extensions, applied in order.
287
+
288
+ Raises:
289
+ ProtocolError: If the frame contains incorrect values.
290
+
291
+ """
292
+ self.check()
293
+
294
+ if extensions is None:
295
+ extensions = []
296
+ for extension in extensions:
297
+ self = extension.encode(self)
298
+
299
+ output = io.BytesIO()
300
+
301
+ # Prepare the header.
302
+ head1 = (
303
+ (0b10000000 if self.fin else 0)
304
+ | (0b01000000 if self.rsv1 else 0)
305
+ | (0b00100000 if self.rsv2 else 0)
306
+ | (0b00010000 if self.rsv3 else 0)
307
+ | self.opcode
308
+ )
309
+
310
+ head2 = 0b10000000 if mask else 0
311
+
312
+ length = len(self.data)
313
+ if length < 126:
314
+ output.write(struct.pack("!BB", head1, head2 | length))
315
+ elif length < 65536:
316
+ output.write(struct.pack("!BBH", head1, head2 | 126, length))
317
+ else:
318
+ output.write(struct.pack("!BBQ", head1, head2 | 127, length))
319
+
320
+ if mask:
321
+ mask_bytes = secrets.token_bytes(4)
322
+ output.write(mask_bytes)
323
+
324
+ # Prepare the data.
325
+ data: BytesLike
326
+ if mask:
327
+ data = apply_mask(self.data, mask_bytes)
328
+ else:
329
+ data = self.data
330
+ output.write(data)
331
+
332
+ return output.getvalue()
333
+
334
+ def check(self) -> None:
335
+ """
336
+ Check that reserved bits and opcode have acceptable values.
337
+
338
+ Raises:
339
+ ProtocolError: If a reserved bit or the opcode is invalid.
340
+
341
+ """
342
+ if self.rsv1 or self.rsv2 or self.rsv3:
343
+ raise ProtocolError("reserved bits must be 0")
344
+
345
+ if self.opcode in CTRL_OPCODES:
346
+ if len(self.data) > 125:
347
+ raise ProtocolError("control frame too long")
348
+ if not self.fin:
349
+ raise ProtocolError("fragmented control frame")
350
+
351
+
352
+ @dataclasses.dataclass
353
+ class Close:
354
+ """
355
+ Code and reason for WebSocket close frames.
356
+
357
+ Attributes:
358
+ code: Close code.
359
+ reason: Close reason.
360
+
361
+ """
362
+
363
+ code: CloseCode | int
364
+ reason: str
365
+
366
+ def __str__(self) -> str:
367
+ """
368
+ Return a human-readable representation of a close code and reason.
369
+
370
+ """
371
+ if 3000 <= self.code < 4000:
372
+ explanation = "registered"
373
+ elif 4000 <= self.code < 5000:
374
+ explanation = "private use"
375
+ else:
376
+ explanation = CLOSE_CODE_EXPLANATIONS.get(self.code, "unknown")
377
+ result = f"{self.code} ({explanation})"
378
+
379
+ if self.reason:
380
+ result = f"{result} {self.reason}"
381
+
382
+ return result
383
+
384
+ @classmethod
385
+ def parse(cls, data: BytesLike) -> Close:
386
+ """
387
+ Parse the payload of a close frame.
388
+
389
+ Args:
390
+ data: Payload of the close frame.
391
+
392
+ Raises:
393
+ ProtocolError: If data is ill-formed.
394
+ UnicodeDecodeError: If the reason isn't valid UTF-8.
395
+
396
+ """
397
+ if isinstance(data, memoryview):
398
+ raise AssertionError("only compressed outgoing frames use memoryview")
399
+ if len(data) >= 2:
400
+ (code,) = struct.unpack("!H", data[:2])
401
+ reason = data[2:].decode()
402
+ close = cls(code, reason)
403
+ close.check()
404
+ return close
405
+ elif len(data) == 0:
406
+ return cls(CloseCode.NO_STATUS_RCVD, "")
407
+ else:
408
+ raise ProtocolError("close frame too short")
409
+
410
+ def serialize(self) -> bytes:
411
+ """
412
+ Serialize the payload of a close frame.
413
+
414
+ """
415
+ self.check()
416
+ return struct.pack("!H", self.code) + self.reason.encode()
417
+
418
+ def check(self) -> None:
419
+ """
420
+ Check that the close code has a valid value for a close frame.
421
+
422
+ Raises:
423
+ ProtocolError: If the close code is invalid.
424
+
425
+ """
426
+ if not (self.code in EXTERNAL_CLOSE_CODES or 3000 <= self.code < 5000):
427
+ raise ProtocolError("invalid status code")
428
+
429
+
430
+ # At the bottom to break import cycles created by type annotations.
431
+ from . import extensions # noqa: E402
source/websockets/headers.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import binascii
5
+ import ipaddress
6
+ import re
7
+ from collections.abc import Sequence
8
+ from typing import Callable, TypeVar, cast
9
+
10
+ from .exceptions import InvalidHeaderFormat, InvalidHeaderValue
11
+ from .typing import (
12
+ ConnectionOption,
13
+ ExtensionHeader,
14
+ ExtensionName,
15
+ ExtensionParameter,
16
+ Subprotocol,
17
+ UpgradeProtocol,
18
+ )
19
+
20
+
21
+ __all__ = [
22
+ "build_host",
23
+ "parse_connection",
24
+ "parse_upgrade",
25
+ "parse_extension",
26
+ "build_extension",
27
+ "parse_subprotocol",
28
+ "build_subprotocol",
29
+ "validate_subprotocols",
30
+ "build_www_authenticate_basic",
31
+ "parse_authorization_basic",
32
+ "build_authorization_basic",
33
+ ]
34
+
35
+
36
+ T = TypeVar("T")
37
+
38
+
39
+ def build_host(
40
+ host: str,
41
+ port: int,
42
+ secure: bool,
43
+ *,
44
+ always_include_port: bool = False,
45
+ ) -> str:
46
+ """
47
+ Build a ``Host`` header.
48
+
49
+ """
50
+ # https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.2
51
+ # IPv6 addresses must be enclosed in brackets.
52
+ try:
53
+ address = ipaddress.ip_address(host)
54
+ except ValueError:
55
+ # host is a hostname
56
+ pass
57
+ else:
58
+ # host is an IP address
59
+ if address.version == 6:
60
+ host = f"[{host}]"
61
+
62
+ if always_include_port or port != (443 if secure else 80):
63
+ host = f"{host}:{port}"
64
+
65
+ return host
66
+
67
+
68
+ # To avoid a dependency on a parsing library, we implement manually the ABNF
69
+ # described in https://datatracker.ietf.org/doc/html/rfc6455#section-9.1 and
70
+ # https://datatracker.ietf.org/doc/html/rfc7230#appendix-B.
71
+
72
+
73
+ def peek_ahead(header: str, pos: int) -> str | None:
74
+ """
75
+ Return the next character from ``header`` at the given position.
76
+
77
+ Return :obj:`None` at the end of ``header``.
78
+
79
+ We never need to peek more than one character ahead.
80
+
81
+ """
82
+ return None if pos == len(header) else header[pos]
83
+
84
+
85
+ _OWS_re = re.compile(r"[\t ]*")
86
+
87
+
88
+ def parse_OWS(header: str, pos: int) -> int:
89
+ """
90
+ Parse optional whitespace from ``header`` at the given position.
91
+
92
+ Return the new position.
93
+
94
+ The whitespace itself isn't returned because it isn't significant.
95
+
96
+ """
97
+ # There's always a match, possibly empty, whose content doesn't matter.
98
+ match = _OWS_re.match(header, pos)
99
+ assert match is not None
100
+ return match.end()
101
+
102
+
103
+ _token_re = re.compile(r"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+")
104
+
105
+
106
+ def parse_token(header: str, pos: int, header_name: str) -> tuple[str, int]:
107
+ """
108
+ Parse a token from ``header`` at the given position.
109
+
110
+ Return the token value and the new position.
111
+
112
+ Raises:
113
+ InvalidHeaderFormat: On invalid inputs.
114
+
115
+ """
116
+ match = _token_re.match(header, pos)
117
+ if match is None:
118
+ raise InvalidHeaderFormat(header_name, "expected token", header, pos)
119
+ return match.group(), match.end()
120
+
121
+
122
+ _quoted_string_re = re.compile(
123
+ r'"(?:[\x09\x20-\x21\x23-\x5b\x5d-\x7e]|\\[\x09\x20-\x7e\x80-\xff])*"'
124
+ )
125
+
126
+
127
+ _unquote_re = re.compile(r"\\([\x09\x20-\x7e\x80-\xff])")
128
+
129
+
130
+ def parse_quoted_string(header: str, pos: int, header_name: str) -> tuple[str, int]:
131
+ """
132
+ Parse a quoted string from ``header`` at the given position.
133
+
134
+ Return the unquoted value and the new position.
135
+
136
+ Raises:
137
+ InvalidHeaderFormat: On invalid inputs.
138
+
139
+ """
140
+ match = _quoted_string_re.match(header, pos)
141
+ if match is None:
142
+ raise InvalidHeaderFormat(header_name, "expected quoted string", header, pos)
143
+ return _unquote_re.sub(r"\1", match.group()[1:-1]), match.end()
144
+
145
+
146
+ _quotable_re = re.compile(r"[\x09\x20-\x7e\x80-\xff]*")
147
+
148
+
149
+ _quote_re = re.compile(r"([\x22\x5c])")
150
+
151
+
152
+ def build_quoted_string(value: str) -> str:
153
+ """
154
+ Format ``value`` as a quoted string.
155
+
156
+ This is the reverse of :func:`parse_quoted_string`.
157
+
158
+ """
159
+ match = _quotable_re.fullmatch(value)
160
+ if match is None:
161
+ raise ValueError("invalid characters for quoted-string encoding")
162
+ return '"' + _quote_re.sub(r"\\\1", value) + '"'
163
+
164
+
165
+ def parse_list(
166
+ parse_item: Callable[[str, int, str], tuple[T, int]],
167
+ header: str,
168
+ pos: int,
169
+ header_name: str,
170
+ ) -> list[T]:
171
+ """
172
+ Parse a comma-separated list from ``header`` at the given position.
173
+
174
+ This is appropriate for parsing values with the following grammar:
175
+
176
+ 1#item
177
+
178
+ ``parse_item`` parses one item.
179
+
180
+ ``header`` is assumed not to start or end with whitespace.
181
+
182
+ (This function is designed for parsing an entire header value and
183
+ :func:`~websockets.http.read_headers` strips whitespace from values.)
184
+
185
+ Return a list of items.
186
+
187
+ Raises:
188
+ InvalidHeaderFormat: On invalid inputs.
189
+
190
+ """
191
+ # Per https://datatracker.ietf.org/doc/html/rfc7230#section-7, "a recipient
192
+ # MUST parse and ignore a reasonable number of empty list elements";
193
+ # hence while loops that remove extra delimiters.
194
+
195
+ # Remove extra delimiters before the first item.
196
+ while peek_ahead(header, pos) == ",":
197
+ pos = parse_OWS(header, pos + 1)
198
+
199
+ items = []
200
+ while True:
201
+ # Loop invariant: a item starts at pos in header.
202
+ item, pos = parse_item(header, pos, header_name)
203
+ items.append(item)
204
+ pos = parse_OWS(header, pos)
205
+
206
+ # We may have reached the end of the header.
207
+ if pos == len(header):
208
+ break
209
+
210
+ # There must be a delimiter after each element except the last one.
211
+ if peek_ahead(header, pos) == ",":
212
+ pos = parse_OWS(header, pos + 1)
213
+ else:
214
+ raise InvalidHeaderFormat(header_name, "expected comma", header, pos)
215
+
216
+ # Remove extra delimiters before the next item.
217
+ while peek_ahead(header, pos) == ",":
218
+ pos = parse_OWS(header, pos + 1)
219
+
220
+ # We may have reached the end of the header.
221
+ if pos == len(header):
222
+ break
223
+
224
+ # Since we only advance in the header by one character with peek_ahead()
225
+ # or with the end position of a regex match, we can't overshoot the end.
226
+ assert pos == len(header)
227
+
228
+ return items
229
+
230
+
231
+ def parse_connection_option(
232
+ header: str, pos: int, header_name: str
233
+ ) -> tuple[ConnectionOption, int]:
234
+ """
235
+ Parse a Connection option from ``header`` at the given position.
236
+
237
+ Return the protocol value and the new position.
238
+
239
+ Raises:
240
+ InvalidHeaderFormat: On invalid inputs.
241
+
242
+ """
243
+ item, pos = parse_token(header, pos, header_name)
244
+ return cast(ConnectionOption, item), pos
245
+
246
+
247
+ def parse_connection(header: str) -> list[ConnectionOption]:
248
+ """
249
+ Parse a ``Connection`` header.
250
+
251
+ Return a list of HTTP connection options.
252
+
253
+ Args
254
+ header: value of the ``Connection`` header.
255
+
256
+ Raises:
257
+ InvalidHeaderFormat: On invalid inputs.
258
+
259
+ """
260
+ return parse_list(parse_connection_option, header, 0, "Connection")
261
+
262
+
263
+ _protocol_re = re.compile(
264
+ r"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+(?:/[-!#$%&\'*+.^_`|~0-9a-zA-Z]+)?"
265
+ )
266
+
267
+
268
+ def parse_upgrade_protocol(
269
+ header: str, pos: int, header_name: str
270
+ ) -> tuple[UpgradeProtocol, int]:
271
+ """
272
+ Parse an Upgrade protocol from ``header`` at the given position.
273
+
274
+ Return the protocol value and the new position.
275
+
276
+ Raises:
277
+ InvalidHeaderFormat: On invalid inputs.
278
+
279
+ """
280
+ match = _protocol_re.match(header, pos)
281
+ if match is None:
282
+ raise InvalidHeaderFormat(header_name, "expected protocol", header, pos)
283
+ return cast(UpgradeProtocol, match.group()), match.end()
284
+
285
+
286
+ def parse_upgrade(header: str) -> list[UpgradeProtocol]:
287
+ """
288
+ Parse an ``Upgrade`` header.
289
+
290
+ Return a list of HTTP protocols.
291
+
292
+ Args:
293
+ header: Value of the ``Upgrade`` header.
294
+
295
+ Raises:
296
+ InvalidHeaderFormat: On invalid inputs.
297
+
298
+ """
299
+ return parse_list(parse_upgrade_protocol, header, 0, "Upgrade")
300
+
301
+
302
+ def parse_extension_item_param(
303
+ header: str, pos: int, header_name: str
304
+ ) -> tuple[ExtensionParameter, int]:
305
+ """
306
+ Parse a single extension parameter from ``header`` at the given position.
307
+
308
+ Return a ``(name, value)`` pair and the new position.
309
+
310
+ Raises:
311
+ InvalidHeaderFormat: On invalid inputs.
312
+
313
+ """
314
+ # Extract parameter name.
315
+ name, pos = parse_token(header, pos, header_name)
316
+ pos = parse_OWS(header, pos)
317
+ # Extract parameter value, if there is one.
318
+ value: str | None = None
319
+ if peek_ahead(header, pos) == "=":
320
+ pos = parse_OWS(header, pos + 1)
321
+ if peek_ahead(header, pos) == '"':
322
+ pos_before = pos # for proper error reporting below
323
+ value, pos = parse_quoted_string(header, pos, header_name)
324
+ # https://datatracker.ietf.org/doc/html/rfc6455#section-9.1 says:
325
+ # the value after quoted-string unescaping MUST conform to
326
+ # the 'token' ABNF.
327
+ if _token_re.fullmatch(value) is None:
328
+ raise InvalidHeaderFormat(
329
+ header_name, "invalid quoted header content", header, pos_before
330
+ )
331
+ else:
332
+ value, pos = parse_token(header, pos, header_name)
333
+ pos = parse_OWS(header, pos)
334
+
335
+ return (name, value), pos
336
+
337
+
338
+ def parse_extension_item(
339
+ header: str, pos: int, header_name: str
340
+ ) -> tuple[ExtensionHeader, int]:
341
+ """
342
+ Parse an extension definition from ``header`` at the given position.
343
+
344
+ Return an ``(extension name, parameters)`` pair, where ``parameters`` is a
345
+ list of ``(name, value)`` pairs, and the new position.
346
+
347
+ Raises:
348
+ InvalidHeaderFormat: On invalid inputs.
349
+
350
+ """
351
+ # Extract extension name.
352
+ name, pos = parse_token(header, pos, header_name)
353
+ pos = parse_OWS(header, pos)
354
+ # Extract all parameters.
355
+ parameters = []
356
+ while peek_ahead(header, pos) == ";":
357
+ pos = parse_OWS(header, pos + 1)
358
+ parameter, pos = parse_extension_item_param(header, pos, header_name)
359
+ parameters.append(parameter)
360
+ return (cast(ExtensionName, name), parameters), pos
361
+
362
+
363
+ def parse_extension(header: str) -> list[ExtensionHeader]:
364
+ """
365
+ Parse a ``Sec-WebSocket-Extensions`` header.
366
+
367
+ Return a list of WebSocket extensions and their parameters in this format::
368
+
369
+ [
370
+ (
371
+ 'extension name',
372
+ [
373
+ ('parameter name', 'parameter value'),
374
+ ....
375
+ ]
376
+ ),
377
+ ...
378
+ ]
379
+
380
+ Parameter values are :obj:`None` when no value is provided.
381
+
382
+ Raises:
383
+ InvalidHeaderFormat: On invalid inputs.
384
+
385
+ """
386
+ return parse_list(parse_extension_item, header, 0, "Sec-WebSocket-Extensions")
387
+
388
+
389
+ parse_extension_list = parse_extension # alias for backwards compatibility
390
+
391
+
392
+ def build_extension_item(
393
+ name: ExtensionName, parameters: Sequence[ExtensionParameter]
394
+ ) -> str:
395
+ """
396
+ Build an extension definition.
397
+
398
+ This is the reverse of :func:`parse_extension_item`.
399
+
400
+ """
401
+ return "; ".join(
402
+ [cast(str, name)]
403
+ + [
404
+ # Quoted strings aren't necessary because values are always tokens.
405
+ name if value is None else f"{name}={value}"
406
+ for name, value in parameters
407
+ ]
408
+ )
409
+
410
+
411
+ def build_extension(extensions: Sequence[ExtensionHeader]) -> str:
412
+ """
413
+ Build a ``Sec-WebSocket-Extensions`` header.
414
+
415
+ This is the reverse of :func:`parse_extension`.
416
+
417
+ """
418
+ return ", ".join(
419
+ build_extension_item(name, parameters) for name, parameters in extensions
420
+ )
421
+
422
+
423
+ build_extension_list = build_extension # alias for backwards compatibility
424
+
425
+
426
+ def parse_subprotocol_item(
427
+ header: str, pos: int, header_name: str
428
+ ) -> tuple[Subprotocol, int]:
429
+ """
430
+ Parse a subprotocol from ``header`` at the given position.
431
+
432
+ Return the subprotocol value and the new position.
433
+
434
+ Raises:
435
+ InvalidHeaderFormat: On invalid inputs.
436
+
437
+ """
438
+ item, pos = parse_token(header, pos, header_name)
439
+ return cast(Subprotocol, item), pos
440
+
441
+
442
+ def parse_subprotocol(header: str) -> list[Subprotocol]:
443
+ """
444
+ Parse a ``Sec-WebSocket-Protocol`` header.
445
+
446
+ Return a list of WebSocket subprotocols.
447
+
448
+ Raises:
449
+ InvalidHeaderFormat: On invalid inputs.
450
+
451
+ """
452
+ return parse_list(parse_subprotocol_item, header, 0, "Sec-WebSocket-Protocol")
453
+
454
+
455
+ parse_subprotocol_list = parse_subprotocol # alias for backwards compatibility
456
+
457
+
458
+ def build_subprotocol(subprotocols: Sequence[Subprotocol]) -> str:
459
+ """
460
+ Build a ``Sec-WebSocket-Protocol`` header.
461
+
462
+ This is the reverse of :func:`parse_subprotocol`.
463
+
464
+ """
465
+ return ", ".join(subprotocols)
466
+
467
+
468
+ build_subprotocol_list = build_subprotocol # alias for backwards compatibility
469
+
470
+
471
+ def validate_subprotocols(subprotocols: Sequence[Subprotocol]) -> None:
472
+ """
473
+ Validate that ``subprotocols`` is suitable for :func:`build_subprotocol`.
474
+
475
+ """
476
+ if not isinstance(subprotocols, Sequence):
477
+ raise TypeError("subprotocols must be a list")
478
+ if isinstance(subprotocols, str):
479
+ raise TypeError("subprotocols must be a list, not a str")
480
+ for subprotocol in subprotocols:
481
+ if not _token_re.fullmatch(subprotocol):
482
+ raise ValueError(f"invalid subprotocol: {subprotocol}")
483
+
484
+
485
+ def build_www_authenticate_basic(realm: str) -> str:
486
+ """
487
+ Build a ``WWW-Authenticate`` header for HTTP Basic Auth.
488
+
489
+ Args:
490
+ realm: Identifier of the protection space.
491
+
492
+ """
493
+ # https://datatracker.ietf.org/doc/html/rfc7617#section-2
494
+ realm = build_quoted_string(realm)
495
+ charset = build_quoted_string("UTF-8")
496
+ return f"Basic realm={realm}, charset={charset}"
497
+
498
+
499
+ _token68_re = re.compile(r"[A-Za-z0-9-._~+/]+=*")
500
+
501
+
502
+ def parse_token68(header: str, pos: int, header_name: str) -> tuple[str, int]:
503
+ """
504
+ Parse a token68 from ``header`` at the given position.
505
+
506
+ Return the token value and the new position.
507
+
508
+ Raises:
509
+ InvalidHeaderFormat: On invalid inputs.
510
+
511
+ """
512
+ match = _token68_re.match(header, pos)
513
+ if match is None:
514
+ raise InvalidHeaderFormat(header_name, "expected token68", header, pos)
515
+ return match.group(), match.end()
516
+
517
+
518
+ def parse_end(header: str, pos: int, header_name: str) -> None:
519
+ """
520
+ Check that parsing reached the end of header.
521
+
522
+ """
523
+ if pos < len(header):
524
+ raise InvalidHeaderFormat(header_name, "trailing data", header, pos)
525
+
526
+
527
+ def parse_authorization_basic(header: str) -> tuple[str, str]:
528
+ """
529
+ Parse an ``Authorization`` header for HTTP Basic Auth.
530
+
531
+ Return a ``(username, password)`` tuple.
532
+
533
+ Args:
534
+ header: Value of the ``Authorization`` header.
535
+
536
+ Raises:
537
+ InvalidHeaderFormat: On invalid inputs.
538
+ InvalidHeaderValue: On unsupported inputs.
539
+
540
+ """
541
+ # https://datatracker.ietf.org/doc/html/rfc7235#section-2.1
542
+ # https://datatracker.ietf.org/doc/html/rfc7617#section-2
543
+ scheme, pos = parse_token(header, 0, "Authorization")
544
+ if scheme.lower() != "basic":
545
+ raise InvalidHeaderValue(
546
+ "Authorization",
547
+ f"unsupported scheme: {scheme}",
548
+ )
549
+ if peek_ahead(header, pos) != " ":
550
+ raise InvalidHeaderFormat(
551
+ "Authorization", "expected space after scheme", header, pos
552
+ )
553
+ pos += 1
554
+ basic_credentials, pos = parse_token68(header, pos, "Authorization")
555
+ parse_end(header, pos, "Authorization")
556
+
557
+ try:
558
+ user_pass = base64.b64decode(basic_credentials.encode()).decode()
559
+ except binascii.Error:
560
+ raise InvalidHeaderValue(
561
+ "Authorization",
562
+ "expected base64-encoded credentials",
563
+ ) from None
564
+ try:
565
+ username, password = user_pass.split(":", 1)
566
+ except ValueError:
567
+ raise InvalidHeaderValue(
568
+ "Authorization",
569
+ "expected username:password credentials",
570
+ ) from None
571
+
572
+ return username, password
573
+
574
+
575
+ def build_authorization_basic(username: str, password: str) -> str:
576
+ """
577
+ Build an ``Authorization`` header for HTTP Basic Auth.
578
+
579
+ This is the reverse of :func:`parse_authorization_basic`.
580
+
581
+ """
582
+ # https://datatracker.ietf.org/doc/html/rfc7617#section-2
583
+ assert ":" not in username
584
+ user_pass = f"{username}:{password}"
585
+ basic_credentials = base64.b64encode(user_pass.encode()).decode()
586
+ return "Basic " + basic_credentials
source/websockets/http.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+
5
+ from .datastructures import Headers, MultipleValuesError # noqa: F401
6
+
7
+
8
+ with warnings.catch_warnings():
9
+ # Suppress redundant DeprecationWarning raised by websockets.legacy.
10
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
11
+ from .legacy.http import read_request, read_response # noqa: F401
12
+
13
+
14
+ warnings.warn( # deprecated in 9.0 - 2021-09-01
15
+ "Headers and MultipleValuesError were moved "
16
+ "from websockets.http to websockets.datastructures"
17
+ "and read_request and read_response were moved "
18
+ "from websockets.http to websockets.legacy.http",
19
+ DeprecationWarning,
20
+ )
source/websockets/http11.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import os
5
+ import re
6
+ import sys
7
+ import warnings
8
+ from collections.abc import Generator
9
+ from typing import Callable
10
+
11
+ from .datastructures import Headers
12
+ from .exceptions import SecurityError
13
+ from .version import version as websockets_version
14
+
15
+
16
+ __all__ = [
17
+ "SERVER",
18
+ "USER_AGENT",
19
+ "Request",
20
+ "Response",
21
+ ]
22
+
23
+
24
+ PYTHON_VERSION = "{}.{}".format(*sys.version_info)
25
+
26
+ # User-Agent header for HTTP requests.
27
+ USER_AGENT = os.environ.get(
28
+ "WEBSOCKETS_USER_AGENT",
29
+ f"Python/{PYTHON_VERSION} websockets/{websockets_version}",
30
+ )
31
+
32
+ # Server header for HTTP responses.
33
+ SERVER = os.environ.get(
34
+ "WEBSOCKETS_SERVER",
35
+ f"Python/{PYTHON_VERSION} websockets/{websockets_version}",
36
+ )
37
+
38
+ # Maximum total size of headers is around 128 * 8 KiB = 1 MiB.
39
+ MAX_NUM_HEADERS = int(os.environ.get("WEBSOCKETS_MAX_NUM_HEADERS", "128"))
40
+
41
+ # Limit request line and header lines. 8KiB is the most common default
42
+ # configuration of popular HTTP servers.
43
+ MAX_LINE_LENGTH = int(os.environ.get("WEBSOCKETS_MAX_LINE_LENGTH", "8192"))
44
+
45
+ # Support for HTTP response bodies is intended to read an error message
46
+ # returned by a server. It isn't designed to perform large file transfers.
47
+ MAX_BODY_SIZE = int(os.environ.get("WEBSOCKETS_MAX_BODY_SIZE", "1_048_576")) # 1 MiB
48
+
49
+
50
+ def d(value: bytes | bytearray) -> str:
51
+ """
52
+ Decode a bytestring for interpolating into an error message.
53
+
54
+ """
55
+ return value.decode(errors="backslashreplace")
56
+
57
+
58
+ # See https://datatracker.ietf.org/doc/html/rfc7230#appendix-B.
59
+
60
+ # Regex for validating header names.
61
+
62
+ _token_re = re.compile(rb"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+")
63
+
64
+ # Regex for validating header values.
65
+
66
+ # We don't attempt to support obsolete line folding.
67
+
68
+ # Include HTAB (\x09), SP (\x20), VCHAR (\x21-\x7e), obs-text (\x80-\xff).
69
+
70
+ # The ABNF is complicated because it attempts to express that optional
71
+ # whitespace is ignored. We strip whitespace and don't revalidate that.
72
+
73
+ # See also https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189
74
+
75
+ _value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*")
76
+
77
+
78
+ @dataclasses.dataclass
79
+ class Request:
80
+ """
81
+ WebSocket handshake request.
82
+
83
+ Attributes:
84
+ path: Request path, including optional query.
85
+ headers: Request headers.
86
+ """
87
+
88
+ path: str
89
+ headers: Headers
90
+ # body isn't useful is the context of this library.
91
+
92
+ _exception: Exception | None = None
93
+
94
+ @property
95
+ def exception(self) -> Exception | None: # pragma: no cover
96
+ warnings.warn( # deprecated in 10.3 - 2022-04-17
97
+ "Request.exception is deprecated; use ServerProtocol.handshake_exc instead",
98
+ DeprecationWarning,
99
+ )
100
+ return self._exception
101
+
102
+ @classmethod
103
+ def parse(
104
+ cls,
105
+ read_line: Callable[[int], Generator[None, None, bytes | bytearray]],
106
+ ) -> Generator[None, None, Request]:
107
+ """
108
+ Parse a WebSocket handshake request.
109
+
110
+ This is a generator-based coroutine.
111
+
112
+ The request path isn't URL-decoded or validated in any way.
113
+
114
+ The request path and headers are expected to contain only ASCII
115
+ characters. Other characters are represented with surrogate escapes.
116
+
117
+ :meth:`parse` doesn't attempt to read the request body because
118
+ WebSocket handshake requests don't have one. If the request contains a
119
+ body, it may be read from the data stream after :meth:`parse` returns.
120
+
121
+ Args:
122
+ read_line: Generator-based coroutine that reads a LF-terminated
123
+ line or raises an exception if there isn't enough data
124
+
125
+ Raises:
126
+ EOFError: If the connection is closed without a full HTTP request.
127
+ SecurityError: If the request exceeds a security limit.
128
+ ValueError: If the request isn't well formatted.
129
+
130
+ """
131
+ # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.1
132
+
133
+ # Parsing is simple because fixed values are expected for method and
134
+ # version and because path isn't checked. Since WebSocket software tends
135
+ # to implement HTTP/1.1 strictly, there's little need for lenient parsing.
136
+
137
+ try:
138
+ request_line = yield from parse_line(read_line)
139
+ except EOFError as exc:
140
+ raise EOFError("connection closed while reading HTTP request line") from exc
141
+
142
+ try:
143
+ method, raw_path, protocol = request_line.split(b" ", 2)
144
+ except ValueError: # not enough values to unpack (expected 3, got 1-2)
145
+ raise ValueError(f"invalid HTTP request line: {d(request_line)}") from None
146
+ if protocol != b"HTTP/1.1":
147
+ raise ValueError(
148
+ f"unsupported protocol; expected HTTP/1.1: {d(request_line)}"
149
+ )
150
+ if method != b"GET":
151
+ raise ValueError(f"unsupported HTTP method; expected GET; got {d(method)}")
152
+ path = raw_path.decode("ascii", "surrogateescape")
153
+
154
+ headers = yield from parse_headers(read_line)
155
+
156
+ # https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.3
157
+
158
+ if "Transfer-Encoding" in headers:
159
+ raise NotImplementedError("transfer codings aren't supported")
160
+
161
+ if "Content-Length" in headers:
162
+ # Some devices send a Content-Length header with a value of 0.
163
+ # This raises ValueError if Content-Length isn't an integer too.
164
+ if int(headers["Content-Length"]) != 0:
165
+ raise ValueError("unsupported request body")
166
+
167
+ return cls(path, headers)
168
+
169
+ def serialize(self) -> bytes:
170
+ """
171
+ Serialize a WebSocket handshake request.
172
+
173
+ """
174
+ # Since the request line and headers only contain ASCII characters,
175
+ # we can keep this simple.
176
+ request = f"GET {self.path} HTTP/1.1\r\n".encode()
177
+ request += self.headers.serialize()
178
+ return request
179
+
180
+
181
+ @dataclasses.dataclass
182
+ class Response:
183
+ """
184
+ WebSocket handshake response.
185
+
186
+ Attributes:
187
+ status_code: Response code.
188
+ reason_phrase: Response reason.
189
+ headers: Response headers.
190
+ body: Response body.
191
+
192
+ """
193
+
194
+ status_code: int
195
+ reason_phrase: str
196
+ headers: Headers
197
+ body: bytes | bytearray = b""
198
+
199
+ _exception: Exception | None = None
200
+
201
+ @property
202
+ def exception(self) -> Exception | None: # pragma: no cover
203
+ warnings.warn( # deprecated in 10.3 - 2022-04-17
204
+ "Response.exception is deprecated; "
205
+ "use ClientProtocol.handshake_exc instead",
206
+ DeprecationWarning,
207
+ )
208
+ return self._exception
209
+
210
+ @classmethod
211
+ def parse(
212
+ cls,
213
+ read_line: Callable[[int], Generator[None, None, bytes | bytearray]],
214
+ read_exact: Callable[[int], Generator[None, None, bytes | bytearray]],
215
+ read_to_eof: Callable[[int], Generator[None, None, bytes | bytearray]],
216
+ proxy: bool = False,
217
+ ) -> Generator[None, None, Response]:
218
+ """
219
+ Parse a WebSocket handshake response.
220
+
221
+ This is a generator-based coroutine.
222
+
223
+ The reason phrase and headers are expected to contain only ASCII
224
+ characters. Other characters are represented with surrogate escapes.
225
+
226
+ Args:
227
+ read_line: Generator-based coroutine that reads a LF-terminated
228
+ line or raises an exception if there isn't enough data.
229
+ read_exact: Generator-based coroutine that reads the requested
230
+ bytes or raises an exception if there isn't enough data.
231
+ read_to_eof: Generator-based coroutine that reads until the end
232
+ of the stream.
233
+
234
+ Raises:
235
+ EOFError: If the connection is closed without a full HTTP response.
236
+ SecurityError: If the response exceeds a security limit.
237
+ LookupError: If the response isn't well formatted.
238
+ ValueError: If the response isn't well formatted.
239
+
240
+ """
241
+ # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.2
242
+
243
+ try:
244
+ status_line = yield from parse_line(read_line)
245
+ except EOFError as exc:
246
+ raise EOFError("connection closed while reading HTTP status line") from exc
247
+
248
+ try:
249
+ protocol, raw_status_code, raw_reason = status_line.split(b" ", 2)
250
+ except ValueError: # not enough values to unpack (expected 3, got 1-2)
251
+ raise ValueError(f"invalid HTTP status line: {d(status_line)}") from None
252
+ if proxy: # some proxies still use HTTP/1.0
253
+ if protocol not in [b"HTTP/1.1", b"HTTP/1.0"]:
254
+ raise ValueError(
255
+ f"unsupported protocol; expected HTTP/1.1 or HTTP/1.0: "
256
+ f"{d(status_line)}"
257
+ )
258
+ else:
259
+ if protocol != b"HTTP/1.1":
260
+ raise ValueError(
261
+ f"unsupported protocol; expected HTTP/1.1: {d(status_line)}"
262
+ )
263
+ try:
264
+ status_code = int(raw_status_code)
265
+ except ValueError: # invalid literal for int() with base 10
266
+ raise ValueError(
267
+ f"invalid status code; expected integer; got {d(raw_status_code)}"
268
+ ) from None
269
+ if not 100 <= status_code < 600:
270
+ raise ValueError(
271
+ f"invalid status code; expected 100–599; got {d(raw_status_code)}"
272
+ )
273
+ if not _value_re.fullmatch(raw_reason):
274
+ raise ValueError(f"invalid HTTP reason phrase: {d(raw_reason)}")
275
+ reason = raw_reason.decode("ascii", "surrogateescape")
276
+
277
+ headers = yield from parse_headers(read_line)
278
+
279
+ body: bytes | bytearray
280
+ if proxy:
281
+ body = b""
282
+ else:
283
+ body = yield from read_body(
284
+ status_code, headers, read_line, read_exact, read_to_eof
285
+ )
286
+
287
+ return cls(status_code, reason, headers, body)
288
+
289
+ def serialize(self) -> bytes:
290
+ """
291
+ Serialize a WebSocket handshake response.
292
+
293
+ """
294
+ # Since the status line and headers only contain ASCII characters,
295
+ # we can keep this simple.
296
+ response = f"HTTP/1.1 {self.status_code} {self.reason_phrase}\r\n".encode()
297
+ response += self.headers.serialize()
298
+ response += self.body
299
+ return response
300
+
301
+
302
+ def parse_line(
303
+ read_line: Callable[[int], Generator[None, None, bytes | bytearray]],
304
+ ) -> Generator[None, None, bytes | bytearray]:
305
+ """
306
+ Parse a single line.
307
+
308
+ CRLF is stripped from the return value.
309
+
310
+ Args:
311
+ read_line: Generator-based coroutine that reads a LF-terminated line
312
+ or raises an exception if there isn't enough data.
313
+
314
+ Raises:
315
+ EOFError: If the connection is closed without a CRLF.
316
+ SecurityError: If the response exceeds a security limit.
317
+
318
+ """
319
+ try:
320
+ line = yield from read_line(MAX_LINE_LENGTH)
321
+ except RuntimeError:
322
+ raise SecurityError("line too long")
323
+ # Not mandatory but safe - https://datatracker.ietf.org/doc/html/rfc7230#section-3.5
324
+ if not line.endswith(b"\r\n"):
325
+ raise EOFError("line without CRLF")
326
+ return line[:-2]
327
+
328
+
329
+ def parse_headers(
330
+ read_line: Callable[[int], Generator[None, None, bytes | bytearray]],
331
+ ) -> Generator[None, None, Headers]:
332
+ """
333
+ Parse HTTP headers.
334
+
335
+ Non-ASCII characters are represented with surrogate escapes.
336
+
337
+ Args:
338
+ read_line: Generator-based coroutine that reads a LF-terminated line
339
+ or raises an exception if there isn't enough data.
340
+
341
+ Raises:
342
+ EOFError: If the connection is closed without complete headers.
343
+ SecurityError: If the request exceeds a security limit.
344
+ ValueError: If the request isn't well formatted.
345
+
346
+ """
347
+ # https://datatracker.ietf.org/doc/html/rfc7230#section-3.2
348
+
349
+ # We don't attempt to support obsolete line folding.
350
+
351
+ headers = Headers()
352
+ for _ in range(MAX_NUM_HEADERS + 1):
353
+ try:
354
+ line = yield from parse_line(read_line)
355
+ except EOFError as exc:
356
+ raise EOFError("connection closed while reading HTTP headers") from exc
357
+ if line == b"":
358
+ break
359
+
360
+ try:
361
+ raw_name, raw_value = line.split(b":", 1)
362
+ except ValueError: # not enough values to unpack (expected 2, got 1)
363
+ raise ValueError(f"invalid HTTP header line: {d(line)}") from None
364
+ if not _token_re.fullmatch(raw_name):
365
+ raise ValueError(f"invalid HTTP header name: {d(raw_name)}")
366
+ raw_value = raw_value.strip(b" \t")
367
+ if not _value_re.fullmatch(raw_value):
368
+ raise ValueError(f"invalid HTTP header value: {d(raw_value)}")
369
+
370
+ name = raw_name.decode("ascii") # guaranteed to be ASCII at this point
371
+ value = raw_value.decode("ascii", "surrogateescape")
372
+ headers[name] = value
373
+
374
+ else:
375
+ raise SecurityError("too many HTTP headers")
376
+
377
+ return headers
378
+
379
+
380
+ def read_body(
381
+ status_code: int,
382
+ headers: Headers,
383
+ read_line: Callable[[int], Generator[None, None, bytes | bytearray]],
384
+ read_exact: Callable[[int], Generator[None, None, bytes | bytearray]],
385
+ read_to_eof: Callable[[int], Generator[None, None, bytes | bytearray]],
386
+ ) -> Generator[None, None, bytes | bytearray]:
387
+ # https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.3
388
+
389
+ # Since websockets only does GET requests (no HEAD, no CONNECT), all
390
+ # responses except 1xx, 204, and 304 include a message body.
391
+ if 100 <= status_code < 200 or status_code == 204 or status_code == 304:
392
+ return b""
393
+
394
+ # MultipleValuesError is sufficiently unlikely that we don't attempt to
395
+ # handle it when accessing headers. Instead we document that its parent
396
+ # class, LookupError, may be raised.
397
+ # Conversions from str to int are protected by sys.set_int_max_str_digits..
398
+
399
+ elif (coding := headers.get("Transfer-Encoding")) is not None:
400
+ if coding != "chunked":
401
+ raise NotImplementedError(f"transfer coding {coding} isn't supported")
402
+
403
+ body = b""
404
+ while True:
405
+ chunk_size_line = yield from parse_line(read_line)
406
+ raw_chunk_size = chunk_size_line.split(b";", 1)[0]
407
+ # Set a lower limit than default_max_str_digits; 1 EB is plenty.
408
+ if len(raw_chunk_size) > 15:
409
+ str_chunk_size = raw_chunk_size.decode(errors="backslashreplace")
410
+ raise SecurityError(f"chunk too large: 0x{str_chunk_size} bytes")
411
+ chunk_size = int(raw_chunk_size, 16)
412
+ if chunk_size == 0:
413
+ break
414
+ if len(body) + chunk_size > MAX_BODY_SIZE:
415
+ raise SecurityError(
416
+ f"chunk too large: {chunk_size} bytes after {len(body)} bytes"
417
+ )
418
+ body += yield from read_exact(chunk_size)
419
+ if (yield from read_exact(2)) != b"\r\n":
420
+ raise ValueError("chunk without CRLF")
421
+ # Read the trailer.
422
+ yield from parse_headers(read_line)
423
+ return body
424
+
425
+ elif (raw_content_length := headers.get("Content-Length")) is not None:
426
+ # Set a lower limit than default_max_str_digits; 1 EiB is plenty.
427
+ if len(raw_content_length) > 18:
428
+ raise SecurityError(f"body too large: {raw_content_length} bytes")
429
+ content_length = int(raw_content_length)
430
+ if content_length > MAX_BODY_SIZE:
431
+ raise SecurityError(f"body too large: {content_length} bytes")
432
+ return (yield from read_exact(content_length))
433
+
434
+ else:
435
+ try:
436
+ return (yield from read_to_eof(MAX_BODY_SIZE))
437
+ except RuntimeError:
438
+ raise SecurityError(f"body too large: over {MAX_BODY_SIZE} bytes")
source/websockets/imports.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+ from collections.abc import Iterable
5
+ from typing import Any
6
+
7
+
8
+ __all__ = ["lazy_import"]
9
+
10
+
11
+ def import_name(name: str, source: str, namespace: dict[str, Any]) -> Any:
12
+ """
13
+ Import ``name`` from ``source`` in ``namespace``.
14
+
15
+ There are two use cases:
16
+
17
+ - ``name`` is an object defined in ``source``;
18
+ - ``name`` is a submodule of ``source``.
19
+
20
+ Neither :func:`__import__` nor :func:`~importlib.import_module` does
21
+ exactly this. :func:`__import__` is closer to the intended behavior.
22
+
23
+ """
24
+ level = 0
25
+ while source[level] == ".":
26
+ level += 1
27
+ assert level < len(source), "importing from parent isn't supported"
28
+ module = __import__(source[level:], namespace, None, [name], level)
29
+ return getattr(module, name)
30
+
31
+
32
+ def lazy_import(
33
+ namespace: dict[str, Any],
34
+ aliases: dict[str, str] | None = None,
35
+ deprecated_aliases: dict[str, str] | None = None,
36
+ ) -> None:
37
+ """
38
+ Provide lazy, module-level imports.
39
+
40
+ Typical use::
41
+
42
+ __getattr__, __dir__ = lazy_import(
43
+ globals(),
44
+ aliases={
45
+ "<name>": "<source module>",
46
+ ...
47
+ },
48
+ deprecated_aliases={
49
+ ...,
50
+ }
51
+ )
52
+
53
+ This function defines ``__getattr__`` and ``__dir__`` per :pep:`562`.
54
+
55
+ """
56
+ if aliases is None:
57
+ aliases = {}
58
+ if deprecated_aliases is None:
59
+ deprecated_aliases = {}
60
+
61
+ namespace_set = set(namespace)
62
+ aliases_set = set(aliases)
63
+ deprecated_aliases_set = set(deprecated_aliases)
64
+
65
+ assert not namespace_set & aliases_set, "namespace conflict"
66
+ assert not namespace_set & deprecated_aliases_set, "namespace conflict"
67
+ assert not aliases_set & deprecated_aliases_set, "namespace conflict"
68
+
69
+ package = namespace["__name__"]
70
+
71
+ def __getattr__(name: str) -> Any:
72
+ assert aliases is not None # mypy cannot figure this out
73
+ try:
74
+ source = aliases[name]
75
+ except KeyError:
76
+ pass
77
+ else:
78
+ return import_name(name, source, namespace)
79
+
80
+ assert deprecated_aliases is not None # mypy cannot figure this out
81
+ try:
82
+ source = deprecated_aliases[name]
83
+ except KeyError:
84
+ pass
85
+ else:
86
+ warnings.warn(
87
+ f"{package}.{name} is deprecated",
88
+ DeprecationWarning,
89
+ stacklevel=2,
90
+ )
91
+ return import_name(name, source, namespace)
92
+
93
+ raise AttributeError(f"module {package!r} has no attribute {name!r}")
94
+
95
+ namespace["__getattr__"] = __getattr__
96
+
97
+ def __dir__() -> Iterable[str]:
98
+ return sorted(namespace_set | aliases_set | deprecated_aliases_set)
99
+
100
+ namespace["__dir__"] = __dir__
source/websockets/legacy/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+
5
+
6
+ warnings.warn( # deprecated in 14.0 - 2024-11-09
7
+ "websockets.legacy is deprecated; "
8
+ "see https://websockets.readthedocs.io/en/stable/howto/upgrade.html "
9
+ "for upgrade instructions",
10
+ DeprecationWarning,
11
+ )
source/websockets/legacy/auth.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import hmac
5
+ import http
6
+ from collections.abc import Awaitable, Iterable
7
+ from typing import Any, Callable, cast
8
+
9
+ from ..datastructures import Headers
10
+ from ..exceptions import InvalidHeader
11
+ from ..headers import build_www_authenticate_basic, parse_authorization_basic
12
+ from .server import HTTPResponse, WebSocketServerProtocol
13
+
14
+
15
+ __all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"]
16
+
17
+ Credentials = tuple[str, str]
18
+
19
+
20
+ def is_credentials(value: Any) -> bool:
21
+ try:
22
+ username, password = value
23
+ except (TypeError, ValueError):
24
+ return False
25
+ else:
26
+ return isinstance(username, str) and isinstance(password, str)
27
+
28
+
29
+ class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol):
30
+ """
31
+ WebSocket server protocol that enforces HTTP Basic Auth.
32
+
33
+ """
34
+
35
+ realm: str = ""
36
+ """
37
+ Scope of protection.
38
+
39
+ If provided, it should contain only ASCII characters because the
40
+ encoding of non-ASCII characters is undefined.
41
+ """
42
+
43
+ username: str | None = None
44
+ """Username of the authenticated user."""
45
+
46
+ def __init__(
47
+ self,
48
+ *args: Any,
49
+ realm: str | None = None,
50
+ check_credentials: Callable[[str, str], Awaitable[bool]] | None = None,
51
+ **kwargs: Any,
52
+ ) -> None:
53
+ if realm is not None:
54
+ self.realm = realm # shadow class attribute
55
+ self._check_credentials = check_credentials
56
+ super().__init__(*args, **kwargs)
57
+
58
+ async def check_credentials(self, username: str, password: str) -> bool:
59
+ """
60
+ Check whether credentials are authorized.
61
+
62
+ This coroutine may be overridden in a subclass, for example to
63
+ authenticate against a database or an external service.
64
+
65
+ Args:
66
+ username: HTTP Basic Auth username.
67
+ password: HTTP Basic Auth password.
68
+
69
+ Returns:
70
+ :obj:`True` if the handshake should continue;
71
+ :obj:`False` if it should fail with an HTTP 401 error.
72
+
73
+ """
74
+ if self._check_credentials is not None:
75
+ return await self._check_credentials(username, password)
76
+
77
+ return False
78
+
79
+ async def process_request(
80
+ self,
81
+ path: str,
82
+ request_headers: Headers,
83
+ ) -> HTTPResponse | None:
84
+ """
85
+ Check HTTP Basic Auth and return an HTTP 401 response if needed.
86
+
87
+ """
88
+ try:
89
+ authorization = request_headers["Authorization"]
90
+ except KeyError:
91
+ return (
92
+ http.HTTPStatus.UNAUTHORIZED,
93
+ [("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
94
+ b"Missing credentials\n",
95
+ )
96
+
97
+ try:
98
+ username, password = parse_authorization_basic(authorization)
99
+ except InvalidHeader:
100
+ return (
101
+ http.HTTPStatus.UNAUTHORIZED,
102
+ [("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
103
+ b"Unsupported credentials\n",
104
+ )
105
+
106
+ if not await self.check_credentials(username, password):
107
+ return (
108
+ http.HTTPStatus.UNAUTHORIZED,
109
+ [("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
110
+ b"Invalid credentials\n",
111
+ )
112
+
113
+ self.username = username
114
+
115
+ return await super().process_request(path, request_headers)
116
+
117
+
118
+ def basic_auth_protocol_factory(
119
+ realm: str | None = None,
120
+ credentials: Credentials | Iterable[Credentials] | None = None,
121
+ check_credentials: Callable[[str, str], Awaitable[bool]] | None = None,
122
+ create_protocol: Callable[..., BasicAuthWebSocketServerProtocol] | None = None,
123
+ ) -> Callable[..., BasicAuthWebSocketServerProtocol]:
124
+ """
125
+ Protocol factory that enforces HTTP Basic Auth.
126
+
127
+ :func:`basic_auth_protocol_factory` is designed to integrate with
128
+ :func:`~websockets.legacy.server.serve` like this::
129
+
130
+ serve(
131
+ ...,
132
+ create_protocol=basic_auth_protocol_factory(
133
+ realm="my dev server",
134
+ credentials=("hello", "iloveyou"),
135
+ )
136
+ )
137
+
138
+ Args:
139
+ realm: Scope of protection. It should contain only ASCII characters
140
+ because the encoding of non-ASCII characters is undefined.
141
+ Refer to section 2.2 of :rfc:`7235` for details.
142
+ credentials: Hard coded authorized credentials. It can be a
143
+ ``(username, password)`` pair or a list of such pairs.
144
+ check_credentials: Coroutine that verifies credentials.
145
+ It receives ``username`` and ``password`` arguments
146
+ and returns a :class:`bool`. One of ``credentials`` or
147
+ ``check_credentials`` must be provided but not both.
148
+ create_protocol: Factory that creates the protocol. By default, this
149
+ is :class:`BasicAuthWebSocketServerProtocol`. It can be replaced
150
+ by a subclass.
151
+ Raises:
152
+ TypeError: If the ``credentials`` or ``check_credentials`` argument is
153
+ wrong.
154
+
155
+ """
156
+ if (credentials is None) == (check_credentials is None):
157
+ raise TypeError("provide either credentials or check_credentials")
158
+
159
+ if credentials is not None:
160
+ if is_credentials(credentials):
161
+ credentials_list = [cast(Credentials, credentials)]
162
+ elif isinstance(credentials, Iterable):
163
+ credentials_list = list(cast(Iterable[Credentials], credentials))
164
+ if not all(is_credentials(item) for item in credentials_list):
165
+ raise TypeError(f"invalid credentials argument: {credentials}")
166
+ else:
167
+ raise TypeError(f"invalid credentials argument: {credentials}")
168
+
169
+ credentials_dict = dict(credentials_list)
170
+
171
+ async def check_credentials(username: str, password: str) -> bool:
172
+ try:
173
+ expected_password = credentials_dict[username]
174
+ except KeyError:
175
+ return False
176
+ return hmac.compare_digest(expected_password, password)
177
+
178
+ if create_protocol is None:
179
+ create_protocol = BasicAuthWebSocketServerProtocol
180
+
181
+ # Help mypy and avoid this error: "type[BasicAuthWebSocketServerProtocol] |
182
+ # Callable[..., BasicAuthWebSocketServerProtocol]" not callable [misc]
183
+ create_protocol = cast(
184
+ Callable[..., BasicAuthWebSocketServerProtocol], create_protocol
185
+ )
186
+ return functools.partial(
187
+ create_protocol,
188
+ realm=realm,
189
+ check_credentials=check_credentials,
190
+ )
source/websockets/legacy/client.py ADDED
@@ -0,0 +1,703 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import functools
5
+ import logging
6
+ import os
7
+ import random
8
+ import traceback
9
+ import urllib.parse
10
+ import warnings
11
+ from collections.abc import AsyncIterator, Generator, Sequence
12
+ from types import TracebackType
13
+ from typing import Any, Callable, cast
14
+
15
+ from ..asyncio.compatibility import asyncio_timeout
16
+ from ..datastructures import Headers, HeadersLike
17
+ from ..exceptions import (
18
+ InvalidHeader,
19
+ InvalidHeaderValue,
20
+ InvalidMessage,
21
+ NegotiationError,
22
+ SecurityError,
23
+ )
24
+ from ..extensions import ClientExtensionFactory, Extension
25
+ from ..extensions.permessage_deflate import enable_client_permessage_deflate
26
+ from ..headers import (
27
+ build_authorization_basic,
28
+ build_extension,
29
+ build_host,
30
+ build_subprotocol,
31
+ parse_extension,
32
+ parse_subprotocol,
33
+ validate_subprotocols,
34
+ )
35
+ from ..http11 import USER_AGENT
36
+ from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol
37
+ from ..uri import WebSocketURI, parse_uri
38
+ from .exceptions import InvalidStatusCode, RedirectHandshake
39
+ from .handshake import build_request, check_response
40
+ from .http import read_response
41
+ from .protocol import WebSocketCommonProtocol
42
+
43
+
44
+ __all__ = ["connect", "unix_connect", "WebSocketClientProtocol"]
45
+
46
+
47
+ class WebSocketClientProtocol(WebSocketCommonProtocol):
48
+ """
49
+ WebSocket client connection.
50
+
51
+ :class:`WebSocketClientProtocol` provides :meth:`recv` and :meth:`send`
52
+ coroutines for receiving and sending messages.
53
+
54
+ It supports asynchronous iteration to receive messages::
55
+
56
+ async for message in websocket:
57
+ await process(message)
58
+
59
+ The iterator exits normally when the connection is closed with close code
60
+ 1000 (OK) or 1001 (going away) or without a close code. It raises
61
+ a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection
62
+ is closed with any other code.
63
+
64
+ See :func:`connect` for the documentation of ``logger``, ``origin``,
65
+ ``extensions``, ``subprotocols``, ``extra_headers``, and
66
+ ``user_agent_header``.
67
+
68
+ See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the
69
+ documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``,
70
+ ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``.
71
+
72
+ """
73
+
74
+ is_client = True
75
+ side = "client"
76
+
77
+ def __init__(
78
+ self,
79
+ *,
80
+ logger: LoggerLike | None = None,
81
+ origin: Origin | None = None,
82
+ extensions: Sequence[ClientExtensionFactory] | None = None,
83
+ subprotocols: Sequence[Subprotocol] | None = None,
84
+ extra_headers: HeadersLike | None = None,
85
+ user_agent_header: str | None = USER_AGENT,
86
+ **kwargs: Any,
87
+ ) -> None:
88
+ if logger is None:
89
+ logger = logging.getLogger("websockets.client")
90
+ super().__init__(logger=logger, **kwargs)
91
+ self.origin = origin
92
+ self.available_extensions = extensions
93
+ self.available_subprotocols = subprotocols
94
+ self.extra_headers = extra_headers
95
+ self.user_agent_header = user_agent_header
96
+
97
+ def write_http_request(self, path: str, headers: Headers) -> None:
98
+ """
99
+ Write request line and headers to the HTTP request.
100
+
101
+ """
102
+ self.path = path
103
+ self.request_headers = headers
104
+
105
+ if self.debug:
106
+ self.logger.debug("> GET %s HTTP/1.1", path)
107
+ for key, value in headers.raw_items():
108
+ self.logger.debug("> %s: %s", key, value)
109
+
110
+ # Since the path and headers only contain ASCII characters,
111
+ # we can keep this simple.
112
+ request = f"GET {path} HTTP/1.1\r\n"
113
+ request += str(headers)
114
+
115
+ self.transport.write(request.encode())
116
+
117
+ async def read_http_response(self) -> tuple[int, Headers]:
118
+ """
119
+ Read status line and headers from the HTTP response.
120
+
121
+ If the response contains a body, it may be read from ``self.reader``
122
+ after this coroutine returns.
123
+
124
+ Raises:
125
+ InvalidMessage: If the HTTP message is malformed or isn't an
126
+ HTTP/1.1 GET response.
127
+
128
+ """
129
+ try:
130
+ status_code, reason, headers = await read_response(self.reader)
131
+ except Exception as exc:
132
+ raise InvalidMessage("did not receive a valid HTTP response") from exc
133
+
134
+ if self.debug:
135
+ self.logger.debug("< HTTP/1.1 %d %s", status_code, reason)
136
+ for key, value in headers.raw_items():
137
+ self.logger.debug("< %s: %s", key, value)
138
+
139
+ self.response_headers = headers
140
+
141
+ return status_code, self.response_headers
142
+
143
+ @staticmethod
144
+ def process_extensions(
145
+ headers: Headers,
146
+ available_extensions: Sequence[ClientExtensionFactory] | None,
147
+ ) -> list[Extension]:
148
+ """
149
+ Handle the Sec-WebSocket-Extensions HTTP response header.
150
+
151
+ Check that each extension is supported, as well as its parameters.
152
+
153
+ Return the list of accepted extensions.
154
+
155
+ Raise :exc:`~websockets.exceptions.InvalidHandshake` to abort the
156
+ connection.
157
+
158
+ :rfc:`6455` leaves the rules up to the specification of each
159
+ :extension.
160
+
161
+ To provide this level of flexibility, for each extension accepted by
162
+ the server, we check for a match with each extension available in the
163
+ client configuration. If no match is found, an exception is raised.
164
+
165
+ If several variants of the same extension are accepted by the server,
166
+ it may be configured several times, which won't make sense in general.
167
+ Extensions must implement their own requirements. For this purpose,
168
+ the list of previously accepted extensions is provided.
169
+
170
+ Other requirements, for example related to mandatory extensions or the
171
+ order of extensions, may be implemented by overriding this method.
172
+
173
+ """
174
+ accepted_extensions: list[Extension] = []
175
+
176
+ header_values = headers.get_all("Sec-WebSocket-Extensions")
177
+
178
+ if header_values:
179
+ if available_extensions is None:
180
+ raise NegotiationError("no extensions supported")
181
+
182
+ parsed_header_values: list[ExtensionHeader] = sum(
183
+ [parse_extension(header_value) for header_value in header_values], []
184
+ )
185
+
186
+ for name, response_params in parsed_header_values:
187
+ for extension_factory in available_extensions:
188
+ # Skip non-matching extensions based on their name.
189
+ if extension_factory.name != name:
190
+ continue
191
+
192
+ # Skip non-matching extensions based on their params.
193
+ try:
194
+ extension = extension_factory.process_response_params(
195
+ response_params, accepted_extensions
196
+ )
197
+ except NegotiationError:
198
+ continue
199
+
200
+ # Add matching extension to the final list.
201
+ accepted_extensions.append(extension)
202
+
203
+ # Break out of the loop once we have a match.
204
+ break
205
+
206
+ # If we didn't break from the loop, no extension in our list
207
+ # matched what the server sent. Fail the connection.
208
+ else:
209
+ raise NegotiationError(
210
+ f"Unsupported extension: "
211
+ f"name = {name}, params = {response_params}"
212
+ )
213
+
214
+ return accepted_extensions
215
+
216
+ @staticmethod
217
+ def process_subprotocol(
218
+ headers: Headers, available_subprotocols: Sequence[Subprotocol] | None
219
+ ) -> Subprotocol | None:
220
+ """
221
+ Handle the Sec-WebSocket-Protocol HTTP response header.
222
+
223
+ Check that it contains exactly one supported subprotocol.
224
+
225
+ Return the selected subprotocol.
226
+
227
+ """
228
+ subprotocol: Subprotocol | None = None
229
+
230
+ header_values = headers.get_all("Sec-WebSocket-Protocol")
231
+
232
+ if header_values:
233
+ if available_subprotocols is None:
234
+ raise NegotiationError("no subprotocols supported")
235
+
236
+ parsed_header_values: Sequence[Subprotocol] = sum(
237
+ [parse_subprotocol(header_value) for header_value in header_values], []
238
+ )
239
+
240
+ if len(parsed_header_values) > 1:
241
+ raise InvalidHeaderValue(
242
+ "Sec-WebSocket-Protocol",
243
+ f"multiple values: {', '.join(parsed_header_values)}",
244
+ )
245
+
246
+ subprotocol = parsed_header_values[0]
247
+
248
+ if subprotocol not in available_subprotocols:
249
+ raise NegotiationError(f"unsupported subprotocol: {subprotocol}")
250
+
251
+ return subprotocol
252
+
253
+ async def handshake(
254
+ self,
255
+ wsuri: WebSocketURI,
256
+ origin: Origin | None = None,
257
+ available_extensions: Sequence[ClientExtensionFactory] | None = None,
258
+ available_subprotocols: Sequence[Subprotocol] | None = None,
259
+ extra_headers: HeadersLike | None = None,
260
+ ) -> None:
261
+ """
262
+ Perform the client side of the opening handshake.
263
+
264
+ Args:
265
+ wsuri: URI of the WebSocket server.
266
+ origin: Value of the ``Origin`` header.
267
+ extensions: List of supported extensions, in order in which they
268
+ should be negotiated and run.
269
+ subprotocols: List of supported subprotocols, in order of decreasing
270
+ preference.
271
+ extra_headers: Arbitrary HTTP headers to add to the handshake request.
272
+
273
+ Raises:
274
+ InvalidHandshake: If the handshake fails.
275
+
276
+ """
277
+ request_headers = Headers()
278
+
279
+ request_headers["Host"] = build_host(wsuri.host, wsuri.port, wsuri.secure)
280
+
281
+ if wsuri.user_info:
282
+ request_headers["Authorization"] = build_authorization_basic(
283
+ *wsuri.user_info
284
+ )
285
+
286
+ if origin is not None:
287
+ request_headers["Origin"] = origin
288
+
289
+ key = build_request(request_headers)
290
+
291
+ if available_extensions is not None:
292
+ extensions_header = build_extension(
293
+ [
294
+ (extension_factory.name, extension_factory.get_request_params())
295
+ for extension_factory in available_extensions
296
+ ]
297
+ )
298
+ request_headers["Sec-WebSocket-Extensions"] = extensions_header
299
+
300
+ if available_subprotocols is not None:
301
+ protocol_header = build_subprotocol(available_subprotocols)
302
+ request_headers["Sec-WebSocket-Protocol"] = protocol_header
303
+
304
+ if self.extra_headers is not None:
305
+ request_headers.update(self.extra_headers)
306
+
307
+ if self.user_agent_header:
308
+ request_headers.setdefault("User-Agent", self.user_agent_header)
309
+
310
+ self.write_http_request(wsuri.resource_name, request_headers)
311
+
312
+ status_code, response_headers = await self.read_http_response()
313
+ if status_code in (301, 302, 303, 307, 308):
314
+ if "Location" not in response_headers:
315
+ raise InvalidHeader("Location")
316
+ raise RedirectHandshake(response_headers["Location"])
317
+ elif status_code != 101:
318
+ raise InvalidStatusCode(status_code, response_headers)
319
+
320
+ check_response(response_headers, key)
321
+
322
+ self.extensions = self.process_extensions(
323
+ response_headers, available_extensions
324
+ )
325
+
326
+ self.subprotocol = self.process_subprotocol(
327
+ response_headers, available_subprotocols
328
+ )
329
+
330
+ self.connection_open()
331
+
332
+
333
+ class Connect:
334
+ """
335
+ Connect to the WebSocket server at ``uri``.
336
+
337
+ Awaiting :func:`connect` yields a :class:`WebSocketClientProtocol` which
338
+ can then be used to send and receive messages.
339
+
340
+ :func:`connect` can be used as a asynchronous context manager::
341
+
342
+ async with connect(...) as websocket:
343
+ ...
344
+
345
+ The connection is closed automatically when exiting the context.
346
+
347
+ :func:`connect` can be used as an infinite asynchronous iterator to
348
+ reconnect automatically on errors::
349
+
350
+ async for websocket in connect(...):
351
+ try:
352
+ ...
353
+ except websockets.exceptions.ConnectionClosed:
354
+ continue
355
+
356
+ The connection is closed automatically after each iteration of the loop.
357
+
358
+ If an error occurs while establishing the connection, :func:`connect`
359
+ retries with exponential backoff. The backoff delay starts at three
360
+ seconds and increases up to one minute.
361
+
362
+ If an error occurs in the body of the loop, you can handle the exception
363
+ and :func:`connect` will reconnect with the next iteration; or you can
364
+ let the exception bubble up and break out of the loop. This lets you
365
+ decide which errors trigger a reconnection and which errors are fatal.
366
+
367
+ Args:
368
+ uri: URI of the WebSocket server.
369
+ create_protocol: Factory for the :class:`asyncio.Protocol` managing
370
+ the connection. It defaults to :class:`WebSocketClientProtocol`.
371
+ Set it to a wrapper or a subclass to customize connection handling.
372
+ logger: Logger for this client.
373
+ It defaults to ``logging.getLogger("websockets.client")``.
374
+ See the :doc:`logging guide <../../topics/logging>` for details.
375
+ compression: The "permessage-deflate" extension is enabled by default.
376
+ Set ``compression`` to :obj:`None` to disable it. See the
377
+ :doc:`compression guide <../../topics/compression>` for details.
378
+ origin: Value of the ``Origin`` header, for servers that require it.
379
+ extensions: List of supported extensions, in order in which they
380
+ should be negotiated and run.
381
+ subprotocols: List of supported subprotocols, in order of decreasing
382
+ preference.
383
+ extra_headers: Arbitrary HTTP headers to add to the handshake request.
384
+ user_agent_header: Value of the ``User-Agent`` request header.
385
+ It defaults to ``"Python/x.y.z websockets/X.Y"``.
386
+ Setting it to :obj:`None` removes the header.
387
+ open_timeout: Timeout for opening the connection in seconds.
388
+ :obj:`None` disables the timeout.
389
+
390
+ See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the
391
+ documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``,
392
+ ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``.
393
+
394
+ Any other keyword arguments are passed the event loop's
395
+ :meth:`~asyncio.loop.create_connection` method.
396
+
397
+ For example:
398
+
399
+ * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enforce TLS
400
+ settings. When connecting to a ``wss://`` URI, if ``ssl`` isn't
401
+ provided, a TLS context is created
402
+ with :func:`~ssl.create_default_context`.
403
+
404
+ * You can set ``host`` and ``port`` to connect to a different host and
405
+ port from those found in ``uri``. This only changes the destination of
406
+ the TCP connection. The host name from ``uri`` is still used in the TLS
407
+ handshake for secure connections and in the ``Host`` header.
408
+
409
+ Raises:
410
+ InvalidURI: If ``uri`` isn't a valid WebSocket URI.
411
+ OSError: If the TCP connection fails.
412
+ InvalidHandshake: If the opening handshake fails.
413
+ ~asyncio.TimeoutError: If the opening handshake times out.
414
+
415
+ """
416
+
417
+ MAX_REDIRECTS_ALLOWED = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10"))
418
+
419
+ def __init__(
420
+ self,
421
+ uri: str,
422
+ *,
423
+ create_protocol: Callable[..., WebSocketClientProtocol] | None = None,
424
+ logger: LoggerLike | None = None,
425
+ compression: str | None = "deflate",
426
+ origin: Origin | None = None,
427
+ extensions: Sequence[ClientExtensionFactory] | None = None,
428
+ subprotocols: Sequence[Subprotocol] | None = None,
429
+ extra_headers: HeadersLike | None = None,
430
+ user_agent_header: str | None = USER_AGENT,
431
+ open_timeout: float | None = 10,
432
+ ping_interval: float | None = 20,
433
+ ping_timeout: float | None = 20,
434
+ close_timeout: float | None = None,
435
+ max_size: int | None = 2**20,
436
+ max_queue: int | None = 2**5,
437
+ read_limit: int = 2**16,
438
+ write_limit: int = 2**16,
439
+ **kwargs: Any,
440
+ ) -> None:
441
+ # Backwards compatibility: close_timeout used to be called timeout.
442
+ timeout: float | None = kwargs.pop("timeout", None)
443
+ if timeout is None:
444
+ timeout = 10
445
+ else:
446
+ warnings.warn("rename timeout to close_timeout", DeprecationWarning)
447
+ # If both are specified, timeout is ignored.
448
+ if close_timeout is None:
449
+ close_timeout = timeout
450
+
451
+ # Backwards compatibility: create_protocol used to be called klass.
452
+ klass: type[WebSocketClientProtocol] | None = kwargs.pop("klass", None)
453
+ if klass is None:
454
+ klass = WebSocketClientProtocol
455
+ else:
456
+ warnings.warn("rename klass to create_protocol", DeprecationWarning)
457
+ # If both are specified, klass is ignored.
458
+ if create_protocol is None:
459
+ create_protocol = klass
460
+
461
+ # Backwards compatibility: recv() used to return None on closed connections
462
+ legacy_recv: bool = kwargs.pop("legacy_recv", False)
463
+
464
+ # Backwards compatibility: the loop parameter used to be supported.
465
+ _loop: asyncio.AbstractEventLoop | None = kwargs.pop("loop", None)
466
+ if _loop is None:
467
+ loop = asyncio.get_event_loop()
468
+ else:
469
+ loop = _loop
470
+ warnings.warn("remove loop argument", DeprecationWarning)
471
+
472
+ wsuri = parse_uri(uri)
473
+ if wsuri.secure:
474
+ kwargs.setdefault("ssl", True)
475
+ elif kwargs.get("ssl") is not None:
476
+ raise ValueError(
477
+ "connect() received a ssl argument for a ws:// URI, "
478
+ "use a wss:// URI to enable TLS"
479
+ )
480
+
481
+ if compression == "deflate":
482
+ extensions = enable_client_permessage_deflate(extensions)
483
+ elif compression is not None:
484
+ raise ValueError(f"unsupported compression: {compression}")
485
+
486
+ if subprotocols is not None:
487
+ validate_subprotocols(subprotocols)
488
+
489
+ # Help mypy and avoid this error: "type[WebSocketClientProtocol] |
490
+ # Callable[..., WebSocketClientProtocol]" not callable [misc]
491
+ create_protocol = cast(Callable[..., WebSocketClientProtocol], create_protocol)
492
+ factory = functools.partial(
493
+ create_protocol,
494
+ logger=logger,
495
+ origin=origin,
496
+ extensions=extensions,
497
+ subprotocols=subprotocols,
498
+ extra_headers=extra_headers,
499
+ user_agent_header=user_agent_header,
500
+ ping_interval=ping_interval,
501
+ ping_timeout=ping_timeout,
502
+ close_timeout=close_timeout,
503
+ max_size=max_size,
504
+ max_queue=max_queue,
505
+ read_limit=read_limit,
506
+ write_limit=write_limit,
507
+ host=wsuri.host,
508
+ port=wsuri.port,
509
+ secure=wsuri.secure,
510
+ legacy_recv=legacy_recv,
511
+ loop=_loop,
512
+ )
513
+
514
+ if kwargs.pop("unix", False):
515
+ path: str | None = kwargs.pop("path", None)
516
+ create_connection = functools.partial(
517
+ loop.create_unix_connection, factory, path, **kwargs
518
+ )
519
+ else:
520
+ host: str | None
521
+ port: int | None
522
+ if kwargs.get("sock") is None:
523
+ host, port = wsuri.host, wsuri.port
524
+ else:
525
+ # If sock is given, host and port shouldn't be specified.
526
+ host, port = None, None
527
+ if kwargs.get("ssl"):
528
+ kwargs.setdefault("server_hostname", wsuri.host)
529
+ # If host and port are given, override values from the URI.
530
+ host = kwargs.pop("host", host)
531
+ port = kwargs.pop("port", port)
532
+ create_connection = functools.partial(
533
+ loop.create_connection, factory, host, port, **kwargs
534
+ )
535
+
536
+ self.open_timeout = open_timeout
537
+ if logger is None:
538
+ logger = logging.getLogger("websockets.client")
539
+ self.logger = logger
540
+
541
+ # This is a coroutine function.
542
+ self._create_connection = create_connection
543
+ self._uri = uri
544
+ self._wsuri = wsuri
545
+
546
+ def handle_redirect(self, uri: str) -> None:
547
+ # Update the state of this instance to connect to a new URI.
548
+ old_uri = self._uri
549
+ old_wsuri = self._wsuri
550
+ new_uri = urllib.parse.urljoin(old_uri, uri)
551
+ new_wsuri = parse_uri(new_uri)
552
+
553
+ # Forbid TLS downgrade.
554
+ if old_wsuri.secure and not new_wsuri.secure:
555
+ raise SecurityError("redirect from WSS to WS")
556
+
557
+ same_origin = (
558
+ old_wsuri.secure == new_wsuri.secure
559
+ and old_wsuri.host == new_wsuri.host
560
+ and old_wsuri.port == new_wsuri.port
561
+ )
562
+
563
+ # Rewrite secure, host, and port for cross-origin redirects.
564
+ # This preserves connection overrides with the host and port
565
+ # arguments if the redirect points to the same host and port.
566
+ if not same_origin:
567
+ factory = self._create_connection.args[0]
568
+ # Support TLS upgrade.
569
+ if not old_wsuri.secure and new_wsuri.secure:
570
+ factory.keywords["secure"] = True
571
+ self._create_connection.keywords.setdefault("ssl", True)
572
+ # Replace secure, host, and port arguments of the protocol factory.
573
+ factory = functools.partial(
574
+ factory.func,
575
+ *factory.args,
576
+ **dict(factory.keywords, host=new_wsuri.host, port=new_wsuri.port),
577
+ )
578
+ # Replace secure, host, and port arguments of create_connection.
579
+ self._create_connection = functools.partial(
580
+ self._create_connection.func,
581
+ *(factory, new_wsuri.host, new_wsuri.port),
582
+ **self._create_connection.keywords,
583
+ )
584
+
585
+ # Set the new WebSocket URI. This suffices for same-origin redirects.
586
+ self._uri = new_uri
587
+ self._wsuri = new_wsuri
588
+
589
+ # async for ... in connect(...):
590
+
591
+ BACKOFF_INITIAL = float(os.environ.get("WEBSOCKETS_BACKOFF_INITIAL_DELAY", "5"))
592
+ BACKOFF_MIN = float(os.environ.get("WEBSOCKETS_BACKOFF_MIN_DELAY", "3.1"))
593
+ BACKOFF_MAX = float(os.environ.get("WEBSOCKETS_BACKOFF_MAX_DELAY", "90.0"))
594
+ BACKOFF_FACTOR = float(os.environ.get("WEBSOCKETS_BACKOFF_FACTOR", "1.618"))
595
+
596
+ async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]:
597
+ backoff_delay = self.BACKOFF_MIN / self.BACKOFF_FACTOR
598
+ while True:
599
+ try:
600
+ async with self as protocol:
601
+ yield protocol
602
+ except Exception as exc:
603
+ # Add a random initial delay between 0 and 5 seconds.
604
+ # See 7.2.3. Recovering from Abnormal Closure in RFC 6455.
605
+ if backoff_delay == self.BACKOFF_MIN:
606
+ initial_delay = random.random() * self.BACKOFF_INITIAL
607
+ self.logger.info(
608
+ "connect failed; reconnecting in %.1f seconds: %s",
609
+ initial_delay,
610
+ traceback.format_exception_only(exc)[0].strip(),
611
+ )
612
+ await asyncio.sleep(initial_delay)
613
+ else:
614
+ self.logger.info(
615
+ "connect failed again; retrying in %d seconds: %s",
616
+ int(backoff_delay),
617
+ traceback.format_exception_only(exc)[0].strip(),
618
+ )
619
+ await asyncio.sleep(int(backoff_delay))
620
+ # Increase delay with truncated exponential backoff.
621
+ backoff_delay = backoff_delay * self.BACKOFF_FACTOR
622
+ backoff_delay = min(backoff_delay, self.BACKOFF_MAX)
623
+ continue
624
+ else:
625
+ # Connection succeeded - reset backoff delay
626
+ backoff_delay = self.BACKOFF_MIN
627
+
628
+ # async with connect(...) as ...:
629
+
630
+ async def __aenter__(self) -> WebSocketClientProtocol:
631
+ return await self
632
+
633
+ async def __aexit__(
634
+ self,
635
+ exc_type: type[BaseException] | None,
636
+ exc_value: BaseException | None,
637
+ traceback: TracebackType | None,
638
+ ) -> None:
639
+ await self.protocol.close()
640
+
641
+ # ... = await connect(...)
642
+
643
+ def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]:
644
+ # Create a suitable iterator by calling __await__ on a coroutine.
645
+ return self.__await_impl__().__await__()
646
+
647
+ async def __await_impl__(self) -> WebSocketClientProtocol:
648
+ async with asyncio_timeout(self.open_timeout):
649
+ for _redirects in range(self.MAX_REDIRECTS_ALLOWED):
650
+ _transport, protocol = await self._create_connection()
651
+ try:
652
+ await protocol.handshake(
653
+ self._wsuri,
654
+ origin=protocol.origin,
655
+ available_extensions=protocol.available_extensions,
656
+ available_subprotocols=protocol.available_subprotocols,
657
+ extra_headers=protocol.extra_headers,
658
+ )
659
+ except RedirectHandshake as exc:
660
+ protocol.fail_connection()
661
+ await protocol.wait_closed()
662
+ self.handle_redirect(exc.uri)
663
+ # Avoid leaking a connected socket when the handshake fails.
664
+ except (Exception, asyncio.CancelledError):
665
+ protocol.fail_connection()
666
+ await protocol.wait_closed()
667
+ raise
668
+ else:
669
+ self.protocol = protocol
670
+ return protocol
671
+ else:
672
+ raise SecurityError("too many redirects")
673
+
674
+ # ... = yield from connect(...) - remove when dropping Python < 3.11
675
+
676
+ __iter__ = __await__
677
+
678
+
679
+ connect = Connect
680
+
681
+
682
+ def unix_connect(
683
+ path: str | None = None,
684
+ uri: str = "ws://localhost/",
685
+ **kwargs: Any,
686
+ ) -> Connect:
687
+ """
688
+ Similar to :func:`connect`, but for connecting to a Unix socket.
689
+
690
+ This function builds upon the event loop's
691
+ :meth:`~asyncio.loop.create_unix_connection` method.
692
+
693
+ It is only available on Unix.
694
+
695
+ It's mainly useful for debugging servers listening on Unix sockets.
696
+
697
+ Args:
698
+ path: File system path to the Unix socket.
699
+ uri: URI of the WebSocket server; the host is used in the TLS
700
+ handshake for secure connections and in the ``Host`` header.
701
+
702
+ """
703
+ return connect(uri=uri, path=path, unix=True, **kwargs)
source/websockets/legacy/exceptions.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import http
2
+
3
+ from .. import datastructures
4
+ from ..exceptions import (
5
+ InvalidHandshake,
6
+ # InvalidMessage was incorrectly moved here in versions 14.0 and 14.1.
7
+ InvalidMessage, # noqa: F401
8
+ ProtocolError as WebSocketProtocolError, # noqa: F401
9
+ )
10
+ from ..typing import StatusLike
11
+
12
+
13
+ class InvalidStatusCode(InvalidHandshake):
14
+ """
15
+ Raised when a handshake response status code is invalid.
16
+
17
+ """
18
+
19
+ def __init__(self, status_code: int, headers: datastructures.Headers) -> None:
20
+ self.status_code = status_code
21
+ self.headers = headers
22
+
23
+ def __str__(self) -> str:
24
+ return f"server rejected WebSocket connection: HTTP {self.status_code}"
25
+
26
+
27
+ class AbortHandshake(InvalidHandshake):
28
+ """
29
+ Raised to abort the handshake on purpose and return an HTTP response.
30
+
31
+ This exception is an implementation detail.
32
+
33
+ The public API is
34
+ :meth:`~websockets.legacy.server.WebSocketServerProtocol.process_request`.
35
+
36
+ Attributes:
37
+ status (~http.HTTPStatus): HTTP status code.
38
+ headers (Headers): HTTP response headers.
39
+ body (bytes): HTTP response body.
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ status: StatusLike,
45
+ headers: datastructures.HeadersLike,
46
+ body: bytes = b"",
47
+ ) -> None:
48
+ # If a user passes an int instead of an HTTPStatus, fix it automatically.
49
+ self.status = http.HTTPStatus(status)
50
+ self.headers = datastructures.Headers(headers)
51
+ self.body = body
52
+
53
+ def __str__(self) -> str:
54
+ return (
55
+ f"HTTP {self.status:d}, {len(self.headers)} headers, {len(self.body)} bytes"
56
+ )
57
+
58
+
59
+ class RedirectHandshake(InvalidHandshake):
60
+ """
61
+ Raised when a handshake gets redirected.
62
+
63
+ This exception is an implementation detail.
64
+
65
+ """
66
+
67
+ def __init__(self, uri: str) -> None:
68
+ self.uri = uri
69
+
70
+ def __str__(self) -> str:
71
+ return f"redirect to {self.uri}"
source/websockets/legacy/framing.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import struct
4
+ from collections.abc import Awaitable, Sequence
5
+ from typing import Any, Callable, NamedTuple
6
+
7
+ from .. import extensions, frames
8
+ from ..exceptions import PayloadTooBig, ProtocolError
9
+ from ..typing import BytesLike, DataLike
10
+
11
+
12
+ try:
13
+ from ..speedups import apply_mask
14
+ except ImportError:
15
+ from ..utils import apply_mask
16
+
17
+
18
+ class Frame(NamedTuple):
19
+ fin: bool
20
+ opcode: frames.Opcode
21
+ data: BytesLike
22
+ rsv1: bool = False
23
+ rsv2: bool = False
24
+ rsv3: bool = False
25
+
26
+ @property
27
+ def new_frame(self) -> frames.Frame:
28
+ return frames.Frame(
29
+ self.opcode,
30
+ self.data,
31
+ self.fin,
32
+ self.rsv1,
33
+ self.rsv2,
34
+ self.rsv3,
35
+ )
36
+
37
+ def __str__(self) -> str:
38
+ return str(self.new_frame)
39
+
40
+ def check(self) -> None:
41
+ return self.new_frame.check()
42
+
43
+ @classmethod
44
+ async def read(
45
+ cls,
46
+ reader: Callable[[int], Awaitable[bytes]],
47
+ *,
48
+ mask: bool,
49
+ max_size: int | None = None,
50
+ extensions: Sequence[extensions.Extension] | None = None,
51
+ ) -> Frame:
52
+ """
53
+ Read a WebSocket frame.
54
+
55
+ Args:
56
+ reader: Coroutine that reads exactly the requested number of
57
+ bytes, unless the end of file is reached.
58
+ mask: Whether the frame should be masked i.e. whether the read
59
+ happens on the server side.
60
+ max_size: Maximum payload size in bytes.
61
+ extensions: List of extensions, applied in reverse order.
62
+
63
+ Raises:
64
+ PayloadTooBig: If the frame exceeds ``max_size``.
65
+ ProtocolError: If the frame contains incorrect values.
66
+
67
+ """
68
+
69
+ # Read the header.
70
+ data = await reader(2)
71
+ head1, head2 = struct.unpack("!BB", data)
72
+
73
+ # While not Pythonic, this is marginally faster than calling bool().
74
+ fin = True if head1 & 0b10000000 else False
75
+ rsv1 = True if head1 & 0b01000000 else False
76
+ rsv2 = True if head1 & 0b00100000 else False
77
+ rsv3 = True if head1 & 0b00010000 else False
78
+
79
+ try:
80
+ opcode = frames.Opcode(head1 & 0b00001111)
81
+ except ValueError as exc:
82
+ raise ProtocolError("invalid opcode") from exc
83
+
84
+ if (True if head2 & 0b10000000 else False) != mask:
85
+ raise ProtocolError("incorrect masking")
86
+
87
+ length = head2 & 0b01111111
88
+ if length == 126:
89
+ data = await reader(2)
90
+ (length,) = struct.unpack("!H", data)
91
+ elif length == 127:
92
+ data = await reader(8)
93
+ (length,) = struct.unpack("!Q", data)
94
+ if max_size is not None and length > max_size:
95
+ raise PayloadTooBig(length, max_size)
96
+ if mask:
97
+ mask_bits = await reader(4)
98
+
99
+ # Read the data.
100
+ data = await reader(length)
101
+ if mask:
102
+ data = apply_mask(data, mask_bits)
103
+
104
+ new_frame = frames.Frame(opcode, data, fin, rsv1, rsv2, rsv3)
105
+
106
+ if extensions is None:
107
+ extensions = []
108
+ for extension in reversed(extensions):
109
+ new_frame = extension.decode(new_frame, max_size=max_size)
110
+
111
+ new_frame.check()
112
+
113
+ return cls(
114
+ new_frame.fin,
115
+ new_frame.opcode,
116
+ new_frame.data,
117
+ new_frame.rsv1,
118
+ new_frame.rsv2,
119
+ new_frame.rsv3,
120
+ )
121
+
122
+ def write(
123
+ self,
124
+ write: Callable[[bytes], Any],
125
+ *,
126
+ mask: bool,
127
+ extensions: Sequence[extensions.Extension] | None = None,
128
+ ) -> None:
129
+ """
130
+ Write a WebSocket frame.
131
+
132
+ Args:
133
+ frame: Frame to write.
134
+ write: Function that writes bytes.
135
+ mask: Whether the frame should be masked i.e. whether the write
136
+ happens on the client side.
137
+ extensions: List of extensions, applied in order.
138
+
139
+ Raises:
140
+ ProtocolError: If the frame contains incorrect values.
141
+
142
+ """
143
+ # The frame is written in a single call to write in order to prevent
144
+ # TCP fragmentation. See #68 for details. This also makes it safe to
145
+ # send frames concurrently from multiple coroutines.
146
+ write(self.new_frame.serialize(mask=mask, extensions=extensions))
147
+
148
+
149
+ def prepare_data(data: DataLike) -> tuple[int, BytesLike]:
150
+ """
151
+ Convert a string or byte-like object to an opcode and a bytes-like object.
152
+
153
+ This function is designed for data frames.
154
+
155
+ If ``data`` is a :class:`str`, return ``OP_TEXT`` and a :class:`bytes`
156
+ object encoding ``data`` in UTF-8.
157
+
158
+ If ``data`` is a bytes-like object, return ``OP_BINARY`` and a bytes-like
159
+ object.
160
+
161
+ Raises:
162
+ TypeError: If ``data`` doesn't have a supported type.
163
+
164
+ """
165
+ if isinstance(data, str):
166
+ return frames.Opcode.TEXT, data.encode()
167
+ elif isinstance(data, BytesLike):
168
+ return frames.Opcode.BINARY, data
169
+ else:
170
+ raise TypeError("data must be str or bytes-like")
171
+
172
+
173
+ def prepare_ctrl(data: DataLike) -> bytes:
174
+ """
175
+ Convert a string or byte-like object to bytes.
176
+
177
+ This function is designed for ping and pong frames.
178
+
179
+ If ``data`` is a :class:`str`, return a :class:`bytes` object encoding
180
+ ``data`` in UTF-8.
181
+
182
+ If ``data`` is a bytes-like object, return a :class:`bytes` object.
183
+
184
+ Raises:
185
+ TypeError: If ``data`` doesn't have a supported type.
186
+
187
+ """
188
+ if isinstance(data, str):
189
+ return data.encode()
190
+ elif isinstance(data, BytesLike):
191
+ return bytes(data)
192
+ else:
193
+ raise TypeError("data must be str or bytes-like")
194
+
195
+
196
+ # Backwards compatibility with previously documented public APIs
197
+ encode_data = prepare_ctrl
198
+
199
+ # Backwards compatibility with previously documented public APIs
200
+ from ..frames import Close # noqa: E402 F401, I001
201
+
202
+
203
+ def parse_close(data: bytes) -> tuple[int, str]:
204
+ """
205
+ Parse the payload from a close frame.
206
+
207
+ Returns:
208
+ Close code and reason.
209
+
210
+ Raises:
211
+ ProtocolError: If data is ill-formed.
212
+ UnicodeDecodeError: If the reason isn't valid UTF-8.
213
+
214
+ """
215
+ close = Close.parse(data)
216
+ return close.code, close.reason
217
+
218
+
219
+ def serialize_close(code: int, reason: str) -> bytes:
220
+ """
221
+ Serialize the payload for a close frame.
222
+
223
+ """
224
+ return Close(code, reason).serialize()
source/websockets/legacy/handshake.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import binascii
5
+
6
+ from ..datastructures import Headers, MultipleValuesError
7
+ from ..exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade
8
+ from ..headers import parse_connection, parse_upgrade
9
+ from ..typing import ConnectionOption, UpgradeProtocol
10
+ from ..utils import accept_key as accept, generate_key
11
+
12
+
13
+ __all__ = ["build_request", "check_request", "build_response", "check_response"]
14
+
15
+
16
+ def build_request(headers: Headers) -> str:
17
+ """
18
+ Build a handshake request to send to the server.
19
+
20
+ Update request headers passed in argument.
21
+
22
+ Args:
23
+ headers: Handshake request headers.
24
+
25
+ Returns:
26
+ ``key`` that must be passed to :func:`check_response`.
27
+
28
+ """
29
+ key = generate_key()
30
+ headers["Upgrade"] = "websocket"
31
+ headers["Connection"] = "Upgrade"
32
+ headers["Sec-WebSocket-Key"] = key
33
+ headers["Sec-WebSocket-Version"] = "13"
34
+ return key
35
+
36
+
37
+ def check_request(headers: Headers) -> str:
38
+ """
39
+ Check a handshake request received from the client.
40
+
41
+ This function doesn't verify that the request is an HTTP/1.1 or higher GET
42
+ request and doesn't perform ``Host`` and ``Origin`` checks. These controls
43
+ are usually performed earlier in the HTTP request handling code. They're
44
+ the responsibility of the caller.
45
+
46
+ Args:
47
+ headers: Handshake request headers.
48
+
49
+ Returns:
50
+ ``key`` that must be passed to :func:`build_response`.
51
+
52
+ Raises:
53
+ InvalidHandshake: If the handshake request is invalid.
54
+ Then, the server must return a 400 Bad Request error.
55
+
56
+ """
57
+ connection: list[ConnectionOption] = sum(
58
+ [parse_connection(value) for value in headers.get_all("Connection")], []
59
+ )
60
+
61
+ if not any(value.lower() == "upgrade" for value in connection):
62
+ raise InvalidUpgrade("Connection", ", ".join(connection))
63
+
64
+ upgrade: list[UpgradeProtocol] = sum(
65
+ [parse_upgrade(value) for value in headers.get_all("Upgrade")], []
66
+ )
67
+
68
+ # For compatibility with non-strict implementations, ignore case when
69
+ # checking the Upgrade header. The RFC always uses "websocket", except
70
+ # in section 11.2. (IANA registration) where it uses "WebSocket".
71
+ if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
72
+ raise InvalidUpgrade("Upgrade", ", ".join(upgrade))
73
+
74
+ try:
75
+ s_w_key = headers["Sec-WebSocket-Key"]
76
+ except KeyError as exc:
77
+ raise InvalidHeader("Sec-WebSocket-Key") from exc
78
+ except MultipleValuesError as exc:
79
+ raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from exc
80
+
81
+ try:
82
+ raw_key = base64.b64decode(s_w_key.encode(), validate=True)
83
+ except binascii.Error as exc:
84
+ raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key) from exc
85
+ if len(raw_key) != 16:
86
+ raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key)
87
+
88
+ try:
89
+ s_w_version = headers["Sec-WebSocket-Version"]
90
+ except KeyError as exc:
91
+ raise InvalidHeader("Sec-WebSocket-Version") from exc
92
+ except MultipleValuesError as exc:
93
+ raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from exc
94
+
95
+ if s_w_version != "13":
96
+ raise InvalidHeaderValue("Sec-WebSocket-Version", s_w_version)
97
+
98
+ return s_w_key
99
+
100
+
101
+ def build_response(headers: Headers, key: str) -> None:
102
+ """
103
+ Build a handshake response to send to the client.
104
+
105
+ Update response headers passed in argument.
106
+
107
+ Args:
108
+ headers: Handshake response headers.
109
+ key: Returned by :func:`check_request`.
110
+
111
+ """
112
+ headers["Upgrade"] = "websocket"
113
+ headers["Connection"] = "Upgrade"
114
+ headers["Sec-WebSocket-Accept"] = accept(key)
115
+
116
+
117
+ def check_response(headers: Headers, key: str) -> None:
118
+ """
119
+ Check a handshake response received from the server.
120
+
121
+ This function doesn't verify that the response is an HTTP/1.1 or higher
122
+ response with a 101 status code. These controls are the responsibility of
123
+ the caller.
124
+
125
+ Args:
126
+ headers: Handshake response headers.
127
+ key: Returned by :func:`build_request`.
128
+
129
+ Raises:
130
+ InvalidHandshake: If the handshake response is invalid.
131
+
132
+ """
133
+ connection: list[ConnectionOption] = sum(
134
+ [parse_connection(value) for value in headers.get_all("Connection")], []
135
+ )
136
+
137
+ if not any(value.lower() == "upgrade" for value in connection):
138
+ raise InvalidUpgrade("Connection", " ".join(connection))
139
+
140
+ upgrade: list[UpgradeProtocol] = sum(
141
+ [parse_upgrade(value) for value in headers.get_all("Upgrade")], []
142
+ )
143
+
144
+ # For compatibility with non-strict implementations, ignore case when
145
+ # checking the Upgrade header. The RFC always uses "websocket", except
146
+ # in section 11.2. (IANA registration) where it uses "WebSocket".
147
+ if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
148
+ raise InvalidUpgrade("Upgrade", ", ".join(upgrade))
149
+
150
+ try:
151
+ s_w_accept = headers["Sec-WebSocket-Accept"]
152
+ except KeyError as exc:
153
+ raise InvalidHeader("Sec-WebSocket-Accept") from exc
154
+ except MultipleValuesError as exc:
155
+ raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from exc
156
+
157
+ if s_w_accept != accept(key):
158
+ raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept)
source/websockets/legacy/http.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import os
5
+ import re
6
+
7
+ from ..datastructures import Headers
8
+ from ..exceptions import SecurityError
9
+
10
+
11
+ __all__ = ["read_request", "read_response"]
12
+
13
+ MAX_NUM_HEADERS = int(os.environ.get("WEBSOCKETS_MAX_NUM_HEADERS", "128"))
14
+ MAX_LINE_LENGTH = int(os.environ.get("WEBSOCKETS_MAX_LINE_LENGTH", "8192"))
15
+
16
+
17
+ def d(value: bytes) -> str:
18
+ """
19
+ Decode a bytestring for interpolating into an error message.
20
+
21
+ """
22
+ return value.decode(errors="backslashreplace")
23
+
24
+
25
+ # See https://datatracker.ietf.org/doc/html/rfc7230#appendix-B.
26
+
27
+ # Regex for validating header names.
28
+
29
+ _token_re = re.compile(rb"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+")
30
+
31
+ # Regex for validating header values.
32
+
33
+ # We don't attempt to support obsolete line folding.
34
+
35
+ # Include HTAB (\x09), SP (\x20), VCHAR (\x21-\x7e), obs-text (\x80-\xff).
36
+
37
+ # The ABNF is complicated because it attempts to express that optional
38
+ # whitespace is ignored. We strip whitespace and don't revalidate that.
39
+
40
+ # See also https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189
41
+
42
+ _value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*")
43
+
44
+
45
+ async def read_request(stream: asyncio.StreamReader) -> tuple[str, Headers]:
46
+ """
47
+ Read an HTTP/1.1 GET request and return ``(path, headers)``.
48
+
49
+ ``path`` isn't URL-decoded or validated in any way.
50
+
51
+ ``path`` and ``headers`` are expected to contain only ASCII characters.
52
+ Other characters are represented with surrogate escapes.
53
+
54
+ :func:`read_request` doesn't attempt to read the request body because
55
+ WebSocket handshake requests don't have one. If the request contains a
56
+ body, it may be read from ``stream`` after this coroutine returns.
57
+
58
+ Args:
59
+ stream: Input to read the request from.
60
+
61
+ Raises:
62
+ EOFError: If the connection is closed without a full HTTP request.
63
+ SecurityError: If the request exceeds a security limit.
64
+ ValueError: If the request isn't well formatted.
65
+
66
+ """
67
+ # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.1
68
+
69
+ # Parsing is simple because fixed values are expected for method and
70
+ # version and because path isn't checked. Since WebSocket software tends
71
+ # to implement HTTP/1.1 strictly, there's little need for lenient parsing.
72
+
73
+ try:
74
+ request_line = await read_line(stream)
75
+ except EOFError as exc:
76
+ raise EOFError("connection closed while reading HTTP request line") from exc
77
+
78
+ try:
79
+ method, raw_path, version = request_line.split(b" ", 2)
80
+ except ValueError: # not enough values to unpack (expected 3, got 1-2)
81
+ raise ValueError(f"invalid HTTP request line: {d(request_line)}") from None
82
+
83
+ if method != b"GET":
84
+ raise ValueError(f"unsupported HTTP method: {d(method)}")
85
+ if version != b"HTTP/1.1":
86
+ raise ValueError(f"unsupported HTTP version: {d(version)}")
87
+ path = raw_path.decode("ascii", "surrogateescape")
88
+
89
+ headers = await read_headers(stream)
90
+
91
+ return path, headers
92
+
93
+
94
+ async def read_response(stream: asyncio.StreamReader) -> tuple[int, str, Headers]:
95
+ """
96
+ Read an HTTP/1.1 response and return ``(status_code, reason, headers)``.
97
+
98
+ ``reason`` and ``headers`` are expected to contain only ASCII characters.
99
+ Other characters are represented with surrogate escapes.
100
+
101
+ :func:`read_request` doesn't attempt to read the response body because
102
+ WebSocket handshake responses don't have one. If the response contains a
103
+ body, it may be read from ``stream`` after this coroutine returns.
104
+
105
+ Args:
106
+ stream: Input to read the response from.
107
+
108
+ Raises:
109
+ EOFError: If the connection is closed without a full HTTP response.
110
+ SecurityError: If the response exceeds a security limit.
111
+ ValueError: If the response isn't well formatted.
112
+
113
+ """
114
+ # https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.2
115
+
116
+ # As in read_request, parsing is simple because a fixed value is expected
117
+ # for version, status_code is a 3-digit number, and reason can be ignored.
118
+
119
+ try:
120
+ status_line = await read_line(stream)
121
+ except EOFError as exc:
122
+ raise EOFError("connection closed while reading HTTP status line") from exc
123
+
124
+ try:
125
+ version, raw_status_code, raw_reason = status_line.split(b" ", 2)
126
+ except ValueError: # not enough values to unpack (expected 3, got 1-2)
127
+ raise ValueError(f"invalid HTTP status line: {d(status_line)}") from None
128
+
129
+ if version != b"HTTP/1.1":
130
+ raise ValueError(f"unsupported HTTP version: {d(version)}")
131
+ try:
132
+ status_code = int(raw_status_code)
133
+ except ValueError: # invalid literal for int() with base 10
134
+ raise ValueError(f"invalid HTTP status code: {d(raw_status_code)}") from None
135
+ if not 100 <= status_code < 1000:
136
+ raise ValueError(f"unsupported HTTP status code: {d(raw_status_code)}")
137
+ if not _value_re.fullmatch(raw_reason):
138
+ raise ValueError(f"invalid HTTP reason phrase: {d(raw_reason)}")
139
+ reason = raw_reason.decode()
140
+
141
+ headers = await read_headers(stream)
142
+
143
+ return status_code, reason, headers
144
+
145
+
146
+ async def read_headers(stream: asyncio.StreamReader) -> Headers:
147
+ """
148
+ Read HTTP headers from ``stream``.
149
+
150
+ Non-ASCII characters are represented with surrogate escapes.
151
+
152
+ """
153
+ # https://datatracker.ietf.org/doc/html/rfc7230#section-3.2
154
+
155
+ # We don't attempt to support obsolete line folding.
156
+
157
+ headers = Headers()
158
+ for _ in range(MAX_NUM_HEADERS + 1):
159
+ try:
160
+ line = await read_line(stream)
161
+ except EOFError as exc:
162
+ raise EOFError("connection closed while reading HTTP headers") from exc
163
+ if line == b"":
164
+ break
165
+
166
+ try:
167
+ raw_name, raw_value = line.split(b":", 1)
168
+ except ValueError: # not enough values to unpack (expected 2, got 1)
169
+ raise ValueError(f"invalid HTTP header line: {d(line)}") from None
170
+ if not _token_re.fullmatch(raw_name):
171
+ raise ValueError(f"invalid HTTP header name: {d(raw_name)}")
172
+ raw_value = raw_value.strip(b" \t")
173
+ if not _value_re.fullmatch(raw_value):
174
+ raise ValueError(f"invalid HTTP header value: {d(raw_value)}")
175
+
176
+ name = raw_name.decode("ascii") # guaranteed to be ASCII at this point
177
+ value = raw_value.decode("ascii", "surrogateescape")
178
+ headers[name] = value
179
+
180
+ else:
181
+ raise SecurityError("too many HTTP headers")
182
+
183
+ return headers
184
+
185
+
186
+ async def read_line(stream: asyncio.StreamReader) -> bytes:
187
+ """
188
+ Read a single line from ``stream``.
189
+
190
+ CRLF is stripped from the return value.
191
+
192
+ """
193
+ # Security: this is bounded by the StreamReader's limit (default = 32 KiB).
194
+ line = await stream.readline()
195
+ # Security: this guarantees header values are small (hard-coded = 8 KiB)
196
+ if len(line) > MAX_LINE_LENGTH:
197
+ raise SecurityError("line too long")
198
+ # Not mandatory but safe - https://datatracker.ietf.org/doc/html/rfc7230#section-3.5
199
+ if not line.endswith(b"\r\n"):
200
+ raise EOFError("line without CRLF")
201
+ return line[:-2]
source/websockets/legacy/protocol.py ADDED
@@ -0,0 +1,1635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import codecs
5
+ import collections
6
+ import logging
7
+ import random
8
+ import ssl
9
+ import struct
10
+ import sys
11
+ import time
12
+ import traceback
13
+ import uuid
14
+ import warnings
15
+ from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterable, Mapping
16
+ from typing import Any, Callable, Deque, cast
17
+
18
+ from ..asyncio.compatibility import asyncio_timeout
19
+ from ..datastructures import Headers
20
+ from ..exceptions import (
21
+ ConnectionClosed,
22
+ ConnectionClosedError,
23
+ ConnectionClosedOK,
24
+ InvalidState,
25
+ PayloadTooBig,
26
+ ProtocolError,
27
+ )
28
+ from ..extensions import Extension
29
+ from ..frames import (
30
+ OK_CLOSE_CODES,
31
+ OP_BINARY,
32
+ OP_CLOSE,
33
+ OP_CONT,
34
+ OP_PING,
35
+ OP_PONG,
36
+ OP_TEXT,
37
+ Close,
38
+ CloseCode,
39
+ Opcode,
40
+ )
41
+ from ..protocol import State
42
+ from ..typing import BytesLike, Data, DataLike, LoggerLike, Subprotocol
43
+ from .framing import Frame, prepare_ctrl, prepare_data
44
+
45
+
46
+ __all__ = ["WebSocketCommonProtocol"]
47
+
48
+
49
+ # In order to ensure consistency, the code always checks the current value of
50
+ # WebSocketCommonProtocol.state before assigning a new value and never yields
51
+ # between the check and the assignment.
52
+
53
+
54
+ class WebSocketCommonProtocol(asyncio.Protocol):
55
+ """
56
+ WebSocket connection.
57
+
58
+ :class:`WebSocketCommonProtocol` provides APIs shared between WebSocket
59
+ servers and clients. You shouldn't use it directly. Instead, use
60
+ :class:`~websockets.legacy.client.WebSocketClientProtocol` or
61
+ :class:`~websockets.legacy.server.WebSocketServerProtocol`.
62
+
63
+ This documentation focuses on low-level details that aren't covered in the
64
+ documentation of :class:`~websockets.legacy.client.WebSocketClientProtocol`
65
+ and :class:`~websockets.legacy.server.WebSocketServerProtocol` for the sake
66
+ of simplicity.
67
+
68
+ Once the connection is open, a Ping_ frame is sent every ``ping_interval``
69
+ seconds. This serves as a keepalive. It helps keeping the connection open,
70
+ especially in the presence of proxies with short timeouts on inactive
71
+ connections. Set ``ping_interval`` to :obj:`None` to disable this behavior.
72
+
73
+ .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2
74
+
75
+ If the corresponding Pong_ frame isn't received within ``ping_timeout``
76
+ seconds, the connection is considered unusable and is closed with code 1011.
77
+ This ensures that the remote endpoint remains responsive. Set
78
+ ``ping_timeout`` to :obj:`None` to disable this behavior.
79
+
80
+ .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3
81
+
82
+ See the discussion of :doc:`keepalive <../../topics/keepalive>` for details.
83
+
84
+ The ``close_timeout`` parameter defines a maximum wait time for completing
85
+ the closing handshake and terminating the TCP connection. For legacy
86
+ reasons, :meth:`close` completes in at most ``5 * close_timeout`` seconds
87
+ for clients and ``4 * close_timeout`` for servers.
88
+
89
+ ``close_timeout`` is a parameter of the protocol because websockets usually
90
+ calls :meth:`close` implicitly upon exit:
91
+
92
+ * on the client side, when using :func:`~websockets.legacy.client.connect`
93
+ as a context manager;
94
+ * on the server side, when the connection handler terminates.
95
+
96
+ To apply a timeout to any other API, wrap it in :func:`~asyncio.timeout` or
97
+ :func:`~asyncio.wait_for`.
98
+
99
+ The ``max_size`` parameter enforces the maximum size for incoming messages
100
+ in bytes. The default value is 1 MiB. If a larger message is received,
101
+ :meth:`recv` will raise :exc:`~websockets.exceptions.ConnectionClosedError`
102
+ and the connection will be closed with code 1009.
103
+
104
+ The ``max_queue`` parameter sets the maximum length of the queue that
105
+ holds incoming messages. The default value is ``32``. Messages are added
106
+ to an in-memory queue when they're received; then :meth:`recv` pops from
107
+ that queue. In order to prevent excessive memory consumption when
108
+ messages are received faster than they can be processed, the queue must
109
+ be bounded. If the queue fills up, the protocol stops processing incoming
110
+ data until :meth:`recv` is called. In this situation, various receive
111
+ buffers (at least in :mod:`asyncio` and in the OS) will fill up, then the
112
+ TCP receive window will shrink, slowing down transmission to avoid packet
113
+ loss.
114
+
115
+ Since Python can use up to 4 bytes of memory to represent a single
116
+ character, each connection may use up to ``4 * max_size * max_queue``
117
+ bytes of memory to store incoming messages. By default, this is 128 MiB.
118
+ You may want to lower the limits, depending on your application's
119
+ requirements.
120
+
121
+ The ``read_limit`` argument sets the high-water limit of the buffer for
122
+ incoming bytes. The low-water limit is half the high-water limit. The
123
+ default value is 64 KiB, half of asyncio's default (based on the current
124
+ implementation of :class:`~asyncio.StreamReader`).
125
+
126
+ The ``write_limit`` argument sets the high-water limit of the buffer for
127
+ outgoing bytes. The low-water limit is a quarter of the high-water limit.
128
+ The default value is 64 KiB, equal to asyncio's default (based on the
129
+ current implementation of ``FlowControlMixin``).
130
+
131
+ See the discussion of :doc:`memory usage <../../topics/memory>` for details.
132
+
133
+ Args:
134
+ logger: Logger for this server.
135
+ It defaults to ``logging.getLogger("websockets.protocol")``.
136
+ See the :doc:`logging guide <../../topics/logging>` for details.
137
+ ping_interval: Interval between keepalive pings in seconds.
138
+ :obj:`None` disables keepalive.
139
+ ping_timeout: Timeout for keepalive pings in seconds.
140
+ :obj:`None` disables timeouts.
141
+ close_timeout: Timeout for closing the connection in seconds.
142
+ For legacy reasons, the actual timeout is 4 or 5 times larger.
143
+ max_size: Maximum size of incoming messages in bytes.
144
+ :obj:`None` disables the limit.
145
+ max_queue: Maximum number of incoming messages in receive buffer.
146
+ :obj:`None` disables the limit.
147
+ read_limit: High-water mark of read buffer in bytes.
148
+ write_limit: High-water mark of write buffer in bytes.
149
+
150
+ """
151
+
152
+ # There are only two differences between the client-side and server-side
153
+ # behavior: masking the payload and closing the underlying TCP connection.
154
+ # Set is_client = True/False and side = "client"/"server" to pick a side.
155
+ is_client: bool
156
+ side: str = "undefined"
157
+
158
+ def __init__(
159
+ self,
160
+ *,
161
+ logger: LoggerLike | None = None,
162
+ ping_interval: float | None = 20,
163
+ ping_timeout: float | None = 20,
164
+ close_timeout: float | None = None,
165
+ max_size: int | None = 2**20,
166
+ max_queue: int | None = 2**5,
167
+ read_limit: int = 2**16,
168
+ write_limit: int = 2**16,
169
+ # The following arguments are kept only for backwards compatibility.
170
+ host: str | None = None,
171
+ port: int | None = None,
172
+ secure: bool | None = None,
173
+ legacy_recv: bool = False,
174
+ loop: asyncio.AbstractEventLoop | None = None,
175
+ timeout: float | None = None,
176
+ ) -> None:
177
+ if legacy_recv: # pragma: no cover
178
+ warnings.warn("legacy_recv is deprecated", DeprecationWarning)
179
+
180
+ # Backwards compatibility: close_timeout used to be called timeout.
181
+ if timeout is None:
182
+ timeout = 10
183
+ else:
184
+ warnings.warn("rename timeout to close_timeout", DeprecationWarning)
185
+ # If both are specified, timeout is ignored.
186
+ if close_timeout is None:
187
+ close_timeout = timeout
188
+
189
+ # Backwards compatibility: the loop parameter used to be supported.
190
+ if loop is None:
191
+ loop = asyncio.get_event_loop()
192
+ else:
193
+ warnings.warn("remove loop argument", DeprecationWarning)
194
+
195
+ self.ping_interval = ping_interval
196
+ self.ping_timeout = ping_timeout
197
+ self.close_timeout = close_timeout
198
+ self.max_size = max_size
199
+ self.max_queue = max_queue
200
+ self.read_limit = read_limit
201
+ self.write_limit = write_limit
202
+
203
+ # Unique identifier. For logs.
204
+ self.id: uuid.UUID = uuid.uuid4()
205
+ """Unique identifier of the connection. Useful in logs."""
206
+
207
+ # Logger or LoggerAdapter for this connection.
208
+ if logger is None:
209
+ logger = logging.getLogger("websockets.protocol")
210
+ self.logger: LoggerLike = logging.LoggerAdapter(logger, {"websocket": self})
211
+ """Logger for this connection."""
212
+
213
+ # Track if DEBUG is enabled. Shortcut logging calls if it isn't.
214
+ self.debug = logger.isEnabledFor(logging.DEBUG)
215
+
216
+ self.loop = loop
217
+
218
+ self._host = host
219
+ self._port = port
220
+ self._secure = secure
221
+ self.legacy_recv = legacy_recv
222
+
223
+ # Configure read buffer limits. The high-water limit is defined by
224
+ # ``self.read_limit``. The ``limit`` argument controls the line length
225
+ # limit and half the buffer limit of :class:`~asyncio.StreamReader`.
226
+ # That's why it must be set to half of ``self.read_limit``.
227
+ self.reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop)
228
+
229
+ # Copied from asyncio.FlowControlMixin
230
+ self._paused = False
231
+ self._drain_waiter: asyncio.Future[None] | None = None
232
+
233
+ # This class implements the data transfer and closing handshake, which
234
+ # are shared between the client-side and the server-side.
235
+ # Subclasses implement the opening handshake and, on success, execute
236
+ # :meth:`connection_open` to change the state to OPEN.
237
+ self.state = State.CONNECTING
238
+ if self.debug:
239
+ self.logger.debug("= connection is CONNECTING")
240
+
241
+ # HTTP protocol parameters.
242
+ self.path: str
243
+ """Path of the opening handshake request."""
244
+ self.request_headers: Headers
245
+ """Opening handshake request headers."""
246
+ self.response_headers: Headers
247
+ """Opening handshake response headers."""
248
+
249
+ # WebSocket protocol parameters.
250
+ self.extensions: list[Extension] = []
251
+ self.subprotocol: Subprotocol | None = None
252
+ """Subprotocol, if one was negotiated."""
253
+
254
+ # Close code and reason, set when a close frame is sent or received.
255
+ self.close_rcvd: Close | None = None
256
+ self.close_sent: Close | None = None
257
+ self.close_rcvd_then_sent: bool | None = None
258
+
259
+ # Completed when the connection state becomes CLOSED. Translates the
260
+ # :meth:`connection_lost` callback to a :class:`~asyncio.Future`
261
+ # that can be awaited. (Other :class:`~asyncio.Protocol` callbacks are
262
+ # translated by ``self.stream_reader``).
263
+ self.connection_lost_waiter: asyncio.Future[None] = loop.create_future()
264
+
265
+ # Queue of received messages.
266
+ self.messages: Deque[Data] = collections.deque()
267
+ self._pop_message_waiter: asyncio.Future[None] | None = None
268
+ self._put_message_waiter: asyncio.Future[None] | None = None
269
+
270
+ # Protect sending fragmented messages.
271
+ self._fragmented_message_waiter: asyncio.Future[None] | None = None
272
+
273
+ # Mapping of ping IDs to pong waiters, in chronological order.
274
+ self.pings: dict[bytes, tuple[asyncio.Future[float], float]] = {}
275
+
276
+ self.latency: float = 0
277
+ """
278
+ Latency of the connection, in seconds.
279
+
280
+ Latency is defined as the round-trip time of the connection. It is
281
+ measured by sending a Ping frame and waiting for a matching Pong frame.
282
+ Before the first measurement, :attr:`latency` is ``0``.
283
+
284
+ By default, websockets enables a :ref:`keepalive <keepalive>` mechanism
285
+ that sends Ping frames automatically at regular intervals. You can also
286
+ send Ping frames and measure latency with :meth:`ping`.
287
+ """
288
+
289
+ # Task running the data transfer.
290
+ self.transfer_data_task: asyncio.Task[None]
291
+
292
+ # Exception that occurred during data transfer, if any.
293
+ self.transfer_data_exc: BaseException | None = None
294
+
295
+ # Task sending keepalive pings.
296
+ self.keepalive_ping_task: asyncio.Task[None]
297
+
298
+ # Task closing the TCP connection.
299
+ self.close_connection_task: asyncio.Task[None]
300
+
301
+ # Copied from asyncio.FlowControlMixin
302
+ async def _drain_helper(self) -> None: # pragma: no cover
303
+ if self.connection_lost_waiter.done():
304
+ raise ConnectionResetError("Connection lost")
305
+ if not self._paused:
306
+ return
307
+ waiter = self._drain_waiter
308
+ assert waiter is None or waiter.cancelled()
309
+ waiter = self.loop.create_future()
310
+ self._drain_waiter = waiter
311
+ await waiter
312
+
313
+ # Copied from asyncio.StreamWriter
314
+ async def _drain(self) -> None: # pragma: no cover
315
+ if self.reader is not None:
316
+ exc = self.reader.exception()
317
+ if exc is not None:
318
+ raise exc
319
+ if self.transport is not None:
320
+ if self.transport.is_closing():
321
+ # Yield to the event loop so connection_lost() may be
322
+ # called. Without this, _drain_helper() would return
323
+ # immediately, and code that calls
324
+ # write(...); yield from drain()
325
+ # in a loop would never call connection_lost(), so it
326
+ # would not see an error when the socket is closed.
327
+ await asyncio.sleep(0)
328
+ await self._drain_helper()
329
+
330
+ def connection_open(self) -> None:
331
+ """
332
+ Callback when the WebSocket opening handshake completes.
333
+
334
+ Enter the OPEN state and start the data transfer phase.
335
+
336
+ """
337
+ # 4.1. The WebSocket Connection is Established.
338
+ assert self.state is State.CONNECTING
339
+ self.state = State.OPEN
340
+ if self.debug:
341
+ self.logger.debug("= connection is OPEN")
342
+ # Start the task that receives incoming WebSocket messages.
343
+ self.transfer_data_task = self.loop.create_task(self.transfer_data())
344
+ # Start the task that sends pings at regular intervals.
345
+ self.keepalive_ping_task = self.loop.create_task(self.keepalive_ping())
346
+ # Start the task that eventually closes the TCP connection.
347
+ self.close_connection_task = self.loop.create_task(self.close_connection())
348
+
349
+ @property
350
+ def host(self) -> str | None:
351
+ alternative = "remote_address" if self.is_client else "local_address"
352
+ warnings.warn(f"use {alternative}[0] instead of host", DeprecationWarning)
353
+ return self._host
354
+
355
+ @property
356
+ def port(self) -> int | None:
357
+ alternative = "remote_address" if self.is_client else "local_address"
358
+ warnings.warn(f"use {alternative}[1] instead of port", DeprecationWarning)
359
+ return self._port
360
+
361
+ @property
362
+ def secure(self) -> bool | None:
363
+ warnings.warn("don't use secure", DeprecationWarning)
364
+ return self._secure
365
+
366
+ # Public API
367
+
368
+ @property
369
+ def local_address(self) -> Any:
370
+ """
371
+ Local address of the connection.
372
+
373
+ For IPv4 connections, this is a ``(host, port)`` tuple.
374
+
375
+ The format of the address depends on the address family;
376
+ see :meth:`~socket.socket.getsockname`.
377
+
378
+ :obj:`None` if the TCP connection isn't established yet.
379
+
380
+ """
381
+ try:
382
+ transport = self.transport
383
+ except AttributeError:
384
+ return None
385
+ else:
386
+ return transport.get_extra_info("sockname")
387
+
388
+ @property
389
+ def remote_address(self) -> Any:
390
+ """
391
+ Remote address of the connection.
392
+
393
+ For IPv4 connections, this is a ``(host, port)`` tuple.
394
+
395
+ The format of the address depends on the address family;
396
+ see :meth:`~socket.socket.getpeername`.
397
+
398
+ :obj:`None` if the TCP connection isn't established yet.
399
+
400
+ """
401
+ try:
402
+ transport = self.transport
403
+ except AttributeError:
404
+ return None
405
+ else:
406
+ return transport.get_extra_info("peername")
407
+
408
+ @property
409
+ def open(self) -> bool:
410
+ """
411
+ :obj:`True` when the connection is open; :obj:`False` otherwise.
412
+
413
+ This attribute may be used to detect disconnections. However, this
414
+ approach is discouraged per the EAFP_ principle. Instead, you should
415
+ handle :exc:`~websockets.exceptions.ConnectionClosed` exceptions.
416
+
417
+ .. _EAFP: https://docs.python.org/3/glossary.html#term-eafp
418
+
419
+ """
420
+ return self.state is State.OPEN and not self.transfer_data_task.done()
421
+
422
+ @property
423
+ def closed(self) -> bool:
424
+ """
425
+ :obj:`True` when the connection is closed; :obj:`False` otherwise.
426
+
427
+ Be aware that both :attr:`open` and :attr:`closed` are :obj:`False`
428
+ during the opening and closing sequences.
429
+
430
+ """
431
+ return self.state is State.CLOSED
432
+
433
+ @property
434
+ def close_code(self) -> int | None:
435
+ """
436
+ WebSocket close code, defined in `section 7.1.5 of RFC 6455`_.
437
+
438
+ .. _section 7.1.5 of RFC 6455:
439
+ https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.5
440
+
441
+ :obj:`None` if the connection isn't closed yet.
442
+
443
+ """
444
+ if self.state is not State.CLOSED:
445
+ return None
446
+ elif self.close_rcvd is None:
447
+ return CloseCode.ABNORMAL_CLOSURE
448
+ else:
449
+ return self.close_rcvd.code
450
+
451
+ @property
452
+ def close_reason(self) -> str | None:
453
+ """
454
+ WebSocket close reason, defined in `section 7.1.6 of RFC 6455`_.
455
+
456
+ .. _section 7.1.6 of RFC 6455:
457
+ https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.6
458
+
459
+ :obj:`None` if the connection isn't closed yet.
460
+
461
+ """
462
+ if self.state is not State.CLOSED:
463
+ return None
464
+ elif self.close_rcvd is None:
465
+ return ""
466
+ else:
467
+ return self.close_rcvd.reason
468
+
469
+ async def __aiter__(self) -> AsyncIterator[Data]:
470
+ """
471
+ Iterate on incoming messages.
472
+
473
+ The iterator exits normally when the connection is closed with the close
474
+ code 1000 (OK) or 1001 (going away) or without a close code.
475
+
476
+ It raises a :exc:`~websockets.exceptions.ConnectionClosedError`
477
+ exception when the connection is closed with any other code.
478
+
479
+ """
480
+ try:
481
+ while True:
482
+ yield await self.recv()
483
+ except ConnectionClosedOK:
484
+ return
485
+
486
+ async def recv(self) -> Data:
487
+ """
488
+ Receive the next message.
489
+
490
+ When the connection is closed, :meth:`recv` raises
491
+ :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises
492
+ :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal
493
+ connection closure and
494
+ :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol
495
+ error or a network failure. This is how you detect the end of the
496
+ message stream.
497
+
498
+ Canceling :meth:`recv` is safe. There's no risk of losing the next
499
+ message. The next invocation of :meth:`recv` will return it.
500
+
501
+ This makes it possible to enforce a timeout by wrapping :meth:`recv` in
502
+ :func:`~asyncio.timeout` or :func:`~asyncio.wait_for`.
503
+
504
+ Returns:
505
+ A string (:class:`str`) for a Text_ frame. A bytestring
506
+ (:class:`bytes`) for a Binary_ frame.
507
+
508
+ .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
509
+ .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
510
+
511
+ Raises:
512
+ ConnectionClosed: When the connection is closed.
513
+ RuntimeError: If two coroutines call :meth:`recv` concurrently.
514
+
515
+ """
516
+ if self._pop_message_waiter is not None:
517
+ raise RuntimeError(
518
+ "cannot call recv while another coroutine "
519
+ "is already waiting for the next message"
520
+ )
521
+
522
+ # Don't await self.ensure_open() here:
523
+ # - messages could be available in the queue even if the connection
524
+ # is closed;
525
+ # - messages could be received before the closing frame even if the
526
+ # connection is closing.
527
+
528
+ # Wait until there's a message in the queue (if necessary) or the
529
+ # connection is closed.
530
+ while len(self.messages) <= 0:
531
+ pop_message_waiter: asyncio.Future[None] = self.loop.create_future()
532
+ self._pop_message_waiter = pop_message_waiter
533
+ try:
534
+ # If asyncio.wait() is canceled, it doesn't cancel
535
+ # pop_message_waiter and self.transfer_data_task.
536
+ await asyncio.wait(
537
+ [pop_message_waiter, self.transfer_data_task],
538
+ return_when=asyncio.FIRST_COMPLETED,
539
+ )
540
+ finally:
541
+ self._pop_message_waiter = None
542
+
543
+ # If asyncio.wait(...) exited because self.transfer_data_task
544
+ # completed before receiving a new message, raise a suitable
545
+ # exception (or return None if legacy_recv is enabled).
546
+ if not pop_message_waiter.done():
547
+ if self.legacy_recv:
548
+ return None # type: ignore
549
+ else:
550
+ # Wait until the connection is closed to raise
551
+ # ConnectionClosed with the correct code and reason.
552
+ await self.ensure_open()
553
+
554
+ # Pop a message from the queue.
555
+ message = self.messages.popleft()
556
+
557
+ # Notify transfer_data().
558
+ if self._put_message_waiter is not None:
559
+ self._put_message_waiter.set_result(None)
560
+ self._put_message_waiter = None
561
+
562
+ return message
563
+
564
+ async def send(
565
+ self,
566
+ message: DataLike | Iterable[DataLike] | AsyncIterable[DataLike],
567
+ ) -> None:
568
+ """
569
+ Send a message.
570
+
571
+ A string (:class:`str`) is sent as a Text_ frame. A bytestring or
572
+ bytes-like object (:class:`bytes`, :class:`bytearray`, or
573
+ :class:`memoryview`) is sent as a Binary_ frame.
574
+
575
+ .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
576
+ .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
577
+
578
+ :meth:`send` also accepts an iterable or an asynchronous iterable of
579
+ strings, bytestrings, or bytes-like objects to enable fragmentation_.
580
+ Each item is treated as a message fragment and sent in its own frame.
581
+ All items must be of the same type, or else :meth:`send` will raise a
582
+ :exc:`TypeError` and the connection will be closed.
583
+
584
+ .. _fragmentation: https://datatracker.ietf.org/doc/html/rfc6455#section-5.4
585
+
586
+ :meth:`send` rejects dict-like objects because this is often an error.
587
+ (If you want to send the keys of a dict-like object as fragments, call
588
+ its :meth:`~dict.keys` method and pass the result to :meth:`send`.)
589
+
590
+ Canceling :meth:`send` is discouraged. Instead, you should close the
591
+ connection with :meth:`close`. Indeed, there are only two situations
592
+ where :meth:`send` may yield control to the event loop and then get
593
+ canceled; in both cases, :meth:`close` has the same effect and is
594
+ more clear:
595
+
596
+ 1. The write buffer is full. If you don't want to wait until enough
597
+ data is sent, your only alternative is to close the connection.
598
+ :meth:`close` will likely time out then abort the TCP connection.
599
+ 2. ``message`` is an asynchronous iterator that yields control.
600
+ Stopping in the middle of a fragmented message will cause a
601
+ protocol error and the connection will be closed.
602
+
603
+ When the connection is closed, :meth:`send` raises
604
+ :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it
605
+ raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal
606
+ connection closure and
607
+ :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol
608
+ error or a network failure.
609
+
610
+ Args:
611
+ message: Message to send.
612
+
613
+ Raises:
614
+ ConnectionClosed: When the connection is closed.
615
+ TypeError: If ``message`` doesn't have a supported type.
616
+
617
+ """
618
+ await self.ensure_open()
619
+
620
+ # While sending a fragmented message, prevent sending other messages
621
+ # until all fragments are sent.
622
+ while self._fragmented_message_waiter is not None:
623
+ await asyncio.shield(self._fragmented_message_waiter)
624
+
625
+ # Unfragmented message -- this case must be handled first because
626
+ # strings and bytes-like objects are iterable.
627
+
628
+ if isinstance(message, (str, bytes, bytearray, memoryview)):
629
+ opcode, data = prepare_data(message)
630
+ await self.write_frame(True, opcode, data)
631
+
632
+ # Catch a common mistake -- passing a dict to send().
633
+
634
+ elif isinstance(message, Mapping):
635
+ raise TypeError("data is a dict-like object")
636
+
637
+ # Fragmented message -- regular iterator.
638
+
639
+ elif isinstance(message, Iterable):
640
+ iter_message = iter(message)
641
+ try:
642
+ fragment = next(iter_message)
643
+ except StopIteration:
644
+ return
645
+ opcode, data = prepare_data(fragment)
646
+
647
+ self._fragmented_message_waiter = self.loop.create_future()
648
+ try:
649
+ # First fragment.
650
+ await self.write_frame(False, opcode, data)
651
+
652
+ # Other fragments.
653
+ for fragment in iter_message:
654
+ confirm_opcode, data = prepare_data(fragment)
655
+ if confirm_opcode != opcode:
656
+ raise TypeError("data contains inconsistent types")
657
+ await self.write_frame(False, OP_CONT, data)
658
+
659
+ # Final fragment.
660
+ await self.write_frame(True, OP_CONT, b"")
661
+
662
+ except (Exception, asyncio.CancelledError):
663
+ # We're half-way through a fragmented message and we can't
664
+ # complete it. This makes the connection unusable.
665
+ self.fail_connection(CloseCode.INTERNAL_ERROR)
666
+ raise
667
+
668
+ finally:
669
+ self._fragmented_message_waiter.set_result(None)
670
+ self._fragmented_message_waiter = None
671
+
672
+ # Fragmented message -- asynchronous iterator
673
+
674
+ elif isinstance(message, AsyncIterable):
675
+ # Implement aiter_message = aiter(message) without aiter
676
+ # Work around https://github.com/python/mypy/issues/5738
677
+ aiter_message = cast(
678
+ Callable[[AsyncIterable[DataLike]], AsyncIterator[DataLike]],
679
+ type(message).__aiter__,
680
+ )(message)
681
+ try:
682
+ # Implement fragment = anext(aiter_message) without anext
683
+ # Work around https://github.com/python/mypy/issues/5738
684
+ fragment = await cast(
685
+ Callable[[AsyncIterator[DataLike]], Awaitable[DataLike]],
686
+ type(aiter_message).__anext__,
687
+ )(aiter_message)
688
+ except StopAsyncIteration:
689
+ return
690
+ opcode, data = prepare_data(fragment)
691
+
692
+ self._fragmented_message_waiter = self.loop.create_future()
693
+ try:
694
+ # First fragment.
695
+ await self.write_frame(False, opcode, data)
696
+
697
+ # Other fragments.
698
+ async for fragment in aiter_message:
699
+ confirm_opcode, data = prepare_data(fragment)
700
+ if confirm_opcode != opcode:
701
+ raise TypeError("data contains inconsistent types")
702
+ await self.write_frame(False, OP_CONT, data)
703
+
704
+ # Final fragment.
705
+ await self.write_frame(True, OP_CONT, b"")
706
+
707
+ except (Exception, asyncio.CancelledError):
708
+ # We're half-way through a fragmented message and we can't
709
+ # complete it. This makes the connection unusable.
710
+ self.fail_connection(CloseCode.INTERNAL_ERROR)
711
+ raise
712
+
713
+ finally:
714
+ self._fragmented_message_waiter.set_result(None)
715
+ self._fragmented_message_waiter = None
716
+
717
+ else:
718
+ raise TypeError("data must be str, bytes-like, or iterable")
719
+
720
+ async def close(
721
+ self,
722
+ code: int = CloseCode.NORMAL_CLOSURE,
723
+ reason: str = "",
724
+ ) -> None:
725
+ """
726
+ Perform the closing handshake.
727
+
728
+ :meth:`close` waits for the other end to complete the handshake and
729
+ for the TCP connection to terminate. As a consequence, there's no need
730
+ to await :meth:`wait_closed` after :meth:`close`.
731
+
732
+ :meth:`close` is idempotent: it doesn't do anything once the
733
+ connection is closed.
734
+
735
+ Wrapping :func:`close` in :func:`~asyncio.create_task` is safe, given
736
+ that errors during connection termination aren't particularly useful.
737
+
738
+ Canceling :meth:`close` is discouraged. If it takes too long, you can
739
+ set a shorter ``close_timeout``. If you don't want to wait, let the
740
+ Python process exit, then the OS will take care of closing the TCP
741
+ connection.
742
+
743
+ Args:
744
+ code: WebSocket close code.
745
+ reason: WebSocket close reason.
746
+
747
+ """
748
+ try:
749
+ async with asyncio_timeout(self.close_timeout):
750
+ await self.write_close_frame(Close(code, reason))
751
+ except asyncio.TimeoutError:
752
+ # If the close frame cannot be sent because the send buffers
753
+ # are full, the closing handshake won't complete anyway.
754
+ # Fail the connection to shut down faster.
755
+ self.fail_connection()
756
+
757
+ # If no close frame is received within the timeout, asyncio_timeout()
758
+ # cancels the data transfer task and raises TimeoutError.
759
+
760
+ # If close() is called multiple times concurrently and one of these
761
+ # calls hits the timeout, the data transfer task will be canceled.
762
+ # Other calls will receive a CancelledError here.
763
+
764
+ try:
765
+ # If close() is canceled during the wait, self.transfer_data_task
766
+ # is canceled before the timeout elapses.
767
+ async with asyncio_timeout(self.close_timeout):
768
+ await self.transfer_data_task
769
+ except (asyncio.TimeoutError, asyncio.CancelledError):
770
+ pass
771
+
772
+ # Wait for the close connection task to close the TCP connection.
773
+ await asyncio.shield(self.close_connection_task)
774
+
775
+ async def wait_closed(self) -> None:
776
+ """
777
+ Wait until the connection is closed.
778
+
779
+ This coroutine is identical to the :attr:`closed` attribute, except it
780
+ can be awaited.
781
+
782
+ This can make it easier to detect connection termination, regardless
783
+ of its cause, in tasks that interact with the WebSocket connection.
784
+
785
+ """
786
+ await asyncio.shield(self.connection_lost_waiter)
787
+
788
+ async def ping(self, data: DataLike | None = None) -> Awaitable[float]:
789
+ """
790
+ Send a Ping_.
791
+
792
+ .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2
793
+
794
+ A ping may serve as a keepalive, as a check that the remote endpoint
795
+ received all messages up to this point, or to measure :attr:`latency`.
796
+
797
+ Canceling :meth:`ping` is discouraged. If :meth:`ping` doesn't return
798
+ immediately, it means the write buffer is full. If you don't want to
799
+ wait, you should close the connection.
800
+
801
+ Canceling the :class:`~asyncio.Future` returned by :meth:`ping` has no
802
+ effect.
803
+
804
+ Args:
805
+ data: Payload of the ping. A string will be encoded to UTF-8.
806
+ If ``data`` is :obj:`None`, the payload is four random bytes.
807
+
808
+ Returns:
809
+ A future that will be completed when the corresponding pong is
810
+ received. You can ignore it if you don't intend to wait. The result
811
+ of the future is the latency of the connection in seconds.
812
+
813
+ ::
814
+
815
+ pong_waiter = await ws.ping()
816
+ # only if you want to wait for the corresponding pong
817
+ latency = await pong_waiter
818
+
819
+ Raises:
820
+ ConnectionClosed: When the connection is closed.
821
+ RuntimeError: If another ping was sent with the same data and
822
+ the corresponding pong wasn't received yet.
823
+
824
+ """
825
+ await self.ensure_open()
826
+
827
+ if data is not None:
828
+ data = prepare_ctrl(data)
829
+
830
+ # Protect against duplicates if a payload is explicitly set.
831
+ if data in self.pings:
832
+ raise RuntimeError("already waiting for a pong with the same data")
833
+
834
+ # Generate a unique random payload otherwise.
835
+ while data is None or data in self.pings:
836
+ data = struct.pack("!I", random.getrandbits(32))
837
+
838
+ pong_waiter = self.loop.create_future()
839
+ # Resolution of time.monotonic() may be too low on Windows.
840
+ ping_timestamp = time.perf_counter()
841
+ self.pings[data] = (pong_waiter, ping_timestamp)
842
+
843
+ await self.write_frame(True, OP_PING, data)
844
+
845
+ return asyncio.shield(pong_waiter)
846
+
847
+ async def pong(self, data: DataLike = b"") -> None:
848
+ """
849
+ Send a Pong_.
850
+
851
+ .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3
852
+
853
+ An unsolicited pong may serve as a unidirectional heartbeat.
854
+
855
+ Canceling :meth:`pong` is discouraged. If :meth:`pong` doesn't return
856
+ immediately, it means the write buffer is full. If you don't want to
857
+ wait, you should close the connection.
858
+
859
+ Args:
860
+ data: Payload of the pong. A string will be encoded to UTF-8.
861
+
862
+ Raises:
863
+ ConnectionClosed: When the connection is closed.
864
+
865
+ """
866
+ await self.ensure_open()
867
+
868
+ data = prepare_ctrl(data)
869
+
870
+ await self.write_frame(True, OP_PONG, data)
871
+
872
+ # Private methods - no guarantees.
873
+
874
+ def connection_closed_exc(self) -> ConnectionClosed:
875
+ exc: ConnectionClosed
876
+ if (
877
+ self.close_rcvd is not None
878
+ and self.close_rcvd.code in OK_CLOSE_CODES
879
+ and self.close_sent is not None
880
+ and self.close_sent.code in OK_CLOSE_CODES
881
+ ):
882
+ exc = ConnectionClosedOK(
883
+ self.close_rcvd,
884
+ self.close_sent,
885
+ self.close_rcvd_then_sent,
886
+ )
887
+ else:
888
+ exc = ConnectionClosedError(
889
+ self.close_rcvd,
890
+ self.close_sent,
891
+ self.close_rcvd_then_sent,
892
+ )
893
+ # Chain to the exception that terminated data transfer, if any.
894
+ exc.__cause__ = self.transfer_data_exc
895
+ return exc
896
+
897
+ async def ensure_open(self) -> None:
898
+ """
899
+ Check that the WebSocket connection is open.
900
+
901
+ Raise :exc:`~websockets.exceptions.ConnectionClosed` if it isn't.
902
+
903
+ """
904
+ # Handle cases from most common to least common for performance.
905
+ if self.state is State.OPEN:
906
+ # If self.transfer_data_task exited without a closing handshake,
907
+ # self.close_connection_task may be closing the connection, going
908
+ # straight from OPEN to CLOSED.
909
+ if self.transfer_data_task.done():
910
+ await asyncio.shield(self.close_connection_task)
911
+ raise self.connection_closed_exc()
912
+ else:
913
+ return
914
+
915
+ if self.state is State.CLOSED:
916
+ raise self.connection_closed_exc()
917
+
918
+ if self.state is State.CLOSING:
919
+ # If we started the closing handshake, wait for its completion to
920
+ # get the proper close code and reason. self.close_connection_task
921
+ # will complete within 4 or 5 * close_timeout after close(). The
922
+ # CLOSING state also occurs when failing the connection. In that
923
+ # case self.close_connection_task will complete even faster.
924
+ await asyncio.shield(self.close_connection_task)
925
+ raise self.connection_closed_exc()
926
+
927
+ # Control may only reach this point in buggy third-party subclasses.
928
+ assert self.state is State.CONNECTING
929
+ raise InvalidState("WebSocket connection isn't established yet")
930
+
931
+ async def transfer_data(self) -> None:
932
+ """
933
+ Read incoming messages and put them in a queue.
934
+
935
+ This coroutine runs in a task until the closing handshake is started.
936
+
937
+ """
938
+ try:
939
+ while True:
940
+ message = await self.read_message()
941
+
942
+ # Exit the loop when receiving a close frame.
943
+ if message is None:
944
+ break
945
+
946
+ # Wait until there's room in the queue (if necessary).
947
+ if self.max_queue is not None:
948
+ while len(self.messages) >= self.max_queue:
949
+ self._put_message_waiter = self.loop.create_future()
950
+ try:
951
+ await asyncio.shield(self._put_message_waiter)
952
+ finally:
953
+ self._put_message_waiter = None
954
+
955
+ # Put the message in the queue.
956
+ self.messages.append(message)
957
+
958
+ # Notify recv().
959
+ if self._pop_message_waiter is not None:
960
+ self._pop_message_waiter.set_result(None)
961
+ self._pop_message_waiter = None
962
+
963
+ except asyncio.CancelledError as exc:
964
+ self.transfer_data_exc = exc
965
+ # If fail_connection() cancels this task, avoid logging the error
966
+ # twice and failing the connection again.
967
+ raise
968
+
969
+ except ProtocolError as exc:
970
+ self.transfer_data_exc = exc
971
+ self.fail_connection(CloseCode.PROTOCOL_ERROR)
972
+
973
+ except (ConnectionError, TimeoutError, EOFError, ssl.SSLError) as exc:
974
+ # Reading data with self.reader.readexactly may raise:
975
+ # - most subclasses of ConnectionError if the TCP connection
976
+ # breaks, is reset, or is aborted;
977
+ # - TimeoutError if the TCP connection times out;
978
+ # - IncompleteReadError, a subclass of EOFError, if fewer
979
+ # bytes are available than requested;
980
+ # - ssl.SSLError if the other side infringes the TLS protocol.
981
+ self.transfer_data_exc = exc
982
+ self.fail_connection(CloseCode.ABNORMAL_CLOSURE)
983
+
984
+ except UnicodeDecodeError as exc:
985
+ self.transfer_data_exc = exc
986
+ self.fail_connection(CloseCode.INVALID_DATA)
987
+
988
+ except PayloadTooBig as exc:
989
+ self.transfer_data_exc = exc
990
+ self.fail_connection(CloseCode.MESSAGE_TOO_BIG)
991
+
992
+ except Exception as exc:
993
+ # This shouldn't happen often because exceptions expected under
994
+ # regular circumstances are handled above. If it does, consider
995
+ # catching and handling more exceptions.
996
+ self.logger.error("data transfer failed", exc_info=True)
997
+
998
+ self.transfer_data_exc = exc
999
+ self.fail_connection(CloseCode.INTERNAL_ERROR)
1000
+
1001
+ async def read_message(self) -> Data | None:
1002
+ """
1003
+ Read a single message from the connection.
1004
+
1005
+ Re-assemble data frames if the message is fragmented.
1006
+
1007
+ Return :obj:`None` when the closing handshake is started.
1008
+
1009
+ """
1010
+ frame = await self.read_data_frame(max_size=self.max_size)
1011
+
1012
+ # A close frame was received.
1013
+ if frame is None:
1014
+ return None
1015
+
1016
+ if frame.opcode == OP_TEXT:
1017
+ text = True
1018
+ elif frame.opcode == OP_BINARY:
1019
+ text = False
1020
+ else: # frame.opcode == OP_CONT
1021
+ raise ProtocolError("unexpected opcode")
1022
+
1023
+ # Shortcut for the common case - no fragmentation
1024
+ if frame.fin:
1025
+ if isinstance(frame.data, memoryview):
1026
+ raise AssertionError("only compressed outgoing frames use memoryview")
1027
+ return frame.data.decode() if text else bytes(frame.data)
1028
+
1029
+ # 5.4. Fragmentation
1030
+ fragments: list[DataLike] = []
1031
+ max_size = self.max_size
1032
+ if text:
1033
+ decoder_factory = codecs.getincrementaldecoder("utf-8")
1034
+ decoder = decoder_factory(errors="strict")
1035
+ if max_size is None:
1036
+
1037
+ def append(frame: Frame) -> None:
1038
+ nonlocal fragments
1039
+ fragments.append(decoder.decode(frame.data, frame.fin))
1040
+
1041
+ else:
1042
+
1043
+ def append(frame: Frame) -> None:
1044
+ nonlocal fragments, max_size
1045
+ fragments.append(decoder.decode(frame.data, frame.fin))
1046
+ assert isinstance(max_size, int)
1047
+ max_size -= len(frame.data)
1048
+
1049
+ else:
1050
+ if max_size is None:
1051
+
1052
+ def append(frame: Frame) -> None:
1053
+ nonlocal fragments
1054
+ fragments.append(frame.data)
1055
+
1056
+ else:
1057
+
1058
+ def append(frame: Frame) -> None:
1059
+ nonlocal fragments, max_size
1060
+ fragments.append(frame.data)
1061
+ assert isinstance(max_size, int)
1062
+ max_size -= len(frame.data)
1063
+
1064
+ append(frame)
1065
+
1066
+ while not frame.fin:
1067
+ frame = await self.read_data_frame(max_size=max_size)
1068
+ if frame is None:
1069
+ raise ProtocolError("incomplete fragmented message")
1070
+ if frame.opcode != OP_CONT:
1071
+ raise ProtocolError("unexpected opcode")
1072
+ append(frame)
1073
+
1074
+ return ("" if text else b"").join(fragments)
1075
+
1076
+ async def read_data_frame(self, max_size: int | None) -> Frame | None:
1077
+ """
1078
+ Read a single data frame from the connection.
1079
+
1080
+ Process control frames received before the next data frame.
1081
+
1082
+ Return :obj:`None` if a close frame is encountered before any data frame.
1083
+
1084
+ """
1085
+ # 6.2. Receiving Data
1086
+ while True:
1087
+ frame = await self.read_frame(max_size)
1088
+
1089
+ # 5.5. Control Frames
1090
+ if frame.opcode == OP_CLOSE:
1091
+ # 7.1.5. The WebSocket Connection Close Code
1092
+ # 7.1.6. The WebSocket Connection Close Reason
1093
+ self.close_rcvd = Close.parse(frame.data)
1094
+ if self.close_sent is not None:
1095
+ self.close_rcvd_then_sent = False
1096
+ try:
1097
+ # Echo the original data instead of re-serializing it with
1098
+ # Close.serialize() because that fails when the close frame
1099
+ # is empty and Close.parse() synthesizes a 1005 close code.
1100
+ await self.write_close_frame(self.close_rcvd, frame.data)
1101
+ except ConnectionClosed:
1102
+ # Connection closed before we could echo the close frame.
1103
+ pass
1104
+ return None
1105
+
1106
+ elif frame.opcode == OP_PING:
1107
+ # Answer pings, unless connection is CLOSING.
1108
+ if self.state is State.OPEN:
1109
+ try:
1110
+ await self.pong(frame.data)
1111
+ except ConnectionClosed:
1112
+ # Connection closed while draining write buffer.
1113
+ pass
1114
+
1115
+ elif frame.opcode == OP_PONG:
1116
+ if frame.data in self.pings:
1117
+ pong_timestamp = time.perf_counter()
1118
+ # Sending a pong for only the most recent ping is legal.
1119
+ # Acknowledge all previous pings too in that case.
1120
+ ping_id = None
1121
+ ping_ids = []
1122
+ for ping_id, (pong_waiter, ping_timestamp) in self.pings.items():
1123
+ ping_ids.append(ping_id)
1124
+ if not pong_waiter.done():
1125
+ pong_waiter.set_result(pong_timestamp - ping_timestamp)
1126
+ if ping_id == frame.data:
1127
+ self.latency = pong_timestamp - ping_timestamp
1128
+ break
1129
+ else:
1130
+ raise AssertionError("solicited pong not found in pings")
1131
+ # Remove acknowledged pings from self.pings.
1132
+ for ping_id in ping_ids:
1133
+ del self.pings[ping_id]
1134
+
1135
+ # 5.6. Data Frames
1136
+ else:
1137
+ return frame
1138
+
1139
+ async def read_frame(self, max_size: int | None) -> Frame:
1140
+ """
1141
+ Read a single frame from the connection.
1142
+
1143
+ """
1144
+ frame = await Frame.read(
1145
+ self.reader.readexactly,
1146
+ mask=not self.is_client,
1147
+ max_size=max_size,
1148
+ extensions=self.extensions,
1149
+ )
1150
+ if self.debug:
1151
+ self.logger.debug("< %s", frame)
1152
+ return frame
1153
+
1154
+ def write_frame_sync(self, fin: bool, opcode: int, data: BytesLike) -> None:
1155
+ frame = Frame(fin, Opcode(opcode), data)
1156
+ if self.debug:
1157
+ self.logger.debug("> %s", frame)
1158
+ frame.write(
1159
+ self.transport.write,
1160
+ mask=self.is_client,
1161
+ extensions=self.extensions,
1162
+ )
1163
+
1164
+ async def drain(self) -> None:
1165
+ try:
1166
+ # Handle flow control automatically.
1167
+ await self._drain()
1168
+ except ConnectionError:
1169
+ # Terminate the connection if the socket died.
1170
+ self.fail_connection()
1171
+ # Wait until the connection is closed to raise ConnectionClosed
1172
+ # with the correct code and reason.
1173
+ await self.ensure_open()
1174
+
1175
+ async def write_frame(
1176
+ self, fin: bool, opcode: int, data: BytesLike, *, _state: int = State.OPEN
1177
+ ) -> None:
1178
+ # Defensive assertion for protocol compliance.
1179
+ if self.state is not _state: # pragma: no cover
1180
+ raise InvalidState(
1181
+ f"Cannot write to a WebSocket in the {self.state.name} state"
1182
+ )
1183
+ self.write_frame_sync(fin, opcode, data)
1184
+ await self.drain()
1185
+
1186
+ async def write_close_frame(
1187
+ self, close: Close, data: BytesLike | None = None
1188
+ ) -> None:
1189
+ """
1190
+ Write a close frame if and only if the connection state is OPEN.
1191
+
1192
+ This dedicated coroutine must be used for writing close frames to
1193
+ ensure that at most one close frame is sent on a given connection.
1194
+
1195
+ """
1196
+ # Test and set the connection state before sending the close frame to
1197
+ # avoid sending two frames in case of concurrent calls.
1198
+ if self.state is State.OPEN:
1199
+ # 7.1.3. The WebSocket Closing Handshake is Started
1200
+ self.state = State.CLOSING
1201
+ if self.debug:
1202
+ self.logger.debug("= connection is CLOSING")
1203
+
1204
+ self.close_sent = close
1205
+ if self.close_rcvd is not None:
1206
+ self.close_rcvd_then_sent = True
1207
+ if data is None:
1208
+ data = close.serialize()
1209
+
1210
+ # 7.1.2. Start the WebSocket Closing Handshake
1211
+ await self.write_frame(True, OP_CLOSE, data, _state=State.CLOSING)
1212
+
1213
+ async def keepalive_ping(self) -> None:
1214
+ """
1215
+ Send a Ping frame and wait for a Pong frame at regular intervals.
1216
+
1217
+ This coroutine exits when the connection terminates and one of the
1218
+ following happens:
1219
+
1220
+ - :meth:`ping` raises :exc:`ConnectionClosed`, or
1221
+ - :meth:`close_connection` cancels :attr:`keepalive_ping_task`.
1222
+
1223
+ """
1224
+ if self.ping_interval is None:
1225
+ return
1226
+
1227
+ try:
1228
+ while True:
1229
+ await asyncio.sleep(self.ping_interval)
1230
+
1231
+ if self.debug:
1232
+ self.logger.debug("% sending keepalive ping")
1233
+ pong_waiter = await self.ping()
1234
+
1235
+ if self.ping_timeout is not None:
1236
+ try:
1237
+ async with asyncio_timeout(self.ping_timeout):
1238
+ # Raises CancelledError if the connection is closed,
1239
+ # when close_connection() cancels keepalive_ping().
1240
+ # Raises ConnectionClosed if the connection is lost,
1241
+ # when connection_lost() calls abort_pings().
1242
+ await pong_waiter
1243
+ if self.debug:
1244
+ self.logger.debug("% received keepalive pong")
1245
+ except asyncio.TimeoutError:
1246
+ if self.debug:
1247
+ self.logger.debug("- timed out waiting for keepalive pong")
1248
+ self.fail_connection(
1249
+ CloseCode.INTERNAL_ERROR,
1250
+ "keepalive ping timeout",
1251
+ )
1252
+ break
1253
+
1254
+ except ConnectionClosed:
1255
+ pass
1256
+
1257
+ except Exception:
1258
+ self.logger.error("keepalive ping failed", exc_info=True)
1259
+
1260
+ async def close_connection(self) -> None:
1261
+ """
1262
+ 7.1.1. Close the WebSocket Connection
1263
+
1264
+ When the opening handshake succeeds, :meth:`connection_open` starts
1265
+ this coroutine in a task. It waits for the data transfer phase to
1266
+ complete then it closes the TCP connection cleanly.
1267
+
1268
+ When the opening handshake fails, :meth:`fail_connection` does the
1269
+ same. There's no data transfer phase in that case.
1270
+
1271
+ """
1272
+ try:
1273
+ # Wait for the data transfer phase to complete.
1274
+ if hasattr(self, "transfer_data_task"):
1275
+ try:
1276
+ await self.transfer_data_task
1277
+ except asyncio.CancelledError:
1278
+ pass
1279
+
1280
+ # Cancel the keepalive ping task.
1281
+ if hasattr(self, "keepalive_ping_task"):
1282
+ self.keepalive_ping_task.cancel()
1283
+
1284
+ # A client should wait for a TCP close from the server.
1285
+ if self.is_client and hasattr(self, "transfer_data_task"):
1286
+ if await self.wait_for_connection_lost():
1287
+ return
1288
+ if self.debug:
1289
+ self.logger.debug("- timed out waiting for TCP close")
1290
+
1291
+ # Half-close the TCP connection if possible (when there's no TLS).
1292
+ if self.transport.can_write_eof():
1293
+ if self.debug:
1294
+ self.logger.debug("x half-closing TCP connection")
1295
+ # write_eof() doesn't document which exceptions it raises.
1296
+ # "[Errno 107] Transport endpoint is not connected" happens
1297
+ # but it isn't completely clear under which circumstances.
1298
+ # uvloop can raise RuntimeError here.
1299
+ try:
1300
+ self.transport.write_eof()
1301
+ except (OSError, RuntimeError): # pragma: no cover
1302
+ pass
1303
+
1304
+ if await self.wait_for_connection_lost():
1305
+ return
1306
+ if self.debug:
1307
+ self.logger.debug("- timed out waiting for TCP close")
1308
+
1309
+ finally:
1310
+ # The try/finally ensures that the transport never remains open,
1311
+ # even if this coroutine is canceled (for example).
1312
+ await self.close_transport()
1313
+
1314
+ async def close_transport(self) -> None:
1315
+ """
1316
+ Close the TCP connection.
1317
+
1318
+ """
1319
+ # If connection_lost() was called, the TCP connection is closed.
1320
+ # However, if TLS is enabled, the transport still needs closing.
1321
+ # Else asyncio complains: ResourceWarning: unclosed transport.
1322
+ if self.connection_lost_waiter.done() and self.transport.is_closing():
1323
+ return
1324
+
1325
+ # Close the TCP connection. Buffers are flushed asynchronously.
1326
+ if self.debug:
1327
+ self.logger.debug("x closing TCP connection")
1328
+ self.transport.close()
1329
+
1330
+ if await self.wait_for_connection_lost():
1331
+ return
1332
+ if self.debug:
1333
+ self.logger.debug("- timed out waiting for TCP close")
1334
+
1335
+ # Abort the TCP connection. Buffers are discarded.
1336
+ if self.debug:
1337
+ self.logger.debug("x aborting TCP connection")
1338
+ self.transport.abort()
1339
+
1340
+ # connection_lost() is called quickly after aborting.
1341
+ await self.wait_for_connection_lost()
1342
+
1343
+ async def wait_for_connection_lost(self) -> bool:
1344
+ """
1345
+ Wait until the TCP connection is closed or ``self.close_timeout`` elapses.
1346
+
1347
+ Return :obj:`True` if the connection is closed and :obj:`False`
1348
+ otherwise.
1349
+
1350
+ """
1351
+ if not self.connection_lost_waiter.done():
1352
+ try:
1353
+ async with asyncio_timeout(self.close_timeout):
1354
+ await asyncio.shield(self.connection_lost_waiter)
1355
+ except asyncio.TimeoutError:
1356
+ pass
1357
+ # Re-check self.connection_lost_waiter.done() synchronously because
1358
+ # connection_lost() could run between the moment the timeout occurs
1359
+ # and the moment this coroutine resumes running.
1360
+ return self.connection_lost_waiter.done()
1361
+
1362
+ def fail_connection(
1363
+ self,
1364
+ code: int = CloseCode.ABNORMAL_CLOSURE,
1365
+ reason: str = "",
1366
+ ) -> None:
1367
+ """
1368
+ 7.1.7. Fail the WebSocket Connection
1369
+
1370
+ This requires:
1371
+
1372
+ 1. Stopping all processing of incoming data, which means canceling
1373
+ :attr:`transfer_data_task`. The close code will be 1006 unless a
1374
+ close frame was received earlier.
1375
+
1376
+ 2. Sending a close frame with an appropriate code if the opening
1377
+ handshake succeeded and the other side is likely to process it.
1378
+
1379
+ 3. Closing the connection. :meth:`close_connection` takes care of
1380
+ this once :attr:`transfer_data_task` exits after being canceled.
1381
+
1382
+ (The specification describes these steps in the opposite order.)
1383
+
1384
+ """
1385
+ if self.debug:
1386
+ self.logger.debug("! failing connection with code %d", code)
1387
+
1388
+ # Cancel transfer_data_task if the opening handshake succeeded.
1389
+ # cancel() is idempotent and ignored if the task is done already.
1390
+ if hasattr(self, "transfer_data_task"):
1391
+ self.transfer_data_task.cancel()
1392
+
1393
+ # Send a close frame when the state is OPEN (a close frame was already
1394
+ # sent if it's CLOSING), except when failing the connection because of
1395
+ # an error reading from or writing to the network.
1396
+ # Don't send a close frame if the connection is broken.
1397
+ if code != CloseCode.ABNORMAL_CLOSURE and self.state is State.OPEN:
1398
+ close = Close(code, reason)
1399
+
1400
+ # Write the close frame without draining the write buffer.
1401
+
1402
+ # Keeping fail_connection() synchronous guarantees it can't
1403
+ # get stuck and simplifies the implementation of the callers.
1404
+ # Not drainig the write buffer is acceptable in this context.
1405
+
1406
+ # This duplicates a few lines of code from write_close_frame().
1407
+
1408
+ self.state = State.CLOSING
1409
+ if self.debug:
1410
+ self.logger.debug("= connection is CLOSING")
1411
+
1412
+ # If self.close_rcvd was set, the connection state would be
1413
+ # CLOSING. Therefore self.close_rcvd isn't set and we don't
1414
+ # have to set self.close_rcvd_then_sent.
1415
+ assert self.close_rcvd is None
1416
+ self.close_sent = close
1417
+
1418
+ self.write_frame_sync(True, OP_CLOSE, close.serialize())
1419
+
1420
+ # Start close_connection_task if the opening handshake didn't succeed.
1421
+ if not hasattr(self, "close_connection_task"):
1422
+ self.close_connection_task = self.loop.create_task(self.close_connection())
1423
+
1424
+ def abort_pings(self) -> None:
1425
+ """
1426
+ Raise ConnectionClosed in pending keepalive pings.
1427
+
1428
+ They'll never receive a pong once the connection is closed.
1429
+
1430
+ """
1431
+ assert self.state is State.CLOSED
1432
+ exc = self.connection_closed_exc()
1433
+
1434
+ for pong_waiter, _ping_timestamp in self.pings.values():
1435
+ pong_waiter.set_exception(exc)
1436
+ # If the exception is never retrieved, it will be logged when ping
1437
+ # is garbage-collected. This is confusing for users.
1438
+ # Given that ping is done (with an exception), canceling it does
1439
+ # nothing, but it prevents logging the exception.
1440
+ pong_waiter.cancel()
1441
+
1442
+ # asyncio.Protocol methods
1443
+
1444
+ def connection_made(self, transport: asyncio.BaseTransport) -> None:
1445
+ """
1446
+ Configure write buffer limits.
1447
+
1448
+ The high-water limit is defined by ``self.write_limit``.
1449
+
1450
+ The low-water limit currently defaults to ``self.write_limit // 4`` in
1451
+ :meth:`~asyncio.WriteTransport.set_write_buffer_limits`, which should
1452
+ be all right for reasonable use cases of this library.
1453
+
1454
+ This is the earliest point where we can get hold of the transport,
1455
+ which means it's the best point for configuring it.
1456
+
1457
+ """
1458
+ transport = cast(asyncio.Transport, transport)
1459
+ transport.set_write_buffer_limits(self.write_limit)
1460
+ self.transport = transport
1461
+
1462
+ # Copied from asyncio.StreamReaderProtocol
1463
+ self.reader.set_transport(transport)
1464
+
1465
+ def connection_lost(self, exc: Exception | None) -> None:
1466
+ """
1467
+ 7.1.4. The WebSocket Connection is Closed.
1468
+
1469
+ """
1470
+ self.state = State.CLOSED
1471
+ if self.debug:
1472
+ self.logger.debug("= connection is CLOSED")
1473
+
1474
+ self.abort_pings()
1475
+
1476
+ # If self.connection_lost_waiter isn't pending, that's a bug, because:
1477
+ # - it's set only here in connection_lost() which is called only once;
1478
+ # - it must never be canceled.
1479
+ self.connection_lost_waiter.set_result(None)
1480
+
1481
+ if True: # pragma: no cover
1482
+ # Copied from asyncio.StreamReaderProtocol
1483
+ if self.reader is not None:
1484
+ if exc is None:
1485
+ self.reader.feed_eof()
1486
+ else:
1487
+ self.reader.set_exception(exc)
1488
+
1489
+ # Copied from asyncio.FlowControlMixin
1490
+ # Wake up the writer if currently paused.
1491
+ if not self._paused:
1492
+ return
1493
+ waiter = self._drain_waiter
1494
+ if waiter is None:
1495
+ return
1496
+ self._drain_waiter = None
1497
+ if waiter.done():
1498
+ return
1499
+ if exc is None:
1500
+ waiter.set_result(None)
1501
+ else:
1502
+ waiter.set_exception(exc)
1503
+
1504
+ def pause_writing(self) -> None: # pragma: no cover
1505
+ assert not self._paused
1506
+ self._paused = True
1507
+
1508
+ def resume_writing(self) -> None: # pragma: no cover
1509
+ assert self._paused
1510
+ self._paused = False
1511
+
1512
+ waiter = self._drain_waiter
1513
+ if waiter is not None:
1514
+ self._drain_waiter = None
1515
+ if not waiter.done():
1516
+ waiter.set_result(None)
1517
+
1518
+ def data_received(self, data: bytes) -> None:
1519
+ self.reader.feed_data(data)
1520
+
1521
+ def eof_received(self) -> None:
1522
+ """
1523
+ Close the transport after receiving EOF.
1524
+
1525
+ The WebSocket protocol has its own closing handshake: endpoints close
1526
+ the TCP or TLS connection after sending and receiving a close frame.
1527
+
1528
+ As a consequence, they never need to write after receiving EOF, so
1529
+ there's no reason to keep the transport open by returning :obj:`True`.
1530
+
1531
+ Besides, that doesn't work on TLS connections.
1532
+
1533
+ """
1534
+ self.reader.feed_eof()
1535
+
1536
+
1537
+ # broadcast() is defined in the protocol module even though it's primarily
1538
+ # used by servers and documented in the server module because it works with
1539
+ # client connections too and because it's easier to test together with the
1540
+ # WebSocketCommonProtocol class.
1541
+
1542
+
1543
+ def broadcast(
1544
+ websockets: Iterable[WebSocketCommonProtocol],
1545
+ message: DataLike,
1546
+ raise_exceptions: bool = False,
1547
+ ) -> None:
1548
+ """
1549
+ Broadcast a message to several WebSocket connections.
1550
+
1551
+ A string (:class:`str`) is sent as a Text_ frame. A bytestring or bytes-like
1552
+ object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent
1553
+ as a Binary_ frame.
1554
+
1555
+ .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
1556
+ .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
1557
+
1558
+ :func:`broadcast` pushes the message synchronously to all connections even
1559
+ if their write buffers are overflowing. There's no backpressure.
1560
+
1561
+ If you broadcast messages faster than a connection can handle them, messages
1562
+ will pile up in its write buffer until the connection times out. Keep
1563
+ ``ping_interval`` and ``ping_timeout`` low to prevent excessive memory usage
1564
+ from slow connections.
1565
+
1566
+ Unlike :meth:`~websockets.legacy.protocol.WebSocketCommonProtocol.send`,
1567
+ :func:`broadcast` doesn't support sending fragmented messages. Indeed,
1568
+ fragmentation is useful for sending large messages without buffering them in
1569
+ memory, while :func:`broadcast` buffers one copy per connection as fast as
1570
+ possible.
1571
+
1572
+ :func:`broadcast` skips connections that aren't open in order to avoid
1573
+ errors on connections where the closing handshake is in progress.
1574
+
1575
+ :func:`broadcast` ignores failures to write the message on some connections.
1576
+ It continues writing to other connections. On Python 3.11 and above, you may
1577
+ set ``raise_exceptions`` to :obj:`True` to record failures and raise all
1578
+ exceptions in a :pep:`654` :exc:`ExceptionGroup`.
1579
+
1580
+ While :func:`broadcast` makes more sense for servers, it works identically
1581
+ with clients, if you have a use case for opening connections to many servers
1582
+ and broadcasting a message to them.
1583
+
1584
+ Args:
1585
+ websockets: WebSocket connections to which the message will be sent.
1586
+ message: Message to send.
1587
+ raise_exceptions: Whether to raise an exception in case of failures.
1588
+
1589
+ Raises:
1590
+ TypeError: If ``message`` doesn't have a supported type.
1591
+
1592
+ """
1593
+ if not isinstance(message, (str, bytes, bytearray, memoryview)):
1594
+ raise TypeError("data must be str or bytes-like")
1595
+
1596
+ if raise_exceptions:
1597
+ if sys.version_info[:2] < (3, 11): # pragma: no cover
1598
+ raise ValueError("raise_exceptions requires at least Python 3.11")
1599
+ exceptions = []
1600
+
1601
+ opcode, data = prepare_data(message)
1602
+
1603
+ for websocket in websockets:
1604
+ if websocket.state is not State.OPEN:
1605
+ continue
1606
+
1607
+ if websocket._fragmented_message_waiter is not None:
1608
+ if raise_exceptions:
1609
+ exception = RuntimeError("sending a fragmented message")
1610
+ exceptions.append(exception)
1611
+ else:
1612
+ websocket.logger.warning(
1613
+ "skipped broadcast: sending a fragmented message",
1614
+ )
1615
+ continue
1616
+
1617
+ try:
1618
+ websocket.write_frame_sync(True, opcode, data)
1619
+ except Exception as write_exception:
1620
+ if raise_exceptions:
1621
+ exception = RuntimeError("failed to write message")
1622
+ exception.__cause__ = write_exception
1623
+ exceptions.append(exception)
1624
+ else:
1625
+ websocket.logger.warning(
1626
+ "skipped broadcast: failed to write message: %s",
1627
+ traceback.format_exception_only(write_exception)[0].strip(),
1628
+ )
1629
+
1630
+ if raise_exceptions and exceptions:
1631
+ raise ExceptionGroup("skipped broadcast", exceptions)
1632
+
1633
+
1634
+ # Pretend that broadcast is actually defined in the server module.
1635
+ broadcast.__module__ = "websockets.legacy.server"
source/websockets/legacy/server.py ADDED
@@ -0,0 +1,1191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import email.utils
5
+ import functools
6
+ import http
7
+ import inspect
8
+ import logging
9
+ import socket
10
+ import warnings
11
+ from collections.abc import Awaitable, Generator, Iterable, Sequence
12
+ from types import TracebackType
13
+ from typing import Any, Callable, cast
14
+
15
+ from ..asyncio.compatibility import asyncio_timeout
16
+ from ..datastructures import Headers, HeadersLike, MultipleValuesError
17
+ from ..exceptions import (
18
+ InvalidHandshake,
19
+ InvalidHeader,
20
+ InvalidMessage,
21
+ InvalidOrigin,
22
+ InvalidUpgrade,
23
+ NegotiationError,
24
+ )
25
+ from ..extensions import Extension, ServerExtensionFactory
26
+ from ..extensions.permessage_deflate import enable_server_permessage_deflate
27
+ from ..headers import (
28
+ build_extension,
29
+ parse_extension,
30
+ parse_subprotocol,
31
+ validate_subprotocols,
32
+ )
33
+ from ..http11 import SERVER
34
+ from ..protocol import State
35
+ from ..typing import ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol
36
+ from .exceptions import AbortHandshake
37
+ from .handshake import build_response, check_request
38
+ from .http import read_request
39
+ from .protocol import WebSocketCommonProtocol, broadcast
40
+
41
+
42
+ __all__ = [
43
+ "broadcast",
44
+ "serve",
45
+ "unix_serve",
46
+ "WebSocketServerProtocol",
47
+ "WebSocketServer",
48
+ ]
49
+
50
+
51
+ HeadersLikeOrCallable = HeadersLike | Callable[[str, Headers], HeadersLike]
52
+
53
+ HTTPResponse = tuple[StatusLike, HeadersLike, bytes]
54
+
55
+
56
+ class WebSocketServerProtocol(WebSocketCommonProtocol):
57
+ """
58
+ WebSocket server connection.
59
+
60
+ :class:`WebSocketServerProtocol` provides :meth:`recv` and :meth:`send`
61
+ coroutines for receiving and sending messages.
62
+
63
+ It supports asynchronous iteration to receive messages::
64
+
65
+ async for message in websocket:
66
+ await process(message)
67
+
68
+ The iterator exits normally when the connection is closed with close code
69
+ 1000 (OK) or 1001 (going away) or without a close code. It raises
70
+ a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection
71
+ is closed with any other code.
72
+
73
+ You may customize the opening handshake in a subclass by
74
+ overriding :meth:`process_request` or :meth:`select_subprotocol`.
75
+
76
+ Args:
77
+ ws_server: WebSocket server that created this connection.
78
+
79
+ See :func:`serve` for the documentation of ``ws_handler``, ``logger``, ``origins``,
80
+ ``extensions``, ``subprotocols``, ``extra_headers``, and ``server_header``.
81
+
82
+ See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the
83
+ documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``,
84
+ ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``.
85
+
86
+ """
87
+
88
+ is_client = False
89
+ side = "server"
90
+
91
+ def __init__(
92
+ self,
93
+ # The version that accepts the path in the second argument is deprecated.
94
+ ws_handler: (
95
+ Callable[[WebSocketServerProtocol], Awaitable[Any]]
96
+ | Callable[[WebSocketServerProtocol, str], Awaitable[Any]]
97
+ ),
98
+ ws_server: WebSocketServer,
99
+ *,
100
+ logger: LoggerLike | None = None,
101
+ origins: Sequence[Origin | None] | None = None,
102
+ extensions: Sequence[ServerExtensionFactory] | None = None,
103
+ subprotocols: Sequence[Subprotocol] | None = None,
104
+ extra_headers: HeadersLikeOrCallable | None = None,
105
+ server_header: str | None = SERVER,
106
+ process_request: (
107
+ Callable[[str, Headers], Awaitable[HTTPResponse | None]] | None
108
+ ) = None,
109
+ select_subprotocol: (
110
+ Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] | None
111
+ ) = None,
112
+ open_timeout: float | None = 10,
113
+ **kwargs: Any,
114
+ ) -> None:
115
+ if logger is None:
116
+ logger = logging.getLogger("websockets.server")
117
+ super().__init__(logger=logger, **kwargs)
118
+ # For backwards compatibility with 6.0 or earlier.
119
+ if origins is not None and "" in origins:
120
+ warnings.warn("use None instead of '' in origins", DeprecationWarning)
121
+ origins = [None if origin == "" else origin for origin in origins]
122
+ # For backwards compatibility with 10.0 or earlier. Done here in
123
+ # addition to serve to trigger the deprecation warning on direct
124
+ # use of WebSocketServerProtocol.
125
+ self.ws_handler = remove_path_argument(ws_handler)
126
+ self.ws_server = ws_server
127
+ self.origins = origins
128
+ self.available_extensions = extensions
129
+ self.available_subprotocols = subprotocols
130
+ self.extra_headers = extra_headers
131
+ self.server_header = server_header
132
+ self._process_request = process_request
133
+ self._select_subprotocol = select_subprotocol
134
+ self.open_timeout = open_timeout
135
+
136
+ def connection_made(self, transport: asyncio.BaseTransport) -> None:
137
+ """
138
+ Register connection and initialize a task to handle it.
139
+
140
+ """
141
+ super().connection_made(transport)
142
+ # Register the connection with the server before creating the handler
143
+ # task. Registering at the beginning of the handler coroutine would
144
+ # create a race condition between the creation of the task, which
145
+ # schedules its execution, and the moment the handler starts running.
146
+ self.ws_server.register(self)
147
+ self.handler_task = self.loop.create_task(self.handler())
148
+
149
+ async def handler(self) -> None:
150
+ """
151
+ Handle the lifecycle of a WebSocket connection.
152
+
153
+ Since this method doesn't have a caller able to handle exceptions, it
154
+ attempts to log relevant ones and guarantees that the TCP connection is
155
+ closed before exiting.
156
+
157
+ """
158
+ try:
159
+ try:
160
+ async with asyncio_timeout(self.open_timeout):
161
+ await self.handshake(
162
+ origins=self.origins,
163
+ available_extensions=self.available_extensions,
164
+ available_subprotocols=self.available_subprotocols,
165
+ extra_headers=self.extra_headers,
166
+ )
167
+ except asyncio.TimeoutError: # pragma: no cover
168
+ raise
169
+ except ConnectionError:
170
+ raise
171
+ except Exception as exc:
172
+ if isinstance(exc, AbortHandshake):
173
+ status, headers, body = exc.status, exc.headers, exc.body
174
+ elif isinstance(exc, InvalidOrigin):
175
+ if self.debug:
176
+ self.logger.debug("! invalid origin", exc_info=True)
177
+ status, headers, body = (
178
+ http.HTTPStatus.FORBIDDEN,
179
+ Headers(),
180
+ f"Failed to open a WebSocket connection: {exc}.\n".encode(),
181
+ )
182
+ elif isinstance(exc, InvalidUpgrade):
183
+ if self.debug:
184
+ self.logger.debug("! invalid upgrade", exc_info=True)
185
+ status, headers, body = (
186
+ http.HTTPStatus.UPGRADE_REQUIRED,
187
+ Headers([("Upgrade", "websocket")]),
188
+ (
189
+ f"Failed to open a WebSocket connection: {exc}.\n"
190
+ f"\n"
191
+ f"You cannot access a WebSocket server directly "
192
+ f"with a browser. You need a WebSocket client.\n"
193
+ ).encode(),
194
+ )
195
+ elif isinstance(exc, InvalidHandshake):
196
+ if self.debug:
197
+ self.logger.debug("! invalid handshake", exc_info=True)
198
+ exc_chain = cast(BaseException, exc)
199
+ exc_str = f"{exc_chain}"
200
+ while exc_chain.__cause__ is not None:
201
+ exc_chain = exc_chain.__cause__
202
+ exc_str += f"; {exc_chain}"
203
+ status, headers, body = (
204
+ http.HTTPStatus.BAD_REQUEST,
205
+ Headers(),
206
+ f"Failed to open a WebSocket connection: {exc_str}.\n".encode(),
207
+ )
208
+ else:
209
+ self.logger.error("opening handshake failed", exc_info=True)
210
+ status, headers, body = (
211
+ http.HTTPStatus.INTERNAL_SERVER_ERROR,
212
+ Headers(),
213
+ (
214
+ b"Failed to open a WebSocket connection.\n"
215
+ b"See server log for more information.\n"
216
+ ),
217
+ )
218
+
219
+ headers.setdefault("Date", email.utils.formatdate(usegmt=True))
220
+ if self.server_header:
221
+ headers.setdefault("Server", self.server_header)
222
+
223
+ headers.setdefault("Content-Length", str(len(body)))
224
+ headers.setdefault("Content-Type", "text/plain")
225
+ headers.setdefault("Connection", "close")
226
+
227
+ self.write_http_response(status, headers, body)
228
+ self.logger.info(
229
+ "connection rejected (%d %s)", status.value, status.phrase
230
+ )
231
+ await self.close_transport()
232
+ return
233
+
234
+ try:
235
+ await self.ws_handler(self)
236
+ except Exception:
237
+ self.logger.error("connection handler failed", exc_info=True)
238
+ if not self.closed:
239
+ self.fail_connection(1011)
240
+ raise
241
+
242
+ try:
243
+ await self.close()
244
+ except ConnectionError:
245
+ raise
246
+ except Exception:
247
+ self.logger.error("closing handshake failed", exc_info=True)
248
+ raise
249
+
250
+ except Exception:
251
+ # Last-ditch attempt to avoid leaking connections on errors.
252
+ try:
253
+ self.transport.close()
254
+ except Exception: # pragma: no cover
255
+ pass
256
+
257
+ finally:
258
+ # Unregister the connection with the server when the handler task
259
+ # terminates. Registration is tied to the lifecycle of the handler
260
+ # task because the server waits for tasks attached to registered
261
+ # connections before terminating.
262
+ self.ws_server.unregister(self)
263
+ self.logger.info("connection closed")
264
+
265
+ async def read_http_request(self) -> tuple[str, Headers]:
266
+ """
267
+ Read request line and headers from the HTTP request.
268
+
269
+ If the request contains a body, it may be read from ``self.reader``
270
+ after this coroutine returns.
271
+
272
+ Raises:
273
+ InvalidMessage: If the HTTP message is malformed or isn't an
274
+ HTTP/1.1 GET request.
275
+
276
+ """
277
+ try:
278
+ path, headers = await read_request(self.reader)
279
+ except asyncio.CancelledError: # pragma: no cover
280
+ raise
281
+ except Exception as exc:
282
+ raise InvalidMessage("did not receive a valid HTTP request") from exc
283
+
284
+ if self.debug:
285
+ self.logger.debug("< GET %s HTTP/1.1", path)
286
+ for key, value in headers.raw_items():
287
+ self.logger.debug("< %s: %s", key, value)
288
+
289
+ self.path = path
290
+ self.request_headers = headers
291
+
292
+ return path, headers
293
+
294
+ def write_http_response(
295
+ self, status: http.HTTPStatus, headers: Headers, body: bytes | None = None
296
+ ) -> None:
297
+ """
298
+ Write status line and headers to the HTTP response.
299
+
300
+ This coroutine is also able to write a response body.
301
+
302
+ """
303
+ self.response_headers = headers
304
+
305
+ if self.debug:
306
+ self.logger.debug("> HTTP/1.1 %d %s", status.value, status.phrase)
307
+ for key, value in headers.raw_items():
308
+ self.logger.debug("> %s: %s", key, value)
309
+ if body is not None:
310
+ self.logger.debug("> [body] (%d bytes)", len(body))
311
+
312
+ # Since the status line and headers only contain ASCII characters,
313
+ # we can keep this simple.
314
+ response = f"HTTP/1.1 {status.value} {status.phrase}\r\n"
315
+ response += str(headers)
316
+
317
+ self.transport.write(response.encode())
318
+
319
+ if body is not None:
320
+ self.transport.write(body)
321
+
322
+ async def process_request(
323
+ self, path: str, request_headers: Headers
324
+ ) -> HTTPResponse | None:
325
+ """
326
+ Intercept the HTTP request and return an HTTP response if appropriate.
327
+
328
+ You may override this method in a :class:`WebSocketServerProtocol`
329
+ subclass, for example:
330
+
331
+ * to return an HTTP 200 OK response on a given path; then a load
332
+ balancer can use this path for a health check;
333
+ * to authenticate the request and return an HTTP 401 Unauthorized or an
334
+ HTTP 403 Forbidden when authentication fails.
335
+
336
+ You may also override this method with the ``process_request``
337
+ argument of :func:`serve` and :class:`WebSocketServerProtocol`. This
338
+ is equivalent, except ``process_request`` won't have access to the
339
+ protocol instance, so it can't store information for later use.
340
+
341
+ :meth:`process_request` is expected to complete quickly. If it may run
342
+ for a long time, then it should await :meth:`wait_closed` and exit if
343
+ :meth:`wait_closed` completes, or else it could prevent the server
344
+ from shutting down.
345
+
346
+ Args:
347
+ path: Request path, including optional query string.
348
+ request_headers: Request headers.
349
+
350
+ Returns:
351
+ tuple[StatusLike, HeadersLike, bytes] | None: :obj:`None` to
352
+ continue the WebSocket handshake normally.
353
+
354
+ An HTTP response, represented by a 3-uple of the response status,
355
+ headers, and body, to abort the WebSocket handshake and return
356
+ that HTTP response instead.
357
+
358
+ """
359
+ if self._process_request is not None:
360
+ response = self._process_request(path, request_headers)
361
+ if isinstance(response, Awaitable):
362
+ return await response
363
+ else:
364
+ # For backwards compatibility with 7.0.
365
+ warnings.warn(
366
+ "declare process_request as a coroutine", DeprecationWarning
367
+ )
368
+ return response
369
+ return None
370
+
371
+ @staticmethod
372
+ def process_origin(
373
+ headers: Headers, origins: Sequence[Origin | None] | None = None
374
+ ) -> Origin | None:
375
+ """
376
+ Handle the Origin HTTP request header.
377
+
378
+ Args:
379
+ headers: Request headers.
380
+ origins: Optional list of acceptable origins.
381
+
382
+ Raises:
383
+ InvalidOrigin: If the origin isn't acceptable.
384
+
385
+ """
386
+ # "The user agent MUST NOT include more than one Origin header field"
387
+ # per https://datatracker.ietf.org/doc/html/rfc6454#section-7.3.
388
+ try:
389
+ origin = headers.get("Origin")
390
+ except MultipleValuesError as exc:
391
+ raise InvalidHeader("Origin", "multiple values") from exc
392
+ if origin is not None:
393
+ origin = cast(Origin, origin)
394
+ if origins is not None:
395
+ if origin not in origins:
396
+ raise InvalidOrigin(origin)
397
+ return origin
398
+
399
+ @staticmethod
400
+ def process_extensions(
401
+ headers: Headers,
402
+ available_extensions: Sequence[ServerExtensionFactory] | None,
403
+ ) -> tuple[str | None, list[Extension]]:
404
+ """
405
+ Handle the Sec-WebSocket-Extensions HTTP request header.
406
+
407
+ Accept or reject each extension proposed in the client request.
408
+ Negotiate parameters for accepted extensions.
409
+
410
+ Return the Sec-WebSocket-Extensions HTTP response header and the list
411
+ of accepted extensions.
412
+
413
+ :rfc:`6455` leaves the rules up to the specification of each
414
+ :extension.
415
+
416
+ To provide this level of flexibility, for each extension proposed by
417
+ the client, we check for a match with each extension available in the
418
+ server configuration. If no match is found, the extension is ignored.
419
+
420
+ If several variants of the same extension are proposed by the client,
421
+ it may be accepted several times, which won't make sense in general.
422
+ Extensions must implement their own requirements. For this purpose,
423
+ the list of previously accepted extensions is provided.
424
+
425
+ This process doesn't allow the server to reorder extensions. It can
426
+ only select a subset of the extensions proposed by the client.
427
+
428
+ Other requirements, for example related to mandatory extensions or the
429
+ order of extensions, may be implemented by overriding this method.
430
+
431
+ Args:
432
+ headers: Request headers.
433
+ extensions: Optional list of supported extensions.
434
+
435
+ Raises:
436
+ InvalidHandshake: To abort the handshake with an HTTP 400 error.
437
+
438
+ """
439
+ response_header_value: str | None = None
440
+
441
+ extension_headers: list[ExtensionHeader] = []
442
+ accepted_extensions: list[Extension] = []
443
+
444
+ header_values = headers.get_all("Sec-WebSocket-Extensions")
445
+
446
+ if header_values and available_extensions:
447
+ parsed_header_values: list[ExtensionHeader] = sum(
448
+ [parse_extension(header_value) for header_value in header_values], []
449
+ )
450
+
451
+ for name, request_params in parsed_header_values:
452
+ for ext_factory in available_extensions:
453
+ # Skip non-matching extensions based on their name.
454
+ if ext_factory.name != name:
455
+ continue
456
+
457
+ # Skip non-matching extensions based on their params.
458
+ try:
459
+ response_params, extension = ext_factory.process_request_params(
460
+ request_params, accepted_extensions
461
+ )
462
+ except NegotiationError:
463
+ continue
464
+
465
+ # Add matching extension to the final list.
466
+ extension_headers.append((name, response_params))
467
+ accepted_extensions.append(extension)
468
+
469
+ # Break out of the loop once we have a match.
470
+ break
471
+
472
+ # If we didn't break from the loop, no extension in our list
473
+ # matched what the client sent. The extension is declined.
474
+
475
+ # Serialize extension header.
476
+ if extension_headers:
477
+ response_header_value = build_extension(extension_headers)
478
+
479
+ return response_header_value, accepted_extensions
480
+
481
+ # Not @staticmethod because it calls self.select_subprotocol()
482
+ def process_subprotocol(
483
+ self, headers: Headers, available_subprotocols: Sequence[Subprotocol] | None
484
+ ) -> Subprotocol | None:
485
+ """
486
+ Handle the Sec-WebSocket-Protocol HTTP request header.
487
+
488
+ Return Sec-WebSocket-Protocol HTTP response header, which is the same
489
+ as the selected subprotocol.
490
+
491
+ Args:
492
+ headers: Request headers.
493
+ available_subprotocols: Optional list of supported subprotocols.
494
+
495
+ Raises:
496
+ InvalidHandshake: To abort the handshake with an HTTP 400 error.
497
+
498
+ """
499
+ subprotocol: Subprotocol | None = None
500
+
501
+ header_values = headers.get_all("Sec-WebSocket-Protocol")
502
+
503
+ if header_values and available_subprotocols:
504
+ parsed_header_values: list[Subprotocol] = sum(
505
+ [parse_subprotocol(header_value) for header_value in header_values], []
506
+ )
507
+
508
+ subprotocol = self.select_subprotocol(
509
+ parsed_header_values, available_subprotocols
510
+ )
511
+
512
+ return subprotocol
513
+
514
+ def select_subprotocol(
515
+ self,
516
+ client_subprotocols: Sequence[Subprotocol],
517
+ server_subprotocols: Sequence[Subprotocol],
518
+ ) -> Subprotocol | None:
519
+ """
520
+ Pick a subprotocol among those supported by the client and the server.
521
+
522
+ If several subprotocols are available, select the preferred subprotocol
523
+ by giving equal weight to the preferences of the client and the server.
524
+
525
+ If no subprotocol is available, proceed without a subprotocol.
526
+
527
+ You may provide a ``select_subprotocol`` argument to :func:`serve` or
528
+ :class:`WebSocketServerProtocol` to override this logic. For example,
529
+ you could reject the handshake if the client doesn't support a
530
+ particular subprotocol, rather than accept the handshake without that
531
+ subprotocol.
532
+
533
+ Args:
534
+ client_subprotocols: List of subprotocols offered by the client.
535
+ server_subprotocols: List of subprotocols available on the server.
536
+
537
+ Returns:
538
+ Selected subprotocol, if a common subprotocol was found.
539
+
540
+ :obj:`None` to continue without a subprotocol.
541
+
542
+ """
543
+ if self._select_subprotocol is not None:
544
+ return self._select_subprotocol(client_subprotocols, server_subprotocols)
545
+
546
+ subprotocols = set(client_subprotocols) & set(server_subprotocols)
547
+ if not subprotocols:
548
+ return None
549
+ return sorted(
550
+ subprotocols,
551
+ key=lambda p: client_subprotocols.index(p) + server_subprotocols.index(p),
552
+ )[0]
553
+
554
+ async def handshake(
555
+ self,
556
+ origins: Sequence[Origin | None] | None = None,
557
+ available_extensions: Sequence[ServerExtensionFactory] | None = None,
558
+ available_subprotocols: Sequence[Subprotocol] | None = None,
559
+ extra_headers: HeadersLikeOrCallable | None = None,
560
+ ) -> str:
561
+ """
562
+ Perform the server side of the opening handshake.
563
+
564
+ Args:
565
+ origins: List of acceptable values of the Origin HTTP header;
566
+ include :obj:`None` if the lack of an origin is acceptable.
567
+ extensions: List of supported extensions, in order in which they
568
+ should be tried.
569
+ subprotocols: List of supported subprotocols, in order of
570
+ decreasing preference.
571
+ extra_headers: Arbitrary HTTP headers to add to the response when
572
+ the handshake succeeds.
573
+
574
+ Returns:
575
+ path of the URI of the request.
576
+
577
+ Raises:
578
+ InvalidHandshake: If the handshake fails.
579
+
580
+ """
581
+ path, request_headers = await self.read_http_request()
582
+
583
+ # Hook for customizing request handling, for example checking
584
+ # authentication or treating some paths as plain HTTP endpoints.
585
+ early_response_awaitable = self.process_request(path, request_headers)
586
+ if isinstance(early_response_awaitable, Awaitable):
587
+ early_response = await early_response_awaitable
588
+ else:
589
+ # For backwards compatibility with 7.0.
590
+ warnings.warn("declare process_request as a coroutine", DeprecationWarning)
591
+ early_response = early_response_awaitable
592
+
593
+ # The connection may drop while process_request is running.
594
+ if self.state is State.CLOSED:
595
+ # This subclass of ConnectionError is silently ignored in handler().
596
+ raise BrokenPipeError("connection closed during opening handshake")
597
+
598
+ # Change the response to a 503 error if the server is shutting down.
599
+ if not self.ws_server.is_serving():
600
+ early_response = (
601
+ http.HTTPStatus.SERVICE_UNAVAILABLE,
602
+ [],
603
+ b"Server is shutting down.\n",
604
+ )
605
+
606
+ if early_response is not None:
607
+ raise AbortHandshake(*early_response)
608
+
609
+ key = check_request(request_headers)
610
+
611
+ self.origin = self.process_origin(request_headers, origins)
612
+
613
+ extensions_header, self.extensions = self.process_extensions(
614
+ request_headers, available_extensions
615
+ )
616
+
617
+ protocol_header = self.subprotocol = self.process_subprotocol(
618
+ request_headers, available_subprotocols
619
+ )
620
+
621
+ response_headers = Headers()
622
+
623
+ build_response(response_headers, key)
624
+
625
+ if extensions_header is not None:
626
+ response_headers["Sec-WebSocket-Extensions"] = extensions_header
627
+
628
+ if protocol_header is not None:
629
+ response_headers["Sec-WebSocket-Protocol"] = protocol_header
630
+
631
+ if callable(extra_headers):
632
+ extra_headers = extra_headers(path, self.request_headers)
633
+ if extra_headers is not None:
634
+ response_headers.update(extra_headers)
635
+
636
+ response_headers.setdefault("Date", email.utils.formatdate(usegmt=True))
637
+ if self.server_header is not None:
638
+ response_headers.setdefault("Server", self.server_header)
639
+
640
+ self.write_http_response(http.HTTPStatus.SWITCHING_PROTOCOLS, response_headers)
641
+
642
+ self.logger.info("connection open")
643
+
644
+ self.connection_open()
645
+
646
+ return path
647
+
648
+
649
+ class WebSocketServer:
650
+ """
651
+ WebSocket server returned by :func:`serve`.
652
+
653
+ This class mirrors the API of :class:`~asyncio.Server`.
654
+
655
+ It keeps track of WebSocket connections in order to close them properly
656
+ when shutting down.
657
+
658
+ Args:
659
+ logger: Logger for this server.
660
+ It defaults to ``logging.getLogger("websockets.server")``.
661
+ See the :doc:`logging guide <../../topics/logging>` for details.
662
+
663
+ """
664
+
665
+ def __init__(self, logger: LoggerLike | None = None) -> None:
666
+ if logger is None:
667
+ logger = logging.getLogger("websockets.server")
668
+ self.logger = logger
669
+
670
+ # Keep track of active connections.
671
+ self.websockets: set[WebSocketServerProtocol] = set()
672
+
673
+ # Task responsible for closing the server and terminating connections.
674
+ self.close_task: asyncio.Task[None] | None = None
675
+
676
+ # Completed when the server is closed and connections are terminated.
677
+ self.closed_waiter: asyncio.Future[None]
678
+
679
+ def wrap(self, server: asyncio.base_events.Server) -> None:
680
+ """
681
+ Attach to a given :class:`~asyncio.Server`.
682
+
683
+ Since :meth:`~asyncio.loop.create_server` doesn't support injecting a
684
+ custom ``Server`` class, the easiest solution that doesn't rely on
685
+ private :mod:`asyncio` APIs is to:
686
+
687
+ - instantiate a :class:`WebSocketServer`
688
+ - give the protocol factory a reference to that instance
689
+ - call :meth:`~asyncio.loop.create_server` with the factory
690
+ - attach the resulting :class:`~asyncio.Server` with this method
691
+
692
+ """
693
+ self.server = server
694
+ for sock in server.sockets:
695
+ if sock.family == socket.AF_INET:
696
+ name = "%s:%d" % sock.getsockname()
697
+ elif sock.family == socket.AF_INET6:
698
+ name = "[%s]:%d" % sock.getsockname()[:2]
699
+ elif sock.family == socket.AF_UNIX:
700
+ name = sock.getsockname()
701
+ # In the unlikely event that someone runs websockets over a
702
+ # protocol other than IP or Unix sockets, avoid crashing.
703
+ else: # pragma: no cover
704
+ name = str(sock.getsockname())
705
+ self.logger.info("server listening on %s", name)
706
+
707
+ # Initialized here because we need a reference to the event loop.
708
+ # This could be moved back to __init__ now that Python < 3.10 isn't
709
+ # supported anymore, but I'm not taking that risk in legacy code.
710
+ self.closed_waiter = server.get_loop().create_future()
711
+
712
+ def register(self, protocol: WebSocketServerProtocol) -> None:
713
+ """
714
+ Register a connection with this server.
715
+
716
+ """
717
+ self.websockets.add(protocol)
718
+
719
+ def unregister(self, protocol: WebSocketServerProtocol) -> None:
720
+ """
721
+ Unregister a connection with this server.
722
+
723
+ """
724
+ self.websockets.remove(protocol)
725
+
726
+ def close(self, close_connections: bool = True) -> None:
727
+ """
728
+ Close the server.
729
+
730
+ * Close the underlying :class:`~asyncio.Server`.
731
+ * When ``close_connections`` is :obj:`True`, which is the default,
732
+ close existing connections. Specifically:
733
+
734
+ * Reject opening WebSocket connections with an HTTP 503 (service
735
+ unavailable) error. This happens when the server accepted the TCP
736
+ connection but didn't complete the opening handshake before closing.
737
+ * Close open WebSocket connections with close code 1001 (going away).
738
+
739
+ * Wait until all connection handlers terminate.
740
+
741
+ :meth:`close` is idempotent.
742
+
743
+ """
744
+ if self.close_task is None:
745
+ self.close_task = self.get_loop().create_task(
746
+ self._close(close_connections)
747
+ )
748
+
749
+ async def _close(self, close_connections: bool) -> None:
750
+ """
751
+ Implementation of :meth:`close`.
752
+
753
+ This calls :meth:`~asyncio.Server.close` on the underlying
754
+ :class:`~asyncio.Server` object to stop accepting new connections and
755
+ then closes open connections with close code 1001.
756
+
757
+ """
758
+ self.logger.info("server closing")
759
+
760
+ # Stop accepting new connections.
761
+ self.server.close()
762
+
763
+ # Wait until all accepted connections reach connection_made() and call
764
+ # register(). See https://github.com/python/cpython/issues/79033 for
765
+ # details. This workaround can be removed when dropping Python < 3.11.
766
+ await asyncio.sleep(0)
767
+
768
+ if close_connections:
769
+ # Close OPEN connections with close code 1001. After server.close(),
770
+ # handshake() closes OPENING connections with an HTTP 503 error.
771
+ close_tasks = [
772
+ asyncio.create_task(websocket.close(1001))
773
+ for websocket in self.websockets
774
+ if websocket.state is not State.CONNECTING
775
+ ]
776
+ # asyncio.wait doesn't accept an empty first argument.
777
+ if close_tasks:
778
+ await asyncio.wait(close_tasks)
779
+
780
+ # Wait until all TCP connections are closed.
781
+ await self.server.wait_closed()
782
+
783
+ # Wait until all connection handlers terminate.
784
+ # asyncio.wait doesn't accept an empty first argument.
785
+ if self.websockets:
786
+ await asyncio.wait(
787
+ [websocket.handler_task for websocket in self.websockets]
788
+ )
789
+
790
+ # Tell wait_closed() to return.
791
+ self.closed_waiter.set_result(None)
792
+
793
+ self.logger.info("server closed")
794
+
795
+ async def wait_closed(self) -> None:
796
+ """
797
+ Wait until the server is closed.
798
+
799
+ When :meth:`wait_closed` returns, all TCP connections are closed and
800
+ all connection handlers have returned.
801
+
802
+ To ensure a fast shutdown, a connection handler should always be
803
+ awaiting at least one of:
804
+
805
+ * :meth:`~WebSocketServerProtocol.recv`: when the connection is closed,
806
+ it raises :exc:`~websockets.exceptions.ConnectionClosedOK`;
807
+ * :meth:`~WebSocketServerProtocol.wait_closed`: when the connection is
808
+ closed, it returns.
809
+
810
+ Then the connection handler is immediately notified of the shutdown;
811
+ it can clean up and exit.
812
+
813
+ """
814
+ await asyncio.shield(self.closed_waiter)
815
+
816
+ def get_loop(self) -> asyncio.AbstractEventLoop:
817
+ """
818
+ See :meth:`asyncio.Server.get_loop`.
819
+
820
+ """
821
+ return self.server.get_loop()
822
+
823
+ def is_serving(self) -> bool:
824
+ """
825
+ See :meth:`asyncio.Server.is_serving`.
826
+
827
+ """
828
+ return self.server.is_serving()
829
+
830
+ async def start_serving(self) -> None: # pragma: no cover
831
+ """
832
+ See :meth:`asyncio.Server.start_serving`.
833
+
834
+ Typical use::
835
+
836
+ server = await serve(..., start_serving=False)
837
+ # perform additional setup here...
838
+ # ... then start the server
839
+ await server.start_serving()
840
+
841
+ """
842
+ await self.server.start_serving()
843
+
844
+ async def serve_forever(self) -> None: # pragma: no cover
845
+ """
846
+ See :meth:`asyncio.Server.serve_forever`.
847
+
848
+ Typical use::
849
+
850
+ server = await serve(...)
851
+ # this coroutine doesn't return
852
+ # canceling it stops the server
853
+ await server.serve_forever()
854
+
855
+ This is an alternative to using :func:`serve` as an asynchronous context
856
+ manager. Shutdown is triggered by canceling :meth:`serve_forever`
857
+ instead of exiting a :func:`serve` context.
858
+
859
+ """
860
+ await self.server.serve_forever()
861
+
862
+ @property
863
+ def sockets(self) -> Iterable[socket.socket]:
864
+ """
865
+ See :attr:`asyncio.Server.sockets`.
866
+
867
+ """
868
+ return self.server.sockets
869
+
870
+ async def __aenter__(self) -> WebSocketServer: # pragma: no cover
871
+ return self
872
+
873
+ async def __aexit__(
874
+ self,
875
+ exc_type: type[BaseException] | None,
876
+ exc_value: BaseException | None,
877
+ traceback: TracebackType | None,
878
+ ) -> None: # pragma: no cover
879
+ self.close()
880
+ await self.wait_closed()
881
+
882
+
883
+ class Serve:
884
+ """
885
+ Start a WebSocket server listening on ``host`` and ``port``.
886
+
887
+ Whenever a client connects, the server creates a
888
+ :class:`WebSocketServerProtocol`, performs the opening handshake, and
889
+ delegates to the connection handler, ``ws_handler``.
890
+
891
+ The handler receives the :class:`WebSocketServerProtocol` and uses it to
892
+ send and receive messages.
893
+
894
+ Once the handler completes, either normally or with an exception, the
895
+ server performs the closing handshake and closes the connection.
896
+
897
+ Awaiting :func:`serve` yields a :class:`WebSocketServer`. This object
898
+ provides a :meth:`~WebSocketServer.close` method to shut down the server::
899
+
900
+ # set this future to exit the server
901
+ stop = asyncio.get_running_loop().create_future()
902
+
903
+ server = await serve(...)
904
+ await stop
905
+ server.close()
906
+ await server.wait_closed()
907
+
908
+ :func:`serve` can be used as an asynchronous context manager. Then, the
909
+ server is shut down automatically when exiting the context::
910
+
911
+ # set this future to exit the server
912
+ stop = asyncio.get_running_loop().create_future()
913
+
914
+ async with serve(...):
915
+ await stop
916
+
917
+ Args:
918
+ ws_handler: Connection handler. It receives the WebSocket connection,
919
+ which is a :class:`WebSocketServerProtocol`, in argument.
920
+ host: Network interfaces the server binds to.
921
+ See :meth:`~asyncio.loop.create_server` for details.
922
+ port: TCP port the server listens on.
923
+ See :meth:`~asyncio.loop.create_server` for details.
924
+ create_protocol: Factory for the :class:`asyncio.Protocol` managing
925
+ the connection. It defaults to :class:`WebSocketServerProtocol`.
926
+ Set it to a wrapper or a subclass to customize connection handling.
927
+ logger: Logger for this server.
928
+ It defaults to ``logging.getLogger("websockets.server")``.
929
+ See the :doc:`logging guide <../../topics/logging>` for details.
930
+ compression: The "permessage-deflate" extension is enabled by default.
931
+ Set ``compression`` to :obj:`None` to disable it. See the
932
+ :doc:`compression guide <../../topics/compression>` for details.
933
+ origins: Acceptable values of the ``Origin`` header, for defending
934
+ against Cross-Site WebSocket Hijacking attacks. Include :obj:`None`
935
+ in the list if the lack of an origin is acceptable.
936
+ extensions: List of supported extensions, in order in which they
937
+ should be negotiated and run.
938
+ subprotocols: List of supported subprotocols, in order of decreasing
939
+ preference.
940
+ extra_headers (HeadersLike | Callable[[str, Headers] | HeadersLike]):
941
+ Arbitrary HTTP headers to add to the response. This can be
942
+ a :data:`~websockets.datastructures.HeadersLike` or a callable
943
+ taking the request path and headers in arguments and returning
944
+ a :data:`~websockets.datastructures.HeadersLike`.
945
+ server_header: Value of the ``Server`` response header.
946
+ It defaults to ``"Python/x.y.z websockets/X.Y"``.
947
+ Setting it to :obj:`None` removes the header.
948
+ process_request (Callable[[str, Headers], \
949
+ Awaitable[tuple[StatusLike, HeadersLike, bytes] | None]] | None):
950
+ Intercept HTTP request before the opening handshake.
951
+ See :meth:`~WebSocketServerProtocol.process_request` for details.
952
+ select_subprotocol: Select a subprotocol supported by the client.
953
+ See :meth:`~WebSocketServerProtocol.select_subprotocol` for details.
954
+ open_timeout: Timeout for opening connections in seconds.
955
+ :obj:`None` disables the timeout.
956
+
957
+ See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the
958
+ documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``,
959
+ ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``.
960
+
961
+ Any other keyword arguments are passed the event loop's
962
+ :meth:`~asyncio.loop.create_server` method.
963
+
964
+ For example:
965
+
966
+ * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enable TLS.
967
+
968
+ * You can set ``sock`` to a :obj:`~socket.socket` that you created
969
+ outside of websockets.
970
+
971
+ Returns:
972
+ WebSocket server.
973
+
974
+ """
975
+
976
+ def __init__(
977
+ self,
978
+ # The version that accepts the path in the second argument is deprecated.
979
+ ws_handler: (
980
+ Callable[[WebSocketServerProtocol], Awaitable[Any]]
981
+ | Callable[[WebSocketServerProtocol, str], Awaitable[Any]]
982
+ ),
983
+ host: str | Sequence[str] | None = None,
984
+ port: int | None = None,
985
+ *,
986
+ create_protocol: Callable[..., WebSocketServerProtocol] | None = None,
987
+ logger: LoggerLike | None = None,
988
+ compression: str | None = "deflate",
989
+ origins: Sequence[Origin | None] | None = None,
990
+ extensions: Sequence[ServerExtensionFactory] | None = None,
991
+ subprotocols: Sequence[Subprotocol] | None = None,
992
+ extra_headers: HeadersLikeOrCallable | None = None,
993
+ server_header: str | None = SERVER,
994
+ process_request: (
995
+ Callable[[str, Headers], Awaitable[HTTPResponse | None]] | None
996
+ ) = None,
997
+ select_subprotocol: (
998
+ Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] | None
999
+ ) = None,
1000
+ open_timeout: float | None = 10,
1001
+ ping_interval: float | None = 20,
1002
+ ping_timeout: float | None = 20,
1003
+ close_timeout: float | None = None,
1004
+ max_size: int | None = 2**20,
1005
+ max_queue: int | None = 2**5,
1006
+ read_limit: int = 2**16,
1007
+ write_limit: int = 2**16,
1008
+ **kwargs: Any,
1009
+ ) -> None:
1010
+ # Backwards compatibility: close_timeout used to be called timeout.
1011
+ timeout: float | None = kwargs.pop("timeout", None)
1012
+ if timeout is None:
1013
+ timeout = 10
1014
+ else:
1015
+ warnings.warn("rename timeout to close_timeout", DeprecationWarning)
1016
+ # If both are specified, timeout is ignored.
1017
+ if close_timeout is None:
1018
+ close_timeout = timeout
1019
+
1020
+ # Backwards compatibility: create_protocol used to be called klass.
1021
+ klass: type[WebSocketServerProtocol] | None = kwargs.pop("klass", None)
1022
+ if klass is None:
1023
+ klass = WebSocketServerProtocol
1024
+ else:
1025
+ warnings.warn("rename klass to create_protocol", DeprecationWarning)
1026
+ # If both are specified, klass is ignored.
1027
+ if create_protocol is None:
1028
+ create_protocol = klass
1029
+
1030
+ # Backwards compatibility: recv() used to return None on closed connections
1031
+ legacy_recv: bool = kwargs.pop("legacy_recv", False)
1032
+
1033
+ # Backwards compatibility: the loop parameter used to be supported.
1034
+ _loop: asyncio.AbstractEventLoop | None = kwargs.pop("loop", None)
1035
+ if _loop is None:
1036
+ loop = asyncio.get_event_loop()
1037
+ else:
1038
+ loop = _loop
1039
+ warnings.warn("remove loop argument", DeprecationWarning)
1040
+
1041
+ ws_server = WebSocketServer(logger=logger)
1042
+
1043
+ secure = kwargs.get("ssl") is not None
1044
+
1045
+ if compression == "deflate":
1046
+ extensions = enable_server_permessage_deflate(extensions)
1047
+ elif compression is not None:
1048
+ raise ValueError(f"unsupported compression: {compression}")
1049
+
1050
+ if subprotocols is not None:
1051
+ validate_subprotocols(subprotocols)
1052
+
1053
+ # Help mypy and avoid this error: "type[WebSocketServerProtocol] |
1054
+ # Callable[..., WebSocketServerProtocol]" not callable [misc]
1055
+ create_protocol = cast(Callable[..., WebSocketServerProtocol], create_protocol)
1056
+ factory = functools.partial(
1057
+ create_protocol,
1058
+ # For backwards compatibility with 10.0 or earlier. Done here in
1059
+ # addition to WebSocketServerProtocol to trigger the deprecation
1060
+ # warning once per serve() call rather than once per connection.
1061
+ remove_path_argument(ws_handler),
1062
+ ws_server,
1063
+ host=host,
1064
+ port=port,
1065
+ secure=secure,
1066
+ open_timeout=open_timeout,
1067
+ ping_interval=ping_interval,
1068
+ ping_timeout=ping_timeout,
1069
+ close_timeout=close_timeout,
1070
+ max_size=max_size,
1071
+ max_queue=max_queue,
1072
+ read_limit=read_limit,
1073
+ write_limit=write_limit,
1074
+ loop=_loop,
1075
+ legacy_recv=legacy_recv,
1076
+ origins=origins,
1077
+ extensions=extensions,
1078
+ subprotocols=subprotocols,
1079
+ extra_headers=extra_headers,
1080
+ server_header=server_header,
1081
+ process_request=process_request,
1082
+ select_subprotocol=select_subprotocol,
1083
+ logger=logger,
1084
+ )
1085
+
1086
+ if kwargs.pop("unix", False):
1087
+ path: str | None = kwargs.pop("path", None)
1088
+ # unix_serve(path) must not specify host and port parameters.
1089
+ assert host is None and port is None
1090
+ create_server = functools.partial(
1091
+ loop.create_unix_server, factory, path, **kwargs
1092
+ )
1093
+ else:
1094
+ create_server = functools.partial(
1095
+ loop.create_server, factory, host, port, **kwargs
1096
+ )
1097
+
1098
+ # This is a coroutine function.
1099
+ self._create_server = create_server
1100
+ self.ws_server = ws_server
1101
+
1102
+ # async with serve(...)
1103
+
1104
+ async def __aenter__(self) -> WebSocketServer:
1105
+ return await self
1106
+
1107
+ async def __aexit__(
1108
+ self,
1109
+ exc_type: type[BaseException] | None,
1110
+ exc_value: BaseException | None,
1111
+ traceback: TracebackType | None,
1112
+ ) -> None:
1113
+ self.ws_server.close()
1114
+ await self.ws_server.wait_closed()
1115
+
1116
+ # await serve(...)
1117
+
1118
+ def __await__(self) -> Generator[Any, None, WebSocketServer]:
1119
+ # Create a suitable iterator by calling __await__ on a coroutine.
1120
+ return self.__await_impl__().__await__()
1121
+
1122
+ async def __await_impl__(self) -> WebSocketServer:
1123
+ server = await self._create_server()
1124
+ self.ws_server.wrap(server)
1125
+ return self.ws_server
1126
+
1127
+ # yield from serve(...) - remove when dropping Python < 3.11
1128
+
1129
+ __iter__ = __await__
1130
+
1131
+
1132
+ serve = Serve
1133
+
1134
+
1135
+ def unix_serve(
1136
+ # The version that accepts the path in the second argument is deprecated.
1137
+ ws_handler: (
1138
+ Callable[[WebSocketServerProtocol], Awaitable[Any]]
1139
+ | Callable[[WebSocketServerProtocol, str], Awaitable[Any]]
1140
+ ),
1141
+ path: str | None = None,
1142
+ **kwargs: Any,
1143
+ ) -> Serve:
1144
+ """
1145
+ Start a WebSocket server listening on a Unix socket.
1146
+
1147
+ This function is identical to :func:`serve`, except the ``host`` and
1148
+ ``port`` arguments are replaced by ``path``. It is only available on Unix.
1149
+
1150
+ Unrecognized keyword arguments are passed the event loop's
1151
+ :meth:`~asyncio.loop.create_unix_server` method.
1152
+
1153
+ It's useful for deploying a server behind a reverse proxy such as nginx.
1154
+
1155
+ Args:
1156
+ path: File system path to the Unix socket.
1157
+
1158
+ """
1159
+ return serve(ws_handler, path=path, unix=True, **kwargs)
1160
+
1161
+
1162
+ def remove_path_argument(
1163
+ ws_handler: (
1164
+ Callable[[WebSocketServerProtocol], Awaitable[Any]]
1165
+ | Callable[[WebSocketServerProtocol, str], Awaitable[Any]]
1166
+ ),
1167
+ ) -> Callable[[WebSocketServerProtocol], Awaitable[Any]]:
1168
+ try:
1169
+ inspect.signature(ws_handler).bind(None)
1170
+ except TypeError:
1171
+ try:
1172
+ inspect.signature(ws_handler).bind(None, "")
1173
+ except TypeError: # pragma: no cover
1174
+ # ws_handler accepts neither one nor two arguments; leave it alone.
1175
+ pass
1176
+ else:
1177
+ # ws_handler accepts two arguments; activate backwards compatibility.
1178
+ warnings.warn("remove second argument of ws_handler", DeprecationWarning)
1179
+
1180
+ async def _ws_handler(websocket: WebSocketServerProtocol) -> Any:
1181
+ return await cast(
1182
+ Callable[[WebSocketServerProtocol, str], Awaitable[Any]],
1183
+ ws_handler,
1184
+ )(websocket, websocket.path)
1185
+
1186
+ return _ws_handler
1187
+
1188
+ return cast(
1189
+ Callable[[WebSocketServerProtocol], Awaitable[Any]],
1190
+ ws_handler,
1191
+ )
source/websockets/protocol.py ADDED
@@ -0,0 +1,768 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import enum
4
+ import logging
5
+ import uuid
6
+ from collections.abc import Generator
7
+
8
+ from .exceptions import (
9
+ ConnectionClosed,
10
+ ConnectionClosedError,
11
+ ConnectionClosedOK,
12
+ InvalidState,
13
+ PayloadTooBig,
14
+ ProtocolError,
15
+ )
16
+ from .extensions import Extension
17
+ from .frames import (
18
+ OK_CLOSE_CODES,
19
+ OP_BINARY,
20
+ OP_CLOSE,
21
+ OP_CONT,
22
+ OP_PING,
23
+ OP_PONG,
24
+ OP_TEXT,
25
+ Close,
26
+ CloseCode,
27
+ Frame,
28
+ )
29
+ from .http11 import Request, Response
30
+ from .streams import StreamReader
31
+ from .typing import BytesLike, LoggerLike, Origin, Subprotocol
32
+
33
+
34
+ __all__ = [
35
+ "Protocol",
36
+ "Side",
37
+ "State",
38
+ "SEND_EOF",
39
+ ]
40
+
41
+ Event = Request | Response | Frame
42
+ """Events that :meth:`~Protocol.events_received` may return."""
43
+
44
+
45
+ class Side(enum.IntEnum):
46
+ """A WebSocket connection is either a server or a client."""
47
+
48
+ SERVER, CLIENT = range(2)
49
+
50
+
51
+ SERVER = Side.SERVER
52
+ CLIENT = Side.CLIENT
53
+
54
+
55
+ class State(enum.IntEnum):
56
+ """A WebSocket connection is in one of these four states."""
57
+
58
+ CONNECTING, OPEN, CLOSING, CLOSED = range(4)
59
+
60
+
61
+ CONNECTING = State.CONNECTING
62
+ OPEN = State.OPEN
63
+ CLOSING = State.CLOSING
64
+ CLOSED = State.CLOSED
65
+
66
+
67
+ SEND_EOF = b""
68
+ """Sentinel signaling that the TCP connection must be half-closed."""
69
+
70
+
71
+ class Protocol:
72
+ """
73
+ Sans-I/O implementation of a WebSocket connection.
74
+
75
+ Args:
76
+ side: :attr:`~Side.CLIENT` or :attr:`~Side.SERVER`.
77
+ state: Initial state of the WebSocket connection.
78
+ max_size: Maximum size of incoming messages in bytes.
79
+ :obj:`None` disables the limit. You may pass a ``(max_message_size,
80
+ max_fragment_size)`` tuple to set different limits for messages and
81
+ fragments when you expect long messages sent in short fragments.
82
+ logger: Logger for this connection; depending on ``side``,
83
+ defaults to ``logging.getLogger("websockets.client")``
84
+ or ``logging.getLogger("websockets.server")``;
85
+ see the :doc:`logging guide <../../topics/logging>` for details.
86
+
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ side: Side,
92
+ *,
93
+ state: State = OPEN,
94
+ max_size: tuple[int | None, int | None] | int | None = 2**20,
95
+ logger: LoggerLike | None = None,
96
+ ) -> None:
97
+ # Unique identifier. For logs.
98
+ self.id: uuid.UUID = uuid.uuid4()
99
+ """Unique identifier of the connection. Useful in logs."""
100
+
101
+ # Logger or LoggerAdapter for this connection.
102
+ if logger is None:
103
+ logger = logging.getLogger(f"websockets.{side.name.lower()}")
104
+ self.logger: LoggerLike = logger
105
+ """Logger for this connection."""
106
+
107
+ # Track if DEBUG is enabled. Shortcut logging calls if it isn't.
108
+ self.debug = logger.isEnabledFor(logging.DEBUG)
109
+
110
+ # Connection side. CLIENT or SERVER.
111
+ self.side = side
112
+
113
+ # Connection state. Initially OPEN because subclasses handle CONNECTING.
114
+ self.state = state
115
+
116
+ # Maximum size of incoming messages in bytes.
117
+ if isinstance(max_size, int) or max_size is None:
118
+ self.max_message_size, self.max_fragment_size = max_size, None
119
+ else:
120
+ self.max_message_size, self.max_fragment_size = max_size
121
+
122
+ # Current size of incoming message in bytes. Only set while reading a
123
+ # fragmented message i.e. a data frames with the FIN bit not set.
124
+ self.current_size: int | None = None
125
+
126
+ # True while sending a fragmented message i.e. a data frames with the
127
+ # FIN bit not set.
128
+ self.expect_continuation_frame = False
129
+
130
+ # WebSocket protocol parameters.
131
+ self.origin: Origin | None = None
132
+ self.extensions: list[Extension] = []
133
+ self.subprotocol: Subprotocol | None = None
134
+
135
+ # Close code and reason, set when a close frame is sent or received.
136
+ self.close_rcvd: Close | None = None
137
+ self.close_sent: Close | None = None
138
+ self.close_rcvd_then_sent: bool | None = None
139
+
140
+ # Track if an exception happened during the handshake.
141
+ self.handshake_exc: Exception | None = None
142
+ """
143
+ Exception to raise if the opening handshake failed.
144
+
145
+ :obj:`None` if the opening handshake succeeded.
146
+
147
+ """
148
+
149
+ # Track if send_eof() was called.
150
+ self.eof_sent = False
151
+
152
+ # Parser state.
153
+ self.reader = StreamReader()
154
+ self.events: list[Event] = []
155
+ self.writes: list[bytes] = []
156
+ self.parser = self.parse()
157
+ next(self.parser) # start coroutine
158
+ self.parser_exc: Exception | None = None
159
+
160
+ @property
161
+ def state(self) -> State:
162
+ """
163
+ State of the WebSocket connection.
164
+
165
+ Defined in 4.1_, 4.2_, 7.1.3_, and 7.1.4_ of :rfc:`6455`.
166
+
167
+ .. _4.1: https://datatracker.ietf.org/doc/html/rfc6455#section-4.1
168
+ .. _4.2: https://datatracker.ietf.org/doc/html/rfc6455#section-4.2
169
+ .. _7.1.3: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.3
170
+ .. _7.1.4: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.4
171
+
172
+ """
173
+ return self._state
174
+
175
+ @state.setter
176
+ def state(self, state: State) -> None:
177
+ if self.debug:
178
+ self.logger.debug("= connection is %s", state.name)
179
+ self._state = state
180
+
181
+ @property
182
+ def close_code(self) -> int | None:
183
+ """
184
+ WebSocket close code received from the remote endpoint.
185
+
186
+ Defined in 7.1.5_ of :rfc:`6455`.
187
+
188
+ .. _7.1.5: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.5
189
+
190
+ :obj:`None` if the connection isn't closed yet.
191
+
192
+ """
193
+ if self.state is not CLOSED:
194
+ return None
195
+ elif self.close_rcvd is None:
196
+ return CloseCode.ABNORMAL_CLOSURE
197
+ else:
198
+ return self.close_rcvd.code
199
+
200
+ @property
201
+ def close_reason(self) -> str | None:
202
+ """
203
+ WebSocket close reason received from the remote endpoint.
204
+
205
+ Defined in 7.1.6_ of :rfc:`6455`.
206
+
207
+ .. _7.1.6: https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.6
208
+
209
+ :obj:`None` if the connection isn't closed yet.
210
+
211
+ """
212
+ if self.state is not CLOSED:
213
+ return None
214
+ elif self.close_rcvd is None:
215
+ return ""
216
+ else:
217
+ return self.close_rcvd.reason
218
+
219
+ @property
220
+ def close_exc(self) -> ConnectionClosed:
221
+ """
222
+ Exception to raise when trying to interact with a closed connection.
223
+
224
+ Don't raise this exception while the connection :attr:`state`
225
+ is :attr:`~websockets.protocol.State.CLOSING`; wait until
226
+ it's :attr:`~websockets.protocol.State.CLOSED`.
227
+
228
+ Indeed, the exception includes the close code and reason, which are
229
+ known only once the connection is closed.
230
+
231
+ Raises:
232
+ AssertionError: If the connection isn't closed yet.
233
+
234
+ """
235
+ assert self.state is CLOSED, "connection isn't closed yet"
236
+ exc_type: type[ConnectionClosed]
237
+ if (
238
+ self.close_rcvd is not None
239
+ and self.close_sent is not None
240
+ and self.close_rcvd.code in OK_CLOSE_CODES
241
+ and self.close_sent.code in OK_CLOSE_CODES
242
+ ):
243
+ exc_type = ConnectionClosedOK
244
+ else:
245
+ exc_type = ConnectionClosedError
246
+ exc: ConnectionClosed = exc_type(
247
+ self.close_rcvd,
248
+ self.close_sent,
249
+ self.close_rcvd_then_sent,
250
+ )
251
+ # Chain to the exception raised in the parser, if any.
252
+ exc.__cause__ = self.parser_exc
253
+ return exc
254
+
255
+ # Public methods for receiving data.
256
+
257
+ def receive_data(self, data: bytes | bytearray) -> None:
258
+ """
259
+ Receive data from the network.
260
+
261
+ After calling this method:
262
+
263
+ - You must call :meth:`data_to_send` and send this data to the network.
264
+ - You should call :meth:`events_received` and process resulting events.
265
+
266
+ Raises:
267
+ EOFError: If :meth:`receive_eof` was called earlier.
268
+
269
+ """
270
+ self.reader.feed_data(data)
271
+ next(self.parser)
272
+
273
+ def receive_eof(self) -> None:
274
+ """
275
+ Receive the end of the data stream from the network.
276
+
277
+ After calling this method:
278
+
279
+ - You must call :meth:`data_to_send` and send this data to the network;
280
+ it will return ``[b""]``, signaling the end of the stream, or ``[]``.
281
+ - You aren't expected to call :meth:`events_received`; it won't return
282
+ any new events.
283
+
284
+ :meth:`receive_eof` is idempotent.
285
+
286
+ """
287
+ if self.reader.eof:
288
+ return
289
+ self.reader.feed_eof()
290
+ next(self.parser)
291
+
292
+ # Public methods for sending events.
293
+
294
+ def send_continuation(self, data: BytesLike, fin: bool) -> None:
295
+ """
296
+ Send a `Continuation frame`_.
297
+
298
+ .. _Continuation frame:
299
+ https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
300
+
301
+ Parameters:
302
+ data: payload containing the same kind of data
303
+ as the initial frame.
304
+ fin: FIN bit; set it to :obj:`True` if this is the last frame
305
+ of a fragmented message and to :obj:`False` otherwise.
306
+
307
+ Raises:
308
+ ProtocolError: If a fragmented message isn't in progress.
309
+
310
+ """
311
+ if not self.expect_continuation_frame:
312
+ raise ProtocolError("unexpected continuation frame")
313
+ if self._state is not OPEN:
314
+ raise InvalidState(f"connection is {self.state.name.lower()}")
315
+ self.expect_continuation_frame = not fin
316
+ self.send_frame(Frame(OP_CONT, data, fin))
317
+
318
+ def send_text(self, data: BytesLike, fin: bool = True) -> None:
319
+ """
320
+ Send a `Text frame`_.
321
+
322
+ .. _Text frame:
323
+ https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
324
+
325
+ Parameters:
326
+ data: payload containing text encoded with UTF-8.
327
+ fin: FIN bit; set it to :obj:`False` if this is the first frame of
328
+ a fragmented message.
329
+
330
+ Raises:
331
+ ProtocolError: If a fragmented message is in progress.
332
+
333
+ """
334
+ if self.expect_continuation_frame:
335
+ raise ProtocolError("expected a continuation frame")
336
+ if self._state is not OPEN:
337
+ raise InvalidState(f"connection is {self.state.name.lower()}")
338
+ self.expect_continuation_frame = not fin
339
+ self.send_frame(Frame(OP_TEXT, data, fin))
340
+
341
+ def send_binary(self, data: BytesLike, fin: bool = True) -> None:
342
+ """
343
+ Send a `Binary frame`_.
344
+
345
+ .. _Binary frame:
346
+ https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
347
+
348
+ Parameters:
349
+ data: payload containing arbitrary binary data.
350
+ fin: FIN bit; set it to :obj:`False` if this is the first frame of
351
+ a fragmented message.
352
+
353
+ Raises:
354
+ ProtocolError: If a fragmented message is in progress.
355
+
356
+ """
357
+ if self.expect_continuation_frame:
358
+ raise ProtocolError("expected a continuation frame")
359
+ if self._state is not OPEN:
360
+ raise InvalidState(f"connection is {self.state.name.lower()}")
361
+ self.expect_continuation_frame = not fin
362
+ self.send_frame(Frame(OP_BINARY, data, fin))
363
+
364
+ def send_close(self, code: CloseCode | int | None = None, reason: str = "") -> None:
365
+ """
366
+ Send a `Close frame`_.
367
+
368
+ .. _Close frame:
369
+ https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.1
370
+
371
+ Parameters:
372
+ code: close code.
373
+ reason: close reason.
374
+
375
+ Raises:
376
+ ProtocolError: If the code isn't valid or if a reason is provided
377
+ without a code.
378
+
379
+ """
380
+ # While RFC 6455 doesn't rule out sending more than one close Frame,
381
+ # websockets is conservative in what it sends and doesn't allow that.
382
+ if self._state is not OPEN:
383
+ raise InvalidState(f"connection is {self.state.name.lower()}")
384
+ if code is None:
385
+ if reason != "":
386
+ raise ProtocolError("cannot send a reason without a code")
387
+ close = Close(CloseCode.NO_STATUS_RCVD, "")
388
+ data = b""
389
+ else:
390
+ close = Close(code, reason)
391
+ data = close.serialize()
392
+ # 7.1.3. The WebSocket Closing Handshake is Started
393
+ self.send_frame(Frame(OP_CLOSE, data))
394
+ # Since the state is OPEN, no close frame was received yet.
395
+ # As a consequence, self.close_rcvd_then_sent remains None.
396
+ assert self.close_rcvd is None
397
+ self.close_sent = close
398
+ self.state = CLOSING
399
+
400
+ def send_ping(self, data: BytesLike) -> None:
401
+ """
402
+ Send a `Ping frame`_.
403
+
404
+ .. _Ping frame:
405
+ https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2
406
+
407
+ Parameters:
408
+ data: payload containing arbitrary binary data.
409
+
410
+ """
411
+ # RFC 6455 allows control frames after starting the closing handshake.
412
+ if self._state is not OPEN and self._state is not CLOSING:
413
+ raise InvalidState(f"connection is {self.state.name.lower()}")
414
+ self.send_frame(Frame(OP_PING, data))
415
+
416
+ def send_pong(self, data: BytesLike) -> None:
417
+ """
418
+ Send a `Pong frame`_.
419
+
420
+ .. _Pong frame:
421
+ https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3
422
+
423
+ Parameters:
424
+ data: payload containing arbitrary binary data.
425
+
426
+ """
427
+ # RFC 6455 allows control frames after starting the closing handshake.
428
+ if self._state is not OPEN and self._state is not CLOSING:
429
+ raise InvalidState(f"connection is {self.state.name.lower()}")
430
+ self.send_frame(Frame(OP_PONG, data))
431
+
432
+ def fail(self, code: CloseCode | int, reason: str = "") -> None:
433
+ """
434
+ `Fail the WebSocket connection`_.
435
+
436
+ .. _Fail the WebSocket connection:
437
+ https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.7
438
+
439
+ Parameters:
440
+ code: close code
441
+ reason: close reason
442
+
443
+ Raises:
444
+ ProtocolError: If the code isn't valid.
445
+ """
446
+ # 7.1.7. Fail the WebSocket Connection
447
+
448
+ # Send a close frame when the state is OPEN (a close frame was already
449
+ # sent if it's CLOSING), except when failing the connection because
450
+ # of an error reading from or writing to the network.
451
+ if self.state is OPEN:
452
+ if code != CloseCode.ABNORMAL_CLOSURE:
453
+ close = Close(code, reason)
454
+ data = close.serialize()
455
+ self.send_frame(Frame(OP_CLOSE, data))
456
+ self.close_sent = close
457
+ # If recv_messages() raised an exception upon receiving a close
458
+ # frame but before echoing it, then close_rcvd is not None even
459
+ # though the state is OPEN. This happens when the connection is
460
+ # closed while receiving a fragmented message.
461
+ if self.close_rcvd is not None:
462
+ self.close_rcvd_then_sent = True
463
+ self.state = CLOSING
464
+
465
+ # When failing the connection, a server closes the TCP connection
466
+ # without waiting for the client to complete the handshake, while a
467
+ # client waits for the server to close the TCP connection, possibly
468
+ # after sending a close frame that the client will ignore.
469
+ if self.side is SERVER and not self.eof_sent:
470
+ self.send_eof()
471
+
472
+ # 7.1.7. Fail the WebSocket Connection "An endpoint MUST NOT continue
473
+ # to attempt to process data(including a responding Close frame) from
474
+ # the remote endpoint after being instructed to _Fail the WebSocket
475
+ # Connection_."
476
+ self.parser = self.discard()
477
+ next(self.parser) # start coroutine
478
+
479
+ # Public method for getting incoming events after receiving data.
480
+
481
+ def events_received(self) -> list[Event]:
482
+ """
483
+ Fetch events generated from data received from the network.
484
+
485
+ Call this method immediately after any of the ``receive_*()`` methods.
486
+
487
+ Process resulting events, likely by passing them to the application.
488
+
489
+ Returns:
490
+ Events read from the connection.
491
+ """
492
+ events, self.events = self.events, []
493
+ return events
494
+
495
+ # Public method for getting outgoing data after receiving data or sending events.
496
+
497
+ def data_to_send(self) -> list[bytes]:
498
+ """
499
+ Obtain data to send to the network.
500
+
501
+ Call this method immediately after any of the ``receive_*()``,
502
+ ``send_*()``, or :meth:`fail` methods.
503
+
504
+ Write resulting data to the connection.
505
+
506
+ The empty bytestring :data:`~websockets.protocol.SEND_EOF` signals
507
+ the end of the data stream. When you receive it, half-close the TCP
508
+ connection.
509
+
510
+ Returns:
511
+ Data to write to the connection.
512
+
513
+ """
514
+ writes, self.writes = self.writes, []
515
+ return writes
516
+
517
+ def close_expected(self) -> bool:
518
+ """
519
+ Tell if the TCP connection is expected to close soon.
520
+
521
+ Call this method immediately after any of the ``receive_*()``,
522
+ ``send_close()``, or :meth:`fail` methods.
523
+
524
+ If it returns :obj:`True`, schedule closing the TCP connection after a
525
+ short timeout if the other side hasn't already closed it.
526
+
527
+ Returns:
528
+ Whether the TCP connection is expected to close soon.
529
+
530
+ """
531
+ # During the opening handshake, when our state is CONNECTING, we expect
532
+ # a TCP close if and only if the hansdake fails. When it does, we start
533
+ # the TCP closing handshake by sending EOF with send_eof().
534
+
535
+ # Once the opening handshake completes successfully, we expect a TCP
536
+ # close if and only if we sent a close frame, meaning that our state
537
+ # progressed to CLOSING:
538
+
539
+ # * Normal closure: once we send a close frame, we expect a TCP close:
540
+ # server waits for client to complete the TCP closing handshake;
541
+ # client waits for server to initiate the TCP closing handshake.
542
+
543
+ # * Abnormal closure: we always send a close frame and the same logic
544
+ # applies, except on EOFError where we don't send a close frame
545
+ # because we already received the TCP close, so we don't expect it.
546
+
547
+ # If our state is CLOSED, we already received a TCP close so we don't
548
+ # expect it anymore.
549
+
550
+ # Micro-optimization: put the most common case first
551
+ if self.state is OPEN:
552
+ return False
553
+ if self.state is CLOSING:
554
+ return True
555
+ if self.state is CLOSED:
556
+ return False
557
+ assert self.state is CONNECTING
558
+ return self.eof_sent
559
+
560
+ # Private methods for receiving data.
561
+
562
+ def parse(self) -> Generator[None]:
563
+ """
564
+ Parse incoming data into frames.
565
+
566
+ :meth:`receive_data` and :meth:`receive_eof` run this generator
567
+ coroutine until it needs more data or reaches EOF.
568
+
569
+ :meth:`parse` never raises an exception. Instead, it sets the
570
+ :attr:`parser_exc` and yields control.
571
+
572
+ """
573
+ try:
574
+ while True:
575
+ if (yield from self.reader.at_eof()):
576
+ if self.debug:
577
+ self.logger.debug("< EOF")
578
+ # If the WebSocket connection is closed cleanly, with a
579
+ # closing handhshake, recv_frame() substitutes parse()
580
+ # with discard(). This branch is reached only when the
581
+ # connection isn't closed cleanly.
582
+ raise EOFError("unexpected end of stream")
583
+
584
+ max_size = None
585
+
586
+ if self.max_message_size is not None:
587
+ if self.current_size is None:
588
+ max_size = self.max_message_size
589
+ else:
590
+ max_size = self.max_message_size - self.current_size
591
+
592
+ if self.max_fragment_size is not None:
593
+ if max_size is None:
594
+ max_size = self.max_fragment_size
595
+ else:
596
+ max_size = min(max_size, self.max_fragment_size)
597
+
598
+ # During a normal closure, execution ends here on the next
599
+ # iteration of the loop after receiving a close frame. At
600
+ # this point, recv_frame() replaced parse() by discard().
601
+ frame = yield from Frame.parse(
602
+ self.reader.read_exact,
603
+ mask=self.side is SERVER,
604
+ max_size=max_size,
605
+ extensions=self.extensions,
606
+ )
607
+
608
+ if self.debug:
609
+ self.logger.debug("< %s", frame)
610
+
611
+ self.recv_frame(frame)
612
+
613
+ except ProtocolError as exc:
614
+ self.fail(CloseCode.PROTOCOL_ERROR, str(exc))
615
+ self.parser_exc = exc
616
+
617
+ except EOFError as exc:
618
+ self.fail(CloseCode.ABNORMAL_CLOSURE, str(exc))
619
+ self.parser_exc = exc
620
+
621
+ except UnicodeDecodeError as exc:
622
+ self.fail(CloseCode.INVALID_DATA, f"{exc.reason} at position {exc.start}")
623
+ self.parser_exc = exc
624
+
625
+ except PayloadTooBig as exc:
626
+ exc.set_current_size(self.current_size)
627
+ self.fail(CloseCode.MESSAGE_TOO_BIG, str(exc))
628
+ self.parser_exc = exc
629
+
630
+ except Exception as exc:
631
+ self.logger.error("parser failed", exc_info=True)
632
+ # Don't include exception details, which may be security-sensitive.
633
+ self.fail(CloseCode.INTERNAL_ERROR)
634
+ self.parser_exc = exc
635
+
636
+ # During an abnormal closure, execution ends here after catching an
637
+ # exception. At this point, fail() replaced parse() by discard().
638
+ yield
639
+ raise AssertionError("parse() shouldn't step after error")
640
+
641
+ def discard(self) -> Generator[None]:
642
+ """
643
+ Discard incoming data.
644
+
645
+ This coroutine replaces :meth:`parse`:
646
+
647
+ - after receiving a close frame, during a normal closure (1.4);
648
+ - after sending a close frame, during an abnormal closure (7.1.7).
649
+
650
+ """
651
+ # After the opening handshake completes, the server closes the TCP
652
+ # connection in the same circumstances where discard() replaces parse().
653
+ # The client closes it when it receives EOF from the server or times
654
+ # out. (The latter case cannot be handled in this Sans-I/O layer.)
655
+ assert (self.side is SERVER or self.state is CONNECTING) == (self.eof_sent)
656
+ while not (yield from self.reader.at_eof()):
657
+ self.reader.discard()
658
+ if self.debug:
659
+ self.logger.debug("< EOF")
660
+ # A server closes the TCP connection immediately, while a client
661
+ # waits for the server to close the TCP connection.
662
+ if self.side is CLIENT and self.state is not CONNECTING:
663
+ self.send_eof()
664
+ self.state = CLOSED
665
+ # If discard() completes normally, execution ends here.
666
+ yield
667
+ # Once the reader reaches EOF, its feed_data/eof() methods raise an
668
+ # error, so our receive_data/eof() methods don't step the generator.
669
+ raise AssertionError("discard() shouldn't step after EOF")
670
+
671
+ def recv_frame(self, frame: Frame) -> None:
672
+ """
673
+ Process an incoming frame.
674
+
675
+ """
676
+ if frame.opcode is OP_TEXT or frame.opcode is OP_BINARY:
677
+ if self.current_size is not None:
678
+ raise ProtocolError("expected a continuation frame")
679
+ if not frame.fin:
680
+ self.current_size = len(frame.data)
681
+
682
+ elif frame.opcode is OP_CONT:
683
+ if self.current_size is None:
684
+ raise ProtocolError("unexpected continuation frame")
685
+ if frame.fin:
686
+ self.current_size = None
687
+ else:
688
+ self.current_size += len(frame.data)
689
+
690
+ elif frame.opcode is OP_PING:
691
+ # 5.5.2. Ping: "Upon receipt of a Ping frame, an endpoint MUST
692
+ # send a Pong frame in response"
693
+ pong_frame = Frame(OP_PONG, frame.data)
694
+ self.send_frame(pong_frame)
695
+
696
+ elif frame.opcode is OP_PONG:
697
+ # 5.5.3 Pong: "A response to an unsolicited Pong frame is not
698
+ # expected."
699
+ pass
700
+
701
+ elif frame.opcode is OP_CLOSE:
702
+ # 7.1.5. The WebSocket Connection Close Code
703
+ # 7.1.6. The WebSocket Connection Close Reason
704
+ self.close_rcvd = Close.parse(frame.data)
705
+ if self.state is CLOSING:
706
+ assert self.close_sent is not None
707
+ self.close_rcvd_then_sent = False
708
+
709
+ if self.current_size is not None:
710
+ raise ProtocolError("incomplete fragmented message")
711
+
712
+ # 5.5.1 Close: "If an endpoint receives a Close frame and did
713
+ # not previously send a Close frame, the endpoint MUST send a
714
+ # Close frame in response. (When sending a Close frame in
715
+ # response, the endpoint typically echos the status code it
716
+ # received.)"
717
+
718
+ if self.state is OPEN:
719
+ # Echo the original data instead of re-serializing it with
720
+ # Close.serialize() because that fails when the close frame
721
+ # is empty and Close.parse() synthesizes a 1005 close code.
722
+ # The rest is identical to send_close().
723
+ self.send_frame(Frame(OP_CLOSE, frame.data))
724
+ self.close_sent = self.close_rcvd
725
+ self.close_rcvd_then_sent = True
726
+ self.state = CLOSING
727
+
728
+ # 7.1.2. Start the WebSocket Closing Handshake: "Once an
729
+ # endpoint has both sent and received a Close control frame,
730
+ # that endpoint SHOULD _Close the WebSocket Connection_"
731
+
732
+ # A server closes the TCP connection immediately, while a client
733
+ # waits for the server to close the TCP connection.
734
+ if self.side is SERVER:
735
+ self.send_eof()
736
+
737
+ # 1.4. Closing Handshake: "after receiving a control frame
738
+ # indicating the connection should be closed, a peer discards
739
+ # any further data received."
740
+ # RFC 6455 allows reading Ping and Pong frames after a Close frame.
741
+ # However, that doesn't seem useful; websockets doesn't support it.
742
+ self.parser = self.discard()
743
+ next(self.parser) # start coroutine
744
+
745
+ else:
746
+ # This can't happen because Frame.parse() validates opcodes.
747
+ raise AssertionError(f"unexpected opcode: {frame.opcode:02x}")
748
+
749
+ self.events.append(frame)
750
+
751
+ # Private methods for sending events.
752
+
753
+ def send_frame(self, frame: Frame) -> None:
754
+ if self.debug:
755
+ self.logger.debug("> %s", frame)
756
+ self.writes.append(
757
+ frame.serialize(
758
+ mask=self.side is CLIENT,
759
+ extensions=self.extensions,
760
+ )
761
+ )
762
+
763
+ def send_eof(self) -> None:
764
+ assert not self.eof_sent
765
+ self.eof_sent = True
766
+ if self.debug:
767
+ self.logger.debug("> EOF")
768
+ self.writes.append(SEND_EOF)
source/websockets/proxy.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import urllib.parse
5
+ import urllib.request
6
+
7
+ from .datastructures import Headers
8
+ from .exceptions import InvalidProxy
9
+ from .headers import build_authorization_basic, build_host
10
+ from .http11 import USER_AGENT
11
+ from .uri import DELIMS, WebSocketURI
12
+
13
+
14
+ __all__ = ["get_proxy", "parse_proxy", "Proxy"]
15
+
16
+
17
+ @dataclasses.dataclass
18
+ class Proxy:
19
+ """
20
+ Proxy address.
21
+
22
+ Attributes:
23
+ scheme: ``"socks5h"``, ``"socks5"``, ``"socks4a"``, ``"socks4"``,
24
+ ``"https"``, or ``"http"``.
25
+ host: Normalized to lower case.
26
+ port: Always set even if it's the default.
27
+ username: Available when the proxy address contains `User Information`_.
28
+ password: Available when the proxy address contains `User Information`_.
29
+
30
+ .. _User Information: https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.1
31
+
32
+ """
33
+
34
+ scheme: str
35
+ host: str
36
+ port: int
37
+ username: str | None = None
38
+ password: str | None = None
39
+
40
+ @property
41
+ def user_info(self) -> tuple[str, str] | None:
42
+ if self.username is None:
43
+ return None
44
+ assert self.password is not None
45
+ return (self.username, self.password)
46
+
47
+
48
+ def parse_proxy(proxy: str) -> Proxy:
49
+ """
50
+ Parse and validate a proxy.
51
+
52
+ Args:
53
+ proxy: proxy.
54
+
55
+ Returns:
56
+ Parsed proxy.
57
+
58
+ Raises:
59
+ InvalidProxy: If ``proxy`` isn't a valid proxy.
60
+
61
+ """
62
+ parsed = urllib.parse.urlparse(proxy)
63
+ if parsed.scheme not in ["socks5h", "socks5", "socks4a", "socks4", "https", "http"]:
64
+ raise InvalidProxy(proxy, f"scheme {parsed.scheme} isn't supported")
65
+ if parsed.hostname is None:
66
+ raise InvalidProxy(proxy, "hostname isn't provided")
67
+ if parsed.path not in ["", "/"]:
68
+ raise InvalidProxy(proxy, "path is meaningless")
69
+ if parsed.query != "":
70
+ raise InvalidProxy(proxy, "query is meaningless")
71
+ if parsed.fragment != "":
72
+ raise InvalidProxy(proxy, "fragment is meaningless")
73
+
74
+ scheme = parsed.scheme
75
+ host = parsed.hostname
76
+ port = parsed.port or (443 if parsed.scheme == "https" else 80)
77
+ username = parsed.username
78
+ password = parsed.password
79
+ # urllib.parse.urlparse accepts URLs with a username but without a
80
+ # password. This doesn't make sense for HTTP Basic Auth credentials.
81
+ if username is not None and password is None:
82
+ raise InvalidProxy(proxy, "username provided without password")
83
+
84
+ try:
85
+ proxy.encode("ascii")
86
+ except UnicodeEncodeError:
87
+ # Input contains non-ASCII characters.
88
+ # It must be an IRI. Convert it to a URI.
89
+ host = host.encode("idna").decode()
90
+ if username is not None:
91
+ assert password is not None
92
+ username = urllib.parse.quote(username, safe=DELIMS)
93
+ password = urllib.parse.quote(password, safe=DELIMS)
94
+
95
+ return Proxy(scheme, host, port, username, password)
96
+
97
+
98
+ def get_proxy(uri: WebSocketURI) -> str | None:
99
+ """
100
+ Return the proxy to use for connecting to the given WebSocket URI, if any.
101
+
102
+ """
103
+ if urllib.request.proxy_bypass(f"{uri.host}:{uri.port}"):
104
+ return None
105
+
106
+ # According to the _Proxy Usage_ section of RFC 6455, use a SOCKS5 proxy if
107
+ # available, else favor the proxy for HTTPS connections over the proxy for
108
+ # HTTP connections.
109
+
110
+ # The priority of a proxy for WebSocket connections is unspecified. We give
111
+ # it the highest priority. This makes it easy to configure a specific proxy
112
+ # for websockets.
113
+
114
+ # getproxies() may return SOCKS proxies as {"socks": "http://host:port"} or
115
+ # as {"https": "socks5h://host:port"} depending on whether they're declared
116
+ # in the operating system or in environment variables.
117
+
118
+ proxies = urllib.request.getproxies()
119
+ if uri.secure:
120
+ schemes = ["wss", "socks", "https"]
121
+ else:
122
+ schemes = ["ws", "socks", "https", "http"]
123
+
124
+ for scheme in schemes:
125
+ proxy = proxies.get(scheme)
126
+ if proxy is not None:
127
+ if scheme == "socks" and proxy.startswith("http://"):
128
+ proxy = "socks5h://" + proxy[7:]
129
+ return proxy
130
+ else:
131
+ return None
132
+
133
+
134
+ def prepare_connect_request(
135
+ proxy: Proxy,
136
+ ws_uri: WebSocketURI,
137
+ user_agent_header: str | None = USER_AGENT,
138
+ ) -> bytes:
139
+ host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True)
140
+ headers = Headers()
141
+ headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure)
142
+ if user_agent_header is not None:
143
+ headers["User-Agent"] = user_agent_header
144
+ if proxy.username is not None:
145
+ assert proxy.password is not None # enforced by parse_proxy()
146
+ headers["Proxy-Authorization"] = build_authorization_basic(
147
+ proxy.username, proxy.password
148
+ )
149
+ # We cannot use the Request class because it supports only GET requests.
150
+ return f"CONNECT {host} HTTP/1.1\r\n".encode() + headers.serialize()
source/websockets/py.typed ADDED
File without changes