Spaces:
Running
Running
| """Session manager for handling multiple concurrent agent sessions with user isolation.""" | |
| import asyncio | |
| import logging | |
| import uuid | |
| from dataclasses import dataclass, field | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any, Optional | |
| from event_manager import event_manager | |
| from lifecycle import lifecycle_manager | |
| from agent.config import load_config | |
| from agent.core.agent_loop import process_submission | |
| from agent.core.session import Event, OpType, Session | |
| from agent.core.tools import ToolRouter | |
| # Get project root (parent of backend directory) | |
| PROJECT_ROOT = Path(__file__).parent.parent | |
| DEFAULT_CONFIG_PATH = str(PROJECT_ROOT / "configs" / "main_agent_config.json") | |
| # These dataclasses match agent/main.py structure | |
| class Operation: | |
| """Operation to be executed by the agent.""" | |
| op_type: OpType | |
| data: Optional[dict[str, Any]] = None | |
| class Submission: | |
| """Submission to the agent loop.""" | |
| id: str | |
| operation: Operation | |
| logger = logging.getLogger(__name__) | |
| class UserContext: | |
| """Context for the authenticated user of a session.""" | |
| user_id: str # HF username | |
| hf_token: str # HF OAuth access token | |
| anthropic_key: Optional[str] = None # User's Anthropic API key | |
| class AgentSession: | |
| """Wrapper for an agent session with its associated resources.""" | |
| session_id: str | |
| session: Session | |
| tool_router: ToolRouter | |
| submission_queue: asyncio.Queue | |
| task: asyncio.Task | None = None | |
| created_at: datetime = field(default_factory=datetime.utcnow) | |
| is_active: bool = True | |
| # User isolation | |
| user_id: Optional[str] = None # Owner of this session | |
| user_context: Optional[UserContext] = None # User's auth context | |
| class SessionManager: | |
| """Manages multiple concurrent agent sessions with user isolation.""" | |
| def __init__(self, config_path: str | None = None) -> None: | |
| self.config = load_config(config_path or DEFAULT_CONFIG_PATH) | |
| self.sessions: dict[str, AgentSession] = {} | |
| self._lock = asyncio.Lock() | |
| async def create_session( | |
| self, | |
| user_id: Optional[str] = None, | |
| hf_token: Optional[str] = None, | |
| anthropic_key: Optional[str] = None, | |
| ) -> str: | |
| """Create a new agent session and return its ID. | |
| Args: | |
| user_id: Optional owner user ID (HF username) | |
| hf_token: Optional HF OAuth token for the user | |
| anthropic_key: Optional Anthropic API key for the user | |
| Returns: | |
| Session ID (UUID) | |
| """ | |
| session_id = str(uuid.uuid4()) | |
| # Create queues for this session | |
| submission_queue: asyncio.Queue = asyncio.Queue() | |
| event_queue: asyncio.Queue = asyncio.Queue() | |
| # Create user context if user is authenticated | |
| user_context = None | |
| if user_id and hf_token: | |
| user_context = UserContext( | |
| user_id=user_id, | |
| hf_token=hf_token, | |
| anthropic_key=anthropic_key, | |
| ) | |
| # Create tool router with user context for token injection | |
| tool_router = ToolRouter( | |
| self.config.mcpServers, | |
| hf_token=hf_token, # Pass user's HF token | |
| ) | |
| # Create the agent session with user's keys if provided | |
| session = Session( | |
| event_queue, | |
| config=self.config, | |
| tool_router=tool_router, | |
| anthropic_key=anthropic_key, | |
| hf_token=hf_token, | |
| ) | |
| # Create wrapper | |
| agent_session = AgentSession( | |
| session_id=session_id, | |
| session=session, | |
| tool_router=tool_router, | |
| submission_queue=submission_queue, | |
| user_id=user_id, | |
| user_context=user_context, | |
| ) | |
| async with self._lock: | |
| self.sessions[session_id] = agent_session | |
| # Start the agent loop task | |
| task = asyncio.create_task( | |
| self._run_session(session_id, submission_queue, event_queue, tool_router) | |
| ) | |
| agent_session.task = task | |
| return session_id | |
| async def create_session_with_id( | |
| self, | |
| session_id: str, | |
| user_id: Optional[str] = None, | |
| hf_token: Optional[str] = None, | |
| anthropic_key: Optional[str] = None, | |
| history: Optional[list[dict]] = None, | |
| ) -> str: | |
| """Create an agent session with a specific ID (for resuming). | |
| Args: | |
| session_id: The session ID to use | |
| user_id: Optional owner user ID (HF username) | |
| hf_token: Optional HF OAuth token for the user | |
| anthropic_key: Optional Anthropic API key for the user | |
| Returns: | |
| Session ID | |
| """ | |
| # Check if session already exists in memory | |
| if session_id in self.sessions: | |
| return session_id | |
| # Create queues for this session | |
| submission_queue: asyncio.Queue = asyncio.Queue() | |
| event_queue: asyncio.Queue = asyncio.Queue() | |
| # Create user context if user is authenticated | |
| user_context = None | |
| if user_id and hf_token: | |
| user_context = UserContext( | |
| user_id=user_id, | |
| hf_token=hf_token, | |
| anthropic_key=anthropic_key, | |
| ) | |
| # Create tool router with user context for token injection | |
| tool_router = ToolRouter( | |
| self.config.mcpServers, | |
| hf_token=hf_token, | |
| ) | |
| # Create the agent session | |
| session = Session( | |
| event_queue, | |
| config=self.config, | |
| tool_router=tool_router, | |
| anthropic_key=anthropic_key, | |
| hf_token=hf_token, | |
| ) | |
| # Restore conversation history if provided | |
| if history: | |
| from litellm import Message | |
| for msg in history: | |
| if msg.get("role") != "system": # Skip system, we have our own | |
| session.context_manager.items.append(Message(**msg)) | |
| # Create wrapper with the specified session_id | |
| agent_session = AgentSession( | |
| session_id=session_id, | |
| session=session, | |
| tool_router=tool_router, | |
| submission_queue=submission_queue, | |
| user_id=user_id, | |
| user_context=user_context, | |
| ) | |
| async with self._lock: | |
| self.sessions[session_id] = agent_session | |
| # Start the agent loop task | |
| task = asyncio.create_task( | |
| self._run_session(session_id, submission_queue, event_queue, tool_router) | |
| ) | |
| agent_session.task = task | |
| return session_id | |
| def _check_session_ownership( | |
| self, session_id: str, user_id: Optional[str] | |
| ) -> AgentSession | None: | |
| """Check if user owns the session and return it if so. | |
| Args: | |
| session_id: Session to check | |
| user_id: User to verify ownership | |
| Returns: | |
| AgentSession if user owns it or session has no owner, None otherwise | |
| """ | |
| agent_session = self.sessions.get(session_id) | |
| if not agent_session: | |
| return None | |
| # If session has an owner, verify it matches | |
| if agent_session.user_id and agent_session.user_id != user_id: | |
| logger.warning( | |
| f"User {user_id} attempted to access session {session_id} " | |
| f"owned by {agent_session.user_id}" | |
| ) | |
| return None | |
| return agent_session | |
| async def _run_session( | |
| self, | |
| session_id: str, | |
| submission_queue: asyncio.Queue, | |
| event_queue: asyncio.Queue, | |
| tool_router: ToolRouter, | |
| ) -> None: | |
| """Run the agent loop for a session and forward events to SSE clients.""" | |
| agent_session = self.sessions.get(session_id) | |
| if not agent_session: | |
| logger.error(f"Session {session_id} not found") | |
| return | |
| session = agent_session.session | |
| # Start event forwarder task | |
| event_forwarder = asyncio.create_task( | |
| self._forward_events(session_id, event_queue) | |
| ) | |
| try: | |
| async with tool_router: | |
| # Send ready event | |
| await session.send_event( | |
| Event(event_type="ready", data={"message": "Agent initialized"}) | |
| ) | |
| while session.is_running: | |
| try: | |
| # Wait for submission with timeout to allow checking is_running | |
| submission = await asyncio.wait_for( | |
| submission_queue.get(), timeout=1.0 | |
| ) | |
| should_continue = await process_submission(session, submission) | |
| # Persist session after each turn | |
| await self._persist_session(session_id) | |
| if not should_continue: | |
| break | |
| except asyncio.TimeoutError: | |
| continue | |
| except asyncio.CancelledError: | |
| break | |
| except Exception as e: | |
| logger.error(f"Error in session {session_id}: {e}") | |
| await session.send_event( | |
| Event(event_type="error", data={"error": str(e)}) | |
| ) | |
| finally: | |
| event_forwarder.cancel() | |
| try: | |
| await event_forwarder | |
| except asyncio.CancelledError: | |
| pass | |
| async with self._lock: | |
| if session_id in self.sessions: | |
| self.sessions[session_id].is_active = False | |
| async def _persist_session(self, session_id: str) -> None: | |
| """Persist session state to HF Dataset.""" | |
| agent_session = self.sessions.get(session_id) | |
| if not agent_session: | |
| return | |
| # Serialize full message objects (preserves tool_calls, tool_call_id, etc.) | |
| messages = [ | |
| item.model_dump() for item in agent_session.session.context_manager.items | |
| ] | |
| await lifecycle_manager.persist_session( | |
| session_id=session_id, | |
| user_id=agent_session.user_id or "anonymous", | |
| messages=messages, | |
| config={"model_name": self.config.model_name}, | |
| title=f"Chat {session_id[:8]}", | |
| status="active", | |
| ) | |
| async def _forward_events( | |
| self, session_id: str, event_queue: asyncio.Queue | |
| ) -> None: | |
| """Forward events from the agent to SSE clients.""" | |
| while True: | |
| try: | |
| event: Event = await event_queue.get() | |
| await event_manager.send_event(session_id, event.event_type, event.data) | |
| except asyncio.CancelledError: | |
| break | |
| except Exception as e: | |
| logger.error(f"Error forwarding event for {session_id}: {e}") | |
| async def submit( | |
| self, session_id: str, operation: Operation, user_id: Optional[str] = None | |
| ) -> bool: | |
| """Submit an operation to a session. | |
| Args: | |
| session_id: Target session | |
| operation: Operation to submit | |
| user_id: User making the request (for ownership check) | |
| Returns: | |
| True if submitted successfully | |
| """ | |
| async with self._lock: | |
| agent_session = self._check_session_ownership(session_id, user_id) | |
| if not agent_session or not agent_session.is_active: | |
| logger.warning( | |
| f"Session {session_id} not found, inactive, or access denied" | |
| ) | |
| return False | |
| submission = Submission(id=f"sub_{uuid.uuid4().hex[:8]}", operation=operation) | |
| await agent_session.submission_queue.put(submission) | |
| return True | |
| async def submit_user_input( | |
| self, session_id: str, text: str, user_id: Optional[str] = None | |
| ) -> bool: | |
| """Submit user input to a session.""" | |
| operation = Operation(op_type=OpType.USER_INPUT, data={"text": text}) | |
| return await self.submit(session_id, operation, user_id) | |
| async def submit_approval( | |
| self, | |
| session_id: str, | |
| approvals: list[dict[str, Any]], | |
| user_id: Optional[str] = None, | |
| ) -> bool: | |
| """Submit tool approvals to a session.""" | |
| operation = Operation( | |
| op_type=OpType.EXEC_APPROVAL, data={"approvals": approvals} | |
| ) | |
| return await self.submit(session_id, operation, user_id) | |
| async def interrupt(self, session_id: str, user_id: Optional[str] = None) -> bool: | |
| """Interrupt a session.""" | |
| operation = Operation(op_type=OpType.INTERRUPT) | |
| return await self.submit(session_id, operation, user_id) | |
| async def undo(self, session_id: str, user_id: Optional[str] = None) -> bool: | |
| """Undo last turn in a session.""" | |
| operation = Operation(op_type=OpType.UNDO) | |
| return await self.submit(session_id, operation, user_id) | |
| async def compact(self, session_id: str, user_id: Optional[str] = None) -> bool: | |
| """Compact context in a session.""" | |
| operation = Operation(op_type=OpType.COMPACT) | |
| return await self.submit(session_id, operation, user_id) | |
| async def shutdown_session( | |
| self, session_id: str, user_id: Optional[str] = None | |
| ) -> bool: | |
| """Shutdown a specific session.""" | |
| operation = Operation(op_type=OpType.SHUTDOWN) | |
| success = await self.submit(session_id, operation, user_id) | |
| if success: | |
| async with self._lock: | |
| agent_session = self._check_session_ownership(session_id, user_id) | |
| if agent_session and agent_session.task: | |
| # Wait for task to complete | |
| try: | |
| await asyncio.wait_for(agent_session.task, timeout=5.0) | |
| except asyncio.TimeoutError: | |
| agent_session.task.cancel() | |
| # Close and persist the session | |
| agent_session = self.sessions.get(session_id) | |
| if agent_session: | |
| messages = [ | |
| item.model_dump() | |
| for item in agent_session.session.context_manager.items | |
| ] | |
| await lifecycle_manager.close_session( | |
| session_id=session_id, | |
| user_id=agent_session.user_id or "anonymous", | |
| messages=messages, | |
| config={"model_name": self.config.model_name}, | |
| title=f"Chat {session_id[:8]}", | |
| ) | |
| return success | |
| async def delete_session( | |
| self, session_id: str, user_id: Optional[str] = None | |
| ) -> bool: | |
| """Delete a session entirely. | |
| Args: | |
| session_id: Session to delete | |
| user_id: User making the request (for ownership check) | |
| Returns: | |
| True if deleted successfully | |
| """ | |
| async with self._lock: | |
| agent_session = self._check_session_ownership(session_id, user_id) | |
| if not agent_session: | |
| return False | |
| # Remove from sessions | |
| self.sessions.pop(session_id, None) | |
| # Cancel the task if running | |
| if agent_session.task and not agent_session.task.done(): | |
| agent_session.task.cancel() | |
| try: | |
| await agent_session.task | |
| except asyncio.CancelledError: | |
| pass | |
| return True | |
| async def update_session_model( | |
| self, session_id: str, model_name: str, user_id: Optional[str] = None | |
| ) -> bool: | |
| """Update the model for an active session. | |
| Args: | |
| session_id: Target session | |
| model_name: New model name | |
| user_id: User making the request (for ownership check) | |
| Returns: | |
| True if updated successfully | |
| """ | |
| async with self._lock: | |
| agent_session = self._check_session_ownership(session_id, user_id) | |
| if not agent_session: | |
| return False | |
| # Update the model in session config | |
| agent_session.session.config.model_name = model_name | |
| # Persist the change | |
| await self._persist_session(session_id) | |
| return True | |
| def get_session_info( | |
| self, session_id: str, user_id: Optional[str] = None | |
| ) -> dict[str, Any] | None: | |
| """Get information about a session. | |
| Args: | |
| session_id: Session to get info for | |
| user_id: User making the request (for ownership check) | |
| Returns: | |
| Session info dict or None if not found/access denied | |
| """ | |
| agent_session = self._check_session_ownership(session_id, user_id) | |
| if not agent_session: | |
| return None | |
| return { | |
| "session_id": session_id, | |
| "created_at": agent_session.created_at.isoformat(), | |
| "is_active": agent_session.is_active, | |
| "message_count": len(agent_session.session.context_manager.items), | |
| "user_id": agent_session.user_id, | |
| "model_name": agent_session.session.config.model_name, | |
| } | |
| def list_sessions(self, user_id: Optional[str] = None) -> list[dict[str, Any]]: | |
| """List sessions, optionally filtered by user. | |
| Args: | |
| user_id: If provided, only return sessions owned by this user | |
| Returns: | |
| List of session info dicts | |
| """ | |
| results = [] | |
| for sid, agent_session in self.sessions.items(): | |
| # If user_id provided, only include sessions owned by that user | |
| if user_id and agent_session.user_id != user_id: | |
| continue | |
| info = self.get_session_info(sid, user_id) | |
| if info: | |
| results.append(info) | |
| return results | |
| def active_session_count(self) -> int: | |
| """Get count of active sessions.""" | |
| return sum(1 for s in self.sessions.values() if s.is_active) | |
| # Global session manager instance | |
| session_manager = SessionManager() | |