"""Session manager for handling multiple concurrent agent sessions.""" import asyncio import json import logging import os import uuid from dataclasses import dataclass, field from datetime import datetime from pathlib import Path from typing import Any, Optional from agent.config import load_config from agent.core.agent_loop import process_submission from agent.core.session import Event, OpType, Session from agent.core.session_persistence import get_session_store from agent.core.tools import ToolRouter from agent.messaging.gateway import NotificationGateway # Get project root (parent of backend directory) PROJECT_ROOT = Path(__file__).parent.parent DEFAULT_CONFIG_PATH = str(PROJECT_ROOT / "configs" / "frontend_agent_config.json") # These dataclasses match agent/main.py structure @dataclass class Operation: """Operation to be executed by the agent.""" op_type: OpType data: Optional[dict[str, Any]] = None @dataclass class Submission: """Submission to the agent loop.""" id: str operation: Operation logger = logging.getLogger(__name__) class EventBroadcaster: """Reads from the agent's event queue and fans out to SSE subscribers. Events that arrive when no subscribers are listening are discarded by this in-memory fanout. Durable replay is handled by session_persistence. """ def __init__(self, event_queue: asyncio.Queue): self._source = event_queue self._subscribers: dict[int, asyncio.Queue] = {} self._counter = 0 def subscribe(self) -> tuple[int, asyncio.Queue]: """Create a new subscriber. Returns (id, queue).""" self._counter += 1 sub_id = self._counter q: asyncio.Queue = asyncio.Queue() self._subscribers[sub_id] = q return sub_id, q def unsubscribe(self, sub_id: int) -> None: self._subscribers.pop(sub_id, None) async def run(self) -> None: """Main loop — reads from source queue and broadcasts.""" while True: try: event: Event = await self._source.get() msg = { "event_type": event.event_type, "data": event.data, "seq": event.seq, } for q in self._subscribers.values(): await q.put(msg) except asyncio.CancelledError: break except Exception as e: logger.error(f"EventBroadcaster error: {e}") @dataclass class AgentSession: """Wrapper for an agent session with its associated resources.""" session_id: str session: Session tool_router: ToolRouter submission_queue: asyncio.Queue user_id: str = "dev" # Owner of this session hf_username: str | None = None # HF namespace used for personal trace uploads hf_token: str | None = None # User's HF OAuth token for tool execution task: asyncio.Task | None = None created_at: datetime = field(default_factory=datetime.utcnow) is_active: bool = True is_processing: bool = False # True while a submission is being executed broadcaster: Any = None title: str | None = None # True once this session has been counted against the user's daily # Claude quota. Guards double-counting when the user re-selects an # Anthropic model mid-session. claude_counted: bool = False class SessionCapacityError(Exception): """Raised when no more sessions can be created.""" def __init__(self, message: str, error_type: str = "global") -> None: super().__init__(message) self.error_type = error_type # "global" or "per_user" # ── Capacity limits ───────────────────────────────────────────────── # Sized for HF Spaces 8 vCPU / 32 GB RAM. # Each session uses ~10-20 MB (context, tools, queues, task); 200 × 20 MB # = 4 GB worst case, leaving plenty of headroom for the Python runtime # and per-request overhead. MAX_SESSIONS: int = 200 MAX_SESSIONS_PER_USER: int = 10 DEFAULT_YOLO_COST_CAP_USD: float = 5.0 SANDBOX_SHUTDOWN_CLEANUP_CONCURRENCY: int = 10 SANDBOX_SHUTDOWN_CLEANUP_TIMEOUT_S: float = 60.0 class SessionManager: """Manages multiple concurrent agent sessions.""" def __init__(self, config_path: str | None = None) -> None: self.config = load_config(config_path or DEFAULT_CONFIG_PATH) self.messaging_gateway = NotificationGateway(self.config.messaging) self.sessions: dict[str, AgentSession] = {} self._lock = asyncio.Lock() self.persistence_store = None async def start(self) -> None: """Start shared background resources.""" self.persistence_store = get_session_store() await self.persistence_store.init() await self.messaging_gateway.start() async def close(self) -> None: """Flush and close shared background resources.""" await self._cleanup_all_sandboxes_on_close() await self.messaging_gateway.close() if self.persistence_store is not None: await self.persistence_store.close() def _store(self): if self.persistence_store is None: self.persistence_store = get_session_store() return self.persistence_store def _count_user_sessions(self, user_id: str) -> int: """Count active sessions owned by a specific user.""" return sum( 1 for s in self.sessions.values() if s.user_id == user_id and s.is_active ) def _create_session_sync( self, *, session_id: str, user_id: str, hf_username: str | None, hf_token: str | None, model: str | None, event_queue: asyncio.Queue, notification_destinations: list[str] | None = None, ) -> tuple[ToolRouter, Session]: """Build blocking per-session resources in a worker thread.""" import time as _time t0 = _time.monotonic() tool_router = ToolRouter(self.config.mcpServers, hf_token=hf_token) # Deep-copy config so each session's model switches independently — # tab A picking GLM doesn't flip tab B off Claude. session_config = self.config.model_copy(deep=True) if model: session_config.model_name = model session = Session( event_queue=event_queue, config=session_config, tool_router=tool_router, hf_token=hf_token, user_id=user_id, hf_username=hf_username, notification_gateway=self.messaging_gateway, notification_destinations=notification_destinations or [], session_id=session_id, persistence_store=self._store(), ) t1 = _time.monotonic() logger.info("Session initialized in %.2fs", t1 - t0) return tool_router, session def _serialize_messages(self, session: Session) -> list[dict[str, Any]]: return [msg.model_dump(mode="json") for msg in session.context_manager.items] def _serialize_pending_approval(self, session: Session) -> list[dict[str, Any]]: pending = session.pending_approval or {} tool_calls = pending.get("tool_calls") or [] serialized: list[dict[str, Any]] = [] for tc in tool_calls: if hasattr(tc, "model_dump"): serialized.append(tc.model_dump(mode="json")) elif isinstance(tc, dict): serialized.append(tc) return serialized @staticmethod def _pending_tools_for_api(session: Session) -> list[dict[str, Any]] | None: pending = session.pending_approval or {} tool_calls = pending.get("tool_calls") or [] if not tool_calls: return None result: list[dict[str, Any]] = [] for tc in tool_calls: try: args = json.loads(tc.function.arguments) except (json.JSONDecodeError, AttributeError, TypeError): args = {} result.append( { "tool": getattr(tc.function, "name", None), "tool_call_id": getattr(tc, "id", None), "arguments": args, } ) return result def _restore_pending_approval( self, session: Session, pending_approval: list[dict[str, Any]] | None ) -> None: if not pending_approval: session.pending_approval = None return from litellm import ChatCompletionMessageToolCall as ToolCall restored = [] for raw in pending_approval: try: if "function" in raw: restored.append(ToolCall(**raw)) else: restored.append( ToolCall( id=raw["tool_call_id"], type="function", function={ "name": raw["tool"], "arguments": json.dumps(raw.get("arguments") or {}), }, ) ) except Exception as e: logger.warning("Dropping malformed pending approval: %s", e) session.pending_approval = {"tool_calls": restored} if restored else None @staticmethod def _pending_docs_for_api( pending_approval: list[dict[str, Any]] | None, ) -> list[dict[str, Any]] | None: if not pending_approval: return None result: list[dict[str, Any]] = [] for raw in pending_approval: if "function" in raw: function = raw.get("function") or {} try: args = json.loads(function.get("arguments") or "{}") except (json.JSONDecodeError, TypeError): args = {} result.append( { "tool": function.get("name"), "tool_call_id": raw.get("id"), "arguments": args, } ) elif {"tool", "tool_call_id"}.issubset(raw): result.append( { "tool": raw.get("tool"), "tool_call_id": raw.get("tool_call_id"), "arguments": raw.get("arguments") or {}, } ) return result or None @staticmethod def _runtime_state(agent_session: AgentSession) -> str: if agent_session.session.pending_approval: return "waiting_approval" if agent_session.is_processing: return "processing" if not agent_session.is_active: return "ended" return "idle" @staticmethod def _auto_approval_summary(session: Session) -> dict[str, Any]: if hasattr(session, "auto_approval_policy_summary"): return session.auto_approval_policy_summary() cap = getattr(session, "auto_approval_cost_cap_usd", None) estimated = float( getattr(session, "auto_approval_estimated_spend_usd", 0.0) or 0.0 ) remaining = None if cap is None else round(max(0.0, float(cap) - estimated), 4) return { "enabled": bool(getattr(session, "auto_approval_enabled", False)), "cost_cap_usd": cap, "estimated_spend_usd": round(estimated, 4), "remaining_usd": remaining, } async def _start_agent_session( self, *, agent_session: AgentSession, event_queue: asyncio.Queue, tool_router: ToolRouter, ) -> AgentSession: async with self._lock: existing = self.sessions.get(agent_session.session_id) if existing: return existing self.sessions[agent_session.session_id] = agent_session task = asyncio.create_task( self._run_session( agent_session.session_id, agent_session.submission_queue, event_queue, tool_router, ) ) agent_session.task = task return agent_session @staticmethod def _start_cpu_sandbox_preload(agent_session: AgentSession) -> None: """Kick off a best-effort cpu-basic sandbox for the session.""" try: from agent.tools.sandbox_tool import start_cpu_sandbox_preload start_cpu_sandbox_preload(agent_session.session) except Exception as e: logger.warning( "Failed to start CPU sandbox preload for %s: %s", agent_session.session_id, e, ) @staticmethod def _can_access_session(agent_session: AgentSession, user_id: str) -> bool: return ( user_id == "dev" or agent_session.user_id == "dev" or agent_session.user_id == user_id ) @staticmethod def _update_hf_identity( agent_session: AgentSession, *, hf_token: str | None, hf_username: str | None, ) -> None: if hf_token: agent_session.hf_token = hf_token agent_session.session.hf_token = hf_token if hf_username: agent_session.hf_username = hf_username agent_session.session.hf_username = hf_username @staticmethod def _has_active_sandbox_preload(agent_session: AgentSession) -> bool: task = getattr(agent_session.session, "sandbox_preload_task", None) return bool(task and not task.done()) @staticmethod def _preload_failed_for_missing_hf_token(agent_session: AgentSession) -> bool: error = getattr(agent_session.session, "sandbox_preload_error", None) return isinstance(error, str) and error.startswith("No HF token available") def _restart_cpu_preload_if_token_recovered( self, agent_session: AgentSession, *, preload_sandbox: bool, ) -> None: if not preload_sandbox: return session = agent_session.session if getattr(session, "sandbox", None): return if self._has_active_sandbox_preload(agent_session): return if not (agent_session.hf_token or getattr(session, "hf_token", None)): return if not self._preload_failed_for_missing_hf_token(agent_session): return session.sandbox_preload_error = None session.sandbox_preload_task = None session.sandbox_preload_cancel_event = None self._start_cpu_sandbox_preload(agent_session) async def _clear_persisted_sandbox_metadata(self, session_id: str) -> None: try: await self._store().update_session_fields( session_id, sandbox_space_id=None, sandbox_hardware=None, sandbox_owner=None, sandbox_created_at=None, sandbox_status="destroyed", ) except Exception as e: logger.warning("Failed to clear sandbox metadata for %s: %s", session_id, e) async def _cleanup_persisted_sandbox( self, session_id: str, metadata: dict[str, Any], *, hf_token: str | None, ) -> None: """Delete a sandbox recorded by a previous backend process, if any.""" space_id = metadata.get("sandbox_space_id") if not isinstance(space_id, str) or not space_id: return if metadata.get("sandbox_status") == "destroyed": return tokens: list[tuple[str, str]] = [] seen: set[str] = set() for label, token in ( ("user", hf_token), ("admin", os.environ.get("HF_ADMIN_TOKEN")), ): if token and token not in seen: tokens.append((label, token)) seen.add(token) if not tokens: logger.warning( "Cannot clean persisted sandbox %s for session %s: no HF token available", space_id, session_id, ) return last_err: Exception | None = None for label, token in tokens: try: from huggingface_hub import HfApi api = HfApi(token=token) await asyncio.to_thread( api.delete_repo, repo_id=space_id, repo_type="space", ) logger.info( "Deleted persisted sandbox %s for session %s with %s token", space_id, session_id, label, ) await self._clear_persisted_sandbox_metadata(session_id) return except Exception as e: status_code = getattr(getattr(e, "response", None), "status_code", None) if status_code == 404: logger.info( "Persisted sandbox %s for session %s is already gone", space_id, session_id, ) await self._clear_persisted_sandbox_metadata(session_id) return last_err = e logger.warning( "Failed to delete persisted sandbox %s for session %s: %s", space_id, session_id, last_err, ) async def persist_session_snapshot( self, agent_session: AgentSession, *, runtime_state: str | None = None, status: str = "active", ) -> None: """Persist the current runtime context snapshot.""" store = self._store() if not getattr(store, "enabled", False): return try: await store.save_snapshot( session_id=agent_session.session_id, user_id=agent_session.user_id, model=agent_session.session.config.model_name, title=agent_session.title, messages=self._serialize_messages(agent_session.session), runtime_state=runtime_state or self._runtime_state(agent_session), status=status, turn_count=agent_session.session.turn_count, pending_approval=self._serialize_pending_approval( agent_session.session ), claude_counted=agent_session.claude_counted, created_at=agent_session.created_at, notification_destinations=list( agent_session.session.notification_destinations ), auto_approval_enabled=bool( getattr(agent_session.session, "auto_approval_enabled", False) ), auto_approval_cost_cap_usd=getattr( agent_session.session, "auto_approval_cost_cap_usd", None ), auto_approval_estimated_spend_usd=float( getattr( agent_session.session, "auto_approval_estimated_spend_usd", 0.0, ) or 0.0 ), ) except Exception as e: logger.warning( "Failed to persist snapshot for %s: %s", agent_session.session_id, e, ) async def ensure_session_loaded( self, session_id: str, user_id: str, hf_token: str | None = None, hf_username: str | None = None, preload_sandbox: bool = True, ) -> AgentSession | None: """Return a live runtime session, lazily restoring it from Mongo.""" async with self._lock: existing = self.sessions.get(session_id) if existing: if self._can_access_session(existing, user_id): self._update_hf_identity( existing, hf_token=hf_token, hf_username=hf_username, ) self._restart_cpu_preload_if_token_recovered( existing, preload_sandbox=preload_sandbox, ) return existing return None store = self._store() loaded = await store.load_session(session_id) if not loaded: return None async with self._lock: existing = self.sessions.get(session_id) if existing: if self._can_access_session(existing, user_id): self._update_hf_identity( existing, hf_token=hf_token, hf_username=hf_username, ) self._restart_cpu_preload_if_token_recovered( existing, preload_sandbox=preload_sandbox, ) return existing return None meta = loaded.get("metadata") or {} owner = str(meta.get("user_id") or "") if user_id != "dev" and owner != "dev" and owner != user_id: return None await self._cleanup_persisted_sandbox( session_id, meta, hf_token=hf_token, ) from litellm import Message model = meta.get("model") or self.config.model_name event_queue: asyncio.Queue = asyncio.Queue() submission_queue: asyncio.Queue = asyncio.Queue() tool_router, session = await asyncio.to_thread( self._create_session_sync, session_id=session_id, user_id=owner or user_id, hf_username=hf_username, hf_token=hf_token, model=model, event_queue=event_queue, notification_destinations=meta.get("notification_destinations") or [], ) restored_messages: list[Message] = [] for raw in loaded.get("messages") or []: if not isinstance(raw, dict) or raw.get("role") == "system": continue try: restored_messages.append(Message.model_validate(raw)) except Exception as e: logger.warning("Dropping malformed restored message: %s", e) if restored_messages: # Keep the freshly-rendered system prompt, then attach the durable # non-system context so tools/date/user context stay current. session.context_manager.items = [ session.context_manager.items[0], *restored_messages, ] self._restore_pending_approval(session, meta.get("pending_approval") or []) session.turn_count = int(meta.get("turn_count") or 0) session.auto_approval_enabled = bool(meta.get("auto_approval_enabled", False)) raw_cap = meta.get("auto_approval_cost_cap_usd") session.auto_approval_cost_cap_usd = ( float(raw_cap) if isinstance(raw_cap, int | float) else None ) session.auto_approval_estimated_spend_usd = float( meta.get("auto_approval_estimated_spend_usd") or 0.0 ) created_at = meta.get("created_at") if not isinstance(created_at, datetime): created_at = datetime.utcnow() agent_session = AgentSession( session_id=session_id, session=session, tool_router=tool_router, submission_queue=submission_queue, user_id=owner or user_id, hf_username=hf_username, hf_token=hf_token, created_at=created_at, is_active=True, is_processing=False, claude_counted=bool(meta.get("claude_counted")), title=meta.get("title"), ) started = await self._start_agent_session( agent_session=agent_session, event_queue=event_queue, tool_router=tool_router, ) if started is not agent_session: self._update_hf_identity( started, hf_token=hf_token, hf_username=hf_username, ) return started if preload_sandbox: self._start_cpu_sandbox_preload(agent_session) logger.info("Restored session %s for user %s", session_id, owner or user_id) return agent_session async def create_session( self, user_id: str = "dev", hf_username: str | None = None, hf_token: str | None = None, model: str | None = None, is_pro: bool | None = None, ) -> str: """Create a new agent session and return its ID. Session() and ToolRouter() constructors contain blocking I/O (e.g. HfApi().whoami(), litellm.get_max_tokens()) so they are executed in a thread pool to avoid freezing the async event loop. Args: user_id: The ID of the user who owns this session. hf_username: The HF username/namespace used for personal trace uploads. hf_token: The user's HF OAuth token, stored for tool execution. model: Optional model override. When set, replaces ``model_name`` on the per-session config clone. None falls back to the config default. Raises: SessionCapacityError: If the server or user has reached the maximum number of concurrent sessions. """ # ── Capacity checks ────────────────────────────────────────── async with self._lock: active_count = self.active_session_count if active_count >= MAX_SESSIONS: raise SessionCapacityError( f"Server is at capacity ({active_count}/{MAX_SESSIONS} sessions). " "Please try again later.", error_type="global", ) if user_id != "dev": user_count = self._count_user_sessions(user_id) if user_count >= MAX_SESSIONS_PER_USER: raise SessionCapacityError( f"You have reached the maximum of {MAX_SESSIONS_PER_USER} " "concurrent sessions. Please close an existing session first.", error_type="per_user", ) session_id = str(uuid.uuid4()) # Create queues for this session submission_queue: asyncio.Queue = asyncio.Queue() event_queue: asyncio.Queue = asyncio.Queue() # Run blocking constructors in a thread to keep the event loop responsive. tool_router, session = await asyncio.to_thread( self._create_session_sync, session_id=session_id, user_id=user_id, hf_username=hf_username, hf_token=hf_token, model=model, event_queue=event_queue, ) # Create wrapper agent_session = AgentSession( session_id=session_id, session=session, tool_router=tool_router, submission_queue=submission_queue, user_id=user_id, hf_username=hf_username, hf_token=hf_token, ) await self._start_agent_session( agent_session=agent_session, event_queue=event_queue, tool_router=tool_router, ) await self.persist_session_snapshot(agent_session, runtime_state="idle") self._start_cpu_sandbox_preload(agent_session) if is_pro is not None and user_id and user_id != "dev": await self._track_pro_status(agent_session, is_pro=is_pro) logger.info(f"Created session {session_id} for user {user_id}") return session_id async def _track_pro_status( self, agent_session: AgentSession, *, is_pro: bool ) -> None: """Update Mongo per-user Pro state and emit a one-shot conversion event if the store reports a free→Pro transition. Best-effort: any Mongo failure is swallowed so we never fail session creation on telemetry.""" store = self._store() if not getattr(store, "enabled", False): return try: result = await store.mark_pro_seen(agent_session.user_id, is_pro=is_pro) except Exception as e: logger.debug("mark_pro_seen failed: %s", e) return if not result or not result.get("converted"): return try: from agent.core import telemetry await telemetry.record_pro_conversion( agent_session.session, first_seen_at=result.get("first_seen_at"), ) except Exception as e: logger.debug("record_pro_conversion failed: %s", e) async def seed_from_summary(self, session_id: str, messages: list[dict]) -> int: """Rehydrate a session from cached prior messages via summarization. Runs the standard summarization prompt (same one compaction uses) over the provided messages, then seeds the new session's context with that summary. Tool-call pairing concerns disappear because the output is plain text. Returns the number of messages summarized. """ from litellm import Message from agent.context_manager.manager import _RESTORE_PROMPT, summarize_messages agent_session = self.sessions.get(session_id) if not agent_session: raise ValueError(f"Session {session_id} not found") # Parse into Message objects, tolerating malformed entries. parsed: list[Message] = [] for raw in messages: if raw.get("role") == "system": continue # the new session has its own system prompt try: parsed.append(Message.model_validate(raw)) except Exception as e: logger.warning("Dropping malformed message during seed: %s", e) if not parsed: return 0 session = agent_session.session # Pass the real tool specs so the summarizer sees what the agent # actually has — otherwise Anthropic's modify_params injects a # dummy tool and the summarizer editorializes that the original # tool calls were fabricated. tool_specs = None try: tool_specs = agent_session.tool_router.get_tool_specs_for_llm() except Exception: pass try: summary, _ = await summarize_messages( parsed, model_name=session.config.model_name, hf_token=session.hf_token, max_tokens=4000, prompt=_RESTORE_PROMPT, tool_specs=tool_specs, session=session, kind="restore", ) except Exception as e: logger.error("Summary call failed during seed: %s", e) raise seed = Message( role="user", content=( "[SYSTEM: Your prior memory of this conversation — written " "in your own voice right before restart. Continue from here.]\n\n" + (summary or "(no summary returned)") ), ) session.context_manager.items.append(seed) await self.persist_session_snapshot(agent_session, runtime_state="idle") return len(parsed) @staticmethod async def _cleanup_sandbox(session: Session) -> None: """Delete the sandbox Space if one was created for this session. Retries on transient failures (HF API 5xx, rate-limit, network blips) with exponential backoff. A single missed delete = a permanently orphaned Space, so the cost of an extra retry beats the alternative. """ from agent.tools.sandbox_tool import teardown_session_sandbox await teardown_session_sandbox(session) async def _cleanup_all_sandboxes_on_close(self) -> None: """Best-effort sandbox cleanup for graceful backend shutdown.""" async with self._lock: agent_sessions = list(self.sessions.values()) if not agent_sessions: return semaphore = asyncio.Semaphore(SANDBOX_SHUTDOWN_CLEANUP_CONCURRENCY) async def _cleanup_one(agent_session: AgentSession) -> None: async with semaphore: try: await self._cleanup_sandbox(agent_session.session) except Exception as e: logger.warning( "Shutdown sandbox cleanup failed for %s: %s", agent_session.session_id, e, ) tasks = [ asyncio.create_task(_cleanup_one(agent_session)) for agent_session in agent_sessions ] try: await asyncio.wait_for( asyncio.gather(*tasks, return_exceptions=True), timeout=SANDBOX_SHUTDOWN_CLEANUP_TIMEOUT_S, ) except asyncio.TimeoutError: logger.warning( "Timed out after %.0fs cleaning up sandboxes on shutdown; " "orphan sweeper will handle any stragglers", SANDBOX_SHUTDOWN_CLEANUP_TIMEOUT_S, ) async def _run_session( self, session_id: str, submission_queue: asyncio.Queue, event_queue: asyncio.Queue, tool_router: ToolRouter, ) -> None: """Run the agent loop for a session and broadcast events via EventBroadcaster.""" agent_session = self.sessions.get(session_id) if not agent_session: logger.error(f"Session {session_id} not found") return session = agent_session.session # Start event broadcaster task broadcaster = EventBroadcaster(event_queue) agent_session.broadcaster = broadcaster broadcast_task = asyncio.create_task(broadcaster.run()) try: async with tool_router: # Send ready event await session.send_event( Event(event_type="ready", data={"message": "Agent initialized"}) ) while session.is_running: try: # Wait for submission with timeout to allow checking is_running submission = await asyncio.wait_for( submission_queue.get(), timeout=1.0 ) agent_session.is_processing = True try: should_continue = await process_submission( session, submission ) finally: agent_session.is_processing = False await self.persist_session_snapshot(agent_session) if not should_continue: break except asyncio.TimeoutError: continue except asyncio.CancelledError: logger.info(f"Session {session_id} cancelled") break except Exception as e: logger.error(f"Error in session {session_id}: {e}") await session.send_event( Event(event_type="error", data={"error": str(e)}) ) finally: broadcast_task.cancel() try: await broadcast_task except asyncio.CancelledError: pass await self._cleanup_sandbox(session) # Final-flush: always save on session death so we capture ended # sessions even if the client disconnects without /shutdown. # Idempotent via session_id key; detached subprocess. if session.config.save_sessions: try: session.save_and_upload_detached( session.config.session_dataset_repo ) except Exception as e: logger.warning(f"Final-flush failed for {session_id}: {e}") async with self._lock: if session_id in self.sessions: self.sessions[session_id].is_active = False await self.persist_session_snapshot( self.sessions[session_id], runtime_state="ended", status="ended", ) logger.info(f"Session {session_id} ended") async def submit(self, session_id: str, operation: Operation) -> bool: """Submit an operation to a session.""" async with self._lock: agent_session = self.sessions.get(session_id) if not agent_session or not agent_session.is_active: logger.warning(f"Session {session_id} not found or inactive") return False submission = Submission(id=f"sub_{uuid.uuid4().hex[:8]}", operation=operation) await agent_session.submission_queue.put(submission) return True async def submit_user_input(self, session_id: str, text: str) -> bool: """Submit user input to a session.""" operation = Operation(op_type=OpType.USER_INPUT, data={"text": text}) return await self.submit(session_id, operation) async def submit_approval( self, session_id: str, approvals: list[dict[str, Any]] ) -> bool: """Submit tool approvals to a session.""" operation = Operation( op_type=OpType.EXEC_APPROVAL, data={"approvals": approvals} ) return await self.submit(session_id, operation) async def interrupt(self, session_id: str) -> bool: """Interrupt a session by signalling cancellation directly (bypasses queue).""" agent_session = self.sessions.get(session_id) if not agent_session or not agent_session.is_active: return False agent_session.session.cancel() return True async def undo(self, session_id: str) -> bool: """Undo last turn in a session.""" operation = Operation(op_type=OpType.UNDO) return await self.submit(session_id, operation) async def truncate(self, session_id: str, user_message_index: int) -> bool: """Truncate conversation to before a specific user message (direct, no queue).""" async with self._lock: agent_session = self.sessions.get(session_id) if not agent_session or not agent_session.is_active: return False success = agent_session.session.context_manager.truncate_to_user_message( user_message_index ) if success: await self.persist_session_snapshot(agent_session, runtime_state="idle") return success async def compact(self, session_id: str) -> bool: """Compact context in a session.""" operation = Operation(op_type=OpType.COMPACT) return await self.submit(session_id, operation) async def shutdown_session(self, session_id: str) -> bool: """Shutdown a specific session.""" operation = Operation(op_type=OpType.SHUTDOWN) success = await self.submit(session_id, operation) if success: async with self._lock: agent_session = self.sessions.get(session_id) if agent_session and agent_session.task: # Wait for task to complete try: await asyncio.wait_for(agent_session.task, timeout=5.0) except asyncio.TimeoutError: agent_session.task.cancel() return success async def delete_session(self, session_id: str) -> bool: """Soft-delete a session and stop its runtime resources.""" async with self._lock: agent_session = self.sessions.pop(session_id, None) if not agent_session: await self._store().soft_delete_session(session_id) return True await self._store().soft_delete_session(session_id) # Clean up sandbox Space before cancelling the task await self._cleanup_sandbox(agent_session.session) # Cancel the task if running if agent_session.task and not agent_session.task.done(): agent_session.task.cancel() try: await agent_session.task except asyncio.CancelledError: pass return True async def teardown_sandbox(self, session_id: str) -> bool: """Delete only this session's sandbox runtime, preserving chat state.""" async with self._lock: agent_session = self.sessions.get(session_id) if not agent_session or not agent_session.is_active: return False await self._cleanup_sandbox(agent_session.session) await self.persist_session_snapshot(agent_session, runtime_state="idle") return True async def update_session_title(self, session_id: str, title: str | None) -> None: """Persist a user-visible title for sidebar rehydration.""" agent_session = self.sessions.get(session_id) if agent_session: agent_session.title = title await self._store().update_session_fields(session_id, title=title) async def update_session_model(self, session_id: str, model_id: str) -> bool: agent_session = self.sessions.get(session_id) if not agent_session or not agent_session.is_active: return False agent_session.session.update_model(model_id) await self.persist_session_snapshot(agent_session, runtime_state="idle") return True async def update_session_auto_approval( self, session_id: str, *, enabled: bool, cost_cap_usd: float | None, cap_provided: bool = False, ) -> dict[str, Any]: agent_session = self.sessions.get(session_id) if not agent_session or not agent_session.is_active: raise ValueError("Session not found or inactive") session = agent_session.session if enabled: if not cap_provided and cost_cap_usd is None: cost_cap_usd = getattr(session, "auto_approval_cost_cap_usd", None) if cost_cap_usd is None: cost_cap_usd = DEFAULT_YOLO_COST_CAP_USD elif cost_cap_usd is None: cost_cap_usd = DEFAULT_YOLO_COST_CAP_USD else: if not cap_provided: cost_cap_usd = getattr(session, "auto_approval_cost_cap_usd", None) if hasattr(session, "set_auto_approval_policy"): session.set_auto_approval_policy( enabled=enabled, cost_cap_usd=cost_cap_usd, ) else: session.auto_approval_enabled = bool(enabled) session.auto_approval_cost_cap_usd = cost_cap_usd await self.persist_session_snapshot(agent_session) return self._auto_approval_summary(session) def get_session_owner(self, session_id: str) -> str | None: """Get the user_id that owns a session, or None if session doesn't exist.""" agent_session = self.sessions.get(session_id) if not agent_session: return None return agent_session.user_id def verify_session_access(self, session_id: str, user_id: str) -> bool: """Check if a user has access to a session. Returns True if: - The session exists AND the user owns it - The user_id is "dev" (dev mode bypass) """ owner = self.get_session_owner(session_id) if owner is None: return False if user_id == "dev" or owner == "dev": return True return owner == user_id def get_session_info(self, session_id: str) -> dict[str, Any] | None: """Get information about a session.""" agent_session = self.sessions.get(session_id) if not agent_session: return None pending_approval = self._pending_tools_for_api(agent_session.session) return { "session_id": session_id, "created_at": agent_session.created_at.isoformat(), "is_active": agent_session.is_active, "is_processing": agent_session.is_processing, "message_count": len(agent_session.session.context_manager.items), "user_id": agent_session.user_id, "pending_approval": pending_approval, "model": agent_session.session.config.model_name, "title": agent_session.title, "notification_destinations": list( agent_session.session.notification_destinations ), "auto_approval": self._auto_approval_summary(agent_session.session), } def set_notification_destinations( self, session_id: str, destinations: list[str] ) -> list[str]: """Replace the session's opted-in auto-notification destinations.""" agent_session = self.sessions.get(session_id) if not agent_session or not agent_session.is_active: raise ValueError("Session not found or inactive") normalized: list[str] = [] seen: set[str] = set() for raw_name in destinations: name = raw_name.strip() if not name: raise ValueError("Destination names must not be empty") destination = self.config.messaging.get_destination(name) if destination is None: raise ValueError(f"Unknown destination '{name}'") if not destination.allow_auto_events: raise ValueError(f"Destination '{name}' is not enabled for auto events") if name not in seen: normalized.append(name) seen.add(name) agent_session.session.set_notification_destinations(normalized) return normalized async def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]: """List sessions, optionally filtered by user. Args: user_id: If provided, only return sessions owned by this user. If "dev", return all sessions (dev mode). """ results: list[dict[str, Any]] = [] store = self._store() if getattr(store, "enabled", False): for row in await store.list_sessions(user_id or "dev"): sid = row.get("session_id") or row.get("_id") if not sid: continue runtime_info = self.get_session_info(str(sid)) if runtime_info: results.append(runtime_info) continue created_at = row.get("created_at") if isinstance(created_at, datetime): created_at_str = created_at.isoformat() else: created_at_str = str(created_at or datetime.utcnow().isoformat()) pending = self._pending_docs_for_api(row.get("pending_approval") or []) results.append( { "session_id": str(sid), "created_at": created_at_str, "is_active": row.get("status") != "ended", "is_processing": row.get("runtime_state") == "processing", "message_count": int(row.get("message_count") or 0), "user_id": row.get("user_id") or "dev", "pending_approval": pending or None, "model": row.get("model"), "title": row.get("title"), "notification_destinations": row.get( "notification_destinations" ) or [], "auto_approval": { "enabled": bool(row.get("auto_approval_enabled", False)), "cost_cap_usd": row.get("auto_approval_cost_cap_usd"), "estimated_spend_usd": float( row.get("auto_approval_estimated_spend_usd") or 0.0 ), "remaining_usd": ( None if row.get("auto_approval_cost_cap_usd") is None else round( max( 0.0, float( row.get("auto_approval_cost_cap_usd") or 0.0 ) - float( row.get("auto_approval_estimated_spend_usd") or 0.0 ), ), 4, ) ), }, } ) return results for sid in self.sessions: info = self.get_session_info(sid) if not info: continue if user_id and user_id != "dev" and info.get("user_id") != user_id: continue results.append(info) return results @property def active_session_count(self) -> int: """Get count of active sessions.""" return sum(1 for s in self.sessions.values() if s.is_active) # Global session manager instance session_manager = SessionManager()