File size: 7,368 Bytes
6a5b8d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
"""

Traffic Forwarding Engine

Handles IP packet forwarding and NAT for VPN tunnels

"""

import asyncio
import socket
import struct
from typing import Dict, Optional, Tuple, Union
from dataclasses import dataclass
import os
from .ip_parser import IPv4Header, IPParser
from .logger import Logger, LogCategory
from .nat_engine import NATEngine

@dataclass
class ForwardSession:
    src_ip: str
    dst_ip: str
    src_port: int
    dst_port: int
    protocol: int
    created_at: float
    last_seen: float
    bytes_in: int = 0
    bytes_out: int = 0

class TrafficForwarder:
    """Handles packet forwarding and NAT"""
    
    def __init__(self, logger: Logger, nat_engine: NATEngine):
        self.logger = logger
        self.nat_engine = nat_engine
        self.sessions: Dict[Tuple[str, str, int, int, int], ForwardSession] = {}
        self.tcp_connections = {}
        self.udp_endpoints = {}
        
    async def forward_packet(self, data: bytes, client_ip: str) -> Optional[bytes]:
        """Forward an IP packet"""
        try:
            # Parse IP header
            ip_header = IPParser.parse_ipv4_header(data)
            
            # Apply NAT
            translated_packet = self.nat_engine.translate_outbound(data)
            if not translated_packet:
                return None
                
            # Track session
            session_key = (
                ip_header.src_ip,
                ip_header.dst_ip,
                ip_header.protocol,
                self._get_src_port(data[ip_header.ihl*4:], ip_header.protocol),
                self._get_dst_port(data[ip_header.ihl*4:], ip_header.protocol)
            )
            
            if session_key not in self.sessions:
                self.sessions[session_key] = ForwardSession(
                    src_ip=ip_header.src_ip,
                    dst_ip=ip_header.dst_ip,
                    src_port=session_key[3],
                    dst_port=session_key[4],
                    protocol=ip_header.protocol,
                    created_at=asyncio.get_running_loop().time(),
                    last_seen=asyncio.get_running_loop().time()
                )
            
            session = self.sessions[session_key]
            session.last_seen = asyncio.get_running_loop().time()
            session.bytes_out += len(data)
            
            # Forward based on protocol
            if ip_header.protocol == socket.IPPROTO_TCP:
                return await self._forward_tcp(translated_packet, session)
            elif ip_header.protocol == socket.IPPROTO_UDP:
                return await self._forward_udp(translated_packet, session)
            else:
                # Forward other IP protocols directly
                return translated_packet
                
        except Exception as e:
            self.logger.error(LogCategory.SYSTEM, "traffic_forwarder", f"Error forwarding packet: {e}")
            return None

    async def _forward_tcp(self, data: bytes, session: ForwardSession) -> Optional[bytes]:
        """Forward TCP packet"""
        try:
            ip_header = IPParser.parse_ipv4_header(data)
            tcp_header_offset = ip_header.ihl * 4
            
            if len(data) < tcp_header_offset + 20:  # TCP header is at least 20 bytes
                return None
                
            # Parse TCP header
            tcp_header = data[tcp_header_offset:tcp_header_offset + 20]
            flags = tcp_header[13]
            seq_num = struct.unpack('!I', tcp_header[4:8])[0]
            ack_num = struct.unpack('!I', tcp_header[8:12])[0]
            
            conn_key = (session.src_ip, session.src_port, session.dst_ip, session.dst_port)
            
            # Handle TCP state
            if flags & 0x02:  # SYN
                if conn_key not in self.tcp_connections:
                    self.tcp_connections[conn_key] = {
                        'state': 'SYN_SENT',
                        'seq': seq_num,
                        'ack': 0
                    }
            elif flags & 0x01:  # FIN
                if conn_key in self.tcp_connections:
                    self.tcp_connections[conn_key]['state'] = 'FIN_WAIT'
            elif flags & 0x04:  # RST
                if conn_key in self.tcp_connections:
                    del self.tcp_connections[conn_key]
            
            # Forward the packet
            return await self._send_packet(data)
            
        except Exception as e:
            self.logger.error(LogCategory.SYSTEM, "traffic_forwarder", f"Error forwarding TCP: {e}")
            return None

    async def _forward_udp(self, data: bytes, session: ForwardSession) -> Optional[bytes]:
        """Forward UDP packet"""
        try:
            ip_header = IPParser.parse_ipv4_header(data)
            udp_header_offset = ip_header.ihl * 4
            
            if len(data) < udp_header_offset + 8:  # UDP header is 8 bytes
                return None
                
            # Track UDP endpoint
            endpoint_key = (session.src_ip, session.src_port, session.dst_ip, session.dst_port)
            self.udp_endpoints[endpoint_key] = asyncio.get_running_loop().time()
            
            # Forward the packet
            return await self._send_packet(data)
            
        except Exception as e:
            self.logger.error(LogCategory.SYSTEM, "traffic_forwarder", f"Error forwarding UDP: {e}")
            return None

    async def _send_packet(self, data: bytes) -> Optional[bytes]:
        """Send packet to destination"""
        try:
            # This is where you'd actually send the packet
            # For now, we'll just return it for the VPN server to handle
            return data
            
        except Exception as e:
            self.logger.error(LogCategory.SYSTEM, "traffic_forwarder", f"Error sending packet: {e}")
            return None

    def _get_src_port(self, transport_header: bytes, protocol: int) -> int:
        """Extract source port from transport header"""
        if len(transport_header) >= 2:
            return struct.unpack('!H', transport_header[0:2])[0]
        return 0

    def _get_dst_port(self, transport_header: bytes, protocol: int) -> int:
        """Extract destination port from transport header"""
        if len(transport_header) >= 4:
            return struct.unpack('!H', transport_header[2:4])[0]
        return 0

    async def cleanup(self):
        """Clean up expired sessions"""
        current_time = asyncio.get_running_loop().time()
        
        # Clean TCP connections
        for key, conn in list(self.tcp_connections.items()):
            if current_time - conn.get('last_seen', 0) > 300:  # 5 minutes timeout
                del self.tcp_connections[key]
        
        # Clean UDP endpoints
        for key, last_seen in list(self.udp_endpoints.items()):
            if current_time - last_seen > 60:  # 1 minute timeout
                del self.udp_endpoints[key]
        
        # Clean sessions
        for key, session in list(self.sessions.items()):
            if current_time - session.last_seen > 300:  # 5 minutes timeout
                del self.sessions[key]