| from __future__ import annotations |
|
|
| import io |
| import itertools |
| import sys |
| import typing |
|
|
| from .._models import Request, Response |
| from .._types import SyncByteStream |
| from .base import BaseTransport |
|
|
| if typing.TYPE_CHECKING: |
| from _typeshed import OptExcInfo |
| from _typeshed.wsgi import WSGIApplication |
|
|
| _T = typing.TypeVar("_T") |
|
|
|
|
| __all__ = ["WSGITransport"] |
|
|
|
|
| def _skip_leading_empty_chunks(body: typing.Iterable[_T]) -> typing.Iterable[_T]: |
| body = iter(body) |
| for chunk in body: |
| if chunk: |
| return itertools.chain([chunk], body) |
| return [] |
|
|
|
|
| class WSGIByteStream(SyncByteStream): |
| def __init__(self, result: typing.Iterable[bytes]) -> None: |
| self._close = getattr(result, "close", None) |
| self._result = _skip_leading_empty_chunks(result) |
|
|
| def __iter__(self) -> typing.Iterator[bytes]: |
| for part in self._result: |
| yield part |
|
|
| def close(self) -> None: |
| if self._close is not None: |
| self._close() |
|
|
|
|
| class WSGITransport(BaseTransport): |
| """ |
| A custom transport that handles sending requests directly to an WSGI app. |
| The simplest way to use this functionality is to use the `app` argument. |
| |
| ``` |
| client = httpx.Client(app=app) |
| ``` |
| |
| Alternatively, you can setup the transport instance explicitly. |
| This allows you to include any additional configuration arguments specific |
| to the WSGITransport class: |
| |
| ``` |
| transport = httpx.WSGITransport( |
| app=app, |
| script_name="/submount", |
| remote_addr="1.2.3.4" |
| ) |
| client = httpx.Client(transport=transport) |
| ``` |
| |
| Arguments: |
| |
| * `app` - The WSGI application. |
| * `raise_app_exceptions` - Boolean indicating if exceptions in the application |
| should be raised. Default to `True`. Can be set to `False` for use cases |
| such as testing the content of a client 500 response. |
| * `script_name` - The root path on which the WSGI application should be mounted. |
| * `remote_addr` - A string indicating the client IP of incoming requests. |
| ``` |
| """ |
|
|
| def __init__( |
| self, |
| app: WSGIApplication, |
| raise_app_exceptions: bool = True, |
| script_name: str = "", |
| remote_addr: str = "127.0.0.1", |
| wsgi_errors: typing.TextIO | None = None, |
| ) -> None: |
| self.app = app |
| self.raise_app_exceptions = raise_app_exceptions |
| self.script_name = script_name |
| self.remote_addr = remote_addr |
| self.wsgi_errors = wsgi_errors |
|
|
| def handle_request(self, request: Request) -> Response: |
| request.read() |
| wsgi_input = io.BytesIO(request.content) |
|
|
| port = request.url.port or {"http": 80, "https": 443}[request.url.scheme] |
| environ = { |
| "wsgi.version": (1, 0), |
| "wsgi.url_scheme": request.url.scheme, |
| "wsgi.input": wsgi_input, |
| "wsgi.errors": self.wsgi_errors or sys.stderr, |
| "wsgi.multithread": True, |
| "wsgi.multiprocess": False, |
| "wsgi.run_once": False, |
| "REQUEST_METHOD": request.method, |
| "SCRIPT_NAME": self.script_name, |
| "PATH_INFO": request.url.path, |
| "QUERY_STRING": request.url.query.decode("ascii"), |
| "SERVER_NAME": request.url.host, |
| "SERVER_PORT": str(port), |
| "SERVER_PROTOCOL": "HTTP/1.1", |
| "REMOTE_ADDR": self.remote_addr, |
| } |
| for header_key, header_value in request.headers.raw: |
| key = header_key.decode("ascii").upper().replace("-", "_") |
| if key not in ("CONTENT_TYPE", "CONTENT_LENGTH"): |
| key = "HTTP_" + key |
| environ[key] = header_value.decode("ascii") |
|
|
| seen_status = None |
| seen_response_headers = None |
| seen_exc_info = None |
|
|
| def start_response( |
| status: str, |
| response_headers: list[tuple[str, str]], |
| exc_info: OptExcInfo | None = None, |
| ) -> typing.Callable[[bytes], typing.Any]: |
| nonlocal seen_status, seen_response_headers, seen_exc_info |
| seen_status = status |
| seen_response_headers = response_headers |
| seen_exc_info = exc_info |
| return lambda _: None |
|
|
| result = self.app(environ, start_response) |
|
|
| stream = WSGIByteStream(result) |
|
|
| assert seen_status is not None |
| assert seen_response_headers is not None |
| if seen_exc_info and seen_exc_info[0] and self.raise_app_exceptions: |
| raise seen_exc_info[1] |
|
|
| status_code = int(seen_status.split()[0]) |
| headers = [ |
| (key.encode("ascii"), value.encode("ascii")) |
| for key, value in seen_response_headers |
| ] |
|
|
| return Response(status_code, headers=headers, stream=stream) |
|
|