File size: 8,633 Bytes
a7c2198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96e312e
a7c2198
 
 
 
 
 
 
 
 
96e312e
a7c2198
 
 
 
96e312e
a7c2198
 
96e312e
a7c2198
 
 
 
 
 
 
 
 
 
96e312e
a7c2198
 
 
 
 
 
 
 
96e312e
a7c2198
 
96e312e
a7c2198
 
 
 
 
96e312e
a7c2198
 
 
96e312e
 
a7c2198
96e312e
a7c2198
 
 
 
 
 
 
96e312e
a7c2198
 
 
 
 
 
96e312e
a7c2198
 
 
96e312e
a7c2198
 
 
96e312e
a7c2198
 
 
 
 
96e312e
a7c2198
 
 
 
 
 
 
 
 
 
96e312e
a7c2198
 
 
96e312e
a7c2198
 
96e312e
a7c2198
 
 
 
96e312e
a7c2198
 
 
96e312e
a7c2198
 
 
96e312e
a7c2198
 
 
 
 
 
 
96e312e
a7c2198
 
96e312e
a7c2198
 
 
 
 
 
 
 
96e312e
a7c2198
 
 
 
96e312e
a7c2198
 
 
96e312e
a7c2198
 
 
 
 
 
 
 
 
 
 
 
 
 
96e312e
 
a7c2198
 
 
 
 
 
 
 
 
 
 
96e312e
a7c2198
 
96e312e
a7c2198
 
96e312e
a7c2198
 
96e312e
a7c2198
 
 
 
96e312e
a7c2198
 
 
 
96e312e
a7c2198
 
 
 
96e312e
a7c2198
 
 
 
 
96e312e
 
a7c2198
 
 
 
 
 
 
 
 
 
 
 
96e312e
a7c2198
 
 
 
 
 
96e312e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
"""
Security middleware for input validation, rate limiting, and request sanitization.
"""

import time
import json
import logging
from typing import Dict, Any, Optional
from collections import defaultdict, deque
from fastapi import Request, Response, HTTPException
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from app.utils.input_sanitizer import InputSanitizer

# Use standard logger for middleware to avoid circular dependencies
logger = logging.getLogger(__name__)


class SecurityMiddleware(BaseHTTPMiddleware):
    """
    Comprehensive security middleware that provides:
    - Request size limiting
    - Rate limiting
    - Input validation
    - Request logging
    - Security headers
    """

    def __init__(self, app, max_request_size: int = 10 * 1024 * 1024):  # 10MB default
        super().__init__(app)
        self.max_request_size = max_request_size
        self.rate_limiter = RateLimiter()

    async def dispatch(self, request: Request, call_next):
        start_time = time.time()

        try:
            # Check request size
            if hasattr(request, 'headers') and 'content-length' in request.headers:
                content_length = int(request.headers['content-length'])
                if content_length > self.max_request_size:
                    logger.warning("Request size too large")
                    return JSONResponse(
                        status_code=413,
                        content={"error": "Request entity too large"}
                    )

            # Rate limiting
            client_ip = self._get_client_ip(request)
            if not self.rate_limiter.is_allowed(client_ip, request.url.path):
                logger.warning("Rate limit exceeded for client")
                return JSONResponse(
                    status_code=429,
                    content={"error": "Rate limit exceeded"}
                )

            # Process request
            response = await call_next(request)

            # Add security headers
            response.headers["X-Content-Type-Options"] = "nosniff"
            response.headers["X-Frame-Options"] = "DENY"
            response.headers["X-XSS-Protection"] = "1; mode=block"
            response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"

            # Log request safely (basic logging to avoid circular dependencies)
            processing_time = time.time() - start_time
            logger.info(f"Request processed: {request.method} {request.url.path} "
                        f"in {processing_time:.3f}s with status {response.status_code}")

            return response

        except Exception as e:
            # Use basic logging to avoid circular dependency issues
            logger.error("Security middleware error occurred")
            return JSONResponse(
                status_code=500,
                content={"error": "Internal server error"}
            )

    def _get_client_ip(self, request: Request) -> str:
        """Extract client IP address from request"""
        # Check for forwarded headers first
        forwarded_for = request.headers.get("X-Forwarded-For")
        if forwarded_for:
            return forwarded_for.split(",")[0].strip()

        real_ip = request.headers.get("X-Real-IP")
        if real_ip:
            return real_ip

        # Fallback to client host
        return request.client.host if request.client else "unknown"


