| | from __future__ import annotations |
| |
|
| | import logging |
| | import re |
| | import threading |
| | import types |
| | import typing |
| |
|
| | import h2.config |
| | import h2.connection |
| | import h2.events |
| |
|
| | from .._base_connection import _TYPE_BODY |
| | from .._collections import HTTPHeaderDict |
| | from ..connection import HTTPSConnection, _get_default_user_agent |
| | from ..exceptions import ConnectionError |
| | from ..response import BaseHTTPResponse |
| |
|
| | orig_HTTPSConnection = HTTPSConnection |
| |
|
| | T = typing.TypeVar("T") |
| |
|
| | log = logging.getLogger(__name__) |
| |
|
| | RE_IS_LEGAL_HEADER_NAME = re.compile(rb"^[!#$%&'*+\-.^_`|~0-9a-z]+$") |
| | RE_IS_ILLEGAL_HEADER_VALUE = re.compile(rb"[\0\x00\x0a\x0d\r\n]|^[ \r\n\t]|[ \r\n\t]$") |
| |
|
| |
|
| | def _is_legal_header_name(name: bytes) -> bool: |
| | """ |
| | "An implementation that validates fields according to the definitions in Sections |
| | 5.1 and 5.5 of [HTTP] only needs an additional check that field names do not |
| | include uppercase characters." (https://httpwg.org/specs/rfc9113.html#n-field-validity) |
| | |
| | `http.client._is_legal_header_name` does not validate the field name according to the |
| | HTTP 1.1 spec, so we do that here, in addition to checking for uppercase characters. |
| | |
| | This does not allow for the `:` character in the header name, so should not |
| | be used to validate pseudo-headers. |
| | """ |
| | return bool(RE_IS_LEGAL_HEADER_NAME.match(name)) |
| |
|
| |
|
| | def _is_illegal_header_value(value: bytes) -> bool: |
| | """ |
| | "A field value MUST NOT contain the zero value (ASCII NUL, 0x00), line feed |
| | (ASCII LF, 0x0a), or carriage return (ASCII CR, 0x0d) at any position. A field |
| | value MUST NOT start or end with an ASCII whitespace character (ASCII SP or HTAB, |
| | 0x20 or 0x09)." (https://httpwg.org/specs/rfc9113.html#n-field-validity) |
| | """ |
| | return bool(RE_IS_ILLEGAL_HEADER_VALUE.search(value)) |
| |
|
| |
|
| | class _LockedObject(typing.Generic[T]): |
| | """ |
| | A wrapper class that hides a specific object behind a lock. |
| | The goal here is to provide a simple way to protect access to an object |
| | that cannot safely be simultaneously accessed from multiple threads. The |
| | intended use of this class is simple: take hold of it with a context |
| | manager, which returns the protected object. |
| | """ |
| |
|
| | __slots__ = ( |
| | "lock", |
| | "_obj", |
| | ) |
| |
|
| | def __init__(self, obj: T): |
| | self.lock = threading.RLock() |
| | self._obj = obj |
| |
|
| | def __enter__(self) -> T: |
| | self.lock.acquire() |
| | return self._obj |
| |
|
| | def __exit__( |
| | self, |
| | exc_type: type[BaseException] | None, |
| | exc_val: BaseException | None, |
| | exc_tb: types.TracebackType | None, |
| | ) -> None: |
| | self.lock.release() |
| |
|
| |
|
| | class HTTP2Connection(HTTPSConnection): |
| | def __init__( |
| | self, host: str, port: int | None = None, **kwargs: typing.Any |
| | ) -> None: |
| | self._h2_conn = self._new_h2_conn() |
| | self._h2_stream: int | None = None |
| | self._headers: list[tuple[bytes, bytes]] = [] |
| |
|
| | if "proxy" in kwargs or "proxy_config" in kwargs: |
| | raise NotImplementedError("Proxies aren't supported with HTTP/2") |
| |
|
| | super().__init__(host, port, **kwargs) |
| |
|
| | if self._tunnel_host is not None: |
| | raise NotImplementedError("Tunneling isn't supported with HTTP/2") |
| |
|
| | def _new_h2_conn(self) -> _LockedObject[h2.connection.H2Connection]: |
| | config = h2.config.H2Configuration(client_side=True) |
| | return _LockedObject(h2.connection.H2Connection(config=config)) |
| |
|
| | def connect(self) -> None: |
| | super().connect() |
| | with self._h2_conn as conn: |
| | conn.initiate_connection() |
| | if data_to_send := conn.data_to_send(): |
| | self.sock.sendall(data_to_send) |
| |
|
| | def putrequest( |
| | self, |
| | method: str, |
| | url: str, |
| | **kwargs: typing.Any, |
| | ) -> None: |
| | """putrequest |
| | This deviates from the HTTPConnection method signature since we never need to override |
| | sending accept-encoding headers or the host header. |
| | """ |
| | if "skip_host" in kwargs: |
| | raise NotImplementedError("`skip_host` isn't supported") |
| | if "skip_accept_encoding" in kwargs: |
| | raise NotImplementedError("`skip_accept_encoding` isn't supported") |
| |
|
| | self._request_url = url or "/" |
| | self._validate_path(url) |
| |
|
| | if ":" in self.host: |
| | authority = f"[{self.host}]:{self.port or 443}" |
| | else: |
| | authority = f"{self.host}:{self.port or 443}" |
| |
|
| | self._headers.append((b":scheme", b"https")) |
| | self._headers.append((b":method", method.encode())) |
| | self._headers.append((b":authority", authority.encode())) |
| | self._headers.append((b":path", url.encode())) |
| |
|
| | with self._h2_conn as conn: |
| | self._h2_stream = conn.get_next_available_stream_id() |
| |
|
| | def putheader(self, header: str | bytes, *values: str | bytes) -> None: |
| | |
| | header = header.encode() if isinstance(header, str) else header |
| | header = header.lower() |
| | if not _is_legal_header_name(header): |
| | raise ValueError(f"Illegal header name {str(header)}") |
| |
|
| | for value in values: |
| | value = value.encode() if isinstance(value, str) else value |
| | if _is_illegal_header_value(value): |
| | raise ValueError(f"Illegal header value {str(value)}") |
| | self._headers.append((header, value)) |
| |
|
| | def endheaders(self, message_body: typing.Any = None) -> None: |
| | if self._h2_stream is None: |
| | raise ConnectionError("Must call `putrequest` first.") |
| |
|
| | with self._h2_conn as conn: |
| | conn.send_headers( |
| | stream_id=self._h2_stream, |
| | headers=self._headers, |
| | end_stream=(message_body is None), |
| | ) |
| | if data_to_send := conn.data_to_send(): |
| | self.sock.sendall(data_to_send) |
| | self._headers = [] |
| |
|
| | def send(self, data: typing.Any) -> None: |
| | """Send data to the server. |
| | `data` can be: `str`, `bytes`, an iterable, or file-like objects |
| | that support a .read() method. |
| | """ |
| | if self._h2_stream is None: |
| | raise ConnectionError("Must call `putrequest` first.") |
| |
|
| | with self._h2_conn as conn: |
| | if data_to_send := conn.data_to_send(): |
| | self.sock.sendall(data_to_send) |
| |
|
| | if hasattr(data, "read"): |
| | while True: |
| | chunk = data.read(self.blocksize) |
| | if not chunk: |
| | break |
| | if isinstance(chunk, str): |
| | chunk = chunk.encode() |
| | conn.send_data(self._h2_stream, chunk, end_stream=False) |
| | if data_to_send := conn.data_to_send(): |
| | self.sock.sendall(data_to_send) |
| | conn.end_stream(self._h2_stream) |
| | return |
| |
|
| | if isinstance(data, str): |
| | data = data.encode() |
| |
|
| | try: |
| | if isinstance(data, bytes): |
| | conn.send_data(self._h2_stream, data, end_stream=True) |
| | if data_to_send := conn.data_to_send(): |
| | self.sock.sendall(data_to_send) |
| | else: |
| | for chunk in data: |
| | conn.send_data(self._h2_stream, chunk, end_stream=False) |
| | if data_to_send := conn.data_to_send(): |
| | self.sock.sendall(data_to_send) |
| | conn.end_stream(self._h2_stream) |
| | except TypeError: |
| | raise TypeError( |
| | "`data` should be str, bytes, iterable, or file. got %r" |
| | % type(data) |
| | ) |
| |
|
| | def set_tunnel( |
| | self, |
| | host: str, |
| | port: int | None = None, |
| | headers: typing.Mapping[str, str] | None = None, |
| | scheme: str = "http", |
| | ) -> None: |
| | raise NotImplementedError( |
| | "HTTP/2 does not support setting up a tunnel through a proxy" |
| | ) |
| |
|
| | def getresponse( |
| | self, |
| | ) -> HTTP2Response: |
| | status = None |
| | data = bytearray() |
| | with self._h2_conn as conn: |
| | end_stream = False |
| | while not end_stream: |
| | |
| | if received_data := self.sock.recv(65535): |
| | events = conn.receive_data(received_data) |
| | for event in events: |
| | if isinstance(event, h2.events.ResponseReceived): |
| | headers = HTTPHeaderDict() |
| | for header, value in event.headers: |
| | if header == b":status": |
| | status = int(value.decode()) |
| | else: |
| | headers.add( |
| | header.decode("ascii"), value.decode("ascii") |
| | ) |
| |
|
| | elif isinstance(event, h2.events.DataReceived): |
| | data += event.data |
| | conn.acknowledge_received_data( |
| | event.flow_controlled_length, event.stream_id |
| | ) |
| |
|
| | elif isinstance(event, h2.events.StreamEnded): |
| | end_stream = True |
| |
|
| | if data_to_send := conn.data_to_send(): |
| | self.sock.sendall(data_to_send) |
| |
|
| | assert status is not None |
| | return HTTP2Response( |
| | status=status, |
| | headers=headers, |
| | request_url=self._request_url, |
| | data=bytes(data), |
| | ) |
| |
|
| | def request( |
| | self, |
| | method: str, |
| | url: str, |
| | body: _TYPE_BODY | None = None, |
| | headers: typing.Mapping[str, str] | None = None, |
| | *, |
| | preload_content: bool = True, |
| | decode_content: bool = True, |
| | enforce_content_length: bool = True, |
| | **kwargs: typing.Any, |
| | ) -> None: |
| | """Send an HTTP/2 request""" |
| | if "chunked" in kwargs: |
| | |
| | |
| | pass |
| |
|
| | if self.sock is not None: |
| | self.sock.settimeout(self.timeout) |
| |
|
| | self.putrequest(method, url) |
| |
|
| | headers = headers or {} |
| | for k, v in headers.items(): |
| | if k.lower() == "transfer-encoding" and v == "chunked": |
| | continue |
| | else: |
| | self.putheader(k, v) |
| |
|
| | if b"user-agent" not in dict(self._headers): |
| | self.putheader(b"user-agent", _get_default_user_agent()) |
| |
|
| | if body: |
| | self.endheaders(message_body=body) |
| | self.send(body) |
| | else: |
| | self.endheaders() |
| |
|
| | def close(self) -> None: |
| | with self._h2_conn as conn: |
| | try: |
| | conn.close_connection() |
| | if data := conn.data_to_send(): |
| | self.sock.sendall(data) |
| | except Exception: |
| | pass |
| |
|
| | |
| | self._h2_conn = self._new_h2_conn() |
| | self._h2_stream = None |
| | self._headers = [] |
| |
|
| | super().close() |
| |
|
| |
|
| | class HTTP2Response(BaseHTTPResponse): |
| | |
| | def __init__( |
| | self, |
| | status: int, |
| | headers: HTTPHeaderDict, |
| | request_url: str, |
| | data: bytes, |
| | decode_content: bool = False, |
| | ) -> None: |
| | super().__init__( |
| | status=status, |
| | headers=headers, |
| | |
| | version=20, |
| | version_string="HTTP/2", |
| | |
| | reason=None, |
| | decode_content=decode_content, |
| | request_url=request_url, |
| | ) |
| | self._data = data |
| | self.length_remaining = 0 |
| |
|
| | @property |
| | def data(self) -> bytes: |
| | return self._data |
| |
|
| | def get_redirect_location(self) -> None: |
| | return None |
| |
|
| | def close(self) -> None: |
| | pass |
| |
|