""" Firewall Module Implements packet filtering and access control: - Rule-based packet filtering (allow/block by IP, port, protocol) - Ordered rule processing - Logging and statistics - Dynamic rule management via API """ import time import threading import ipaddress import re from typing import Dict, List, Optional, Tuple, Any from dataclasses import dataclass from enum import Enum from .ip_parser import ParsedPacket, TCPHeader, UDPHeader class FirewallAction(Enum): ACCEPT = "ACCEPT" DROP = "DROP" REJECT = "REJECT" class FirewallDirection(Enum): INBOUND = "INBOUND" OUTBOUND = "OUTBOUND" BOTH = "BOTH" @dataclass class FirewallRule: """Represents a firewall rule""" rule_id: str priority: int # Lower number = higher priority action: FirewallAction direction: FirewallDirection # Match criteria source_ip: Optional[str] = None # IP or CIDR dest_ip: Optional[str] = None # IP or CIDR source_port: Optional[str] = None # Port or range (e.g., "80", "80-90", "80,443") dest_port: Optional[str] = None # Port or range protocol: Optional[str] = None # TCP, UDP, ICMP, or None for any # Metadata description: str = "" enabled: bool = True created_time: float = 0 hit_count: int = 0 last_hit: Optional[float] = None def __post_init__(self): if self.created_time == 0: self.created_time = time.time() def record_hit(self): """Record a rule hit""" self.hit_count += 1 self.last_hit = time.time() def to_dict(self) -> Dict: """Convert rule to dictionary""" return { 'rule_id': self.rule_id, 'priority': self.priority, 'action': self.action.value, 'direction': self.direction.value, 'source_ip': self.source_ip, 'dest_ip': self.dest_ip, 'source_port': self.source_port, 'dest_port': self.dest_port, 'protocol': self.protocol, 'description': self.description, 'enabled': self.enabled, 'created_time': self.created_time, 'hit_count': self.hit_count, 'last_hit': self.last_hit } @dataclass class FirewallLogEntry: """Represents a firewall log entry""" timestamp: float action: str rule_id: Optional[str] source_ip: str dest_ip: str source_port: int dest_port: int protocol: str packet_size: int reason: str = "" def to_dict(self) -> Dict: """Convert log entry to dictionary""" return { 'timestamp': self.timestamp, 'action': self.action, 'rule_id': self.rule_id, 'source_ip': self.source_ip, 'dest_ip': self.dest_ip, 'source_port': self.source_port, 'dest_port': self.dest_port, 'protocol': self.protocol, 'packet_size': self.packet_size, 'reason': self.reason } class FirewallEngine: """Firewall engine implementation""" def __init__(self, config: Dict): self.config = config self.rules: Dict[str, FirewallRule] = {} self.logs: List[FirewallLogEntry] = [] self.lock = threading.Lock() # Configuration self.default_policy = FirewallAction(config.get('default_policy', 'ACCEPT')) self.log_blocked = config.get('log_blocked', True) self.log_accepted = config.get('log_accepted', False) self.max_log_entries = config.get('max_log_entries', 10000) # Statistics self.stats = { 'packets_processed': 0, 'packets_accepted': 0, 'packets_dropped': 0, 'packets_rejected': 0, 'rules_hit': 0, 'default_policy_hits': 0 } # Load initial rules initial_rules = config.get('rules', []) for rule_config in initial_rules: self._add_rule_from_config(rule_config) def _add_rule_from_config(self, rule_config: Dict): """Add rule from configuration""" rule = FirewallRule( rule_id=rule_config['rule_id'], priority=rule_config.get('priority', 100), action=FirewallAction(rule_config['action']), direction=FirewallDirection(rule_config.get('direction', 'BOTH')), source_ip=rule_config.get('source_ip'), dest_ip=rule_config.get('dest_ip'), source_port=rule_config.get('source_port'), dest_port=rule_config.get('dest_port'), protocol=rule_config.get('protocol'), description=rule_config.get('description', ''), enabled=rule_config.get('enabled', True) ) with self.lock: self.rules[rule.rule_id] = rule def _match_ip(self, ip: str, pattern: str) -> bool: """Match IP address against pattern (IP or CIDR)""" try: if '/' in pattern: # CIDR notation network = ipaddress.ip_network(pattern, strict=False) return ipaddress.ip_address(ip) in network else: # Exact IP match return ip == pattern except (ipaddress.AddressValueError, ValueError): return False def _match_port(self, port: int, pattern: str) -> bool: """Match port against pattern (port, range, or list)""" try: if ',' in pattern: # List of ports: "80,443,8080" ports = [int(p.strip()) for p in pattern.split(',')] return port in ports elif '-' in pattern: # Port range: "80-90" start, end = map(int, pattern.split('-', 1)) return start <= port <= end else: # Single port: "80" return port == int(pattern) except (ValueError, TypeError): return False def _match_protocol(self, protocol: str, pattern: str) -> bool: """Match protocol against pattern""" if pattern is None: return True # Match any protocol return protocol.upper() == pattern.upper() def _evaluate_rule(self, rule: FirewallRule, packet: ParsedPacket, direction: FirewallDirection) -> bool: """Evaluate if a rule matches a packet""" if not rule.enabled: return False # Check direction if rule.direction != FirewallDirection.BOTH and rule.direction != direction: return False # Check source IP if rule.source_ip and not self._match_ip(packet.ip_header.source_ip, rule.source_ip): return False # Check destination IP if rule.dest_ip and not self._match_ip(packet.ip_header.dest_ip, rule.dest_ip): return False # Check protocol if packet.transport_header: if isinstance(packet.transport_header, TCPHeader): protocol = 'TCP' source_port = packet.transport_header.source_port dest_port = packet.transport_header.dest_port elif isinstance(packet.transport_header, UDPHeader): protocol = 'UDP' source_port = packet.transport_header.source_port dest_port = packet.transport_header.dest_port else: protocol = 'OTHER' source_port = 0 dest_port = 0 else: protocol = 'OTHER' source_port = 0 dest_port = 0 if not self._match_protocol(protocol, rule.protocol): return False # Check source port if rule.source_port and not self._match_port(source_port, rule.source_port): return False # Check destination port if rule.dest_port and not self._match_port(dest_port, rule.dest_port): return False return True def _log_packet(self, action: str, packet: ParsedPacket, rule_id: Optional[str] = None, reason: str = ""): """Log packet processing""" if not (self.log_blocked or self.log_accepted): return # Only log if configured if action == 'ACCEPT' and not self.log_accepted: return if action in ['DROP', 'REJECT'] and not self.log_blocked: return # Extract packet information if packet.transport_header: if isinstance(packet.transport_header, (TCPHeader, UDPHeader)): source_port = packet.transport_header.source_port dest_port = packet.transport_header.dest_port protocol = 'TCP' if isinstance(packet.transport_header, TCPHeader) else 'UDP' else: source_port = 0 dest_port = 0 protocol = 'OTHER' else: source_port = 0 dest_port = 0 protocol = 'OTHER' log_entry = FirewallLogEntry( timestamp=time.time(), action=action, rule_id=rule_id, source_ip=packet.ip_header.source_ip, dest_ip=packet.ip_header.dest_ip, source_port=source_port, dest_port=dest_port, protocol=protocol, packet_size=len(packet.raw_packet), reason=reason ) with self.lock: self.logs.append(log_entry) # Trim logs if too many if len(self.logs) > self.max_log_entries: self.logs = self.logs[-self.max_log_entries:] def process_packet(self, packet: ParsedPacket, direction: FirewallDirection) -> FirewallAction: """Process packet through firewall rules""" self.stats['packets_processed'] += 1 # Get sorted rules by priority with self.lock: sorted_rules = sorted(self.rules.values(), key=lambda r: r.priority) # Evaluate rules in order for rule in sorted_rules: if self._evaluate_rule(rule, packet, direction): rule.record_hit() self.stats['rules_hit'] += 1 # Log the action self._log_packet(rule.action.value, packet, rule.rule_id, f"Matched rule: {rule.description}") # Update statistics if rule.action == FirewallAction.ACCEPT: self.stats['packets_accepted'] += 1 elif rule.action == FirewallAction.DROP: self.stats['packets_dropped'] += 1 elif rule.action == FirewallAction.REJECT: self.stats['packets_rejected'] += 1 return rule.action # No rule matched, apply default policy self.stats['default_policy_hits'] += 1 self._log_packet(self.default_policy.value, packet, None, "Default policy") if self.default_policy == FirewallAction.ACCEPT: self.stats['packets_accepted'] += 1 elif self.default_policy == FirewallAction.DROP: self.stats['packets_dropped'] += 1 elif self.default_policy == FirewallAction.REJECT: self.stats['packets_rejected'] += 1 return self.default_policy def add_rule(self, rule: FirewallRule) -> bool: """Add firewall rule""" with self.lock: if rule.rule_id in self.rules: return False self.rules[rule.rule_id] = rule return True def remove_rule(self, rule_id: str) -> bool: """Remove firewall rule""" with self.lock: if rule_id in self.rules: del self.rules[rule_id] return True return False def update_rule(self, rule_id: str, **kwargs) -> bool: """Update firewall rule""" with self.lock: if rule_id not in self.rules: return False rule = self.rules[rule_id] for key, value in kwargs.items(): if hasattr(rule, key): if key in ['action', 'direction']: # Handle enum values if key == 'action': value = FirewallAction(value) elif key == 'direction': value = FirewallDirection(value) setattr(rule, key, value) return True def enable_rule(self, rule_id: str) -> bool: """Enable firewall rule""" return self.update_rule(rule_id, enabled=True) def disable_rule(self, rule_id: str) -> bool: """Disable firewall rule""" return self.update_rule(rule_id, enabled=False) def get_rules(self) -> List[Dict]: """Get all firewall rules""" with self.lock: return [rule.to_dict() for rule in sorted(self.rules.values(), key=lambda r: r.priority)] def get_rule(self, rule_id: str) -> Optional[Dict]: """Get specific firewall rule""" with self.lock: rule = self.rules.get(rule_id) return rule.to_dict() if rule else None def get_logs(self, limit: int = 100, filter_action: Optional[str] = None) -> List[Dict]: """Get firewall logs""" with self.lock: logs = self.logs.copy() # Filter by action if specified if filter_action: logs = [log for log in logs if log.action == filter_action.upper()] # Return most recent logs return [log.to_dict() for log in logs[-limit:]] def clear_logs(self): """Clear firewall logs""" with self.lock: self.logs.clear() def get_stats(self) -> Dict: """Get firewall statistics""" with self.lock: stats = self.stats.copy() stats['total_rules'] = len(self.rules) stats['enabled_rules'] = sum(1 for rule in self.rules.values() if rule.enabled) stats['log_entries'] = len(self.logs) stats['default_policy'] = self.default_policy.value return stats def reset_stats(self): """Reset firewall statistics""" self.stats = { 'packets_processed': 0, 'packets_accepted': 0, 'packets_dropped': 0, 'packets_rejected': 0, 'rules_hit': 0, 'default_policy_hits': 0 } # Reset rule hit counts with self.lock: for rule in self.rules.values(): rule.hit_count = 0 rule.last_hit = None def set_default_policy(self, policy: str): """Set default firewall policy""" self.default_policy = FirewallAction(policy.upper()) def export_rules(self) -> List[Dict]: """Export rules for backup/configuration""" return self.get_rules() def import_rules(self, rules_config: List[Dict], replace: bool = False): """Import rules from configuration""" if replace: with self.lock: self.rules.clear() for rule_config in rules_config: self._add_rule_from_config(rule_config) class FirewallRuleBuilder: """Helper class to build firewall rules""" def __init__(self, rule_id: str): self.rule_id = rule_id self.priority = 100 self.action = FirewallAction.ACCEPT self.direction = FirewallDirection.BOTH self.source_ip = None self.dest_ip = None self.source_port = None self.dest_port = None self.protocol = None self.description = "" self.enabled = True def set_priority(self, priority: int) -> 'FirewallRuleBuilder': self.priority = priority return self def set_action(self, action: str) -> 'FirewallRuleBuilder': self.action = FirewallAction(action.upper()) return self def set_direction(self, direction: str) -> 'FirewallRuleBuilder': self.direction = FirewallDirection(direction.upper()) return self def set_source_ip(self, ip: str) -> 'FirewallRuleBuilder': self.source_ip = ip return self def set_dest_ip(self, ip: str) -> 'FirewallRuleBuilder': self.dest_ip = ip return self def set_source_port(self, port: str) -> 'FirewallRuleBuilder': self.source_port = port return self def set_dest_port(self, port: str) -> 'FirewallRuleBuilder': self.dest_port = port return self def set_protocol(self, protocol: str) -> 'FirewallRuleBuilder': self.protocol = protocol.upper() return self def set_description(self, description: str) -> 'FirewallRuleBuilder': self.description = description return self def set_enabled(self, enabled: bool) -> 'FirewallRuleBuilder': self.enabled = enabled return self def build(self) -> FirewallRule: """Build the firewall rule""" return FirewallRule( rule_id=self.rule_id, priority=self.priority, action=self.action, direction=self.direction, source_ip=self.source_ip, dest_ip=self.dest_ip, source_port=self.source_port, dest_port=self.dest_port, protocol=self.protocol, description=self.description, enabled=self.enabled )