Spaces:
Paused
Paused
| from __future__ import annotations | |
| import io | |
| import itertools | |
| import sys | |
| import typing | |
| from .._models import Request, Response | |
| from .._types import SyncByteStream | |
| from .base import BaseTransport | |
| if typing.TYPE_CHECKING: | |
| from _typeshed import OptExcInfo # pragma: no cover | |
| from _typeshed.wsgi import WSGIApplication # pragma: no cover | |
| _T = typing.TypeVar("_T") | |
| __all__ = ["WSGITransport"] | |
| def _skip_leading_empty_chunks(body: typing.Iterable[_T]) -> typing.Iterable[_T]: | |
| body = iter(body) | |
| for chunk in body: | |
| if chunk: | |
| return itertools.chain([chunk], body) | |
| return [] | |
| class WSGIByteStream(SyncByteStream): | |
| def __init__(self, result: typing.Iterable[bytes]) -> None: | |
| self._close = getattr(result, "close", None) | |
| self._result = _skip_leading_empty_chunks(result) | |
| def __iter__(self) -> typing.Iterator[bytes]: | |
| for part in self._result: | |
| yield part | |
| def close(self) -> None: | |
| if self._close is not None: | |
| self._close() | |
| class WSGITransport(BaseTransport): | |
| """ | |
| A custom transport that handles sending requests directly to an WSGI app. | |
| The simplest way to use this functionality is to use the `app` argument. | |
| ``` | |
| client = httpx.Client(app=app) | |
| ``` | |
| Alternatively, you can setup the transport instance explicitly. | |
| This allows you to include any additional configuration arguments specific | |
| to the WSGITransport class: | |
| ``` | |
| transport = httpx.WSGITransport( | |
| app=app, | |
| script_name="/submount", | |
| remote_addr="1.2.3.4" | |
| ) | |
| client = httpx.Client(transport=transport) | |
| ``` | |
| Arguments: | |
| * `app` - The WSGI application. | |
| * `raise_app_exceptions` - Boolean indicating if exceptions in the application | |
| should be raised. Default to `True`. Can be set to `False` for use cases | |
| such as testing the content of a client 500 response. | |
| * `script_name` - The root path on which the WSGI application should be mounted. | |
| * `remote_addr` - A string indicating the client IP of incoming requests. | |
| ``` | |
| """ | |
| def __init__( | |
| self, | |
| app: WSGIApplication, | |
| raise_app_exceptions: bool = True, | |
| script_name: str = "", | |
| remote_addr: str = "127.0.0.1", | |
| wsgi_errors: typing.TextIO | None = None, | |
| ) -> None: | |
| self.app = app | |
| self.raise_app_exceptions = raise_app_exceptions | |
| self.script_name = script_name | |
| self.remote_addr = remote_addr | |
| self.wsgi_errors = wsgi_errors | |
| def handle_request(self, request: Request) -> Response: | |
| request.read() | |
| wsgi_input = io.BytesIO(request.content) | |
| port = request.url.port or {"http": 80, "https": 443}[request.url.scheme] | |
| environ = { | |
| "wsgi.version": (1, 0), | |
| "wsgi.url_scheme": request.url.scheme, | |
| "wsgi.input": wsgi_input, | |
| "wsgi.errors": self.wsgi_errors or sys.stderr, | |
| "wsgi.multithread": True, | |
| "wsgi.multiprocess": False, | |
| "wsgi.run_once": False, | |
| "REQUEST_METHOD": request.method, | |
| "SCRIPT_NAME": self.script_name, | |
| "PATH_INFO": request.url.path, | |
| "QUERY_STRING": request.url.query.decode("ascii"), | |
| "SERVER_NAME": request.url.host, | |
| "SERVER_PORT": str(port), | |
| "SERVER_PROTOCOL": "HTTP/1.1", | |
| "REMOTE_ADDR": self.remote_addr, | |
| } | |
| for header_key, header_value in request.headers.raw: | |
| key = header_key.decode("ascii").upper().replace("-", "_") | |
| if key not in ("CONTENT_TYPE", "CONTENT_LENGTH"): | |
| key = "HTTP_" + key | |
| environ[key] = header_value.decode("ascii") | |
| seen_status = None | |
| seen_response_headers = None | |
| seen_exc_info = None | |
| def start_response( | |
| status: str, | |
| response_headers: list[tuple[str, str]], | |
| exc_info: OptExcInfo | None = None, | |
| ) -> typing.Callable[[bytes], typing.Any]: | |
| nonlocal seen_status, seen_response_headers, seen_exc_info | |
| seen_status = status | |
| seen_response_headers = response_headers | |
| seen_exc_info = exc_info | |
| return lambda _: None | |
| result = self.app(environ, start_response) | |
| stream = WSGIByteStream(result) | |
| assert seen_status is not None | |
| assert seen_response_headers is not None | |
| if seen_exc_info and seen_exc_info[0] and self.raise_app_exceptions: | |
| raise seen_exc_info[1] | |
| status_code = int(seen_status.split()[0]) | |
| headers = [ | |
| (key.encode("ascii"), value.encode("ascii")) | |
| for key, value in seen_response_headers | |
| ] | |
| return Response(status_code, headers=headers, stream=stream) | |