ml-agent / backend /session_manager.py
akseljoonas's picture
akseljoonas HF Staff
deploy
850e85a
"""Session manager for handling multiple concurrent agent sessions with user isolation."""
import asyncio
import logging
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, Optional
from event_manager import event_manager
from lifecycle import lifecycle_manager
from agent.config import load_config
from agent.core.agent_loop import process_submission
from agent.core.session import Event, OpType, Session
from agent.core.tools import ToolRouter
# Get project root (parent of backend directory)
PROJECT_ROOT = Path(__file__).parent.parent
DEFAULT_CONFIG_PATH = str(PROJECT_ROOT / "configs" / "main_agent_config.json")
# These dataclasses match agent/main.py structure
@dataclass
class Operation:
"""Operation to be executed by the agent."""
op_type: OpType
data: Optional[dict[str, Any]] = None
@dataclass
class Submission:
"""Submission to the agent loop."""
id: str
operation: Operation
logger = logging.getLogger(__name__)
@dataclass
class UserContext:
"""Context for the authenticated user of a session."""
user_id: str # HF username
hf_token: str # HF OAuth access token
anthropic_key: Optional[str] = None # User's Anthropic API key
@dataclass
class AgentSession:
"""Wrapper for an agent session with its associated resources."""
session_id: str
session: Session
tool_router: ToolRouter
submission_queue: asyncio.Queue
task: asyncio.Task | None = None
created_at: datetime = field(default_factory=datetime.utcnow)
is_active: bool = True
# User isolation
user_id: Optional[str] = None # Owner of this session
user_context: Optional[UserContext] = None # User's auth context
class SessionManager:
"""Manages multiple concurrent agent sessions with user isolation."""
def __init__(self, config_path: str | None = None) -> None:
self.config = load_config(config_path or DEFAULT_CONFIG_PATH)
self.sessions: dict[str, AgentSession] = {}
self._lock = asyncio.Lock()
async def create_session(
self,
user_id: Optional[str] = None,
hf_token: Optional[str] = None,
anthropic_key: Optional[str] = None,
) -> str:
"""Create a new agent session and return its ID.
Args:
user_id: Optional owner user ID (HF username)
hf_token: Optional HF OAuth token for the user
anthropic_key: Optional Anthropic API key for the user
Returns:
Session ID (UUID)
"""
session_id = str(uuid.uuid4())
# Create queues for this session
submission_queue: asyncio.Queue = asyncio.Queue()
event_queue: asyncio.Queue = asyncio.Queue()
# Create user context if user is authenticated
user_context = None
if user_id and hf_token:
user_context = UserContext(
user_id=user_id,
hf_token=hf_token,
anthropic_key=anthropic_key,
)
# Create tool router with user context for token injection
tool_router = ToolRouter(
self.config.mcpServers,
hf_token=hf_token, # Pass user's HF token
)
# Create the agent session with user's keys if provided
session = Session(
event_queue,
config=self.config,
tool_router=tool_router,
anthropic_key=anthropic_key,
hf_token=hf_token,
)
# Create wrapper
agent_session = AgentSession(
session_id=session_id,
session=session,
tool_router=tool_router,
submission_queue=submission_queue,
user_id=user_id,
user_context=user_context,
)
async with self._lock:
self.sessions[session_id] = agent_session
# Start the agent loop task
task = asyncio.create_task(
self._run_session(session_id, submission_queue, event_queue, tool_router)
)
agent_session.task = task
return session_id
async def create_session_with_id(
self,
session_id: str,
user_id: Optional[str] = None,
hf_token: Optional[str] = None,
anthropic_key: Optional[str] = None,
history: Optional[list[dict]] = None,
) -> str:
"""Create an agent session with a specific ID (for resuming).
Args:
session_id: The session ID to use
user_id: Optional owner user ID (HF username)
hf_token: Optional HF OAuth token for the user
anthropic_key: Optional Anthropic API key for the user
Returns:
Session ID
"""
# Check if session already exists in memory
if session_id in self.sessions:
return session_id
# Create queues for this session
submission_queue: asyncio.Queue = asyncio.Queue()
event_queue: asyncio.Queue = asyncio.Queue()
# Create user context if user is authenticated
user_context = None
if user_id and hf_token:
user_context = UserContext(
user_id=user_id,
hf_token=hf_token,
anthropic_key=anthropic_key,
)
# Create tool router with user context for token injection
tool_router = ToolRouter(
self.config.mcpServers,
hf_token=hf_token,
)
# Create the agent session
session = Session(
event_queue,
config=self.config,
tool_router=tool_router,
anthropic_key=anthropic_key,
hf_token=hf_token,
)
# Restore conversation history if provided
if history:
from litellm import Message
for msg in history:
if msg.get("role") != "system": # Skip system, we have our own
session.context_manager.items.append(Message(**msg))
# Create wrapper with the specified session_id
agent_session = AgentSession(
session_id=session_id,
session=session,
tool_router=tool_router,
submission_queue=submission_queue,
user_id=user_id,
user_context=user_context,
)
async with self._lock:
self.sessions[session_id] = agent_session
# Start the agent loop task
task = asyncio.create_task(
self._run_session(session_id, submission_queue, event_queue, tool_router)
)
agent_session.task = task
return session_id
def _check_session_ownership(
self, session_id: str, user_id: Optional[str]
) -> AgentSession | None:
"""Check if user owns the session and return it if so.
Args:
session_id: Session to check
user_id: User to verify ownership
Returns:
AgentSession if user owns it or session has no owner, None otherwise
"""
agent_session = self.sessions.get(session_id)
if not agent_session:
return None
# If session has an owner, verify it matches
if agent_session.user_id and agent_session.user_id != user_id:
logger.warning(
f"User {user_id} attempted to access session {session_id} "
f"owned by {agent_session.user_id}"
)
return None
return agent_session
async def _run_session(
self,
session_id: str,
submission_queue: asyncio.Queue,
event_queue: asyncio.Queue,
tool_router: ToolRouter,
) -> None:
"""Run the agent loop for a session and forward events to SSE clients."""
agent_session = self.sessions.get(session_id)
if not agent_session:
logger.error(f"Session {session_id} not found")
return
session = agent_session.session
# Start event forwarder task
event_forwarder = asyncio.create_task(
self._forward_events(session_id, event_queue)
)
try:
async with tool_router:
# Send ready event
await session.send_event(
Event(event_type="ready", data={"message": "Agent initialized"})
)
while session.is_running:
try:
# Wait for submission with timeout to allow checking is_running
submission = await asyncio.wait_for(
submission_queue.get(), timeout=1.0
)
should_continue = await process_submission(session, submission)
# Persist session after each turn
await self._persist_session(session_id)
if not should_continue:
break
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in session {session_id}: {e}")
await session.send_event(
Event(event_type="error", data={"error": str(e)})
)
finally:
event_forwarder.cancel()
try:
await event_forwarder
except asyncio.CancelledError:
pass
async with self._lock:
if session_id in self.sessions:
self.sessions[session_id].is_active = False
async def _persist_session(self, session_id: str) -> None:
"""Persist session state to HF Dataset."""
agent_session = self.sessions.get(session_id)
if not agent_session:
return
# Serialize full message objects (preserves tool_calls, tool_call_id, etc.)
messages = [
item.model_dump() for item in agent_session.session.context_manager.items
]
await lifecycle_manager.persist_session(
session_id=session_id,
user_id=agent_session.user_id or "anonymous",
messages=messages,
config={"model_name": self.config.model_name},
title=f"Chat {session_id[:8]}",
status="active",
)
async def _forward_events(
self, session_id: str, event_queue: asyncio.Queue
) -> None:
"""Forward events from the agent to SSE clients."""
while True:
try:
event: Event = await event_queue.get()
await event_manager.send_event(session_id, event.event_type, event.data)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error forwarding event for {session_id}: {e}")
async def submit(
self, session_id: str, operation: Operation, user_id: Optional[str] = None
) -> bool:
"""Submit an operation to a session.
Args:
session_id: Target session
operation: Operation to submit
user_id: User making the request (for ownership check)
Returns:
True if submitted successfully
"""
async with self._lock:
agent_session = self._check_session_ownership(session_id, user_id)
if not agent_session or not agent_session.is_active:
logger.warning(
f"Session {session_id} not found, inactive, or access denied"
)
return False
submission = Submission(id=f"sub_{uuid.uuid4().hex[:8]}", operation=operation)
await agent_session.submission_queue.put(submission)
return True
async def submit_user_input(
self, session_id: str, text: str, user_id: Optional[str] = None
) -> bool:
"""Submit user input to a session."""
operation = Operation(op_type=OpType.USER_INPUT, data={"text": text})
return await self.submit(session_id, operation, user_id)
async def submit_approval(
self,
session_id: str,
approvals: list[dict[str, Any]],
user_id: Optional[str] = None,
) -> bool:
"""Submit tool approvals to a session."""
operation = Operation(
op_type=OpType.EXEC_APPROVAL, data={"approvals": approvals}
)
return await self.submit(session_id, operation, user_id)
async def interrupt(self, session_id: str, user_id: Optional[str] = None) -> bool:
"""Interrupt a session."""
operation = Operation(op_type=OpType.INTERRUPT)
return await self.submit(session_id, operation, user_id)
async def undo(self, session_id: str, user_id: Optional[str] = None) -> bool:
"""Undo last turn in a session."""
operation = Operation(op_type=OpType.UNDO)
return await self.submit(session_id, operation, user_id)
async def compact(self, session_id: str, user_id: Optional[str] = None) -> bool:
"""Compact context in a session."""
operation = Operation(op_type=OpType.COMPACT)
return await self.submit(session_id, operation, user_id)
async def shutdown_session(
self, session_id: str, user_id: Optional[str] = None
) -> bool:
"""Shutdown a specific session."""
operation = Operation(op_type=OpType.SHUTDOWN)
success = await self.submit(session_id, operation, user_id)
if success:
async with self._lock:
agent_session = self._check_session_ownership(session_id, user_id)
if agent_session and agent_session.task:
# Wait for task to complete
try:
await asyncio.wait_for(agent_session.task, timeout=5.0)
except asyncio.TimeoutError:
agent_session.task.cancel()
# Close and persist the session
agent_session = self.sessions.get(session_id)
if agent_session:
messages = [
item.model_dump()
for item in agent_session.session.context_manager.items
]
await lifecycle_manager.close_session(
session_id=session_id,
user_id=agent_session.user_id or "anonymous",
messages=messages,
config={"model_name": self.config.model_name},
title=f"Chat {session_id[:8]}",
)
return success
async def delete_session(
self, session_id: str, user_id: Optional[str] = None
) -> bool:
"""Delete a session entirely.
Args:
session_id: Session to delete
user_id: User making the request (for ownership check)
Returns:
True if deleted successfully
"""
async with self._lock:
agent_session = self._check_session_ownership(session_id, user_id)
if not agent_session:
return False
# Remove from sessions
self.sessions.pop(session_id, None)
# Cancel the task if running
if agent_session.task and not agent_session.task.done():
agent_session.task.cancel()
try:
await agent_session.task
except asyncio.CancelledError:
pass
return True
async def update_session_model(
self, session_id: str, model_name: str, user_id: Optional[str] = None
) -> bool:
"""Update the model for an active session.
Args:
session_id: Target session
model_name: New model name
user_id: User making the request (for ownership check)
Returns:
True if updated successfully
"""
async with self._lock:
agent_session = self._check_session_ownership(session_id, user_id)
if not agent_session:
return False
# Update the model in session config
agent_session.session.config.model_name = model_name
# Persist the change
await self._persist_session(session_id)
return True
def get_session_info(
self, session_id: str, user_id: Optional[str] = None
) -> dict[str, Any] | None:
"""Get information about a session.
Args:
session_id: Session to get info for
user_id: User making the request (for ownership check)
Returns:
Session info dict or None if not found/access denied
"""
agent_session = self._check_session_ownership(session_id, user_id)
if not agent_session:
return None
return {
"session_id": session_id,
"created_at": agent_session.created_at.isoformat(),
"is_active": agent_session.is_active,
"message_count": len(agent_session.session.context_manager.items),
"user_id": agent_session.user_id,
"model_name": agent_session.session.config.model_name,
}
def list_sessions(self, user_id: Optional[str] = None) -> list[dict[str, Any]]:
"""List sessions, optionally filtered by user.
Args:
user_id: If provided, only return sessions owned by this user
Returns:
List of session info dicts
"""
results = []
for sid, agent_session in self.sessions.items():
# If user_id provided, only include sessions owned by that user
if user_id and agent_session.user_id != user_id:
continue
info = self.get_session_info(sid, user_id)
if info:
results.append(info)
return results
@property
def active_session_count(self) -> int:
"""Get count of active sessions."""
return sum(1 for s in self.sessions.values() if s.is_active)
# Global session manager instance
session_manager = SessionManager()