| """ |
| Message Acknowledgment System for AegisLM Framework |
| |
| Production-ready message acknowledgment system with reliable delivery, |
| retry mechanisms, and comprehensive tracking for WebSocket communications. |
| """ |
|
|
| import asyncio |
| import uuid |
| import json |
| import logging |
| from typing import Dict, Set, Optional, List, Any |
| from datetime import datetime, timedelta |
| from collections import deque |
| from dataclasses import dataclass, field |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class PendingMessage: |
| """Pending message information for acknowledgment tracking.""" |
| message_id: str |
| connection_id: str |
| room_id: Optional[str] |
| message: Dict[str, Any] |
| sent_at: datetime |
| retry_count: int = 0 |
| max_retries: int = 3 |
| ack_timeout: int = 30 |
| last_retry: Optional[datetime] = None |
|
|
|
|
| @dataclass |
| class AcknowledgmentStats: |
| """Message acknowledgment statistics.""" |
| total_sent: int = 0 |
| total_acknowledged: int = 0 |
| total_failed: int = 0 |
| total_retries: int = 0 |
| average_delivery_time: float = 0.0 |
| success_rate: float = 0.0 |
|
|
|
|
| class MessageAcknowledgmentSystem: |
| """ |
| Ensure reliable message delivery with acknowledgments. |
| |
| Provides message tracking, retry mechanisms, delivery guarantees, |
| and comprehensive statistics for WebSocket communications. |
| """ |
| |
| def __init__(self, websocket_manager, ack_timeout: int = 30, max_retries: int = 3, cleanup_interval: int = 60): |
| """ |
| Initialize message acknowledgment system. |
| |
| Args: |
| websocket_manager: WebSocket manager instance |
| ack_timeout: Timeout for message acknowledgment in seconds |
| max_retries: Maximum number of retry attempts |
| cleanup_interval: Interval for cleanup operations in seconds |
| """ |
| self.websocket_manager = websocket_manager |
| self.ack_timeout = ack_timeout |
| self.max_retries = max_retries |
| self.cleanup_interval = cleanup_interval |
| self.pending_acks: Dict[str, PendingMessage] = {} |
| self.ack_stats = AcknowledgmentStats() |
| self._cleanup_task = None |
| self._delivery_times: deque = deque(maxlen=1000) |
| |
| async def start_ack_system(self): |
| """Start acknowledgment cleanup task.""" |
| if self._cleanup_task is None: |
| self._cleanup_task = asyncio.create_task(self._periodic_ack_cleanup()) |
| logger.info("Message acknowledgment system started") |
| |
| async def stop_ack_system(self): |
| """Stop acknowledgment cleanup task.""" |
| if self._cleanup_task: |
| self._cleanup_task.cancel() |
| try: |
| await self._cleanup_task |
| except asyncio.CancelledError: |
| pass |
| self._cleanup_task = None |
| logger.info("Message acknowledgment system stopped") |
| |
| async def send_with_ack(self, connection_id: str, message: Dict, room_id: Optional[str] = None) -> str: |
| """ |
| Send message with acknowledgment tracking. |
| |
| Args: |
| connection_id: Target connection ID |
| message: Message to send |
| room_id: Optional room ID for broadcasting |
| |
| Returns: |
| str: Message ID for tracking |
| |
| Raises: |
| Exception: If send fails |
| """ |
| message_id = str(uuid.uuid4()) |
| |
| |
| ack_message = { |
| **message, |
| 'message_id': message_id, |
| 'requires_ack': True, |
| 'timestamp': datetime.utcnow().isoformat(), |
| 'sent_at': datetime.utcnow().isoformat() |
| } |
| |
| |
| pending_msg = PendingMessage( |
| message_id=message_id, |
| connection_id=connection_id, |
| room_id=room_id, |
| message=message, |
| sent_at=datetime.utcnow(), |
| retry_count=0, |
| max_retries=self.max_retries, |
| ack_timeout=self.ack_timeout |
| ) |
| |
| self.pending_acks[message_id] = pending_msg |
| |
| try: |
| |
| if room_id: |
| await self.websocket_manager.broadcast_to_room(room_id, ack_message) |
| else: |
| await self.websocket_manager.send_to_connection(connection_id, ack_message) |
| |
| self.ack_stats.total_sent += 1 |
| logger.debug(f"Message sent with acknowledgment tracking: {message_id}") |
| |
| return message_id |
| |
| except Exception as e: |
| |
| self.pending_acks.pop(message_id, None) |
| self.ack_stats.total_failed += 1 |
| logger.error(f"Failed to send message {message_id}: {e}") |
| raise e |
| |
| async def send_batch_with_ack(self, messages: List[Dict[str, Any]], connection_id: str, room_id: Optional[str] = None) -> List[str]: |
| """ |
| Send multiple messages with acknowledgment tracking. |
| |
| Args: |
| messages: List of messages to send |
| connection_id: Target connection ID |
| room_id: Optional room ID for broadcasting |
| |
| Returns: |
| List[str]: Message IDs for tracking |
| """ |
| message_ids = [] |
| |
| for message in messages: |
| try: |
| message_id = await self.send_with_ack(connection_id, message, room_id) |
| message_ids.append(message_id) |
| except Exception as e: |
| logger.error(f"Failed to send batch message: {e}") |
| |
| |
| return message_ids |
| |
| async def handle_acknowledgment(self, connection_id: str, message_id: str) -> bool: |
| """ |
| Handle message acknowledgment from client. |
| |
| Args: |
| connection_id: Connection ID that sent acknowledgment |
| message_id: Message ID being acknowledged |
| |
| Returns: |
| bool: True if acknowledgment was processed successfully |
| """ |
| if message_id not in self.pending_acks: |
| logger.warning(f"Received ack for unknown message: {message_id}") |
| return False |
| |
| pending_msg = self.pending_acks.pop(message_id) |
| |
| |
| delivery_time = (datetime.utcnow() - pending_msg.sent_at).total_seconds() |
| self._delivery_times.append(delivery_time) |
| |
| |
| self.ack_stats.total_acknowledged += 1 |
| self._update_success_rate() |
| |
| logger.debug(f"Message {message_id} acknowledged after {delivery_time:.2f}s") |
| |
| return True |
| |
| async def handle_nack(self, connection_id: str, message_id: str, reason: str) -> bool: |
| """ |
| Handle negative acknowledgment from client. |
| |
| Args: |
| connection_id: Connection ID that sent NACK |
| message_id: Message ID being NACKed |
| reason: Reason for NACK |
| |
| Returns: |
| bool: True if NACK was processed successfully |
| """ |
| if message_id not in self.pending_acks: |
| logger.warning(f"Received NACK for unknown message: {message_id}") |
| return False |
| |
| pending_msg = self.pending_acks[message_id] |
| |
| logger.warning(f"Message {message_id} NACKed: {reason}") |
| |
| |
| if pending_msg.retry_count < pending_msg.max_retries: |
| await self._retry_message(message_id, pending_msg) |
| else: |
| |
| self.pending_acks.pop(message_id, None) |
| self.ack_stats.total_failed += 1 |
| self._update_success_rate() |
| logger.error(f"Message {message_id} failed after {pending_msg.max_retries} retries") |
| |
| return True |
| |
| async def _periodic_ack_cleanup(self): |
| """Periodic cleanup of expired acknowledgments and retry failed messages.""" |
| while True: |
| try: |
| await asyncio.sleep(10) |
| cleanup_result = await self.cleanup_expired_acks() |
| if cleanup_result['cleaned'] > 0 or cleanup_result['retried'] > 0: |
| logger.debug(f"Ack cleanup: {cleanup_result}") |
| except asyncio.CancelledError: |
| break |
| except Exception as e: |
| logger.error(f"Ack cleanup error: {e}") |
| |
| async def cleanup_expired_acks(self) -> Dict[str, Any]: |
| """ |
| Clean up expired acknowledgments and retry failed messages. |
| |
| Returns: |
| Dict containing cleanup statistics |
| """ |
| current_time = datetime.utcnow() |
| expired_messages = [] |
| retry_messages = [] |
| |
| for message_id, pending_msg in self.pending_acks.items(): |
| time_since_sent = current_time - pending_msg.sent_at |
| |
| if time_since_sent.total_seconds() > pending_msg.ack_timeout: |
| if pending_msg.retry_count < pending_msg.max_retries: |
| retry_messages.append(message_id) |
| else: |
| expired_messages.append(message_id) |
| |
| |
| for message_id in expired_messages: |
| pending_msg = self.pending_acks.pop(message_id, None) |
| self.ack_stats.total_failed += 1 |
| logger.error(f"Message {message_id} expired after {pending_msg.max_retries} retries") |
| |
| |
| for message_id in retry_messages: |
| pending_msg = self.pending_acks[message_id] |
| await self._retry_message(message_id, pending_msg) |
| |
| |
| self._update_success_rate() |
| |
| return { |
| 'total_pending': len(self.pending_acks), |
| 'expired_count': len(expired_messages), |
| 'retry_count': len(retry_messages), |
| 'cleaned': len(expired_messages), |
| 'retried': len(retry_messages) |
| } |
| |
| async def _retry_message(self, message_id: str, pending_msg: PendingMessage): |
| """Retry sending a message.""" |
| try: |
| |
| pending_msg.retry_count += 1 |
| pending_msg.last_retry = datetime.utcnow() |
| |
| |
| retry_message = { |
| **pending_msg.message, |
| 'message_id': message_id, |
| 'requires_ack': True, |
| 'retry': pending_msg.retry_count, |
| 'timestamp': datetime.utcnow().isoformat(), |
| 'original_sent_at': pending_msg.sent_at.isoformat() |
| } |
| |
| |
| if pending_msg.room_id: |
| await self.websocket_manager.broadcast_to_room(pending_msg.room_id, retry_message) |
| else: |
| await self.websocket_manager.send_to_connection(pending_msg.connection_id, retry_message) |
| |
| self.ack_stats.total_retries += 1 |
| logger.info(f"Retried message {message_id} (attempt {pending_msg.retry_count})") |
| |
| except Exception as e: |
| logger.error(f"Failed to retry message {message_id}: {e}") |
| |
| self.pending_acks.pop(message_id, None) |
| self.ack_stats.total_failed += 1 |
| |
| def _update_success_rate(self): |
| """Update success rate statistics.""" |
| total = self.ack_stats.total_sent |
| if total > 0: |
| self.ack_stats.success_rate = (self.ack_stats.total_acknowledged / total) * 100 |
| |
| |
| if self._delivery_times: |
| self.ack_stats.average_delivery_time = sum(self._delivery_times) / len(self._delivery_times) |
| |
| async def get_pending_ack_count(self) -> int: |
| """Get number of pending acknowledgments.""" |
| return len(self.pending_acks) |
| |
| async def get_pending_acks_for_connection(self, connection_id: str) -> List[Dict[str, Any]]: |
| """Get pending acknowledgments for a specific connection.""" |
| pending = [] |
| |
| for message_id, pending_msg in self.pending_acks.items(): |
| if pending_msg.connection_id == connection_id: |
| pending.append({ |
| 'message_id': message_id, |
| 'sent_at': pending_msg.sent_at.isoformat(), |
| 'retry_count': pending_msg.retry_count, |
| 'max_retries': pending_msg.max_retries, |
| 'time_since_sent': (datetime.utcnow() - pending_msg.sent_at).total_seconds() |
| }) |
| |
| return pending |
| |
| async def get_ack_statistics(self) -> AcknowledgmentStats: |
| """Get acknowledgment statistics.""" |
| self._update_success_rate() |
| return self.ack_stats |
| |
| async def force_retry_message(self, message_id: str) -> bool: |
| """ |
| Force retry of a specific message. |
| |
| Args: |
| message_id: Message ID to retry |
| |
| Returns: |
| bool: True if retry was initiated |
| """ |
| if message_id not in self.pending_acks: |
| return False |
| |
| pending_msg = self.pending_acks[message_id] |
| |
| if pending_msg.retry_count < pending_msg.max_retries: |
| await self._retry_message(message_id, pending_msg) |
| return True |
| |
| return False |
| |
| async def cancel_pending_message(self, message_id: str) -> bool: |
| """ |
| Cancel a pending message. |
| |
| Args: |
| message_id: Message ID to cancel |
| |
| Returns: |
| bool: True if message was cancelled |
| """ |
| if message_id in self.pending_acks: |
| self.pending_acks.pop(message_id, None) |
| self.ack_stats.total_failed += 1 |
| self._update_success_rate() |
| logger.info(f"Cancelled pending message: {message_id}") |
| return True |
| |
| return False |
| |
| async def get_delivery_report(self) -> Dict[str, Any]: |
| """Get comprehensive delivery report.""" |
| current_time = datetime.utcnow() |
| |
| |
| age_buckets = { |
| '0-10s': 0, |
| '10-30s': 0, |
| '30-60s': 0, |
| '60s+': 0 |
| } |
| |
| for pending_msg in self.pending_acks.values(): |
| age = (current_time - pending_msg.sent_at).total_seconds() |
| if age <= 10: |
| age_buckets['0-10s'] += 1 |
| elif age <= 30: |
| age_buckets['10-30s'] += 1 |
| elif age <= 60: |
| age_buckets['30-60s'] += 1 |
| else: |
| age_buckets['60s+'] += 1 |
| |
| return { |
| 'statistics': { |
| 'total_sent': self.ack_stats.total_sent, |
| 'total_acknowledged': self.ack_stats.total_acknowledged, |
| 'total_failed': self.ack_stats.total_failed, |
| 'total_retries': self.ack_stats.total_retries, |
| 'pending_count': len(self.pending_acks), |
| 'success_rate': self.ack_stats.success_rate, |
| 'average_delivery_time': self.ack_stats.average_delivery_time |
| }, |
| 'pending_by_age': age_buckets, |
| 'retry_distribution': self._get_retry_distribution(), |
| 'timestamp': current_time.isoformat() |
| } |
| |
| def _get_retry_distribution(self) -> Dict[str, int]: |
| """Get distribution of retry counts.""" |
| distribution = {str(i): 0 for i in range(self.max_retries + 1)} |
| distribution[f'{self.max_retries}+'] = 0 |
| |
| for pending_msg in self.pending_acks.values(): |
| if pending_msg.retry_count <= self.max_retries: |
| distribution[str(pending_msg.retry_count)] += 1 |
| else: |
| distribution[f'{self.max_retries}+'] += 1 |
| |
| return distribution |
| |
| async def reset_statistics(self): |
| """Reset acknowledgment statistics.""" |
| self.ack_stats = AcknowledgmentStats() |
| self._delivery_times.clear() |
| logger.info("Message acknowledgment statistics reset") |
| |
| async def cleanup_connection_acks(self, connection_id: str) -> int: |
| """ |
| Clean up all pending acknowledgments for a connection. |
| |
| Args: |
| connection_id: Connection ID to clean up |
| |
| Returns: |
| int: Number of acknowledgments cleaned up |
| """ |
| to_remove = [] |
| |
| for message_id, pending_msg in self.pending_acks.items(): |
| if pending_msg.connection_id == connection_id: |
| to_remove.append(message_id) |
| |
| for message_id in to_remove: |
| self.pending_acks.pop(message_id, None) |
| self.ack_stats.total_failed += 1 |
| |
| logger.info(f"Cleaned up {len(to_remove)} pending acknowledgments for connection {connection_id}") |
| |
| return len(to_remove) |
|
|
|
|
| |
| def create_message_acknowledgment_system(websocket_manager, ack_timeout: int = 30, max_retries: int = 3, cleanup_interval: int = 60) -> MessageAcknowledgmentSystem: |
| """ |
| Create a message acknowledgment system instance. |
| |
| Args: |
| websocket_manager: WebSocket manager instance |
| ack_timeout: Timeout for acknowledgments in seconds |
| max_retries: Maximum retry attempts |
| cleanup_interval: Cleanup interval in seconds |
| |
| Returns: |
| MessageAcknowledgmentSystem: Configured system |
| """ |
| return MessageAcknowledgmentSystem(websocket_manager, ack_timeout, max_retries, cleanup_interval) |
|
|