Spaces:
Paused
Paused
| """ | |
| Firewall Module | |
| Implements packet filtering and access control: | |
| - Rule-based packet filtering (allow/block by IP, port, protocol) | |
| - Ordered rule processing | |
| - Logging and statistics | |
| - Dynamic rule management via API | |
| """ | |
| import time | |
| import threading | |
| import ipaddress | |
| import re | |
| from typing import Dict, List, Optional, Tuple, Any | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from .ip_parser import ParsedPacket, TCPHeader, UDPHeader | |
| class FirewallAction(Enum): | |
| ACCEPT = "ACCEPT" | |
| DROP = "DROP" | |
| REJECT = "REJECT" | |
| class FirewallDirection(Enum): | |
| INBOUND = "INBOUND" | |
| OUTBOUND = "OUTBOUND" | |
| BOTH = "BOTH" | |
| class FirewallRule: | |
| """Represents a firewall rule""" | |
| rule_id: str | |
| priority: int # Lower number = higher priority | |
| action: FirewallAction | |
| direction: FirewallDirection | |
| # Match criteria | |
| source_ip: Optional[str] = None # IP or CIDR | |
| dest_ip: Optional[str] = None # IP or CIDR | |
| source_port: Optional[str] = None # Port or range (e.g., "80", "80-90", "80,443") | |
| dest_port: Optional[str] = None # Port or range | |
| protocol: Optional[str] = None # TCP, UDP, ICMP, or None for any | |
| # Metadata | |
| description: str = "" | |
| enabled: bool = True | |
| created_time: float = 0 | |
| hit_count: int = 0 | |
| last_hit: Optional[float] = None | |
| def __post_init__(self): | |
| if self.created_time == 0: | |
| self.created_time = time.time() | |
| def record_hit(self): | |
| """Record a rule hit""" | |
| self.hit_count += 1 | |
| self.last_hit = time.time() | |
| def to_dict(self) -> Dict: | |
| """Convert rule to dictionary""" | |
| return { | |
| 'rule_id': self.rule_id, | |
| 'priority': self.priority, | |
| 'action': self.action.value, | |
| 'direction': self.direction.value, | |
| 'source_ip': self.source_ip, | |
| 'dest_ip': self.dest_ip, | |
| 'source_port': self.source_port, | |
| 'dest_port': self.dest_port, | |
| 'protocol': self.protocol, | |
| 'description': self.description, | |
| 'enabled': self.enabled, | |
| 'created_time': self.created_time, | |
| 'hit_count': self.hit_count, | |
| 'last_hit': self.last_hit | |
| } | |
| class FirewallLogEntry: | |
| """Represents a firewall log entry""" | |
| timestamp: float | |
| action: str | |
| rule_id: Optional[str] | |
| source_ip: str | |
| dest_ip: str | |
| source_port: int | |
| dest_port: int | |
| protocol: str | |
| packet_size: int | |
| reason: str = "" | |
| def to_dict(self) -> Dict: | |
| """Convert log entry to dictionary""" | |
| return { | |
| 'timestamp': self.timestamp, | |
| 'action': self.action, | |
| 'rule_id': self.rule_id, | |
| 'source_ip': self.source_ip, | |
| 'dest_ip': self.dest_ip, | |
| 'source_port': self.source_port, | |
| 'dest_port': self.dest_port, | |
| 'protocol': self.protocol, | |
| 'packet_size': self.packet_size, | |
| 'reason': self.reason | |
| } | |
| class FirewallEngine: | |
| """Firewall engine implementation""" | |
| def __init__(self, config: Dict): | |
| self.config = config | |
| self.rules: Dict[str, FirewallRule] = {} | |
| self.logs: List[FirewallLogEntry] = [] | |
| self.lock = threading.Lock() | |
| # Configuration | |
| self.default_policy = FirewallAction(config.get('default_policy', 'ACCEPT')) | |
| self.log_blocked = config.get('log_blocked', True) | |
| self.log_accepted = config.get('log_accepted', False) | |
| self.max_log_entries = config.get('max_log_entries', 10000) | |
| # Statistics | |
| self.stats = { | |
| 'packets_processed': 0, | |
| 'packets_accepted': 0, | |
| 'packets_dropped': 0, | |
| 'packets_rejected': 0, | |
| 'rules_hit': 0, | |
| 'default_policy_hits': 0 | |
| } | |
| # Load initial rules | |
| initial_rules = config.get('rules', []) | |
| for rule_config in initial_rules: | |
| self._add_rule_from_config(rule_config) | |
| def _add_rule_from_config(self, rule_config: Dict): | |
| """Add rule from configuration""" | |
| rule = FirewallRule( | |
| rule_id=rule_config['rule_id'], | |
| priority=rule_config.get('priority', 100), | |
| action=FirewallAction(rule_config['action']), | |
| direction=FirewallDirection(rule_config.get('direction', 'BOTH')), | |
| source_ip=rule_config.get('source_ip'), | |
| dest_ip=rule_config.get('dest_ip'), | |
| source_port=rule_config.get('source_port'), | |
| dest_port=rule_config.get('dest_port'), | |
| protocol=rule_config.get('protocol'), | |
| description=rule_config.get('description', ''), | |
| enabled=rule_config.get('enabled', True) | |
| ) | |
| with self.lock: | |
| self.rules[rule.rule_id] = rule | |
| def _match_ip(self, ip: str, pattern: str) -> bool: | |
| """Match IP address against pattern (IP or CIDR)""" | |
| try: | |
| if '/' in pattern: | |
| # CIDR notation | |
| network = ipaddress.ip_network(pattern, strict=False) | |
| return ipaddress.ip_address(ip) in network | |
| else: | |
| # Exact IP match | |
| return ip == pattern | |
| except (ipaddress.AddressValueError, ValueError): | |
| return False | |
| def _match_port(self, port: int, pattern: str) -> bool: | |
| """Match port against pattern (port, range, or list)""" | |
| try: | |
| if ',' in pattern: | |
| # List of ports: "80,443,8080" | |
| ports = [int(p.strip()) for p in pattern.split(',')] | |
| return port in ports | |
| elif '-' in pattern: | |
| # Port range: "80-90" | |
| start, end = map(int, pattern.split('-', 1)) | |
| return start <= port <= end | |
| else: | |
| # Single port: "80" | |
| return port == int(pattern) | |
| except (ValueError, TypeError): | |
| return False | |
| def _match_protocol(self, protocol: str, pattern: str) -> bool: | |
| """Match protocol against pattern""" | |
| if pattern is None: | |
| return True # Match any protocol | |
| return protocol.upper() == pattern.upper() | |
| def _evaluate_rule(self, rule: FirewallRule, packet: ParsedPacket, direction: FirewallDirection) -> bool: | |
| """Evaluate if a rule matches a packet""" | |
| if not rule.enabled: | |
| return False | |
| # Check direction | |
| if rule.direction != FirewallDirection.BOTH and rule.direction != direction: | |
| return False | |
| # Check source IP | |
| if rule.source_ip and not self._match_ip(packet.ip_header.source_ip, rule.source_ip): | |
| return False | |
| # Check destination IP | |
| if rule.dest_ip and not self._match_ip(packet.ip_header.dest_ip, rule.dest_ip): | |
| return False | |
| # Check protocol | |
| if packet.transport_header: | |
| if isinstance(packet.transport_header, TCPHeader): | |
| protocol = 'TCP' | |
| source_port = packet.transport_header.source_port | |
| dest_port = packet.transport_header.dest_port | |
| elif isinstance(packet.transport_header, UDPHeader): | |
| protocol = 'UDP' | |
| source_port = packet.transport_header.source_port | |
| dest_port = packet.transport_header.dest_port | |
| else: | |
| protocol = 'OTHER' | |
| source_port = 0 | |
| dest_port = 0 | |
| else: | |
| protocol = 'OTHER' | |
| source_port = 0 | |
| dest_port = 0 | |
| if not self._match_protocol(protocol, rule.protocol): | |
| return False | |
| # Check source port | |
| if rule.source_port and not self._match_port(source_port, rule.source_port): | |
| return False | |
| # Check destination port | |
| if rule.dest_port and not self._match_port(dest_port, rule.dest_port): | |
| return False | |
| return True | |
| def _log_packet(self, action: str, packet: ParsedPacket, rule_id: Optional[str] = None, reason: str = ""): | |
| """Log packet processing""" | |
| if not (self.log_blocked or self.log_accepted): | |
| return | |
| # Only log if configured | |
| if action == 'ACCEPT' and not self.log_accepted: | |
| return | |
| if action in ['DROP', 'REJECT'] and not self.log_blocked: | |
| return | |
| # Extract packet information | |
| if packet.transport_header: | |
| if isinstance(packet.transport_header, (TCPHeader, UDPHeader)): | |
| source_port = packet.transport_header.source_port | |
| dest_port = packet.transport_header.dest_port | |
| protocol = 'TCP' if isinstance(packet.transport_header, TCPHeader) else 'UDP' | |
| else: | |
| source_port = 0 | |
| dest_port = 0 | |
| protocol = 'OTHER' | |
| else: | |
| source_port = 0 | |
| dest_port = 0 | |
| protocol = 'OTHER' | |
| log_entry = FirewallLogEntry( | |
| timestamp=time.time(), | |
| action=action, | |
| rule_id=rule_id, | |
| source_ip=packet.ip_header.source_ip, | |
| dest_ip=packet.ip_header.dest_ip, | |
| source_port=source_port, | |
| dest_port=dest_port, | |
| protocol=protocol, | |
| packet_size=len(packet.raw_packet), | |
| reason=reason | |
| ) | |
| with self.lock: | |
| self.logs.append(log_entry) | |
| # Trim logs if too many | |
| if len(self.logs) > self.max_log_entries: | |
| self.logs = self.logs[-self.max_log_entries:] | |
| def process_packet(self, packet: ParsedPacket, direction: FirewallDirection) -> FirewallAction: | |
| """Process packet through firewall rules""" | |
| self.stats['packets_processed'] += 1 | |
| # Get sorted rules by priority | |
| with self.lock: | |
| sorted_rules = sorted(self.rules.values(), key=lambda r: r.priority) | |
| # Evaluate rules in order | |
| for rule in sorted_rules: | |
| if self._evaluate_rule(rule, packet, direction): | |
| rule.record_hit() | |
| self.stats['rules_hit'] += 1 | |
| # Log the action | |
| self._log_packet(rule.action.value, packet, rule.rule_id, f"Matched rule: {rule.description}") | |
| # Update statistics | |
| if rule.action == FirewallAction.ACCEPT: | |
| self.stats['packets_accepted'] += 1 | |
| elif rule.action == FirewallAction.DROP: | |
| self.stats['packets_dropped'] += 1 | |
| elif rule.action == FirewallAction.REJECT: | |
| self.stats['packets_rejected'] += 1 | |
| return rule.action | |
| # No rule matched, apply default policy | |
| self.stats['default_policy_hits'] += 1 | |
| self._log_packet(self.default_policy.value, packet, None, "Default policy") | |
| if self.default_policy == FirewallAction.ACCEPT: | |
| self.stats['packets_accepted'] += 1 | |
| elif self.default_policy == FirewallAction.DROP: | |
| self.stats['packets_dropped'] += 1 | |
| elif self.default_policy == FirewallAction.REJECT: | |
| self.stats['packets_rejected'] += 1 | |
| return self.default_policy | |
| def add_rule(self, rule: FirewallRule) -> bool: | |
| """Add firewall rule""" | |
| with self.lock: | |
| if rule.rule_id in self.rules: | |
| return False | |
| self.rules[rule.rule_id] = rule | |
| return True | |
| def remove_rule(self, rule_id: str) -> bool: | |
| """Remove firewall rule""" | |
| with self.lock: | |
| if rule_id in self.rules: | |
| del self.rules[rule_id] | |
| return True | |
| return False | |
| def update_rule(self, rule_id: str, **kwargs) -> bool: | |
| """Update firewall rule""" | |
| with self.lock: | |
| if rule_id not in self.rules: | |
| return False | |
| rule = self.rules[rule_id] | |
| for key, value in kwargs.items(): | |
| if hasattr(rule, key): | |
| if key in ['action', 'direction']: | |
| # Handle enum values | |
| if key == 'action': | |
| value = FirewallAction(value) | |
| elif key == 'direction': | |
| value = FirewallDirection(value) | |
| setattr(rule, key, value) | |
| return True | |
| def enable_rule(self, rule_id: str) -> bool: | |
| """Enable firewall rule""" | |
| return self.update_rule(rule_id, enabled=True) | |
| def disable_rule(self, rule_id: str) -> bool: | |
| """Disable firewall rule""" | |
| return self.update_rule(rule_id, enabled=False) | |
| def get_rules(self) -> List[Dict]: | |
| """Get all firewall rules""" | |
| with self.lock: | |
| return [rule.to_dict() for rule in sorted(self.rules.values(), key=lambda r: r.priority)] | |
| def get_rule(self, rule_id: str) -> Optional[Dict]: | |
| """Get specific firewall rule""" | |
| with self.lock: | |
| rule = self.rules.get(rule_id) | |
| return rule.to_dict() if rule else None | |
| def get_logs(self, limit: int = 100, filter_action: Optional[str] = None) -> List[Dict]: | |
| """Get firewall logs""" | |
| with self.lock: | |
| logs = self.logs.copy() | |
| # Filter by action if specified | |
| if filter_action: | |
| logs = [log for log in logs if log.action == filter_action.upper()] | |
| # Return most recent logs | |
| return [log.to_dict() for log in logs[-limit:]] | |
| def clear_logs(self): | |
| """Clear firewall logs""" | |
| with self.lock: | |
| self.logs.clear() | |
| def get_stats(self) -> Dict: | |
| """Get firewall statistics""" | |
| with self.lock: | |
| stats = self.stats.copy() | |
| stats['total_rules'] = len(self.rules) | |
| stats['enabled_rules'] = sum(1 for rule in self.rules.values() if rule.enabled) | |
| stats['log_entries'] = len(self.logs) | |
| stats['default_policy'] = self.default_policy.value | |
| return stats | |
| def reset_stats(self): | |
| """Reset firewall statistics""" | |
| self.stats = { | |
| 'packets_processed': 0, | |
| 'packets_accepted': 0, | |
| 'packets_dropped': 0, | |
| 'packets_rejected': 0, | |
| 'rules_hit': 0, | |
| 'default_policy_hits': 0 | |
| } | |
| # Reset rule hit counts | |
| with self.lock: | |
| for rule in self.rules.values(): | |
| rule.hit_count = 0 | |
| rule.last_hit = None | |
| def set_default_policy(self, policy: str): | |
| """Set default firewall policy""" | |
| self.default_policy = FirewallAction(policy.upper()) | |
| def export_rules(self) -> List[Dict]: | |
| """Export rules for backup/configuration""" | |
| return self.get_rules() | |
| def import_rules(self, rules_config: List[Dict], replace: bool = False): | |
| """Import rules from configuration""" | |
| if replace: | |
| with self.lock: | |
| self.rules.clear() | |
| for rule_config in rules_config: | |
| self._add_rule_from_config(rule_config) | |
| class FirewallRuleBuilder: | |
| """Helper class to build firewall rules""" | |
| def __init__(self, rule_id: str): | |
| self.rule_id = rule_id | |
| self.priority = 100 | |
| self.action = FirewallAction.ACCEPT | |
| self.direction = FirewallDirection.BOTH | |
| self.source_ip = None | |
| self.dest_ip = None | |
| self.source_port = None | |
| self.dest_port = None | |
| self.protocol = None | |
| self.description = "" | |
| self.enabled = True | |
| def set_priority(self, priority: int) -> 'FirewallRuleBuilder': | |
| self.priority = priority | |
| return self | |
| def set_action(self, action: str) -> 'FirewallRuleBuilder': | |
| self.action = FirewallAction(action.upper()) | |
| return self | |
| def set_direction(self, direction: str) -> 'FirewallRuleBuilder': | |
| self.direction = FirewallDirection(direction.upper()) | |
| return self | |
| def set_source_ip(self, ip: str) -> 'FirewallRuleBuilder': | |
| self.source_ip = ip | |
| return self | |
| def set_dest_ip(self, ip: str) -> 'FirewallRuleBuilder': | |
| self.dest_ip = ip | |
| return self | |
| def set_source_port(self, port: str) -> 'FirewallRuleBuilder': | |
| self.source_port = port | |
| return self | |
| def set_dest_port(self, port: str) -> 'FirewallRuleBuilder': | |
| self.dest_port = port | |
| return self | |
| def set_protocol(self, protocol: str) -> 'FirewallRuleBuilder': | |
| self.protocol = protocol.upper() | |
| return self | |
| def set_description(self, description: str) -> 'FirewallRuleBuilder': | |
| self.description = description | |
| return self | |
| def set_enabled(self, enabled: bool) -> 'FirewallRuleBuilder': | |
| self.enabled = enabled | |
| return self | |
| def build(self) -> FirewallRule: | |
| """Build the firewall rule""" | |
| return FirewallRule( | |
| rule_id=self.rule_id, | |
| priority=self.priority, | |
| action=self.action, | |
| direction=self.direction, | |
| source_ip=self.source_ip, | |
| dest_ip=self.dest_ip, | |
| source_port=self.source_port, | |
| dest_port=self.dest_port, | |
| protocol=self.protocol, | |
| description=self.description, | |
| enabled=self.enabled | |
| ) | |