""" Session Tracker Module Manages and tracks all network sessions across the virtual ISP stack: - Unified session management across all modules - Session lifecycle tracking - Performance metrics and analytics - Session correlation and debugging """ import time import threading import uuid from typing import Dict, List, Optional, Set, Any, Tuple from dataclasses import dataclass, field from enum import Enum import json from .dhcp_server import DHCPLease from .nat_engine import NATSession from .tcp_engine import TCPConnection from .socket_translator import SocketConnection class SessionType(Enum): DHCP_LEASE = "DHCP_LEASE" NAT_SESSION = "NAT_SESSION" TCP_CONNECTION = "TCP_CONNECTION" SOCKET_CONNECTION = "SOCKET_CONNECTION" BRIDGE_CLIENT = "BRIDGE_CLIENT" class SessionState(Enum): INITIALIZING = "INITIALIZING" ACTIVE = "ACTIVE" IDLE = "IDLE" CLOSING = "CLOSING" CLOSED = "CLOSED" ERROR = "ERROR" @dataclass class SessionMetrics: """Session performance metrics""" bytes_in: int = 0 bytes_out: int = 0 packets_in: int = 0 packets_out: int = 0 errors: int = 0 retransmits: int = 0 rtt_samples: List[float] = field(default_factory=list) @property def total_bytes(self) -> int: return self.bytes_in + self.bytes_out @property def total_packets(self) -> int: return self.packets_in + self.packets_out @property def average_rtt(self) -> float: return sum(self.rtt_samples) / len(self.rtt_samples) if self.rtt_samples else 0.0 def update_bytes(self, bytes_in: int = 0, bytes_out: int = 0): """Update byte counters""" self.bytes_in += bytes_in self.bytes_out += bytes_out def update_packets(self, packets_in: int = 0, packets_out: int = 0): """Update packet counters""" self.packets_in += packets_in self.packets_out += packets_out def add_rtt_sample(self, rtt: float): """Add RTT sample""" self.rtt_samples.append(rtt) # Keep only last 100 samples if len(self.rtt_samples) > 100: self.rtt_samples = self.rtt_samples[-100:] def to_dict(self) -> Dict: """Convert to dictionary""" return { 'bytes_in': self.bytes_in, 'bytes_out': self.bytes_out, 'packets_in': self.packets_in, 'packets_out': self.packets_out, 'total_bytes': self.total_bytes, 'total_packets': self.total_packets, 'errors': self.errors, 'retransmits': self.retransmits, 'average_rtt': self.average_rtt, 'rtt_samples_count': len(self.rtt_samples) } @dataclass class UnifiedSession: """Unified session representation""" session_id: str session_type: SessionType state: SessionState created_time: float last_activity: float # Session identifiers virtual_ip: Optional[str] = None virtual_port: Optional[int] = None real_ip: Optional[str] = None real_port: Optional[int] = None protocol: Optional[str] = None # Related sessions (for correlation) related_sessions: Set[str] = field(default_factory=set) parent_session: Optional[str] = None child_sessions: Set[str] = field(default_factory=set) # Metrics metrics: SessionMetrics = field(default_factory=SessionMetrics) # Additional data metadata: Dict[str, Any] = field(default_factory=dict) def __post_init__(self): if not self.session_id: self.session_id = str(uuid.uuid4()) if self.created_time == 0: self.created_time = time.time() if self.last_activity == 0: self.last_activity = time.time() def update_activity(self): """Update last activity timestamp""" self.last_activity = time.time() def add_related_session(self, session_id: str): """Add related session""" self.related_sessions.add(session_id) def add_child_session(self, session_id: str): """Add child session""" self.child_sessions.add(session_id) def set_parent_session(self, session_id: str): """Set parent session""" self.parent_session = session_id @property def duration(self) -> float: """Get session duration in seconds""" return time.time() - self.created_time @property def idle_time(self) -> float: """Get idle time in seconds""" return time.time() - self.last_activity def to_dict(self) -> Dict: """Convert to dictionary""" return { 'session_id': self.session_id, 'session_type': self.session_type.value, 'state': self.state.value, 'created_time': self.created_time, 'last_activity': self.last_activity, 'duration': self.duration, 'idle_time': self.idle_time, 'virtual_ip': self.virtual_ip, 'virtual_port': self.virtual_port, 'real_ip': self.real_ip, 'real_port': self.real_port, 'protocol': self.protocol, 'related_sessions': list(self.related_sessions), 'parent_session': self.parent_session, 'child_sessions': list(self.child_sessions), 'metrics': self.metrics.to_dict(), 'metadata': self.metadata } class SessionTracker: """Unified session tracker""" def __init__(self, config: Dict): self.config = config self.sessions: Dict[str, UnifiedSession] = {} self.session_index: Dict[Tuple[str, str], Set[str]] = {} # (type, key) -> session_ids self.lock = threading.Lock() # Configuration self.max_sessions = config.get('max_sessions', 10000) self.session_timeout = config.get('session_timeout', 3600) self.cleanup_interval = config.get('cleanup_interval', 300) self.metrics_retention = config.get('metrics_retention', 86400) # 24 hours # Statistics self.stats = { 'total_sessions': 0, 'active_sessions': 0, 'expired_sessions': 0, 'session_types': {t.value: 0 for t in SessionType}, 'session_states': {s.value: 0 for s in SessionState}, 'cleanup_runs': 0, 'correlations_created': 0 } # Background tasks self.running = False self.cleanup_thread = None def _generate_session_key(self, session_type: SessionType, **kwargs) -> str: """Generate session key for indexing""" if session_type == SessionType.DHCP_LEASE: return f"dhcp_{kwargs.get('mac_address', 'unknown')}" elif session_type == SessionType.NAT_SESSION: return f"nat_{kwargs.get('virtual_ip', '')}_{kwargs.get('virtual_port', 0)}_{kwargs.get('protocol', '')}" elif session_type == SessionType.TCP_CONNECTION: return f"tcp_{kwargs.get('local_ip', '')}_{kwargs.get('local_port', 0)}_{kwargs.get('remote_ip', '')}_{kwargs.get('remote_port', 0)}" elif session_type == SessionType.SOCKET_CONNECTION: return f"socket_{kwargs.get('connection_id', 'unknown')}" elif session_type == SessionType.BRIDGE_CLIENT: return f"bridge_{kwargs.get('client_id', 'unknown')}" else: return f"unknown_{time.time()}" def _add_to_index(self, session: UnifiedSession): """Add session to search index""" # Index by type type_key = (session.session_type.value, 'all') if type_key not in self.session_index: self.session_index[type_key] = set() self.session_index[type_key].add(session.session_id) # Index by IP addresses if session.virtual_ip: ip_key = ('virtual_ip', session.virtual_ip) if ip_key not in self.session_index: self.session_index[ip_key] = set() self.session_index[ip_key].add(session.session_id) if session.real_ip: ip_key = ('real_ip', session.real_ip) if ip_key not in self.session_index: self.session_index[ip_key] = set() self.session_index[ip_key].add(session.session_id) # Index by protocol if session.protocol: proto_key = ('protocol', session.protocol) if proto_key not in self.session_index: self.session_index[proto_key] = set() self.session_index[proto_key].add(session.session_id) def _remove_from_index(self, session: UnifiedSession): """Remove session from search index""" for key, session_set in self.session_index.items(): session_set.discard(session.session_id) def create_session(self, session_type: SessionType, **kwargs) -> str: """Create new session""" with self.lock: # Check session limit if len(self.sessions) >= self.max_sessions: # Remove oldest expired session self._cleanup_expired_sessions() if len(self.sessions) >= self.max_sessions: return None # Create session session = UnifiedSession( session_id=kwargs.get('session_id', str(uuid.uuid4())), session_type=session_type, state=SessionState.INITIALIZING, virtual_ip=kwargs.get('virtual_ip'), virtual_port=kwargs.get('virtual_port'), real_ip=kwargs.get('real_ip'), real_port=kwargs.get('real_port'), protocol=kwargs.get('protocol'), metadata=kwargs.get('metadata', {}) ) # Add to sessions self.sessions[session.session_id] = session self._add_to_index(session) # Update statistics self.stats['total_sessions'] += 1 self.stats['active_sessions'] = len(self.sessions) self.stats['session_types'][session_type.value] += 1 self.stats['session_states'][SessionState.INITIALIZING.value] += 1 return session.session_id def update_session(self, session_id: str, **kwargs) -> bool: """Update session""" with self.lock: session = self.sessions.get(session_id) if not session: return False # Update fields old_state = session.state for key, value in kwargs.items(): if hasattr(session, key): setattr(session, key, value) session.update_activity() # Update state statistics if 'state' in kwargs and kwargs['state'] != old_state: self.stats['session_states'][old_state.value] -= 1 self.stats['session_states'][kwargs['state'].value] += 1 return True def close_session(self, session_id: str, reason: str = "") -> bool: """Close session""" with self.lock: session = self.sessions.get(session_id) if not session: return False old_state = session.state session.state = SessionState.CLOSED session.update_activity() if reason: session.metadata['close_reason'] = reason # Update statistics self.stats['session_states'][old_state.value] -= 1 self.stats['session_states'][SessionState.CLOSED.value] += 1 return True def remove_session(self, session_id: str) -> bool: """Remove session completely""" with self.lock: session = self.sessions.get(session_id) if not session: return False # Remove from index self._remove_from_index(session) # Remove from sessions del self.sessions[session_id] # Update statistics self.stats['active_sessions'] = len(self.sessions) self.stats['session_types'][session.session_type.value] -= 1 self.stats['session_states'][session.state.value] -= 1 return True def get_session(self, session_id: str) -> Optional[UnifiedSession]: """Get session by ID""" with self.lock: return self.sessions.get(session_id) def find_sessions(self, **criteria) -> List[UnifiedSession]: """Find sessions by criteria""" with self.lock: matching_sessions = [] # Use index if possible if 'session_type' in criteria: type_key = (criteria['session_type'].value if isinstance(criteria['session_type'], SessionType) else criteria['session_type'], 'all') candidate_ids = self.session_index.get(type_key, set()) elif 'virtual_ip' in criteria: ip_key = ('virtual_ip', criteria['virtual_ip']) candidate_ids = self.session_index.get(ip_key, set()) elif 'real_ip' in criteria: ip_key = ('real_ip', criteria['real_ip']) candidate_ids = self.session_index.get(ip_key, set()) elif 'protocol' in criteria: proto_key = ('protocol', criteria['protocol']) candidate_ids = self.session_index.get(proto_key, set()) else: candidate_ids = set(self.sessions.keys()) # Filter candidates for session_id in candidate_ids: session = self.sessions.get(session_id) if not session: continue match = True for key, value in criteria.items(): if hasattr(session, key): session_value = getattr(session, key) if isinstance(value, (SessionType, SessionState)): if session_value != value: match = False break elif session_value != value: match = False break else: match = False break if match: matching_sessions.append(session) return matching_sessions def correlate_sessions(self, session_id1: str, session_id2: str, relationship: str = 'related') -> bool: """Create correlation between sessions""" with self.lock: session1 = self.sessions.get(session_id1) session2 = self.sessions.get(session_id2) if not session1 or not session2: return False if relationship == 'parent_child': session1.add_child_session(session_id2) session2.set_parent_session(session_id1) else: session1.add_related_session(session_id2) session2.add_related_session(session_id1) self.stats['correlations_created'] += 1 return True def update_metrics(self, session_id: str, **metrics) -> bool: """Update session metrics""" with self.lock: session = self.sessions.get(session_id) if not session: return False session.update_activity() # Update metrics if 'bytes_in' in metrics or 'bytes_out' in metrics: session.metrics.update_bytes( metrics.get('bytes_in', 0), metrics.get('bytes_out', 0) ) if 'packets_in' in metrics or 'packets_out' in metrics: session.metrics.update_packets( metrics.get('packets_in', 0), metrics.get('packets_out', 0) ) if 'rtt' in metrics: session.metrics.add_rtt_sample(metrics['rtt']) if 'errors' in metrics: session.metrics.errors += metrics['errors'] if 'retransmits' in metrics: session.metrics.retransmits += metrics['retransmits'] return True def _cleanup_expired_sessions(self): """Clean up expired sessions""" current_time = time.time() expired_sessions = [] for session_id, session in self.sessions.items(): # Check if session is expired if (session.state == SessionState.CLOSED and current_time - session.last_activity > self.cleanup_interval): expired_sessions.append(session_id) elif (session.state != SessionState.CLOSED and current_time - session.last_activity > self.session_timeout): expired_sessions.append(session_id) # Remove expired sessions for session_id in expired_sessions: self.remove_session(session_id) self.stats['expired_sessions'] += 1 def _cleanup_loop(self): """Background cleanup loop""" while self.running: try: with self.lock: self._cleanup_expired_sessions() self.stats['cleanup_runs'] += 1 time.sleep(self.cleanup_interval) except Exception as e: print(f"Session tracker cleanup error: {e}") time.sleep(60) def get_sessions(self, limit: int = 100, offset: int = 0, **filters) -> List[Dict]: """Get sessions with pagination and filtering""" with self.lock: if filters: sessions = self.find_sessions(**filters) else: sessions = list(self.sessions.values()) # Sort by last activity (most recent first) sessions.sort(key=lambda s: s.last_activity, reverse=True) # Apply pagination paginated_sessions = sessions[offset:offset + limit] return [session.to_dict() for session in paginated_sessions] def get_session_summary(self) -> Dict: """Get session summary statistics""" with self.lock: summary = { 'total_sessions': len(self.sessions), 'by_type': {}, 'by_state': {}, 'by_protocol': {}, 'active_sessions_by_age': { 'last_hour': 0, 'last_day': 0, 'older': 0 } } current_time = time.time() hour_ago = current_time - 3600 day_ago = current_time - 86400 for session in self.sessions.values(): # Count by type session_type = session.session_type.value summary['by_type'][session_type] = summary['by_type'].get(session_type, 0) + 1 # Count by state session_state = session.state.value summary['by_state'][session_state] = summary['by_state'].get(session_state, 0) + 1 # Count by protocol if session.protocol: summary['by_protocol'][session.protocol] = summary['by_protocol'].get(session.protocol, 0) + 1 # Count by age if session.last_activity > hour_ago: summary['active_sessions_by_age']['last_hour'] += 1 elif session.last_activity > day_ago: summary['active_sessions_by_age']['last_day'] += 1 else: summary['active_sessions_by_age']['older'] += 1 return summary def get_stats(self) -> Dict: """Get tracker statistics""" with self.lock: stats = self.stats.copy() stats['active_sessions'] = len(self.sessions) return stats def reset_stats(self): """Reset statistics""" self.stats = { 'total_sessions': len(self.sessions), 'active_sessions': len(self.sessions), 'expired_sessions': 0, 'session_types': {t.value: 0 for t in SessionType}, 'session_states': {s.value: 0 for s in SessionState}, 'cleanup_runs': 0, 'correlations_created': 0 } # Recalculate current counts with self.lock: for session in self.sessions.values(): self.stats['session_types'][session.session_type.value] += 1 self.stats['session_states'][session.state.value] += 1 def export_sessions(self, format: str = 'json') -> str: """Export sessions data""" with self.lock: sessions_data = [session.to_dict() for session in self.sessions.values()] if format == 'json': return json.dumps(sessions_data, indent=2, default=str) else: raise ValueError(f"Unsupported export format: {format}") def start(self): """Start session tracker""" self.running = True self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True) self.cleanup_thread.start() print("Session tracker started") def stop(self): """Stop session tracker""" self.running = False if self.cleanup_thread: self.cleanup_thread.join() print("Session tracker stopped")