File size: 5,497 Bytes
1a4aa87 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | """
Request Size Limits Middleware for AegisLM
Provides middleware to enforce request size limits for API protection.
"""
import os
from typing import Callable, Optional
from fastapi import FastAPI, Request, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
class RequestSizeLimitMiddleware(BaseHTTPMiddleware):
"""
Middleware to enforce maximum request size limits.
Protects against:
- Denial of Service (DoS) attacks via large payloads
- Memory exhaustion from oversized requests
- Buffer overflow attacks
"""
def __init__(
self,
app: ASGIApp,
max_request_size_bytes: Optional[int] = None,
):
"""
Initialize the middleware.
Args:
app: ASGI application
max_request_size_bytes: Maximum request size in bytes.
Defaults to 10MB if not specified.
"""
super().__init__(app)
self.max_request_size_bytes = max_request_size_bytes or self._get_default_max_size()
def _get_default_max_size(self) -> int:
"""Get default max request size from environment or use default."""
# Default: 10MB
default_size = 10 * 1024 * 1024 # 10 MB
env_size = os.getenv("AEGISLM_MAX_REQUEST_SIZE_BYTES")
if env_size:
try:
return int(env_size)
except ValueError:
pass
return default_size
async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""
Process the request and enforce size limits.
Args:
request: The incoming request
call_next: The next middleware or route handler
Returns:
Response or error if request is too large
"""
# Get content length
content_length = request.headers.get("content-length")
if content_length:
try:
content_length = int(content_length)
if content_length > self.max_request_size_bytes:
return JSONResponse(
status_code=413, # Payload Too Large
content={
"error": "request_too_large",
"message": f"Request body exceeds maximum allowed size of {self.max_request_size_bytes} bytes",
"max_size_bytes": self.max_request_size_bytes,
}
)
except ValueError:
pass
# Also check for Content-Length mismatch during streaming
try:
body = await request.body()
if len(body) > self.max_request_size_bytes:
return JSONResponse(
status_code=413, # Payload Too Large
content={
"error": "request_too_large",
"message": f"Request body exceeds maximum allowed size of {self.max_request_size_bytes} bytes",
"max_size_bytes": self.max_request_size_bytes,
}
)
# Re-create request with body for downstream handlers
async def receive():
return {"type": "http.request", "body": body}
request._receive = receive
except Exception:
pass
return await call_next(request)
# =============================================================================
# Additional security limits
# =============================================================================
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""
Middleware to add security headers to responses.
Adds headers for:
- Content Security Policy
- X-Frame-Options
- X-Content-Type-Options
- Strict-Transport-Security
- X-XSS-Protection
"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""
Add security headers to the response.
Args:
request: The incoming request
call_next: The next middleware or route handler
Returns:
Response with security headers
"""
response = await call_next(request)
# Add security headers
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
response.headers["Content-Security-Policy"] = "default-src 'self'"
return response
def get_request_size_limit() -> int:
"""
Get the configured request size limit.
Returns:
Maximum request size in bytes
"""
default_size = 10 * 1024 * 1024 # 10 MB
env_size = os.getenv("AEGISLM_MAX_REQUEST_SIZE_BYTES")
if env_size:
try:
return int(env_size)
except ValueError:
pass
return default_size
|