Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
Fix: Enforce session capacity on restore and prevent session-creation race
#41
by omshrivastava - opened
- backend/session_manager.py +364 -151
- tests/unit/test_session_capacity.py +152 -0
backend/session_manager.py
CHANGED
|
@@ -6,7 +6,7 @@ import logging
|
|
| 6 |
import os
|
| 7 |
import uuid
|
| 8 |
from dataclasses import dataclass, field
|
| 9 |
-
from datetime import datetime
|
| 10 |
from pathlib import Path
|
| 11 |
from typing import Any, Optional
|
| 12 |
|
|
@@ -100,6 +100,8 @@ class AgentSession:
|
|
| 100 |
is_processing: bool = False # True while a submission is being executed
|
| 101 |
broadcaster: Any = None
|
| 102 |
title: str | None = None
|
|
|
|
|
|
|
| 103 |
# True once this session has been counted against the user's daily
|
| 104 |
# Claude quota. Guards double-counting when the user re-selects an
|
| 105 |
# Anthropic model mid-session.
|
|
@@ -122,6 +124,8 @@ class SessionCapacityError(Exception):
|
|
| 122 |
MAX_SESSIONS: int = 200
|
| 123 |
MAX_SESSIONS_PER_USER: int = 10
|
| 124 |
DEFAULT_YOLO_COST_CAP_USD: float = 5.0
|
|
|
|
|
|
|
| 125 |
SANDBOX_SHUTDOWN_CLEANUP_CONCURRENCY: int = 10
|
| 126 |
SANDBOX_SHUTDOWN_CLEANUP_TIMEOUT_S: float = 60.0
|
| 127 |
|
|
@@ -141,10 +145,19 @@ class SessionManager:
|
|
| 141 |
self.persistence_store = get_session_store()
|
| 142 |
await self.persistence_store.init()
|
| 143 |
await self.messaging_gateway.start()
|
|
|
|
|
|
|
| 144 |
|
| 145 |
async def close(self) -> None:
|
| 146 |
"""Flush and close shared background resources."""
|
| 147 |
await self._cleanup_all_sandboxes_on_close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
await self.messaging_gateway.close()
|
| 149 |
if self.persistence_store is not None:
|
| 150 |
await self.persistence_store.close()
|
|
@@ -197,6 +210,40 @@ class SessionManager:
|
|
| 197 |
logger.info("Session initialized in %.2fs", t1 - t0)
|
| 198 |
return tool_router, session
|
| 199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
def _serialize_messages(self, session: Session) -> list[dict[str, Any]]:
|
| 201 |
return [msg.model_dump(mode="json") for msg in session.context_manager.items]
|
| 202 |
|
|
@@ -327,8 +374,17 @@ class SessionManager:
|
|
| 327 |
async with self._lock:
|
| 328 |
existing = self.sessions.get(agent_session.session_id)
|
| 329 |
if existing:
|
| 330 |
-
|
| 331 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
|
| 333 |
task = asyncio.create_task(
|
| 334 |
self._run_session(
|
|
@@ -339,6 +395,8 @@ class SessionManager:
|
|
| 339 |
)
|
| 340 |
)
|
| 341 |
agent_session.task = task
|
|
|
|
|
|
|
| 342 |
return agent_session
|
| 343 |
|
| 344 |
@staticmethod
|
|
@@ -494,6 +552,82 @@ class SessionManager:
|
|
| 494 |
last_err,
|
| 495 |
)
|
| 496 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 497 |
async def persist_session_snapshot(
|
| 498 |
self,
|
| 499 |
agent_session: AgentSession,
|
|
@@ -554,131 +688,171 @@ class SessionManager:
|
|
| 554 |
preload_sandbox: bool = True,
|
| 555 |
) -> AgentSession | None:
|
| 556 |
"""Return a live runtime session, lazily restoring it from Mongo."""
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
self._update_hf_identity(
|
| 562 |
-
existing,
|
| 563 |
-
hf_token=hf_token,
|
| 564 |
-
hf_username=hf_username,
|
| 565 |
-
)
|
| 566 |
-
self._restart_cpu_preload_if_token_recovered(
|
| 567 |
-
existing,
|
| 568 |
-
preload_sandbox=preload_sandbox,
|
| 569 |
-
)
|
| 570 |
-
return existing
|
| 571 |
-
return None
|
| 572 |
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 577 |
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 585 |
hf_username=hf_username,
|
|
|
|
|
|
|
| 586 |
)
|
| 587 |
-
self.
|
| 588 |
-
|
| 589 |
-
preload_sandbox=preload_sandbox,
|
| 590 |
-
)
|
| 591 |
-
return existing
|
| 592 |
-
return None
|
| 593 |
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 606 |
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
self._create_session_sync,
|
| 612 |
-
session_id=session_id,
|
| 613 |
-
user_id=owner or user_id,
|
| 614 |
-
hf_username=hf_username,
|
| 615 |
-
hf_token=hf_token,
|
| 616 |
-
model=model,
|
| 617 |
-
event_queue=event_queue,
|
| 618 |
-
notification_destinations=meta.get("notification_destinations") or [],
|
| 619 |
-
)
|
| 620 |
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
restored_messages.append(Message.model_validate(raw))
|
| 627 |
-
except Exception as e:
|
| 628 |
-
logger.warning("Dropping malformed restored message: %s", e)
|
| 629 |
-
if restored_messages:
|
| 630 |
-
# Keep the freshly-rendered system prompt, then attach the durable
|
| 631 |
-
# non-system context so tools/date/user context stay current.
|
| 632 |
-
session.context_manager.items = [
|
| 633 |
-
session.context_manager.items[0],
|
| 634 |
-
*restored_messages,
|
| 635 |
-
]
|
| 636 |
-
|
| 637 |
-
self._restore_pending_approval(session, meta.get("pending_approval") or [])
|
| 638 |
-
session.turn_count = int(meta.get("turn_count") or 0)
|
| 639 |
-
session.auto_approval_enabled = bool(meta.get("auto_approval_enabled", False))
|
| 640 |
-
raw_cap = meta.get("auto_approval_cost_cap_usd")
|
| 641 |
-
session.auto_approval_cost_cap_usd = (
|
| 642 |
-
float(raw_cap) if isinstance(raw_cap, int | float) else None
|
| 643 |
-
)
|
| 644 |
-
session.auto_approval_estimated_spend_usd = float(
|
| 645 |
-
meta.get("auto_approval_estimated_spend_usd") or 0.0
|
| 646 |
-
)
|
| 647 |
|
| 648 |
-
|
| 649 |
-
if not isinstance(created_at, datetime):
|
| 650 |
-
created_at = datetime.utcnow()
|
| 651 |
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
hf_username=hf_username,
|
| 659 |
-
hf_token=hf_token,
|
| 660 |
-
created_at=created_at,
|
| 661 |
-
is_active=True,
|
| 662 |
-
is_processing=False,
|
| 663 |
-
claude_counted=bool(meta.get("claude_counted")),
|
| 664 |
-
title=meta.get("title"),
|
| 665 |
-
)
|
| 666 |
-
started = await self._start_agent_session(
|
| 667 |
-
agent_session=agent_session,
|
| 668 |
-
event_queue=event_queue,
|
| 669 |
-
tool_router=tool_router,
|
| 670 |
-
)
|
| 671 |
-
if started is not agent_session:
|
| 672 |
-
self._update_hf_identity(
|
| 673 |
-
started,
|
| 674 |
hf_token=hf_token,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 675 |
hf_username=hf_username,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 676 |
)
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 682 |
|
| 683 |
async def create_session(
|
| 684 |
self,
|
|
@@ -706,7 +880,11 @@ class SessionManager:
|
|
| 706 |
SessionCapacityError: If the server or user has reached the
|
| 707 |
maximum number of concurrent sessions.
|
| 708 |
"""
|
| 709 |
-
# ── Capacity checks ───────────────────────────
|
|
|
|
|
|
|
|
|
|
|
|
|
| 710 |
async with self._lock:
|
| 711 |
active_count = self.active_session_count
|
| 712 |
if active_count >= MAX_SESSIONS:
|
|
@@ -724,47 +902,82 @@ class SessionManager:
|
|
| 724 |
error_type="per_user",
|
| 725 |
)
|
| 726 |
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 732 |
|
| 733 |
# Run blocking constructors in a thread to keep the event loop responsive.
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
|
|
|
|
|
|
| 743 |
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
|
|
|
| 754 |
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
| 762 |
|
| 763 |
-
|
| 764 |
-
|
| 765 |
|
| 766 |
-
|
| 767 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 768 |
|
| 769 |
async def _track_pro_status(
|
| 770 |
self, agent_session: AgentSession, *, is_pro: bool
|
|
|
|
| 6 |
import os
|
| 7 |
import uuid
|
| 8 |
from dataclasses import dataclass, field
|
| 9 |
+
from datetime import datetime, timedelta
|
| 10 |
from pathlib import Path
|
| 11 |
from typing import Any, Optional
|
| 12 |
|
|
|
|
| 100 |
is_processing: bool = False # True while a submission is being executed
|
| 101 |
broadcaster: Any = None
|
| 102 |
title: str | None = None
|
| 103 |
+
# Last time this session was accessed (for idle-unload logic)
|
| 104 |
+
last_access: datetime = field(default_factory=datetime.utcnow)
|
| 105 |
# True once this session has been counted against the user's daily
|
| 106 |
# Claude quota. Guards double-counting when the user re-selects an
|
| 107 |
# Anthropic model mid-session.
|
|
|
|
| 124 |
MAX_SESSIONS: int = 200
|
| 125 |
MAX_SESSIONS_PER_USER: int = 10
|
| 126 |
DEFAULT_YOLO_COST_CAP_USD: float = 5.0
|
| 127 |
+
INACTIVE_SESSION_SWEEP_INTERVAL_SECONDS: int = 600
|
| 128 |
+
INACTIVE_SESSION_IDLE_THRESHOLD: timedelta = timedelta(hours=24)
|
| 129 |
SANDBOX_SHUTDOWN_CLEANUP_CONCURRENCY: int = 10
|
| 130 |
SANDBOX_SHUTDOWN_CLEANUP_TIMEOUT_S: float = 60.0
|
| 131 |
|
|
|
|
| 145 |
self.persistence_store = get_session_store()
|
| 146 |
await self.persistence_store.init()
|
| 147 |
await self.messaging_gateway.start()
|
| 148 |
+
# Start background cleanup task to unload long-idle sessions.
|
| 149 |
+
self._cleanup_task = asyncio.create_task(self._unload_inactive_sessions_loop())
|
| 150 |
|
| 151 |
async def close(self) -> None:
|
| 152 |
"""Flush and close shared background resources."""
|
| 153 |
await self._cleanup_all_sandboxes_on_close()
|
| 154 |
+
# Cancel cleanup task
|
| 155 |
+
if getattr(self, "_cleanup_task", None):
|
| 156 |
+
self._cleanup_task.cancel()
|
| 157 |
+
try:
|
| 158 |
+
await self._cleanup_task
|
| 159 |
+
except asyncio.CancelledError:
|
| 160 |
+
pass
|
| 161 |
await self.messaging_gateway.close()
|
| 162 |
if self.persistence_store is not None:
|
| 163 |
await self.persistence_store.close()
|
|
|
|
| 210 |
logger.info("Session initialized in %.2fs", t1 - t0)
|
| 211 |
return tool_router, session
|
| 212 |
|
| 213 |
+
def _make_reserved_session(
|
| 214 |
+
self,
|
| 215 |
+
*,
|
| 216 |
+
session_id: str,
|
| 217 |
+
user_id: str,
|
| 218 |
+
hf_username: str | None,
|
| 219 |
+
hf_token: str | None,
|
| 220 |
+
submission_queue: asyncio.Queue,
|
| 221 |
+
) -> AgentSession:
|
| 222 |
+
"""Create a placeholder session that reserves capacity under lock."""
|
| 223 |
+
return AgentSession(
|
| 224 |
+
session_id=session_id,
|
| 225 |
+
session=None,
|
| 226 |
+
tool_router=None,
|
| 227 |
+
submission_queue=submission_queue,
|
| 228 |
+
user_id=user_id,
|
| 229 |
+
hf_username=hf_username,
|
| 230 |
+
hf_token=hf_token,
|
| 231 |
+
is_active=True,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
async def _release_reserved_session_slot(
|
| 235 |
+
self,
|
| 236 |
+
session_id: str,
|
| 237 |
+
reserved_session: AgentSession | None = None,
|
| 238 |
+
) -> None:
|
| 239 |
+
"""Remove a reserved placeholder if it is still present."""
|
| 240 |
+
async with self._lock:
|
| 241 |
+
current = self.sessions.get(session_id)
|
| 242 |
+
if current is None:
|
| 243 |
+
return
|
| 244 |
+
if current is reserved_session or getattr(current, "session", None) is None:
|
| 245 |
+
self.sessions.pop(session_id, None)
|
| 246 |
+
|
| 247 |
def _serialize_messages(self, session: Session) -> list[dict[str, Any]]:
|
| 248 |
return [msg.model_dump(mode="json") for msg in session.context_manager.items]
|
| 249 |
|
|
|
|
| 374 |
async with self._lock:
|
| 375 |
existing = self.sessions.get(agent_session.session_id)
|
| 376 |
if existing:
|
| 377 |
+
# If an earlier coroutine reserved the slot with a placeholder
|
| 378 |
+
# AgentSession (session is None), replace it with the real
|
| 379 |
+
# agent_session. Otherwise return the existing live session.
|
| 380 |
+
if getattr(existing, "session", None) is None:
|
| 381 |
+
self.sessions[agent_session.session_id] = agent_session
|
| 382 |
+
else:
|
| 383 |
+
# Update access time for the existing live session
|
| 384 |
+
existing.last_access = datetime.utcnow()
|
| 385 |
+
return existing
|
| 386 |
+
else:
|
| 387 |
+
self.sessions[agent_session.session_id] = agent_session
|
| 388 |
|
| 389 |
task = asyncio.create_task(
|
| 390 |
self._run_session(
|
|
|
|
| 395 |
)
|
| 396 |
)
|
| 397 |
agent_session.task = task
|
| 398 |
+
# mark last access time when the session has been started
|
| 399 |
+
agent_session.last_access = datetime.utcnow()
|
| 400 |
return agent_session
|
| 401 |
|
| 402 |
@staticmethod
|
|
|
|
| 552 |
last_err,
|
| 553 |
)
|
| 554 |
|
| 555 |
+
async def _unload_inactive_sessions_loop(self) -> None:
|
| 556 |
+
"""Background task: unload sessions that have been idle for too long.
|
| 557 |
+
|
| 558 |
+
Sessions idle for more than 24 hours are persisted and removed from
|
| 559 |
+
memory to free capacity. The loop runs periodically (every 10 minutes).
|
| 560 |
+
"""
|
| 561 |
+
try:
|
| 562 |
+
while True:
|
| 563 |
+
await asyncio.sleep(INACTIVE_SESSION_SWEEP_INTERVAL_SECONDS)
|
| 564 |
+
now = datetime.utcnow()
|
| 565 |
+
to_unload: list[tuple[str, AgentSession]] = []
|
| 566 |
+
async with self._lock:
|
| 567 |
+
for sid, agent_session in list(self.sessions.items()):
|
| 568 |
+
try:
|
| 569 |
+
if not agent_session.is_active:
|
| 570 |
+
continue
|
| 571 |
+
if getattr(agent_session, "is_processing", False):
|
| 572 |
+
continue
|
| 573 |
+
last = getattr(
|
| 574 |
+
agent_session,
|
| 575 |
+
"last_access",
|
| 576 |
+
agent_session.created_at,
|
| 577 |
+
)
|
| 578 |
+
if now - last > INACTIVE_SESSION_IDLE_THRESHOLD:
|
| 579 |
+
# Mark inactive, but keep it resident until the
|
| 580 |
+
# snapshot has been persisted successfully.
|
| 581 |
+
agent_session.is_active = False
|
| 582 |
+
to_unload.append((sid, agent_session))
|
| 583 |
+
except Exception as e:
|
| 584 |
+
logger.debug("Skipping unload check for %s: %s", sid, e)
|
| 585 |
+
|
| 586 |
+
for sid, agent_session in to_unload:
|
| 587 |
+
try:
|
| 588 |
+
await self.persist_session_snapshot(
|
| 589 |
+
agent_session,
|
| 590 |
+
runtime_state=self._runtime_state(agent_session),
|
| 591 |
+
status="inactive",
|
| 592 |
+
)
|
| 593 |
+
except Exception as e:
|
| 594 |
+
logger.warning(
|
| 595 |
+
"Failed to persist snapshot before unloading %s: %s",
|
| 596 |
+
sid,
|
| 597 |
+
e,
|
| 598 |
+
)
|
| 599 |
+
# Keep the session in memory so the next cleanup cycle
|
| 600 |
+
# can retry persistence and so callers can still inspect
|
| 601 |
+
# the session state.
|
| 602 |
+
agent_session.is_active = True
|
| 603 |
+
continue
|
| 604 |
+
|
| 605 |
+
removed = False
|
| 606 |
+
async with self._lock:
|
| 607 |
+
current = self.sessions.get(sid)
|
| 608 |
+
if current is agent_session:
|
| 609 |
+
self.sessions.pop(sid, None)
|
| 610 |
+
removed = True
|
| 611 |
+
|
| 612 |
+
if not removed:
|
| 613 |
+
# Session was replaced or revived while we were persisting.
|
| 614 |
+
continue
|
| 615 |
+
|
| 616 |
+
# Cancel the running task if present once the snapshot is safe.
|
| 617 |
+
try:
|
| 618 |
+
if agent_session.task:
|
| 619 |
+
agent_session.task.cancel()
|
| 620 |
+
try:
|
| 621 |
+
await asyncio.wait_for(agent_session.task, timeout=5)
|
| 622 |
+
except Exception:
|
| 623 |
+
pass
|
| 624 |
+
except Exception:
|
| 625 |
+
pass
|
| 626 |
+
|
| 627 |
+
logger.info("Unloaded inactive session %s due to inactivity", sid)
|
| 628 |
+
except asyncio.CancelledError:
|
| 629 |
+
return
|
| 630 |
+
|
| 631 |
async def persist_session_snapshot(
|
| 632 |
self,
|
| 633 |
agent_session: AgentSession,
|
|
|
|
| 688 |
preload_sandbox: bool = True,
|
| 689 |
) -> AgentSession | None:
|
| 690 |
"""Return a live runtime session, lazily restoring it from Mongo."""
|
| 691 |
+
submission_queue: asyncio.Queue = asyncio.Queue()
|
| 692 |
+
event_queue: asyncio.Queue = asyncio.Queue()
|
| 693 |
+
reserved_session: AgentSession | None = None
|
| 694 |
+
should_release_reserved_slot = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 695 |
|
| 696 |
+
try:
|
| 697 |
+
async with self._lock:
|
| 698 |
+
existing = self.sessions.get(session_id)
|
| 699 |
+
if existing:
|
| 700 |
+
if getattr(existing, "session", None) is None:
|
| 701 |
+
return None
|
| 702 |
+
if self._can_access_session(existing, user_id):
|
| 703 |
+
self._update_hf_identity(
|
| 704 |
+
existing,
|
| 705 |
+
hf_token=hf_token,
|
| 706 |
+
hf_username=hf_username,
|
| 707 |
+
)
|
| 708 |
+
self._restart_cpu_preload_if_token_recovered(
|
| 709 |
+
existing,
|
| 710 |
+
preload_sandbox=preload_sandbox,
|
| 711 |
+
)
|
| 712 |
+
existing.last_access = datetime.utcnow()
|
| 713 |
+
return existing
|
| 714 |
+
return None
|
| 715 |
|
| 716 |
+
active_count = self.active_session_count
|
| 717 |
+
if active_count >= MAX_SESSIONS:
|
| 718 |
+
logger.warning(
|
| 719 |
+
"Cannot restore session %s: server at capacity (%d/%d)",
|
| 720 |
+
session_id,
|
| 721 |
+
active_count,
|
| 722 |
+
MAX_SESSIONS,
|
| 723 |
+
)
|
| 724 |
+
return None
|
| 725 |
+
|
| 726 |
+
reserved_session = self._make_reserved_session(
|
| 727 |
+
session_id=session_id,
|
| 728 |
+
user_id=user_id,
|
| 729 |
hf_username=hf_username,
|
| 730 |
+
hf_token=hf_token,
|
| 731 |
+
submission_queue=submission_queue,
|
| 732 |
)
|
| 733 |
+
self.sessions[session_id] = reserved_session
|
| 734 |
+
should_release_reserved_slot = True
|
|
|
|
|
|
|
|
|
|
|
|
|
| 735 |
|
| 736 |
+
store = self._store()
|
| 737 |
+
loaded = await store.load_session(session_id)
|
| 738 |
+
if not loaded:
|
| 739 |
+
return None
|
| 740 |
|
| 741 |
+
async with self._lock:
|
| 742 |
+
existing = self.sessions.get(session_id)
|
| 743 |
+
if existing is not reserved_session:
|
| 744 |
+
if existing and getattr(existing, "session", None) is not None:
|
| 745 |
+
if self._can_access_session(existing, user_id):
|
| 746 |
+
self._update_hf_identity(
|
| 747 |
+
existing,
|
| 748 |
+
hf_token=hf_token,
|
| 749 |
+
hf_username=hf_username,
|
| 750 |
+
)
|
| 751 |
+
self._restart_cpu_preload_if_token_recovered(
|
| 752 |
+
existing,
|
| 753 |
+
preload_sandbox=preload_sandbox,
|
| 754 |
+
)
|
| 755 |
+
existing.last_access = datetime.utcnow()
|
| 756 |
+
should_release_reserved_slot = False
|
| 757 |
+
return existing
|
| 758 |
+
return None
|
| 759 |
|
| 760 |
+
meta = loaded.get("metadata") or {}
|
| 761 |
+
owner = str(meta.get("user_id") or "")
|
| 762 |
+
if user_id != "dev" and owner != "dev" and owner != user_id:
|
| 763 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 764 |
|
| 765 |
+
await self._cleanup_persisted_sandbox(
|
| 766 |
+
session_id,
|
| 767 |
+
meta,
|
| 768 |
+
hf_token=hf_token,
|
| 769 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 770 |
|
| 771 |
+
from litellm import Message
|
|
|
|
|
|
|
| 772 |
|
| 773 |
+
model = meta.get("model") or self.config.model_name
|
| 774 |
+
tool_router, session = await asyncio.to_thread(
|
| 775 |
+
self._create_session_sync,
|
| 776 |
+
session_id=session_id,
|
| 777 |
+
user_id=owner or user_id,
|
| 778 |
+
hf_username=hf_username,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 779 |
hf_token=hf_token,
|
| 780 |
+
model=model,
|
| 781 |
+
event_queue=event_queue,
|
| 782 |
+
notification_destinations=meta.get("notification_destinations") or [],
|
| 783 |
+
)
|
| 784 |
+
|
| 785 |
+
restored_messages: list[Message] = []
|
| 786 |
+
for raw in loaded.get("messages") or []:
|
| 787 |
+
if not isinstance(raw, dict) or raw.get("role") == "system":
|
| 788 |
+
continue
|
| 789 |
+
try:
|
| 790 |
+
restored_messages.append(Message.model_validate(raw))
|
| 791 |
+
except Exception as e:
|
| 792 |
+
logger.warning("Dropping malformed restored message: %s", e)
|
| 793 |
+
if restored_messages:
|
| 794 |
+
# Keep the freshly-rendered system prompt, then attach the durable
|
| 795 |
+
# non-system context so tools/date/user context stay current.
|
| 796 |
+
session.context_manager.items = [
|
| 797 |
+
session.context_manager.items[0],
|
| 798 |
+
*restored_messages,
|
| 799 |
+
]
|
| 800 |
+
|
| 801 |
+
self._restore_pending_approval(session, meta.get("pending_approval") or [])
|
| 802 |
+
session.turn_count = int(meta.get("turn_count") or 0)
|
| 803 |
+
session.auto_approval_enabled = bool(
|
| 804 |
+
meta.get("auto_approval_enabled", False)
|
| 805 |
+
)
|
| 806 |
+
raw_cap = meta.get("auto_approval_cost_cap_usd")
|
| 807 |
+
session.auto_approval_cost_cap_usd = (
|
| 808 |
+
float(raw_cap) if isinstance(raw_cap, int | float) else None
|
| 809 |
+
)
|
| 810 |
+
session.auto_approval_estimated_spend_usd = float(
|
| 811 |
+
meta.get("auto_approval_estimated_spend_usd") or 0.0
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
+
created_at = meta.get("created_at")
|
| 815 |
+
if not isinstance(created_at, datetime):
|
| 816 |
+
created_at = datetime.utcnow()
|
| 817 |
+
|
| 818 |
+
agent_session = AgentSession(
|
| 819 |
+
session_id=session_id,
|
| 820 |
+
session=session,
|
| 821 |
+
tool_router=tool_router,
|
| 822 |
+
submission_queue=submission_queue,
|
| 823 |
+
user_id=owner or user_id,
|
| 824 |
hf_username=hf_username,
|
| 825 |
+
hf_token=hf_token,
|
| 826 |
+
created_at=created_at,
|
| 827 |
+
is_active=True,
|
| 828 |
+
is_processing=False,
|
| 829 |
+
claude_counted=bool(meta.get("claude_counted")),
|
| 830 |
+
title=meta.get("title"),
|
| 831 |
)
|
| 832 |
+
started = await self._start_agent_session(
|
| 833 |
+
agent_session=agent_session,
|
| 834 |
+
event_queue=event_queue,
|
| 835 |
+
tool_router=tool_router,
|
| 836 |
+
)
|
| 837 |
+
if started is not agent_session:
|
| 838 |
+
self._update_hf_identity(
|
| 839 |
+
started,
|
| 840 |
+
hf_token=hf_token,
|
| 841 |
+
hf_username=hf_username,
|
| 842 |
+
)
|
| 843 |
+
started.last_access = datetime.utcnow()
|
| 844 |
+
should_release_reserved_slot = False
|
| 845 |
+
return started
|
| 846 |
+
if preload_sandbox:
|
| 847 |
+
self._start_cpu_sandbox_preload(agent_session)
|
| 848 |
+
logger.info("Restored session %s for user %s", session_id, owner or user_id)
|
| 849 |
+
should_release_reserved_slot = False
|
| 850 |
+
return agent_session
|
| 851 |
+
except Exception:
|
| 852 |
+
raise
|
| 853 |
+
finally:
|
| 854 |
+
if should_release_reserved_slot:
|
| 855 |
+
await self._release_reserved_session_slot(session_id, reserved_session)
|
| 856 |
|
| 857 |
async def create_session(
|
| 858 |
self,
|
|
|
|
| 880 |
SessionCapacityError: If the server or user has reached the
|
| 881 |
maximum number of concurrent sessions.
|
| 882 |
"""
|
| 883 |
+
# ── Capacity checks & reservation ───────────────────────────
|
| 884 |
+
# Create lightweight queues up-front (non-blocking).
|
| 885 |
+
submission_queue: asyncio.Queue = asyncio.Queue()
|
| 886 |
+
event_queue: asyncio.Queue = asyncio.Queue()
|
| 887 |
+
|
| 888 |
async with self._lock:
|
| 889 |
active_count = self.active_session_count
|
| 890 |
if active_count >= MAX_SESSIONS:
|
|
|
|
| 902 |
error_type="per_user",
|
| 903 |
)
|
| 904 |
|
| 905 |
+
session_id = str(uuid.uuid4())
|
| 906 |
+
# Reserve the slot with a placeholder AgentSession so concurrent
|
| 907 |
+
# creators cannot exceed MAX_SESSIONS. The placeholder has no
|
| 908 |
+
# real Session/tool_router yet (session is None) but counts
|
| 909 |
+
# towards `active_session_count` because `is_active` is True.
|
| 910 |
+
placeholder = AgentSession(
|
| 911 |
+
session_id=session_id,
|
| 912 |
+
session=None,
|
| 913 |
+
tool_router=None,
|
| 914 |
+
submission_queue=submission_queue,
|
| 915 |
+
user_id=user_id,
|
| 916 |
+
hf_username=hf_username,
|
| 917 |
+
hf_token=hf_token,
|
| 918 |
+
is_active=True,
|
| 919 |
+
)
|
| 920 |
+
self.sessions[session_id] = placeholder
|
| 921 |
|
| 922 |
# Run blocking constructors in a thread to keep the event loop responsive.
|
| 923 |
+
agent_session: AgentSession | None = None
|
| 924 |
+
try:
|
| 925 |
+
tool_router, session = await asyncio.to_thread(
|
| 926 |
+
self._create_session_sync,
|
| 927 |
+
session_id=session_id,
|
| 928 |
+
user_id=user_id,
|
| 929 |
+
hf_username=hf_username,
|
| 930 |
+
hf_token=hf_token,
|
| 931 |
+
model=model,
|
| 932 |
+
event_queue=event_queue,
|
| 933 |
+
)
|
| 934 |
|
| 935 |
+
# Create wrapper with the real session resources and replace the
|
| 936 |
+
# placeholder in _start_agent_session.
|
| 937 |
+
agent_session = AgentSession(
|
| 938 |
+
session_id=session_id,
|
| 939 |
+
session=session,
|
| 940 |
+
tool_router=tool_router,
|
| 941 |
+
submission_queue=submission_queue,
|
| 942 |
+
user_id=user_id,
|
| 943 |
+
hf_username=hf_username,
|
| 944 |
+
hf_token=hf_token,
|
| 945 |
+
)
|
| 946 |
|
| 947 |
+
await self._start_agent_session(
|
| 948 |
+
agent_session=agent_session,
|
| 949 |
+
event_queue=event_queue,
|
| 950 |
+
tool_router=tool_router,
|
| 951 |
+
)
|
| 952 |
+
await self.persist_session_snapshot(agent_session, runtime_state="idle")
|
| 953 |
+
self._start_cpu_sandbox_preload(agent_session)
|
| 954 |
|
| 955 |
+
if is_pro is not None and user_id and user_id != "dev":
|
| 956 |
+
await self._track_pro_status(agent_session, is_pro=is_pro)
|
| 957 |
|
| 958 |
+
logger.info(f"Created session {session_id} for user {user_id}")
|
| 959 |
+
return session_id
|
| 960 |
+
except Exception:
|
| 961 |
+
cleanup_task: asyncio.Task | None = None
|
| 962 |
+
async with self._lock:
|
| 963 |
+
current = self.sessions.get(session_id)
|
| 964 |
+
if current and (
|
| 965 |
+
current is agent_session
|
| 966 |
+
or getattr(current, "session", None) is None
|
| 967 |
+
):
|
| 968 |
+
self.sessions.pop(session_id, None)
|
| 969 |
+
if agent_session is not None:
|
| 970 |
+
agent_session.is_active = False
|
| 971 |
+
cleanup_task = agent_session.task
|
| 972 |
+
if cleanup_task and not cleanup_task.done():
|
| 973 |
+
cleanup_task.cancel()
|
| 974 |
+
try:
|
| 975 |
+
await cleanup_task
|
| 976 |
+
except asyncio.CancelledError:
|
| 977 |
+
pass
|
| 978 |
+
except Exception:
|
| 979 |
+
pass
|
| 980 |
+
raise
|
| 981 |
|
| 982 |
async def _track_pro_status(
|
| 983 |
self, agent_session: AgentSession, *, is_pro: bool
|
tests/unit/test_session_capacity.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import importlib
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from types import SimpleNamespace
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
import types
|
| 9 |
+
|
| 10 |
+
_BACKEND_DIR = Path(__file__).resolve().parent.parent.parent / "backend"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@pytest.fixture
|
| 14 |
+
def session_manager_module(monkeypatch):
|
| 15 |
+
"""Import backend.session_manager with temporary dependency stubs.
|
| 16 |
+
|
| 17 |
+
The stubs are inserted with monkeypatch so they are restored after the test,
|
| 18 |
+
and the imported session_manager module is removed from sys.modules to avoid
|
| 19 |
+
leaking the stubbed import state into other tests.
|
| 20 |
+
"""
|
| 21 |
+
with monkeypatch.context() as m:
|
| 22 |
+
m.syspath_prepend(str(_BACKEND_DIR))
|
| 23 |
+
|
| 24 |
+
litellm_stub = types.ModuleType("litellm")
|
| 25 |
+
|
| 26 |
+
class _DummyMessage:
|
| 27 |
+
@staticmethod
|
| 28 |
+
def model_validate(raw):
|
| 29 |
+
return SimpleNamespace(**raw)
|
| 30 |
+
|
| 31 |
+
setattr(litellm_stub, "Message", _DummyMessage)
|
| 32 |
+
m.setitem(sys.modules, "litellm", litellm_stub)
|
| 33 |
+
|
| 34 |
+
fastmcp_stub = types.ModuleType("fastmcp")
|
| 35 |
+
|
| 36 |
+
class _DummyClient:
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
setattr(fastmcp_stub, "Client", _DummyClient)
|
| 40 |
+
m.setitem(sys.modules, "fastmcp", fastmcp_stub)
|
| 41 |
+
|
| 42 |
+
m.setitem(sys.modules, "thefuzz", types.ModuleType("thefuzz"))
|
| 43 |
+
m.setitem(
|
| 44 |
+
sys.modules,
|
| 45 |
+
"huggingface_hub",
|
| 46 |
+
types.ModuleType("huggingface_hub"),
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
module = importlib.import_module("session_manager")
|
| 50 |
+
yield module
|
| 51 |
+
|
| 52 |
+
sys.modules.pop("session_manager", None)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@pytest.mark.asyncio
|
| 56 |
+
async def test_restore_denied_when_at_capacity(session_manager_module, caplog):
|
| 57 |
+
manager = session_manager_module.SessionManager()
|
| 58 |
+
# Fill in-memory sessions up to MAX_SESSIONS
|
| 59 |
+
for i in range(session_manager_module.MAX_SESSIONS):
|
| 60 |
+
manager.sessions[str(i)] = session_manager_module.AgentSession(
|
| 61 |
+
session_id=str(i),
|
| 62 |
+
session=object(),
|
| 63 |
+
tool_router=None,
|
| 64 |
+
submission_queue=asyncio.Queue(),
|
| 65 |
+
is_active=True,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
class DummyStore:
|
| 69 |
+
enabled = True
|
| 70 |
+
|
| 71 |
+
async def load_session(self, sid):
|
| 72 |
+
return {"metadata": {"user_id": "test", "model": "gpt"}, "messages": []}
|
| 73 |
+
|
| 74 |
+
manager.persistence_store = DummyStore()
|
| 75 |
+
|
| 76 |
+
caplog.set_level("WARNING")
|
| 77 |
+
res = await manager.ensure_session_loaded("restored", user_id="test")
|
| 78 |
+
assert res is None
|
| 79 |
+
assert any("Cannot restore session" in rec.message for rec in caplog.records)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@pytest.mark.asyncio
|
| 83 |
+
async def test_restore_allowed_under_capacity(session_manager_module, monkeypatch):
|
| 84 |
+
manager = session_manager_module.SessionManager()
|
| 85 |
+
manager.sessions.clear()
|
| 86 |
+
|
| 87 |
+
class DummyStore:
|
| 88 |
+
enabled = True
|
| 89 |
+
|
| 90 |
+
async def load_session(self, sid):
|
| 91 |
+
return {"metadata": {"user_id": "test", "model": "gpt"}, "messages": []}
|
| 92 |
+
|
| 93 |
+
manager.persistence_store = DummyStore()
|
| 94 |
+
|
| 95 |
+
# Replace the heavy sync constructor with a lightweight fake
|
| 96 |
+
def fake_create_session_sync(*, session_id, user_id, hf_username, hf_token, model, event_queue, notification_destinations):
|
| 97 |
+
class SimpleSession:
|
| 98 |
+
def __init__(self):
|
| 99 |
+
self.context_manager = SimpleNamespace(items=[object()])
|
| 100 |
+
self.pending_approval = None
|
| 101 |
+
self.turn_count = 0
|
| 102 |
+
|
| 103 |
+
return None, SimpleSession()
|
| 104 |
+
|
| 105 |
+
monkeypatch.setattr(manager, "_create_session_sync", fake_create_session_sync)
|
| 106 |
+
|
| 107 |
+
# Fake _start_agent_session to register the session without starting tasks
|
| 108 |
+
async def fake_start_agent_session(*, agent_session, event_queue, tool_router):
|
| 109 |
+
async with manager._lock:
|
| 110 |
+
manager.sessions[agent_session.session_id] = agent_session
|
| 111 |
+
return agent_session
|
| 112 |
+
|
| 113 |
+
monkeypatch.setattr(manager, "_start_agent_session", fake_start_agent_session)
|
| 114 |
+
|
| 115 |
+
res = await manager.ensure_session_loaded("restored", user_id="test")
|
| 116 |
+
assert res is not None
|
| 117 |
+
assert res.session is not None
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@pytest.mark.asyncio
|
| 121 |
+
async def test_restore_rolls_back_placeholder_on_load_failure(session_manager_module):
|
| 122 |
+
manager = session_manager_module.SessionManager()
|
| 123 |
+
|
| 124 |
+
class FailingStore:
|
| 125 |
+
enabled = True
|
| 126 |
+
|
| 127 |
+
async def load_session(self, sid):
|
| 128 |
+
raise RuntimeError("load failed")
|
| 129 |
+
|
| 130 |
+
manager.persistence_store = FailingStore()
|
| 131 |
+
|
| 132 |
+
with pytest.raises(RuntimeError, match="load failed"):
|
| 133 |
+
await manager.ensure_session_loaded("restored", user_id="test")
|
| 134 |
+
|
| 135 |
+
assert manager.sessions == {}
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@pytest.mark.asyncio
|
| 139 |
+
async def test_create_session_rolls_back_placeholder_on_failure(
|
| 140 |
+
session_manager_module, monkeypatch
|
| 141 |
+
):
|
| 142 |
+
manager = session_manager_module.SessionManager()
|
| 143 |
+
|
| 144 |
+
def fake_create_session_sync(**kwargs):
|
| 145 |
+
raise RuntimeError("boom")
|
| 146 |
+
|
| 147 |
+
monkeypatch.setattr(manager, "_create_session_sync", fake_create_session_sync)
|
| 148 |
+
|
| 149 |
+
with pytest.raises(RuntimeError, match="boom"):
|
| 150 |
+
await manager.create_session(user_id="u1")
|
| 151 |
+
|
| 152 |
+
assert manager.sessions == {}
|