Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| """Session manager for handling multiple concurrent agent sessions.""" | |
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import uuid | |
| from dataclasses import dataclass, field | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any, Optional | |
| from agent.config import load_config | |
| from agent.core.agent_loop import process_submission | |
| from agent.core.session import Event, OpType, Session | |
| from agent.core.session_persistence import get_session_store | |
| from agent.core.tools import ToolRouter | |
| from agent.messaging.gateway import NotificationGateway | |
| # Get project root (parent of backend directory) | |
| PROJECT_ROOT = Path(__file__).parent.parent | |
| DEFAULT_CONFIG_PATH = str(PROJECT_ROOT / "configs" / "frontend_agent_config.json") | |
| # These dataclasses match agent/main.py structure | |
| class Operation: | |
| """Operation to be executed by the agent.""" | |
| op_type: OpType | |
| data: Optional[dict[str, Any]] = None | |
| class Submission: | |
| """Submission to the agent loop.""" | |
| id: str | |
| operation: Operation | |
| logger = logging.getLogger(__name__) | |
| class EventBroadcaster: | |
| """Reads from the agent's event queue and fans out to SSE subscribers. | |
| Events that arrive when no subscribers are listening are discarded by | |
| this in-memory fanout. Durable replay is handled by session_persistence. | |
| """ | |
| def __init__(self, event_queue: asyncio.Queue): | |
| self._source = event_queue | |
| self._subscribers: dict[int, asyncio.Queue] = {} | |
| self._counter = 0 | |
| def subscribe(self) -> tuple[int, asyncio.Queue]: | |
| """Create a new subscriber. Returns (id, queue).""" | |
| self._counter += 1 | |
| sub_id = self._counter | |
| q: asyncio.Queue = asyncio.Queue() | |
| self._subscribers[sub_id] = q | |
| return sub_id, q | |
| def unsubscribe(self, sub_id: int) -> None: | |
| self._subscribers.pop(sub_id, None) | |
| async def run(self) -> None: | |
| """Main loop — reads from source queue and broadcasts.""" | |
| while True: | |
| try: | |
| event: Event = await self._source.get() | |
| msg = { | |
| "event_type": event.event_type, | |
| "data": event.data, | |
| "seq": event.seq, | |
| } | |
| for q in self._subscribers.values(): | |
| await q.put(msg) | |
| except asyncio.CancelledError: | |
| break | |
| except Exception as e: | |
| logger.error(f"EventBroadcaster error: {e}") | |
| class AgentSession: | |
| """Wrapper for an agent session with its associated resources.""" | |
| session_id: str | |
| session: Session | |
| tool_router: ToolRouter | |
| submission_queue: asyncio.Queue | |
| user_id: str = "dev" # Owner of this session | |
| hf_username: str | None = None # HF namespace used for personal trace uploads | |
| hf_token: str | None = None # User's HF OAuth token for tool execution | |
| task: asyncio.Task | None = None | |
| created_at: datetime = field(default_factory=datetime.utcnow) | |
| is_active: bool = True | |
| is_processing: bool = False # True while a submission is being executed | |
| broadcaster: Any = None | |
| title: str | None = None | |
| # True once this session has been counted against the user's daily | |
| # Claude quota. Guards double-counting when the user re-selects an | |
| # Anthropic model mid-session. | |
| claude_counted: bool = False | |
| class SessionCapacityError(Exception): | |
| """Raised when no more sessions can be created.""" | |
| def __init__(self, message: str, error_type: str = "global") -> None: | |
| super().__init__(message) | |
| self.error_type = error_type # "global" or "per_user" | |
| # ── Capacity limits ───────────────────────────────────────────────── | |
| # Sized for HF Spaces 8 vCPU / 32 GB RAM. | |
| # Each session uses ~10-20 MB (context, tools, queues, task); 200 × 20 MB | |
| # = 4 GB worst case, leaving plenty of headroom for the Python runtime | |
| # and per-request overhead. | |
| MAX_SESSIONS: int = 200 | |
| MAX_SESSIONS_PER_USER: int = 10 | |
| DEFAULT_YOLO_COST_CAP_USD: float = 5.0 | |
| SANDBOX_SHUTDOWN_CLEANUP_CONCURRENCY: int = 10 | |
| SANDBOX_SHUTDOWN_CLEANUP_TIMEOUT_S: float = 60.0 | |
| class SessionManager: | |
| """Manages multiple concurrent agent sessions.""" | |
| def __init__(self, config_path: str | None = None) -> None: | |
| self.config = load_config(config_path or DEFAULT_CONFIG_PATH) | |
| self.messaging_gateway = NotificationGateway(self.config.messaging) | |
| self.sessions: dict[str, AgentSession] = {} | |
| self._lock = asyncio.Lock() | |
| self.persistence_store = None | |
| async def start(self) -> None: | |
| """Start shared background resources.""" | |
| self.persistence_store = get_session_store() | |
| await self.persistence_store.init() | |
| await self.messaging_gateway.start() | |
| async def close(self) -> None: | |
| """Flush and close shared background resources.""" | |
| await self._cleanup_all_sandboxes_on_close() | |
| await self.messaging_gateway.close() | |
| if self.persistence_store is not None: | |
| await self.persistence_store.close() | |
| def _store(self): | |
| if self.persistence_store is None: | |
| self.persistence_store = get_session_store() | |
| return self.persistence_store | |
| def _count_user_sessions(self, user_id: str) -> int: | |
| """Count active sessions owned by a specific user.""" | |
| return sum( | |
| 1 for s in self.sessions.values() if s.user_id == user_id and s.is_active | |
| ) | |
| def _create_session_sync( | |
| self, | |
| *, | |
| session_id: str, | |
| user_id: str, | |
| hf_username: str | None, | |
| hf_token: str | None, | |
| model: str | None, | |
| event_queue: asyncio.Queue, | |
| notification_destinations: list[str] | None = None, | |
| ) -> tuple[ToolRouter, Session]: | |
| """Build blocking per-session resources in a worker thread.""" | |
| import time as _time | |
| t0 = _time.monotonic() | |
| tool_router = ToolRouter(self.config.mcpServers, hf_token=hf_token) | |
| # Deep-copy config so each session's model switches independently — | |
| # tab A picking GLM doesn't flip tab B off Claude. | |
| session_config = self.config.model_copy(deep=True) | |
| if model: | |
| session_config.model_name = model | |
| session = Session( | |
| event_queue=event_queue, | |
| config=session_config, | |
| tool_router=tool_router, | |
| hf_token=hf_token, | |
| user_id=user_id, | |
| hf_username=hf_username, | |
| notification_gateway=self.messaging_gateway, | |
| notification_destinations=notification_destinations or [], | |
| session_id=session_id, | |
| persistence_store=self._store(), | |
| ) | |
| t1 = _time.monotonic() | |
| logger.info("Session initialized in %.2fs", t1 - t0) | |
| return tool_router, session | |
| def _serialize_messages(self, session: Session) -> list[dict[str, Any]]: | |
| return [msg.model_dump(mode="json") for msg in session.context_manager.items] | |
| def _serialize_pending_approval(self, session: Session) -> list[dict[str, Any]]: | |
| pending = session.pending_approval or {} | |
| tool_calls = pending.get("tool_calls") or [] | |
| serialized: list[dict[str, Any]] = [] | |
| for tc in tool_calls: | |
| if hasattr(tc, "model_dump"): | |
| serialized.append(tc.model_dump(mode="json")) | |
| elif isinstance(tc, dict): | |
| serialized.append(tc) | |
| return serialized | |
| def _pending_tools_for_api(session: Session) -> list[dict[str, Any]] | None: | |
| pending = session.pending_approval or {} | |
| tool_calls = pending.get("tool_calls") or [] | |
| if not tool_calls: | |
| return None | |
| result: list[dict[str, Any]] = [] | |
| for tc in tool_calls: | |
| try: | |
| args = json.loads(tc.function.arguments) | |
| except (json.JSONDecodeError, AttributeError, TypeError): | |
| args = {} | |
| result.append( | |
| { | |
| "tool": getattr(tc.function, "name", None), | |
| "tool_call_id": getattr(tc, "id", None), | |
| "arguments": args, | |
| } | |
| ) | |
| return result | |
| def _restore_pending_approval( | |
| self, session: Session, pending_approval: list[dict[str, Any]] | None | |
| ) -> None: | |
| if not pending_approval: | |
| session.pending_approval = None | |
| return | |
| from litellm import ChatCompletionMessageToolCall as ToolCall | |
| restored = [] | |
| for raw in pending_approval: | |
| try: | |
| if "function" in raw: | |
| restored.append(ToolCall(**raw)) | |
| else: | |
| restored.append( | |
| ToolCall( | |
| id=raw["tool_call_id"], | |
| type="function", | |
| function={ | |
| "name": raw["tool"], | |
| "arguments": json.dumps(raw.get("arguments") or {}), | |
| }, | |
| ) | |
| ) | |
| except Exception as e: | |
| logger.warning("Dropping malformed pending approval: %s", e) | |
| session.pending_approval = {"tool_calls": restored} if restored else None | |
| def _pending_docs_for_api( | |
| pending_approval: list[dict[str, Any]] | None, | |
| ) -> list[dict[str, Any]] | None: | |
| if not pending_approval: | |
| return None | |
| result: list[dict[str, Any]] = [] | |
| for raw in pending_approval: | |
| if "function" in raw: | |
| function = raw.get("function") or {} | |
| try: | |
| args = json.loads(function.get("arguments") or "{}") | |
| except (json.JSONDecodeError, TypeError): | |
| args = {} | |
| result.append( | |
| { | |
| "tool": function.get("name"), | |
| "tool_call_id": raw.get("id"), | |
| "arguments": args, | |
| } | |
| ) | |
| elif {"tool", "tool_call_id"}.issubset(raw): | |
| result.append( | |
| { | |
| "tool": raw.get("tool"), | |
| "tool_call_id": raw.get("tool_call_id"), | |
| "arguments": raw.get("arguments") or {}, | |
| } | |
| ) | |
| return result or None | |
| def _runtime_state(agent_session: AgentSession) -> str: | |
| if agent_session.session.pending_approval: | |
| return "waiting_approval" | |
| if agent_session.is_processing: | |
| return "processing" | |
| if not agent_session.is_active: | |
| return "ended" | |
| return "idle" | |
| def _auto_approval_summary(session: Session) -> dict[str, Any]: | |
| if hasattr(session, "auto_approval_policy_summary"): | |
| return session.auto_approval_policy_summary() | |
| cap = getattr(session, "auto_approval_cost_cap_usd", None) | |
| estimated = float( | |
| getattr(session, "auto_approval_estimated_spend_usd", 0.0) or 0.0 | |
| ) | |
| remaining = None if cap is None else round(max(0.0, float(cap) - estimated), 4) | |
| return { | |
| "enabled": bool(getattr(session, "auto_approval_enabled", False)), | |
| "cost_cap_usd": cap, | |
| "estimated_spend_usd": round(estimated, 4), | |
| "remaining_usd": remaining, | |
| } | |
| async def _start_agent_session( | |
| self, | |
| *, | |
| agent_session: AgentSession, | |
| event_queue: asyncio.Queue, | |
| tool_router: ToolRouter, | |
| ) -> AgentSession: | |
| async with self._lock: | |
| existing = self.sessions.get(agent_session.session_id) | |
| if existing: | |
| return existing | |
| self.sessions[agent_session.session_id] = agent_session | |
| task = asyncio.create_task( | |
| self._run_session( | |
| agent_session.session_id, | |
| agent_session.submission_queue, | |
| event_queue, | |
| tool_router, | |
| ) | |
| ) | |
| agent_session.task = task | |
| return agent_session | |
| def _start_cpu_sandbox_preload(agent_session: AgentSession) -> None: | |
| """Kick off a best-effort cpu-basic sandbox for the session.""" | |
| try: | |
| from agent.tools.sandbox_tool import start_cpu_sandbox_preload | |
| start_cpu_sandbox_preload(agent_session.session) | |
| except Exception as e: | |
| logger.warning( | |
| "Failed to start CPU sandbox preload for %s: %s", | |
| agent_session.session_id, | |
| e, | |
| ) | |
| def _can_access_session(agent_session: AgentSession, user_id: str) -> bool: | |
| return ( | |
| user_id == "dev" | |
| or agent_session.user_id == "dev" | |
| or agent_session.user_id == user_id | |
| ) | |
| def _update_hf_identity( | |
| agent_session: AgentSession, | |
| *, | |
| hf_token: str | None, | |
| hf_username: str | None, | |
| ) -> None: | |
| if hf_token: | |
| agent_session.hf_token = hf_token | |
| agent_session.session.hf_token = hf_token | |
| if hf_username: | |
| agent_session.hf_username = hf_username | |
| agent_session.session.hf_username = hf_username | |
| def _has_active_sandbox_preload(agent_session: AgentSession) -> bool: | |
| task = getattr(agent_session.session, "sandbox_preload_task", None) | |
| return bool(task and not task.done()) | |
| def _preload_failed_for_missing_hf_token(agent_session: AgentSession) -> bool: | |
| error = getattr(agent_session.session, "sandbox_preload_error", None) | |
| return isinstance(error, str) and error.startswith("No HF token available") | |
| def _restart_cpu_preload_if_token_recovered( | |
| self, | |
| agent_session: AgentSession, | |
| *, | |
| preload_sandbox: bool, | |
| ) -> None: | |
| if not preload_sandbox: | |
| return | |
| session = agent_session.session | |
| if getattr(session, "sandbox", None): | |
| return | |
| if self._has_active_sandbox_preload(agent_session): | |
| return | |
| if not (agent_session.hf_token or getattr(session, "hf_token", None)): | |
| return | |
| if not self._preload_failed_for_missing_hf_token(agent_session): | |
| return | |
| session.sandbox_preload_error = None | |
| session.sandbox_preload_task = None | |
| session.sandbox_preload_cancel_event = None | |
| self._start_cpu_sandbox_preload(agent_session) | |
| async def _clear_persisted_sandbox_metadata(self, session_id: str) -> None: | |
| try: | |
| await self._store().update_session_fields( | |
| session_id, | |
| sandbox_space_id=None, | |
| sandbox_hardware=None, | |
| sandbox_owner=None, | |
| sandbox_created_at=None, | |
| sandbox_status="destroyed", | |
| ) | |
| except Exception as e: | |
| logger.warning("Failed to clear sandbox metadata for %s: %s", session_id, e) | |
| async def _cleanup_persisted_sandbox( | |
| self, | |
| session_id: str, | |
| metadata: dict[str, Any], | |
| *, | |
| hf_token: str | None, | |
| ) -> None: | |
| """Delete a sandbox recorded by a previous backend process, if any.""" | |
| space_id = metadata.get("sandbox_space_id") | |
| if not isinstance(space_id, str) or not space_id: | |
| return | |
| if metadata.get("sandbox_status") == "destroyed": | |
| return | |
| tokens: list[tuple[str, str]] = [] | |
| seen: set[str] = set() | |
| for label, token in ( | |
| ("user", hf_token), | |
| ("admin", os.environ.get("HF_ADMIN_TOKEN")), | |
| ): | |
| if token and token not in seen: | |
| tokens.append((label, token)) | |
| seen.add(token) | |
| if not tokens: | |
| logger.warning( | |
| "Cannot clean persisted sandbox %s for session %s: no HF token available", | |
| space_id, | |
| session_id, | |
| ) | |
| return | |
| last_err: Exception | None = None | |
| for label, token in tokens: | |
| try: | |
| from huggingface_hub import HfApi | |
| api = HfApi(token=token) | |
| await asyncio.to_thread( | |
| api.delete_repo, | |
| repo_id=space_id, | |
| repo_type="space", | |
| ) | |
| logger.info( | |
| "Deleted persisted sandbox %s for session %s with %s token", | |
| space_id, | |
| session_id, | |
| label, | |
| ) | |
| await self._clear_persisted_sandbox_metadata(session_id) | |
| return | |
| except Exception as e: | |
| status_code = getattr(getattr(e, "response", None), "status_code", None) | |
| if status_code == 404: | |
| logger.info( | |
| "Persisted sandbox %s for session %s is already gone", | |
| space_id, | |
| session_id, | |
| ) | |
| await self._clear_persisted_sandbox_metadata(session_id) | |
| return | |
| last_err = e | |
| logger.warning( | |
| "Failed to delete persisted sandbox %s for session %s: %s", | |
| space_id, | |
| session_id, | |
| last_err, | |
| ) | |
| async def persist_session_snapshot( | |
| self, | |
| agent_session: AgentSession, | |
| *, | |
| runtime_state: str | None = None, | |
| status: str = "active", | |
| ) -> None: | |
| """Persist the current runtime context snapshot.""" | |
| store = self._store() | |
| if not getattr(store, "enabled", False): | |
| return | |
| try: | |
| await store.save_snapshot( | |
| session_id=agent_session.session_id, | |
| user_id=agent_session.user_id, | |
| model=agent_session.session.config.model_name, | |
| title=agent_session.title, | |
| messages=self._serialize_messages(agent_session.session), | |
| runtime_state=runtime_state or self._runtime_state(agent_session), | |
| status=status, | |
| turn_count=agent_session.session.turn_count, | |
| pending_approval=self._serialize_pending_approval( | |
| agent_session.session | |
| ), | |
| claude_counted=agent_session.claude_counted, | |
| created_at=agent_session.created_at, | |
| notification_destinations=list( | |
| agent_session.session.notification_destinations | |
| ), | |
| auto_approval_enabled=bool( | |
| getattr(agent_session.session, "auto_approval_enabled", False) | |
| ), | |
| auto_approval_cost_cap_usd=getattr( | |
| agent_session.session, "auto_approval_cost_cap_usd", None | |
| ), | |
| auto_approval_estimated_spend_usd=float( | |
| getattr( | |
| agent_session.session, | |
| "auto_approval_estimated_spend_usd", | |
| 0.0, | |
| ) | |
| or 0.0 | |
| ), | |
| ) | |
| except Exception as e: | |
| logger.warning( | |
| "Failed to persist snapshot for %s: %s", | |
| agent_session.session_id, | |
| e, | |
| ) | |
| async def ensure_session_loaded( | |
| self, | |
| session_id: str, | |
| user_id: str, | |
| hf_token: str | None = None, | |
| hf_username: str | None = None, | |
| preload_sandbox: bool = True, | |
| ) -> AgentSession | None: | |
| """Return a live runtime session, lazily restoring it from Mongo.""" | |
| async with self._lock: | |
| existing = self.sessions.get(session_id) | |
| if existing: | |
| if self._can_access_session(existing, user_id): | |
| self._update_hf_identity( | |
| existing, | |
| hf_token=hf_token, | |
| hf_username=hf_username, | |
| ) | |
| self._restart_cpu_preload_if_token_recovered( | |
| existing, | |
| preload_sandbox=preload_sandbox, | |
| ) | |
| return existing | |
| return None | |
| store = self._store() | |
| loaded = await store.load_session(session_id) | |
| if not loaded: | |
| return None | |
| async with self._lock: | |
| existing = self.sessions.get(session_id) | |
| if existing: | |
| if self._can_access_session(existing, user_id): | |
| self._update_hf_identity( | |
| existing, | |
| hf_token=hf_token, | |
| hf_username=hf_username, | |
| ) | |
| self._restart_cpu_preload_if_token_recovered( | |
| existing, | |
| preload_sandbox=preload_sandbox, | |
| ) | |
| return existing | |
| return None | |
| meta = loaded.get("metadata") or {} | |
| owner = str(meta.get("user_id") or "") | |
| if user_id != "dev" and owner != "dev" and owner != user_id: | |
| return None | |
| await self._cleanup_persisted_sandbox( | |
| session_id, | |
| meta, | |
| hf_token=hf_token, | |
| ) | |
| from litellm import Message | |
| model = meta.get("model") or self.config.model_name | |
| event_queue: asyncio.Queue = asyncio.Queue() | |
| submission_queue: asyncio.Queue = asyncio.Queue() | |
| tool_router, session = await asyncio.to_thread( | |
| self._create_session_sync, | |
| session_id=session_id, | |
| user_id=owner or user_id, | |
| hf_username=hf_username, | |
| hf_token=hf_token, | |
| model=model, | |
| event_queue=event_queue, | |
| notification_destinations=meta.get("notification_destinations") or [], | |
| ) | |
| restored_messages: list[Message] = [] | |
| for raw in loaded.get("messages") or []: | |
| if not isinstance(raw, dict) or raw.get("role") == "system": | |
| continue | |
| try: | |
| restored_messages.append(Message.model_validate(raw)) | |
| except Exception as e: | |
| logger.warning("Dropping malformed restored message: %s", e) | |
| if restored_messages: | |
| # Keep the freshly-rendered system prompt, then attach the durable | |
| # non-system context so tools/date/user context stay current. | |
| session.context_manager.items = [ | |
| session.context_manager.items[0], | |
| *restored_messages, | |
| ] | |
| self._restore_pending_approval(session, meta.get("pending_approval") or []) | |
| session.turn_count = int(meta.get("turn_count") or 0) | |
| session.auto_approval_enabled = bool(meta.get("auto_approval_enabled", False)) | |
| raw_cap = meta.get("auto_approval_cost_cap_usd") | |
| session.auto_approval_cost_cap_usd = ( | |
| float(raw_cap) if isinstance(raw_cap, int | float) else None | |
| ) | |
| session.auto_approval_estimated_spend_usd = float( | |
| meta.get("auto_approval_estimated_spend_usd") or 0.0 | |
| ) | |
| created_at = meta.get("created_at") | |
| if not isinstance(created_at, datetime): | |
| created_at = datetime.utcnow() | |
| agent_session = AgentSession( | |
| session_id=session_id, | |
| session=session, | |
| tool_router=tool_router, | |
| submission_queue=submission_queue, | |
| user_id=owner or user_id, | |
| hf_username=hf_username, | |
| hf_token=hf_token, | |
| created_at=created_at, | |
| is_active=True, | |
| is_processing=False, | |
| claude_counted=bool(meta.get("claude_counted")), | |
| title=meta.get("title"), | |
| ) | |
| started = await self._start_agent_session( | |
| agent_session=agent_session, | |
| event_queue=event_queue, | |
| tool_router=tool_router, | |
| ) | |
| if started is not agent_session: | |
| self._update_hf_identity( | |
| started, | |
| hf_token=hf_token, | |
| hf_username=hf_username, | |
| ) | |
| return started | |
| if preload_sandbox: | |
| self._start_cpu_sandbox_preload(agent_session) | |
| logger.info("Restored session %s for user %s", session_id, owner or user_id) | |
| return agent_session | |
| async def create_session( | |
| self, | |
| user_id: str = "dev", | |
| hf_username: str | None = None, | |
| hf_token: str | None = None, | |
| model: str | None = None, | |
| is_pro: bool | None = None, | |
| ) -> str: | |
| """Create a new agent session and return its ID. | |
| Session() and ToolRouter() constructors contain blocking I/O | |
| (e.g. HfApi().whoami(), litellm.get_max_tokens()) so they are | |
| executed in a thread pool to avoid freezing the async event loop. | |
| Args: | |
| user_id: The ID of the user who owns this session. | |
| hf_username: The HF username/namespace used for personal trace uploads. | |
| hf_token: The user's HF OAuth token, stored for tool execution. | |
| model: Optional model override. When set, replaces ``model_name`` | |
| on the per-session config clone. None falls back to the | |
| config default. | |
| Raises: | |
| SessionCapacityError: If the server or user has reached the | |
| maximum number of concurrent sessions. | |
| """ | |
| # ── Capacity checks ────────────────────────────────────────── | |
| async with self._lock: | |
| active_count = self.active_session_count | |
| if active_count >= MAX_SESSIONS: | |
| raise SessionCapacityError( | |
| f"Server is at capacity ({active_count}/{MAX_SESSIONS} sessions). " | |
| "Please try again later.", | |
| error_type="global", | |
| ) | |
| if user_id != "dev": | |
| user_count = self._count_user_sessions(user_id) | |
| if user_count >= MAX_SESSIONS_PER_USER: | |
| raise SessionCapacityError( | |
| f"You have reached the maximum of {MAX_SESSIONS_PER_USER} " | |
| "concurrent sessions. Please close an existing session first.", | |
| error_type="per_user", | |
| ) | |
| session_id = str(uuid.uuid4()) | |
| # Create queues for this session | |
| submission_queue: asyncio.Queue = asyncio.Queue() | |
| event_queue: asyncio.Queue = asyncio.Queue() | |
| # Run blocking constructors in a thread to keep the event loop responsive. | |
| tool_router, session = await asyncio.to_thread( | |
| self._create_session_sync, | |
| session_id=session_id, | |
| user_id=user_id, | |
| hf_username=hf_username, | |
| hf_token=hf_token, | |
| model=model, | |
| event_queue=event_queue, | |
| ) | |
| # Create wrapper | |
| agent_session = AgentSession( | |
| session_id=session_id, | |
| session=session, | |
| tool_router=tool_router, | |
| submission_queue=submission_queue, | |
| user_id=user_id, | |
| hf_username=hf_username, | |
| hf_token=hf_token, | |
| ) | |
| await self._start_agent_session( | |
| agent_session=agent_session, | |
| event_queue=event_queue, | |
| tool_router=tool_router, | |
| ) | |
| await self.persist_session_snapshot(agent_session, runtime_state="idle") | |
| self._start_cpu_sandbox_preload(agent_session) | |
| if is_pro is not None and user_id and user_id != "dev": | |
| await self._track_pro_status(agent_session, is_pro=is_pro) | |
| logger.info(f"Created session {session_id} for user {user_id}") | |
| return session_id | |
| async def _track_pro_status( | |
| self, agent_session: AgentSession, *, is_pro: bool | |
| ) -> None: | |
| """Update Mongo per-user Pro state and emit a one-shot conversion | |
| event if the store reports a free→Pro transition. Best-effort: any | |
| Mongo failure is swallowed so we never fail session creation on | |
| telemetry.""" | |
| store = self._store() | |
| if not getattr(store, "enabled", False): | |
| return | |
| try: | |
| result = await store.mark_pro_seen(agent_session.user_id, is_pro=is_pro) | |
| except Exception as e: | |
| logger.debug("mark_pro_seen failed: %s", e) | |
| return | |
| if not result or not result.get("converted"): | |
| return | |
| try: | |
| from agent.core import telemetry | |
| await telemetry.record_pro_conversion( | |
| agent_session.session, | |
| first_seen_at=result.get("first_seen_at"), | |
| ) | |
| except Exception as e: | |
| logger.debug("record_pro_conversion failed: %s", e) | |
| async def seed_from_summary(self, session_id: str, messages: list[dict]) -> int: | |
| """Rehydrate a session from cached prior messages via summarization. | |
| Runs the standard summarization prompt (same one compaction uses) | |
| over the provided messages, then seeds the new session's context | |
| with that summary. Tool-call pairing concerns disappear because the | |
| output is plain text. Returns the number of messages summarized. | |
| """ | |
| from litellm import Message | |
| from agent.context_manager.manager import _RESTORE_PROMPT, summarize_messages | |
| agent_session = self.sessions.get(session_id) | |
| if not agent_session: | |
| raise ValueError(f"Session {session_id} not found") | |
| # Parse into Message objects, tolerating malformed entries. | |
| parsed: list[Message] = [] | |
| for raw in messages: | |
| if raw.get("role") == "system": | |
| continue # the new session has its own system prompt | |
| try: | |
| parsed.append(Message.model_validate(raw)) | |
| except Exception as e: | |
| logger.warning("Dropping malformed message during seed: %s", e) | |
| if not parsed: | |
| return 0 | |
| session = agent_session.session | |
| # Pass the real tool specs so the summarizer sees what the agent | |
| # actually has — otherwise Anthropic's modify_params injects a | |
| # dummy tool and the summarizer editorializes that the original | |
| # tool calls were fabricated. | |
| tool_specs = None | |
| try: | |
| tool_specs = agent_session.tool_router.get_tool_specs_for_llm() | |
| except Exception: | |
| pass | |
| try: | |
| summary, _ = await summarize_messages( | |
| parsed, | |
| model_name=session.config.model_name, | |
| hf_token=session.hf_token, | |
| max_tokens=4000, | |
| prompt=_RESTORE_PROMPT, | |
| tool_specs=tool_specs, | |
| session=session, | |
| kind="restore", | |
| ) | |
| except Exception as e: | |
| logger.error("Summary call failed during seed: %s", e) | |
| raise | |
| seed = Message( | |
| role="user", | |
| content=( | |
| "[SYSTEM: Your prior memory of this conversation — written " | |
| "in your own voice right before restart. Continue from here.]\n\n" | |
| + (summary or "(no summary returned)") | |
| ), | |
| ) | |
| session.context_manager.items.append(seed) | |
| await self.persist_session_snapshot(agent_session, runtime_state="idle") | |
| return len(parsed) | |
| async def _cleanup_sandbox(session: Session) -> None: | |
| """Delete the sandbox Space if one was created for this session. | |
| Retries on transient failures (HF API 5xx, rate-limit, network blips) | |
| with exponential backoff. A single missed delete = a permanently | |
| orphaned Space, so the cost of an extra retry beats the alternative. | |
| """ | |
| from agent.tools.sandbox_tool import teardown_session_sandbox | |
| await teardown_session_sandbox(session) | |
| async def _cleanup_all_sandboxes_on_close(self) -> None: | |
| """Best-effort sandbox cleanup for graceful backend shutdown.""" | |
| async with self._lock: | |
| agent_sessions = list(self.sessions.values()) | |
| if not agent_sessions: | |
| return | |
| semaphore = asyncio.Semaphore(SANDBOX_SHUTDOWN_CLEANUP_CONCURRENCY) | |
| async def _cleanup_one(agent_session: AgentSession) -> None: | |
| async with semaphore: | |
| try: | |
| await self._cleanup_sandbox(agent_session.session) | |
| except Exception as e: | |
| logger.warning( | |
| "Shutdown sandbox cleanup failed for %s: %s", | |
| agent_session.session_id, | |
| e, | |
| ) | |
| tasks = [ | |
| asyncio.create_task(_cleanup_one(agent_session)) | |
| for agent_session in agent_sessions | |
| ] | |
| try: | |
| await asyncio.wait_for( | |
| asyncio.gather(*tasks, return_exceptions=True), | |
| timeout=SANDBOX_SHUTDOWN_CLEANUP_TIMEOUT_S, | |
| ) | |
| except asyncio.TimeoutError: | |
| logger.warning( | |
| "Timed out after %.0fs cleaning up sandboxes on shutdown; " | |
| "orphan sweeper will handle any stragglers", | |
| SANDBOX_SHUTDOWN_CLEANUP_TIMEOUT_S, | |
| ) | |
| async def _run_session( | |
| self, | |
| session_id: str, | |
| submission_queue: asyncio.Queue, | |
| event_queue: asyncio.Queue, | |
| tool_router: ToolRouter, | |
| ) -> None: | |
| """Run the agent loop for a session and broadcast events via EventBroadcaster.""" | |
| agent_session = self.sessions.get(session_id) | |
| if not agent_session: | |
| logger.error(f"Session {session_id} not found") | |
| return | |
| session = agent_session.session | |
| # Start event broadcaster task | |
| broadcaster = EventBroadcaster(event_queue) | |
| agent_session.broadcaster = broadcaster | |
| broadcast_task = asyncio.create_task(broadcaster.run()) | |
| try: | |
| async with tool_router: | |
| # Send ready event | |
| await session.send_event( | |
| Event(event_type="ready", data={"message": "Agent initialized"}) | |
| ) | |
| while session.is_running: | |
| try: | |
| # Wait for submission with timeout to allow checking is_running | |
| submission = await asyncio.wait_for( | |
| submission_queue.get(), timeout=1.0 | |
| ) | |
| agent_session.is_processing = True | |
| try: | |
| should_continue = await process_submission( | |
| session, submission | |
| ) | |
| finally: | |
| agent_session.is_processing = False | |
| await self.persist_session_snapshot(agent_session) | |
| if not should_continue: | |
| break | |
| except asyncio.TimeoutError: | |
| continue | |
| except asyncio.CancelledError: | |
| logger.info(f"Session {session_id} cancelled") | |
| break | |
| except Exception as e: | |
| logger.error(f"Error in session {session_id}: {e}") | |
| await session.send_event( | |
| Event(event_type="error", data={"error": str(e)}) | |
| ) | |
| finally: | |
| broadcast_task.cancel() | |
| try: | |
| await broadcast_task | |
| except asyncio.CancelledError: | |
| pass | |
| await self._cleanup_sandbox(session) | |
| # Final-flush: always save on session death so we capture ended | |
| # sessions even if the client disconnects without /shutdown. | |
| # Idempotent via session_id key; detached subprocess. | |
| if session.config.save_sessions: | |
| try: | |
| session.save_and_upload_detached( | |
| session.config.session_dataset_repo | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Final-flush failed for {session_id}: {e}") | |
| async with self._lock: | |
| if session_id in self.sessions: | |
| self.sessions[session_id].is_active = False | |
| await self.persist_session_snapshot( | |
| self.sessions[session_id], | |
| runtime_state="ended", | |
| status="ended", | |
| ) | |
| logger.info(f"Session {session_id} ended") | |
| async def submit(self, session_id: str, operation: Operation) -> bool: | |
| """Submit an operation to a session.""" | |
| async with self._lock: | |
| agent_session = self.sessions.get(session_id) | |
| if not agent_session or not agent_session.is_active: | |
| logger.warning(f"Session {session_id} not found or inactive") | |
| return False | |
| submission = Submission(id=f"sub_{uuid.uuid4().hex[:8]}", operation=operation) | |
| await agent_session.submission_queue.put(submission) | |
| return True | |
| async def submit_user_input(self, session_id: str, text: str) -> bool: | |
| """Submit user input to a session.""" | |
| operation = Operation(op_type=OpType.USER_INPUT, data={"text": text}) | |
| return await self.submit(session_id, operation) | |
| async def submit_approval( | |
| self, session_id: str, approvals: list[dict[str, Any]] | |
| ) -> bool: | |
| """Submit tool approvals to a session.""" | |
| operation = Operation( | |
| op_type=OpType.EXEC_APPROVAL, data={"approvals": approvals} | |
| ) | |
| return await self.submit(session_id, operation) | |
| async def interrupt(self, session_id: str) -> bool: | |
| """Interrupt a session by signalling cancellation directly (bypasses queue).""" | |
| agent_session = self.sessions.get(session_id) | |
| if not agent_session or not agent_session.is_active: | |
| return False | |
| agent_session.session.cancel() | |
| return True | |
| async def undo(self, session_id: str) -> bool: | |
| """Undo last turn in a session.""" | |
| operation = Operation(op_type=OpType.UNDO) | |
| return await self.submit(session_id, operation) | |
| async def truncate(self, session_id: str, user_message_index: int) -> bool: | |
| """Truncate conversation to before a specific user message (direct, no queue).""" | |
| async with self._lock: | |
| agent_session = self.sessions.get(session_id) | |
| if not agent_session or not agent_session.is_active: | |
| return False | |
| success = agent_session.session.context_manager.truncate_to_user_message( | |
| user_message_index | |
| ) | |
| if success: | |
| await self.persist_session_snapshot(agent_session, runtime_state="idle") | |
| return success | |
| async def compact(self, session_id: str) -> bool: | |
| """Compact context in a session.""" | |
| operation = Operation(op_type=OpType.COMPACT) | |
| return await self.submit(session_id, operation) | |
| async def shutdown_session(self, session_id: str) -> bool: | |
| """Shutdown a specific session.""" | |
| operation = Operation(op_type=OpType.SHUTDOWN) | |
| success = await self.submit(session_id, operation) | |
| if success: | |
| async with self._lock: | |
| agent_session = self.sessions.get(session_id) | |
| if agent_session and agent_session.task: | |
| # Wait for task to complete | |
| try: | |
| await asyncio.wait_for(agent_session.task, timeout=5.0) | |
| except asyncio.TimeoutError: | |
| agent_session.task.cancel() | |
| return success | |
| async def delete_session(self, session_id: str) -> bool: | |
| """Soft-delete a session and stop its runtime resources.""" | |
| async with self._lock: | |
| agent_session = self.sessions.pop(session_id, None) | |
| if not agent_session: | |
| await self._store().soft_delete_session(session_id) | |
| return True | |
| await self._store().soft_delete_session(session_id) | |
| # Clean up sandbox Space before cancelling the task | |
| await self._cleanup_sandbox(agent_session.session) | |
| # Cancel the task if running | |
| if agent_session.task and not agent_session.task.done(): | |
| agent_session.task.cancel() | |
| try: | |
| await agent_session.task | |
| except asyncio.CancelledError: | |
| pass | |
| return True | |
| async def teardown_sandbox(self, session_id: str) -> bool: | |
| """Delete only this session's sandbox runtime, preserving chat state.""" | |
| async with self._lock: | |
| agent_session = self.sessions.get(session_id) | |
| if not agent_session or not agent_session.is_active: | |
| return False | |
| await self._cleanup_sandbox(agent_session.session) | |
| await self.persist_session_snapshot(agent_session, runtime_state="idle") | |
| return True | |
| async def update_session_title(self, session_id: str, title: str | None) -> None: | |
| """Persist a user-visible title for sidebar rehydration.""" | |
| agent_session = self.sessions.get(session_id) | |
| if agent_session: | |
| agent_session.title = title | |
| await self._store().update_session_fields(session_id, title=title) | |
| async def update_session_model(self, session_id: str, model_id: str) -> bool: | |
| agent_session = self.sessions.get(session_id) | |
| if not agent_session or not agent_session.is_active: | |
| return False | |
| agent_session.session.update_model(model_id) | |
| await self.persist_session_snapshot(agent_session, runtime_state="idle") | |
| return True | |
| async def update_session_auto_approval( | |
| self, | |
| session_id: str, | |
| *, | |
| enabled: bool, | |
| cost_cap_usd: float | None, | |
| cap_provided: bool = False, | |
| ) -> dict[str, Any]: | |
| agent_session = self.sessions.get(session_id) | |
| if not agent_session or not agent_session.is_active: | |
| raise ValueError("Session not found or inactive") | |
| session = agent_session.session | |
| if enabled: | |
| if not cap_provided and cost_cap_usd is None: | |
| cost_cap_usd = getattr(session, "auto_approval_cost_cap_usd", None) | |
| if cost_cap_usd is None: | |
| cost_cap_usd = DEFAULT_YOLO_COST_CAP_USD | |
| elif cost_cap_usd is None: | |
| cost_cap_usd = DEFAULT_YOLO_COST_CAP_USD | |
| else: | |
| if not cap_provided: | |
| cost_cap_usd = getattr(session, "auto_approval_cost_cap_usd", None) | |
| if hasattr(session, "set_auto_approval_policy"): | |
| session.set_auto_approval_policy( | |
| enabled=enabled, | |
| cost_cap_usd=cost_cap_usd, | |
| ) | |
| else: | |
| session.auto_approval_enabled = bool(enabled) | |
| session.auto_approval_cost_cap_usd = cost_cap_usd | |
| await self.persist_session_snapshot(agent_session) | |
| return self._auto_approval_summary(session) | |
| def get_session_owner(self, session_id: str) -> str | None: | |
| """Get the user_id that owns a session, or None if session doesn't exist.""" | |
| agent_session = self.sessions.get(session_id) | |
| if not agent_session: | |
| return None | |
| return agent_session.user_id | |
| def verify_session_access(self, session_id: str, user_id: str) -> bool: | |
| """Check if a user has access to a session. | |
| Returns True if: | |
| - The session exists AND the user owns it | |
| - The user_id is "dev" (dev mode bypass) | |
| """ | |
| owner = self.get_session_owner(session_id) | |
| if owner is None: | |
| return False | |
| if user_id == "dev" or owner == "dev": | |
| return True | |
| return owner == user_id | |
| def get_session_info(self, session_id: str) -> dict[str, Any] | None: | |
| """Get information about a session.""" | |
| agent_session = self.sessions.get(session_id) | |
| if not agent_session: | |
| return None | |
| pending_approval = self._pending_tools_for_api(agent_session.session) | |
| return { | |
| "session_id": session_id, | |
| "created_at": agent_session.created_at.isoformat(), | |
| "is_active": agent_session.is_active, | |
| "is_processing": agent_session.is_processing, | |
| "message_count": len(agent_session.session.context_manager.items), | |
| "user_id": agent_session.user_id, | |
| "pending_approval": pending_approval, | |
| "model": agent_session.session.config.model_name, | |
| "title": agent_session.title, | |
| "notification_destinations": list( | |
| agent_session.session.notification_destinations | |
| ), | |
| "auto_approval": self._auto_approval_summary(agent_session.session), | |
| } | |
| def set_notification_destinations( | |
| self, session_id: str, destinations: list[str] | |
| ) -> list[str]: | |
| """Replace the session's opted-in auto-notification destinations.""" | |
| agent_session = self.sessions.get(session_id) | |
| if not agent_session or not agent_session.is_active: | |
| raise ValueError("Session not found or inactive") | |
| normalized: list[str] = [] | |
| seen: set[str] = set() | |
| for raw_name in destinations: | |
| name = raw_name.strip() | |
| if not name: | |
| raise ValueError("Destination names must not be empty") | |
| destination = self.config.messaging.get_destination(name) | |
| if destination is None: | |
| raise ValueError(f"Unknown destination '{name}'") | |
| if not destination.allow_auto_events: | |
| raise ValueError(f"Destination '{name}' is not enabled for auto events") | |
| if name not in seen: | |
| normalized.append(name) | |
| seen.add(name) | |
| agent_session.session.set_notification_destinations(normalized) | |
| return normalized | |
| async def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]: | |
| """List sessions, optionally filtered by user. | |
| Args: | |
| user_id: If provided, only return sessions owned by this user. | |
| If "dev", return all sessions (dev mode). | |
| """ | |
| results: list[dict[str, Any]] = [] | |
| store = self._store() | |
| if getattr(store, "enabled", False): | |
| for row in await store.list_sessions(user_id or "dev"): | |
| sid = row.get("session_id") or row.get("_id") | |
| if not sid: | |
| continue | |
| runtime_info = self.get_session_info(str(sid)) | |
| if runtime_info: | |
| results.append(runtime_info) | |
| continue | |
| created_at = row.get("created_at") | |
| if isinstance(created_at, datetime): | |
| created_at_str = created_at.isoformat() | |
| else: | |
| created_at_str = str(created_at or datetime.utcnow().isoformat()) | |
| pending = self._pending_docs_for_api(row.get("pending_approval") or []) | |
| results.append( | |
| { | |
| "session_id": str(sid), | |
| "created_at": created_at_str, | |
| "is_active": row.get("status") != "ended", | |
| "is_processing": row.get("runtime_state") == "processing", | |
| "message_count": int(row.get("message_count") or 0), | |
| "user_id": row.get("user_id") or "dev", | |
| "pending_approval": pending or None, | |
| "model": row.get("model"), | |
| "title": row.get("title"), | |
| "notification_destinations": row.get( | |
| "notification_destinations" | |
| ) | |
| or [], | |
| "auto_approval": { | |
| "enabled": bool(row.get("auto_approval_enabled", False)), | |
| "cost_cap_usd": row.get("auto_approval_cost_cap_usd"), | |
| "estimated_spend_usd": float( | |
| row.get("auto_approval_estimated_spend_usd") or 0.0 | |
| ), | |
| "remaining_usd": ( | |
| None | |
| if row.get("auto_approval_cost_cap_usd") is None | |
| else round( | |
| max( | |
| 0.0, | |
| float( | |
| row.get("auto_approval_cost_cap_usd") or 0.0 | |
| ) | |
| - float( | |
| row.get("auto_approval_estimated_spend_usd") | |
| or 0.0 | |
| ), | |
| ), | |
| 4, | |
| ) | |
| ), | |
| }, | |
| } | |
| ) | |
| return results | |
| for sid in self.sessions: | |
| info = self.get_session_info(sid) | |
| if not info: | |
| continue | |
| if user_id and user_id != "dev" and info.get("user_id") != user_id: | |
| continue | |
| results.append(info) | |
| return results | |
| def active_session_count(self) -> int: | |
| """Get count of active sessions.""" | |
| return sum(1 for s in self.sessions.values() if s.is_active) | |
| # Global session manager instance | |
| session_manager = SessionManager() | |