""" Allocation Worker — Consumes booking events from booking_allocation_queue, finds eligible partners, scores them, and creates an offer to the top-ranked partner. Responsibilities: 1. Consume BookingEvent from Redis Stream 2. Enrich minimal retry events (missing city/lat/lng) by re-fetching from DB 3. Idempotency guard — check allocation_status before processing 4. SLA guard — reject bookings whose scheduled_time has already passed 5. Query eligible partners (multi-category, city match, not previously attempted) 6. Score and rank partners using weighted algorithm (uses actual max_load per partner) 7. Create offer to top-ranked partner 8. Handle no-partner and all-exhausted scenarios 9. Reclaim stale PEL messages from crashed instances (XAUTOCLAIM) 10. Graceful shutdown with in-flight task completion """ import asyncio import json import os import signal import socket from datetime import datetime, timezone from typing import Optional, Set from uuid import UUID from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncEngine from app.core.config import settings from app.core.logging import get_logger from app.core.settings_loader import get_allocation_settings from app.models.booking import BookingEvent from app.queue.queue_manager import RedisQueueManager from app.services.notification_service import NotificationService from app.services.offer_service import OfferService from app.services.scoring_service import ScoringService logger = get_logger(__name__) QUEUE = settings.BOOKING_ALLOCATION_QUEUE GROUP = "allocation-workers" PEL_RECLAIM_INTERVAL = 30 # reclaim every N loop iterations (~2.5 min at 5s block) PEL_MIN_IDLE_MS = 60_000 # reclaim messages idle > 1 minute (safe: healthy consumers ack within seconds) class AllocationWorker: """Consumes booking events, scores partners, dispatches offers.""" def __init__( self, queue_manager: RedisQueueManager, offer_service: OfferService, scoring_service: ScoringService, engine: AsyncEngine, consumer_name: str, mongo_db=None, max_retries: int = 3, ): self.queue_manager = queue_manager self.offer_service = offer_service self.scoring_service = scoring_service self.engine = engine self.consumer_name = consumer_name self.mongo_db = mongo_db self.max_retries = max_retries self.notification_service = NotificationService(queue_manager=queue_manager) self._shutdown = False self._in_flight: Set[str] = set() self._loop_counter = 0 def setup_signal_handlers(self): for sig in (signal.SIGTERM, signal.SIGINT): signal.signal(sig, self._handle_signal) def _handle_signal(self, signum, frame): logger.info("Shutdown signal received", extra={"signal": signum}) self._shutdown = True async def run(self): """Main loop — consume events from booking_allocation_queue.""" self.setup_signal_handlers() # Ensure consumer group exists await self.queue_manager.create_consumer_group(QUEUE, GROUP) logger.info("AllocationWorker started", extra={"consumer": self.consumer_name}) while not self._shutdown: self._loop_counter += 1 # Gap 4 fix: own our PEL reclaim — runs every PEL_RECLAIM_INTERVAL iterations if self._loop_counter % PEL_RECLAIM_INTERVAL == 0: await self._reclaim_stale_messages() try: messages = await self.queue_manager.consume_events( QUEUE, GROUP, self.consumer_name, count=1, block=5000, ) if not messages: continue for stream_name, message_list in messages: for message_id, data in message_list: self._in_flight.add(message_id) try: await self._process_message(message_id, data) finally: self._in_flight.discard(message_id) except Exception as e: logger.error("Allocation worker cycle error", extra={"error": str(e)}) await asyncio.sleep(2) # Wait for in-flight tasks to complete if self._in_flight: logger.info("Waiting for in-flight tasks", extra={"count": len(self._in_flight)}) await asyncio.sleep(5) logger.info("AllocationWorker stopped") # ------------------------------------------------------------------------- # Gap 4 fix: PEL reclaim — each worker owns its own queue # ------------------------------------------------------------------------- async def _reclaim_stale_messages(self): """ XAUTOCLAIM messages idle > 5min in our own PEL (crashed instance recovery). The scheduler must NOT do this — it would assign messages to 'scheduler-0', a consumer that never reads from booking_allocation_queue. """ try: claimed, _, _ = await self.queue_manager.redis.xautoclaim( QUEUE, GROUP, self.consumer_name, min_idle_time=PEL_MIN_IDLE_MS, start_id="0-0", count=10, ) if claimed: logger.info( "Reclaimed stale PEL messages", extra={"queue": QUEUE, "count": len(claimed), "consumer": self.consumer_name}, ) except Exception as e: # Consumer group may not exist yet at first startup — not fatal logger.debug("PEL reclaim skipped", extra={"queue": QUEUE, "error": str(e)}) # ------------------------------------------------------------------------- # Gap 1 fix: enrich minimal retry events before parsing # ------------------------------------------------------------------------- async def _fetch_booking_details(self, booking_id: str) -> Optional[dict]: """ Fetch full booking details from DB for enriching minimal retry events. Retry events published by expiry/response workers only contain booking_id, retry_reason, previous_partner_id, previous_offer_id — all booking fields (city, lat, lng, scheduled_time, service_categories) must be re-fetched. Returns dict in the same string format as a Redis Stream message, or None if the booking no longer exists. """ try: async with self.engine.begin() as conn: result = await conn.execute( text(""" SELECT b.booking_id, b.booking_ref, b.city, b.location_lat, b.location_lng, b.scheduled_time, b.payment_mode, b.customer_id, array_agg(DISTINCT i.service_category) AS service_categories FROM trans.spa_bookings b JOIN trans.spa_booking_items i ON i.booking_id = b.booking_id WHERE b.booking_id = :booking_id GROUP BY b.booking_id, b.booking_ref, b.city, b.location_lat, b.location_lng, b.scheduled_time, b.payment_mode, b.customer_id """), {"booking_id": booking_id}, ) row = result.fetchone() if not row: return None return { "booking_id": str(row.booking_id), "booking_ref": row.booking_ref or "", "city": row.city, "location_lat": str(row.location_lat), "location_lng": str(row.location_lng), "scheduled_time": str(row.scheduled_time), "service_categories": json.dumps(list(row.service_categories)), "payment_mode": row.payment_mode or "pay_later", "customer_id": str(row.customer_id) if row.customer_id else "", } except Exception as e: logger.error( "Failed to fetch booking details for retry enrichment", extra={"booking_id": booking_id, "error": str(e)}, ) return None async def _process_message(self, message_id: str, data: dict): """Process a single booking event with full idempotency guard.""" retry_count = 0 booking_id = data.get("booking_id", "unknown") # Gap 1 fix: retry events from expiry/response workers are minimal — # they only carry booking_id + retry metadata, missing city/lat/lng/scheduled_time. # Re-fetch full booking details from DB before attempting to parse. if "retry_reason" in data and "city" not in data: logger.info( "Minimal retry event detected — enriching from DB", extra={"booking_id": booking_id, "retry_reason": data.get("retry_reason")}, ) db_data = await self._fetch_booking_details(booking_id) if not db_data: logger.warning( "Booking not found for retry enrichment — discarding", extra={"booking_id": booking_id}, ) await self.queue_manager.acknowledge_event(QUEUE, GROUP, message_id) return # Preserve all retry metadata alongside full booking fields data = { **db_data, "retry_reason": data.get("retry_reason"), "previous_partner_id": data.get("previous_partner_id"), "previous_offer_id": data.get("previous_offer_id"), } while retry_count <= self.max_retries: try: event = self._parse_event(data) await self._allocate(event) await self.queue_manager.acknowledge_event(QUEUE, GROUP, message_id) return except ValueError as e: # Invalid event — move to DLQ immediately logger.error("Invalid booking event", extra={"booking_id": booking_id, "error": str(e)}) await self.queue_manager.move_to_dead_letter_queue( data, e, QUEUE, retry_count, booking_id=booking_id, ) await self.queue_manager.acknowledge_event(QUEUE, GROUP, message_id) return except Exception as e: retry_count += 1 if retry_count > self.max_retries: logger.error("Max retries exceeded", extra={"booking_id": booking_id, "error": str(e)}) await self.queue_manager.move_to_dead_letter_queue( data, e, QUEUE, retry_count, booking_id=booking_id, ) await self.queue_manager.acknowledge_event(QUEUE, GROUP, message_id) return backoff = 2 ** (retry_count - 1) logger.warning( "Retrying allocation", extra={"booking_id": booking_id, "retry": retry_count, "backoff": backoff}, ) await asyncio.sleep(backoff) def _parse_event(self, data: dict) -> BookingEvent: """Parse Redis Stream message into BookingEvent.""" # service_categories comes as JSON string from XADD categories = data.get("service_categories", "[]") if isinstance(categories, str): categories = json.loads(categories) return BookingEvent( booking_id=data["booking_id"], booking_ref=data.get("booking_ref", ""), city=data["city"], location_lat=data["location_lat"], location_lng=data["location_lng"], scheduled_time=int(data["scheduled_time"]), service_categories=categories, payment_mode=data.get("payment_mode", "pay_later"), customer_id=data.get("customer_id", ""), ) async def _allocate(self, event: BookingEvent): """Core allocation logic with idempotency guard and SLA guard.""" alloc_settings = await get_allocation_settings(self.queue_manager.redis, self.mongo_db) # Idempotency guard — check current booking state async with self.engine.begin() as conn: result = await conn.execute( text(""" SELECT allocation_status, booking_status FROM trans.spa_bookings WHERE booking_id = :booking_id FOR UPDATE """), {"booking_id": str(event.booking_id)}, ) row = result.fetchone() if not row: logger.warning("Booking not found — discarding", extra={"booking_id": str(event.booking_id)}) return if row.booking_status == "cancelled": logger.info("Booking cancelled — discarding", extra={"booking_id": str(event.booking_id)}) return if row.allocation_status == "assigned": logger.info("Already assigned — discarding duplicate", extra={"booking_id": str(event.booking_id)}) return if row.allocation_status == "failed": logger.info("Already failed — discarding", extra={"booking_id": str(event.booking_id)}) return # Only proceed if allocation_status is 'offering' if row.allocation_status not in ("offering", "unassigned"): logger.warning( "Unexpected allocation_status", extra={"booking_id": str(event.booking_id), "status": row.allocation_status}, ) return # Gap 2 fix: SLA guard — reject if booking's scheduled time has already passed. # The scheduler's _expire_missed_bookings() also catches this but only on its # heartbeat interval (every N seconds). Between heartbeats, retry events can # carry bookings that crossed their scheduled time mid-allocation. Catch it here # so we never create an offer for a booking that's already overdue. now_utc = datetime.now(tz=timezone.utc) if event.scheduled_dt <= now_utc: logger.warning( "Booking past scheduled time — marking failed", extra={ "booking_id": str(event.booking_id), "scheduled_dt": event.scheduled_dt.isoformat(), "now_utc": now_utc.isoformat(), }, ) await self._mark_booking_failed(event.booking_id) await self.notification_service.notify_admin_allocation_failed(event.booking_id, city=event.city) await self.notification_service.send_customer_booking_failed(event.booking_id, customer_id=event.customer_id) return # Find eligible partners partners = await self._get_eligible_partners(event) if not partners: logger.warning("No eligible partners found", extra={"booking_id": str(event.booking_id)}) await self._mark_booking_failed(event.booking_id) await self.notification_service.notify_admin_allocation_failed(event.booking_id, city=event.city) await self.notification_service.send_customer_booking_failed(event.booking_id, customer_id=event.customer_id) return # Score and rank partners ranked = self.scoring_service.rank_partners( partners=partners, booking_lat=float(event.location_lat), booking_lng=float(event.location_lng), weights=alloc_settings.get("scoring_weights"), max_distance_km=alloc_settings.get("max_partner_distance_km", 50), ) if not ranked: logger.warning("No partners after scoring", extra={"booking_id": str(event.booking_id)}) await self._mark_booking_failed(event.booking_id) await self.notification_service.notify_admin_allocation_failed(event.booking_id, city=event.city) await self.notification_service.send_customer_booking_failed(event.booking_id, customer_id=event.customer_id) return # Create offer to top-ranked partner top = ranked[0] offer = await self.offer_service.create_offer( booking_id=event.booking_id, partner_id=top.partner.partner_id, expiry_seconds=alloc_settings.get("offer_expiry_seconds", settings.OFFER_EXPIRY_SECONDS), ) if offer: # Write score to audit trail await self._write_attempt_score( event.booking_id, top.partner.partner_id, float(top.score), ) logger.info( "Offer created", extra={ "booking_id": str(event.booking_id), "partner_id": str(top.partner.partner_id), "score": float(top.score), }, ) await self.notification_service.send_offer_notification( partner_id=top.partner.partner_id, offer_id=offer.offer_id, booking_id=event.booking_id, scheduled_time=event.scheduled_time, offer_expiry=offer.offer_expiry, booking_ref=event.booking_ref, service_categories=event.service_categories, ) else: logger.warning( "Offer creation returned None — booking already assigned, cancelled, " "pending offer exists, or Redis expiry publish failed", extra={"booking_id": str(event.booking_id), "partner_id": str(top.partner.partner_id)}, ) async def _get_eligible_partners(self, event: BookingEvent) -> list: """ CTE-based query: find partners covering ALL service categories, available at scheduled time, under max load, not previously attempted. Returns Partner objects with actual max_load from spa_partner_availability so the scoring service can compute accurate load scores per partner capacity. """ async with self.engine.begin() as conn: result = await conn.execute( text(""" WITH booking_categories AS ( SELECT service_category FROM trans.spa_booking_items WHERE booking_id = :booking_id ), category_count AS ( SELECT COUNT(DISTINCT service_category) AS total FROM booking_categories ), eligible_partners AS ( SELECT psm.partner_id, COUNT(DISTINCT psm.service_category) AS matched_categories FROM trans.spa_partner_service_map psm JOIN booking_categories bc ON bc.service_category = psm.service_category WHERE psm.city = :city AND psm.is_active = TRUE AND NOT EXISTS ( SELECT 1 FROM trans.spa_allocation_attempts aa WHERE aa.booking_id = :booking_id AND aa.partner_id = psm.partner_id ) GROUP BY psm.partner_id HAVING COUNT(DISTINCT psm.service_category) = (SELECT total FROM category_count) ) SELECT ep.partner_id, pp.rating, pp.completed_bookings, pa.current_load, pa.max_load, pa.location_lat, pa.location_lng FROM eligible_partners ep JOIN trans.spa_partner_availability pa ON pa.partner_id = ep.partner_id AND pa.available_from <= to_timestamp(:scheduled_time_ms / 1000.0) AND pa.available_to >= to_timestamp(:scheduled_time_ms / 1000.0) AND pa.current_load < pa.max_load JOIN trans.spa_partner_profiles pp ON pp.partner_id = ep.partner_id """), { "booking_id": str(event.booking_id), "city": event.city, "scheduled_time_ms": event.scheduled_time, }, ) rows = result.fetchall() # Gap 3 fix: pass actual max_load from DB — scoring service uses partner.max_load, # not a hardcoded constant, so load scores reflect each partner's real capacity. from app.models.partner import Partner partners = [] for row in rows: partners.append(Partner( partner_id=row.partner_id, name=row.partner_id.hex, # Required by Pydantic min_length=1; not used in scoring rating=row.rating, completed_bookings=row.completed_bookings, active_bookings=row.current_load, max_load=row.max_load, # actual capacity, not hardcoded 5 location_lat=row.location_lat, location_lng=row.location_lng, service_categories=[], # already validated by CTE city=event.city, )) return partners async def _write_attempt_score(self, booking_id: UUID, partner_id: UUID, score: float): """Write score to spa_allocation_attempts for audit trail.""" async with self.engine.begin() as conn: await conn.execute( text(""" INSERT INTO trans.spa_allocation_attempts (booking_id, partner_id, attempt_timestamp, attempt_status, score) VALUES (:booking_id, :partner_id, :ts, 'offered', :score) ON CONFLICT (booking_id, partner_id) DO UPDATE SET attempt_status = 'offered', score = :score, attempt_timestamp = :ts """), { "booking_id": str(booking_id), "partner_id": str(partner_id), "ts": datetime.now(tz=timezone.utc), "score": score, }, ) async def _mark_booking_failed(self, booking_id: UUID): """Mark booking as failed when no eligible partners remain or SLA expired.""" async with self.engine.begin() as conn: await conn.execute( text(""" UPDATE trans.spa_bookings SET allocation_status = 'failed', updated_at = NOW() WHERE booking_id = :booking_id AND allocation_status != 'assigned' """), {"booking_id": str(booking_id)}, ) logger.warning("Booking marked as failed", extra={"booking_id": str(booking_id)}) async def main(): """Entry point for allocation worker.""" from app.core.logging import setup_logging setup_logging(settings.LOG_LEVEL) from app.db.postgres import create_pg_engine from app.db.mongo import get_mongo_db from app.queue.redis_client import create_redis_client logger.info("Initializing Allocation Worker") engine = await create_pg_engine() redis_client = await create_redis_client() queue_manager = RedisQueueManager(redis_client) offer_service = OfferService(engine, queue_manager) scoring_service = ScoringService() try: mongo_db = await get_mongo_db() except Exception as e: logger.warning("MongoDB unavailable — using .env defaults", extra={"error": str(e)}) mongo_db = None consumer_name = os.getenv("HOSTNAME", socket.gethostname()) worker = AllocationWorker( queue_manager=queue_manager, offer_service=offer_service, scoring_service=scoring_service, engine=engine, consumer_name=consumer_name, mongo_db=mongo_db, ) try: await worker.run() finally: await queue_manager.close() await engine.dispose() from app.db.mongo import close_mongo await close_mongo() logger.info("Allocation Worker shutdown complete") if __name__ == "__main__": asyncio.run(main())