Spaces:
Runtime error
Runtime error
| """ | |
| Traffic Router Module | |
| Handles routing of all client traffic with bandwidth monitoring | |
| """ | |
| import asyncio | |
| import socket | |
| import logging | |
| import ipaddress | |
| from typing import Dict, Any, Optional, Tuple | |
| from datetime import datetime | |
| from .tcp_forward import OutlineTCPForwardingEngine as TCPForwardingEngine | |
| from .nat_engine import NATEngine | |
| logger = logging.getLogger(__name__) | |
| class BandwidthMonitor: | |
| def __init__(self): | |
| self.total_bytes_in = 0 | |
| self.total_bytes_out = 0 | |
| self.user_bandwidth: Dict[str, Dict[str, int]] = {} | |
| self.start_time = datetime.now() | |
| def update(self, user_id: str, bytes_in: int = 0, bytes_out: int = 0): | |
| """Update bandwidth usage for a user""" | |
| if user_id not in self.user_bandwidth: | |
| self.user_bandwidth[user_id] = { | |
| "bytes_in": 0, | |
| "bytes_out": 0, | |
| "last_update": datetime.now() | |
| } | |
| self.user_bandwidth[user_id]["bytes_in"] += bytes_in | |
| self.user_bandwidth[user_id]["bytes_out"] += bytes_out | |
| self.user_bandwidth[user_id]["last_update"] = datetime.now() | |
| self.total_bytes_in += bytes_in | |
| self.total_bytes_out += bytes_out | |
| def get_stats(self) -> Dict: | |
| """Get bandwidth statistics""" | |
| current_time = datetime.now() | |
| uptime = (current_time - self.start_time).total_seconds() | |
| return { | |
| "total_bytes_in": self.total_bytes_in, | |
| "total_bytes_out": self.total_bytes_out, | |
| "avg_speed_in": self.total_bytes_in / uptime if uptime > 0 else 0, | |
| "avg_speed_out": self.total_bytes_out / uptime if uptime > 0 else 0, | |
| "user_stats": self.user_bandwidth | |
| } | |
| class TrafficRouter: | |
| """Manages traffic routing for VPN clients with bandwidth monitoring""" | |
| def __init__(self, config: Dict[str, Any], logger_instance=None): | |
| self.config = config | |
| self.is_running = False | |
| # VPN server configuration | |
| self.vpn_host = self.config.get("vpn_host", "0.0.0.0") | |
| self.vpn_port = self.config.get("vpn_port", 9000) | |
| # Virtual network configuration | |
| self.virtual_network = ipaddress.ip_network( | |
| self.config.get("virtual_network", "10.0.0.0/24") | |
| ) | |
| self.virtual_gateway = str(next(self.virtual_network.hosts())) | |
| # Initialize engines | |
| self.nat_engine = NATEngine() | |
| self.tcp_engine = TCPForwardingEngine(access_key="") | |
| self.bandwidth_monitor = BandwidthMonitor() | |
| self.logger = logger_instance if logger_instance else logging.getLogger(__name__) | |
| # Server instances | |
| self.loop = None | |
| self.vpn_server = None | |
| # Statistics | |
| self.stats = { | |
| "total_connections": 0, | |
| "active_connections": 0, | |
| "bytes_forwarded": 0, | |
| "nat_sessions": 0, | |
| "errors": 0 | |
| } | |
| async def start(self): | |
| """Start the traffic router""" | |
| if self.is_running: | |
| logger.warning("Traffic Router is already running") | |
| return True | |
| self.is_running = True | |
| self.loop = asyncio.get_event_loop() | |
| try: | |
| # Start VPN server | |
| self.vpn_server = await asyncio.start_server( | |
| self._handle_client_connection, | |
| self.vpn_host, | |
| self.vpn_port | |
| ) | |
| self.logger.info(f"Traffic Router started on {self.vpn_host}:{self.vpn_port}") | |
| self.logger.info(f"Virtual network: {self.virtual_network}") | |
| self.logger.info(f"Virtual gateway: {self.virtual_gateway}") | |
| # Start NAT engine | |
| await self.nat_engine.start() | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to start Traffic Router: {e}") | |
| self.is_running = False | |
| return False | |
| async def stop(self): | |
| """Stop the traffic router""" | |
| if not self.is_running: | |
| return | |
| self.is_running = False | |
| # Stop NAT engine | |
| await self.nat_engine.stop() | |
| # Close VPN server | |
| if self.vpn_server: | |
| self.vpn_server.close() | |
| await self.vpn_server.wait_closed() | |
| logger.info("Traffic Router stopped") | |
| async def _handle_client_connection(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): | |
| """Handle incoming client connection""" | |
| peer = writer.get_extra_info('peername') | |
| logger.info(f"New client connection from {peer}") | |
| try: | |
| buffer_size = self.config.get("buffer_size", 8192) | |
| user_id = self.config.get("user_id", "unknown") | |
| while self.is_running: | |
| data = await reader.read(buffer_size) | |
| if not data: | |
| break | |
| # Forward data and track bandwidth | |
| bytes_forwarded = await self._forward_data(data, "client", reader, writer) | |
| self.bandwidth_monitor.update(user_id, bytes_in=bytes_forwarded) | |
| self.stats["bytes_forwarded"] += bytes_forwarded | |
| except Exception as e: | |
| logger.error(f"Error handling client {peer}: {e}") | |
| self.stats["errors"] += 1 | |
| finally: | |
| writer.close() | |
| await writer.wait_closed() | |
| async def _forward_data(self, data: bytes, source: str, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> int: | |
| """Forward data and return bytes forwarded""" | |
| try: | |
| if source == "client": | |
| target_addr = self._extract_target_address(data) | |
| if not target_addr: | |
| return 0 | |
| # Create NAT session | |
| session = self.nat_engine.create_session( | |
| virtual_ip=writer.get_extra_info('peername')[0], | |
| virtual_port=writer.get_extra_info('peername')[1], | |
| real_ip=target_addr[0], | |
| real_port=target_addr[1] | |
| ) | |
| # Forward through TCP engine | |
| conn = await self.tcp_engine.create_connection( | |
| reader, writer, | |
| target_addr[0], target_addr[1] | |
| ) | |
| if conn: | |
| return len(data) | |
| return 0 | |
| except Exception as e: | |
| logger.error(f"Error forwarding data: {e}") | |
| return 0 | |
| def _extract_target_address(self, data: bytes) -> Optional[Tuple[str, int]]: | |
| """Extract target address from packet""" | |
| try: | |
| if len(data) < 7: | |
| return None | |
| addr_type = data[0] | |
| if addr_type == 1: # IPv4 | |
| ip = socket.inet_ntoa(data[1:5]) | |
| port = int.from_bytes(data[5:7], 'big') | |
| return (ip, port) | |
| elif addr_type == 3: # Domain | |
| domain_len = data[1] | |
| domain = data[2:2+domain_len].decode() | |
| port = int.from_bytes(data[2+domain_len:4+domain_len], 'big') | |
| return (domain, port) | |
| return None | |
| except Exception as e: | |
| logger.error(f"Error extracting target address: {e}") | |
| return None | |
| def get_stats(self) -> Dict: | |
| """Get traffic router statistics""" | |
| return { | |
| "router_stats": self.stats, | |
| "bandwidth_stats": self.bandwidth_monitor.get_stats(), | |
| "nat_stats": self.nat_engine.get_stats() | |
| } | |