ml-intern3 / backend /session_manager.py
lewtun's picture
lewtun HF Staff
Route LLM inference through HF Router (#282)
437a804 unverified
"""Session manager for handling multiple concurrent agent sessions."""
import asyncio
import json
import logging
import os
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Optional
from agent.config import load_config
from agent.core.agent_loop import process_submission
from agent.core.model_ids import (
DEFAULT_MODEL_ID,
KIMI_K26_MODEL_ID,
is_known_router_model_id,
strip_huggingface_model_prefix,
)
from agent.core.session import Event, OpType, Session
from agent.core.session_persistence import get_session_store
from agent.core.tools import ToolRouter
from agent.messaging.gateway import NotificationGateway
# Get project root (parent of backend directory)
PROJECT_ROOT = Path(__file__).parent.parent
DEFAULT_CONFIG_PATH = str(PROJECT_ROOT / "configs" / "frontend_agent_config.json")
# These dataclasses match agent/main.py structure
@dataclass
class Operation:
"""Operation to be executed by the agent."""
op_type: OpType
data: Optional[dict[str, Any]] = None
@dataclass
class Submission:
"""Submission to the agent loop."""
id: str
operation: Operation
logger = logging.getLogger(__name__)
class EventBroadcaster:
"""Reads from the agent's event queue and fans out to SSE subscribers.
Events that arrive when no subscribers are listening are discarded by
this in-memory fanout. Durable replay is handled by session_persistence.
"""
def __init__(self, event_queue: asyncio.Queue):
self._source = event_queue
self._subscribers: dict[int, asyncio.Queue] = {}
self._counter = 0
def subscribe(self) -> tuple[int, asyncio.Queue]:
"""Create a new subscriber. Returns (id, queue)."""
self._counter += 1
sub_id = self._counter
q: asyncio.Queue = asyncio.Queue()
self._subscribers[sub_id] = q
return sub_id, q
def unsubscribe(self, sub_id: int) -> None:
self._subscribers.pop(sub_id, None)
async def run(self) -> None:
"""Main loop — reads from source queue and broadcasts."""
while True:
try:
event: Event = await self._source.get()
msg = {
"event_type": event.event_type,
"data": event.data,
"seq": event.seq,
}
for q in self._subscribers.values():
await q.put(msg)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"EventBroadcaster error: {e}")
@dataclass
class AgentSession:
"""Wrapper for an agent session with its associated resources."""
session_id: str
session: Session
tool_router: ToolRouter
submission_queue: asyncio.Queue
user_id: str = "dev" # Owner of this session
hf_username: str | None = None # HF namespace used for personal trace uploads
hf_token: str | None = None # User's HF OAuth token for tool execution
task: asyncio.Task | None = None
created_at: datetime = field(default_factory=datetime.utcnow)
# Last genuine activity (submit/turn-start/turn-finish/direct user write).
# Drives the idle reaper. Defaults to load time so a freshly-restored but
# untouched session isn't reaped for a full idle window.
last_active_at: datetime = field(default_factory=datetime.utcnow)
is_active: bool = True
is_processing: bool = False # True while a submission is being executed
# Set under the lock by the reaper while tearing this session down. Blocks
# submit() from enqueueing onto a session that's being evicted.
is_reaping: bool = False
broadcaster: Any = None
title: str | None = None
# True once this session has been counted against the user's daily premium
# quota. The field name is kept for persistence compatibility.
claude_counted: bool = False
class SessionCapacityError(Exception):
"""Raised when no more sessions can be created."""
def __init__(self, message: str, error_type: str = "global") -> None:
super().__init__(message)
self.error_type = error_type # "global" or "per_user"
# ── Capacity limits ─────────────────────────────────────────────────
# Sized for HF Spaces 8 vCPU / 32 GB RAM.
# Each session uses ~10-20 MB (context, tools, queues, task); 200 × 20 MB
# = 4 GB worst case, leaving plenty of headroom for the Python runtime
# and per-request overhead.
MAX_SESSIONS: int = 200
MAX_SESSIONS_PER_USER: int = 10
DEFAULT_YOLO_COST_CAP_USD: float = 5.0
SANDBOX_SHUTDOWN_CLEANUP_CONCURRENCY: int = 10
SANDBOX_SHUTDOWN_CLEANUP_TIMEOUT_S: float = 60.0
# ── Idle-session reaper ─────────────────────────────────────────────
# A live session idle ≥ REAPER_IDLE_HOURS with no in-flight work has its
# sandbox + RAM released and is evicted from the live pool, while staying
# fully resumable from Mongo (it reappears as a normal chat, never "ended").
# This frees both the global pool and the user's concurrent slots.
REAPER_IDLE_HOURS: float = float(os.environ.get("REAPER_IDLE_HOURS", "2"))
REAPER_INTERVAL_S: float = float(os.environ.get("REAPER_INTERVAL_S", "300"))
REAP_TEARDOWN_TIMEOUT_S: float = float(os.environ.get("REAP_TEARDOWN_TIMEOUT_S", "30"))
REAPER_IDLE = timedelta(hours=REAPER_IDLE_HOURS)
class SessionManager:
"""Manages multiple concurrent agent sessions."""
def __init__(self, config_path: str | None = None) -> None:
self.config = load_config(config_path or DEFAULT_CONFIG_PATH)
normalized_default = strip_huggingface_model_prefix(self.config.model_name)
if normalized_default:
self.config.model_name = normalized_default
self.messaging_gateway = NotificationGateway(self.config.messaging)
self.sessions: dict[str, AgentSession] = {}
self._lock = asyncio.Lock()
self.persistence_store = None
# In-flight create_session calls that have passed the capacity check
# but not yet inserted their session. Counted alongside
# active_session_count to hard-cap the global pool against the
# check-then-create race (see create_session).
self._pending_creates: int = 0
self._reaper_task: asyncio.Task | None = None
async def start(self) -> None:
"""Start shared background resources."""
self.persistence_store = get_session_store()
await self.persistence_store.init()
await self.messaging_gateway.start()
self._reaper_task = asyncio.create_task(self._reaper_loop())
async def close(self) -> None:
"""Flush and close shared background resources."""
if self._reaper_task is not None:
self._reaper_task.cancel()
try:
await self._reaper_task
except asyncio.CancelledError:
pass
self._reaper_task = None
await self._cleanup_all_sandboxes_on_close()
await self.messaging_gateway.close()
if self.persistence_store is not None:
await self.persistence_store.close()
def _store(self):
if self.persistence_store is None:
self.persistence_store = get_session_store()
return self.persistence_store
def _count_user_sessions(self, user_id: str) -> int:
"""Count active sessions owned by a specific user."""
return sum(
1 for s in self.sessions.values() if s.user_id == user_id and s.is_active
)
@staticmethod
def _touch(agent_session: "AgentSession") -> None:
"""Stamp genuine activity so the idle reaper's clock resets.
Call on real user/agent activity (submit, turn start/finish, direct
user-initiated writes) — never on passive reads or hydration, which
would keep an otherwise-idle session alive forever.
"""
agent_session.last_active_at = datetime.utcnow()
@staticmethod
def _model_from_saved_metadata(
model: str | None,
*,
premium_user_billed: bool,
claude_counted: bool,
) -> tuple[str, bool, bool]:
normalized = strip_huggingface_model_prefix(model)
if normalized and is_known_router_model_id(normalized):
return normalized, premium_user_billed, claude_counted
fallback_model = KIMI_K26_MODEL_ID if premium_user_billed else DEFAULT_MODEL_ID
logger.warning(
"Saved session model %r failed validation; using %r",
model,
fallback_model,
)
if fallback_model == KIMI_K26_MODEL_ID:
return fallback_model, False, False
return fallback_model, premium_user_billed, claude_counted
def _create_session_sync(
self,
*,
session_id: str,
user_id: str,
hf_username: str | None,
hf_token: str | None,
model: str | None,
event_queue: asyncio.Queue,
notification_destinations: list[str] | None = None,
) -> tuple[ToolRouter, Session]:
"""Build blocking per-session resources in a worker thread."""
import time as _time
t0 = _time.monotonic()
tool_router = ToolRouter(self.config.mcpServers, hf_token=hf_token)
# Deep-copy config so each session's model switches independently —
# tab A picking GLM doesn't flip tab B off the default model.
session_config = self.config.model_copy(deep=True)
normalized_model = strip_huggingface_model_prefix(model)
if normalized_model:
session_config.model_name = normalized_model
session = Session(
event_queue=event_queue,
config=session_config,
tool_router=tool_router,
hf_token=hf_token,
user_id=user_id,
hf_username=hf_username,
notification_gateway=self.messaging_gateway,
notification_destinations=notification_destinations or [],
session_id=session_id,
persistence_store=self._store(),
)
t1 = _time.monotonic()
logger.info("Session initialized in %.2fs", t1 - t0)
return tool_router, session
def _serialize_messages(self, session: Session) -> list[dict[str, Any]]:
return [msg.model_dump(mode="json") for msg in session.context_manager.items]
def _serialize_pending_approval(self, session: Session) -> list[dict[str, Any]]:
pending = session.pending_approval or {}
tool_calls = pending.get("tool_calls") or []
serialized: list[dict[str, Any]] = []
for tc in tool_calls:
if hasattr(tc, "model_dump"):
serialized.append(tc.model_dump(mode="json"))
elif isinstance(tc, dict):
serialized.append(tc)
return serialized
@staticmethod
def _pending_tools_for_api(session: Session) -> list[dict[str, Any]] | None:
pending = session.pending_approval or {}
tool_calls = pending.get("tool_calls") or []
if not tool_calls:
return None
result: list[dict[str, Any]] = []
for tc in tool_calls:
try:
args = json.loads(tc.function.arguments)
except (json.JSONDecodeError, AttributeError, TypeError):
args = {}
result.append(
{
"tool": getattr(tc.function, "name", None),
"tool_call_id": getattr(tc, "id", None),
"arguments": args,
}
)
return result
def _restore_pending_approval(
self, session: Session, pending_approval: list[dict[str, Any]] | None
) -> None:
if not pending_approval:
session.pending_approval = None
return
from litellm import ChatCompletionMessageToolCall as ToolCall
restored = []
for raw in pending_approval:
try:
if "function" in raw:
restored.append(ToolCall(**raw))
else:
restored.append(
ToolCall(
id=raw["tool_call_id"],
type="function",
function={
"name": raw["tool"],
"arguments": json.dumps(raw.get("arguments") or {}),
},
)
)
except Exception as e:
logger.warning("Dropping malformed pending approval: %s", e)
session.pending_approval = {"tool_calls": restored} if restored else None
@staticmethod
def _pending_docs_for_api(
pending_approval: list[dict[str, Any]] | None,
) -> list[dict[str, Any]] | None:
if not pending_approval:
return None
result: list[dict[str, Any]] = []
for raw in pending_approval:
if "function" in raw:
function = raw.get("function") or {}
try:
args = json.loads(function.get("arguments") or "{}")
except (json.JSONDecodeError, TypeError):
args = {}
result.append(
{
"tool": function.get("name"),
"tool_call_id": raw.get("id"),
"arguments": args,
}
)
elif {"tool", "tool_call_id"}.issubset(raw):
result.append(
{
"tool": raw.get("tool"),
"tool_call_id": raw.get("tool_call_id"),
"arguments": raw.get("arguments") or {},
}
)
return result or None
@staticmethod
def _runtime_state(agent_session: AgentSession) -> str:
if agent_session.session.pending_approval:
return "waiting_approval"
if agent_session.is_processing:
return "processing"
if not agent_session.is_active:
return "ended"
return "idle"
@staticmethod
def _auto_approval_summary(session: Session) -> dict[str, Any]:
if hasattr(session, "auto_approval_policy_summary"):
return session.auto_approval_policy_summary()
cap = getattr(session, "auto_approval_cost_cap_usd", None)
estimated = float(
getattr(session, "auto_approval_estimated_spend_usd", 0.0) or 0.0
)
remaining = None if cap is None else round(max(0.0, float(cap) - estimated), 4)
return {
"enabled": bool(getattr(session, "auto_approval_enabled", False)),
"cost_cap_usd": cap,
"estimated_spend_usd": round(estimated, 4),
"remaining_usd": remaining,
}
async def _start_agent_session(
self,
*,
agent_session: AgentSession,
event_queue: asyncio.Queue,
tool_router: ToolRouter,
) -> AgentSession:
async with self._lock:
existing = self.sessions.get(agent_session.session_id)
if existing:
return existing
self.sessions[agent_session.session_id] = agent_session
task = asyncio.create_task(
self._run_session(
agent_session.session_id,
agent_session.submission_queue,
event_queue,
tool_router,
)
)
agent_session.task = task
return agent_session
@staticmethod
def _start_cpu_sandbox_preload(agent_session: AgentSession) -> None:
"""Kick off a best-effort cpu-basic sandbox for the session."""
try:
from agent.tools.sandbox_tool import start_cpu_sandbox_preload
start_cpu_sandbox_preload(agent_session.session)
except Exception as e:
logger.warning(
"Failed to start CPU sandbox preload for %s: %s",
agent_session.session_id,
e,
)
@staticmethod
def _can_access_session(agent_session: AgentSession, user_id: str) -> bool:
return (
user_id == "dev"
or agent_session.user_id == "dev"
or agent_session.user_id == user_id
)
@staticmethod
def _update_hf_identity(
agent_session: AgentSession,
*,
hf_token: str | None,
hf_username: str | None,
) -> None:
if hf_token:
agent_session.hf_token = hf_token
agent_session.session.hf_token = hf_token
if hf_username:
agent_session.hf_username = hf_username
agent_session.session.hf_username = hf_username
@staticmethod
def _has_active_sandbox_preload(agent_session: AgentSession) -> bool:
task = getattr(agent_session.session, "sandbox_preload_task", None)
return bool(task and not task.done())
@staticmethod
def _preload_failed_for_missing_hf_token(agent_session: AgentSession) -> bool:
error = getattr(agent_session.session, "sandbox_preload_error", None)
return isinstance(error, str) and error.startswith("No HF token available")
def _restart_cpu_preload_if_token_recovered(
self,
agent_session: AgentSession,
*,
preload_sandbox: bool,
) -> None:
if not preload_sandbox:
return
session = agent_session.session
if getattr(session, "sandbox", None):
return
if self._has_active_sandbox_preload(agent_session):
return
if not (agent_session.hf_token or getattr(session, "hf_token", None)):
return
if not self._preload_failed_for_missing_hf_token(agent_session):
return
session.sandbox_preload_error = None
session.sandbox_preload_task = None
session.sandbox_preload_cancel_event = None
self._start_cpu_sandbox_preload(agent_session)
async def _clear_persisted_sandbox_metadata(self, session_id: str) -> None:
try:
await self._store().update_session_fields(
session_id,
sandbox_space_id=None,
sandbox_hardware=None,
sandbox_owner=None,
sandbox_created_at=None,
sandbox_status="destroyed",
)
except Exception as e:
logger.warning("Failed to clear sandbox metadata for %s: %s", session_id, e)
async def _cleanup_persisted_sandbox(
self,
session_id: str,
metadata: dict[str, Any],
*,
hf_token: str | None,
) -> None:
"""Delete a sandbox recorded by a previous backend process, if any."""
space_id = metadata.get("sandbox_space_id")
if not isinstance(space_id, str) or not space_id:
return
if metadata.get("sandbox_status") == "destroyed":
return
tokens: list[tuple[str, str]] = []
seen: set[str] = set()
for label, token in (
("user", hf_token),
("admin", os.environ.get("HF_ADMIN_TOKEN")),
):
if token and token not in seen:
tokens.append((label, token))
seen.add(token)
if not tokens:
logger.warning(
"Cannot clean persisted sandbox %s for session %s: no HF token available",
space_id,
session_id,
)
return
last_err: Exception | None = None
for label, token in tokens:
try:
from huggingface_hub import HfApi
api = HfApi(token=token)
await asyncio.to_thread(
api.delete_repo,
repo_id=space_id,
repo_type="space",
)
logger.info(
"Deleted persisted sandbox %s for session %s with %s token",
space_id,
session_id,
label,
)
await self._clear_persisted_sandbox_metadata(session_id)
return
except Exception as e:
status_code = getattr(getattr(e, "response", None), "status_code", None)
if status_code == 404:
logger.info(
"Persisted sandbox %s for session %s is already gone",
space_id,
session_id,
)
await self._clear_persisted_sandbox_metadata(session_id)
return
last_err = e
logger.warning(
"Failed to delete persisted sandbox %s for session %s: %s",
space_id,
session_id,
last_err,
)
async def persist_session_snapshot(
self,
agent_session: AgentSession,
*,
runtime_state: str | None = None,
status: str = "active",
raise_on_error: bool = False,
) -> None:
"""Persist the current runtime context snapshot.
Best-effort by default: a disabled store is a no-op and write failures
are swallowed. Pass ``raise_on_error=True`` when the caller must know
the snapshot was durably written (e.g. the reaper, which only evicts a
session after confirming it stayed resumable) — then a disabled store
or a write failure raises instead of silently dropping the snapshot.
"""
store = self._store()
if not getattr(store, "enabled", False):
if raise_on_error:
raise RuntimeError("persistence store is disabled")
return
try:
await store.save_snapshot(
session_id=agent_session.session_id,
user_id=agent_session.user_id,
model=agent_session.session.config.model_name,
title=agent_session.title,
messages=self._serialize_messages(agent_session.session),
runtime_state=runtime_state or self._runtime_state(agent_session),
status=status,
turn_count=agent_session.session.turn_count,
pending_approval=self._serialize_pending_approval(
agent_session.session
),
claude_counted=agent_session.claude_counted,
premium_user_billed=getattr(
agent_session.session, "premium_user_billed", False
),
created_at=agent_session.created_at,
notification_destinations=list(
agent_session.session.notification_destinations
),
auto_approval_enabled=bool(
getattr(agent_session.session, "auto_approval_enabled", False)
),
auto_approval_cost_cap_usd=getattr(
agent_session.session, "auto_approval_cost_cap_usd", None
),
auto_approval_estimated_spend_usd=float(
getattr(
agent_session.session,
"auto_approval_estimated_spend_usd",
0.0,
)
or 0.0
),
raise_on_error=raise_on_error,
)
except Exception as e:
if raise_on_error:
raise
logger.warning(
"Failed to persist snapshot for %s: %s",
agent_session.session_id,
e,
)
async def ensure_session_loaded(
self,
session_id: str,
user_id: str,
hf_token: str | None = None,
hf_username: str | None = None,
preload_sandbox: bool = True,
) -> AgentSession | None:
"""Return a live runtime session, lazily restoring it from Mongo."""
async with self._lock:
existing = self.sessions.get(session_id)
if existing:
if self._can_access_session(existing, user_id):
self._update_hf_identity(
existing,
hf_token=hf_token,
hf_username=hf_username,
)
self._restart_cpu_preload_if_token_recovered(
existing,
preload_sandbox=preload_sandbox,
)
return existing
return None
store = self._store()
loaded = await store.load_session(session_id)
if not loaded:
return None
async with self._lock:
existing = self.sessions.get(session_id)
if existing:
if self._can_access_session(existing, user_id):
self._update_hf_identity(
existing,
hf_token=hf_token,
hf_username=hf_username,
)
self._restart_cpu_preload_if_token_recovered(
existing,
preload_sandbox=preload_sandbox,
)
return existing
return None
meta = loaded.get("metadata") or {}
owner = str(meta.get("user_id") or "")
if user_id != "dev" and owner != "dev" and owner != user_id:
return None
await self._cleanup_persisted_sandbox(
session_id,
meta,
hf_token=hf_token,
)
from litellm import Message
model, premium_user_billed, claude_counted = self._model_from_saved_metadata(
meta.get("model") or self.config.model_name,
premium_user_billed=bool(meta.get("premium_user_billed", False)),
claude_counted=bool(meta.get("claude_counted")),
)
event_queue: asyncio.Queue = asyncio.Queue()
submission_queue: asyncio.Queue = asyncio.Queue()
tool_router, session = await asyncio.to_thread(
self._create_session_sync,
session_id=session_id,
user_id=owner or user_id,
hf_username=hf_username,
hf_token=hf_token,
model=model,
event_queue=event_queue,
notification_destinations=meta.get("notification_destinations") or [],
)
restored_messages: list[Message] = []
for raw in loaded.get("messages") or []:
if not isinstance(raw, dict) or raw.get("role") == "system":
continue
try:
restored_messages.append(Message.model_validate(raw))
except Exception as e:
logger.warning("Dropping malformed restored message: %s", e)
if restored_messages:
# Keep the freshly-rendered system prompt, then attach the durable
# non-system context so tools/date/user context stay current.
session.context_manager.items = [
session.context_manager.items[0],
*restored_messages,
]
# If this session ever had a sandbox, its container did not survive the
# resume (a fresh, empty one is lazily recreated). Tell the agent so it
# recreates files/packages instead of assuming /app/train.py et al. still
# exist. Gated on sandbox_status so pure Q&A chats get no note. Mirrors
# the seed_from_summary note convention.
#
# Skip it when an approval is pending: the restored context ends with an
# assistant tool-call message awaiting results, so injecting a user
# message here would sit between the tool_calls and their results. On
# approval the real results get appended after the note, leaving them
# orphaned (the context manager stubs the "missing" result right after
# the assistant message) — which the provider rejects. The agent still
# learns the sandbox is empty when the approved tool runs against it.
if meta.get("sandbox_status") and not meta.get("pending_approval"):
session.context_manager.items.append(
Message(
role="user",
content=(
"[SYSTEM: This session was resumed and its sandbox was "
"reset. Any files, installed packages, or running "
"processes from earlier are gone — recreate what you "
"need before using the sandbox.]"
),
)
)
self._restore_pending_approval(session, meta.get("pending_approval") or [])
session.turn_count = int(meta.get("turn_count") or 0)
session.auto_approval_enabled = bool(meta.get("auto_approval_enabled", False))
session.premium_user_billed = premium_user_billed
raw_cap = meta.get("auto_approval_cost_cap_usd")
session.auto_approval_cost_cap_usd = (
float(raw_cap) if isinstance(raw_cap, int | float) else None
)
session.auto_approval_estimated_spend_usd = float(
meta.get("auto_approval_estimated_spend_usd") or 0.0
)
created_at = meta.get("created_at")
if not isinstance(created_at, datetime):
created_at = datetime.utcnow()
agent_session = AgentSession(
session_id=session_id,
session=session,
tool_router=tool_router,
submission_queue=submission_queue,
user_id=owner or user_id,
hf_username=hf_username,
hf_token=hf_token,
created_at=created_at,
is_active=True,
is_processing=False,
claude_counted=claude_counted,
title=meta.get("title"),
)
started = await self._start_agent_session(
agent_session=agent_session,
event_queue=event_queue,
tool_router=tool_router,
)
if started is not agent_session:
self._update_hf_identity(
started,
hf_token=hf_token,
hf_username=hf_username,
)
return started
if preload_sandbox:
self._start_cpu_sandbox_preload(agent_session)
logger.info("Restored session %s for user %s", session_id, owner or user_id)
return agent_session
async def create_session(
self,
user_id: str = "dev",
hf_username: str | None = None,
hf_token: str | None = None,
model: str | None = None,
is_pro: bool | None = None,
) -> str:
"""Create a new agent session and return its ID.
Session() and ToolRouter() constructors contain blocking I/O
(e.g. HfApi().whoami(), litellm.get_max_tokens()) so they are
executed in a thread pool to avoid freezing the async event loop.
Args:
user_id: The ID of the user who owns this session.
hf_username: The HF username/namespace used for personal trace uploads.
hf_token: The user's HF OAuth token, stored for tool execution.
model: Optional model override. When set, replaces ``model_name``
on the per-session config clone. None falls back to the
config default.
Raises:
SessionCapacityError: If the server or user has reached the
maximum number of concurrent sessions.
"""
# ── Capacity checks ──────────────────────────────────────────
# Reserve a global slot under the lock so concurrent creates can't all
# pass the check then over-admit past MAX_SESSIONS (the build + insert
# happen later, outside the lock). active_session_count won't reflect
# this session until _start_agent_session inserts it, so we count
# _pending_creates alongside it to close that gap.
async with self._lock:
active_count = self.active_session_count
projected = active_count + self._pending_creates
if projected >= MAX_SESSIONS:
raise SessionCapacityError(
f"Server is at capacity ({projected}/{MAX_SESSIONS} sessions). "
"Please try again later.",
error_type="global",
)
if user_id != "dev":
user_count = self._count_user_sessions(user_id)
if user_count >= MAX_SESSIONS_PER_USER:
raise SessionCapacityError(
f"You have reached the maximum of {MAX_SESSIONS_PER_USER} "
"concurrent sessions. Please close an existing session first.",
error_type="per_user",
)
self._pending_creates += 1
session_id = str(uuid.uuid4())
# Create queues for this session
submission_queue: asyncio.Queue = asyncio.Queue()
event_queue: asyncio.Queue = asyncio.Queue()
reserved = True
try:
# Run blocking constructors in a thread to keep the event loop responsive.
tool_router, session = await asyncio.to_thread(
self._create_session_sync,
session_id=session_id,
user_id=user_id,
hf_username=hf_username,
hf_token=hf_token,
model=model,
event_queue=event_queue,
)
# Create wrapper
agent_session = AgentSession(
session_id=session_id,
session=session,
tool_router=tool_router,
submission_queue=submission_queue,
user_id=user_id,
hf_username=hf_username,
hf_token=hf_token,
)
await self._start_agent_session(
agent_session=agent_session,
event_queue=event_queue,
tool_router=tool_router,
)
# The session is now in self.sessions, so active_session_count
# reflects it — release the reservation before the slower (and
# non-capacity) persistence + preload work.
async with self._lock:
self._pending_creates -= 1
reserved = False
await self.persist_session_snapshot(agent_session, runtime_state="idle")
self._start_cpu_sandbox_preload(agent_session)
if is_pro is not None and user_id and user_id != "dev":
await self._track_pro_status(agent_session, is_pro=is_pro)
logger.info(f"Created session {session_id} for user {user_id}")
return session_id
finally:
# Build/start failed before the session was inserted — always
# release the reservation so a failed create can't permanently
# shrink the pool.
if reserved:
async with self._lock:
self._pending_creates -= 1
async def _track_pro_status(
self, agent_session: AgentSession, *, is_pro: bool
) -> None:
"""Update Mongo per-user Pro state and emit a one-shot conversion
event if the store reports a free→Pro transition. Best-effort: any
Mongo failure is swallowed so we never fail session creation on
telemetry."""
store = self._store()
if not getattr(store, "enabled", False):
return
try:
result = await store.mark_pro_seen(agent_session.user_id, is_pro=is_pro)
except Exception as e:
logger.debug("mark_pro_seen failed: %s", e)
return
if not result or not result.get("converted"):
return
try:
from agent.core import telemetry
await telemetry.record_pro_conversion(
agent_session.session,
first_seen_at=result.get("first_seen_at"),
)
except Exception as e:
logger.debug("record_pro_conversion failed: %s", e)
async def seed_from_summary(self, session_id: str, messages: list[dict]) -> int:
"""Rehydrate a session from cached prior messages via summarization.
Runs the standard summarization prompt (same one compaction uses)
over the provided messages, then seeds the new session's context
with that summary. Tool-call pairing concerns disappear because the
output is plain text. Returns the number of messages summarized.
"""
from litellm import Message
from agent.context_manager.manager import _RESTORE_PROMPT, summarize_messages
agent_session = self.sessions.get(session_id)
if not agent_session:
raise ValueError(f"Session {session_id} not found")
# Parse into Message objects, tolerating malformed entries.
parsed: list[Message] = []
for raw in messages:
if raw.get("role") == "system":
continue # the new session has its own system prompt
try:
parsed.append(Message.model_validate(raw))
except Exception as e:
logger.warning("Dropping malformed message during seed: %s", e)
if not parsed:
return 0
session = agent_session.session
# Pass the real tool specs so the summarizer sees what the agent
# actually has. Without them, the summarizer can editorialize that
# original tool calls were fabricated.
tool_specs = None
try:
tool_specs = agent_session.tool_router.get_tool_specs_for_llm()
except Exception:
pass
try:
summary, _ = await summarize_messages(
parsed,
model_name=session.config.model_name,
hf_token=session.hf_token,
max_tokens=4000,
prompt=_RESTORE_PROMPT,
tool_specs=tool_specs,
session=session,
kind="restore",
)
except Exception as e:
logger.error("Summary call failed during seed: %s", e)
raise
seed = Message(
role="user",
content=(
"[SYSTEM: Your prior memory of this conversation — written "
"in your own voice right before restart. Continue from here.]\n\n"
+ (summary or "(no summary returned)")
),
)
session.context_manager.items.append(seed)
self._touch(agent_session)
await self.persist_session_snapshot(agent_session, runtime_state="idle")
return len(parsed)
@staticmethod
async def _cleanup_sandbox(session: Session) -> None:
"""Delete the sandbox Space if one was created for this session.
Retries on transient failures (HF API 5xx, rate-limit, network blips)
with exponential backoff. A single missed delete = a permanently
orphaned Space, so the cost of an extra retry beats the alternative.
"""
from agent.tools.sandbox_tool import teardown_session_sandbox
await teardown_session_sandbox(session)
async def _cleanup_all_sandboxes_on_close(self) -> None:
"""Best-effort sandbox cleanup for graceful backend shutdown."""
async with self._lock:
agent_sessions = list(self.sessions.values())
if not agent_sessions:
return
semaphore = asyncio.Semaphore(SANDBOX_SHUTDOWN_CLEANUP_CONCURRENCY)
async def _cleanup_one(agent_session: AgentSession) -> None:
async with semaphore:
try:
await self._cleanup_sandbox(agent_session.session)
except Exception as e:
logger.warning(
"Shutdown sandbox cleanup failed for %s: %s",
agent_session.session_id,
e,
)
tasks = [
asyncio.create_task(_cleanup_one(agent_session))
for agent_session in agent_sessions
]
try:
await asyncio.wait_for(
asyncio.gather(*tasks, return_exceptions=True),
timeout=SANDBOX_SHUTDOWN_CLEANUP_TIMEOUT_S,
)
except asyncio.TimeoutError:
logger.warning(
"Timed out after %.0fs cleaning up sandboxes on shutdown; "
"orphan sweeper will handle any stragglers",
SANDBOX_SHUTDOWN_CLEANUP_TIMEOUT_S,
)
async def _reaper_loop(self) -> None:
"""Periodically release resources held by idle sessions.
Modeled on EventBroadcaster.run: a long-lived task started in start()
and cancelled in close(). Per-sweep exceptions are swallowed so one bad
sweep never kills the loop.
"""
while True:
try:
await asyncio.sleep(REAPER_INTERVAL_S)
await self._reap_idle_sessions()
except asyncio.CancelledError:
break
except Exception as e:
logger.error("Idle-session reaper sweep failed: %s", e)
async def _reap_idle_sessions(self) -> None:
"""Select idle candidates under the lock, then tear each down.
Candidates are non-dev sessions that are live, not processing, not
awaiting tool approval (those are "approve later", not idle — reaping
would destroy the sandbox the approved tool needs), and untouched for
the idle window. We only snapshot IDs under the lock; the actual
teardown in _reap_one re-acquires it, because tearing a session down
while holding the lock would deadlock (the lock is non-reentrant).
"""
# Reaping is only safe when sessions stay resumable from Mongo. With no
# store, eviction would destroy non-dev chats outright, so don't reap.
if not getattr(self._store(), "enabled", False):
return
cutoff = datetime.utcnow() - REAPER_IDLE
async with self._lock:
candidates = [
agent_session.session_id
for agent_session in self.sessions.values()
if agent_session.is_active
and not agent_session.is_processing
and not agent_session.is_reaping
and agent_session.user_id != "dev"
and not agent_session.session.pending_approval
and agent_session.last_active_at <= cutoff
]
if not candidates:
return
reaped = 0
for session_id in candidates:
try:
if await self._reap_one(session_id, cutoff):
reaped += 1
except Exception as e:
logger.warning("Failed to reap idle session %s: %s", session_id, e)
if reaped:
logger.info("Reaped %d idle session(s)", reaped)
async def _reap_one(self, session_id: str, cutoff: datetime) -> bool:
"""Tear down one idle session, leaving it resumable from Mongo.
Re-checks every idle condition under the lock (a user may have become
active in the gap since selection), marks the session reaping, persists
a resumable snapshot outside the lock, then does one final locked
re-check before eviction. The runtime task is cancelled *outside* the
lock: its own ``finally`` frees the sandbox, and its identity-gated
persist no-ops because the session is already popped — so it can't
overwrite our resumable snapshot with ``"ended"`` and there's no
deadlock. Returns True if the session was reaped.
"""
async with self._lock:
agent_session = self.sessions.get(session_id)
if (
agent_session is None
or not agent_session.is_active
or agent_session.is_processing
or agent_session.is_reaping
or agent_session.session.pending_approval
or agent_session.last_active_at > cutoff
or not agent_session.submission_queue.empty()
):
return False
agent_session.is_reaping = True
# Persist a resumable snapshot *before* eviction so a concurrent reopen
# reloads clean state. status="active" (never "ended") keeps it a normal
# chat in the sidebar. Do this outside the manager lock: Mongo writes can
# take network round trips, and is_reaping=True is enough to block submit
# from enqueueing while the snapshot is in flight.
try:
await self.persist_session_snapshot(
agent_session,
runtime_state="idle",
status="active",
raise_on_error=True,
)
except Exception as e:
async with self._lock:
if self.sessions.get(session_id) is agent_session:
agent_session.is_reaping = False
logger.warning(
"Skipping reap of %s: could not persist resumable snapshot: %s",
session_id,
e,
)
return False
async with self._lock:
current = self.sessions.get(session_id)
if current is not agent_session:
return False
if (
not agent_session.is_active
or agent_session.is_processing
or agent_session.session.pending_approval
or agent_session.last_active_at > cutoff
or not agent_session.submission_queue.empty()
):
agent_session.is_reaping = False
return False
self.sessions.pop(session_id, None)
task = agent_session.task
session = agent_session.session
if task is not None and not task.done():
task.cancel()
# Use asyncio.wait, not wait_for: wait_for re-raises the cancelled
# task's CancelledError, which we'd have to swallow — and that same
# bare except would also eat an *outer* cancel aimed at the reaper
# itself (close() cancelling _reaper_task), hanging shutdown.
# asyncio.wait returns the cancelled task in `done` and lets an
# outer cancel propagate cleanly.
done, _pending = await asyncio.wait({task}, timeout=REAP_TEARDOWN_TIMEOUT_S)
if not done:
logger.warning(
"Reaper teardown timed out after %.0fs for %s; orphan "
"sweeper will handle any sandbox straggler",
REAP_TEARDOWN_TIMEOUT_S,
session_id,
)
elif not task.cancelled():
# Surface (and retrieve, to avoid "exception never retrieved")
# any non-cancellation teardown error.
exc = task.exception()
if exc is not None:
logger.warning("Reaper teardown error for %s: %s", session_id, exc)
else:
# No live task to run the cleanup finally — free the sandbox here so
# a reaped session never leaves an orphaned Space behind.
await self._cleanup_sandbox(session)
return True
async def _run_session(
self,
session_id: str,
submission_queue: asyncio.Queue,
event_queue: asyncio.Queue,
tool_router: ToolRouter,
) -> None:
"""Run the agent loop for a session and broadcast events via EventBroadcaster."""
agent_session = self.sessions.get(session_id)
if not agent_session:
logger.error(f"Session {session_id} not found")
return
session = agent_session.session
# Start event broadcaster task
broadcaster = EventBroadcaster(event_queue)
agent_session.broadcaster = broadcaster
broadcast_task = asyncio.create_task(broadcaster.run())
try:
async with tool_router:
# Send ready event
await session.send_event(
Event(event_type="ready", data={"message": "Agent initialized"})
)
while session.is_running:
try:
# Wait for submission with timeout to allow checking is_running
submission = await asyncio.wait_for(
submission_queue.get(), timeout=1.0
)
agent_session.is_processing = True
self._touch(agent_session)
try:
should_continue = await process_submission(
session, submission
)
finally:
agent_session.is_processing = False
# Stamp on turn finish too: a turn that ran longer
# than the idle window would otherwise be reaped the
# instant it completes.
self._touch(agent_session)
await self.persist_session_snapshot(agent_session)
if not should_continue:
break
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
logger.info(f"Session {session_id} cancelled")
break
except Exception as e:
logger.error(f"Error in session {session_id}: {e}")
await session.send_event(
Event(event_type="error", data={"error": str(e)})
)
finally:
broadcast_task.cancel()
try:
await broadcast_task
except asyncio.CancelledError:
pass
await self._cleanup_sandbox(session)
# Final-flush: always save on session death so we capture ended
# sessions even if the client disconnects without /shutdown.
# Idempotent via session_id key; detached subprocess.
if session.config.save_sessions:
try:
session.save_and_upload_detached(
session.config.session_dataset_repo
)
except Exception as e:
logger.warning(f"Final-flush failed for {session_id}: {e}")
async with self._lock:
if self.sessions.get(session_id) is agent_session:
agent_session.is_active = False
await self.persist_session_snapshot(
agent_session,
runtime_state="ended",
status="ended",
)
logger.info(f"Session {session_id} ended")
async def submit(self, session_id: str, operation: Operation) -> bool:
"""Submit an operation to a session.
Enqueues under the lock and rejects sessions being reaped, so submit
and reap can't interleave: either the message is enqueued before the
reaper's empty() re-check (which then aborts the reap), or the session
is already popped (we return False and the caller reloads a fresh
runtime from Mongo). The queue is unbounded, so put_nowait never blocks.
"""
submission = Submission(id=f"sub_{uuid.uuid4().hex[:8]}", operation=operation)
async with self._lock:
agent_session = self.sessions.get(session_id)
if (
not agent_session
or not agent_session.is_active
or agent_session.is_reaping
):
logger.warning(f"Session {session_id} not found or inactive")
return False
agent_session.submission_queue.put_nowait(submission)
self._touch(agent_session)
return True
async def submit_user_input(self, session_id: str, text: str) -> bool:
"""Submit user input to a session."""
operation = Operation(op_type=OpType.USER_INPUT, data={"text": text})
return await self.submit(session_id, operation)
async def submit_approval(
self, session_id: str, approvals: list[dict[str, Any]]
) -> bool:
"""Submit tool approvals to a session."""
operation = Operation(
op_type=OpType.EXEC_APPROVAL, data={"approvals": approvals}
)
return await self.submit(session_id, operation)
async def interrupt(self, session_id: str) -> bool:
"""Interrupt a session by signalling cancellation directly (bypasses queue)."""
agent_session = self.sessions.get(session_id)
if not agent_session or not agent_session.is_active:
return False
agent_session.session.cancel()
return True
async def undo(self, session_id: str) -> bool:
"""Undo last turn in a session."""
operation = Operation(op_type=OpType.UNDO)
return await self.submit(session_id, operation)
async def truncate(self, session_id: str, user_message_index: int) -> bool:
"""Truncate conversation to before a specific user message (direct, no queue)."""
async with self._lock:
agent_session = self.sessions.get(session_id)
if not agent_session or not agent_session.is_active:
return False
success = agent_session.session.context_manager.truncate_to_user_message(
user_message_index
)
if success:
self._touch(agent_session)
await self.persist_session_snapshot(agent_session, runtime_state="idle")
return success
async def compact(self, session_id: str) -> bool:
"""Compact context in a session."""
operation = Operation(op_type=OpType.COMPACT)
return await self.submit(session_id, operation)
async def shutdown_session(self, session_id: str) -> bool:
"""Shutdown a specific session."""
operation = Operation(op_type=OpType.SHUTDOWN)
success = await self.submit(session_id, operation)
if success:
async with self._lock:
agent_session = self.sessions.get(session_id)
if agent_session and agent_session.task:
# Wait for task to complete
try:
await asyncio.wait_for(agent_session.task, timeout=5.0)
except asyncio.TimeoutError:
agent_session.task.cancel()
return success
async def delete_session(self, session_id: str) -> bool:
"""Soft-delete a session and stop its runtime resources."""
async with self._lock:
agent_session = self.sessions.pop(session_id, None)
if not agent_session:
await self._store().soft_delete_session(session_id)
return True
await self._store().soft_delete_session(session_id)
# Clean up sandbox Space before cancelling the task
await self._cleanup_sandbox(agent_session.session)
# Cancel the task if running
if agent_session.task and not agent_session.task.done():
agent_session.task.cancel()
try:
await agent_session.task
except asyncio.CancelledError:
pass
return True
async def teardown_sandbox(self, session_id: str) -> bool:
"""Delete only this session's sandbox runtime, preserving chat state."""
async with self._lock:
agent_session = self.sessions.get(session_id)
if not agent_session or not agent_session.is_active:
return False
await self._cleanup_sandbox(agent_session.session)
await self.persist_session_snapshot(agent_session, runtime_state="idle")
return True
async def update_session_title(self, session_id: str, title: str | None) -> None:
"""Persist a user-visible title for sidebar rehydration."""
agent_session = self.sessions.get(session_id)
if agent_session:
agent_session.title = title
await self._store().update_session_fields(session_id, title=title)
async def update_session_model(self, session_id: str, model_id: str) -> bool:
agent_session = self.sessions.get(session_id)
if not agent_session or not agent_session.is_active:
return False
agent_session.session.update_model(model_id)
self._touch(agent_session)
await self.persist_session_snapshot(agent_session, runtime_state="idle")
return True
async def update_session_auto_approval(
self,
session_id: str,
*,
enabled: bool,
cost_cap_usd: float | None,
cap_provided: bool = False,
) -> dict[str, Any]:
agent_session = self.sessions.get(session_id)
if not agent_session or not agent_session.is_active:
raise ValueError("Session not found or inactive")
session = agent_session.session
if enabled:
if not cap_provided and cost_cap_usd is None:
cost_cap_usd = getattr(session, "auto_approval_cost_cap_usd", None)
if cost_cap_usd is None:
cost_cap_usd = DEFAULT_YOLO_COST_CAP_USD
elif cost_cap_usd is None:
cost_cap_usd = DEFAULT_YOLO_COST_CAP_USD
else:
if not cap_provided:
cost_cap_usd = getattr(session, "auto_approval_cost_cap_usd", None)
if hasattr(session, "set_auto_approval_policy"):
session.set_auto_approval_policy(
enabled=enabled,
cost_cap_usd=cost_cap_usd,
)
else:
session.auto_approval_enabled = bool(enabled)
session.auto_approval_cost_cap_usd = cost_cap_usd
self._touch(agent_session)
await self.persist_session_snapshot(agent_session)
return self._auto_approval_summary(session)
def get_session_owner(self, session_id: str) -> str | None:
"""Get the user_id that owns a session, or None if session doesn't exist."""
agent_session = self.sessions.get(session_id)
if not agent_session:
return None
return agent_session.user_id
def verify_session_access(self, session_id: str, user_id: str) -> bool:
"""Check if a user has access to a session.
Returns True if:
- The session exists AND the user owns it
- The user_id is "dev" (dev mode bypass)
"""
owner = self.get_session_owner(session_id)
if owner is None:
return False
if user_id == "dev" or owner == "dev":
return True
return owner == user_id
def get_session_info(self, session_id: str) -> dict[str, Any] | None:
"""Get information about a session."""
agent_session = self.sessions.get(session_id)
if not agent_session:
return None
pending_approval = self._pending_tools_for_api(agent_session.session)
return {
"session_id": session_id,
"created_at": agent_session.created_at.isoformat(),
"is_active": agent_session.is_active,
"is_processing": agent_session.is_processing,
"message_count": len(agent_session.session.context_manager.items),
"user_id": agent_session.user_id,
"pending_approval": pending_approval,
"model": agent_session.session.config.model_name,
"title": agent_session.title,
"notification_destinations": list(
agent_session.session.notification_destinations
),
"auto_approval": self._auto_approval_summary(agent_session.session),
"premium_user_billed": getattr(
agent_session.session, "premium_user_billed", False
),
"premium_quota_counted": agent_session.claude_counted,
}
def set_notification_destinations(
self, session_id: str, destinations: list[str]
) -> list[str]:
"""Replace the session's opted-in auto-notification destinations."""
agent_session = self.sessions.get(session_id)
if not agent_session or not agent_session.is_active:
raise ValueError("Session not found or inactive")
normalized: list[str] = []
seen: set[str] = set()
for raw_name in destinations:
name = raw_name.strip()
if not name:
raise ValueError("Destination names must not be empty")
destination = self.config.messaging.get_destination(name)
if destination is None:
raise ValueError(f"Unknown destination '{name}'")
if not destination.allow_auto_events:
raise ValueError(f"Destination '{name}' is not enabled for auto events")
if name not in seen:
normalized.append(name)
seen.add(name)
agent_session.session.set_notification_destinations(normalized)
self._touch(agent_session)
return normalized
async def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]:
"""List sessions, optionally filtered by user.
Args:
user_id: If provided, only return sessions owned by this user.
If "dev", return all sessions (dev mode).
"""
results: list[dict[str, Any]] = []
store = self._store()
if getattr(store, "enabled", False):
for row in await store.list_sessions(user_id or "dev"):
sid = row.get("session_id") or row.get("_id")
if not sid:
continue
runtime_info = self.get_session_info(str(sid))
if runtime_info:
results.append(runtime_info)
continue
created_at = row.get("created_at")
if isinstance(created_at, datetime):
created_at_str = created_at.isoformat()
else:
created_at_str = str(created_at or datetime.utcnow().isoformat())
pending = self._pending_docs_for_api(row.get("pending_approval") or [])
results.append(
{
"session_id": str(sid),
"created_at": created_at_str,
"is_active": row.get("status") != "ended",
"is_processing": row.get("runtime_state") == "processing",
"message_count": int(row.get("message_count") or 0),
"user_id": row.get("user_id") or "dev",
"pending_approval": pending or None,
"model": row.get("model"),
"title": row.get("title"),
"premium_user_billed": bool(
row.get("premium_user_billed", False)
),
"premium_quota_counted": bool(row.get("claude_counted", False)),
"notification_destinations": row.get(
"notification_destinations"
)
or [],
"auto_approval": {
"enabled": bool(row.get("auto_approval_enabled", False)),
"cost_cap_usd": row.get("auto_approval_cost_cap_usd"),
"estimated_spend_usd": float(
row.get("auto_approval_estimated_spend_usd") or 0.0
),
"remaining_usd": (
None
if row.get("auto_approval_cost_cap_usd") is None
else round(
max(
0.0,
float(
row.get("auto_approval_cost_cap_usd") or 0.0
)
- float(
row.get("auto_approval_estimated_spend_usd")
or 0.0
),
),
4,
)
),
},
}
)
return results
for sid in self.sessions:
info = self.get_session_info(sid)
if not info:
continue
if user_id and user_id != "dev" and info.get("user_id") != user_id:
continue
results.append(info)
return results
@property
def active_session_count(self) -> int:
"""Get count of active sessions."""
return sum(1 for s in self.sessions.values() if s.is_active)
# Global session manager instance
session_manager = SessionManager()