import redis from typing import Callable, List, Optional import time from concurrent.futures import ThreadPoolExecutor import structlog from prometheus_client import Counter, Histogram from .config import BrokerConfig from .message import Message logger = structlog.get_logger() messages_processed = Counter( "messages_processed_total", "Total number of messages processed", ["queue", "status"] ) processing_time = Histogram( "message_processing_seconds", "Time spent processing messages", ["queue"] ) class MessageConsumer: def __init__(self, config: BrokerConfig, queue: str, handler: Callable[[Message], None]): self.config = config self.queue = queue self.handler = handler logger.info("Creating Redis connection pool", host=config.redis.host, port=config.redis.port, ssl=config.redis.ssl) connection_params = { "host": config.redis.host, "port": config.redis.port, "db": config.redis.db, "password": config.redis.password, "decode_responses": True, "max_connections": config.redis.connection_pool_size } if config.redis.ssl: connection_params.update({ "ssl": True, "ssl_cert_reqs": None, "ssl_ca_certs": None }) connection_pool = redis.ConnectionPool(**connection_params) self._redis = redis.Redis(connection_pool=connection_pool) self._executor = ThreadPoolExecutor(max_workers=config.num_workers) self._running = False logger.info("Message consumer initialized", queue=queue, config=config.dict()) def start(self) -> None: """Start consuming messages.""" self._running = True self._executor.submit(self._process_retry_queue) logger.info("self.config.num_workers") for _ in range(self.config.num_workers): logger.info("-----------------") self._executor.submit(self._consume) logger.info("Consumer started", queue=self.queue) def stop(self) -> None: """Stop consuming messages.""" self._running = False self._executor.shutdown(wait=True) logger.info("Consumer stopped", queue=self.queue) def _consume(self) -> None: """Consume messages from the queue.""" logger.info("Consumer thread started", queue=self.queue) while self._running: try: messages = self._batch_pop_messages() if messages: logger.info("Received messages", queue=self.queue, count=len(messages)) for message_data in messages: self._process_message(Message.from_json(message_data)) else: # Small sleep to prevent CPU spinning when queue is empty time.sleep(0.1) except Exception as e: logger.error("Error in consumer loop", error=str(e), queue=self.queue) time.sleep(1) def _batch_pop_messages(self) -> List[str]: """Pop a batch of messages from the queue.""" messages = [] try: # Using brpop instead of rpop for blocking operation result = self._redis.brpop([f"queue:{self.queue}"], timeout=1) if result: messages.append(result[1]) # brpop returns (key, value) tuple # Try to get more messages up to batch size for _ in range(self.config.batch_size - 1): msg = self._redis.rpop(f"queue:{self.queue}") if msg: messages.append(msg) else: break logger.debug("Batch pop result", queue=self.queue, messages_count=len(messages)) return messages except Exception as e: logger.error("Error in batch pop", error=str(e), queue=self.queue) return [] def _process_message(self, message: Message) -> None: """Process a single message.""" with processing_time.labels(queue=self.queue).time(): try: self.handler(message) messages_processed.labels( queue=self.queue, status="success" ).inc() logger.info( "Message processed successfully", message_id=message.id, queue=self.queue ) except Exception as e: messages_processed.labels( queue=self.queue, status="error" ).inc() message.error = str(e) self._handle_processing_error(message) def _handle_processing_error(self, message: Message) -> None: """Handle a message processing error.""" if message.retry_count < message.max_retries: self._retry_message(message) else: self._move_to_dead_letter(message) def _retry_message(self, message: Message) -> None: """Move a message to the retry queue with exponential backoff.""" message.retry_count += 1 delay = min( self.config.retry.initial_delay * (self.config.retry.backoff_factor ** (message.retry_count - 1)), self.config.retry.max_delay ) self._redis.zadd( f"retry:{self.queue}", {message.to_json(): time.time() + delay} ) logger.info( "Message scheduled for retry", message_id=message.id, queue=self.queue, retry_count=message.retry_count, delay=delay ) def _process_retry_queue(self) -> None: """Process messages in the retry queue.""" while self._running: try: # Get messages that are ready to be retried messages = self._redis.zrangebyscore( f"retry:{self.queue}", "-inf", time.time(), start=0, num=self.config.batch_size ) if not messages: time.sleep(self.config.polling_interval) continue # Remove the processed messages from the retry queue pipeline = self._redis.pipeline() for message_data in messages: #message = Message.from_json(message_data) pipeline.zrem(f"retry:{self.queue}", message_data) pipeline.lpush(f"queue:{self.queue}", message_data) pipeline.execute() except Exception as e: logger.error("Error processing retry queue", error=str(e)) time.sleep(1) def _move_to_dead_letter(self, message: Message) -> None: """Move a message to the dead letter queue.""" try: self._redis.lpush(f"dead_letter:{self.queue}", message.to_json()) logger.warning( "Message moved to dead letter queue", message_id=message.id, queue=self.queue, error=message.error ) except redis.RedisError as e: logger.error("Failed to move message to dead letter queue", error=str(e))