|
|
""" |
|
|
Firewall Module |
|
|
|
|
|
Implements packet filtering and access control: |
|
|
- Rule-based packet filtering (allow/block by IP, port, protocol) |
|
|
- Ordered rule processing |
|
|
- Logging and statistics |
|
|
- Dynamic rule management via API |
|
|
""" |
|
|
|
|
|
import time |
|
|
import threading |
|
|
import ipaddress |
|
|
import re |
|
|
from typing import Dict, List, Optional, Tuple, Any |
|
|
from dataclasses import dataclass |
|
|
from enum import Enum |
|
|
|
|
|
from .ip_parser import ParsedPacket, TCPHeader, UDPHeader |
|
|
|
|
|
|
|
|
class FirewallAction(Enum): |
|
|
ACCEPT = "ACCEPT" |
|
|
DROP = "DROP" |
|
|
REJECT = "REJECT" |
|
|
|
|
|
|
|
|
class FirewallDirection(Enum): |
|
|
INBOUND = "INBOUND" |
|
|
OUTBOUND = "OUTBOUND" |
|
|
BOTH = "BOTH" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class FirewallRule: |
|
|
"""Represents a firewall rule""" |
|
|
rule_id: str |
|
|
priority: int |
|
|
action: FirewallAction |
|
|
direction: FirewallDirection |
|
|
|
|
|
|
|
|
source_ip: Optional[str] = None |
|
|
dest_ip: Optional[str] = None |
|
|
source_port: Optional[str] = None |
|
|
dest_port: Optional[str] = None |
|
|
protocol: Optional[str] = None |
|
|
|
|
|
|
|
|
description: str = "" |
|
|
enabled: bool = True |
|
|
created_time: float = 0 |
|
|
hit_count: int = 0 |
|
|
last_hit: Optional[float] = None |
|
|
|
|
|
def __post_init__(self): |
|
|
if self.created_time == 0: |
|
|
self.created_time = time.time() |
|
|
|
|
|
def record_hit(self): |
|
|
"""Record a rule hit""" |
|
|
self.hit_count += 1 |
|
|
self.last_hit = time.time() |
|
|
|
|
|
def to_dict(self) -> Dict: |
|
|
"""Convert rule to dictionary""" |
|
|
return { |
|
|
'rule_id': self.rule_id, |
|
|
'priority': self.priority, |
|
|
'action': self.action.value, |
|
|
'direction': self.direction.value, |
|
|
'source_ip': self.source_ip, |
|
|
'dest_ip': self.dest_ip, |
|
|
'source_port': self.source_port, |
|
|
'dest_port': self.dest_port, |
|
|
'protocol': self.protocol, |
|
|
'description': self.description, |
|
|
'enabled': self.enabled, |
|
|
'created_time': self.created_time, |
|
|
'hit_count': self.hit_count, |
|
|
'last_hit': self.last_hit |
|
|
} |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class FirewallLogEntry: |
|
|
"""Represents a firewall log entry""" |
|
|
timestamp: float |
|
|
action: str |
|
|
rule_id: Optional[str] |
|
|
source_ip: str |
|
|
dest_ip: str |
|
|
source_port: int |
|
|
dest_port: int |
|
|
protocol: str |
|
|
packet_size: int |
|
|
reason: str = "" |
|
|
|
|
|
def to_dict(self) -> Dict: |
|
|
"""Convert log entry to dictionary""" |
|
|
return { |
|
|
'timestamp': self.timestamp, |
|
|
'action': self.action, |
|
|
'rule_id': self.rule_id, |
|
|
'source_ip': self.source_ip, |
|
|
'dest_ip': self.dest_ip, |
|
|
'source_port': self.source_port, |
|
|
'dest_port': self.dest_port, |
|
|
'protocol': self.protocol, |
|
|
'packet_size': self.packet_size, |
|
|
'reason': self.reason |
|
|
} |
|
|
|
|
|
|
|
|
class FirewallEngine: |
|
|
"""Firewall engine implementation""" |
|
|
|
|
|
def __init__(self, config: Dict): |
|
|
self.config = config |
|
|
self.rules: Dict[str, FirewallRule] = {} |
|
|
self.logs: List[FirewallLogEntry] = [] |
|
|
self.lock = threading.Lock() |
|
|
|
|
|
|
|
|
self.default_policy = FirewallAction(config.get('default_policy', 'ACCEPT')) |
|
|
self.log_blocked = config.get('log_blocked', True) |
|
|
self.log_accepted = config.get('log_accepted', False) |
|
|
self.max_log_entries = config.get('max_log_entries', 10000) |
|
|
|
|
|
|
|
|
self.stats = { |
|
|
'packets_processed': 0, |
|
|
'packets_accepted': 0, |
|
|
'packets_dropped': 0, |
|
|
'packets_rejected': 0, |
|
|
'rules_hit': 0, |
|
|
'default_policy_hits': 0 |
|
|
} |
|
|
|
|
|
|
|
|
initial_rules = config.get('rules', []) |
|
|
for rule_config in initial_rules: |
|
|
self._add_rule_from_config(rule_config) |
|
|
|
|
|
def _add_rule_from_config(self, rule_config: Dict): |
|
|
"""Add rule from configuration""" |
|
|
rule = FirewallRule( |
|
|
rule_id=rule_config['rule_id'], |
|
|
priority=rule_config.get('priority', 100), |
|
|
action=FirewallAction(rule_config['action']), |
|
|
direction=FirewallDirection(rule_config.get('direction', 'BOTH')), |
|
|
source_ip=rule_config.get('source_ip'), |
|
|
dest_ip=rule_config.get('dest_ip'), |
|
|
source_port=rule_config.get('source_port'), |
|
|
dest_port=rule_config.get('dest_port'), |
|
|
protocol=rule_config.get('protocol'), |
|
|
description=rule_config.get('description', ''), |
|
|
enabled=rule_config.get('enabled', True) |
|
|
) |
|
|
|
|
|
with self.lock: |
|
|
self.rules[rule.rule_id] = rule |
|
|
|
|
|
def _match_ip(self, ip: str, pattern: str) -> bool: |
|
|
"""Match IP address against pattern (IP or CIDR)""" |
|
|
try: |
|
|
if '/' in pattern: |
|
|
|
|
|
network = ipaddress.ip_network(pattern, strict=False) |
|
|
return ipaddress.ip_address(ip) in network |
|
|
else: |
|
|
|
|
|
return ip == pattern |
|
|
except (ipaddress.AddressValueError, ValueError): |
|
|
return False |
|
|
|
|
|
def _match_port(self, port: int, pattern: str) -> bool: |
|
|
"""Match port against pattern (port, range, or list)""" |
|
|
try: |
|
|
if ',' in pattern: |
|
|
|
|
|
ports = [int(p.strip()) for p in pattern.split(',')] |
|
|
return port in ports |
|
|
elif '-' in pattern: |
|
|
|
|
|
start, end = map(int, pattern.split('-', 1)) |
|
|
return start <= port <= end |
|
|
else: |
|
|
|
|
|
return port == int(pattern) |
|
|
except (ValueError, TypeError): |
|
|
return False |
|
|
|
|
|
def _match_protocol(self, protocol: str, pattern: str) -> bool: |
|
|
"""Match protocol against pattern""" |
|
|
if pattern is None: |
|
|
return True |
|
|
return protocol.upper() == pattern.upper() |
|
|
|
|
|
def _evaluate_rule(self, rule: FirewallRule, packet: ParsedPacket, direction: FirewallDirection) -> bool: |
|
|
"""Evaluate if a rule matches a packet""" |
|
|
if not rule.enabled: |
|
|
return False |
|
|
|
|
|
|
|
|
if rule.direction != FirewallDirection.BOTH and rule.direction != direction: |
|
|
return False |
|
|
|
|
|
|
|
|
if rule.source_ip and not self._match_ip(packet.ip_header.source_ip, rule.source_ip): |
|
|
return False |
|
|
|
|
|
|
|
|
if rule.dest_ip and not self._match_ip(packet.ip_header.dest_ip, rule.dest_ip): |
|
|
return False |
|
|
|
|
|
|
|
|
if packet.transport_header: |
|
|
if isinstance(packet.transport_header, TCPHeader): |
|
|
protocol = 'TCP' |
|
|
source_port = packet.transport_header.source_port |
|
|
dest_port = packet.transport_header.dest_port |
|
|
elif isinstance(packet.transport_header, UDPHeader): |
|
|
protocol = 'UDP' |
|
|
source_port = packet.transport_header.source_port |
|
|
dest_port = packet.transport_header.dest_port |
|
|
else: |
|
|
protocol = 'OTHER' |
|
|
source_port = 0 |
|
|
dest_port = 0 |
|
|
else: |
|
|
protocol = 'OTHER' |
|
|
source_port = 0 |
|
|
dest_port = 0 |
|
|
|
|
|
if not self._match_protocol(protocol, rule.protocol): |
|
|
return False |
|
|
|
|
|
|
|
|
if rule.source_port and not self._match_port(source_port, rule.source_port): |
|
|
return False |
|
|
|
|
|
|
|
|
if rule.dest_port and not self._match_port(dest_port, rule.dest_port): |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
def _log_packet(self, action: str, packet: ParsedPacket, rule_id: Optional[str] = None, reason: str = ""): |
|
|
"""Log packet processing""" |
|
|
if not (self.log_blocked or self.log_accepted): |
|
|
return |
|
|
|
|
|
|
|
|
if action == 'ACCEPT' and not self.log_accepted: |
|
|
return |
|
|
if action in ['DROP', 'REJECT'] and not self.log_blocked: |
|
|
return |
|
|
|
|
|
|
|
|
if packet.transport_header: |
|
|
if isinstance(packet.transport_header, (TCPHeader, UDPHeader)): |
|
|
source_port = packet.transport_header.source_port |
|
|
dest_port = packet.transport_header.dest_port |
|
|
protocol = 'TCP' if isinstance(packet.transport_header, TCPHeader) else 'UDP' |
|
|
else: |
|
|
source_port = 0 |
|
|
dest_port = 0 |
|
|
protocol = 'OTHER' |
|
|
else: |
|
|
source_port = 0 |
|
|
dest_port = 0 |
|
|
protocol = 'OTHER' |
|
|
|
|
|
log_entry = FirewallLogEntry( |
|
|
timestamp=time.time(), |
|
|
action=action, |
|
|
rule_id=rule_id, |
|
|
source_ip=packet.ip_header.source_ip, |
|
|
dest_ip=packet.ip_header.dest_ip, |
|
|
source_port=source_port, |
|
|
dest_port=dest_port, |
|
|
protocol=protocol, |
|
|
packet_size=len(packet.raw_packet), |
|
|
reason=reason |
|
|
) |
|
|
|
|
|
with self.lock: |
|
|
self.logs.append(log_entry) |
|
|
|
|
|
|
|
|
if len(self.logs) > self.max_log_entries: |
|
|
self.logs = self.logs[-self.max_log_entries:] |
|
|
|
|
|
def process_packet(self, packet: ParsedPacket, direction: FirewallDirection) -> FirewallAction: |
|
|
"""Process packet through firewall rules""" |
|
|
self.stats['packets_processed'] += 1 |
|
|
|
|
|
|
|
|
with self.lock: |
|
|
sorted_rules = sorted(self.rules.values(), key=lambda r: r.priority) |
|
|
|
|
|
|
|
|
for rule in sorted_rules: |
|
|
if self._evaluate_rule(rule, packet, direction): |
|
|
rule.record_hit() |
|
|
self.stats['rules_hit'] += 1 |
|
|
|
|
|
|
|
|
self._log_packet(rule.action.value, packet, rule.rule_id, f"Matched rule: {rule.description}") |
|
|
|
|
|
|
|
|
if rule.action == FirewallAction.ACCEPT: |
|
|
self.stats['packets_accepted'] += 1 |
|
|
elif rule.action == FirewallAction.DROP: |
|
|
self.stats['packets_dropped'] += 1 |
|
|
elif rule.action == FirewallAction.REJECT: |
|
|
self.stats['packets_rejected'] += 1 |
|
|
|
|
|
return rule.action |
|
|
|
|
|
|
|
|
self.stats['default_policy_hits'] += 1 |
|
|
self._log_packet(self.default_policy.value, packet, None, "Default policy") |
|
|
|
|
|
if self.default_policy == FirewallAction.ACCEPT: |
|
|
self.stats['packets_accepted'] += 1 |
|
|
elif self.default_policy == FirewallAction.DROP: |
|
|
self.stats['packets_dropped'] += 1 |
|
|
elif self.default_policy == FirewallAction.REJECT: |
|
|
self.stats['packets_rejected'] += 1 |
|
|
|
|
|
return self.default_policy |
|
|
|
|
|
def add_rule(self, rule: FirewallRule) -> bool: |
|
|
"""Add firewall rule""" |
|
|
with self.lock: |
|
|
if rule.rule_id in self.rules: |
|
|
return False |
|
|
self.rules[rule.rule_id] = rule |
|
|
return True |
|
|
|
|
|
def remove_rule(self, rule_id: str) -> bool: |
|
|
"""Remove firewall rule""" |
|
|
with self.lock: |
|
|
if rule_id in self.rules: |
|
|
del self.rules[rule_id] |
|
|
return True |
|
|
return False |
|
|
|
|
|
def update_rule(self, rule_id: str, **kwargs) -> bool: |
|
|
"""Update firewall rule""" |
|
|
with self.lock: |
|
|
if rule_id not in self.rules: |
|
|
return False |
|
|
|
|
|
rule = self.rules[rule_id] |
|
|
for key, value in kwargs.items(): |
|
|
if hasattr(rule, key): |
|
|
if key in ['action', 'direction']: |
|
|
|
|
|
if key == 'action': |
|
|
value = FirewallAction(value) |
|
|
elif key == 'direction': |
|
|
value = FirewallDirection(value) |
|
|
setattr(rule, key, value) |
|
|
|
|
|
return True |
|
|
|
|
|
def enable_rule(self, rule_id: str) -> bool: |
|
|
"""Enable firewall rule""" |
|
|
return self.update_rule(rule_id, enabled=True) |
|
|
|
|
|
def disable_rule(self, rule_id: str) -> bool: |
|
|
"""Disable firewall rule""" |
|
|
return self.update_rule(rule_id, enabled=False) |
|
|
|
|
|
def get_rules(self) -> List[Dict]: |
|
|
"""Get all firewall rules""" |
|
|
with self.lock: |
|
|
return [rule.to_dict() for rule in sorted(self.rules.values(), key=lambda r: r.priority)] |
|
|
|
|
|
def get_rule(self, rule_id: str) -> Optional[Dict]: |
|
|
"""Get specific firewall rule""" |
|
|
with self.lock: |
|
|
rule = self.rules.get(rule_id) |
|
|
return rule.to_dict() if rule else None |
|
|
|
|
|
def get_logs(self, limit: int = 100, filter_action: Optional[str] = None) -> List[Dict]: |
|
|
"""Get firewall logs""" |
|
|
with self.lock: |
|
|
logs = self.logs.copy() |
|
|
|
|
|
|
|
|
if filter_action: |
|
|
logs = [log for log in logs if log.action == filter_action.upper()] |
|
|
|
|
|
|
|
|
return [log.to_dict() for log in logs[-limit:]] |
|
|
|
|
|
def clear_logs(self): |
|
|
"""Clear firewall logs""" |
|
|
with self.lock: |
|
|
self.logs.clear() |
|
|
|
|
|
def get_stats(self) -> Dict: |
|
|
"""Get firewall statistics""" |
|
|
with self.lock: |
|
|
stats = self.stats.copy() |
|
|
stats['total_rules'] = len(self.rules) |
|
|
stats['enabled_rules'] = sum(1 for rule in self.rules.values() if rule.enabled) |
|
|
stats['log_entries'] = len(self.logs) |
|
|
stats['default_policy'] = self.default_policy.value |
|
|
|
|
|
return stats |
|
|
|
|
|
def reset_stats(self): |
|
|
"""Reset firewall statistics""" |
|
|
self.stats = { |
|
|
'packets_processed': 0, |
|
|
'packets_accepted': 0, |
|
|
'packets_dropped': 0, |
|
|
'packets_rejected': 0, |
|
|
'rules_hit': 0, |
|
|
'default_policy_hits': 0 |
|
|
} |
|
|
|
|
|
|
|
|
with self.lock: |
|
|
for rule in self.rules.values(): |
|
|
rule.hit_count = 0 |
|
|
rule.last_hit = None |
|
|
|
|
|
def set_default_policy(self, policy: str): |
|
|
"""Set default firewall policy""" |
|
|
self.default_policy = FirewallAction(policy.upper()) |
|
|
|
|
|
def export_rules(self) -> List[Dict]: |
|
|
"""Export rules for backup/configuration""" |
|
|
return self.get_rules() |
|
|
|
|
|
def import_rules(self, rules_config: List[Dict], replace: bool = False): |
|
|
"""Import rules from configuration""" |
|
|
if replace: |
|
|
with self.lock: |
|
|
self.rules.clear() |
|
|
|
|
|
for rule_config in rules_config: |
|
|
self._add_rule_from_config(rule_config) |
|
|
|
|
|
|
|
|
class FirewallRuleBuilder: |
|
|
"""Helper class to build firewall rules""" |
|
|
|
|
|
def __init__(self, rule_id: str): |
|
|
self.rule_id = rule_id |
|
|
self.priority = 100 |
|
|
self.action = FirewallAction.ACCEPT |
|
|
self.direction = FirewallDirection.BOTH |
|
|
self.source_ip = None |
|
|
self.dest_ip = None |
|
|
self.source_port = None |
|
|
self.dest_port = None |
|
|
self.protocol = None |
|
|
self.description = "" |
|
|
self.enabled = True |
|
|
|
|
|
def set_priority(self, priority: int) -> 'FirewallRuleBuilder': |
|
|
self.priority = priority |
|
|
return self |
|
|
|
|
|
def set_action(self, action: str) -> 'FirewallRuleBuilder': |
|
|
self.action = FirewallAction(action.upper()) |
|
|
return self |
|
|
|
|
|
def set_direction(self, direction: str) -> 'FirewallRuleBuilder': |
|
|
self.direction = FirewallDirection(direction.upper()) |
|
|
return self |
|
|
|
|
|
def set_source_ip(self, ip: str) -> 'FirewallRuleBuilder': |
|
|
self.source_ip = ip |
|
|
return self |
|
|
|
|
|
def set_dest_ip(self, ip: str) -> 'FirewallRuleBuilder': |
|
|
self.dest_ip = ip |
|
|
return self |
|
|
|
|
|
def set_source_port(self, port: str) -> 'FirewallRuleBuilder': |
|
|
self.source_port = port |
|
|
return self |
|
|
|
|
|
def set_dest_port(self, port: str) -> 'FirewallRuleBuilder': |
|
|
self.dest_port = port |
|
|
return self |
|
|
|
|
|
def set_protocol(self, protocol: str) -> 'FirewallRuleBuilder': |
|
|
self.protocol = protocol.upper() |
|
|
return self |
|
|
|
|
|
def set_description(self, description: str) -> 'FirewallRuleBuilder': |
|
|
self.description = description |
|
|
return self |
|
|
|
|
|
def set_enabled(self, enabled: bool) -> 'FirewallRuleBuilder': |
|
|
self.enabled = enabled |
|
|
return self |
|
|
|
|
|
def build(self) -> FirewallRule: |
|
|
"""Build the firewall rule""" |
|
|
return FirewallRule( |
|
|
rule_id=self.rule_id, |
|
|
priority=self.priority, |
|
|
action=self.action, |
|
|
direction=self.direction, |
|
|
source_ip=self.source_ip, |
|
|
dest_ip=self.dest_ip, |
|
|
source_port=self.source_port, |
|
|
dest_port=self.dest_port, |
|
|
protocol=self.protocol, |
|
|
description=self.description, |
|
|
enabled=self.enabled |
|
|
) |
|
|
|
|
|
|