| """
|
| Traffic Forwarding Engine
|
| Handles IP packet forwarding and NAT for VPN tunnels
|
| """
|
|
|
| import asyncio
|
| import socket
|
| import struct
|
| from typing import Dict, Optional, Tuple, Union
|
| from dataclasses import dataclass
|
| import os
|
| from .ip_parser import IPv4Header, IPParser
|
| from .logger import Logger, LogCategory
|
| from .nat_engine import NATEngine
|
|
|
| @dataclass
|
| class ForwardSession:
|
| src_ip: str
|
| dst_ip: str
|
| src_port: int
|
| dst_port: int
|
| protocol: int
|
| created_at: float
|
| last_seen: float
|
| bytes_in: int = 0
|
| bytes_out: int = 0
|
|
|
| class TrafficForwarder:
|
| """Handles packet forwarding and NAT"""
|
|
|
| def __init__(self, logger: Logger, nat_engine: NATEngine):
|
| self.logger = logger
|
| self.nat_engine = nat_engine
|
| self.sessions: Dict[Tuple[str, str, int, int, int], ForwardSession] = {}
|
| self.tcp_connections = {}
|
| self.udp_endpoints = {}
|
|
|
| async def forward_packet(self, data: bytes, client_ip: str) -> Optional[bytes]:
|
| """Forward an IP packet"""
|
| try:
|
|
|
| ip_header = IPParser.parse_ipv4_header(data)
|
|
|
|
|
| translated_packet = self.nat_engine.translate_outbound(data)
|
| if not translated_packet:
|
| return None
|
|
|
|
|
| session_key = (
|
| ip_header.src_ip,
|
| ip_header.dst_ip,
|
| ip_header.protocol,
|
| self._get_src_port(data[ip_header.ihl*4:], ip_header.protocol),
|
| self._get_dst_port(data[ip_header.ihl*4:], ip_header.protocol)
|
| )
|
|
|
| if session_key not in self.sessions:
|
| self.sessions[session_key] = ForwardSession(
|
| src_ip=ip_header.src_ip,
|
| dst_ip=ip_header.dst_ip,
|
| src_port=session_key[3],
|
| dst_port=session_key[4],
|
| protocol=ip_header.protocol,
|
| created_at=asyncio.get_running_loop().time(),
|
| last_seen=asyncio.get_running_loop().time()
|
| )
|
|
|
| session = self.sessions[session_key]
|
| session.last_seen = asyncio.get_running_loop().time()
|
| session.bytes_out += len(data)
|
|
|
|
|
| if ip_header.protocol == socket.IPPROTO_TCP:
|
| return await self._forward_tcp(translated_packet, session)
|
| elif ip_header.protocol == socket.IPPROTO_UDP:
|
| return await self._forward_udp(translated_packet, session)
|
| else:
|
|
|
| return translated_packet
|
|
|
| except Exception as e:
|
| self.logger.error(LogCategory.SYSTEM, "traffic_forwarder", f"Error forwarding packet: {e}")
|
| return None
|
|
|
| async def _forward_tcp(self, data: bytes, session: ForwardSession) -> Optional[bytes]:
|
| """Forward TCP packet"""
|
| try:
|
| ip_header = IPParser.parse_ipv4_header(data)
|
| tcp_header_offset = ip_header.ihl * 4
|
|
|
| if len(data) < tcp_header_offset + 20:
|
| return None
|
|
|
|
|
| tcp_header = data[tcp_header_offset:tcp_header_offset + 20]
|
| flags = tcp_header[13]
|
| seq_num = struct.unpack('!I', tcp_header[4:8])[0]
|
| ack_num = struct.unpack('!I', tcp_header[8:12])[0]
|
|
|
| conn_key = (session.src_ip, session.src_port, session.dst_ip, session.dst_port)
|
|
|
|
|
| if flags & 0x02:
|
| if conn_key not in self.tcp_connections:
|
| self.tcp_connections[conn_key] = {
|
| 'state': 'SYN_SENT',
|
| 'seq': seq_num,
|
| 'ack': 0
|
| }
|
| elif flags & 0x01:
|
| if conn_key in self.tcp_connections:
|
| self.tcp_connections[conn_key]['state'] = 'FIN_WAIT'
|
| elif flags & 0x04:
|
| if conn_key in self.tcp_connections:
|
| del self.tcp_connections[conn_key]
|
|
|
|
|
| return await self._send_packet(data)
|
|
|
| except Exception as e:
|
| self.logger.error(LogCategory.SYSTEM, "traffic_forwarder", f"Error forwarding TCP: {e}")
|
| return None
|
|
|
| async def _forward_udp(self, data: bytes, session: ForwardSession) -> Optional[bytes]:
|
| """Forward UDP packet"""
|
| try:
|
| ip_header = IPParser.parse_ipv4_header(data)
|
| udp_header_offset = ip_header.ihl * 4
|
|
|
| if len(data) < udp_header_offset + 8:
|
| return None
|
|
|
|
|
| endpoint_key = (session.src_ip, session.src_port, session.dst_ip, session.dst_port)
|
| self.udp_endpoints[endpoint_key] = asyncio.get_running_loop().time()
|
|
|
|
|
| return await self._send_packet(data)
|
|
|
| except Exception as e:
|
| self.logger.error(LogCategory.SYSTEM, "traffic_forwarder", f"Error forwarding UDP: {e}")
|
| return None
|
|
|
| async def _send_packet(self, data: bytes) -> Optional[bytes]:
|
| """Send packet to destination"""
|
| try:
|
|
|
|
|
| return data
|
|
|
| except Exception as e:
|
| self.logger.error(LogCategory.SYSTEM, "traffic_forwarder", f"Error sending packet: {e}")
|
| return None
|
|
|
| def _get_src_port(self, transport_header: bytes, protocol: int) -> int:
|
| """Extract source port from transport header"""
|
| if len(transport_header) >= 2:
|
| return struct.unpack('!H', transport_header[0:2])[0]
|
| return 0
|
|
|
| def _get_dst_port(self, transport_header: bytes, protocol: int) -> int:
|
| """Extract destination port from transport header"""
|
| if len(transport_header) >= 4:
|
| return struct.unpack('!H', transport_header[2:4])[0]
|
| return 0
|
|
|
| async def cleanup(self):
|
| """Clean up expired sessions"""
|
| current_time = asyncio.get_running_loop().time()
|
|
|
|
|
| for key, conn in list(self.tcp_connections.items()):
|
| if current_time - conn.get('last_seen', 0) > 300:
|
| del self.tcp_connections[key]
|
|
|
|
|
| for key, last_seen in list(self.udp_endpoints.items()):
|
| if current_time - last_seen > 60:
|
| del self.udp_endpoints[key]
|
|
|
|
|
| for key, session in list(self.sessions.items()):
|
| if current_time - session.last_seen > 300:
|
| del self.sessions[key]
|
|
|