Spaces:
Runtime error
Runtime error
| import asyncio | |
| import signal | |
| from asyncio import Task | |
| from collections.abc import Sequence | |
| from datetime import datetime, timedelta, timezone | |
| from logging import getLogger | |
| from typing import Any, NamedTuple, cast | |
| import sentry_sdk | |
| from dotenv import load_dotenv | |
| from nanoid import generate as generate_nanoid | |
| from sentry_sdk.integrations.asyncio import AsyncioIntegration | |
| from sqlalchemy import and_, delete, or_, select, update | |
| from sqlalchemy.dialects.postgresql import insert | |
| from sqlalchemy.engine import CursorResult | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from sqlalchemy.sql import func | |
| from src import models | |
| from src.cache.client import close_cache, init_cache | |
| from src.config import settings | |
| from src.dependencies import tracked_db | |
| from src.deriver.consumer import ( | |
| process_item, | |
| process_representation_batch, | |
| ) | |
| from src.dreamer.dream_scheduler import ( | |
| DreamScheduler, | |
| get_dream_scheduler, | |
| set_dream_scheduler, | |
| ) | |
| from src.models import QueueItem | |
| from src.reconciler import ( | |
| ReconcilerScheduler, | |
| get_reconciler_scheduler, | |
| set_reconciler_scheduler, | |
| ) | |
| from src.schemas import ResolvedConfiguration | |
| from src.telemetry import prometheus_metrics | |
| from src.telemetry.sentry import initialize_sentry | |
| from src.utils.work_unit import parse_work_unit_key | |
| from src.webhooks.events import ( | |
| QueueEmptyEvent, | |
| publish_webhook_event, | |
| ) | |
| logger = getLogger(__name__) | |
| load_dotenv(override=True) | |
| class WorkerOwnership(NamedTuple): | |
| """Represents the instance of a work unit that a worker is processing.""" | |
| work_unit_key: str | |
| aqs_id: str # The ID of the ActiveQueueSession that the worker is processing | |
| def _detach_queue_batch_objects( | |
| db: AsyncSession, | |
| messages_context: list[models.Message], | |
| items_to_process: list[QueueItem], | |
| ) -> None: | |
| """Detach loaded batch objects so they remain usable after tracked_db exits.""" | |
| seen: set[int] = set() | |
| for obj in [*messages_context, *items_to_process]: | |
| obj_id = id(obj) | |
| if obj_id in seen: | |
| continue | |
| db.expunge(obj) | |
| seen.add(obj_id) | |
| def _resolve_batch_configuration( | |
| items_to_process: list[QueueItem], | |
| ) -> tuple[list[QueueItem], ResolvedConfiguration | None]: | |
| """Keep only the initial homogeneous configuration prefix for a batch.""" | |
| if not items_to_process: | |
| return [], None | |
| raw_config = items_to_process[0].payload.get("configuration") | |
| resolved_config = ( | |
| None if raw_config is None else ResolvedConfiguration.model_validate(raw_config) | |
| ) | |
| valid_items: list[QueueItem] = [] | |
| for item in items_to_process: | |
| item_raw_config = item.payload.get("configuration") | |
| item_config = ( | |
| None | |
| if item_raw_config is None | |
| else ResolvedConfiguration.model_validate(item_raw_config) | |
| ) | |
| if item_config != resolved_config: | |
| break | |
| valid_items.append(item) | |
| return valid_items, resolved_config | |
| class QueueManager: | |
| def __init__(self): | |
| self.shutdown_event: asyncio.Event = asyncio.Event() | |
| self.active_tasks: set[asyncio.Task[None]] = set() | |
| self.worker_ownership: dict[str, WorkerOwnership] = {} | |
| self.queue_empty_flag: asyncio.Event = asyncio.Event() | |
| # Initialize from settings | |
| self.workers: int = settings.DERIVER.WORKERS | |
| self.semaphore: asyncio.Semaphore = asyncio.Semaphore(self.workers) | |
| # Get or create the singleton dream scheduler | |
| existing_scheduler = get_dream_scheduler() | |
| if existing_scheduler is None: | |
| self.dream_scheduler: DreamScheduler = DreamScheduler() | |
| set_dream_scheduler(self.dream_scheduler) | |
| else: | |
| self.dream_scheduler = existing_scheduler | |
| # Get or create the singleton reconciler scheduler | |
| existing_reconciler = get_reconciler_scheduler() | |
| if existing_reconciler is None: | |
| self.reconciler_scheduler: ReconcilerScheduler = ReconcilerScheduler() | |
| set_reconciler_scheduler(self.reconciler_scheduler) | |
| else: | |
| self.reconciler_scheduler = existing_reconciler | |
| # Initialize Sentry if enabled, using settings | |
| if settings.SENTRY.ENABLED: | |
| initialize_sentry(integrations=[AsyncioIntegration()]) | |
| def add_task(self, task: asyncio.Task[None]) -> None: | |
| """Track a new task""" | |
| self.active_tasks.add(task) | |
| task.add_done_callback(self.active_tasks.discard) | |
| def track_worker_work_unit( | |
| self, worker_id: str, work_unit_key: str, aqs_id: str | |
| ) -> None: | |
| """Track a work unit owned by a specific worker""" | |
| self.worker_ownership[worker_id] = WorkerOwnership(work_unit_key, aqs_id) | |
| def untrack_worker_work_unit(self, worker_id: str, work_unit_key: str) -> None: | |
| """Remove a work unit from worker tracking""" | |
| ownership = self.worker_ownership.get(worker_id) | |
| if ownership and ownership.work_unit_key == work_unit_key: | |
| del self.worker_ownership[worker_id] | |
| def create_worker_id(self) -> str: | |
| """Generate a unique worker ID for this processing task""" | |
| return generate_nanoid() | |
| def get_total_owned_work_units(self) -> int: | |
| """Get the total number of work units owned by all workers""" | |
| return len(self.worker_ownership) | |
| async def initialize(self) -> None: | |
| """Setup signal handlers, initialize client, and start the main polling loop""" | |
| logger.debug(f"Initializing QueueManager with {self.workers} workers") | |
| # Set up signal handlers | |
| loop = asyncio.get_running_loop() | |
| signals = (signal.SIGTERM, signal.SIGINT) | |
| for sig in signals: | |
| loop.add_signal_handler( | |
| sig, lambda s=sig: asyncio.create_task(self.shutdown(s)) | |
| ) | |
| logger.debug("Signal handlers registered") | |
| # Start the reconciler scheduler | |
| try: | |
| await self.reconciler_scheduler.start() | |
| except Exception: | |
| logger.exception("Failed to start reconciler scheduler") | |
| # Run the polling loop directly in this task | |
| logger.debug("Starting polling loop directly") | |
| try: | |
| await self.polling_loop() | |
| finally: | |
| await self.cleanup() | |
| async def shutdown(self, sig: signal.Signals) -> None: | |
| """Handle graceful shutdown""" | |
| logger.info(f"Received exit signal {sig.name}...") | |
| self.shutdown_event.set() | |
| # Cancel all pending dreams | |
| await self.dream_scheduler.shutdown() | |
| # Stop the reconciler scheduler | |
| await self.reconciler_scheduler.shutdown() | |
| if self.active_tasks: | |
| logger.info( | |
| f"Waiting for {len(self.active_tasks)} active tasks to complete..." | |
| ) | |
| await asyncio.gather(*self.active_tasks, return_exceptions=True) | |
| async def cleanup(self) -> None: | |
| """Clean up owned work units""" | |
| total_work_units = self.get_total_owned_work_units() | |
| if total_work_units > 0: | |
| logger.debug(f"Cleaning up {total_work_units} owned work units...") | |
| try: | |
| # Use the tracked_db dependency for transaction safety | |
| async with tracked_db("queue_cleanup") as db: | |
| aqs_ids = [ | |
| ownership.aqs_id for ownership in self.worker_ownership.values() | |
| ] | |
| if aqs_ids: | |
| await db.execute( | |
| delete(models.ActiveQueueSession).where( | |
| models.ActiveQueueSession.id.in_(aqs_ids) | |
| ) | |
| ) | |
| await db.commit() | |
| except Exception as e: | |
| logger.error(f"Error during cleanup: {str(e)}") | |
| if settings.SENTRY.ENABLED: | |
| sentry_sdk.capture_exception(e) | |
| finally: | |
| self.worker_ownership.clear() | |
| ########################## | |
| # Polling and Scheduling # | |
| ########################## | |
| async def cleanup_stale_work_units(self) -> None: | |
| """Clean up stale work units""" | |
| async with tracked_db("cleanup_stale_work_units") as db: | |
| cutoff = datetime.now(timezone.utc) - timedelta( | |
| minutes=settings.DERIVER.STALE_SESSION_TIMEOUT_MINUTES | |
| ) | |
| stale_ids = ( | |
| ( | |
| await db.execute( | |
| select(models.ActiveQueueSession.id) | |
| .where(models.ActiveQueueSession.last_updated < cutoff) | |
| .order_by(models.ActiveQueueSession.last_updated) | |
| .with_for_update(skip_locked=True) | |
| ) | |
| ) | |
| .scalars() | |
| .all() | |
| ) | |
| # Delete only the records we successfully got locks for | |
| if stale_ids: | |
| await db.execute( | |
| delete(models.ActiveQueueSession).where( | |
| models.ActiveQueueSession.id.in_(stale_ids) | |
| ) | |
| ) | |
| await db.commit() | |
| async def get_and_claim_work_units(self) -> dict[str, str]: | |
| """ | |
| Get available work units that aren't being processed. | |
| For representation tasks, only returns work units with accumulated tokens | |
| >= REPRESENTATION_BATCH_MAX_TOKENS (forced batching), unless FLUSH_ENABLED is True. | |
| Returns a dict mapping work_unit_key to aqs_id. | |
| """ | |
| limit: int = max(0, self.workers - self.get_total_owned_work_units()) | |
| if limit == 0: | |
| return {} | |
| batch_max_tokens = settings.DERIVER.REPRESENTATION_BATCH_MAX_TOKENS | |
| async with tracked_db("get_available_work_units") as db: | |
| representation_prefix = "representation:" | |
| token_stats_subq = ( | |
| select( | |
| models.QueueItem.work_unit_key, | |
| func.sum(models.Message.token_count).label("total_tokens"), | |
| ) | |
| .join( | |
| models.Message, | |
| models.QueueItem.message_id == models.Message.id, | |
| ) | |
| .where(~models.QueueItem.processed) | |
| .where(models.QueueItem.work_unit_key.startswith(representation_prefix)) | |
| .group_by(models.QueueItem.work_unit_key) | |
| .subquery() | |
| ) | |
| work_units_subq = ( | |
| select(models.QueueItem.work_unit_key) | |
| .where(~models.QueueItem.processed) | |
| .group_by(models.QueueItem.work_unit_key) | |
| .subquery() | |
| ) | |
| query = ( | |
| select(work_units_subq.c.work_unit_key) | |
| .limit(limit) | |
| .outerjoin( | |
| token_stats_subq, | |
| work_units_subq.c.work_unit_key == token_stats_subq.c.work_unit_key, | |
| ) | |
| .where( | |
| ~select(models.ActiveQueueSession.id) | |
| .where( | |
| models.ActiveQueueSession.work_unit_key | |
| == work_units_subq.c.work_unit_key | |
| ) | |
| .exists() | |
| ) | |
| ) | |
| # Apply batch threshold filter (skip if FLUSH_ENABLED is True) | |
| if not settings.DERIVER.FLUSH_ENABLED and batch_max_tokens > 0: | |
| query = query.where( | |
| or_( | |
| ~work_units_subq.c.work_unit_key.startswith( | |
| representation_prefix | |
| ), | |
| func.coalesce(token_stats_subq.c.total_tokens, 0) | |
| >= batch_max_tokens, | |
| ) | |
| ) | |
| result = await db.execute(query) | |
| available_units = result.scalars().all() | |
| if not available_units: | |
| await db.commit() | |
| return {} | |
| claimed_mapping = await self.claim_work_units(db, available_units) | |
| await db.commit() | |
| return claimed_mapping | |
| async def claim_work_units( | |
| self, db: AsyncSession, work_unit_keys: Sequence[str] | |
| ) -> dict[str, str]: | |
| """ | |
| Claim work units and return a mapping of work_unit_key to aqs_id. | |
| Returns only the work units that were successfully claimed. | |
| """ | |
| values = [{"work_unit_key": key} for key in work_unit_keys] | |
| stmt = ( | |
| insert(models.ActiveQueueSession) | |
| .values(values) | |
| .on_conflict_do_nothing() | |
| .returning( | |
| models.ActiveQueueSession.work_unit_key, models.ActiveQueueSession.id | |
| ) | |
| ) | |
| result = await db.execute(stmt) | |
| claimed_rows = result.all() | |
| claimed_mapping = {row[0]: row[1] for row in claimed_rows} | |
| logger.debug( | |
| f"Claimed {len(claimed_mapping)} work units: {list(claimed_mapping.keys())}" | |
| ) | |
| return claimed_mapping | |
| async def polling_loop(self) -> None: | |
| """Main polling loop to find and process new work units""" | |
| logger.debug("Starting polling loop") | |
| try: | |
| while not self.shutdown_event.is_set(): | |
| if self.queue_empty_flag.is_set(): | |
| # logger.debug("Queue empty flag set, waiting") | |
| await asyncio.sleep(settings.DERIVER.POLLING_SLEEP_INTERVAL_SECONDS) | |
| self.queue_empty_flag.clear() | |
| continue | |
| # Check if we have capacity before querying | |
| if self.semaphore.locked(): | |
| # logger.debug("All workers busy, waiting") | |
| await asyncio.sleep(settings.DERIVER.POLLING_SLEEP_INTERVAL_SECONDS) | |
| continue | |
| try: | |
| await self.cleanup_stale_work_units() | |
| claimed_work_units = await self.get_and_claim_work_units() | |
| if claimed_work_units: | |
| for work_unit_key, aqs_id in claimed_work_units.items(): | |
| # Create a new task for processing this work unit | |
| if not self.shutdown_event.is_set(): | |
| # Track worker ownership | |
| worker_id = self.create_worker_id() | |
| self.track_worker_work_unit( | |
| worker_id, work_unit_key, aqs_id | |
| ) | |
| task: Task[None] = asyncio.create_task( | |
| self.process_work_unit(work_unit_key, worker_id) | |
| ) | |
| self.add_task(task) | |
| else: | |
| self.queue_empty_flag.set() | |
| await asyncio.sleep( | |
| settings.DERIVER.POLLING_SLEEP_INTERVAL_SECONDS | |
| ) | |
| except Exception as e: | |
| logger.exception("Error in polling loop") | |
| if settings.SENTRY.ENABLED: | |
| sentry_sdk.capture_exception(e) | |
| # Note: rollback is handled by tracked_db dependency | |
| await asyncio.sleep(settings.DERIVER.POLLING_SLEEP_INTERVAL_SECONDS) | |
| finally: | |
| logger.info("Polling loop stopped") | |
| ###################### | |
| # Queue Worker Logic # | |
| ###################### | |
| async def _handle_processing_error( | |
| self, | |
| error: Exception, | |
| items: list[QueueItem], | |
| work_unit_key: str, | |
| context: str, | |
| ) -> None: | |
| """ | |
| Handle processing errors by marking queue items as errored, logging, and forwarding to Sentry. | |
| We only mark the first queue item as errored so we don't potentially throw away a batch. This allows us | |
| to incrementally attempt to process the batch while still maintaining progress in a work unit. | |
| Args: | |
| error: The exception that occurred | |
| items: The queue items that were being processed | |
| work_unit_key: The work unit key for the queue items | |
| context: Context string describing what was being processed (e.g., "processing representation batch") | |
| """ | |
| error_msg = f"{error.__class__.__name__}: {str(error)}" | |
| try: | |
| if items: | |
| await self.mark_queue_item_as_errored( | |
| items[0], work_unit_key, error_msg | |
| ) | |
| except Exception as mark_error: | |
| logger.error( | |
| f"Failed to mark queue items as errored for work unit {work_unit_key}: {mark_error}", | |
| exc_info=True, | |
| ) | |
| logger.error( | |
| f"Error {context} for work unit {work_unit_key}: {error}", | |
| exc_info=True, | |
| ) | |
| if settings.SENTRY.ENABLED: | |
| sentry_sdk.capture_exception(error) | |
| async def process_work_unit(self, work_unit_key: str, worker_id: str) -> None: | |
| """Process all queue items for a specific work unit by routing to the correct handler.""" | |
| logger.debug(f"Starting to process work unit {work_unit_key}") | |
| work_unit = parse_work_unit_key(work_unit_key) | |
| async with self.semaphore: | |
| queue_item_count = 0 | |
| try: | |
| while not self.shutdown_event.is_set(): | |
| # Get worker ownership info for verification | |
| ownership = self.worker_ownership.get(worker_id) | |
| if not ownership or ownership.work_unit_key != work_unit_key: | |
| logger.warning( | |
| f"Worker {worker_id} lost ownership of work unit {work_unit_key}, stopping processing {work_unit_key}" | |
| ) | |
| break | |
| try: | |
| if work_unit.task_type == "representation": | |
| ( | |
| messages_context, | |
| items_to_process, | |
| message_level_configuration, | |
| ) = await self.get_queue_item_batch( | |
| work_unit.task_type, work_unit_key, ownership.aqs_id | |
| ) | |
| logger.debug( | |
| f"Worker {worker_id} retrieved {len(messages_context)} messages and {len(items_to_process)} queue items for work unit {work_unit_key} (AQS ID: {ownership.aqs_id})" | |
| ) | |
| if not items_to_process: | |
| logger.debug( | |
| f"No more queue items to process for work unit {work_unit_key} for worker {worker_id}" | |
| ) | |
| break | |
| try: | |
| # Extract observers from the payload (handle both old and new format) | |
| payload = items_to_process[0].payload | |
| observers = payload.get("observers") | |
| if observers is None: | |
| # Legacy format: single observer string | |
| legacy_observer = payload.get("observer") | |
| if legacy_observer: | |
| observers = [legacy_observer] | |
| else: | |
| observers = [] | |
| queue_item_message_ids = [ | |
| item.message_id | |
| for item in items_to_process | |
| if item.message_id is not None | |
| ] | |
| await process_representation_batch( | |
| messages_context, | |
| message_level_configuration, | |
| observers=observers, | |
| observed=work_unit.observed, | |
| queue_item_message_ids=queue_item_message_ids, | |
| ) | |
| await self.mark_queue_items_as_processed( | |
| items_to_process, work_unit_key | |
| ) | |
| queue_item_count += len(items_to_process) | |
| except Exception as e: | |
| await self._handle_processing_error( | |
| e, | |
| items_to_process, | |
| work_unit_key, | |
| f"processing {work_unit.task_type} batch", | |
| ) | |
| else: | |
| queue_item = await self.get_next_queue_item( | |
| work_unit.task_type, work_unit_key, ownership.aqs_id | |
| ) | |
| if not queue_item: | |
| logger.debug( | |
| f"No more queue items to process for work unit {work_unit_key} for worker {worker_id}" | |
| ) | |
| break | |
| try: | |
| await process_item(queue_item) | |
| await self.mark_queue_items_as_processed( | |
| [queue_item], work_unit_key | |
| ) | |
| queue_item_count += 1 | |
| except Exception as e: | |
| await self._handle_processing_error( | |
| e, | |
| [queue_item], | |
| work_unit_key, | |
| "processing queue item", | |
| ) | |
| except Exception as e: | |
| logger.error( | |
| f"Error in processing loop for work unit {work_unit_key}: {e}", | |
| exc_info=True, | |
| ) | |
| if settings.SENTRY.ENABLED: | |
| sentry_sdk.capture_exception(e) | |
| # Check for shutdown after processing each batch | |
| if self.shutdown_event.is_set(): | |
| logger.debug( | |
| "Shutdown requested, stopping processing for work unit %s", | |
| work_unit_key, | |
| ) | |
| break | |
| finally: | |
| # Remove work unit from active_queue_sessions when done | |
| ownership: WorkerOwnership | None = self.worker_ownership.get(worker_id) | |
| if ownership and ownership.work_unit_key == work_unit_key: | |
| removed = await self._cleanup_work_unit( | |
| ownership.aqs_id, work_unit_key | |
| ) | |
| else: | |
| removed = False | |
| self.untrack_worker_work_unit(worker_id, work_unit_key) | |
| if removed and queue_item_count > 0: | |
| # Only publish webhook if we actually removed an active session | |
| try: | |
| if ( | |
| work_unit.task_type in ["representation", "summary"] | |
| and work_unit.workspace_name is not None | |
| ): | |
| logger.debug( | |
| f"Publishing queue.empty event for {work_unit_key} in workspace {work_unit.workspace_name}" | |
| ) | |
| await publish_webhook_event( | |
| QueueEmptyEvent( | |
| workspace_id=work_unit.workspace_name, | |
| queue_type=work_unit.task_type, | |
| session_id=work_unit.session_name, | |
| observer=work_unit.observer, | |
| observed=work_unit.observed, | |
| ) | |
| ) | |
| except Exception: | |
| logger.exception("Error triggering queue_empty webhook") | |
| else: | |
| logger.debug( | |
| f"Work unit {work_unit_key} already cleaned up by another worker, skipping webhook" | |
| ) | |
| async def get_next_queue_item( | |
| self, task_type: str, work_unit_key: str, aqs_id: str | |
| ) -> QueueItem | None: | |
| """Get the next queue item to process for a specific work unit.""" | |
| if task_type == "representation": | |
| raise ValueError( | |
| "representation tasks are not supported for get_next_queue_item" | |
| ) | |
| async with tracked_db("get_next_queue_item") as db: | |
| # ActiveQueueSession conditions for worker ownership verification | |
| aqs_conditions = [ | |
| models.ActiveQueueSession.work_unit_key == work_unit_key, | |
| models.ActiveQueueSession.id == aqs_id, | |
| ] | |
| query = ( | |
| select(models.QueueItem) | |
| .join( | |
| models.ActiveQueueSession, | |
| models.QueueItem.work_unit_key | |
| == models.ActiveQueueSession.work_unit_key, | |
| ) | |
| .where(models.QueueItem.work_unit_key == work_unit_key) | |
| .where(~models.QueueItem.processed) | |
| .where(*aqs_conditions) | |
| .order_by(models.QueueItem.id) | |
| .limit(1) | |
| ) | |
| result = await db.execute(query) | |
| queue_item = result.scalar_one_or_none() | |
| # Important: commit to avoid tracked_db's rollback expiring the instance | |
| # We rely on expire_on_commit=False to keep attributes accessible post-close | |
| await db.commit() | |
| return queue_item | |
| async def get_queue_item_batch( | |
| self, | |
| task_type: str, | |
| work_unit_key: str, | |
| aqs_id: str, | |
| ) -> tuple[list[models.Message], list[QueueItem], ResolvedConfiguration | None]: | |
| """ | |
| Batch processing for representation and agent tasks. | |
| Returns a tuple of (messages_context, items_to_process, configuration). | |
| - messages_context: unique Message rows (conversation turns) forming the context window | |
| - items_to_process: QueueItems for the current work_unit_key within that window | |
| - configuration: Resolved configuration for the batch | |
| """ | |
| if task_type != "representation": | |
| raise ValueError( | |
| f"{task_type} tasks are not supported for get_queue_item_batch" | |
| ) | |
| batch_max_tokens = settings.DERIVER.REPRESENTATION_BATCH_MAX_TOKENS | |
| parsed_key = parse_work_unit_key(work_unit_key) | |
| messages_context: list[models.Message] = [] | |
| items_to_process: list[QueueItem] = [] | |
| async with tracked_db("get_queue_item_batch") as db: | |
| # For batch tasks, get messages based on token limit. | |
| # Step 1: Verify worker still owns the work_unit_key. | |
| ownership_check = await db.execute( | |
| select(models.ActiveQueueSession.id) | |
| .where(models.ActiveQueueSession.work_unit_key == work_unit_key) | |
| .where(models.ActiveQueueSession.id == aqs_id) | |
| ) | |
| if not ownership_check.scalar_one_or_none(): | |
| return [], [], None | |
| # Step 2: Build a single SQL query that: | |
| # 1. Finds the earliest unprocessed message for this work_unit_key | |
| # 2. Optionally includes the preceding message if from a different peer (for context) | |
| # 3. Gets ALL messages from that point forward (for conversational context) | |
| # 4. Tracks cumulative tokens and focused sender position | |
| # 5. Returns empty if focused sender is beyond token limit | |
| # 6. Otherwise returns messages up to token limit + first focused sender message | |
| # Find the minimum message_id with an unprocessed queue item across the session | |
| min_unprocessed_message_id_subq = ( | |
| select(func.min(models.Message.id)) | |
| .select_from(models.QueueItem) | |
| .join( | |
| models.Message, | |
| models.QueueItem.message_id == models.Message.id, | |
| ) | |
| .where(~models.QueueItem.processed) | |
| .where(models.Message.session_name == parsed_key.session_name) | |
| .where(models.Message.workspace_name == parsed_key.workspace_name) | |
| .where(models.QueueItem.work_unit_key == work_unit_key) | |
| .scalar_subquery() | |
| ) | |
| # Find the immediately preceding message ID (the one right before min_unprocessed) | |
| immediately_preceding_id_subq = ( | |
| select(func.max(models.Message.id)) | |
| .where(models.Message.session_name == parsed_key.session_name) | |
| .where(models.Message.workspace_name == parsed_key.workspace_name) | |
| .where(models.Message.id < min_unprocessed_message_id_subq) | |
| .scalar_subquery() | |
| ) | |
| # Only include the preceding message if it's from a different peer than observed | |
| # This provides conversational context (e.g., the question that prompted the response) | |
| preceding_message_id_subq = ( | |
| select(models.Message.id) | |
| .where(models.Message.id == immediately_preceding_id_subq) | |
| .where(models.Message.peer_name != parsed_key.observed) | |
| .scalar_subquery() | |
| ) | |
| # Determine the effective start: preceding message if it qualifies, else min_unprocessed | |
| # We use COALESCE to fall back to min_unprocessed if no preceding message qualifies | |
| effective_start_id = func.coalesce( | |
| preceding_message_id_subq, min_unprocessed_message_id_subq | |
| ) | |
| # Build CTE with ALL messages starting from effective_start_id | |
| # This includes the preceding context message (if any) and interleaving messages | |
| cte = ( | |
| select( | |
| models.Message.id.label("message_id"), | |
| models.Message.token_count.label("token_count"), | |
| models.Message.peer_name.label("peer_name"), | |
| func.sum(models.Message.token_count) | |
| .over(order_by=models.Message.id) | |
| .label("cumulative_token_count"), | |
| ) | |
| .where(models.Message.session_name == parsed_key.session_name) | |
| .where(models.Message.workspace_name == parsed_key.workspace_name) | |
| .where(models.Message.id >= effective_start_id) | |
| .order_by(models.Message.id) | |
| .cte() | |
| ) | |
| allowed_condition = ( | |
| (cte.c.cumulative_token_count <= batch_max_tokens) | |
| | ( | |
| cte.c.message_id == min_unprocessed_message_id_subq | |
| ) # always include the first unprocessed message | |
| ) | |
| query = ( | |
| select(models.Message, models.QueueItem) | |
| .select_from(cte) | |
| .join(models.Message, models.Message.id == cte.c.message_id) | |
| .outerjoin( | |
| models.QueueItem, | |
| and_( | |
| models.QueueItem.work_unit_key == work_unit_key, | |
| ~models.QueueItem.processed, | |
| models.QueueItem.message_id == models.Message.id, | |
| ), | |
| ) | |
| .where(allowed_condition) | |
| .order_by(models.Message.id, models.QueueItem.id) | |
| ) | |
| result = await db.execute(query) | |
| rows = result.all() | |
| if not rows: | |
| return [], [], None | |
| seen_messages: set[int] = set() | |
| for m, qi in rows: | |
| if m.id not in seen_messages: | |
| messages_context.append(m) | |
| seen_messages.add(m.id) | |
| if qi is not None: | |
| items_to_process.append(qi) | |
| _detach_queue_batch_objects(db, messages_context, items_to_process) | |
| items_to_process, resolved_config = _resolve_batch_configuration( | |
| items_to_process | |
| ) | |
| if items_to_process: | |
| max_queue_item_message_id = max( | |
| qi.message_id for qi in items_to_process if qi.message_id is not None | |
| ) | |
| messages_context = [ | |
| m for m in messages_context if m.id <= max_queue_item_message_id | |
| ] | |
| return messages_context, items_to_process, resolved_config | |
| async def mark_queue_items_as_processed( | |
| self, items: list[QueueItem], work_unit_key: str | |
| ) -> None: | |
| if not items: | |
| return | |
| async with tracked_db("process_queue_item_batch") as db: | |
| work_unit = parse_work_unit_key(work_unit_key) | |
| item_ids = [item.id for item in items] | |
| await db.execute( | |
| update(models.QueueItem) | |
| .where(models.QueueItem.id.in_(item_ids)) | |
| .where(models.QueueItem.work_unit_key == work_unit_key) | |
| .values(processed=True) | |
| ) | |
| await db.execute( | |
| update(models.ActiveQueueSession) | |
| .where(models.ActiveQueueSession.work_unit_key == work_unit_key) | |
| .values(last_updated=func.now()) | |
| ) | |
| await db.commit() | |
| if ( | |
| work_unit.task_type in ["representation", "summary"] | |
| and work_unit.workspace_name is not None | |
| and settings.METRICS.ENABLED | |
| ): | |
| prometheus_metrics.record_deriver_queue_item( | |
| count=len(items), | |
| workspace_name=work_unit.workspace_name, | |
| task_type=work_unit.task_type, | |
| ) | |
| async def mark_queue_item_as_errored( | |
| self, item: QueueItem, work_unit_key: str, error: str | |
| ) -> None: | |
| """Mark queue item as processed with an error""" | |
| if not item: | |
| return | |
| async with tracked_db("mark_queue_item_as_errored") as db: | |
| await db.execute( | |
| update(models.QueueItem) | |
| .where(models.QueueItem.id == item.id) | |
| .where(models.QueueItem.work_unit_key == work_unit_key) | |
| .values(processed=True, error=error[:65535]) # Truncate to TEXT limit | |
| ) | |
| await db.execute( | |
| update(models.ActiveQueueSession) | |
| .where(models.ActiveQueueSession.work_unit_key == work_unit_key) | |
| .values(last_updated=func.now()) | |
| ) | |
| await db.commit() | |
| async def _cleanup_work_unit( | |
| self, | |
| aqs_id: str, | |
| work_unit_key: str, | |
| ) -> bool: | |
| """ | |
| Clean up a specific work unit session by both work_unit_key and AQS ID. | |
| """ | |
| async with tracked_db("cleanup_work_unit") as db: | |
| result = cast( | |
| CursorResult[Any], | |
| await db.execute( | |
| delete(models.ActiveQueueSession) | |
| .where(models.ActiveQueueSession.id == aqs_id) | |
| .where(models.ActiveQueueSession.work_unit_key == work_unit_key) | |
| ), | |
| ) | |
| await db.commit() | |
| return result.rowcount > 0 | |
| async def main(): | |
| logger.debug("Starting queue manager") | |
| try: | |
| await init_cache() | |
| except Exception as e: | |
| logger.warning( | |
| "Error initializing cache in queue manager; proceeding without cache: %s", e | |
| ) | |
| manager = QueueManager() | |
| try: | |
| await manager.initialize() | |
| except Exception as e: | |
| logger.error(f"Error in main: {str(e)}") | |
| sentry_sdk.capture_exception(e) | |
| finally: | |
| await close_cache() | |
| logger.debug("Main function exiting") | |