| """Session manager for handling multiple concurrent agent sessions.""" |
|
|
| import asyncio |
| import json |
| import logging |
| import os |
| import uuid |
| from dataclasses import dataclass, field |
| from datetime import datetime |
| from pathlib import Path |
| from typing import Any, Optional |
|
|
| from agent.config import load_config |
| from agent.core.agent_loop import process_submission |
| from agent.core.session import Event, OpType, Session |
| from agent.core.session_persistence import get_session_store |
| from agent.core.tools import ToolRouter |
| from agent.messaging.gateway import NotificationGateway |
|
|
| |
| PROJECT_ROOT = Path(__file__).parent.parent |
| DEFAULT_CONFIG_PATH = str(PROJECT_ROOT / "configs" / "frontend_agent_config.json") |
|
|
|
|
| |
| @dataclass |
| class Operation: |
| """Operation to be executed by the agent.""" |
|
|
| op_type: OpType |
| data: Optional[dict[str, Any]] = None |
|
|
|
|
| @dataclass |
| class Submission: |
| """Submission to the agent loop.""" |
|
|
| id: str |
| operation: Operation |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class EventBroadcaster: |
| """Reads from the agent's event queue and fans out to SSE subscribers. |
| |
| Events that arrive when no subscribers are listening are discarded by |
| this in-memory fanout. Durable replay is handled by session_persistence. |
| """ |
|
|
| def __init__(self, event_queue: asyncio.Queue): |
| self._source = event_queue |
| self._subscribers: dict[int, asyncio.Queue] = {} |
| self._counter = 0 |
|
|
| def subscribe(self) -> tuple[int, asyncio.Queue]: |
| """Create a new subscriber. Returns (id, queue).""" |
| self._counter += 1 |
| sub_id = self._counter |
| q: asyncio.Queue = asyncio.Queue() |
| self._subscribers[sub_id] = q |
| return sub_id, q |
|
|
| def unsubscribe(self, sub_id: int) -> None: |
| self._subscribers.pop(sub_id, None) |
|
|
| async def run(self) -> None: |
| """Main loop — reads from source queue and broadcasts.""" |
| while True: |
| try: |
| event: Event = await self._source.get() |
| msg = { |
| "event_type": event.event_type, |
| "data": event.data, |
| "seq": event.seq, |
| } |
| for q in self._subscribers.values(): |
| await q.put(msg) |
| except asyncio.CancelledError: |
| break |
| except Exception as e: |
| logger.error(f"EventBroadcaster error: {e}") |
|
|
|
|
| @dataclass |
| class AgentSession: |
| """Wrapper for an agent session with its associated resources.""" |
|
|
| session_id: str |
| session: Session |
| tool_router: ToolRouter |
| submission_queue: asyncio.Queue |
| user_id: str = "dev" |
| hf_username: str | None = None |
| hf_token: str | None = None |
| task: asyncio.Task | None = None |
| created_at: datetime = field(default_factory=datetime.utcnow) |
| is_active: bool = True |
| is_processing: bool = False |
| broadcaster: Any = None |
| title: str | None = None |
| |
| |
| |
| claude_counted: bool = False |
|
|
|
|
| class SessionCapacityError(Exception): |
| """Raised when no more sessions can be created.""" |
|
|
| def __init__(self, message: str, error_type: str = "global") -> None: |
| super().__init__(message) |
| self.error_type = error_type |
|
|
|
|
| |
| |
| |
| |
| |
| MAX_SESSIONS: int = 200 |
| MAX_SESSIONS_PER_USER: int = 10 |
| DEFAULT_YOLO_COST_CAP_USD: float = 5.0 |
| SANDBOX_SHUTDOWN_CLEANUP_CONCURRENCY: int = 10 |
| SANDBOX_SHUTDOWN_CLEANUP_TIMEOUT_S: float = 60.0 |
|
|
|
|
| class SessionManager: |
| """Manages multiple concurrent agent sessions.""" |
|
|
| def __init__(self, config_path: str | None = None) -> None: |
| self.config = load_config(config_path or DEFAULT_CONFIG_PATH) |
| self.messaging_gateway = NotificationGateway(self.config.messaging) |
| self.sessions: dict[str, AgentSession] = {} |
| self._lock = asyncio.Lock() |
| self.persistence_store = None |
|
|
| async def start(self) -> None: |
| """Start shared background resources.""" |
| self.persistence_store = get_session_store() |
| await self.persistence_store.init() |
| await self.messaging_gateway.start() |
|
|
| async def close(self) -> None: |
| """Flush and close shared background resources.""" |
| await self._cleanup_all_sandboxes_on_close() |
| await self.messaging_gateway.close() |
| if self.persistence_store is not None: |
| await self.persistence_store.close() |
|
|
| def _store(self): |
| if self.persistence_store is None: |
| self.persistence_store = get_session_store() |
| return self.persistence_store |
|
|
| def _count_user_sessions(self, user_id: str) -> int: |
| """Count active sessions owned by a specific user.""" |
| return sum( |
| 1 for s in self.sessions.values() if s.user_id == user_id and s.is_active |
| ) |
|
|
| def _create_session_sync( |
| self, |
| *, |
| session_id: str, |
| user_id: str, |
| hf_username: str | None, |
| hf_token: str | None, |
| model: str | None, |
| event_queue: asyncio.Queue, |
| notification_destinations: list[str] | None = None, |
| ) -> tuple[ToolRouter, Session]: |
| """Build blocking per-session resources in a worker thread.""" |
| import time as _time |
|
|
| t0 = _time.monotonic() |
| tool_router = ToolRouter(self.config.mcpServers, hf_token=hf_token) |
| |
| |
| session_config = self.config.model_copy(deep=True) |
| if model: |
| session_config.model_name = model |
| session = Session( |
| event_queue=event_queue, |
| config=session_config, |
| tool_router=tool_router, |
| hf_token=hf_token, |
| user_id=user_id, |
| hf_username=hf_username, |
| notification_gateway=self.messaging_gateway, |
| notification_destinations=notification_destinations or [], |
| session_id=session_id, |
| persistence_store=self._store(), |
| ) |
| t1 = _time.monotonic() |
| logger.info("Session initialized in %.2fs", t1 - t0) |
| return tool_router, session |
|
|
| def _serialize_messages(self, session: Session) -> list[dict[str, Any]]: |
| return [msg.model_dump(mode="json") for msg in session.context_manager.items] |
|
|
| def _serialize_pending_approval(self, session: Session) -> list[dict[str, Any]]: |
| pending = session.pending_approval or {} |
| tool_calls = pending.get("tool_calls") or [] |
| serialized: list[dict[str, Any]] = [] |
| for tc in tool_calls: |
| if hasattr(tc, "model_dump"): |
| serialized.append(tc.model_dump(mode="json")) |
| elif isinstance(tc, dict): |
| serialized.append(tc) |
| return serialized |
|
|
| @staticmethod |
| def _pending_tools_for_api(session: Session) -> list[dict[str, Any]] | None: |
| pending = session.pending_approval or {} |
| tool_calls = pending.get("tool_calls") or [] |
| if not tool_calls: |
| return None |
| result: list[dict[str, Any]] = [] |
| for tc in tool_calls: |
| try: |
| args = json.loads(tc.function.arguments) |
| except (json.JSONDecodeError, AttributeError, TypeError): |
| args = {} |
| result.append( |
| { |
| "tool": getattr(tc.function, "name", None), |
| "tool_call_id": getattr(tc, "id", None), |
| "arguments": args, |
| } |
| ) |
| return result |
|
|
| def _restore_pending_approval( |
| self, session: Session, pending_approval: list[dict[str, Any]] | None |
| ) -> None: |
| if not pending_approval: |
| session.pending_approval = None |
| return |
| from litellm import ChatCompletionMessageToolCall as ToolCall |
|
|
| restored = [] |
| for raw in pending_approval: |
| try: |
| if "function" in raw: |
| restored.append(ToolCall(**raw)) |
| else: |
| restored.append( |
| ToolCall( |
| id=raw["tool_call_id"], |
| type="function", |
| function={ |
| "name": raw["tool"], |
| "arguments": json.dumps(raw.get("arguments") or {}), |
| }, |
| ) |
| ) |
| except Exception as e: |
| logger.warning("Dropping malformed pending approval: %s", e) |
| session.pending_approval = {"tool_calls": restored} if restored else None |
|
|
| @staticmethod |
| def _pending_docs_for_api( |
| pending_approval: list[dict[str, Any]] | None, |
| ) -> list[dict[str, Any]] | None: |
| if not pending_approval: |
| return None |
| result: list[dict[str, Any]] = [] |
| for raw in pending_approval: |
| if "function" in raw: |
| function = raw.get("function") or {} |
| try: |
| args = json.loads(function.get("arguments") or "{}") |
| except (json.JSONDecodeError, TypeError): |
| args = {} |
| result.append( |
| { |
| "tool": function.get("name"), |
| "tool_call_id": raw.get("id"), |
| "arguments": args, |
| } |
| ) |
| elif {"tool", "tool_call_id"}.issubset(raw): |
| result.append( |
| { |
| "tool": raw.get("tool"), |
| "tool_call_id": raw.get("tool_call_id"), |
| "arguments": raw.get("arguments") or {}, |
| } |
| ) |
| return result or None |
|
|
| @staticmethod |
| def _runtime_state(agent_session: AgentSession) -> str: |
| if agent_session.session.pending_approval: |
| return "waiting_approval" |
| if agent_session.is_processing: |
| return "processing" |
| if not agent_session.is_active: |
| return "ended" |
| return "idle" |
|
|
| @staticmethod |
| def _auto_approval_summary(session: Session) -> dict[str, Any]: |
| if hasattr(session, "auto_approval_policy_summary"): |
| return session.auto_approval_policy_summary() |
| cap = getattr(session, "auto_approval_cost_cap_usd", None) |
| estimated = float( |
| getattr(session, "auto_approval_estimated_spend_usd", 0.0) or 0.0 |
| ) |
| remaining = None if cap is None else round(max(0.0, float(cap) - estimated), 4) |
| return { |
| "enabled": bool(getattr(session, "auto_approval_enabled", False)), |
| "cost_cap_usd": cap, |
| "estimated_spend_usd": round(estimated, 4), |
| "remaining_usd": remaining, |
| } |
|
|
| async def _start_agent_session( |
| self, |
| *, |
| agent_session: AgentSession, |
| event_queue: asyncio.Queue, |
| tool_router: ToolRouter, |
| ) -> AgentSession: |
| async with self._lock: |
| existing = self.sessions.get(agent_session.session_id) |
| if existing: |
| return existing |
| self.sessions[agent_session.session_id] = agent_session |
|
|
| task = asyncio.create_task( |
| self._run_session( |
| agent_session.session_id, |
| agent_session.submission_queue, |
| event_queue, |
| tool_router, |
| ) |
| ) |
| agent_session.task = task |
| return agent_session |
|
|
| @staticmethod |
| def _start_cpu_sandbox_preload(agent_session: AgentSession) -> None: |
| """Kick off a best-effort cpu-basic sandbox for the session.""" |
| try: |
| from agent.tools.sandbox_tool import start_cpu_sandbox_preload |
|
|
| start_cpu_sandbox_preload(agent_session.session) |
| except Exception as e: |
| logger.warning( |
| "Failed to start CPU sandbox preload for %s: %s", |
| agent_session.session_id, |
| e, |
| ) |
|
|
| @staticmethod |
| def _can_access_session(agent_session: AgentSession, user_id: str) -> bool: |
| return ( |
| user_id == "dev" |
| or agent_session.user_id == "dev" |
| or agent_session.user_id == user_id |
| ) |
|
|
| @staticmethod |
| def _update_hf_identity( |
| agent_session: AgentSession, |
| *, |
| hf_token: str | None, |
| hf_username: str | None, |
| ) -> None: |
| if hf_token: |
| agent_session.hf_token = hf_token |
| agent_session.session.hf_token = hf_token |
| if hf_username: |
| agent_session.hf_username = hf_username |
| agent_session.session.hf_username = hf_username |
|
|
| @staticmethod |
| def _has_active_sandbox_preload(agent_session: AgentSession) -> bool: |
| task = getattr(agent_session.session, "sandbox_preload_task", None) |
| return bool(task and not task.done()) |
|
|
| @staticmethod |
| def _preload_failed_for_missing_hf_token(agent_session: AgentSession) -> bool: |
| error = getattr(agent_session.session, "sandbox_preload_error", None) |
| return isinstance(error, str) and error.startswith("No HF token available") |
|
|
| def _restart_cpu_preload_if_token_recovered( |
| self, |
| agent_session: AgentSession, |
| *, |
| preload_sandbox: bool, |
| ) -> None: |
| if not preload_sandbox: |
| return |
| session = agent_session.session |
| if getattr(session, "sandbox", None): |
| return |
| if self._has_active_sandbox_preload(agent_session): |
| return |
| if not (agent_session.hf_token or getattr(session, "hf_token", None)): |
| return |
|
|
| if not self._preload_failed_for_missing_hf_token(agent_session): |
| return |
|
|
| session.sandbox_preload_error = None |
| session.sandbox_preload_task = None |
| session.sandbox_preload_cancel_event = None |
| self._start_cpu_sandbox_preload(agent_session) |
|
|
| async def _clear_persisted_sandbox_metadata(self, session_id: str) -> None: |
| try: |
| await self._store().update_session_fields( |
| session_id, |
| sandbox_space_id=None, |
| sandbox_hardware=None, |
| sandbox_owner=None, |
| sandbox_created_at=None, |
| sandbox_status="destroyed", |
| ) |
| except Exception as e: |
| logger.warning("Failed to clear sandbox metadata for %s: %s", session_id, e) |
|
|
| async def _cleanup_persisted_sandbox( |
| self, |
| session_id: str, |
| metadata: dict[str, Any], |
| *, |
| hf_token: str | None, |
| ) -> None: |
| """Delete a sandbox recorded by a previous backend process, if any.""" |
| space_id = metadata.get("sandbox_space_id") |
| if not isinstance(space_id, str) or not space_id: |
| return |
| if metadata.get("sandbox_status") == "destroyed": |
| return |
|
|
| tokens: list[tuple[str, str]] = [] |
| seen: set[str] = set() |
| for label, token in ( |
| ("user", hf_token), |
| ("admin", os.environ.get("HF_ADMIN_TOKEN")), |
| ): |
| if token and token not in seen: |
| tokens.append((label, token)) |
| seen.add(token) |
|
|
| if not tokens: |
| logger.warning( |
| "Cannot clean persisted sandbox %s for session %s: no HF token available", |
| space_id, |
| session_id, |
| ) |
| return |
|
|
| last_err: Exception | None = None |
| for label, token in tokens: |
| try: |
| from huggingface_hub import HfApi |
|
|
| api = HfApi(token=token) |
| await asyncio.to_thread( |
| api.delete_repo, |
| repo_id=space_id, |
| repo_type="space", |
| ) |
| logger.info( |
| "Deleted persisted sandbox %s for session %s with %s token", |
| space_id, |
| session_id, |
| label, |
| ) |
| await self._clear_persisted_sandbox_metadata(session_id) |
| return |
| except Exception as e: |
| status_code = getattr(getattr(e, "response", None), "status_code", None) |
| if status_code == 404: |
| logger.info( |
| "Persisted sandbox %s for session %s is already gone", |
| space_id, |
| session_id, |
| ) |
| await self._clear_persisted_sandbox_metadata(session_id) |
| return |
| last_err = e |
|
|
| logger.warning( |
| "Failed to delete persisted sandbox %s for session %s: %s", |
| space_id, |
| session_id, |
| last_err, |
| ) |
|
|
| async def persist_session_snapshot( |
| self, |
| agent_session: AgentSession, |
| *, |
| runtime_state: str | None = None, |
| status: str = "active", |
| ) -> None: |
| """Persist the current runtime context snapshot.""" |
| store = self._store() |
| if not getattr(store, "enabled", False): |
| return |
| try: |
| await store.save_snapshot( |
| session_id=agent_session.session_id, |
| user_id=agent_session.user_id, |
| model=agent_session.session.config.model_name, |
| title=agent_session.title, |
| messages=self._serialize_messages(agent_session.session), |
| runtime_state=runtime_state or self._runtime_state(agent_session), |
| status=status, |
| turn_count=agent_session.session.turn_count, |
| pending_approval=self._serialize_pending_approval( |
| agent_session.session |
| ), |
| claude_counted=agent_session.claude_counted, |
| created_at=agent_session.created_at, |
| notification_destinations=list( |
| agent_session.session.notification_destinations |
| ), |
| auto_approval_enabled=bool( |
| getattr(agent_session.session, "auto_approval_enabled", False) |
| ), |
| auto_approval_cost_cap_usd=getattr( |
| agent_session.session, "auto_approval_cost_cap_usd", None |
| ), |
| auto_approval_estimated_spend_usd=float( |
| getattr( |
| agent_session.session, |
| "auto_approval_estimated_spend_usd", |
| 0.0, |
| ) |
| or 0.0 |
| ), |
| ) |
| except Exception as e: |
| logger.warning( |
| "Failed to persist snapshot for %s: %s", |
| agent_session.session_id, |
| e, |
| ) |
|
|
| async def ensure_session_loaded( |
| self, |
| session_id: str, |
| user_id: str, |
| hf_token: str | None = None, |
| hf_username: str | None = None, |
| preload_sandbox: bool = True, |
| ) -> AgentSession | None: |
| """Return a live runtime session, lazily restoring it from Mongo.""" |
| async with self._lock: |
| existing = self.sessions.get(session_id) |
| if existing: |
| if self._can_access_session(existing, user_id): |
| self._update_hf_identity( |
| existing, |
| hf_token=hf_token, |
| hf_username=hf_username, |
| ) |
| self._restart_cpu_preload_if_token_recovered( |
| existing, |
| preload_sandbox=preload_sandbox, |
| ) |
| return existing |
| return None |
|
|
| store = self._store() |
| loaded = await store.load_session(session_id) |
| if not loaded: |
| return None |
|
|
| async with self._lock: |
| existing = self.sessions.get(session_id) |
| if existing: |
| if self._can_access_session(existing, user_id): |
| self._update_hf_identity( |
| existing, |
| hf_token=hf_token, |
| hf_username=hf_username, |
| ) |
| self._restart_cpu_preload_if_token_recovered( |
| existing, |
| preload_sandbox=preload_sandbox, |
| ) |
| return existing |
| return None |
|
|
| meta = loaded.get("metadata") or {} |
| owner = str(meta.get("user_id") or "") |
| if user_id != "dev" and owner != "dev" and owner != user_id: |
| return None |
|
|
| await self._cleanup_persisted_sandbox( |
| session_id, |
| meta, |
| hf_token=hf_token, |
| ) |
|
|
| from litellm import Message |
|
|
| model = meta.get("model") or self.config.model_name |
| event_queue: asyncio.Queue = asyncio.Queue() |
| submission_queue: asyncio.Queue = asyncio.Queue() |
| tool_router, session = await asyncio.to_thread( |
| self._create_session_sync, |
| session_id=session_id, |
| user_id=owner or user_id, |
| hf_username=hf_username, |
| hf_token=hf_token, |
| model=model, |
| event_queue=event_queue, |
| notification_destinations=meta.get("notification_destinations") or [], |
| ) |
|
|
| restored_messages: list[Message] = [] |
| for raw in loaded.get("messages") or []: |
| if not isinstance(raw, dict) or raw.get("role") == "system": |
| continue |
| try: |
| restored_messages.append(Message.model_validate(raw)) |
| except Exception as e: |
| logger.warning("Dropping malformed restored message: %s", e) |
| if restored_messages: |
| |
| |
| session.context_manager.items = [ |
| session.context_manager.items[0], |
| *restored_messages, |
| ] |
|
|
| self._restore_pending_approval(session, meta.get("pending_approval") or []) |
| session.turn_count = int(meta.get("turn_count") or 0) |
| session.auto_approval_enabled = bool(meta.get("auto_approval_enabled", False)) |
| raw_cap = meta.get("auto_approval_cost_cap_usd") |
| session.auto_approval_cost_cap_usd = ( |
| float(raw_cap) if isinstance(raw_cap, int | float) else None |
| ) |
| session.auto_approval_estimated_spend_usd = float( |
| meta.get("auto_approval_estimated_spend_usd") or 0.0 |
| ) |
|
|
| created_at = meta.get("created_at") |
| if not isinstance(created_at, datetime): |
| created_at = datetime.utcnow() |
|
|
| agent_session = AgentSession( |
| session_id=session_id, |
| session=session, |
| tool_router=tool_router, |
| submission_queue=submission_queue, |
| user_id=owner or user_id, |
| hf_username=hf_username, |
| hf_token=hf_token, |
| created_at=created_at, |
| is_active=True, |
| is_processing=False, |
| claude_counted=bool(meta.get("claude_counted")), |
| title=meta.get("title"), |
| ) |
| started = await self._start_agent_session( |
| agent_session=agent_session, |
| event_queue=event_queue, |
| tool_router=tool_router, |
| ) |
| if started is not agent_session: |
| self._update_hf_identity( |
| started, |
| hf_token=hf_token, |
| hf_username=hf_username, |
| ) |
| return started |
| if preload_sandbox: |
| self._start_cpu_sandbox_preload(agent_session) |
| logger.info("Restored session %s for user %s", session_id, owner or user_id) |
| return agent_session |
|
|
| async def create_session( |
| self, |
| user_id: str = "dev", |
| hf_username: str | None = None, |
| hf_token: str | None = None, |
| model: str | None = None, |
| is_pro: bool | None = None, |
| ) -> str: |
| """Create a new agent session and return its ID. |
| |
| Session() and ToolRouter() constructors contain blocking I/O |
| (e.g. HfApi().whoami(), litellm.get_max_tokens()) so they are |
| executed in a thread pool to avoid freezing the async event loop. |
| |
| Args: |
| user_id: The ID of the user who owns this session. |
| hf_username: The HF username/namespace used for personal trace uploads. |
| hf_token: The user's HF OAuth token, stored for tool execution. |
| model: Optional model override. When set, replaces ``model_name`` |
| on the per-session config clone. None falls back to the |
| config default. |
| |
| Raises: |
| SessionCapacityError: If the server or user has reached the |
| maximum number of concurrent sessions. |
| """ |
| |
| async with self._lock: |
| active_count = self.active_session_count |
| if active_count >= MAX_SESSIONS: |
| raise SessionCapacityError( |
| f"Server is at capacity ({active_count}/{MAX_SESSIONS} sessions). " |
| "Please try again later.", |
| error_type="global", |
| ) |
| if user_id != "dev": |
| user_count = self._count_user_sessions(user_id) |
| if user_count >= MAX_SESSIONS_PER_USER: |
| raise SessionCapacityError( |
| f"You have reached the maximum of {MAX_SESSIONS_PER_USER} " |
| "concurrent sessions. Please close an existing session first.", |
| error_type="per_user", |
| ) |
|
|
| session_id = str(uuid.uuid4()) |
|
|
| |
| submission_queue: asyncio.Queue = asyncio.Queue() |
| event_queue: asyncio.Queue = asyncio.Queue() |
|
|
| |
| tool_router, session = await asyncio.to_thread( |
| self._create_session_sync, |
| session_id=session_id, |
| user_id=user_id, |
| hf_username=hf_username, |
| hf_token=hf_token, |
| model=model, |
| event_queue=event_queue, |
| ) |
|
|
| |
| agent_session = AgentSession( |
| session_id=session_id, |
| session=session, |
| tool_router=tool_router, |
| submission_queue=submission_queue, |
| user_id=user_id, |
| hf_username=hf_username, |
| hf_token=hf_token, |
| ) |
|
|
| await self._start_agent_session( |
| agent_session=agent_session, |
| event_queue=event_queue, |
| tool_router=tool_router, |
| ) |
| await self.persist_session_snapshot(agent_session, runtime_state="idle") |
| self._start_cpu_sandbox_preload(agent_session) |
|
|
| if is_pro is not None and user_id and user_id != "dev": |
| await self._track_pro_status(agent_session, is_pro=is_pro) |
|
|
| logger.info(f"Created session {session_id} for user {user_id}") |
| return session_id |
|
|
| async def _track_pro_status( |
| self, agent_session: AgentSession, *, is_pro: bool |
| ) -> None: |
| """Update Mongo per-user Pro state and emit a one-shot conversion |
| event if the store reports a free→Pro transition. Best-effort: any |
| Mongo failure is swallowed so we never fail session creation on |
| telemetry.""" |
| store = self._store() |
| if not getattr(store, "enabled", False): |
| return |
| try: |
| result = await store.mark_pro_seen(agent_session.user_id, is_pro=is_pro) |
| except Exception as e: |
| logger.debug("mark_pro_seen failed: %s", e) |
| return |
| if not result or not result.get("converted"): |
| return |
| try: |
| from agent.core import telemetry |
|
|
| await telemetry.record_pro_conversion( |
| agent_session.session, |
| first_seen_at=result.get("first_seen_at"), |
| ) |
| except Exception as e: |
| logger.debug("record_pro_conversion failed: %s", e) |
|
|
| async def seed_from_summary(self, session_id: str, messages: list[dict]) -> int: |
| """Rehydrate a session from cached prior messages via summarization. |
| |
| Runs the standard summarization prompt (same one compaction uses) |
| over the provided messages, then seeds the new session's context |
| with that summary. Tool-call pairing concerns disappear because the |
| output is plain text. Returns the number of messages summarized. |
| """ |
| from litellm import Message |
|
|
| from agent.context_manager.manager import _RESTORE_PROMPT, summarize_messages |
|
|
| agent_session = self.sessions.get(session_id) |
| if not agent_session: |
| raise ValueError(f"Session {session_id} not found") |
|
|
| |
| parsed: list[Message] = [] |
| for raw in messages: |
| if raw.get("role") == "system": |
| continue |
| try: |
| parsed.append(Message.model_validate(raw)) |
| except Exception as e: |
| logger.warning("Dropping malformed message during seed: %s", e) |
|
|
| if not parsed: |
| return 0 |
|
|
| session = agent_session.session |
| |
| |
| |
| |
| tool_specs = None |
| try: |
| tool_specs = agent_session.tool_router.get_tool_specs_for_llm() |
| except Exception: |
| pass |
| try: |
| summary, _ = await summarize_messages( |
| parsed, |
| model_name=session.config.model_name, |
| hf_token=session.hf_token, |
| max_tokens=4000, |
| prompt=_RESTORE_PROMPT, |
| tool_specs=tool_specs, |
| session=session, |
| kind="restore", |
| ) |
| except Exception as e: |
| logger.error("Summary call failed during seed: %s", e) |
| raise |
|
|
| seed = Message( |
| role="user", |
| content=( |
| "[SYSTEM: Your prior memory of this conversation — written " |
| "in your own voice right before restart. Continue from here.]\n\n" |
| + (summary or "(no summary returned)") |
| ), |
| ) |
| session.context_manager.items.append(seed) |
| await self.persist_session_snapshot(agent_session, runtime_state="idle") |
| return len(parsed) |
|
|
| @staticmethod |
| async def _cleanup_sandbox(session: Session) -> None: |
| """Delete the sandbox Space if one was created for this session. |
| |
| Retries on transient failures (HF API 5xx, rate-limit, network blips) |
| with exponential backoff. A single missed delete = a permanently |
| orphaned Space, so the cost of an extra retry beats the alternative. |
| """ |
| from agent.tools.sandbox_tool import teardown_session_sandbox |
|
|
| await teardown_session_sandbox(session) |
|
|
| async def _cleanup_all_sandboxes_on_close(self) -> None: |
| """Best-effort sandbox cleanup for graceful backend shutdown.""" |
| async with self._lock: |
| agent_sessions = list(self.sessions.values()) |
| if not agent_sessions: |
| return |
|
|
| semaphore = asyncio.Semaphore(SANDBOX_SHUTDOWN_CLEANUP_CONCURRENCY) |
|
|
| async def _cleanup_one(agent_session: AgentSession) -> None: |
| async with semaphore: |
| try: |
| await self._cleanup_sandbox(agent_session.session) |
| except Exception as e: |
| logger.warning( |
| "Shutdown sandbox cleanup failed for %s: %s", |
| agent_session.session_id, |
| e, |
| ) |
|
|
| tasks = [ |
| asyncio.create_task(_cleanup_one(agent_session)) |
| for agent_session in agent_sessions |
| ] |
| try: |
| await asyncio.wait_for( |
| asyncio.gather(*tasks, return_exceptions=True), |
| timeout=SANDBOX_SHUTDOWN_CLEANUP_TIMEOUT_S, |
| ) |
| except asyncio.TimeoutError: |
| logger.warning( |
| "Timed out after %.0fs cleaning up sandboxes on shutdown; " |
| "orphan sweeper will handle any stragglers", |
| SANDBOX_SHUTDOWN_CLEANUP_TIMEOUT_S, |
| ) |
|
|
| async def _run_session( |
| self, |
| session_id: str, |
| submission_queue: asyncio.Queue, |
| event_queue: asyncio.Queue, |
| tool_router: ToolRouter, |
| ) -> None: |
| """Run the agent loop for a session and broadcast events via EventBroadcaster.""" |
| agent_session = self.sessions.get(session_id) |
| if not agent_session: |
| logger.error(f"Session {session_id} not found") |
| return |
|
|
| session = agent_session.session |
|
|
| |
| broadcaster = EventBroadcaster(event_queue) |
| agent_session.broadcaster = broadcaster |
| broadcast_task = asyncio.create_task(broadcaster.run()) |
|
|
| try: |
| async with tool_router: |
| |
| await session.send_event( |
| Event(event_type="ready", data={"message": "Agent initialized"}) |
| ) |
|
|
| while session.is_running: |
| try: |
| |
| submission = await asyncio.wait_for( |
| submission_queue.get(), timeout=1.0 |
| ) |
| agent_session.is_processing = True |
| try: |
| should_continue = await process_submission( |
| session, submission |
| ) |
| finally: |
| agent_session.is_processing = False |
| await self.persist_session_snapshot(agent_session) |
| if not should_continue: |
| break |
| except asyncio.TimeoutError: |
| continue |
| except asyncio.CancelledError: |
| logger.info(f"Session {session_id} cancelled") |
| break |
| except Exception as e: |
| logger.error(f"Error in session {session_id}: {e}") |
| await session.send_event( |
| Event(event_type="error", data={"error": str(e)}) |
| ) |
|
|
| finally: |
| broadcast_task.cancel() |
| try: |
| await broadcast_task |
| except asyncio.CancelledError: |
| pass |
|
|
| await self._cleanup_sandbox(session) |
|
|
| |
| |
| |
| if session.config.save_sessions: |
| try: |
| session.save_and_upload_detached( |
| session.config.session_dataset_repo |
| ) |
| except Exception as e: |
| logger.warning(f"Final-flush failed for {session_id}: {e}") |
|
|
| async with self._lock: |
| if session_id in self.sessions: |
| self.sessions[session_id].is_active = False |
| await self.persist_session_snapshot( |
| self.sessions[session_id], |
| runtime_state="ended", |
| status="ended", |
| ) |
|
|
| logger.info(f"Session {session_id} ended") |
|
|
| async def submit(self, session_id: str, operation: Operation) -> bool: |
| """Submit an operation to a session.""" |
| async with self._lock: |
| agent_session = self.sessions.get(session_id) |
|
|
| if not agent_session or not agent_session.is_active: |
| logger.warning(f"Session {session_id} not found or inactive") |
| return False |
|
|
| submission = Submission(id=f"sub_{uuid.uuid4().hex[:8]}", operation=operation) |
| await agent_session.submission_queue.put(submission) |
| return True |
|
|
| async def submit_user_input(self, session_id: str, text: str) -> bool: |
| """Submit user input to a session.""" |
| operation = Operation(op_type=OpType.USER_INPUT, data={"text": text}) |
| return await self.submit(session_id, operation) |
|
|
| async def submit_approval( |
| self, session_id: str, approvals: list[dict[str, Any]] |
| ) -> bool: |
| """Submit tool approvals to a session.""" |
| operation = Operation( |
| op_type=OpType.EXEC_APPROVAL, data={"approvals": approvals} |
| ) |
| return await self.submit(session_id, operation) |
|
|
| async def interrupt(self, session_id: str) -> bool: |
| """Interrupt a session by signalling cancellation directly (bypasses queue).""" |
| agent_session = self.sessions.get(session_id) |
| if not agent_session or not agent_session.is_active: |
| return False |
| agent_session.session.cancel() |
| return True |
|
|
| async def undo(self, session_id: str) -> bool: |
| """Undo last turn in a session.""" |
| operation = Operation(op_type=OpType.UNDO) |
| return await self.submit(session_id, operation) |
|
|
| async def truncate(self, session_id: str, user_message_index: int) -> bool: |
| """Truncate conversation to before a specific user message (direct, no queue).""" |
| async with self._lock: |
| agent_session = self.sessions.get(session_id) |
| if not agent_session or not agent_session.is_active: |
| return False |
| success = agent_session.session.context_manager.truncate_to_user_message( |
| user_message_index |
| ) |
| if success: |
| await self.persist_session_snapshot(agent_session, runtime_state="idle") |
| return success |
|
|
| async def compact(self, session_id: str) -> bool: |
| """Compact context in a session.""" |
| operation = Operation(op_type=OpType.COMPACT) |
| return await self.submit(session_id, operation) |
|
|
| async def shutdown_session(self, session_id: str) -> bool: |
| """Shutdown a specific session.""" |
| operation = Operation(op_type=OpType.SHUTDOWN) |
| success = await self.submit(session_id, operation) |
|
|
| if success: |
| async with self._lock: |
| agent_session = self.sessions.get(session_id) |
| if agent_session and agent_session.task: |
| |
| try: |
| await asyncio.wait_for(agent_session.task, timeout=5.0) |
| except asyncio.TimeoutError: |
| agent_session.task.cancel() |
|
|
| return success |
|
|
| async def delete_session(self, session_id: str) -> bool: |
| """Soft-delete a session and stop its runtime resources.""" |
| async with self._lock: |
| agent_session = self.sessions.pop(session_id, None) |
|
|
| if not agent_session: |
| await self._store().soft_delete_session(session_id) |
| return True |
|
|
| await self._store().soft_delete_session(session_id) |
|
|
| |
| await self._cleanup_sandbox(agent_session.session) |
|
|
| |
| if agent_session.task and not agent_session.task.done(): |
| agent_session.task.cancel() |
| try: |
| await agent_session.task |
| except asyncio.CancelledError: |
| pass |
|
|
| return True |
|
|
| async def teardown_sandbox(self, session_id: str) -> bool: |
| """Delete only this session's sandbox runtime, preserving chat state.""" |
| async with self._lock: |
| agent_session = self.sessions.get(session_id) |
|
|
| if not agent_session or not agent_session.is_active: |
| return False |
|
|
| await self._cleanup_sandbox(agent_session.session) |
| await self.persist_session_snapshot(agent_session, runtime_state="idle") |
| return True |
|
|
| async def update_session_title(self, session_id: str, title: str | None) -> None: |
| """Persist a user-visible title for sidebar rehydration.""" |
| agent_session = self.sessions.get(session_id) |
| if agent_session: |
| agent_session.title = title |
| await self._store().update_session_fields(session_id, title=title) |
|
|
| async def update_session_model(self, session_id: str, model_id: str) -> bool: |
| agent_session = self.sessions.get(session_id) |
| if not agent_session or not agent_session.is_active: |
| return False |
| agent_session.session.update_model(model_id) |
| await self.persist_session_snapshot(agent_session, runtime_state="idle") |
| return True |
|
|
| async def update_session_auto_approval( |
| self, |
| session_id: str, |
| *, |
| enabled: bool, |
| cost_cap_usd: float | None, |
| cap_provided: bool = False, |
| ) -> dict[str, Any]: |
| agent_session = self.sessions.get(session_id) |
| if not agent_session or not agent_session.is_active: |
| raise ValueError("Session not found or inactive") |
|
|
| session = agent_session.session |
| if enabled: |
| if not cap_provided and cost_cap_usd is None: |
| cost_cap_usd = getattr(session, "auto_approval_cost_cap_usd", None) |
| if cost_cap_usd is None: |
| cost_cap_usd = DEFAULT_YOLO_COST_CAP_USD |
| elif cost_cap_usd is None: |
| cost_cap_usd = DEFAULT_YOLO_COST_CAP_USD |
| else: |
| if not cap_provided: |
| cost_cap_usd = getattr(session, "auto_approval_cost_cap_usd", None) |
|
|
| if hasattr(session, "set_auto_approval_policy"): |
| session.set_auto_approval_policy( |
| enabled=enabled, |
| cost_cap_usd=cost_cap_usd, |
| ) |
| else: |
| session.auto_approval_enabled = bool(enabled) |
| session.auto_approval_cost_cap_usd = cost_cap_usd |
| await self.persist_session_snapshot(agent_session) |
| return self._auto_approval_summary(session) |
|
|
| def get_session_owner(self, session_id: str) -> str | None: |
| """Get the user_id that owns a session, or None if session doesn't exist.""" |
| agent_session = self.sessions.get(session_id) |
| if not agent_session: |
| return None |
| return agent_session.user_id |
|
|
| def verify_session_access(self, session_id: str, user_id: str) -> bool: |
| """Check if a user has access to a session. |
| |
| Returns True if: |
| - The session exists AND the user owns it |
| - The user_id is "dev" (dev mode bypass) |
| """ |
| owner = self.get_session_owner(session_id) |
| if owner is None: |
| return False |
| if user_id == "dev" or owner == "dev": |
| return True |
| return owner == user_id |
|
|
| def get_session_info(self, session_id: str) -> dict[str, Any] | None: |
| """Get information about a session.""" |
| agent_session = self.sessions.get(session_id) |
| if not agent_session: |
| return None |
|
|
| pending_approval = self._pending_tools_for_api(agent_session.session) |
|
|
| return { |
| "session_id": session_id, |
| "created_at": agent_session.created_at.isoformat(), |
| "is_active": agent_session.is_active, |
| "is_processing": agent_session.is_processing, |
| "message_count": len(agent_session.session.context_manager.items), |
| "user_id": agent_session.user_id, |
| "pending_approval": pending_approval, |
| "model": agent_session.session.config.model_name, |
| "title": agent_session.title, |
| "notification_destinations": list( |
| agent_session.session.notification_destinations |
| ), |
| "auto_approval": self._auto_approval_summary(agent_session.session), |
| } |
|
|
| def set_notification_destinations( |
| self, session_id: str, destinations: list[str] |
| ) -> list[str]: |
| """Replace the session's opted-in auto-notification destinations.""" |
| agent_session = self.sessions.get(session_id) |
| if not agent_session or not agent_session.is_active: |
| raise ValueError("Session not found or inactive") |
|
|
| normalized: list[str] = [] |
| seen: set[str] = set() |
| for raw_name in destinations: |
| name = raw_name.strip() |
| if not name: |
| raise ValueError("Destination names must not be empty") |
| destination = self.config.messaging.get_destination(name) |
| if destination is None: |
| raise ValueError(f"Unknown destination '{name}'") |
| if not destination.allow_auto_events: |
| raise ValueError(f"Destination '{name}' is not enabled for auto events") |
| if name not in seen: |
| normalized.append(name) |
| seen.add(name) |
|
|
| agent_session.session.set_notification_destinations(normalized) |
| return normalized |
|
|
| async def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]: |
| """List sessions, optionally filtered by user. |
| |
| Args: |
| user_id: If provided, only return sessions owned by this user. |
| If "dev", return all sessions (dev mode). |
| """ |
| results: list[dict[str, Any]] = [] |
| store = self._store() |
| if getattr(store, "enabled", False): |
| for row in await store.list_sessions(user_id or "dev"): |
| sid = row.get("session_id") or row.get("_id") |
| if not sid: |
| continue |
| runtime_info = self.get_session_info(str(sid)) |
| if runtime_info: |
| results.append(runtime_info) |
| continue |
| created_at = row.get("created_at") |
| if isinstance(created_at, datetime): |
| created_at_str = created_at.isoformat() |
| else: |
| created_at_str = str(created_at or datetime.utcnow().isoformat()) |
| pending = self._pending_docs_for_api(row.get("pending_approval") or []) |
| results.append( |
| { |
| "session_id": str(sid), |
| "created_at": created_at_str, |
| "is_active": row.get("status") != "ended", |
| "is_processing": row.get("runtime_state") == "processing", |
| "message_count": int(row.get("message_count") or 0), |
| "user_id": row.get("user_id") or "dev", |
| "pending_approval": pending or None, |
| "model": row.get("model"), |
| "title": row.get("title"), |
| "notification_destinations": row.get( |
| "notification_destinations" |
| ) |
| or [], |
| "auto_approval": { |
| "enabled": bool(row.get("auto_approval_enabled", False)), |
| "cost_cap_usd": row.get("auto_approval_cost_cap_usd"), |
| "estimated_spend_usd": float( |
| row.get("auto_approval_estimated_spend_usd") or 0.0 |
| ), |
| "remaining_usd": ( |
| None |
| if row.get("auto_approval_cost_cap_usd") is None |
| else round( |
| max( |
| 0.0, |
| float( |
| row.get("auto_approval_cost_cap_usd") or 0.0 |
| ) |
| - float( |
| row.get("auto_approval_estimated_spend_usd") |
| or 0.0 |
| ), |
| ), |
| 4, |
| ) |
| ), |
| }, |
| } |
| ) |
| return results |
|
|
| for sid in self.sessions: |
| info = self.get_session_info(sid) |
| if not info: |
| continue |
| if user_id and user_id != "dev" and info.get("user_id") != user_id: |
| continue |
| results.append(info) |
| return results |
|
|
| @property |
| def active_session_count(self) -> int: |
| """Get count of active sessions.""" |
| return sum(1 for s in self.sessions.values() if s.is_active) |
|
|
|
|
| |
| session_manager = SessionManager() |
|
|