Spaces:
Runtime error
Runtime error
| import logging | |
| from typing import Any, Literal | |
| from sqlalchemy import exists, insert, select | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from src import crud, models, schemas | |
| from src.config import settings | |
| from src.dependencies import tracked_db | |
| from src.dreamer.dream_scheduler import get_dream_scheduler | |
| from src.exceptions import ValidationException | |
| from src.models import QueueItem | |
| from src.schemas import MessageConfiguration, ResolvedConfiguration | |
| from src.utils.config_helpers import get_configuration | |
| from src.utils.queue_payload import ( | |
| create_deletion_payload, | |
| create_dream_payload, | |
| create_payload, | |
| ) | |
| from src.utils.work_unit import construct_work_unit_key | |
| logger = logging.getLogger(__name__) | |
| async def enqueue(payload: list[dict[str, Any]]) -> None: | |
| """ | |
| Add message(s) to the deriver queue for processing. | |
| Args: | |
| payload: List of message payload dictionaries | |
| """ | |
| # Cancel any pending dreams for affected collections since user is active again. | |
| # This cancels dreams for all collections where observed=peer_name, which covers | |
| # both self-observation and peer-to-peer observation cases. | |
| dream_scheduler = get_dream_scheduler() | |
| if dream_scheduler and payload: | |
| cancelled_dreams: set[str] = set() | |
| for message in payload: | |
| workspace_name = message.get("workspace_name") | |
| peer_name = message.get("peer_name") | |
| if workspace_name and peer_name: | |
| cancelled = await dream_scheduler.cancel_dreams_for_observed( | |
| workspace_name, peer_name | |
| ) | |
| cancelled_dreams.update(cancelled) | |
| if cancelled_dreams: | |
| logger.info( | |
| f"Cancelled {len(cancelled_dreams)} pending dreams due to new activity" | |
| ) | |
| async with tracked_db("message_enqueue") as db_session: | |
| try: | |
| # Determine if batch or single processing | |
| if not payload: # Empty list check | |
| return | |
| workspace_name = payload[0]["workspace_name"] | |
| session_name = payload[0]["session_name"] | |
| if session_name is None or workspace_name is None: | |
| raise ValidationException("Session and workspace are required") | |
| queue_records = await handle_session( | |
| db_session, payload, workspace_name, session_name | |
| ) | |
| if queue_records: | |
| stmt = insert(QueueItem).returning(QueueItem) | |
| await db_session.execute(stmt, queue_records) | |
| await db_session.commit() | |
| except Exception as e: | |
| logger.exception("Failed to enqueue message(s)!") | |
| if settings.SENTRY.ENABLED: | |
| import sentry_sdk | |
| sentry_sdk.capture_exception(e) | |
| async def handle_session( | |
| db_session: AsyncSession, | |
| payload: list[dict[str, Any]], | |
| workspace_name: str, | |
| session_name: str, | |
| ) -> list[dict[str, Any]]: | |
| """ | |
| Handle enqueueing for normal session cases, creating appropriate queue items based on configurations. | |
| Args: | |
| db_session: The database session | |
| payload: List of message payloads | |
| workspace_name: Name of the workspace | |
| session_name: Name of the session | |
| Returns: | |
| List of queue records to insert | |
| """ | |
| session = ( | |
| await crud.get_or_create_session( | |
| db_session, | |
| session=schemas.SessionCreate(name=session_name), | |
| workspace_name=workspace_name, | |
| ) | |
| ).resource | |
| # Fetch workspace for configuration resolution | |
| workspace = await crud.get_workspace(db_session, workspace_name=workspace_name) | |
| # Resolve summary configuration with hierarchical fallback | |
| session_level_configuration = get_configuration(None, session, workspace) | |
| peers_with_configuration = await get_peers_with_configuration( | |
| db_session, workspace_name, session_name | |
| ) | |
| queue_records: list[dict[str, Any]] = [] | |
| for message in payload: | |
| message_config: MessageConfiguration | None = message.get("configuration") | |
| if message_config is not None: | |
| message_level_configuration = get_configuration( | |
| message_config, session, workspace | |
| ) | |
| else: | |
| message_level_configuration = session_level_configuration | |
| queue_records.extend( | |
| await generate_queue_records( | |
| db_session, | |
| message, | |
| peers_with_configuration, | |
| session.id, | |
| message_level_configuration, | |
| ) | |
| ) | |
| return queue_records | |
| async def get_peers_with_configuration( | |
| db_session: AsyncSession, workspace_name: str, session_name: str | |
| ) -> dict[str, list[dict[str, Any]]]: | |
| """ | |
| Retrieve peers with their configurations for a given session. | |
| Args: | |
| db_session: The database session | |
| workspace_name: Name of the workspace | |
| session_name: Name of the session | |
| Returns: | |
| Dictionary mapping peer names to their configurations | |
| """ | |
| configuration_query = await crud.get_session_peer_configuration( | |
| workspace_name=workspace_name, session_name=session_name | |
| ) | |
| peers_with_configuration_result = await db_session.execute(configuration_query) | |
| peers_with_configuration_list = peers_with_configuration_result.all() | |
| return { | |
| row.peer_name: [ | |
| row.peer_configuration, | |
| row.session_peer_configuration, | |
| row.is_active, | |
| ] | |
| for row in peers_with_configuration_list | |
| } | |
| def create_representation_record( | |
| message: dict[str, Any], | |
| conf: ResolvedConfiguration, | |
| session_id: str | None = None, | |
| *, | |
| observers: list[str], | |
| observed: str, | |
| ) -> dict[str, Any]: | |
| """ | |
| Create a queue record for representation task. | |
| Args: | |
| message: The message payload | |
| conf: Resolved configuration for this particular message | |
| session_id: Optional session ID | |
| observers: List of observer peer names | |
| observed: Name of the sender | |
| Returns: | |
| Queue record dictionary with workspace_name and message_id as separate fields | |
| """ | |
| workspace_name = message.get("workspace_name") | |
| message_id = message.get("message_id") | |
| if not isinstance(workspace_name, str): | |
| raise TypeError("workspace_name is required and must be a string") | |
| if not isinstance(message_id, int): | |
| raise TypeError("message_id is required and must be an integer") | |
| processed_payload: dict[str, Any] = create_payload( | |
| message=message, | |
| configuration=conf, | |
| task_type="representation", | |
| observers=observers, | |
| observed=observed, | |
| ) | |
| return { | |
| "work_unit_key": construct_work_unit_key(workspace_name, processed_payload), | |
| "payload": processed_payload, | |
| "session_id": session_id, | |
| "task_type": "representation", | |
| "workspace_name": workspace_name, | |
| "message_id": message_id, | |
| } | |
| def create_summary_record( | |
| message: dict[str, Any], | |
| configuration: ResolvedConfiguration, | |
| session_id: str, | |
| message_seq_in_session: int, | |
| ) -> dict[str, Any]: | |
| """ | |
| Create a queue record for summary task. | |
| Args: | |
| message: The message payload | |
| session_id: Session ID | |
| message_seq_in_session: The sequence number of the message in the session | |
| Returns: | |
| Queue record dictionary with workspace_name and message_id as separate fields | |
| """ | |
| workspace_name = message.get("workspace_name") | |
| message_id = message.get("message_id") | |
| if not isinstance(workspace_name, str): | |
| raise ValueError("workspace_name is required and must be a string") | |
| if not isinstance(message_id, int): | |
| raise ValueError("message_id is required and must be an integer") | |
| processed_payload = create_payload( | |
| message=message, | |
| configuration=configuration, | |
| task_type="summary", | |
| message_seq_in_session=message_seq_in_session, | |
| ) | |
| return { | |
| "work_unit_key": construct_work_unit_key(workspace_name, processed_payload), | |
| "payload": processed_payload, | |
| "session_id": session_id, | |
| "task_type": "summary", | |
| "workspace_name": workspace_name, | |
| "message_id": message_id, | |
| } | |
| def get_effective_observe_me( | |
| observed: str, peers_with_configuration: dict[str, list[dict[str, Any]]] | |
| ) -> bool: | |
| """ | |
| Determine the effective observe_me setting for a sender, considering session and peer configurations. | |
| Args: | |
| observed: Name of the sender | |
| peers_with_configuration: Dictionary of peer configurations | |
| Returns: | |
| True if observe_me is enabled, False otherwise | |
| """ | |
| # If the sender is not in peers_with_configuration, they left after sending a message. | |
| # We'll use the default behavior of observing the sender by instantiating the default | |
| # peer-level and session-level configs. | |
| configuration: list[Any] = peers_with_configuration.get(observed, [{}, {}]) | |
| sender_session_peer_config = ( | |
| schemas.SessionPeerConfig(**configuration[1]) if configuration[1] else None | |
| ) | |
| sender_peer_config = ( | |
| schemas.PeerConfig(**configuration[0]) | |
| if configuration[0] | |
| else schemas.PeerConfig() | |
| ) | |
| # Session peer config takes precedence if it exists and has observe_me set | |
| if sender_session_peer_config and sender_session_peer_config.observe_me is not None: | |
| return sender_session_peer_config.observe_me | |
| # Otherwise use peer config | |
| return ( | |
| sender_peer_config.observe_me | |
| if sender_peer_config.observe_me is not None | |
| else True | |
| ) | |
| async def generate_queue_records( | |
| db_session: AsyncSession, | |
| message: dict[str, Any], | |
| peers_with_configuration: dict[str, list[dict[str, Any]]], | |
| session_id: str, | |
| conf: ResolvedConfiguration, | |
| ) -> list[dict[str, Any]]: | |
| """ | |
| Process a single message and generate queue records based on configurations. | |
| Args: | |
| db_session: The database session | |
| message: The message payload | |
| peers_with_configuration: Dictionary of peer configurations | |
| session_id: Session ID | |
| configuration: Resolved configuration for this particular message | |
| Returns: | |
| List of queue records for this message | |
| """ | |
| observed = message["peer_name"] | |
| message_id: int = message["message_id"] | |
| # Prefer the sequence captured during message creation; fallback only if missing | |
| message_seq_in_session = int(message.get("seq_in_session") or 0) | |
| if message_seq_in_session <= 0: | |
| message_seq_in_session = await crud.get_message_seq_in_session( | |
| db_session, | |
| workspace_name=message["workspace_name"], | |
| session_name=message["session_name"], | |
| message_id=message_id, | |
| ) | |
| records: list[dict[str, Any]] = [] | |
| if conf.summary.enabled and ( | |
| message_seq_in_session % conf.summary.messages_per_short_summary == 0 | |
| or message_seq_in_session % conf.summary.messages_per_long_summary == 0 | |
| ): | |
| records.append( | |
| create_summary_record( | |
| message, | |
| configuration=conf, | |
| session_id=session_id, | |
| message_seq_in_session=message_seq_in_session, | |
| ) | |
| ) | |
| # Check if the sender should be observed based on peer configuration | |
| should_observe = get_effective_observe_me(observed, peers_with_configuration) | |
| if not conf.reasoning.enabled: | |
| return records | |
| # Collect all observers into a single list | |
| observers: list[str] = [] | |
| if should_observe: | |
| # Self-observation: the sender observes themselves | |
| observers.append(observed) | |
| # Other peers who want to observe | |
| for peer_name, peer_conf in peers_with_configuration.items(): | |
| if peer_name == observed: | |
| continue | |
| # If the observer peer has left the session, skip them | |
| if not peer_conf[2]: | |
| continue | |
| session_peer_config = ( | |
| schemas.SessionPeerConfig(**peer_conf[1]) if peer_conf[1] else None | |
| ) | |
| if session_peer_config is None or not session_peer_config.observe_others: | |
| continue | |
| observers.append(peer_name) | |
| # Create a single record with all observers (if any) | |
| if observers: | |
| records.append( | |
| create_representation_record( | |
| message, | |
| conf, | |
| observed=observed, | |
| observers=observers, | |
| session_id=session_id, | |
| ) | |
| ) | |
| logger.debug( | |
| "message %s from %s created %s queue items with %s observers", | |
| message_id, | |
| observed, | |
| len(records), | |
| len(observers), | |
| ) | |
| return records | |
| def create_dream_record( | |
| workspace_name: str, | |
| *, | |
| observer: str, | |
| observed: str, | |
| dream_type: schemas.DreamType, | |
| session_name: str | None = None, | |
| ) -> dict[str, Any]: | |
| """ | |
| Create a queue record for a dream task. | |
| Args: | |
| workspace_name: Name of the workspace | |
| observer: Name of the observer peer | |
| observed: Name of the observed peer | |
| dream_type: Type of dream to execute | |
| session_name: Name of the session to scope the dream to if specified | |
| Returns: | |
| Queue record dictionary with workspace_name and other fields | |
| """ | |
| dream_payload = create_dream_payload( | |
| dream_type, | |
| observer=observer, | |
| observed=observed, | |
| session_name=session_name, | |
| ) | |
| return { | |
| "work_unit_key": construct_work_unit_key(workspace_name, dream_payload), | |
| "payload": dream_payload, | |
| "session_id": None, | |
| "task_type": "dream", | |
| "workspace_name": workspace_name, | |
| "message_id": None, | |
| } | |
| async def enqueue_dream( | |
| workspace_name: str, | |
| observer: str, | |
| observed: str, | |
| dream_type: schemas.DreamType, | |
| session_name: str | None = None, | |
| ) -> None: | |
| """ | |
| Enqueue a dream task for immediate processing by the deriver. | |
| Does not touch collection.internal_metadata["dream"] — both guard fields | |
| are written atomically in process_dream on successful completion. | |
| Deduplication: If a dream with the same work_unit_key is already in-progress | |
| (has an ActiveQueueSession) or pending in the queue, the enqueue is skipped. | |
| Args: | |
| workspace_name: Name of the workspace | |
| observer: Name of the observer peer | |
| observed: Name of the observed peer | |
| dream_type: Type of dream to execute | |
| session_name: Name of the session to scope the dream to if specified | |
| """ | |
| async with tracked_db("dream_enqueue") as db_session: | |
| try: | |
| dream_record = create_dream_record( | |
| workspace_name, | |
| observer=observer, | |
| observed=observed, | |
| dream_type=dream_type, | |
| session_name=session_name, | |
| ) | |
| work_unit_key = dream_record["work_unit_key"] | |
| in_progress_check = select( | |
| exists( | |
| select(models.ActiveQueueSession.id).where( | |
| models.ActiveQueueSession.work_unit_key == work_unit_key | |
| ) | |
| ) | |
| ) | |
| is_in_progress = await db_session.scalar(in_progress_check) | |
| if is_in_progress: | |
| logger.info( | |
| "Skipping dream enqueue - already in progress: %s/%s/%s (type: %s)", | |
| workspace_name, | |
| observer, | |
| observed, | |
| dream_type.value, | |
| ) | |
| return | |
| pending_check = select( | |
| exists( | |
| select(QueueItem.id).where( | |
| QueueItem.work_unit_key == work_unit_key, | |
| QueueItem.processed == False, # noqa: E712 | |
| ) | |
| ) | |
| ) | |
| is_pending = await db_session.scalar(pending_check) | |
| if is_pending: | |
| logger.info( | |
| "Dream already pending in queue: %s/%s/%s (type: %s)", | |
| workspace_name, | |
| observer, | |
| observed, | |
| dream_type.value, | |
| ) | |
| return | |
| stmt = insert(QueueItem).returning(QueueItem) | |
| await db_session.execute(stmt, [dream_record]) | |
| await db_session.commit() | |
| logger.info( | |
| "Enqueued dream task for %s/%s/%s (type: %s)", | |
| workspace_name, | |
| observer, | |
| observed, | |
| dream_type.value, | |
| ) | |
| except Exception as e: | |
| logger.exception("Failed to enqueue dream task!") | |
| if settings.SENTRY.ENABLED: | |
| import sentry_sdk | |
| sentry_sdk.capture_exception(e) | |
| raise | |
| def create_deletion_record( | |
| workspace_name: str, | |
| deletion_type: Literal["session", "observation", "workspace"], | |
| resource_id: str, | |
| ) -> dict[str, Any]: | |
| """ | |
| Create a queue record for a deletion task. | |
| Args: | |
| workspace_name: Name of the workspace | |
| deletion_type: Type of resource to delete ("session" or "observation") | |
| resource_id: ID of the resource to delete | |
| Returns: | |
| Queue record dictionary for insertion into the queue | |
| """ | |
| deletion_payload = create_deletion_payload( | |
| deletion_type=deletion_type, | |
| resource_id=resource_id, | |
| ) | |
| return { | |
| "work_unit_key": construct_work_unit_key(workspace_name, deletion_payload), | |
| "payload": deletion_payload, | |
| "session_id": None, | |
| "task_type": "deletion", | |
| "workspace_name": workspace_name, | |
| "message_id": None, | |
| } | |
| async def enqueue_deletion( | |
| workspace_name: str, | |
| deletion_type: Literal["session", "observation", "workspace"], | |
| resource_id: str, | |
| db_session: AsyncSession | None = None, | |
| ) -> None: | |
| """ | |
| Enqueue a deletion task for processing by the deriver. | |
| This function adds a deletion task to the queue for asynchronous processing. | |
| The deletion will be handled by the queue consumer with retry support. | |
| Args: | |
| workspace_name: Name of the workspace | |
| deletion_type: Type of resource to delete ("session" or "observation") | |
| resource_id: ID of the resource to delete | |
| db_session: Optional database session. If provided, uses this session | |
| instead of creating a new one. The caller is responsible for committing. | |
| """ | |
| async def _do_enqueue(session: AsyncSession, should_commit: bool) -> None: | |
| deletion_record = create_deletion_record( | |
| workspace_name, | |
| deletion_type, | |
| resource_id, | |
| ) | |
| stmt = insert(QueueItem).returning(QueueItem) | |
| await session.execute(stmt, [deletion_record]) | |
| if should_commit: | |
| await session.commit() | |
| logger.info( | |
| "Enqueued deletion task: type=%s, resource_id=%s, workspace=%s", | |
| deletion_type, | |
| resource_id, | |
| workspace_name, | |
| ) | |
| try: | |
| if db_session is not None: | |
| # Use the provided session - caller is responsible for committing | |
| await _do_enqueue(db_session, should_commit=False) | |
| else: | |
| # Create a new session and commit | |
| async with tracked_db("deletion_enqueue") as new_session: | |
| await _do_enqueue(new_session, should_commit=True) | |
| except Exception as e: | |
| logger.exception("Failed to enqueue deletion task!") | |
| if settings.SENTRY.ENABLED: | |
| import sentry_sdk | |
| sentry_sdk.capture_exception(e) | |
| raise | |