class RateLimiter:
    """
    Simple in-memory rate limiter with sliding window.
    In production, use Redis or similar distributed cache.
    """

    def __init__(self):
        self.requests = defaultdict(deque)
        self.limits = {
            # requests per minute for different endpoint patterns
            "/api/v1/merchants": 100,
            "/api/v1/helpers": 200,
            "/api/v1/nlp": 50,
            "default": 60
        }
        self.window_size = 60  # 1 minute window

    def is_allowed(self, client_ip: str, path: str) -> bool:
        """Check if request is allowed based on rate limits"""
        current_time = time.time()

        # Determine rate limit for this path
        limit = self._get_limit_for_path(path)

        # Clean old requests outside the window
        client_requests = self.requests[client_ip]
        while client_requests and client_requests[0] < current_time - self.window_size:
            client_requests.popleft()

        # Check if limit exceeded
        if len(client_requests) >= limit:
            return False

        # Add current request
        client_requests.append(current_time)
        return True

    def _get_limit_for_path(self, path: str) -> int:
        """Get rate limit for specific path"""
        for pattern, limit in self.limits.items():
            if pattern != "default" and pattern in path:
                return limit
        return self.limits["default"]


class RequestValidator:
    """Validates common request patterns and parameters"""

    @staticmethod
    def validate_pagination(limit: Optional[int], offset: Optional[int]) -> tuple:
        """Validate pagination parameters"""
        if limit is not None:
            limit = InputSanitizer.sanitize_pagination(limit, 0)[0]
        if offset is not None:
            offset = InputSanitizer.sanitize_pagination(10, offset)[1]
        return limit, offset

    @staticmethod
    def validate_search_params(params: Dict[str, Any]) -> Dict[str, Any]:
        """Validate search parameters"""
        validated = {}

        for key, value in params.items():
            if value is None:
                continue

            try:
                if key == "location_id":
                    validated[key] = InputSanitizer.sanitize_location_id(value)
                elif key == "merchant_id":
                    validated[key] = InputSanitizer.sanitize_merchant_id(value)
                elif key in ["latitude", "longitude"]:
                    lat = params.get("latitude")
                    lng = params.get("longitude")
                    lat, lng = InputSanitizer.sanitize_coordinates(lat, lng)
                    validated["latitude"] = lat
                    validated["longitude"] = lng
                elif key in ["limit", "offset"]:
                    limit = params.get("limit", 10)
                    offset = params.get("offset", 0)
                    limit, offset = InputSanitizer.sanitize_pagination(
                        limit, offset)
                    validated["limit"] = limit
                    validated["offset"] = offset
                elif isinstance(value, str):
                    validated[key] = InputSanitizer.sanitize_string(value)
                else:
                    validated[key] = value
            except ValueError as e:
                raise HTTPException(
                    status_code=400,
                    detail=f"Invalid parameter {key}: {str(e)}"
                )

        return validated


class CSRFProtection:
    """Basic CSRF protection for state-changing operations"""

    def __init__(self):
        self.protected_methods = {"POST", "PUT", "DELETE", "PATCH"}

    def validate_request(self, request: Request) -> bool:
        """Validate CSRF token for protected methods"""
        if request.method not in self.protected_methods:
            return True

        # Check for CSRF token in headers
        csrf_token = request.headers.get("X-CSRF-Token")
        if not csrf_token:
            return False

        # In production, validate against stored token
        # For now, just check that token exists and is not empty
        return len(csrf_token.strip()) > 0


def create_security_middleware(app, **kwargs):
    """Factory function to create security middleware with configuration"""
    return SecurityMiddleware(app, **kwargs)

# Utility decorators for endpoint protection


def require_valid_input(validation_func):
    """Decorator to validate input parameters"""
    def decorator(func):
        async def wrapper(*args, **kwargs):
            try:
                validated_kwargs = validation_func(kwargs)
                return await func(*args, **validated_kwargs)
            except ValueError as e:
                raise HTTPException(status_code=400, detail=str(e))
        return wrapper
    return decorator


def rate_limit(requests_per_minute: int = 60):
    """Decorator for endpoint-specific rate limiting"""
    def decorator(func):
        # This would integrate with the rate limiter
        # Implementation depends on your specific needs
        return func
    return decorator