File size: 8,115 Bytes
bab1185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
"""
QCrypt RNG - API Middleware
Enterprise-grade middleware for rate limiting, authentication, and monitoring
"""

from fastapi import Request, HTTPException, status
from fastapi.responses import JSONResponse
import hashlib
import hmac as _hmac
import time
import asyncio
from typing import Callable, Awaitable, Optional, Set

from app.utils.rate_limiting import rate_limiter
from app.config import settings
from app.utils.logging import logger, get_security_logger

_security_log = get_security_logger()


# ---------------------------------------------------------------------------
# API key allow-list (loaded once at import time from settings)
# ---------------------------------------------------------------------------
def _load_valid_api_keys() -> Optional[Set[str]]:
    """Parse VALID_API_KEYS from settings into a frozen set.

    Returns None when no allow-list is configured (fall back to
    length-based validation).
    """
    raw = settings.valid_api_keys
    if not raw:
        return None
    keys = {k.strip() for k in raw.split(",") if k.strip()}
    return keys if keys else None


_VALID_API_KEYS: Optional[Set[str]] = _load_valid_api_keys()


def _constant_time_key_check(candidate: str, valid_keys: Set[str]) -> bool:
    """Check membership with constant-time comparison per key."""
    candidate_bytes = candidate.encode("utf-8")
    found = False
    for key in valid_keys:
        if _hmac.compare_digest(candidate_bytes, key.encode("utf-8")):
            found = True
    return found


def _mask_api_key(api_key: str) -> str:
    """Return a safe prefix hash for audit logs (never log the raw key)."""
    return hashlib.sha256(api_key.encode("utf-8")).hexdigest()[:12]


async def rate_limit_middleware(
    request: Request, 
    call_next: Callable[[Request], Awaitable[any]]
):
    """
    Rate limiting middleware that checks usage against tier limits
    """
    if not settings.enable_usage_tracking:
        return await call_next(request)
    
    # Extract API key from header
    api_key = request.headers.get(settings.api_key_header, "")
    
    # Skip rate limiting for certain endpoints or if API key is not required
    if not settings.require_api_key and not api_key:
        return await call_next(request)
    
    # Check rate limit
    is_allowed, remaining, reset_time = await rate_limiter.check_limit(
        api_key, 
        request.url.path
    )
    
    if not is_allowed:
        client_ip = request.client.host if request.client else "unknown"
        _security_log.warning(
            f"rate_limit_exceeded | IP: {client_ip} | "
            f"Path: {request.method} {request.url.path} | "
            f"Key: {_mask_api_key(api_key) if api_key else 'none'} | "
            f"Reset: {reset_time}s"
        )
        return JSONResponse(
            status_code=status.HTTP_429_TOO_MANY_REQUESTS,
            content={
                "error": "rate_limit_exceeded",
                "message": f"Rate limit exceeded. Try again in {reset_time} seconds.",
                "remaining_requests": 0,
                "reset_time": reset_time
            }
        )
    
    # Record start time for response time tracking
    start_time = time.time()
    
    try:
        response = await call_next(request)
        
        # Calculate response time
        response_time = time.time() - start_time
        
        # Record usage
        await rate_limiter.record_usage(
            api_key=api_key,
            endpoint=request.url.path,
            method=request.method,
            response_time=response_time,
            bytes_processed=int(response.headers.get("content-length", 0)),
            success=response.status_code < 400
        )
        
        # Increment usage counters
        content_length = int(response.headers.get("content-length", 0))
        await rate_limiter.increment_usage(api_key, content_length)
        
        # Add rate limit headers to response
        response.headers["X-RateLimit-Remaining"] = str(remaining - 1)
        response.headers["X-RateLimit-Reset"] = str(reset_time)
        response.headers["X-Response-Time"] = f"{response_time:.3f}s"
        
        return response
        
    except Exception as e:
        # Calculate response time even for errors
        response_time = time.time() - start_time
        
        # Record error in usage tracking
        await rate_limiter.record_usage(
            api_key=api_key,
            endpoint=request.url.path,
            method=request.method,
            response_time=response_time,
            bytes_processed=0,
            success=False
        )
        
        # Increment usage counters even for errors (failed requests still count)
        await rate_limiter.increment_usage(api_key, 0)
        
        raise


async def api_key_middleware(
    request: Request, 
    call_next: Callable[[Request], Awaitable[any]]
):
    """
    API key validation middleware.

    When VALID_API_KEYS is configured, the key is checked against that
    allow-list using constant-time comparison.  Otherwise falls back to
    a minimum-length check so existing setups keep working.
    """
    if not settings.require_api_key:
        return await call_next(request)

    client_ip = request.client.host if request.client else "unknown"
    api_key = request.headers.get(settings.api_key_header)

    if not api_key:
        _security_log.warning(
            f"api_key_missing | IP: {client_ip} | "
            f"Path: {request.method} {request.url.path}"
        )
        return JSONResponse(
            status_code=status.HTTP_401_UNAUTHORIZED,
            content={"error": "api_key_required", "message": f"API key required in {settings.api_key_header} header"}
        )

    # Validate against the allow-list when configured
    if _VALID_API_KEYS is not None:
        if not _constant_time_key_check(api_key, _VALID_API_KEYS):
            _security_log.warning(
                f"api_key_invalid | IP: {client_ip} | "
                f"Path: {request.method} {request.url.path} | "
                f"KeyHash: {_mask_api_key(api_key)}"
            )
            return JSONResponse(
                status_code=status.HTTP_401_UNAUTHORIZED,
                content={"error": "invalid_api_key", "message": "Invalid API key"}
            )
    else:
        # Fallback: basic length validation
        if len(api_key) < 10:
            _security_log.warning(
                f"api_key_invalid | IP: {client_ip} | "
                f"Path: {request.method} {request.url.path} | "
                f"Reason: key too short"
            )
            return JSONResponse(
                status_code=status.HTTP_401_UNAUTHORIZED,
                content={"error": "invalid_api_key", "message": "Invalid API key format"}
            )

    # Add API key to request state for later use
    request.state.api_key = api_key

    return await call_next(request)


async def monitoring_middleware(
    request: Request, 
    call_next: Callable[[Request], Awaitable[any]]
):
    """
    Monitoring and analytics middleware
    """
    start_time = time.time()
    
    # Log incoming request
    if settings.enable_detailed_logging:
        logger.info(f"Request: {request.method} {request.url.path} - IP: {request.client.host}")
    
    try:
        response = await call_next(request)
        
        # Calculate processing time
        process_time = time.time() - start_time
        
        # Add timing header
        response.headers["X-Process-Time"] = f"{process_time*1000:.2f}ms"
        
        # Log response if detailed logging is enabled
        if settings.enable_detailed_logging:
            logger.info(f"Response: {response.status_code} - Time: {process_time*1000:.2f}ms")
        
        return response
        
    except Exception as e:
        process_time = time.time() - start_time
        
        # Log error
        logger.error(f"Error in {request.method} {request.url.path}: {str(e)} - Time: {process_time*1000:.2f}ms")
        
        # Re-raise the exception to be handled by FastAPI's exception handlers
        raise