| """
|
| Traffic Router Module
|
| Handles routing of all client traffic with bandwidth monitoring
|
| """
|
|
|
| import asyncio
|
| import socket
|
| import logging
|
| import ipaddress
|
| from typing import Dict, Any, Optional, Tuple
|
| from datetime import datetime
|
|
|
| from .tcp_forward import OutlineTCPForwardingEngine as TCPForwardingEngine
|
| from .nat_engine import NATEngine
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
| class BandwidthMonitor:
|
| def __init__(self):
|
| self.total_bytes_in = 0
|
| self.total_bytes_out = 0
|
| self.user_bandwidth: Dict[str, Dict[str, int]] = {}
|
| self.start_time = datetime.now()
|
|
|
| def update(self, user_id: str, bytes_in: int = 0, bytes_out: int = 0):
|
| """Update bandwidth usage for a user"""
|
| if user_id not in self.user_bandwidth:
|
| self.user_bandwidth[user_id] = {
|
| "bytes_in": 0,
|
| "bytes_out": 0,
|
| "last_update": datetime.now()
|
| }
|
|
|
| self.user_bandwidth[user_id]["bytes_in"] += bytes_in
|
| self.user_bandwidth[user_id]["bytes_out"] += bytes_out
|
| self.user_bandwidth[user_id]["last_update"] = datetime.now()
|
|
|
| self.total_bytes_in += bytes_in
|
| self.total_bytes_out += bytes_out
|
|
|
| def get_stats(self) -> Dict:
|
| """Get bandwidth statistics"""
|
| current_time = datetime.now()
|
| uptime = (current_time - self.start_time).total_seconds()
|
|
|
| return {
|
| "total_bytes_in": self.total_bytes_in,
|
| "total_bytes_out": self.total_bytes_out,
|
| "avg_speed_in": self.total_bytes_in / uptime if uptime > 0 else 0,
|
| "avg_speed_out": self.total_bytes_out / uptime if uptime > 0 else 0,
|
| "user_stats": self.user_bandwidth
|
| }
|
|
|
| class TrafficRouter:
|
| """Manages traffic routing for VPN clients with bandwidth monitoring"""
|
|
|
| def __init__(self, config: Dict[str, Any], logger_instance=None):
|
| self.config = config
|
| self.is_running = False
|
|
|
|
|
| self.vpn_host = self.config.get("vpn_host", "0.0.0.0")
|
| self.vpn_port = self.config.get("vpn_port", 9000)
|
|
|
|
|
| self.virtual_network = ipaddress.ip_network(
|
| self.config.get("virtual_network", "10.0.0.0/24")
|
| )
|
| self.virtual_gateway = str(next(self.virtual_network.hosts()))
|
|
|
|
|
| self.nat_engine = NATEngine()
|
| self.tcp_engine = TCPForwardingEngine(access_key="")
|
| self.bandwidth_monitor = BandwidthMonitor()
|
| self.logger = logger_instance if logger_instance else logging.getLogger(__name__)
|
|
|
|
|
| self.loop = None
|
| self.vpn_server = None
|
|
|
|
|
| self.stats = {
|
| "total_connections": 0,
|
| "active_connections": 0,
|
| "bytes_forwarded": 0,
|
| "nat_sessions": 0,
|
| "errors": 0
|
| }
|
|
|
| async def start(self):
|
| """Start the traffic router"""
|
| if self.is_running:
|
| logger.warning("Traffic Router is already running")
|
| return True
|
|
|
| self.is_running = True
|
| self.loop = asyncio.get_event_loop()
|
|
|
| try:
|
|
|
| self.vpn_server = await asyncio.start_server(
|
| self._handle_client_connection,
|
| self.vpn_host,
|
| self.vpn_port
|
| )
|
|
|
| self.logger.info(f"Traffic Router started on {self.vpn_host}:{self.vpn_port}")
|
| self.logger.info(f"Virtual network: {self.virtual_network}")
|
| self.logger.info(f"Virtual gateway: {self.virtual_gateway}")
|
|
|
|
|
| await self.nat_engine.start()
|
|
|
| return True
|
|
|
| except Exception as e:
|
| logger.error(f"Failed to start Traffic Router: {e}")
|
| self.is_running = False
|
| return False
|
|
|
| async def stop(self):
|
| """Stop the traffic router"""
|
| if not self.is_running:
|
| return
|
|
|
| self.is_running = False
|
|
|
|
|
| await self.nat_engine.stop()
|
|
|
|
|
| if self.vpn_server:
|
| self.vpn_server.close()
|
| await self.vpn_server.wait_closed()
|
|
|
| logger.info("Traffic Router stopped")
|
|
|
| async def _handle_client_connection(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
| """Handle incoming client connection"""
|
| peer = writer.get_extra_info('peername')
|
| logger.info(f"New client connection from {peer}")
|
|
|
| try:
|
| buffer_size = self.config.get("buffer_size", 8192)
|
| user_id = self.config.get("user_id", "unknown")
|
|
|
| while self.is_running:
|
| data = await reader.read(buffer_size)
|
| if not data:
|
| break
|
|
|
|
|
| bytes_forwarded = await self._forward_data(data, "client", reader, writer)
|
| self.bandwidth_monitor.update(user_id, bytes_in=bytes_forwarded)
|
| self.stats["bytes_forwarded"] += bytes_forwarded
|
|
|
| except Exception as e:
|
| logger.error(f"Error handling client {peer}: {e}")
|
| self.stats["errors"] += 1
|
| finally:
|
| writer.close()
|
| await writer.wait_closed()
|
|
|
| async def _forward_data(self, data: bytes, source: str, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> int:
|
| """Forward data and return bytes forwarded"""
|
| try:
|
| if source == "client":
|
| target_addr = self._extract_target_address(data)
|
| if not target_addr:
|
| return 0
|
|
|
|
|
| session = self.nat_engine.create_session(
|
| virtual_ip=writer.get_extra_info('peername')[0],
|
| virtual_port=writer.get_extra_info('peername')[1],
|
| real_ip=target_addr[0],
|
| real_port=target_addr[1]
|
| )
|
|
|
|
|
| conn = await self.tcp_engine.create_connection(
|
| reader, writer,
|
| target_addr[0], target_addr[1]
|
| )
|
|
|
| if conn:
|
| return len(data)
|
| return 0
|
|
|
| except Exception as e:
|
| logger.error(f"Error forwarding data: {e}")
|
| return 0
|
|
|
| def _extract_target_address(self, data: bytes) -> Optional[Tuple[str, int]]:
|
| """Extract target address from packet"""
|
| try:
|
| if len(data) < 7:
|
| return None
|
|
|
| addr_type = data[0]
|
| if addr_type == 1:
|
| ip = socket.inet_ntoa(data[1:5])
|
| port = int.from_bytes(data[5:7], 'big')
|
| return (ip, port)
|
| elif addr_type == 3:
|
| domain_len = data[1]
|
| domain = data[2:2+domain_len].decode()
|
| port = int.from_bytes(data[2+domain_len:4+domain_len], 'big')
|
| return (domain, port)
|
| return None
|
| except Exception as e:
|
| logger.error(f"Error extracting target address: {e}")
|
| return None
|
|
|
| def get_stats(self) -> Dict:
|
| """Get traffic router statistics"""
|
| return {
|
| "router_stats": self.stats,
|
| "bandwidth_stats": self.bandwidth_monitor.get_stats(),
|
| "nat_stats": self.nat_engine.get_stats()
|
| }
|
|
|