""" Message Acknowledgment System for AegisLM Framework Production-ready message acknowledgment system with reliable delivery, retry mechanisms, and comprehensive tracking for WebSocket communications. """ import asyncio import uuid import json import logging from typing import Dict, Set, Optional, List, Any from datetime import datetime, timedelta from collections import deque from dataclasses import dataclass, field logger = logging.getLogger(__name__) @dataclass class PendingMessage: """Pending message information for acknowledgment tracking.""" message_id: str connection_id: str room_id: Optional[str] message: Dict[str, Any] sent_at: datetime retry_count: int = 0 max_retries: int = 3 ack_timeout: int = 30 # seconds last_retry: Optional[datetime] = None @dataclass class AcknowledgmentStats: """Message acknowledgment statistics.""" total_sent: int = 0 total_acknowledged: int = 0 total_failed: int = 0 total_retries: int = 0 average_delivery_time: float = 0.0 success_rate: float = 0.0 class MessageAcknowledgmentSystem: """ Ensure reliable message delivery with acknowledgments. Provides message tracking, retry mechanisms, delivery guarantees, and comprehensive statistics for WebSocket communications. """ def __init__(self, websocket_manager, ack_timeout: int = 30, max_retries: int = 3, cleanup_interval: int = 60): """ Initialize message acknowledgment system. Args: websocket_manager: WebSocket manager instance ack_timeout: Timeout for message acknowledgment in seconds max_retries: Maximum number of retry attempts cleanup_interval: Interval for cleanup operations in seconds """ self.websocket_manager = websocket_manager self.ack_timeout = ack_timeout self.max_retries = max_retries self.cleanup_interval = cleanup_interval self.pending_acks: Dict[str, PendingMessage] = {} self.ack_stats = AcknowledgmentStats() self._cleanup_task = None self._delivery_times: deque = deque(maxlen=1000) async def start_ack_system(self): """Start acknowledgment cleanup task.""" if self._cleanup_task is None: self._cleanup_task = asyncio.create_task(self._periodic_ack_cleanup()) logger.info("Message acknowledgment system started") async def stop_ack_system(self): """Stop acknowledgment cleanup task.""" if self._cleanup_task: self._cleanup_task.cancel() try: await self._cleanup_task except asyncio.CancelledError: pass self._cleanup_task = None logger.info("Message acknowledgment system stopped") async def send_with_ack(self, connection_id: str, message: Dict, room_id: Optional[str] = None) -> str: """ Send message with acknowledgment tracking. Args: connection_id: Target connection ID message: Message to send room_id: Optional room ID for broadcasting Returns: str: Message ID for tracking Raises: Exception: If send fails """ message_id = str(uuid.uuid4()) # Add acknowledgment metadata ack_message = { **message, 'message_id': message_id, 'requires_ack': True, 'timestamp': datetime.utcnow().isoformat(), 'sent_at': datetime.utcnow().isoformat() } # Store pending acknowledgment pending_msg = PendingMessage( message_id=message_id, connection_id=connection_id, room_id=room_id, message=message, sent_at=datetime.utcnow(), retry_count=0, max_retries=self.max_retries, ack_timeout=self.ack_timeout ) self.pending_acks[message_id] = pending_msg try: # Send message if room_id: await self.websocket_manager.broadcast_to_room(room_id, ack_message) else: await self.websocket_manager.send_to_connection(connection_id, ack_message) self.ack_stats.total_sent += 1 logger.debug(f"Message sent with acknowledgment tracking: {message_id}") return message_id except Exception as e: # Remove from pending if send failed self.pending_acks.pop(message_id, None) self.ack_stats.total_failed += 1 logger.error(f"Failed to send message {message_id}: {e}") raise e async def send_batch_with_ack(self, messages: List[Dict[str, Any]], connection_id: str, room_id: Optional[str] = None) -> List[str]: """ Send multiple messages with acknowledgment tracking. Args: messages: List of messages to send connection_id: Target connection ID room_id: Optional room ID for broadcasting Returns: List[str]: Message IDs for tracking """ message_ids = [] for message in messages: try: message_id = await self.send_with_ack(connection_id, message, room_id) message_ids.append(message_id) except Exception as e: logger.error(f"Failed to send batch message: {e}") # Continue with other messages return message_ids async def handle_acknowledgment(self, connection_id: str, message_id: str) -> bool: """ Handle message acknowledgment from client. Args: connection_id: Connection ID that sent acknowledgment message_id: Message ID being acknowledged Returns: bool: True if acknowledgment was processed successfully """ if message_id not in self.pending_acks: logger.warning(f"Received ack for unknown message: {message_id}") return False pending_msg = self.pending_acks.pop(message_id) # Calculate delivery time delivery_time = (datetime.utcnow() - pending_msg.sent_at).total_seconds() self._delivery_times.append(delivery_time) # Update statistics self.ack_stats.total_acknowledged += 1 self._update_success_rate() logger.debug(f"Message {message_id} acknowledged after {delivery_time:.2f}s") return True async def handle_nack(self, connection_id: str, message_id: str, reason: str) -> bool: """ Handle negative acknowledgment from client. Args: connection_id: Connection ID that sent NACK message_id: Message ID being NACKed reason: Reason for NACK Returns: bool: True if NACK was processed successfully """ if message_id not in self.pending_acks: logger.warning(f"Received NACK for unknown message: {message_id}") return False pending_msg = self.pending_acks[message_id] logger.warning(f"Message {message_id} NACKed: {reason}") # Try to resend if retries available if pending_msg.retry_count < pending_msg.max_retries: await self._retry_message(message_id, pending_msg) else: # Max retries exceeded, mark as failed self.pending_acks.pop(message_id, None) self.ack_stats.total_failed += 1 self._update_success_rate() logger.error(f"Message {message_id} failed after {pending_msg.max_retries} retries") return True async def _periodic_ack_cleanup(self): """Periodic cleanup of expired acknowledgments and retry failed messages.""" while True: try: await asyncio.sleep(10) # Check every 10 seconds cleanup_result = await self.cleanup_expired_acks() if cleanup_result['cleaned'] > 0 or cleanup_result['retried'] > 0: logger.debug(f"Ack cleanup: {cleanup_result}") except asyncio.CancelledError: break except Exception as e: logger.error(f"Ack cleanup error: {e}") async def cleanup_expired_acks(self) -> Dict[str, Any]: """ Clean up expired acknowledgments and retry failed messages. Returns: Dict containing cleanup statistics """ current_time = datetime.utcnow() expired_messages = [] retry_messages = [] for message_id, pending_msg in self.pending_acks.items(): time_since_sent = current_time - pending_msg.sent_at if time_since_sent.total_seconds() > pending_msg.ack_timeout: if pending_msg.retry_count < pending_msg.max_retries: retry_messages.append(message_id) else: expired_messages.append(message_id) # Handle expired messages for message_id in expired_messages: pending_msg = self.pending_acks.pop(message_id, None) self.ack_stats.total_failed += 1 logger.error(f"Message {message_id} expired after {pending_msg.max_retries} retries") # Retry messages for message_id in retry_messages: pending_msg = self.pending_acks[message_id] await self._retry_message(message_id, pending_msg) # Update statistics self._update_success_rate() return { 'total_pending': len(self.pending_acks), 'expired_count': len(expired_messages), 'retry_count': len(retry_messages), 'cleaned': len(expired_messages), 'retried': len(retry_messages) } async def _retry_message(self, message_id: str, pending_msg: PendingMessage): """Retry sending a message.""" try: # Update retry count pending_msg.retry_count += 1 pending_msg.last_retry = datetime.utcnow() # Add retry metadata retry_message = { **pending_msg.message, 'message_id': message_id, 'requires_ack': True, 'retry': pending_msg.retry_count, 'timestamp': datetime.utcnow().isoformat(), 'original_sent_at': pending_msg.sent_at.isoformat() } # Resend message if pending_msg.room_id: await self.websocket_manager.broadcast_to_room(pending_msg.room_id, retry_message) else: await self.websocket_manager.send_to_connection(pending_msg.connection_id, retry_message) self.ack_stats.total_retries += 1 logger.info(f"Retried message {message_id} (attempt {pending_msg.retry_count})") except Exception as e: logger.error(f"Failed to retry message {message_id}: {e}") # Mark as failed self.pending_acks.pop(message_id, None) self.ack_stats.total_failed += 1 def _update_success_rate(self): """Update success rate statistics.""" total = self.ack_stats.total_sent if total > 0: self.ack_stats.success_rate = (self.ack_stats.total_acknowledged / total) * 100 # Update average delivery time if self._delivery_times: self.ack_stats.average_delivery_time = sum(self._delivery_times) / len(self._delivery_times) async def get_pending_ack_count(self) -> int: """Get number of pending acknowledgments.""" return len(self.pending_acks) async def get_pending_acks_for_connection(self, connection_id: str) -> List[Dict[str, Any]]: """Get pending acknowledgments for a specific connection.""" pending = [] for message_id, pending_msg in self.pending_acks.items(): if pending_msg.connection_id == connection_id: pending.append({ 'message_id': message_id, 'sent_at': pending_msg.sent_at.isoformat(), 'retry_count': pending_msg.retry_count, 'max_retries': pending_msg.max_retries, 'time_since_sent': (datetime.utcnow() - pending_msg.sent_at).total_seconds() }) return pending async def get_ack_statistics(self) -> AcknowledgmentStats: """Get acknowledgment statistics.""" self._update_success_rate() return self.ack_stats async def force_retry_message(self, message_id: str) -> bool: """ Force retry of a specific message. Args: message_id: Message ID to retry Returns: bool: True if retry was initiated """ if message_id not in self.pending_acks: return False pending_msg = self.pending_acks[message_id] if pending_msg.retry_count < pending_msg.max_retries: await self._retry_message(message_id, pending_msg) return True return False async def cancel_pending_message(self, message_id: str) -> bool: """ Cancel a pending message. Args: message_id: Message ID to cancel Returns: bool: True if message was cancelled """ if message_id in self.pending_acks: self.pending_acks.pop(message_id, None) self.ack_stats.total_failed += 1 self._update_success_rate() logger.info(f"Cancelled pending message: {message_id}") return True return False async def get_delivery_report(self) -> Dict[str, Any]: """Get comprehensive delivery report.""" current_time = datetime.utcnow() # Analyze pending messages by age age_buckets = { '0-10s': 0, '10-30s': 0, '30-60s': 0, '60s+': 0 } for pending_msg in self.pending_acks.values(): age = (current_time - pending_msg.sent_at).total_seconds() if age <= 10: age_buckets['0-10s'] += 1 elif age <= 30: age_buckets['10-30s'] += 1 elif age <= 60: age_buckets['30-60s'] += 1 else: age_buckets['60s+'] += 1 return { 'statistics': { 'total_sent': self.ack_stats.total_sent, 'total_acknowledged': self.ack_stats.total_acknowledged, 'total_failed': self.ack_stats.total_failed, 'total_retries': self.ack_stats.total_retries, 'pending_count': len(self.pending_acks), 'success_rate': self.ack_stats.success_rate, 'average_delivery_time': self.ack_stats.average_delivery_time }, 'pending_by_age': age_buckets, 'retry_distribution': self._get_retry_distribution(), 'timestamp': current_time.isoformat() } def _get_retry_distribution(self) -> Dict[str, int]: """Get distribution of retry counts.""" distribution = {str(i): 0 for i in range(self.max_retries + 1)} distribution[f'{self.max_retries}+'] = 0 for pending_msg in self.pending_acks.values(): if pending_msg.retry_count <= self.max_retries: distribution[str(pending_msg.retry_count)] += 1 else: distribution[f'{self.max_retries}+'] += 1 return distribution async def reset_statistics(self): """Reset acknowledgment statistics.""" self.ack_stats = AcknowledgmentStats() self._delivery_times.clear() logger.info("Message acknowledgment statistics reset") async def cleanup_connection_acks(self, connection_id: str) -> int: """ Clean up all pending acknowledgments for a connection. Args: connection_id: Connection ID to clean up Returns: int: Number of acknowledgments cleaned up """ to_remove = [] for message_id, pending_msg in self.pending_acks.items(): if pending_msg.connection_id == connection_id: to_remove.append(message_id) for message_id in to_remove: self.pending_acks.pop(message_id, None) self.ack_stats.total_failed += 1 logger.info(f"Cleaned up {len(to_remove)} pending acknowledgments for connection {connection_id}") return len(to_remove) # Factory function def create_message_acknowledgment_system(websocket_manager, ack_timeout: int = 30, max_retries: int = 3, cleanup_interval: int = 60) -> MessageAcknowledgmentSystem: """ Create a message acknowledgment system instance. Args: websocket_manager: WebSocket manager instance ack_timeout: Timeout for acknowledgments in seconds max_retries: Maximum retry attempts cleanup_interval: Cleanup interval in seconds Returns: MessageAcknowledgmentSystem: Configured system """ return MessageAcknowledgmentSystem(websocket_manager, ack_timeout, max_retries, cleanup_interval)