Spaces:
Paused
Paused
| from contextlib import asynccontextmanager | |
| from dataclasses import asdict, dataclass | |
| from enum import Enum | |
| import re | |
| from typing import ( | |
| TYPE_CHECKING, | |
| Any, | |
| AsyncGenerator, | |
| Dict, | |
| MutableMapping, | |
| Optional, | |
| cast, | |
| ) | |
| import uuid | |
| from asgiref.typing import ( | |
| ASGI3Application, | |
| ASGIReceiveCallable, | |
| ASGIReceiveEvent, | |
| ASGISendCallable, | |
| ASGISendEvent, | |
| Scope as ASGIScope, | |
| ) | |
| from loguru import logger | |
| from starlette.requests import Request | |
| from open_webui.env import AUDIT_LOG_LEVEL, MAX_BODY_LOG_SIZE | |
| from open_webui.utils.auth import get_current_user, get_http_authorization_cred | |
| from open_webui.models.users import UserModel | |
| if TYPE_CHECKING: | |
| from loguru import Logger | |
| class AuditLogEntry: | |
| # `Metadata` audit level properties | |
| id: str | |
| user: Optional[dict[str, Any]] | |
| audit_level: str | |
| verb: str | |
| request_uri: str | |
| user_agent: Optional[str] = None | |
| source_ip: Optional[str] = None | |
| # `Request` audit level properties | |
| request_object: Any = None | |
| # `Request Response` level | |
| response_object: Any = None | |
| response_status_code: Optional[int] = None | |
| class AuditLevel(str, Enum): | |
| NONE = "NONE" | |
| METADATA = "METADATA" | |
| REQUEST = "REQUEST" | |
| REQUEST_RESPONSE = "REQUEST_RESPONSE" | |
| class AuditLogger: | |
| """ | |
| A helper class that encapsulates audit logging functionality. It uses Loguru’s logger with an auditable binding to ensure that audit log entries are filtered correctly. | |
| Parameters: | |
| logger (Logger): An instance of Loguru’s logger. | |
| """ | |
| def __init__(self, logger: "Logger"): | |
| self.logger = logger.bind(auditable=True) | |
| def write( | |
| self, | |
| audit_entry: AuditLogEntry, | |
| *, | |
| log_level: str = "INFO", | |
| extra: Optional[dict] = None, | |
| ): | |
| entry = asdict(audit_entry) | |
| if extra: | |
| entry["extra"] = extra | |
| self.logger.log( | |
| log_level, | |
| "", | |
| **entry, | |
| ) | |
| class AuditContext: | |
| """ | |
| Captures and aggregates the HTTP request and response bodies during the processing of a request. It ensures that only a configurable maximum amount of data is stored to prevent excessive memory usage. | |
| Attributes: | |
| request_body (bytearray): Accumulated request payload. | |
| response_body (bytearray): Accumulated response payload. | |
| max_body_size (int): Maximum number of bytes to capture. | |
| metadata (Dict[str, Any]): A dictionary to store additional audit metadata (user, http verb, user agent, etc.). | |
| """ | |
| def __init__(self, max_body_size: int = MAX_BODY_LOG_SIZE): | |
| self.request_body = bytearray() | |
| self.response_body = bytearray() | |
| self.max_body_size = max_body_size | |
| self.metadata: Dict[str, Any] = {} | |
| def add_request_chunk(self, chunk: bytes): | |
| if len(self.request_body) < self.max_body_size: | |
| self.request_body.extend( | |
| chunk[: self.max_body_size - len(self.request_body)] | |
| ) | |
| def add_response_chunk(self, chunk: bytes): | |
| if len(self.response_body) < self.max_body_size: | |
| self.response_body.extend( | |
| chunk[: self.max_body_size - len(self.response_body)] | |
| ) | |
| class AuditLoggingMiddleware: | |
| """ | |
| ASGI middleware that intercepts HTTP requests and responses to perform audit logging. It captures request/response bodies (depending on audit level), headers, HTTP methods, and user information, then logs a structured audit entry at the end of the request cycle. | |
| """ | |
| AUDITED_METHODS = {"PUT", "PATCH", "DELETE", "POST"} | |
| def __init__( | |
| self, | |
| app: ASGI3Application, | |
| *, | |
| excluded_paths: Optional[list[str]] = None, | |
| max_body_size: int = MAX_BODY_LOG_SIZE, | |
| audit_level: AuditLevel = AuditLevel.NONE, | |
| ) -> None: | |
| self.app = app | |
| self.audit_logger = AuditLogger(logger) | |
| self.excluded_paths = excluded_paths or [] | |
| self.max_body_size = max_body_size | |
| self.audit_level = audit_level | |
| async def __call__( | |
| self, | |
| scope: ASGIScope, | |
| receive: ASGIReceiveCallable, | |
| send: ASGISendCallable, | |
| ) -> None: | |
| if scope["type"] != "http": | |
| return await self.app(scope, receive, send) | |
| request = Request(scope=cast(MutableMapping, scope)) | |
| if self._should_skip_auditing(request): | |
| return await self.app(scope, receive, send) | |
| async with self._audit_context(request) as context: | |
| async def send_wrapper(message: ASGISendEvent) -> None: | |
| if self.audit_level == AuditLevel.REQUEST_RESPONSE: | |
| await self._capture_response(message, context) | |
| await send(message) | |
| original_receive = receive | |
| async def receive_wrapper() -> ASGIReceiveEvent: | |
| nonlocal original_receive | |
| message = await original_receive() | |
| if self.audit_level in ( | |
| AuditLevel.REQUEST, | |
| AuditLevel.REQUEST_RESPONSE, | |
| ): | |
| await self._capture_request(message, context) | |
| return message | |
| await self.app(scope, receive_wrapper, send_wrapper) | |
| async def _audit_context( | |
| self, request: Request | |
| ) -> AsyncGenerator[AuditContext, None]: | |
| """ | |
| async context manager that ensures that an audit log entry is recorded after the request is processed. | |
| """ | |
| context = AuditContext() | |
| try: | |
| yield context | |
| finally: | |
| await self._log_audit_entry(request, context) | |
| async def _get_authenticated_user(self, request: Request) -> Optional[UserModel]: | |
| auth_header = request.headers.get("Authorization") | |
| try: | |
| user = get_current_user( | |
| request, None, get_http_authorization_cred(auth_header) | |
| ) | |
| return user | |
| except Exception as e: | |
| logger.debug(f"Failed to get authenticated user: {str(e)}") | |
| return None | |
| def _should_skip_auditing(self, request: Request) -> bool: | |
| if ( | |
| request.method not in {"POST", "PUT", "PATCH", "DELETE"} | |
| or AUDIT_LOG_LEVEL == "NONE" | |
| ): | |
| return True | |
| ALWAYS_LOG_ENDPOINTS = { | |
| "/api/v1/auths/signin", | |
| "/api/v1/auths/signout", | |
| "/api/v1/auths/signup", | |
| } | |
| path = request.url.path.lower() | |
| for endpoint in ALWAYS_LOG_ENDPOINTS: | |
| if path.startswith(endpoint): | |
| return False # Do NOT skip logging for auth endpoints | |
| # Skip logging if the request is not authenticated | |
| if not request.headers.get("authorization"): | |
| return True | |
| # match either /api/<resource>/...(for the endpoint /api/chat case) or /api/v1/<resource>/... | |
| pattern = re.compile( | |
| r"^/api(?:/v1)?/(" + "|".join(self.excluded_paths) + r")\b" | |
| ) | |
| if pattern.match(request.url.path): | |
| return True | |
| return False | |
| async def _capture_request(self, message: ASGIReceiveEvent, context: AuditContext): | |
| if message["type"] == "http.request": | |
| body = message.get("body", b"") | |
| context.add_request_chunk(body) | |
| async def _capture_response(self, message: ASGISendEvent, context: AuditContext): | |
| if message["type"] == "http.response.start": | |
| context.metadata["response_status_code"] = message["status"] | |
| elif message["type"] == "http.response.body": | |
| body = message.get("body", b"") | |
| context.add_response_chunk(body) | |
| async def _log_audit_entry(self, request: Request, context: AuditContext): | |
| try: | |
| user = await self._get_authenticated_user(request) | |
| user = ( | |
| user.model_dump(include={"id", "name", "email", "role"}) if user else {} | |
| ) | |
| request_body = context.request_body.decode("utf-8", errors="replace") | |
| response_body = context.response_body.decode("utf-8", errors="replace") | |
| # Redact sensitive information | |
| if "password" in request_body: | |
| request_body = re.sub( | |
| r'"password":\s*"(.*?)"', | |
| '"password": "********"', | |
| request_body, | |
| ) | |
| entry = AuditLogEntry( | |
| id=str(uuid.uuid4()), | |
| user=user, | |
| audit_level=self.audit_level.value, | |
| verb=request.method, | |
| request_uri=str(request.url), | |
| response_status_code=context.metadata.get("response_status_code", None), | |
| source_ip=request.client.host if request.client else None, | |
| user_agent=request.headers.get("user-agent"), | |
| request_object=request_body, | |
| response_object=response_body, | |
| ) | |
| self.audit_logger.write(entry) | |
| except Exception as e: | |
| logger.error(f"Failed to log audit entry: {str(e)}") | |