ALM-2 / backend /websocket /message_acknowledgment.py
ACA050's picture
Upload 520 files
2ed8996 verified
"""
Message Acknowledgment System for AegisLM Framework
Production-ready message acknowledgment system with reliable delivery,
retry mechanisms, and comprehensive tracking for WebSocket communications.
"""
import asyncio
import uuid
import json
import logging
from typing import Dict, Set, Optional, List, Any
from datetime import datetime, timedelta
from collections import deque
from dataclasses import dataclass, field
logger = logging.getLogger(__name__)
@dataclass
class PendingMessage:
"""Pending message information for acknowledgment tracking."""
message_id: str
connection_id: str
room_id: Optional[str]
message: Dict[str, Any]
sent_at: datetime
retry_count: int = 0
max_retries: int = 3
ack_timeout: int = 30 # seconds
last_retry: Optional[datetime] = None
@dataclass
class AcknowledgmentStats:
"""Message acknowledgment statistics."""
total_sent: int = 0
total_acknowledged: int = 0
total_failed: int = 0
total_retries: int = 0
average_delivery_time: float = 0.0
success_rate: float = 0.0
class MessageAcknowledgmentSystem:
"""
Ensure reliable message delivery with acknowledgments.
Provides message tracking, retry mechanisms, delivery guarantees,
and comprehensive statistics for WebSocket communications.
"""
def __init__(self, websocket_manager, ack_timeout: int = 30, max_retries: int = 3, cleanup_interval: int = 60):
"""
Initialize message acknowledgment system.
Args:
websocket_manager: WebSocket manager instance
ack_timeout: Timeout for message acknowledgment in seconds
max_retries: Maximum number of retry attempts
cleanup_interval: Interval for cleanup operations in seconds
"""
self.websocket_manager = websocket_manager
self.ack_timeout = ack_timeout
self.max_retries = max_retries
self.cleanup_interval = cleanup_interval
self.pending_acks: Dict[str, PendingMessage] = {}
self.ack_stats = AcknowledgmentStats()
self._cleanup_task = None
self._delivery_times: deque = deque(maxlen=1000)
async def start_ack_system(self):
"""Start acknowledgment cleanup task."""
if self._cleanup_task is None:
self._cleanup_task = asyncio.create_task(self._periodic_ack_cleanup())
logger.info("Message acknowledgment system started")
async def stop_ack_system(self):
"""Stop acknowledgment cleanup task."""
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
self._cleanup_task = None
logger.info("Message acknowledgment system stopped")
async def send_with_ack(self, connection_id: str, message: Dict, room_id: Optional[str] = None) -> str:
"""
Send message with acknowledgment tracking.
Args:
connection_id: Target connection ID
message: Message to send
room_id: Optional room ID for broadcasting
Returns:
str: Message ID for tracking
Raises:
Exception: If send fails
"""
message_id = str(uuid.uuid4())
# Add acknowledgment metadata
ack_message = {
**message,
'message_id': message_id,
'requires_ack': True,
'timestamp': datetime.utcnow().isoformat(),
'sent_at': datetime.utcnow().isoformat()
}
# Store pending acknowledgment
pending_msg = PendingMessage(
message_id=message_id,
connection_id=connection_id,
room_id=room_id,
message=message,
sent_at=datetime.utcnow(),
retry_count=0,
max_retries=self.max_retries,
ack_timeout=self.ack_timeout
)
self.pending_acks[message_id] = pending_msg
try:
# Send message
if room_id:
await self.websocket_manager.broadcast_to_room(room_id, ack_message)
else:
await self.websocket_manager.send_to_connection(connection_id, ack_message)
self.ack_stats.total_sent += 1
logger.debug(f"Message sent with acknowledgment tracking: {message_id}")
return message_id
except Exception as e:
# Remove from pending if send failed
self.pending_acks.pop(message_id, None)
self.ack_stats.total_failed += 1
logger.error(f"Failed to send message {message_id}: {e}")
raise e
async def send_batch_with_ack(self, messages: List[Dict[str, Any]], connection_id: str, room_id: Optional[str] = None) -> List[str]:
"""
Send multiple messages with acknowledgment tracking.
Args:
messages: List of messages to send
connection_id: Target connection ID
room_id: Optional room ID for broadcasting
Returns:
List[str]: Message IDs for tracking
"""
message_ids = []
for message in messages:
try:
message_id = await self.send_with_ack(connection_id, message, room_id)
message_ids.append(message_id)
except Exception as e:
logger.error(f"Failed to send batch message: {e}")
# Continue with other messages
return message_ids
async def handle_acknowledgment(self, connection_id: str, message_id: str) -> bool:
"""
Handle message acknowledgment from client.
Args:
connection_id: Connection ID that sent acknowledgment
message_id: Message ID being acknowledged
Returns:
bool: True if acknowledgment was processed successfully
"""
if message_id not in self.pending_acks:
logger.warning(f"Received ack for unknown message: {message_id}")
return False
pending_msg = self.pending_acks.pop(message_id)
# Calculate delivery time
delivery_time = (datetime.utcnow() - pending_msg.sent_at).total_seconds()
self._delivery_times.append(delivery_time)
# Update statistics
self.ack_stats.total_acknowledged += 1
self._update_success_rate()
logger.debug(f"Message {message_id} acknowledged after {delivery_time:.2f}s")
return True
async def handle_nack(self, connection_id: str, message_id: str, reason: str) -> bool:
"""
Handle negative acknowledgment from client.
Args:
connection_id: Connection ID that sent NACK
message_id: Message ID being NACKed
reason: Reason for NACK
Returns:
bool: True if NACK was processed successfully
"""
if message_id not in self.pending_acks:
logger.warning(f"Received NACK for unknown message: {message_id}")
return False
pending_msg = self.pending_acks[message_id]
logger.warning(f"Message {message_id} NACKed: {reason}")
# Try to resend if retries available
if pending_msg.retry_count < pending_msg.max_retries:
await self._retry_message(message_id, pending_msg)
else:
# Max retries exceeded, mark as failed
self.pending_acks.pop(message_id, None)
self.ack_stats.total_failed += 1
self._update_success_rate()
logger.error(f"Message {message_id} failed after {pending_msg.max_retries} retries")
return True
async def _periodic_ack_cleanup(self):
"""Periodic cleanup of expired acknowledgments and retry failed messages."""
while True:
try:
await asyncio.sleep(10) # Check every 10 seconds
cleanup_result = await self.cleanup_expired_acks()
if cleanup_result['cleaned'] > 0 or cleanup_result['retried'] > 0:
logger.debug(f"Ack cleanup: {cleanup_result}")
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Ack cleanup error: {e}")
async def cleanup_expired_acks(self) -> Dict[str, Any]:
"""
Clean up expired acknowledgments and retry failed messages.
Returns:
Dict containing cleanup statistics
"""
current_time = datetime.utcnow()
expired_messages = []
retry_messages = []
for message_id, pending_msg in self.pending_acks.items():
time_since_sent = current_time - pending_msg.sent_at
if time_since_sent.total_seconds() > pending_msg.ack_timeout:
if pending_msg.retry_count < pending_msg.max_retries:
retry_messages.append(message_id)
else:
expired_messages.append(message_id)
# Handle expired messages
for message_id in expired_messages:
pending_msg = self.pending_acks.pop(message_id, None)
self.ack_stats.total_failed += 1
logger.error(f"Message {message_id} expired after {pending_msg.max_retries} retries")
# Retry messages
for message_id in retry_messages:
pending_msg = self.pending_acks[message_id]
await self._retry_message(message_id, pending_msg)
# Update statistics
self._update_success_rate()
return {
'total_pending': len(self.pending_acks),
'expired_count': len(expired_messages),
'retry_count': len(retry_messages),
'cleaned': len(expired_messages),
'retried': len(retry_messages)
}
async def _retry_message(self, message_id: str, pending_msg: PendingMessage):
"""Retry sending a message."""
try:
# Update retry count
pending_msg.retry_count += 1
pending_msg.last_retry = datetime.utcnow()
# Add retry metadata
retry_message = {
**pending_msg.message,
'message_id': message_id,
'requires_ack': True,
'retry': pending_msg.retry_count,
'timestamp': datetime.utcnow().isoformat(),
'original_sent_at': pending_msg.sent_at.isoformat()
}
# Resend message
if pending_msg.room_id:
await self.websocket_manager.broadcast_to_room(pending_msg.room_id, retry_message)
else:
await self.websocket_manager.send_to_connection(pending_msg.connection_id, retry_message)
self.ack_stats.total_retries += 1
logger.info(f"Retried message {message_id} (attempt {pending_msg.retry_count})")
except Exception as e:
logger.error(f"Failed to retry message {message_id}: {e}")
# Mark as failed
self.pending_acks.pop(message_id, None)
self.ack_stats.total_failed += 1
def _update_success_rate(self):
"""Update success rate statistics."""
total = self.ack_stats.total_sent
if total > 0:
self.ack_stats.success_rate = (self.ack_stats.total_acknowledged / total) * 100
# Update average delivery time
if self._delivery_times:
self.ack_stats.average_delivery_time = sum(self._delivery_times) / len(self._delivery_times)
async def get_pending_ack_count(self) -> int:
"""Get number of pending acknowledgments."""
return len(self.pending_acks)
async def get_pending_acks_for_connection(self, connection_id: str) -> List[Dict[str, Any]]:
"""Get pending acknowledgments for a specific connection."""
pending = []
for message_id, pending_msg in self.pending_acks.items():
if pending_msg.connection_id == connection_id:
pending.append({
'message_id': message_id,
'sent_at': pending_msg.sent_at.isoformat(),
'retry_count': pending_msg.retry_count,
'max_retries': pending_msg.max_retries,
'time_since_sent': (datetime.utcnow() - pending_msg.sent_at).total_seconds()
})
return pending
async def get_ack_statistics(self) -> AcknowledgmentStats:
"""Get acknowledgment statistics."""
self._update_success_rate()
return self.ack_stats
async def force_retry_message(self, message_id: str) -> bool:
"""
Force retry of a specific message.
Args:
message_id: Message ID to retry
Returns:
bool: True if retry was initiated
"""
if message_id not in self.pending_acks:
return False
pending_msg = self.pending_acks[message_id]
if pending_msg.retry_count < pending_msg.max_retries:
await self._retry_message(message_id, pending_msg)
return True
return False
async def cancel_pending_message(self, message_id: str) -> bool:
"""
Cancel a pending message.
Args:
message_id: Message ID to cancel
Returns:
bool: True if message was cancelled
"""
if message_id in self.pending_acks:
self.pending_acks.pop(message_id, None)
self.ack_stats.total_failed += 1
self._update_success_rate()
logger.info(f"Cancelled pending message: {message_id}")
return True
return False
async def get_delivery_report(self) -> Dict[str, Any]:
"""Get comprehensive delivery report."""
current_time = datetime.utcnow()
# Analyze pending messages by age
age_buckets = {
'0-10s': 0,
'10-30s': 0,
'30-60s': 0,
'60s+': 0
}
for pending_msg in self.pending_acks.values():
age = (current_time - pending_msg.sent_at).total_seconds()
if age <= 10:
age_buckets['0-10s'] += 1
elif age <= 30:
age_buckets['10-30s'] += 1
elif age <= 60:
age_buckets['30-60s'] += 1
else:
age_buckets['60s+'] += 1
return {
'statistics': {
'total_sent': self.ack_stats.total_sent,
'total_acknowledged': self.ack_stats.total_acknowledged,
'total_failed': self.ack_stats.total_failed,
'total_retries': self.ack_stats.total_retries,
'pending_count': len(self.pending_acks),
'success_rate': self.ack_stats.success_rate,
'average_delivery_time': self.ack_stats.average_delivery_time
},
'pending_by_age': age_buckets,
'retry_distribution': self._get_retry_distribution(),
'timestamp': current_time.isoformat()
}
def _get_retry_distribution(self) -> Dict[str, int]:
"""Get distribution of retry counts."""
distribution = {str(i): 0 for i in range(self.max_retries + 1)}
distribution[f'{self.max_retries}+'] = 0
for pending_msg in self.pending_acks.values():
if pending_msg.retry_count <= self.max_retries:
distribution[str(pending_msg.retry_count)] += 1
else:
distribution[f'{self.max_retries}+'] += 1
return distribution
async def reset_statistics(self):
"""Reset acknowledgment statistics."""
self.ack_stats = AcknowledgmentStats()
self._delivery_times.clear()
logger.info("Message acknowledgment statistics reset")
async def cleanup_connection_acks(self, connection_id: str) -> int:
"""
Clean up all pending acknowledgments for a connection.
Args:
connection_id: Connection ID to clean up
Returns:
int: Number of acknowledgments cleaned up
"""
to_remove = []
for message_id, pending_msg in self.pending_acks.items():
if pending_msg.connection_id == connection_id:
to_remove.append(message_id)
for message_id in to_remove:
self.pending_acks.pop(message_id, None)
self.ack_stats.total_failed += 1
logger.info(f"Cleaned up {len(to_remove)} pending acknowledgments for connection {connection_id}")
return len(to_remove)
# Factory function
def create_message_acknowledgment_system(websocket_manager, ack_timeout: int = 30, max_retries: int = 3, cleanup_interval: int = 60) -> MessageAcknowledgmentSystem:
"""
Create a message acknowledgment system instance.
Args:
websocket_manager: WebSocket manager instance
ack_timeout: Timeout for acknowledgments in seconds
max_retries: Maximum retry attempts
cleanup_interval: Cleanup interval in seconds
Returns:
MessageAcknowledgmentSystem: Configured system
"""
return MessageAcknowledgmentSystem(websocket_manager, ack_timeout, max_retries, cleanup_interval)