| | """Session manager for handling multiple concurrent agent sessions.""" |
| |
|
| | import asyncio |
| | import logging |
| | import uuid |
| | from dataclasses import dataclass, field |
| | from datetime import datetime |
| | from pathlib import Path |
| | from typing import Any, Optional |
| |
|
| | from websocket import manager as ws_manager |
| |
|
| | from agent.config import load_config |
| | from agent.core.agent_loop import process_submission |
| | from agent.core.session import Event, OpType, Session |
| | from agent.core.tools import ToolRouter |
| |
|
| | |
| | PROJECT_ROOT = Path(__file__).parent.parent |
| | DEFAULT_CONFIG_PATH = str(PROJECT_ROOT / "configs" / "main_agent_config.json") |
| |
|
| |
|
| | |
| | @dataclass |
| | class Operation: |
| | """Operation to be executed by the agent.""" |
| |
|
| | op_type: OpType |
| | data: Optional[dict[str, Any]] = None |
| |
|
| |
|
| | @dataclass |
| | class Submission: |
| | """Submission to the agent loop.""" |
| |
|
| | id: str |
| | operation: Operation |
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | @dataclass |
| | class AgentSession: |
| | """Wrapper for an agent session with its associated resources.""" |
| |
|
| | session_id: str |
| | session: Session |
| | tool_router: ToolRouter |
| | submission_queue: asyncio.Queue |
| | user_id: str = "dev" |
| | hf_token: str | None = None |
| | task: asyncio.Task | None = None |
| | created_at: datetime = field(default_factory=datetime.utcnow) |
| | is_active: bool = True |
| |
|
| |
|
| | class SessionCapacityError(Exception): |
| | """Raised when no more sessions can be created.""" |
| |
|
| | def __init__(self, message: str, error_type: str = "global") -> None: |
| | super().__init__(message) |
| | self.error_type = error_type |
| |
|
| |
|
| | |
| | |
| | |
| | MAX_SESSIONS: int = 50 |
| | MAX_SESSIONS_PER_USER: int = 10 |
| |
|
| |
|
| | class SessionManager: |
| | """Manages multiple concurrent agent sessions.""" |
| |
|
| | def __init__(self, config_path: str | None = None) -> None: |
| | self.config = load_config(config_path or DEFAULT_CONFIG_PATH) |
| | self.sessions: dict[str, AgentSession] = {} |
| | self._lock = asyncio.Lock() |
| |
|
| | def _count_user_sessions(self, user_id: str) -> int: |
| | """Count active sessions owned by a specific user.""" |
| | return sum( |
| | 1 |
| | for s in self.sessions.values() |
| | if s.user_id == user_id and s.is_active |
| | ) |
| |
|
| | async def create_session(self, user_id: str = "dev", hf_token: str | None = None) -> str: |
| | """Create a new agent session and return its ID. |
| | |
| | Session() and ToolRouter() constructors contain blocking I/O |
| | (e.g. HfApi().whoami(), litellm.get_max_tokens()) so they are |
| | executed in a thread pool to avoid freezing the async event loop. |
| | |
| | Args: |
| | user_id: The ID of the user who owns this session. |
| | |
| | Raises: |
| | SessionCapacityError: If the server or user has reached the |
| | maximum number of concurrent sessions. |
| | """ |
| | |
| | async with self._lock: |
| | active_count = self.active_session_count |
| | if active_count >= MAX_SESSIONS: |
| | raise SessionCapacityError( |
| | f"Server is at capacity ({active_count}/{MAX_SESSIONS} sessions). " |
| | "Please try again later.", |
| | error_type="global", |
| | ) |
| | if user_id != "dev": |
| | user_count = self._count_user_sessions(user_id) |
| | if user_count >= MAX_SESSIONS_PER_USER: |
| | raise SessionCapacityError( |
| | f"You have reached the maximum of {MAX_SESSIONS_PER_USER} " |
| | "concurrent sessions. Please close an existing session first.", |
| | error_type="per_user", |
| | ) |
| |
|
| | session_id = str(uuid.uuid4()) |
| |
|
| | |
| | submission_queue: asyncio.Queue = asyncio.Queue() |
| | event_queue: asyncio.Queue = asyncio.Queue() |
| |
|
| | |
| | |
| | |
| | import time as _time |
| |
|
| | def _create_session_sync(): |
| | t0 = _time.monotonic() |
| | tool_router = ToolRouter(self.config.mcpServers) |
| | session = Session(event_queue, config=self.config, tool_router=tool_router) |
| | t1 = _time.monotonic() |
| | logger.info(f"Session initialized in {t1 - t0:.2f}s") |
| | return tool_router, session |
| |
|
| | tool_router, session = await asyncio.to_thread(_create_session_sync) |
| |
|
| | |
| | session.hf_token = hf_token |
| |
|
| | |
| | agent_session = AgentSession( |
| | session_id=session_id, |
| | session=session, |
| | tool_router=tool_router, |
| | submission_queue=submission_queue, |
| | user_id=user_id, |
| | hf_token=hf_token, |
| | ) |
| |
|
| | async with self._lock: |
| | self.sessions[session_id] = agent_session |
| |
|
| | |
| | task = asyncio.create_task( |
| | self._run_session(session_id, submission_queue, event_queue, tool_router) |
| | ) |
| | agent_session.task = task |
| |
|
| | logger.info(f"Created session {session_id} for user {user_id}") |
| | return session_id |
| |
|
| | async def _run_session( |
| | self, |
| | session_id: str, |
| | submission_queue: asyncio.Queue, |
| | event_queue: asyncio.Queue, |
| | tool_router: ToolRouter, |
| | ) -> None: |
| | """Run the agent loop for a session and forward events to WebSocket.""" |
| | agent_session = self.sessions.get(session_id) |
| | if not agent_session: |
| | logger.error(f"Session {session_id} not found") |
| | return |
| |
|
| | session = agent_session.session |
| |
|
| | |
| | event_forwarder = asyncio.create_task( |
| | self._forward_events(session_id, event_queue) |
| | ) |
| |
|
| | try: |
| | async with tool_router: |
| | |
| | await session.send_event( |
| | Event(event_type="ready", data={"message": "Agent initialized"}) |
| | ) |
| |
|
| | while session.is_running: |
| | try: |
| | |
| | submission = await asyncio.wait_for( |
| | submission_queue.get(), timeout=1.0 |
| | ) |
| | should_continue = await process_submission(session, submission) |
| | if not should_continue: |
| | break |
| | except asyncio.TimeoutError: |
| | continue |
| | except asyncio.CancelledError: |
| | logger.info(f"Session {session_id} cancelled") |
| | break |
| | except Exception as e: |
| | logger.error(f"Error in session {session_id}: {e}") |
| | await session.send_event( |
| | Event(event_type="error", data={"error": str(e)}) |
| | ) |
| |
|
| | finally: |
| | event_forwarder.cancel() |
| | try: |
| | await event_forwarder |
| | except asyncio.CancelledError: |
| | pass |
| |
|
| | async with self._lock: |
| | if session_id in self.sessions: |
| | self.sessions[session_id].is_active = False |
| |
|
| | logger.info(f"Session {session_id} ended") |
| |
|
| | async def _forward_events( |
| | self, session_id: str, event_queue: asyncio.Queue |
| | ) -> None: |
| | """Forward events from the agent to the WebSocket.""" |
| | while True: |
| | try: |
| | event: Event = await event_queue.get() |
| | await ws_manager.send_event(session_id, event.event_type, event.data) |
| | except asyncio.CancelledError: |
| | break |
| | except Exception as e: |
| | logger.error(f"Error forwarding event for {session_id}: {e}") |
| |
|
| | async def submit(self, session_id: str, operation: Operation) -> bool: |
| | """Submit an operation to a session.""" |
| | async with self._lock: |
| | agent_session = self.sessions.get(session_id) |
| |
|
| | if not agent_session or not agent_session.is_active: |
| | logger.warning(f"Session {session_id} not found or inactive") |
| | return False |
| |
|
| | submission = Submission(id=f"sub_{uuid.uuid4().hex[:8]}", operation=operation) |
| | await agent_session.submission_queue.put(submission) |
| | return True |
| |
|
| | async def submit_user_input(self, session_id: str, text: str) -> bool: |
| | """Submit user input to a session.""" |
| | operation = Operation(op_type=OpType.USER_INPUT, data={"text": text}) |
| | return await self.submit(session_id, operation) |
| |
|
| | async def submit_approval( |
| | self, session_id: str, approvals: list[dict[str, Any]] |
| | ) -> bool: |
| | """Submit tool approvals to a session.""" |
| | operation = Operation( |
| | op_type=OpType.EXEC_APPROVAL, data={"approvals": approvals} |
| | ) |
| | return await self.submit(session_id, operation) |
| |
|
| | async def interrupt(self, session_id: str) -> bool: |
| | """Interrupt a session.""" |
| | operation = Operation(op_type=OpType.INTERRUPT) |
| | return await self.submit(session_id, operation) |
| |
|
| | async def undo(self, session_id: str) -> bool: |
| | """Undo last turn in a session.""" |
| | operation = Operation(op_type=OpType.UNDO) |
| | return await self.submit(session_id, operation) |
| |
|
| | async def compact(self, session_id: str) -> bool: |
| | """Compact context in a session.""" |
| | operation = Operation(op_type=OpType.COMPACT) |
| | return await self.submit(session_id, operation) |
| |
|
| | async def shutdown_session(self, session_id: str) -> bool: |
| | """Shutdown a specific session.""" |
| | operation = Operation(op_type=OpType.SHUTDOWN) |
| | success = await self.submit(session_id, operation) |
| |
|
| | if success: |
| | async with self._lock: |
| | agent_session = self.sessions.get(session_id) |
| | if agent_session and agent_session.task: |
| | |
| | try: |
| | await asyncio.wait_for(agent_session.task, timeout=5.0) |
| | except asyncio.TimeoutError: |
| | agent_session.task.cancel() |
| |
|
| | return success |
| |
|
| | async def delete_session(self, session_id: str) -> bool: |
| | """Delete a session entirely.""" |
| | async with self._lock: |
| | agent_session = self.sessions.pop(session_id, None) |
| |
|
| | if not agent_session: |
| | return False |
| |
|
| | |
| | if agent_session.task and not agent_session.task.done(): |
| | agent_session.task.cancel() |
| | try: |
| | await agent_session.task |
| | except asyncio.CancelledError: |
| | pass |
| |
|
| | return True |
| |
|
| | def get_session_owner(self, session_id: str) -> str | None: |
| | """Get the user_id that owns a session, or None if session doesn't exist.""" |
| | agent_session = self.sessions.get(session_id) |
| | if not agent_session: |
| | return None |
| | return agent_session.user_id |
| |
|
| | def verify_session_access(self, session_id: str, user_id: str) -> bool: |
| | """Check if a user has access to a session. |
| | |
| | Returns True if: |
| | - The session exists AND the user owns it |
| | - The user_id is "dev" (dev mode bypass) |
| | """ |
| | owner = self.get_session_owner(session_id) |
| | if owner is None: |
| | return False |
| | if user_id == "dev" or owner == "dev": |
| | return True |
| | return owner == user_id |
| |
|
| | def get_session_info(self, session_id: str) -> dict[str, Any] | None: |
| | """Get information about a session.""" |
| | agent_session = self.sessions.get(session_id) |
| | if not agent_session: |
| | return None |
| |
|
| | return { |
| | "session_id": session_id, |
| | "created_at": agent_session.created_at.isoformat(), |
| | "is_active": agent_session.is_active, |
| | "message_count": len(agent_session.session.context_manager.items), |
| | "user_id": agent_session.user_id, |
| | } |
| |
|
| | def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]: |
| | """List sessions, optionally filtered by user. |
| | |
| | Args: |
| | user_id: If provided, only return sessions owned by this user. |
| | If "dev", return all sessions (dev mode). |
| | """ |
| | results = [] |
| | for sid in self.sessions: |
| | info = self.get_session_info(sid) |
| | if not info: |
| | continue |
| | if user_id and user_id != "dev" and info.get("user_id") != user_id: |
| | continue |
| | results.append(info) |
| | return results |
| |
|
| | @property |
| | def active_session_count(self) -> int: |
| | """Get count of active sessions.""" |
| | return sum(1 for s in self.sessions.values() if s.is_active) |
| |
|
| |
|
| | |
| | session_manager = SessionManager() |
| |
|