cuatrolabs-booking-allocation-ms / app /workers /allocation_worker.py
Michael-Antony's picture
feat: Replace stream queue with Redis hash for partner offer notifications
0345040
"""
Allocation Worker — Consumes booking events from booking_allocation_queue,
finds eligible partners, scores them, and creates an offer to the top-ranked partner.
Responsibilities:
1. Consume BookingEvent from Redis Stream
2. Enrich minimal retry events (missing city/lat/lng) by re-fetching from DB
3. Idempotency guard — check allocation_status before processing
4. SLA guard — reject bookings whose scheduled_time has already passed
5. Query eligible partners (multi-category, city match, not previously attempted)
6. Score and rank partners using weighted algorithm (uses actual max_load per partner)
7. Create offer to top-ranked partner
8. Handle no-partner and all-exhausted scenarios
9. Reclaim stale PEL messages from crashed instances (XAUTOCLAIM)
10. Graceful shutdown with in-flight task completion
"""
import asyncio
import json
import os
import signal
import socket
from datetime import datetime, timezone
from typing import Optional, Set
from uuid import UUID
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncEngine
from app.core.config import settings
from app.core.logging import get_logger
from app.core.settings_loader import get_allocation_settings
from app.models.booking import BookingEvent
from app.queue.queue_manager import RedisQueueManager
from app.services.notification_service import NotificationService
from app.services.offer_service import OfferService
from app.services.scoring_service import ScoringService
logger = get_logger(__name__)
QUEUE = settings.BOOKING_ALLOCATION_QUEUE
GROUP = "allocation-workers"
PEL_RECLAIM_INTERVAL = 30 # reclaim every N loop iterations (~2.5 min at 5s block)
PEL_MIN_IDLE_MS = 60_000 # reclaim messages idle > 1 minute (safe: healthy consumers ack within seconds)
class AllocationWorker:
"""Consumes booking events, scores partners, dispatches offers."""
def __init__(
self,
queue_manager: RedisQueueManager,
offer_service: OfferService,
scoring_service: ScoringService,
engine: AsyncEngine,
consumer_name: str,
mongo_db=None,
max_retries: int = 3,
):
self.queue_manager = queue_manager
self.offer_service = offer_service
self.scoring_service = scoring_service
self.engine = engine
self.consumer_name = consumer_name
self.mongo_db = mongo_db
self.max_retries = max_retries
self.notification_service = NotificationService(queue_manager=queue_manager)
self._shutdown = False
self._in_flight: Set[str] = set()
self._loop_counter = 0
def setup_signal_handlers(self):
for sig in (signal.SIGTERM, signal.SIGINT):
signal.signal(sig, self._handle_signal)
def _handle_signal(self, signum, frame):
logger.info("Shutdown signal received", extra={"signal": signum})
self._shutdown = True
async def run(self):
"""Main loop — consume events from booking_allocation_queue."""
self.setup_signal_handlers()
# Ensure consumer group exists
await self.queue_manager.create_consumer_group(QUEUE, GROUP)
logger.info("AllocationWorker started", extra={"consumer": self.consumer_name})
while not self._shutdown:
self._loop_counter += 1
# Gap 4 fix: own our PEL reclaim — runs every PEL_RECLAIM_INTERVAL iterations
if self._loop_counter % PEL_RECLAIM_INTERVAL == 0:
await self._reclaim_stale_messages()
try:
messages = await self.queue_manager.consume_events(
QUEUE, GROUP, self.consumer_name, count=1, block=5000,
)
if not messages:
continue
for stream_name, message_list in messages:
for message_id, data in message_list:
self._in_flight.add(message_id)
try:
await self._process_message(message_id, data)
finally:
self._in_flight.discard(message_id)
except Exception as e:
logger.error("Allocation worker cycle error", extra={"error": str(e)})
await asyncio.sleep(2)
# Wait for in-flight tasks to complete
if self._in_flight:
logger.info("Waiting for in-flight tasks", extra={"count": len(self._in_flight)})
await asyncio.sleep(5)
logger.info("AllocationWorker stopped")
# -------------------------------------------------------------------------
# Gap 4 fix: PEL reclaim — each worker owns its own queue
# -------------------------------------------------------------------------
async def _reclaim_stale_messages(self):
"""
XAUTOCLAIM messages idle > 5min in our own PEL (crashed instance recovery).
The scheduler must NOT do this — it would assign messages to 'scheduler-0',
a consumer that never reads from booking_allocation_queue.
"""
try:
claimed, _, _ = await self.queue_manager.redis.xautoclaim(
QUEUE, GROUP, self.consumer_name,
min_idle_time=PEL_MIN_IDLE_MS,
start_id="0-0",
count=10,
)
if claimed:
logger.info(
"Reclaimed stale PEL messages",
extra={"queue": QUEUE, "count": len(claimed), "consumer": self.consumer_name},
)
except Exception as e:
# Consumer group may not exist yet at first startup — not fatal
logger.debug("PEL reclaim skipped", extra={"queue": QUEUE, "error": str(e)})
# -------------------------------------------------------------------------
# Gap 1 fix: enrich minimal retry events before parsing
# -------------------------------------------------------------------------
async def _fetch_booking_details(self, booking_id: str) -> Optional[dict]:
"""
Fetch full booking details from DB for enriching minimal retry events.
Retry events published by expiry/response workers only contain booking_id,
retry_reason, previous_partner_id, previous_offer_id — all booking fields
(city, lat, lng, scheduled_time, service_categories) must be re-fetched.
Returns dict in the same string format as a Redis Stream message, or None
if the booking no longer exists.
"""
try:
async with self.engine.begin() as conn:
result = await conn.execute(
text("""
SELECT
b.booking_id,
b.booking_ref,
b.city,
b.location_lat,
b.location_lng,
b.scheduled_time,
b.payment_mode,
b.customer_id,
array_agg(DISTINCT i.service_category) AS service_categories
FROM trans.spa_bookings b
JOIN trans.spa_booking_items i ON i.booking_id = b.booking_id
WHERE b.booking_id = :booking_id
GROUP BY b.booking_id, b.booking_ref, b.city,
b.location_lat, b.location_lng,
b.scheduled_time, b.payment_mode, b.customer_id
"""),
{"booking_id": booking_id},
)
row = result.fetchone()
if not row:
return None
return {
"booking_id": str(row.booking_id),
"booking_ref": row.booking_ref or "",
"city": row.city,
"location_lat": str(row.location_lat),
"location_lng": str(row.location_lng),
"scheduled_time": str(row.scheduled_time),
"service_categories": json.dumps(list(row.service_categories)),
"payment_mode": row.payment_mode or "pay_later",
"customer_id": str(row.customer_id) if row.customer_id else "",
}
except Exception as e:
logger.error(
"Failed to fetch booking details for retry enrichment",
extra={"booking_id": booking_id, "error": str(e)},
)
return None
async def _process_message(self, message_id: str, data: dict):
"""Process a single booking event with full idempotency guard."""
retry_count = 0
booking_id = data.get("booking_id", "unknown")
# Gap 1 fix: retry events from expiry/response workers are minimal —
# they only carry booking_id + retry metadata, missing city/lat/lng/scheduled_time.
# Re-fetch full booking details from DB before attempting to parse.
if "retry_reason" in data and "city" not in data:
logger.info(
"Minimal retry event detected — enriching from DB",
extra={"booking_id": booking_id, "retry_reason": data.get("retry_reason")},
)
db_data = await self._fetch_booking_details(booking_id)
if not db_data:
logger.warning(
"Booking not found for retry enrichment — discarding",
extra={"booking_id": booking_id},
)
await self.queue_manager.acknowledge_event(QUEUE, GROUP, message_id)
return
# Preserve all retry metadata alongside full booking fields
data = {
**db_data,
"retry_reason": data.get("retry_reason"),
"previous_partner_id": data.get("previous_partner_id"),
"previous_offer_id": data.get("previous_offer_id"),
}
while retry_count <= self.max_retries:
try:
event = self._parse_event(data)
await self._allocate(event)
await self.queue_manager.acknowledge_event(QUEUE, GROUP, message_id)
return
except ValueError as e:
# Invalid event — move to DLQ immediately
logger.error("Invalid booking event", extra={"booking_id": booking_id, "error": str(e)})
await self.queue_manager.move_to_dead_letter_queue(
data, e, QUEUE, retry_count, booking_id=booking_id,
)
await self.queue_manager.acknowledge_event(QUEUE, GROUP, message_id)
return
except Exception as e:
retry_count += 1
if retry_count > self.max_retries:
logger.error("Max retries exceeded", extra={"booking_id": booking_id, "error": str(e)})
await self.queue_manager.move_to_dead_letter_queue(
data, e, QUEUE, retry_count, booking_id=booking_id,
)
await self.queue_manager.acknowledge_event(QUEUE, GROUP, message_id)
return
backoff = 2 ** (retry_count - 1)
logger.warning(
"Retrying allocation",
extra={"booking_id": booking_id, "retry": retry_count, "backoff": backoff},
)
await asyncio.sleep(backoff)
def _parse_event(self, data: dict) -> BookingEvent:
"""Parse Redis Stream message into BookingEvent."""
# service_categories comes as JSON string from XADD
categories = data.get("service_categories", "[]")
if isinstance(categories, str):
categories = json.loads(categories)
return BookingEvent(
booking_id=data["booking_id"],
booking_ref=data.get("booking_ref", ""),
city=data["city"],
location_lat=data["location_lat"],
location_lng=data["location_lng"],
scheduled_time=int(data["scheduled_time"]),
service_categories=categories,
payment_mode=data.get("payment_mode", "pay_later"),
customer_id=data.get("customer_id", ""),
)
async def _allocate(self, event: BookingEvent):
"""Core allocation logic with idempotency guard and SLA guard."""
alloc_settings = await get_allocation_settings(self.queue_manager.redis, self.mongo_db)
# Idempotency guard — check current booking state
async with self.engine.begin() as conn:
result = await conn.execute(
text("""
SELECT allocation_status, booking_status
FROM trans.spa_bookings
WHERE booking_id = :booking_id
FOR UPDATE
"""),
{"booking_id": str(event.booking_id)},
)
row = result.fetchone()
if not row:
logger.warning("Booking not found — discarding", extra={"booking_id": str(event.booking_id)})
return
if row.booking_status == "cancelled":
logger.info("Booking cancelled — discarding", extra={"booking_id": str(event.booking_id)})
return
if row.allocation_status == "assigned":
logger.info("Already assigned — discarding duplicate", extra={"booking_id": str(event.booking_id)})
return
if row.allocation_status == "failed":
logger.info("Already failed — discarding", extra={"booking_id": str(event.booking_id)})
return
# Only proceed if allocation_status is 'offering'
if row.allocation_status not in ("offering", "unassigned"):
logger.warning(
"Unexpected allocation_status",
extra={"booking_id": str(event.booking_id), "status": row.allocation_status},
)
return
# Gap 2 fix: SLA guard — reject if booking's scheduled time has already passed.
# The scheduler's _expire_missed_bookings() also catches this but only on its
# heartbeat interval (every N seconds). Between heartbeats, retry events can
# carry bookings that crossed their scheduled time mid-allocation. Catch it here
# so we never create an offer for a booking that's already overdue.
now_utc = datetime.now(tz=timezone.utc)
if event.scheduled_dt <= now_utc:
logger.warning(
"Booking past scheduled time — marking failed",
extra={
"booking_id": str(event.booking_id),
"scheduled_dt": event.scheduled_dt.isoformat(),
"now_utc": now_utc.isoformat(),
},
)
await self._mark_booking_failed(event.booking_id)
await self.notification_service.notify_admin_allocation_failed(event.booking_id, city=event.city)
await self.notification_service.send_customer_booking_failed(event.booking_id, customer_id=event.customer_id)
return
# Find eligible partners
partners = await self._get_eligible_partners(event)
if not partners:
logger.warning("No eligible partners found", extra={"booking_id": str(event.booking_id)})
await self._mark_booking_failed(event.booking_id)
await self.notification_service.notify_admin_allocation_failed(event.booking_id, city=event.city)
await self.notification_service.send_customer_booking_failed(event.booking_id, customer_id=event.customer_id)
return
# Score and rank partners
ranked = self.scoring_service.rank_partners(
partners=partners,
booking_lat=float(event.location_lat),
booking_lng=float(event.location_lng),
weights=alloc_settings.get("scoring_weights"),
max_distance_km=alloc_settings.get("max_partner_distance_km", 50),
)
if not ranked:
logger.warning("No partners after scoring", extra={"booking_id": str(event.booking_id)})
await self._mark_booking_failed(event.booking_id)
await self.notification_service.notify_admin_allocation_failed(event.booking_id, city=event.city)
await self.notification_service.send_customer_booking_failed(event.booking_id, customer_id=event.customer_id)
return
# Create offer to top-ranked partner
top = ranked[0]
offer = await self.offer_service.create_offer(
booking_id=event.booking_id,
partner_id=top.partner.partner_id,
expiry_seconds=alloc_settings.get("offer_expiry_seconds", settings.OFFER_EXPIRY_SECONDS),
)
if offer:
# Write score to audit trail
await self._write_attempt_score(
event.booking_id, top.partner.partner_id, float(top.score),
)
logger.info(
"Offer created",
extra={
"booking_id": str(event.booking_id),
"partner_id": str(top.partner.partner_id),
"score": float(top.score),
},
)
await self.notification_service.send_offer_notification(
partner_id=top.partner.partner_id,
offer_id=offer.offer_id,
booking_id=event.booking_id,
scheduled_time=event.scheduled_time,
offer_expiry=offer.offer_expiry,
booking_ref=event.booking_ref,
service_categories=event.service_categories,
)
else:
logger.warning(
"Offer creation returned None — booking already assigned, cancelled, "
"pending offer exists, or Redis expiry publish failed",
extra={"booking_id": str(event.booking_id), "partner_id": str(top.partner.partner_id)},
)
async def _get_eligible_partners(self, event: BookingEvent) -> list:
"""
CTE-based query: find partners covering ALL service categories,
available at scheduled time, under max load, not previously attempted.
Returns Partner objects with actual max_load from spa_partner_availability
so the scoring service can compute accurate load scores per partner capacity.
"""
async with self.engine.begin() as conn:
result = await conn.execute(
text("""
WITH booking_categories AS (
SELECT service_category
FROM trans.spa_booking_items
WHERE booking_id = :booking_id
),
category_count AS (
SELECT COUNT(DISTINCT service_category) AS total
FROM booking_categories
),
eligible_partners AS (
SELECT
psm.partner_id,
COUNT(DISTINCT psm.service_category) AS matched_categories
FROM trans.spa_partner_service_map psm
JOIN booking_categories bc ON bc.service_category = psm.service_category
WHERE psm.city = :city
AND psm.is_active = TRUE
AND NOT EXISTS (
SELECT 1 FROM trans.spa_allocation_attempts aa
WHERE aa.booking_id = :booking_id
AND aa.partner_id = psm.partner_id
)
GROUP BY psm.partner_id
HAVING COUNT(DISTINCT psm.service_category) = (SELECT total FROM category_count)
)
SELECT
ep.partner_id,
pp.rating,
pp.completed_bookings,
pa.current_load,
pa.max_load,
pa.location_lat,
pa.location_lng
FROM eligible_partners ep
JOIN trans.spa_partner_availability pa
ON pa.partner_id = ep.partner_id
AND pa.available_from <= to_timestamp(:scheduled_time_ms / 1000.0)
AND pa.available_to >= to_timestamp(:scheduled_time_ms / 1000.0)
AND pa.current_load < pa.max_load
JOIN trans.spa_partner_profiles pp
ON pp.partner_id = ep.partner_id
"""),
{
"booking_id": str(event.booking_id),
"city": event.city,
"scheduled_time_ms": event.scheduled_time,
},
)
rows = result.fetchall()
# Gap 3 fix: pass actual max_load from DB — scoring service uses partner.max_load,
# not a hardcoded constant, so load scores reflect each partner's real capacity.
from app.models.partner import Partner
partners = []
for row in rows:
partners.append(Partner(
partner_id=row.partner_id,
name=row.partner_id.hex, # Required by Pydantic min_length=1; not used in scoring
rating=row.rating,
completed_bookings=row.completed_bookings,
active_bookings=row.current_load,
max_load=row.max_load, # actual capacity, not hardcoded 5
location_lat=row.location_lat,
location_lng=row.location_lng,
service_categories=[], # already validated by CTE
city=event.city,
))
return partners
async def _write_attempt_score(self, booking_id: UUID, partner_id: UUID, score: float):
"""Write score to spa_allocation_attempts for audit trail."""
async with self.engine.begin() as conn:
await conn.execute(
text("""
INSERT INTO trans.spa_allocation_attempts
(booking_id, partner_id, attempt_timestamp, attempt_status, score)
VALUES
(:booking_id, :partner_id, :ts, 'offered', :score)
ON CONFLICT (booking_id, partner_id)
DO UPDATE SET attempt_status = 'offered', score = :score, attempt_timestamp = :ts
"""),
{
"booking_id": str(booking_id),
"partner_id": str(partner_id),
"ts": datetime.now(tz=timezone.utc),
"score": score,
},
)
async def _mark_booking_failed(self, booking_id: UUID):
"""Mark booking as failed when no eligible partners remain or SLA expired."""
async with self.engine.begin() as conn:
await conn.execute(
text("""
UPDATE trans.spa_bookings
SET allocation_status = 'failed', updated_at = NOW()
WHERE booking_id = :booking_id
AND allocation_status != 'assigned'
"""),
{"booking_id": str(booking_id)},
)
logger.warning("Booking marked as failed", extra={"booking_id": str(booking_id)})
async def main():
"""Entry point for allocation worker."""
from app.core.logging import setup_logging
setup_logging(settings.LOG_LEVEL)
from app.db.postgres import create_pg_engine
from app.db.mongo import get_mongo_db
from app.queue.redis_client import create_redis_client
logger.info("Initializing Allocation Worker")
engine = await create_pg_engine()
redis_client = await create_redis_client()
queue_manager = RedisQueueManager(redis_client)
offer_service = OfferService(engine, queue_manager)
scoring_service = ScoringService()
try:
mongo_db = await get_mongo_db()
except Exception as e:
logger.warning("MongoDB unavailable — using .env defaults", extra={"error": str(e)})
mongo_db = None
consumer_name = os.getenv("HOSTNAME", socket.gethostname())
worker = AllocationWorker(
queue_manager=queue_manager,
offer_service=offer_service,
scoring_service=scoring_service,
engine=engine,
consumer_name=consumer_name,
mongo_db=mongo_db,
)
try:
await worker.run()
finally:
await queue_manager.close()
await engine.dispose()
from app.db.mongo import close_mongo
await close_mongo()
logger.info("Allocation Worker shutdown complete")
if __name__ == "__main__":
asyncio.run(main())