from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware from datetime import datetime import json import logging from typing import Dict, Any from app.core.nosql_client import db # Use app-level logging configuration logger = logging.getLogger(__name__) class SecurityMiddleware(BaseHTTPMiddleware): """ Enhanced security middleware for request logging, device tracking, and security monitoring """ def __init__(self, app): super().__init__(app) self.security_collection = db.security_logs self.device_collection = db.device_tracking def get_client_ip(self, request: Request) -> str: """Extract client IP from request headers""" # Check for forwarded headers first (for proxy/load balancer scenarios) forwarded_for = request.headers.get("X-Forwarded-For") if forwarded_for: return forwarded_for.split(",")[0].strip() real_ip = request.headers.get("X-Real-IP") if real_ip: return real_ip # Fallback to direct client IP return request.client.host if request.client else "unknown" def extract_device_info(self, request: Request) -> Dict[str, Any]: """Extract device and browser information from request headers""" user_agent = request.headers.get("User-Agent", "") accept_language = request.headers.get("Accept-Language", "") accept_encoding = request.headers.get("Accept-Encoding", "") return { "user_agent": user_agent, "accept_language": accept_language, "accept_encoding": accept_encoding, "platform": self._parse_platform(user_agent), "browser": self._parse_browser(user_agent) } def _parse_platform(self, user_agent: str) -> str: """Parse platform from user agent string""" user_agent_lower = user_agent.lower() if "windows" in user_agent_lower: return "Windows" elif "macintosh" in user_agent_lower or "mac os" in user_agent_lower: return "macOS" elif "linux" in user_agent_lower: return "Linux" elif "android" in user_agent_lower: return "Android" elif "iphone" in user_agent_lower or "ipad" in user_agent_lower: return "iOS" else: return "Unknown" def _parse_browser(self, user_agent: str) -> str: """Parse browser from user agent string""" user_agent_lower = user_agent.lower() if "chrome" in user_agent_lower and "edg" not in user_agent_lower: return "Chrome" elif "firefox" in user_agent_lower: return "Firefox" elif "safari" in user_agent_lower and "chrome" not in user_agent_lower: return "Safari" elif "edg" in user_agent_lower: return "Edge" elif "opera" in user_agent_lower: return "Opera" else: return "Unknown" def is_sensitive_endpoint(self, path: str) -> bool: """Check if the endpoint is security-sensitive and should be logged""" sensitive_paths = [ "/auth/", "/login", "/register", "/otp", "/oauth", "/profile", "/account", "/security" ] return any(sensitive_path in path for sensitive_path in sensitive_paths) async def log_security_event(self, request: Request, response: Response, processing_time: float, client_ip: str, device_info: Dict[str, Any]): """Log security-relevant events to database""" try: # Only log sensitive endpoints or failed requests if not (self.is_sensitive_endpoint(str(request.url.path)) or response.status_code >= 400): return log_entry = { "timestamp": datetime.utcnow(), "method": request.method, "path": str(request.url.path), "query_params": dict(request.query_params), "client_ip": client_ip, "status_code": response.status_code, "processing_time_ms": round(processing_time * 1000, 2), "device_info": device_info, "headers": { "user_agent": request.headers.get("User-Agent", ""), "referer": request.headers.get("Referer", ""), "content_type": request.headers.get("Content-Type", "") }, "is_suspicious": self._detect_suspicious_activity(request, response, client_ip) } # Add user ID if available from JWT token auth_header = request.headers.get("Authorization") if auth_header and auth_header.startswith("Bearer "): try: from app.utils.jwt import decode_token token = auth_header.split(" ")[1] payload = decode_token(token) log_entry["customer_id"] = payload.get("sub") except Exception: pass # Token might be invalid or expired await self.security_collection.insert_one(log_entry) except Exception as e: logger.error(f"Failed to log security event: {str(e)}") async def track_device(self, client_ip: str, device_info: Dict[str, Any], customer_id: str = None): """Track device information for security monitoring""" try: device_fingerprint = f"{client_ip}_{device_info.get('user_agent', '')[:100]}" device_entry = { "device_fingerprint": device_fingerprint, "client_ip": client_ip, "device_info": device_info, "first_seen": datetime.utcnow(), "last_seen": datetime.utcnow(), "customer_id": customer_id, "access_count": 1, "is_trusted": False } # Update or insert device tracking await self.device_collection.update_one( {"device_fingerprint": device_fingerprint}, { "$set": { "last_seen": datetime.utcnow(), "device_info": device_info }, "$inc": {"access_count": 1}, "$setOnInsert": { "device_fingerprint": device_fingerprint, "client_ip": client_ip, "first_seen": datetime.utcnow(), "customer_id": customer_id, "is_trusted": False } }, upsert=True ) except Exception as e: logger.error(f"Failed to track device: {str(e)}") def _detect_suspicious_activity(self, request: Request, response: Response, client_ip: str) -> bool: """Detect potentially suspicious activity patterns""" suspicious_indicators = [] # Check for multiple failed login attempts if response.status_code == 401 and "login" in str(request.url.path): suspicious_indicators.append("failed_login") # Check for unusual user agent patterns user_agent = request.headers.get("User-Agent", "") if not user_agent or len(user_agent) < 10: suspicious_indicators.append("suspicious_user_agent") # Check for rapid requests (basic detection) if hasattr(request.state, "request_count") and request.state.request_count > 10: suspicious_indicators.append("rapid_requests") # Check for access to sensitive endpoints without proper authentication if (self.is_sensitive_endpoint(str(request.url.path)) and response.status_code == 403 and not request.headers.get("Authorization")): suspicious_indicators.append("unauthorized_sensitive_access") return len(suspicious_indicators) > 0 async def dispatch(self, request: Request, call_next): """Main middleware dispatch method""" start_time = datetime.utcnow() # Extract client information client_ip = self.get_client_ip(request) device_info = self.extract_device_info(request) # Process the request response = await call_next(request) # Calculate processing time end_time = datetime.utcnow() processing_time = (end_time - start_time).total_seconds() # Log security events asynchronously try: await self.log_security_event(request, response, processing_time, client_ip, device_info) await self.track_device(client_ip, device_info) except Exception as e: logger.error(f"Security middleware error: {str(e)}") # Add security headers to response response.headers["X-Content-Type-Options"] = "nosniff" response.headers["X-Frame-Options"] = "DENY" response.headers["X-XSS-Protection"] = "1; mode=block" response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" return response