zenith-backend / app /security /threat_detection.py
teoat's picture
fix(backend): fix port and health check robustness
d29a5a0 verified
"""
Advanced threat detection and security analytics
"""
import json
import logging
import os
import re
import time
from collections import defaultdict
from datetime import datetime
from functools import wraps
from typing import Any, Optional
from fastapi import HTTPException
from geoip2.database import Reader
from geoip2.errors import AddressNotFoundError
logger = logging.getLogger(__name__)
class ThreatDetection:
"""Advanced threat detection and analytics"""
def __init__(self, redis_client=None):
self.redis = redis_client
self.geoip_reader = Reader(os.getenv("GEOIP_DATABASE_PATH", "/usr/share/GeoLite2-City.mmdb"))
self.threat_signatures = self._load_threat_signatures()
self.attack_patterns = self._load_attack_patterns()
self.suspicious_ips = set()
self.blocked_ips = set()
self.threat_logs = []
def _load_threat_signatures(self) -> dict[str, list[str]]:
"""Load threat signatures from configuration"""
# SQL injection patterns
sql_injection_patterns = [
r"(\b(SELECT|INSERT|UPDATE|DELETE|DROP|CREATE|ALTER|EXEC|UNION|ALL)\b.*\b(FROM|INTO)\b)",
r"(?|\s*)(SELECT|INSERT|UPDATE|DELETE|DROP|CREATE|ALTER|EXEC|UNION|ALL)(?|\s*)\s*(FROM|INTO)",
r"(\b(SELECT|INSERT|UPDATE|DELETE|DROP|CREATE|ALTER|EXEC|UNION|ALL)\b.*\b(LOAD_FILE|OUTFILE)\b)",
r"(\b(SELECT|INSERT|UPDATE|DELETE|DROP|CREATE|ALTER|EXEC|UNION|ALL)\b.*\b(BENCHMARK|SLEEP)\b)",
r"(\\;|\|)(\s*SELECT|INSERT|UPDATE|DELETE|DROP|CREATE|ALTER|EXEC)\s*)",
r"/\*[\s\'\"\`]*SELECT|INSERT|UPDATE|DELETE|DROP|CREATE|ALTER|EXEC\*[\s\'\"\`]",
r"EXEC\s*\(\s*\|'[^']*\'[^']*|\")",
]
# XSS patterns
xss_patterns = [
r"<script[^>]*>.*?</script>",
r"javascript:",
r"vbscript:",
r"onload\s*=",
r"onerror\s*=",
r"onclick\s*=",
r"onmouseover\s*=",
r"<iframe.*</iframe>",
r"<object.*</object>",
r"<embed.*</embed>",
r"<link.*rel=.*stylesheet",
r"<meta.*http-equiv=\"refresh",
]
# Path traversal patterns
path_traversal_patterns = [
r"\.\.[\\/]",
r"\.\.[\\/]\.\.",
r"\.\/[^\\\/]",
r"\.{2,}",
r"[\\/]\.{2,}",
r"%[\\/]",
r"(\.\.[\\/])",
r"(\/[\\]\w+\.[\\]\w+)(?![^\\/])",
]
# Command injection patterns
command_injection_patterns = [
r"[;&|`$(){}]",
r"(\b(\||\||&&|;;)\s*)",
r"(\b(>>|>>|>>|\n|\r|\|)(\s*)",
r"[^\w]*\$[^\w]*\(.*\).*\$[^\w]*\)",
]
return {
"sql_injection": sql_injection_patterns,
"xss": xss_patterns,
"path_traversal": path_traversal_patterns,
"command_injection": command_injection_patterns,
}
def _load_attack_patterns(self) -> dict[str, list[str]]:
"""Load known attack patterns"""
# Brute force patterns
brute_force_patterns = [
r"(admin|administrator|root|test|guest|user|login|password)",
r"\b(123456|password|qwerty|letmein|abc123|admin123|root123|guest123)\b",
r"(error|forbidden|unauthorized|denied|invalid|exception)",
r"(union\s+select|insert+select)",
r"(having\s+\d+)", # Common SQLi technique
r"(\b\d{1,4}\s+)", # Numeric escape sequences
]
# Directory traversal patterns
directory_traversal_patterns = [
r"(\/etc\/passwd|\/etc\/shadow|\/etc\/hosts|\/proc\/version)",
r"(\/var\/log|\/var\/run|\/tmp)",
r"(\/boot\.ini|\/etc\/fstab)",
r"(\/system32|\/windows|\/system64)",
r"(\/dev\/random|\/dev\/urandom)",
]
return {
"brute_force": brute_force_patterns,
"directory_traversal": directory_traversal_patterns,
}
async def analyze_request(self, ip: str, user_agent: str, request_data: dict[str, Any]) -> dict[str, Any]:
"""Analyze request for potential threats"""
threats = []
risk_score = 0
# Check against attack patterns
for pattern_type, patterns in self.threat_signatures.items():
for pattern in patterns:
if self._check_pattern(pattern_type, request_data, patterns):
threats.append({"type": pattern_type, "pattern": pattern, "severity": "high"})
risk_score += 20
# Check IP reputation
ip_threat = await self._check_ip_reputation(ip)
if ip_threat:
threats.append(ip_threat)
risk_score += ip_threat["risk_score"]
# Check user agent anomalies
ua_threats = self._analyze_user_agent(user_agent)
if ua_threats:
threats.extend(ua_threats)
risk_score += sum(threat["risk_score"] for threat in ua_threats)
return {
"threats": threats,
"risk_score": risk_score,
"analysis": {
"ip": ip,
"user_agent": user_agent,
"request_size": len(str(request_data)) if request_data else 0,
"suspicious_patterns": len(threats) > 0,
"automated": self._is_automated_request(request_data),
"geo_info": await self._get_geo_info(ip),
},
}
def _check_pattern(self, pattern_type: str, request_data: dict[str, Any], patterns: list[str]) -> bool:
"""Check if request data matches any pattern in a list"""
# Convert request data to string for pattern matching
data_str = str(request_data)
for pattern in patterns:
if re.search(pattern, data_str, re.IGNORECASE):
logger.warning(f"Threat pattern detected - {pattern_type}: {pattern}")
return True
return False
async def _check_ip_reputation(self, ip: str) -> Optional[dict[str, Any]]:
"""Check IP reputation against threat databases"""
# Check blocked IPs
if ip in self.suspicious_ips or ip in self.blocked_ips:
return {"threat_level": "blocked", "risk_score": 100, "reason": "IP in blocklist"}
try:
# Check against GeoIP database
geo_info = await self._get_geo_info(ip)
# Check for high-risk countries
high_risk_countries = ["CN", "RU", "KP", "IR", "SY", "TR", "VN", "MM"]
if geo_info and geo_info.get("country") in high_risk_countries:
return {"threat_level": "high", "risk_score": 30, "reason": f"High risk country: {geo_info['country']}"}
# Check for Tor exit nodes
if geo_info and geo_info.get("is_tor", False):
return {"threat_level": "high", "risk_score": 25, "reason": "Tor exit node detected"}
# Check for suspicious hosting
if geo_info and geo_info.get("is_proxying", False):
return {"threat_level": "medium", "risk_score": 15, "reason": "Proxying detected"}
except AddressNotFoundError:
pass # GeoIP database not available
return None
async def _get_geo_info(self, ip: str) -> Optional[dict[str, Any]]:
"""Get geolocation information for IP"""
try:
response = self.geoip_reader.city(ip)
return {
"country": response.country.iso_code if response else None,
"region": response.subdivisions.iso_code if response else None,
"city": response.city.name if response else None,
"latitude": response.location.latitude if response else None,
"longitude": response.location.longitude if response else None,
"is_tor": response.is_tor,
"is_proxying": response.is_anonymous_proxy,
"postal_code": response.postal.code if response else None,
}
except AddressNotFoundError:
return None
except Exception as e:
logger.error(f"Error getting GeoIP info: {e}")
return None
def _analyze_user_agent(self, user_agent: str) -> list[dict[str, Any]]:
"""Analyze user agent for anomalies"""
threats = []
# Check for bot signatures
bot_patterns = [
r"(bot|crawler|spider|scraper|automated)",
r"(wget|curl|python|java|perl|ruby|php|node)",
r"(scrapy|nmap|masscan|nikto)",
]
for pattern in bot_patterns:
if re.search(pattern, user_agent, re.IGNORECASE):
threats.append({"type": "bot_detection", "pattern": pattern, "severity": "medium", "risk_score": 15})
# Check for headless browsers
headless_patterns = [
r"(headless|phantomjs|selenium|chrome-headless)",
r"(googlebot|bingbot|slurp|duckduckbot)",
]
for pattern in headless_patterns:
if re.search(pattern, user_agent, re.IGNORECASE):
threats.append({"type": "headless_browser", "pattern": pattern, "security": "medium", "risk_score": 10})
# Check for automated tools
automated_patterns = [
r"(postman|insomnia|burp|nmap|metasploit)",
r"(sqlmap|nmap|nikto|masscan|nessus)",
]
for pattern in automated_patterns:
if re.search(pattern, user_agent, re.IGNORECASE):
threats.append({"type": "automated_tool", "pattern": pattern, "severity": "high", "risk_score": 20})
return threats
def _is_automated_request(self, request_data: dict[str, Any]) -> bool:
"""Check if request appears automated"""
# Check for typical automation indicators
automation_indicators = [
"User-Agent" not in request_data.get("headers", {}),
"Content-Length" not in request_data.get("headers", {}),
"Accept-Encoding" not in request_data.get("headers", {}),
len(str(request_data)) < 100, # Very short requests can be probes
"Content-Type" not in request_data.get("headers", {}),
]
# Count indicators
automation_score = sum(1 for indicator in automation_indicators if indicator is True)
return automation_score >= 3 # Require at least 3 indicators
async def log_threat(self, threat_data: dict[str, Any]):
"""Log threat for analysis"""
if self.redis:
threat_id = f"threat:{int(time.time())}"
await self.redis.setex(
f"threat:{threat_id}",
86400, # 24 hours
json.dumps({**threat_data, "logged_at": datetime.utcnow().isoformat(), "severity": "high"}),
)
# Add to daily threat summary
await self.redis.lpush("daily_threats", json.dumps(threat_data))
logger.warning(
f"Threat detected: {threat_data.get('threat_type', 'unknown')} from {threat_data.get('ip', 'unknown')}"
)
async def get_threat_summary(self, hours: int = 24) -> dict[str, Any]:
"""Get threat summary for the last N hours"""
if not self.redis:
return {}
try:
# Get daily threats
daily_threats = await self.redis.lrange("daily_threats", 0, -1)
threats = []
for threat_json in daily_threats:
try:
threat = json.loads(threat_json)
threats.append(threat)
except json.JSONDecodeError:
continue
# Analyze patterns
threat_types = defaultdict(int)
risk_levels = defaultdict(int)
for threat in threats:
threat_type = threat.get("type", "unknown")
threat_types[threat_type] += 1
risk_levels[threat.get("severity", "unknown")] += 1
return {
"total_threats": len(threats),
"timeframe_hours": hours,
"threat_types": dict(threat_types),
"risk_levels": dict(risk_levels),
"most_common_threat": max(threat_types.items(), key=lambda x: x[1]) if threat_types else None,
"highest_risk_level": max(risk_levels.items(), key=lambda x: x[1]) if risk_levels else None,
}
except Exception as e:
logger.error(f"Error getting threat summary: {e}")
return {}
async def add_suspicious_ip(self, ip: str, reason: str = "Suspicious activity") -> bool:
"""Add IP to suspicious list"""
if ip not in self.suspicious_ips:
self.suspicious_ips.add(ip)
logger.warning(f"Added suspicious IP: {ip} - {reason}")
# Store in Redis
if self.redis:
await self.redis.sadd("suspicious_ips", ip)
await self.redis.setex(
f"suspicious_ip:{ip}",
86400,
json.dumps({"added_at": datetime.utcnow().isoformat(), "reason": reason, "ip": ip}),
)
return True
return False
async def block_ip(self, ip: str, duration: int = 3600, reason: str = "Blocked") -> bool:
"""Block IP address for specified duration"""
if ip not in self.blocked_ips:
self.blocked_ips.add(ip)
logger.warning(f"Blocked IP: {ip} for {duration}s - {reason}")
# Store in Redis with expiration
if self.redis:
await self.redis.sadd("blocked_ips", ip)
await self.redis.setex(
f"blocked_ip:{ip}",
duration,
json.dumps(
{"blocked_at": datetime.utcnow().isoformat(), "reason": reason, "ip": ip, "duration": duration}
),
)
return True
return False
# Global threat detection instance
threat_detection = ThreatDetection()
# Decorator for threat protection
def threat_protection(threshold: int = 50):
"""Decorator for request threat protection"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
request = None
if args and hasattr(args[0], "request"):
request = args[0]
if request:
analysis = await threat_detection.analyze_request(
request.client.host,
request.headers.get("user-agent", ""),
request.query_params if hasattr(request, "query_params") else {},
str(request.body) if hasattr(request, "body") else {},
)
if analysis["risk_score"] >= threshold:
await threat_detection.log_threat(
{
"threat_type": "high_risk_request",
"ip": request.client.host,
"user_agent": request.headers.get("user-agent", ""),
"analysis": analysis,
}
)
# This would typically trigger security alert or block
raise HTTPException(f"Request blocked due to high risk score: {analysis['risk_score']}")
return await func(*args, **kwargs)
return wrapper
return decorator