File size: 6,441 Bytes
0d97019
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import os
import httpx
import tempfile
import shutil
import brotli
from fastapi import Request
from fastapi.responses import StreamingResponse, FileResponse
from starlette.background import BackgroundTask

def _get_file_headers(local_path: str) -> dict:
    headers = {
        "Cross-Origin-Opener-Policy": "same-origin",
        "Cross-Origin-Embedder-Policy": "require-corp"
    }
    
    if local_path.endswith(".br"):
        headers["Content-Encoding"] = "br"
        headers["Content-Type"] = "application/octet-stream"
    
    return headers

def _get_media_type(local_path: str) -> str:
    """Get appropriate media type based on file extension."""
    if local_path.endswith(".wasm.br") or local_path.endswith(".wasm"):
        return "application/wasm"
    if local_path.endswith(".br"):
        return "application/octet-stream"
    return None  # Let FileResponse auto-detect

def get_local_file(local_path: str, request: Request = None) -> FileResponse | StreamingResponse | None:
    """
    Get a local file as response. If it's a .br file and client doesn't accept brotli,
    decompress it on the fly.
    
    Args:
        local_path: Path to the local file
        request: Optional request object to check Accept-Encoding header
        
    Returns:
        FileResponse, StreamingResponse (for decompressed .br), or None if file not found
    """
    if not os.path.isfile(local_path):
        return None
    
    headers = _get_file_headers(local_path)
    media_type = _get_media_type(local_path)
    
    # Check if we need to decompress .br file for client
    is_br_file = local_path.endswith(".br")
    need_decompress = is_br_file and request and not _client_accepts_brotli(request)
    
    if need_decompress:
        # Stream decompressed content
        headers.pop("Content-Encoding", None)
        headers["Content-Type"] = "application/octet-stream"
        
        def iterate_decompressed():
            with open(local_path, "rb") as f:
                decompressor = brotli.Decompressor()
                while chunk := f.read(65536):  # 64KB chunks
                    yield decompressor.process(chunk)
        
        return StreamingResponse(
            iterate_decompressed(),
            media_type="application/octet-stream",
            headers=headers
        )
    
    if media_type:
        return FileResponse(local_path, media_type=media_type, headers=headers)
    return FileResponse(local_path, headers=headers)

def _client_accepts_brotli(request: Request) -> bool:
    """Check if client accepts brotli encoding."""
    accept_encoding = request.headers.get("accept-encoding", "")
    return "br" in accept_encoding.lower()

async def proxy_and_cache(request: Request, url: str, local_path: str = None, disable_cache: bool = False):
    """
    Proxy request to upstream URL and optionally cache the response.
    
    Args:
        request: FastAPI request object
        url: Upstream URL to proxy to
        local_path: Local file path for caching (required if disable_cache is False)
        disable_cache: If True, just proxy without caching or reading from local file
    """
    if not disable_cache and local_path:
        if response := get_local_file(local_path, request):
            return response
    
    # Check if this is a .br file and client doesn't support brotli
    is_br_file = url.endswith(".br")
    client_accepts_br = _client_accepts_brotli(request)
    need_decompress = is_br_file and not client_accepts_br
    
    client = httpx.AsyncClient(timeout=None)
    headers = {k: v for k, v in request.headers.items() if k.lower() not in ["host", "content-length", "accept-encoding"]}
    
    req = client.build_request(request.method, url, headers=headers)
    r = await client.send(req, stream=True)
    
    excluded_headers = {"transfer-encoding", "connection", "keep-alive", "upgrade", "content-security-policy"}
    response_headers = {k: v for k, v in r.headers.items() if k.lower() not in excluded_headers}
    response_headers["Cross-Origin-Opener-Policy"] = "same-origin"
    response_headers["Cross-Origin-Embedder-Policy"] = "require-corp"
    
    # If decompressing, remove content-encoding from response
    if need_decompress:
        response_headers.pop("content-encoding", None)
        response_headers.pop("Content-Encoding", None)
        # Also remove content-length since it will change after decompression
        response_headers.pop("content-length", None)
        response_headers.pop("Content-Length", None)
    
    if r.status_code != 200 or disable_cache or not local_path:
        async def stream_with_decompress():
            decompressor = brotli.Decompressor() if need_decompress else None
            try:
                async for chunk in r.aiter_raw():
                    if decompressor:
                        yield decompressor.process(chunk)
                    else:
                        yield chunk
            finally:
                await r.aclose()
                await client.aclose()
        
        return StreamingResponse(
            stream_with_decompress(),
            status_code=r.status_code,
            headers=response_headers
        )

    os.makedirs(os.path.dirname(local_path), exist_ok=True)
    
    temp_file = tempfile.NamedTemporaryFile(delete=False, dir=os.path.dirname(local_path))
    temp_file_path = temp_file.name
    
    async def iterate_and_save():
        success = False
        decompressor = brotli.Decompressor() if need_decompress else None
        try:
            async for chunk in r.aiter_raw():
                # Always save raw (compressed) data to cache
                temp_file.write(chunk)
                # Yield decompressed data if needed
                if decompressor:
                    yield decompressor.process(chunk)
                else:
                    yield chunk
            temp_file.close()
            
            # If we reached here, the stream was fully consumed
            shutil.move(temp_file_path, local_path)
            success = True
        finally:
            if not success:
                temp_file.close()
                if os.path.exists(temp_file_path):
                    os.remove(temp_file_path)
            await r.aclose()
            await client.aclose()

    return StreamingResponse(
        iterate_and_save(),
        status_code=r.status_code,
        headers=response_headers
    )