Spaces:
Sleeping
Sleeping
| from fastapi import Request, Response | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| from datetime import datetime | |
| import json | |
| import logging | |
| from typing import Dict, Any | |
| from app.core.nosql_client import db | |
| # Use app-level logging configuration | |
| logger = logging.getLogger(__name__) | |
| class SecurityMiddleware(BaseHTTPMiddleware): | |
| """ | |
| Enhanced security middleware for request logging, device tracking, and security monitoring | |
| """ | |
| def __init__(self, app): | |
| super().__init__(app) | |
| self.security_collection = db.security_logs | |
| self.device_collection = db.device_tracking | |
| def get_client_ip(self, request: Request) -> str: | |
| """Extract client IP from request headers""" | |
| # Check for forwarded headers first (for proxy/load balancer scenarios) | |
| forwarded_for = request.headers.get("X-Forwarded-For") | |
| if forwarded_for: | |
| return forwarded_for.split(",")[0].strip() | |
| real_ip = request.headers.get("X-Real-IP") | |
| if real_ip: | |
| return real_ip | |
| # Fallback to direct client IP | |
| return request.client.host if request.client else "unknown" | |
| def extract_device_info(self, request: Request) -> Dict[str, Any]: | |
| """Extract device and browser information from request headers""" | |
| user_agent = request.headers.get("User-Agent", "") | |
| accept_language = request.headers.get("Accept-Language", "") | |
| accept_encoding = request.headers.get("Accept-Encoding", "") | |
| return { | |
| "user_agent": user_agent, | |
| "accept_language": accept_language, | |
| "accept_encoding": accept_encoding, | |
| "platform": self._parse_platform(user_agent), | |
| "browser": self._parse_browser(user_agent) | |
| } | |
| def _parse_platform(self, user_agent: str) -> str: | |
| """Parse platform from user agent string""" | |
| user_agent_lower = user_agent.lower() | |
| if "windows" in user_agent_lower: | |
| return "Windows" | |
| elif "macintosh" in user_agent_lower or "mac os" in user_agent_lower: | |
| return "macOS" | |
| elif "linux" in user_agent_lower: | |
| return "Linux" | |
| elif "android" in user_agent_lower: | |
| return "Android" | |
| elif "iphone" in user_agent_lower or "ipad" in user_agent_lower: | |
| return "iOS" | |
| else: | |
| return "Unknown" | |
| def _parse_browser(self, user_agent: str) -> str: | |
| """Parse browser from user agent string""" | |
| user_agent_lower = user_agent.lower() | |
| if "chrome" in user_agent_lower and "edg" not in user_agent_lower: | |
| return "Chrome" | |
| elif "firefox" in user_agent_lower: | |
| return "Firefox" | |
| elif "safari" in user_agent_lower and "chrome" not in user_agent_lower: | |
| return "Safari" | |
| elif "edg" in user_agent_lower: | |
| return "Edge" | |
| elif "opera" in user_agent_lower: | |
| return "Opera" | |
| else: | |
| return "Unknown" | |
| def is_sensitive_endpoint(self, path: str) -> bool: | |
| """Check if the endpoint is security-sensitive and should be logged""" | |
| sensitive_paths = [ | |
| "/auth/", | |
| "/login", | |
| "/register", | |
| "/otp", | |
| "/oauth", | |
| "/profile", | |
| "/account", | |
| "/security" | |
| ] | |
| return any(sensitive_path in path for sensitive_path in sensitive_paths) | |
| async def log_security_event(self, request: Request, response: Response, | |
| processing_time: float, client_ip: str, | |
| device_info: Dict[str, Any]): | |
| """Log security-relevant events to database""" | |
| try: | |
| # Only log sensitive endpoints or failed requests | |
| if not (self.is_sensitive_endpoint(str(request.url.path)) or response.status_code >= 400): | |
| return | |
| log_entry = { | |
| "timestamp": datetime.utcnow(), | |
| "method": request.method, | |
| "path": str(request.url.path), | |
| "query_params": dict(request.query_params), | |
| "client_ip": client_ip, | |
| "status_code": response.status_code, | |
| "processing_time_ms": round(processing_time * 1000, 2), | |
| "device_info": device_info, | |
| "headers": { | |
| "user_agent": request.headers.get("User-Agent", ""), | |
| "referer": request.headers.get("Referer", ""), | |
| "content_type": request.headers.get("Content-Type", "") | |
| }, | |
| "is_suspicious": self._detect_suspicious_activity(request, response, client_ip) | |
| } | |
| # Add user ID if available from JWT token | |
| auth_header = request.headers.get("Authorization") | |
| if auth_header and auth_header.startswith("Bearer "): | |
| try: | |
| from app.utils.jwt import decode_token | |
| token = auth_header.split(" ")[1] | |
| payload = decode_token(token) | |
| log_entry["customer_id"] = payload.get("sub") | |
| except Exception: | |
| pass # Token might be invalid or expired | |
| await self.security_collection.insert_one(log_entry) | |
| except Exception as e: | |
| logger.error(f"Failed to log security event: {str(e)}") | |
| async def track_device(self, client_ip: str, device_info: Dict[str, Any], | |
| customer_id: str = None): | |
| """Track device information for security monitoring""" | |
| try: | |
| device_fingerprint = f"{client_ip}_{device_info.get('user_agent', '')[:100]}" | |
| device_entry = { | |
| "device_fingerprint": device_fingerprint, | |
| "client_ip": client_ip, | |
| "device_info": device_info, | |
| "first_seen": datetime.utcnow(), | |
| "last_seen": datetime.utcnow(), | |
| "customer_id": customer_id, | |
| "access_count": 1, | |
| "is_trusted": False | |
| } | |
| # Update or insert device tracking | |
| await self.device_collection.update_one( | |
| {"device_fingerprint": device_fingerprint}, | |
| { | |
| "$set": { | |
| "last_seen": datetime.utcnow(), | |
| "device_info": device_info | |
| }, | |
| "$inc": {"access_count": 1}, | |
| "$setOnInsert": { | |
| "device_fingerprint": device_fingerprint, | |
| "client_ip": client_ip, | |
| "first_seen": datetime.utcnow(), | |
| "customer_id": customer_id, | |
| "is_trusted": False | |
| } | |
| }, | |
| upsert=True | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to track device: {str(e)}") | |
| def _detect_suspicious_activity(self, request: Request, response: Response, | |
| client_ip: str) -> bool: | |
| """Detect potentially suspicious activity patterns""" | |
| suspicious_indicators = [] | |
| # Check for multiple failed login attempts | |
| if response.status_code == 401 and "login" in str(request.url.path): | |
| suspicious_indicators.append("failed_login") | |
| # Check for unusual user agent patterns | |
| user_agent = request.headers.get("User-Agent", "") | |
| if not user_agent or len(user_agent) < 10: | |
| suspicious_indicators.append("suspicious_user_agent") | |
| # Check for rapid requests (basic detection) | |
| if hasattr(request.state, "request_count") and request.state.request_count > 10: | |
| suspicious_indicators.append("rapid_requests") | |
| # Check for access to sensitive endpoints without proper authentication | |
| if (self.is_sensitive_endpoint(str(request.url.path)) and | |
| response.status_code == 403 and | |
| not request.headers.get("Authorization")): | |
| suspicious_indicators.append("unauthorized_sensitive_access") | |
| return len(suspicious_indicators) > 0 | |
| async def dispatch(self, request: Request, call_next): | |
| """Main middleware dispatch method""" | |
| start_time = datetime.utcnow() | |
| # Extract client information | |
| client_ip = self.get_client_ip(request) | |
| device_info = self.extract_device_info(request) | |
| # Process the request | |
| response = await call_next(request) | |
| # Calculate processing time | |
| end_time = datetime.utcnow() | |
| processing_time = (end_time - start_time).total_seconds() | |
| # Log security events asynchronously | |
| try: | |
| await self.log_security_event(request, response, processing_time, | |
| client_ip, device_info) | |
| await self.track_device(client_ip, device_info) | |
| except Exception as e: | |
| logger.error(f"Security middleware error: {str(e)}") | |
| # Add security headers to response | |
| response.headers["X-Content-Type-Options"] = "nosniff" | |
| response.headers["X-Frame-Options"] = "DENY" | |
| response.headers["X-XSS-Protection"] = "1; mode=block" | |
| response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" | |
| return response |