File size: 3,259 Bytes
99fddd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import typing
from functools import partial

import anyio
from fastapi.responses import StreamingResponse
from starlette.types import Send, Scope, Receive


class PathMatchingTree:
    """
    PathMatchingTree is a data structure that can be used to match a path with a value.
    It supports exact match, partial match, and wildcard match.
    For example, if the tree is built with the following config:
      {
          "/foo/bar": "value1",
          "/baz/qux": "value2",
          "/foo/*": "value3",
          "/foo/*/bar": "value4"
      }
    Then the following path will match the corresponding value:
      /foo/bar -> value1
      /baz/qux -> value2
      /foo/baz -> value3
      /foo/baz/bar -> value4
      /foo/baz/bar2 -> value3
    """
    child = dict
    value = None

    def __init__(self, config):
        self.child = {}
        self._build_tree(config)

    def _build_tree(self, config):
        for k, v in config.items():
            parts = k.split('/')
            self._add(parts, v)

    def _add(self, parts, value):
        node = self
        for part in parts:
            if part == '':
                continue
            if part not in node.child:
                node.child[part] = PathMatchingTree(dict())
            node = node.child[part]
        node.value = value

    def get_matching(self, path):
        parts = path.split('/')
        matched = self
        for part in parts:
            if part == '':
                continue
            if part in matched.child:
                matched = matched.child[part]
            elif '*' in matched.child:
                matched = matched.child['*']
            else:
                break
        return matched.value


class OverrideStreamResponse(StreamingResponse):
    """
    Override StreamingResponse to support lazy send response status_code and response headers
    """

    async def stream_response(self, send: Send) -> None:
        first_chunk = True
        async for chunk in self.body_iterator:
            if first_chunk:
                await self.send_request_header(send)
                first_chunk = False
            if not isinstance(chunk, bytes):
                chunk = chunk.encode(self.charset)
            await send({'type': 'http.response.body', 'body': chunk, 'more_body': True})

        if first_chunk:
            await self.send_request_header(send)
        await send({'type': 'http.response.body', 'body': b'', 'more_body': False})

    async def send_request_header(self, send: Send) -> None:
        await send(
            {
                'type': 'http.response.start',
                'status': self.status_code,
                'headers': self.raw_headers,
            }
        )

    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        async with anyio.create_task_group() as task_group:
            async def wrap(func: typing.Callable[[], typing.Coroutine]) -> None:
                await func()
                await task_group.cancel_scope.cancel()

            task_group.start_soon(wrap, partial(self.stream_response, send))
            await wrap(partial(self.listen_for_disconnect, receive))

        if self.background is not None:
            await self.background()