| | """Session manager for handling multiple concurrent agent sessions."""
|
| |
|
| | import asyncio
|
| | import logging
|
| | import uuid
|
| | from dataclasses import dataclass, field
|
| | from datetime import datetime
|
| | from pathlib import Path
|
| | from typing import Any, Optional
|
| |
|
| | from websocket import manager as ws_manager
|
| |
|
| | from agent.config import load_config
|
| | from agent.core.agent_loop import process_submission
|
| | from agent.core.session import Event, OpType, Session
|
| | from agent.core.tools import ToolRouter
|
| |
|
| |
|
| | PROJECT_ROOT = Path(__file__).parent.parent
|
| | DEFAULT_CONFIG_PATH = str(PROJECT_ROOT / "configs" / "main_agent_config.json")
|
| |
|
| |
|
| |
|
| | @dataclass
|
| | class Operation:
|
| | """Operation to be executed by the agent."""
|
| |
|
| | op_type: OpType
|
| | data: Optional[dict[str, Any]] = None
|
| |
|
| |
|
| | @dataclass
|
| | class Submission:
|
| | """Submission to the agent loop."""
|
| |
|
| | id: str
|
| | operation: Operation
|
| |
|
| |
|
| | logger = logging.getLogger(__name__)
|
| |
|
| |
|
| | @dataclass
|
| | class AgentSession:
|
| | """Wrapper for an agent session with its associated resources."""
|
| |
|
| | session_id: str
|
| | session: Session
|
| | tool_router: ToolRouter
|
| | submission_queue: asyncio.Queue
|
| | user_id: str = "dev"
|
| | hf_token: str | None = None
|
| | task: asyncio.Task | None = None
|
| | created_at: datetime = field(default_factory=datetime.utcnow)
|
| | is_active: bool = True
|
| |
|
| |
|
| | class SessionCapacityError(Exception):
|
| | """Raised when no more sessions can be created."""
|
| |
|
| | def __init__(self, message: str, error_type: str = "global") -> None:
|
| | super().__init__(message)
|
| | self.error_type = error_type
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | MAX_SESSIONS: int = 50
|
| | MAX_SESSIONS_PER_USER: int = 10
|
| |
|
| |
|
| | class SessionManager:
|
| | """Manages multiple concurrent agent sessions."""
|
| |
|
| | def __init__(self, config_path: str | None = None) -> None:
|
| | self.config = load_config(config_path or DEFAULT_CONFIG_PATH)
|
| | self.sessions: dict[str, AgentSession] = {}
|
| | self._lock = asyncio.Lock()
|
| |
|
| | def _count_user_sessions(self, user_id: str) -> int:
|
| | """Count active sessions owned by a specific user."""
|
| | return sum(
|
| | 1
|
| | for s in self.sessions.values()
|
| | if s.user_id == user_id and s.is_active
|
| | )
|
| |
|
| | async def create_session(self, user_id: str = "dev", hf_token: str | None = None) -> str:
|
| | """Create a new agent session and return its ID.
|
| |
|
| | Session() and ToolRouter() constructors contain blocking I/O
|
| | (e.g. HfApi().whoami(), litellm.get_max_tokens()) so they are
|
| | executed in a thread pool to avoid freezing the async event loop.
|
| |
|
| | Args:
|
| | user_id: The ID of the user who owns this session.
|
| |
|
| | Raises:
|
| | SessionCapacityError: If the server or user has reached the
|
| | maximum number of concurrent sessions.
|
| | """
|
| |
|
| | async with self._lock:
|
| | active_count = self.active_session_count
|
| | if active_count >= MAX_SESSIONS:
|
| | raise SessionCapacityError(
|
| | f"Server is at capacity ({active_count}/{MAX_SESSIONS} sessions). "
|
| | "Please try again later.",
|
| | error_type="global",
|
| | )
|
| | if user_id != "dev":
|
| | user_count = self._count_user_sessions(user_id)
|
| | if user_count >= MAX_SESSIONS_PER_USER:
|
| | raise SessionCapacityError(
|
| | f"You have reached the maximum of {MAX_SESSIONS_PER_USER} "
|
| | "concurrent sessions. Please close an existing session first.",
|
| | error_type="per_user",
|
| | )
|
| |
|
| | session_id = str(uuid.uuid4())
|
| |
|
| |
|
| | submission_queue: asyncio.Queue = asyncio.Queue()
|
| | event_queue: asyncio.Queue = asyncio.Queue()
|
| |
|
| |
|
| |
|
| |
|
| | import time as _time
|
| |
|
| | def _create_session_sync():
|
| | t0 = _time.monotonic()
|
| | tool_router = ToolRouter(self.config.mcpServers)
|
| | session = Session(event_queue, config=self.config, tool_router=tool_router)
|
| | t1 = _time.monotonic()
|
| | logger.info(f"Session initialized in {t1 - t0:.2f}s")
|
| | return tool_router, session
|
| |
|
| | tool_router, session = await asyncio.to_thread(_create_session_sync)
|
| |
|
| |
|
| | session.hf_token = hf_token
|
| |
|
| |
|
| | agent_session = AgentSession(
|
| | session_id=session_id,
|
| | session=session,
|
| | tool_router=tool_router,
|
| | submission_queue=submission_queue,
|
| | user_id=user_id,
|
| | hf_token=hf_token,
|
| | )
|
| |
|
| | async with self._lock:
|
| | self.sessions[session_id] = agent_session
|
| |
|
| |
|
| | task = asyncio.create_task(
|
| | self._run_session(session_id, submission_queue, event_queue, tool_router)
|
| | )
|
| | agent_session.task = task
|
| |
|
| | logger.info(f"Created session {session_id} for user {user_id}")
|
| | return session_id
|
| |
|
| | async def _run_session(
|
| | self,
|
| | session_id: str,
|
| | submission_queue: asyncio.Queue,
|
| | event_queue: asyncio.Queue,
|
| | tool_router: ToolRouter,
|
| | ) -> None:
|
| | """Run the agent loop for a session and forward events to WebSocket."""
|
| | agent_session = self.sessions.get(session_id)
|
| | if not agent_session:
|
| | logger.error(f"Session {session_id} not found")
|
| | return
|
| |
|
| | session = agent_session.session
|
| |
|
| |
|
| | event_forwarder = asyncio.create_task(
|
| | self._forward_events(session_id, event_queue)
|
| | )
|
| |
|
| | try:
|
| | async with tool_router:
|
| |
|
| | await session.send_event(
|
| | Event(event_type="ready", data={"message": "Agent initialized"})
|
| | )
|
| |
|
| | while session.is_running:
|
| | try:
|
| |
|
| | submission = await asyncio.wait_for(
|
| | submission_queue.get(), timeout=1.0
|
| | )
|
| | should_continue = await process_submission(session, submission)
|
| | if not should_continue:
|
| | break
|
| | except asyncio.TimeoutError:
|
| | continue
|
| | except asyncio.CancelledError:
|
| | logger.info(f"Session {session_id} cancelled")
|
| | break
|
| | except Exception as e:
|
| | logger.error(f"Error in session {session_id}: {e}")
|
| | await session.send_event(
|
| | Event(event_type="error", data={"error": str(e)})
|
| | )
|
| |
|
| | finally:
|
| | event_forwarder.cancel()
|
| | try:
|
| | await event_forwarder
|
| | except asyncio.CancelledError:
|
| | pass
|
| |
|
| | async with self._lock:
|
| | if session_id in self.sessions:
|
| | self.sessions[session_id].is_active = False
|
| |
|
| | logger.info(f"Session {session_id} ended")
|
| |
|
| | async def _forward_events(
|
| | self, session_id: str, event_queue: asyncio.Queue
|
| | ) -> None:
|
| | """Forward events from the agent to the WebSocket."""
|
| | while True:
|
| | try:
|
| | event: Event = await event_queue.get()
|
| | await ws_manager.send_event(session_id, event.event_type, event.data)
|
| | except asyncio.CancelledError:
|
| | break
|
| | except Exception as e:
|
| | logger.error(f"Error forwarding event for {session_id}: {e}")
|
| |
|
| | async def submit(self, session_id: str, operation: Operation) -> bool:
|
| | """Submit an operation to a session."""
|
| | async with self._lock:
|
| | agent_session = self.sessions.get(session_id)
|
| |
|
| | if not agent_session or not agent_session.is_active:
|
| | logger.warning(f"Session {session_id} not found or inactive")
|
| | return False
|
| |
|
| | submission = Submission(id=f"sub_{uuid.uuid4().hex[:8]}", operation=operation)
|
| | await agent_session.submission_queue.put(submission)
|
| | return True
|
| |
|
| | async def submit_user_input(self, session_id: str, text: str) -> bool:
|
| | """Submit user input to a session."""
|
| | operation = Operation(op_type=OpType.USER_INPUT, data={"text": text})
|
| | return await self.submit(session_id, operation)
|
| |
|
| | async def submit_approval(
|
| | self, session_id: str, approvals: list[dict[str, Any]]
|
| | ) -> bool:
|
| | """Submit tool approvals to a session."""
|
| | operation = Operation(
|
| | op_type=OpType.EXEC_APPROVAL, data={"approvals": approvals}
|
| | )
|
| | return await self.submit(session_id, operation)
|
| |
|
| | async def interrupt(self, session_id: str) -> bool:
|
| | """Interrupt a session."""
|
| | operation = Operation(op_type=OpType.INTERRUPT)
|
| | return await self.submit(session_id, operation)
|
| |
|
| | async def undo(self, session_id: str) -> bool:
|
| | """Undo last turn in a session."""
|
| | operation = Operation(op_type=OpType.UNDO)
|
| | return await self.submit(session_id, operation)
|
| |
|
| | async def compact(self, session_id: str) -> bool:
|
| | """Compact context in a session."""
|
| | operation = Operation(op_type=OpType.COMPACT)
|
| | return await self.submit(session_id, operation)
|
| |
|
| | async def shutdown_session(self, session_id: str) -> bool:
|
| | """Shutdown a specific session."""
|
| | operation = Operation(op_type=OpType.SHUTDOWN)
|
| | success = await self.submit(session_id, operation)
|
| |
|
| | if success:
|
| | async with self._lock:
|
| | agent_session = self.sessions.get(session_id)
|
| | if agent_session and agent_session.task:
|
| |
|
| | try:
|
| | await asyncio.wait_for(agent_session.task, timeout=5.0)
|
| | except asyncio.TimeoutError:
|
| | agent_session.task.cancel()
|
| |
|
| | return success
|
| |
|
| | async def delete_session(self, session_id: str) -> bool:
|
| | """Delete a session entirely."""
|
| | async with self._lock:
|
| | agent_session = self.sessions.pop(session_id, None)
|
| |
|
| | if not agent_session:
|
| | return False
|
| |
|
| |
|
| | if agent_session.task and not agent_session.task.done():
|
| | agent_session.task.cancel()
|
| | try:
|
| | await agent_session.task
|
| | except asyncio.CancelledError:
|
| | pass
|
| |
|
| | return True
|
| |
|
| | def get_session_owner(self, session_id: str) -> str | None:
|
| | """Get the user_id that owns a session, or None if session doesn't exist."""
|
| | agent_session = self.sessions.get(session_id)
|
| | if not agent_session:
|
| | return None
|
| | return agent_session.user_id
|
| |
|
| | def verify_session_access(self, session_id: str, user_id: str) -> bool:
|
| | """Check if a user has access to a session.
|
| |
|
| | Returns True if:
|
| | - The session exists AND the user owns it
|
| | - The user_id is "dev" (dev mode bypass)
|
| | """
|
| | owner = self.get_session_owner(session_id)
|
| | if owner is None:
|
| | return False
|
| | if user_id == "dev" or owner == "dev":
|
| | return True
|
| | return owner == user_id
|
| |
|
| | def get_session_info(self, session_id: str) -> dict[str, Any] | None:
|
| | """Get information about a session."""
|
| | agent_session = self.sessions.get(session_id)
|
| | if not agent_session:
|
| | return None
|
| |
|
| | return {
|
| | "session_id": session_id,
|
| | "created_at": agent_session.created_at.isoformat(),
|
| | "is_active": agent_session.is_active,
|
| | "message_count": len(agent_session.session.context_manager.items),
|
| | "user_id": agent_session.user_id,
|
| | }
|
| |
|
| | def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]:
|
| | """List sessions, optionally filtered by user.
|
| |
|
| | Args:
|
| | user_id: If provided, only return sessions owned by this user.
|
| | If "dev", return all sessions (dev mode).
|
| | """
|
| | results = []
|
| | for sid in self.sessions:
|
| | info = self.get_session_info(sid)
|
| | if not info:
|
| | continue
|
| | if user_id and user_id != "dev" and info.get("user_id") != user_id:
|
| | continue
|
| | results.append(info)
|
| | return results
|
| |
|
| | @property
|
| | def active_session_count(self) -> int:
|
| | """Get count of active sessions."""
|
| | return sum(1 for s in self.sessions.values() if s.is_active)
|
| |
|
| |
|
| |
|
| | session_manager = SessionManager()
|
| |
|