""" Session Tracker Module Manages and tracks all network sessions across the virtual ISP stack: - Unified session management across all modules - Session lifecycle tracking - Performance metrics and analytics - Session correlation and debugging """ import time import threading import uuid from typing import Dict, List, Optional, Set, Any, Tuple from dataclasses import dataclass, field from enum import Enum import json from .tcp_engine import TCPConnection from .nat_engine import NATSession class SessionType(Enum): NAT_SESSION = "NAT_SESSION" TCP_CONNECTION = "TCP_CONNECTION" SOCKET_CONNECTION = "SOCKET_CONNECTION" class SessionState(Enum): INITIALIZING = "INITIALIZING" ACTIVE = "ACTIVE" IDLE = "IDLE" CLOSING = "CLOSING" CLOSED = "CLOSED" ERROR = "ERROR" @dataclass class SessionMetrics: """Session performance metrics""" bytes_in: int = 0 bytes_out: int = 0 packets_in: int = 0 packets_out: int = 0 errors: int = 0 retransmits: int = 0 rtt_samples: List[float] = field(default_factory=list) @property def total_bytes(self) -> int: return self.bytes_in + self.bytes_out @property def total_packets(self) -> int: return self.packets_in + self.packets_out @property def average_rtt(self) -> float: return sum(self.rtt_samples) / len(self.rtt_samples) if self.rtt_samples else 0.0 def update_bytes(self, bytes_in: int = 0, bytes_out: int = 0): """Update byte counters""" self.bytes_in += bytes_in self.bytes_out += bytes_out def update_packets(self, packets_in: int = 0, packets_out: int = 0): """Update packet counters""" self.packets_in += packets_in self.packets_out += packets_out def add_rtt_sample(self, rtt: float): """Add RTT sample""" self.rtt_samples.append(rtt) # Keep only last 100 samples if len(self.rtt_samples) > 100: self.rtt_samples = self.rtt_samples[-100:] def to_dict(self) -> Dict: """Convert to dictionary""" return { 'bytes_in': self.bytes_in, 'bytes_out': self.bytes_out, 'packets_in': self.packets_in, 'packets_out': self.packets_out, 'total_bytes': self.total_bytes, 'total_packets': self.total_packets, 'errors': self.errors, 'retransmits': self.retransmits, 'average_rtt': self.average_rtt, 'rtt_samples_count': len(self.rtt_samples) } @dataclass class UnifiedSession: """Unified session representation""" id: str type: SessionType state: SessionState start_time: float last_active: float source_ip: str source_port: int dest_ip: str dest_port: int metrics: SessionMetrics = field(default_factory=SessionMetrics) metadata: Dict[str, Any] = field(default_factory=dict) @property def duration(self) -> float: """Get session duration""" return time.time() - self.start_time @property def idle_time(self) -> float: """Get idle time""" return time.time() - self.last_active def update_activity(self): """Update last activity timestamp""" self.last_active = time.time() def to_dict(self) -> Dict: """Convert to dictionary""" return { 'id': self.id, 'type': self.type.value, 'state': self.state.value, 'start_time': self.start_time, 'last_active': self.last_active, 'duration': self.duration, 'idle_time': self.idle_time, 'source_ip': self.source_ip, 'source_port': self.source_port, 'dest_ip': self.dest_ip, 'dest_port': self.dest_port, 'metrics': self.metrics.to_dict(), 'metadata': self.metadata } class SessionTracker: """Tracks all active network sessions""" _instance = None _lock = threading.Lock() def __new__(cls): with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self): if self._initialized: return self.sessions: Dict[str, UnifiedSession] = {} self.lock = threading.Lock() self.cleanup_thread = threading.Thread(target=self._cleanup_loop) self.cleanup_thread.daemon = True self.running = True self.cleanup_thread.start() self._initialized = True def create_session(self, session_type: SessionType, source_ip: str, source_port: int, dest_ip: str, dest_port: int, **kwargs) -> UnifiedSession: """Create a new session""" session = UnifiedSession( id=str(uuid.uuid4()), type=session_type, state=SessionState.INITIALIZING, start_time=time.time(), last_active=time.time(), source_ip=source_ip, source_port=source_port, dest_ip=dest_ip, dest_port=dest_port, metadata=kwargs ) with self.lock: self.sessions[session.id] = session return session def get_session(self, session_id: str) -> Optional[UnifiedSession]: """Get session by ID""" return self.sessions.get(session_id) def update_session(self, session_id: str, state: Optional[SessionState] = None, metrics_update: Optional[Dict] = None, metadata_update: Optional[Dict] = None) -> bool: """Update session state and metrics""" session = self.get_session(session_id) if not session: return False with self.lock: if state: session.state = state if metrics_update: session.metrics.update_bytes( metrics_update.get('bytes_in', 0), metrics_update.get('bytes_out', 0) ) session.metrics.update_packets( metrics_update.get('packets_in', 0), metrics_update.get('packets_out', 0) ) if 'rtt' in metrics_update: session.metrics.add_rtt_sample(metrics_update['rtt']) if metadata_update: session.metadata.update(metadata_update) session.update_activity() return True def close_session(self, session_id: str): """Close a session""" session = self.get_session(session_id) if session: with self.lock: session.state = SessionState.CLOSED def get_all_sessions(self) -> List[UnifiedSession]: """Get all active sessions""" with self.lock: return [s for s in self.sessions.values() if s.state != SessionState.CLOSED] def get_sessions_by_type(self, session_type: SessionType) -> List[UnifiedSession]: """Get sessions by type""" with self.lock: return [s for s in self.sessions.values() if s.type == session_type and s.state != SessionState.CLOSED] def get_sessions_by_ip(self, ip_address: str) -> List[UnifiedSession]: """Get sessions by IP address""" with self.lock: return [s for s in self.sessions.values() if (s.source_ip == ip_address or s.dest_ip == ip_address) and s.state != SessionState.CLOSED] def _cleanup_loop(self): """Background cleanup loop""" while self.running: time.sleep(60) # Run every minute try: self._cleanup_sessions() except Exception as e: print(f"Error in cleanup loop: {e}") def _cleanup_sessions(self): """Clean up old sessions""" current_time = time.time() to_remove = [] with self.lock: for session_id, session in self.sessions.items(): # Remove closed sessions after 5 minutes if (session.state == SessionState.CLOSED and current_time - session.last_active > 300): to_remove.append(session_id) # Remove idle sessions after 30 minutes elif (session.state != SessionState.CLOSED and current_time - session.last_active > 1800): session.state = SessionState.CLOSED to_remove.append(session_id) for session_id in to_remove: del self.sessions[session_id] def shutdown(self): """Shutdown the tracker""" self.running = False if self.cleanup_thread.is_alive(): self.cleanup_thread.join() with self.lock: self.sessions.clear()