Spaces:
Build error
Build error
| """ | |
| Pure-ASGI replacements for the project's previous | |
| `@app.middleware('http')` / `BaseHTTPMiddleware` middlewares. | |
| Why this matters | |
| ---------------- | |
| Starlette's `BaseHTTPMiddleware` (which `@app.middleware('http')` is | |
| sugar for) runs the downstream app inside an `anyio` task group. When | |
| the wrapper exits — for any reason: response complete, client | |
| disconnect, an outer middleware bailing out — the task group cancels | |
| the inner task. That `CancelledError` then propagates into whatever | |
| the inner task was doing, including in-flight DB queries, embedding | |
| calls and disk I/O. | |
| In Open WebUI this surfaces as: | |
| * SQLAlchemy logging multi-page `NotImplementedError: | |
| terminate_force_close()` tracebacks at ERROR every time a request is | |
| cancelled mid-DB-call (the aiosqlite connector cleanup path). | |
| * Spurious cancellations cascading through the four stacked | |
| `@app.middleware('http')` wrappers. | |
| Pure ASGI middleware does not introduce a cancel scope around the | |
| downstream app, so client disconnects propagate the way ASGI was | |
| designed to (via `receive()` returning `http.disconnect`) instead of | |
| being injected as `CancelledError` into arbitrary `await` points. | |
| Reference: https://www.starlette.io/middleware/#limitations | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import re | |
| import time | |
| from urllib.parse import parse_qs, urlencode | |
| from fastapi.responses import JSONResponse, RedirectResponse | |
| from fastapi.security import HTTPAuthorizationCredentials | |
| from starlette.datastructures import MutableHeaders | |
| from starlette.requests import Request | |
| from starlette.types import ASGIApp, Message, Receive, Scope, Send | |
| from open_webui.env import CUSTOM_API_KEY_HEADER | |
| from open_webui.internal.db import ScopedSession | |
| from open_webui.utils.auth import get_http_authorization_cred | |
| log = logging.getLogger(__name__) | |
| class CommitSessionMiddleware: | |
| """Commit and release the thread-local sync `ScopedSession` after each | |
| HTTP request. | |
| Most requests now use the async session; the sync ScopedSession is | |
| only touched by startup, healthchecks, and a handful of legacy | |
| helpers (notably the pgvector / opengauss vector-DB clients). The | |
| middleware exists so that PostgreSQL connections do not accumulate | |
| as "idle in transaction" and so that any pending sync work made | |
| inside the request is durably persisted. | |
| Failure semantics | |
| ----------------- | |
| * Downstream raised → roll back any pending sync work, release the | |
| connection, and re-raise so the outer exception middleware can | |
| turn it into an error response. We never commit work on a | |
| request that did not complete successfully. | |
| * Downstream returned → commit pending sync work; on commit | |
| failure, log loudly, roll back, and re-raise. Note that in pure | |
| ASGI the response messages have already been emitted by the | |
| time `await self.app(...)` returns, so a commit failure cannot | |
| retroactively change what the client sees on the wire — but | |
| re-raising still surfaces the error in logs and to ASGI servers | |
| that expose it. We deliberately do not buffer the response to | |
| gate it on commit success, because that would defeat streaming | |
| responses (chat completions, SSE) which are core to the app. | |
| For request paths where commit-before-send is required, manage the | |
| sync session explicitly inside the handler instead of relying on | |
| this middleware. | |
| """ | |
| def __init__(self, app: ASGIApp) -> None: | |
| self.app = app | |
| async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: | |
| if scope['type'] != 'http': | |
| await self.app(scope, receive, send) | |
| return | |
| try: | |
| await self.app(scope, receive, send) | |
| except BaseException: | |
| # Downstream did not complete successfully. Roll back any | |
| # pending sync writes, release the connection, and let the | |
| # exception propagate. | |
| try: | |
| ScopedSession.rollback() | |
| except Exception: | |
| log.exception('CommitSessionMiddleware: rollback failed after downstream error') | |
| finally: | |
| ScopedSession.remove() | |
| raise | |
| # Downstream completed. Commit pending sync work. | |
| try: | |
| ScopedSession.commit() | |
| except Exception: | |
| log.exception('CommitSessionMiddleware: post-request commit failed; response was already sent to client') | |
| try: | |
| ScopedSession.rollback() | |
| except Exception: | |
| log.exception('CommitSessionMiddleware: rollback failed after commit failure') | |
| raise | |
| finally: | |
| # CRITICAL: remove() returns the connection to the pool. | |
| # Without this, connections remain "checked out" and | |
| # accumulate as "idle in transaction" in PostgreSQL. | |
| ScopedSession.remove() | |
| class AuthTokenMiddleware: | |
| """Extract the bearer/cookie/API-key credential and stash it on | |
| `request.state.token`. | |
| The header used for API-key transport is controlled by the | |
| ``CUSTOM_API_KEY_HEADER`` environment variable (default ``x-api-key``). | |
| This is useful when Open WebUI sits behind a reverse proxy that | |
| consumes the ``Authorization`` header for its own authentication — | |
| set the env var to a unique header (e.g. ``X-OpenWebUI-Key``) so | |
| the middleware checks that instead and avoids the 401 short-circuit. | |
| Routes that depend on `get_verified_user` etc. read this state. | |
| Also exposes `request.state.enable_api_keys` (snapshotted at request | |
| entry from runtime config) and stamps an `X-Process-Time` response | |
| header. | |
| """ | |
| def __init__(self, app: ASGIApp, *, fastapi_app) -> None: | |
| self.app = app | |
| self._fastapi_app = fastapi_app | |
| async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: | |
| if scope['type'] != 'http': | |
| await self.app(scope, receive, send) | |
| return | |
| start_time = time.monotonic() | |
| request = Request(scope) | |
| token = get_http_authorization_cred(request.headers.get('Authorization')) | |
| if token is None: | |
| cookie_token = request.cookies.get('token') | |
| if cookie_token: | |
| token = HTTPAuthorizationCredentials(scheme='Bearer', credentials=cookie_token) | |
| if token is None: | |
| api_key = request.headers.get(CUSTOM_API_KEY_HEADER) | |
| if api_key: | |
| token = HTTPAuthorizationCredentials(scheme='Bearer', credentials=api_key) | |
| request.state.token = token | |
| request.state.enable_api_keys = self._fastapi_app.state.config.ENABLE_API_KEYS | |
| async def send_with_timing(message: Message) -> None: | |
| if message['type'] == 'http.response.start': | |
| process_time = int(time.monotonic() - start_time) | |
| headers = MutableHeaders(scope=message) | |
| headers['X-Process-Time'] = str(process_time) | |
| await send(message) | |
| await self.app(scope, receive, send_with_timing) | |
| class WebsocketUpgradeGuardMiddleware: | |
| """Reject HTTP requests to `/ws/socket.io` that claim | |
| `transport=websocket` but lack the proper `Upgrade`/`Connection` | |
| headers. | |
| Works around https://github.com/miguelgrinberg/python-engineio/issues/367 | |
| where engineio mishandles such requests. | |
| """ | |
| def __init__(self, app: ASGIApp) -> None: | |
| self.app = app | |
| async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: | |
| if scope['type'] != 'http': | |
| await self.app(scope, receive, send) | |
| return | |
| path = scope.get('path', '') | |
| if '/ws/socket.io' in path: | |
| query_string = scope.get('query_string', b'').decode('latin-1', errors='replace') | |
| query_params = parse_qs(query_string) | |
| if query_params.get('transport', [''])[0] == 'websocket': | |
| headers = _scope_headers(scope) | |
| upgrade = headers.get('upgrade', '').lower() | |
| connection_tokens = [token.strip() for token in headers.get('connection', '').lower().split(',')] | |
| if upgrade != 'websocket' or 'upgrade' not in connection_tokens: | |
| response = JSONResponse( | |
| status_code=400, | |
| content={'detail': 'Invalid WebSocket upgrade request'}, | |
| ) | |
| await response(scope, receive, send) | |
| return | |
| await self.app(scope, receive, send) | |
| class RedirectMiddleware: | |
| """Rewrites a couple of legacy entry-points to the SPA's own routes: | |
| * ``GET /watch?v=ID`` (YouTube) → ``/?youtube=ID`` | |
| * ``GET /?shared=…`` (PWA share-target) → ``/?youtube=…`` / | |
| ``/?load-url=…`` / ``/?q=…`` | |
| """ | |
| def __init__(self, app: ASGIApp) -> None: | |
| self.app = app | |
| async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: | |
| if scope['type'] != 'http' or scope.get('method', '').upper() != 'GET': | |
| await self.app(scope, receive, send) | |
| return | |
| path = scope.get('path', '') | |
| query_string = scope.get('query_string', b'').decode('latin-1', errors='replace') | |
| query_params = parse_qs(query_string) | |
| redirect_params: dict[str, str] = {} | |
| if path.endswith('/watch') and 'v' in query_params and query_params['v']: | |
| redirect_params['youtube'] = query_params['v'][0] | |
| if 'shared' in query_params and query_params['shared']: | |
| text = query_params['shared'][0] | |
| if text: | |
| url_match = re.match(r'https://\S+', text) | |
| if url_match: | |
| # Local import: youtube loader pulls heavy deps and is | |
| # only needed when a share-target actually contains a | |
| # YouTube URL. | |
| from open_webui.retrieval.loaders.youtube import _parse_video_id | |
| youtube_video_id = _parse_video_id(url_match[0]) | |
| if youtube_video_id: | |
| redirect_params['youtube'] = youtube_video_id | |
| else: | |
| redirect_params['load-url'] = url_match[0] | |
| else: | |
| redirect_params['q'] = text | |
| if redirect_params: | |
| redirect_url = f'/?{urlencode(redirect_params)}' | |
| response = RedirectResponse(url=redirect_url) | |
| await response(scope, receive, send) | |
| return | |
| await self.app(scope, receive, send) | |
| def _scope_headers(scope: Scope) -> dict[str, str]: | |
| """Return ASGI scope headers as a lower-cased str→str dict. | |
| ASGI delivers headers as a list of (bytes, bytes) pairs. For | |
| convenience, fold duplicate keys with comma-joining (matching | |
| HTTP/1.1 semantics). | |
| """ | |
| decoded: dict[str, str] = {} | |
| for raw_key, raw_value in scope.get('headers', []): | |
| key = raw_key.decode('latin-1').lower() | |
| value = raw_value.decode('latin-1') | |
| if key in decoded: | |
| decoded[key] = f'{decoded[key]}, {value}' | |
| else: | |
| decoded[key] = value | |
| return decoded | |