File size: 8,060 Bytes
6a5b8d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
"""

Traffic Router Module

Handles routing of all client traffic with bandwidth monitoring

"""

import asyncio
import socket
import logging
import ipaddress
from typing import Dict, Any, Optional, Tuple
from datetime import datetime

from .tcp_forward import OutlineTCPForwardingEngine as TCPForwardingEngine
from .nat_engine import NATEngine

logger = logging.getLogger(__name__)

class BandwidthMonitor:
    def __init__(self):
        self.total_bytes_in = 0
        self.total_bytes_out = 0
        self.user_bandwidth: Dict[str, Dict[str, int]] = {}
        self.start_time = datetime.now()
        
    def update(self, user_id: str, bytes_in: int = 0, bytes_out: int = 0):
        """Update bandwidth usage for a user"""
        if user_id not in self.user_bandwidth:
            self.user_bandwidth[user_id] = {
                "bytes_in": 0,
                "bytes_out": 0,
                "last_update": datetime.now()
            }
            
        self.user_bandwidth[user_id]["bytes_in"] += bytes_in
        self.user_bandwidth[user_id]["bytes_out"] += bytes_out
        self.user_bandwidth[user_id]["last_update"] = datetime.now()
        
        self.total_bytes_in += bytes_in
        self.total_bytes_out += bytes_out
        
    def get_stats(self) -> Dict:
        """Get bandwidth statistics"""
        current_time = datetime.now()
        uptime = (current_time - self.start_time).total_seconds()
        
        return {
            "total_bytes_in": self.total_bytes_in,
            "total_bytes_out": self.total_bytes_out,
            "avg_speed_in": self.total_bytes_in / uptime if uptime > 0 else 0,
            "avg_speed_out": self.total_bytes_out / uptime if uptime > 0 else 0,
            "user_stats": self.user_bandwidth
        }

class TrafficRouter:
    """Manages traffic routing for VPN clients with bandwidth monitoring"""

    def __init__(self, config: Dict[str, Any], logger_instance=None):
        self.config = config
        self.is_running = False
        
        # VPN server configuration
        self.vpn_host = self.config.get("vpn_host", "0.0.0.0")
        self.vpn_port = self.config.get("vpn_port", 9000)
        
        # Virtual network configuration
        self.virtual_network = ipaddress.ip_network(
            self.config.get("virtual_network", "10.0.0.0/24")
        )
        self.virtual_gateway = str(next(self.virtual_network.hosts()))
        
        # Initialize engines
        self.nat_engine = NATEngine()
        self.tcp_engine = TCPForwardingEngine(access_key="")
        self.bandwidth_monitor = BandwidthMonitor()
        self.logger = logger_instance if logger_instance else logging.getLogger(__name__)
        
        # Server instances
        self.loop = None
        self.vpn_server = None
        
        # Statistics
        self.stats = {
            "total_connections": 0,
            "active_connections": 0,
            "bytes_forwarded": 0,
            "nat_sessions": 0,
            "errors": 0
        }

    async def start(self):
        """Start the traffic router"""
        if self.is_running:
            logger.warning("Traffic Router is already running")
            return True

        self.is_running = True
        self.loop = asyncio.get_event_loop()
        
        try:
            # Start VPN server
            self.vpn_server = await asyncio.start_server(
                self._handle_client_connection,
                self.vpn_host,
                self.vpn_port
            )
            
            self.logger.info(f"Traffic Router started on {self.vpn_host}:{self.vpn_port}")
            self.logger.info(f"Virtual network: {self.virtual_network}")
            self.logger.info(f"Virtual gateway: {self.virtual_gateway}")
            
            # Start NAT engine
            await self.nat_engine.start()
            
            return True
            
        except Exception as e:
            logger.error(f"Failed to start Traffic Router: {e}")
            self.is_running = False
            return False

    async def stop(self):
        """Stop the traffic router"""
        if not self.is_running:
            return

        self.is_running = False
        
        # Stop NAT engine
        await self.nat_engine.stop()
        
        # Close VPN server
        if self.vpn_server:
            self.vpn_server.close()
            await self.vpn_server.wait_closed()
            
        logger.info("Traffic Router stopped")

    async def _handle_client_connection(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
        """Handle incoming client connection"""
        peer = writer.get_extra_info('peername')
        logger.info(f"New client connection from {peer}")
        
        try:
            buffer_size = self.config.get("buffer_size", 8192)
            user_id = self.config.get("user_id", "unknown")
            
            while self.is_running:
                data = await reader.read(buffer_size)
                if not data:
                    break
                
                # Forward data and track bandwidth
                bytes_forwarded = await self._forward_data(data, "client", reader, writer)
                self.bandwidth_monitor.update(user_id, bytes_in=bytes_forwarded)
                self.stats["bytes_forwarded"] += bytes_forwarded
                
        except Exception as e:
            logger.error(f"Error handling client {peer}: {e}")
            self.stats["errors"] += 1
        finally:
            writer.close()
            await writer.wait_closed()

    async def _forward_data(self, data: bytes, source: str, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> int:
        """Forward data and return bytes forwarded"""
        try:
            if source == "client":
                target_addr = self._extract_target_address(data)
                if not target_addr:
                    return 0
                    
                # Create NAT session
                session = self.nat_engine.create_session(
                    virtual_ip=writer.get_extra_info('peername')[0],
                    virtual_port=writer.get_extra_info('peername')[1],
                    real_ip=target_addr[0],
                    real_port=target_addr[1]
                )
                
                # Forward through TCP engine
                conn = await self.tcp_engine.create_connection(
                    reader, writer,
                    target_addr[0], target_addr[1]
                )
                
                if conn:
                    return len(data)
            return 0
            
        except Exception as e:
            logger.error(f"Error forwarding data: {e}")
            return 0

    def _extract_target_address(self, data: bytes) -> Optional[Tuple[str, int]]:
        """Extract target address from packet"""
        try:
            if len(data) < 7:
                return None
                
            addr_type = data[0]
            if addr_type == 1:  # IPv4
                ip = socket.inet_ntoa(data[1:5])
                port = int.from_bytes(data[5:7], 'big')
                return (ip, port)
            elif addr_type == 3:  # Domain
                domain_len = data[1]
                domain = data[2:2+domain_len].decode()
                port = int.from_bytes(data[2+domain_len:4+domain_len], 'big')
                return (domain, port)
            return None
        except Exception as e:
            logger.error(f"Error extracting target address: {e}")
            return None

    def get_stats(self) -> Dict:
        """Get traffic router statistics"""
        return {
            "router_stats": self.stats,
            "bandwidth_stats": self.bandwidth_monitor.get_stats(),
            "nat_stats": self.nat_engine.get_stats()
        }