File size: 7,259 Bytes
4a2ab42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ae946d
 
 
4a2ab42
 
4ae946d
 
 
4a2ab42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ae946d
 
 
4a2ab42
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
"""
Request validation middleware for enhanced input validation and security
"""

import logging

from fastapi import HTTPException, Request
from starlette.middleware.base import BaseHTTPMiddleware

logger = logging.getLogger(__name__)


class RequestValidationMiddleware(BaseHTTPMiddleware):
    """Middleware for comprehensive request validation and security checks"""

    def __init__(self, app, max_body_size: int = 10 * 1024 * 1024):  # 10MB default
        super().__init__(app)
        self.max_body_size = max_body_size

    async def dispatch(self, request: Request, call_next):
        try:
            # Validate request size
            await self._validate_request_size(request)

            # Validate content type for POST/PUT/PATCH requests
            await self._validate_content_type(request)

            # Log suspicious requests
            await self._log_suspicious_requests(request)

            response = await call_next(request)
            return response

        except HTTPException:
            raise
        except Exception as exc:
            logger.error(f"Request validation middleware error: {exc}")
            raise

    async def _validate_request_size(self, request: Request) -> None:
        """Validate request body size"""
        if request.method in ["POST", "PUT", "PATCH"]:
            content_length = request.headers.get("content-length")
            if content_length:
                try:
                    size = int(content_length)
                    if size > self.max_body_size:
                        raise HTTPException(
                            status_code=413,
                            detail=f"Request body too large. Maximum size: {self.max_body_size} bytes",
                        )
                except ValueError:
                    pass  # Invalid content-length header, let FastAPI handle it

    async def _validate_content_type(self, request: Request) -> None:
        """Validate content type for requests with bodies"""
        if request.method in ["POST", "PUT", "PATCH"]:
            content_type = request.headers.get("content-type", "").lower()

            # Require content-type for requests with bodies
            if not content_type:
                # Read a small amount to check if there's actually a body
                body = await request.body()
                if body and len(body) > 0:
                    raise HTTPException(
                        status_code=400,
                        detail="Content-Type header required for requests with body",
                    )
                return

            # Validate content-type format
            allowed_types = [
                "application/json",
                "application/x-www-form-urlencoded",
                "multipart/form-data",
                "text/plain",
                "application/xml",
                "application/octet-stream",
            ]

            # Check if it's one of the allowed types or starts with allowed prefix
            is_allowed = any(
                content_type.startswith(allowed) for allowed in allowed_types
            )

            if not is_allowed:
                raise HTTPException(
                    status_code=415, detail=f"Unsupported content type: {content_type}"
                )

    async def _log_suspicious_requests(self, request: Request) -> None:
        """Log potentially suspicious requests for security monitoring"""
        suspicious_indicators = []

        # Check for SQL injection patterns in query parameters
        query_params = str(request.query_params)
        sql_patterns = [
            "union",
            "select",
            "insert",
            "update",
            "delete",
            "drop",
            "exec",
            "script",
        ]
        if any(pattern in query_params.lower() for pattern in sql_patterns):
            suspicious_indicators.append("sql_injection_patterns")

        # Check for XSS patterns in query parameters
        xss_patterns = ["<script", "javascript:", "onload=", "onerror="]
        if any(pattern in query_params.lower() for pattern in xss_patterns):
            suspicious_indicators.append("xss_patterns")

        # Check for unusually long query strings
        if len(query_params) > 2000:
            suspicious_indicators.append("long_query_string")

        # Check for suspicious user agents
        user_agent = request.headers.get("user-agent", "").lower()
        suspicious_uas = ["sqlmap", "nmap", "masscan", "dirbuster", "gobuster"]
        if any(ua in user_agent for ua in suspicious_uas):
            suspicious_indicators.append("suspicious_user_agent")

        # Log suspicious requests
        if suspicious_indicators:
            logger.warning(
                f"Suspicious request detected: {request.method} {request.url.path}",
                extra={
                    "client_ip": request.client.host if request.client else "unknown",
                    "user_agent": user_agent,
                    "indicators": suspicious_indicators,
                    "query_params_length": len(query_params),
                },
            )


class InputValidationMiddleware(BaseHTTPMiddleware):
    """Middleware for input sanitization and validation"""

    async def dispatch(self, request: Request, call_next):
        try:
            # Sanitize headers
            await self._sanitize_headers(request)

            # Validate request path and query parameters
            await self._validate_request_parameters(request)

            response = await call_next(request)
            return response

        except HTTPException:
            raise
        except Exception as exc:
            logger.error(f"Input validation middleware error: {exc}")
            raise

    async def _sanitize_headers(self, request: Request) -> None:
        """Sanitize and validate request headers"""
        # Remove any headers that could cause issues

        # Log headers that might indicate proxy misuse
        suspicious_headers = ["x-forwarded-for", "x-real-ip", "x-client-ip"]
        found_suspicious = [h for h in suspicious_headers if h in request.headers]

        if found_suspicious:
            logger.info(f"Request with proxy headers: {found_suspicious}")

    async def _validate_request_parameters(self, request: Request) -> None:
        """Validate request path and query parameters"""
        # Check for path traversal attempts
        path = request.url.path
        if ".." in path or "%" in path:
            # More thorough check for path traversal
            normalized_path = path.replace("\\", "/")
            if "../" in normalized_path or "..\\" in normalized_path:
                raise HTTPException(
                    status_code=400, detail="Invalid path: path traversal detected"
                )

        # Validate query parameter names (no special characters that could cause issues)
        for param_name in request.query_params:
            if any(char in param_name for char in ["<", ">", '"', "'", ";", "(", ")"]):
                raise HTTPException(
                    status_code=400,
                    detail=f"Invalid query parameter name: {param_name}",
                )