| """Regression tests for server-side session persistence restore/access.""" |
|
|
| from __future__ import annotations |
|
|
| import asyncio |
| import sys |
| import threading |
| from datetime import datetime, UTC |
| from pathlib import Path |
| from types import SimpleNamespace |
| from typing import Any |
|
|
| import pytest |
|
|
| _BACKEND_DIR = Path(__file__).resolve().parent.parent.parent / "backend" |
| if str(_BACKEND_DIR) not in sys.path: |
| sys.path.insert(0, str(_BACKEND_DIR)) |
|
|
| from agent.core.session_persistence import NoopSessionStore |
| from session_manager import AgentSession, SessionManager |
|
|
|
|
| class FakeRuntimeSession: |
| def __init__(self, *, hf_token: str | None = None, model: str = "test-model"): |
| self.hf_token = hf_token |
| self.context_manager = SimpleNamespace(items=[]) |
| self.pending_approval = None |
| self.turn_count = 0 |
| self.config = SimpleNamespace(model_name=model) |
| self.notification_destinations = [] |
| self.auto_approval_enabled = False |
| self.auto_approval_cost_cap_usd = None |
| self.auto_approval_estimated_spend_usd = 0.0 |
| self.sandbox = None |
| self.sandbox_hardware = None |
| self.sandbox_preload_task = None |
| self.sandbox_preload_cancel_event = None |
|
|
| def auto_approval_policy_summary(self): |
| cap = self.auto_approval_cost_cap_usd |
| remaining = ( |
| None |
| if cap is None |
| else max(0, cap - self.auto_approval_estimated_spend_usd) |
| ) |
| return { |
| "enabled": self.auto_approval_enabled, |
| "cost_cap_usd": cap, |
| "estimated_spend_usd": self.auto_approval_estimated_spend_usd, |
| "remaining_usd": remaining, |
| } |
|
|
| def set_auto_approval_policy(self, *, enabled, cost_cap_usd): |
| self.auto_approval_enabled = enabled |
| self.auto_approval_cost_cap_usd = cost_cap_usd |
|
|
|
|
| class RestoreStore(NoopSessionStore): |
| enabled = True |
|
|
| def __init__( |
| self, |
| *, |
| metadata: dict[str, Any] | None = None, |
| messages: list[dict[str, Any]] | None = None, |
| delay: float = 0, |
| ) -> None: |
| self.metadata = metadata or { |
| "session_id": "persisted-session", |
| "user_id": "owner", |
| "model": "test-model", |
| "created_at": datetime.now(UTC), |
| } |
| self.messages = messages or [] |
| self.delay = delay |
| self.load_calls = 0 |
| self.updated_fields: list[tuple[str, dict[str, Any]]] = [] |
|
|
| async def load_session(self, session_id: str, **_: Any) -> dict[str, Any] | None: |
| self.load_calls += 1 |
| if self.delay: |
| await asyncio.sleep(self.delay) |
| metadata = dict(self.metadata) |
| metadata.setdefault("session_id", session_id) |
| metadata.setdefault("_id", session_id) |
| return {"metadata": metadata, "messages": self.messages} |
|
|
| async def update_session_fields(self, session_id: str, **fields: Any) -> None: |
| self.updated_fields.append((session_id, fields)) |
| self.metadata.update(fields) |
|
|
|
|
| class CloseableResource: |
| def __init__(self) -> None: |
| self.closed = False |
|
|
| async def close(self) -> None: |
| self.closed = True |
|
|
|
|
| def _manager_with_store(store: NoopSessionStore) -> SessionManager: |
| manager = object.__new__(SessionManager) |
| manager.config = SimpleNamespace(model_name="test-model") |
| manager.sessions = {} |
| manager._lock = asyncio.Lock() |
| manager.persistence_store = store |
| manager.messaging_gateway = CloseableResource() |
| return manager |
|
|
|
|
| def _runtime_agent_session( |
| session_id: str, |
| *, |
| user_id: str = "owner", |
| hf_token: str | None = "owner-token", |
| ) -> AgentSession: |
| runtime_session = FakeRuntimeSession(hf_token=hf_token) |
| return AgentSession( |
| session_id=session_id, |
| session=runtime_session, |
| tool_router=object(), |
| submission_queue=asyncio.Queue(), |
| user_id=user_id, |
| hf_token=hf_token, |
| ) |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_update_session_auto_approval_defaults_to_five_dollars(): |
| manager = _manager_with_store(NoopSessionStore()) |
| existing = _runtime_agent_session("s1", user_id="owner") |
| manager.sessions["s1"] = existing |
|
|
| summary = await manager.update_session_auto_approval( |
| "s1", |
| enabled=True, |
| cost_cap_usd=None, |
| cap_provided=False, |
| ) |
|
|
| assert summary["enabled"] is True |
| assert summary["cost_cap_usd"] == 5.0 |
| assert summary["remaining_usd"] == 5.0 |
|
|
|
|
| def _install_fake_runtime(manager: SessionManager) -> asyncio.Event: |
| stop = asyncio.Event() |
| manager.run_calls = 0 |
|
|
| def fake_create_session_sync(**kwargs: Any): |
| return object(), FakeRuntimeSession( |
| hf_token=kwargs.get("hf_token"), |
| model=kwargs.get("model") or "test-model", |
| ) |
|
|
| async def fake_run_session(*_: Any) -> None: |
| manager.run_calls += 1 |
| await stop.wait() |
|
|
| manager._create_session_sync = fake_create_session_sync |
| manager._run_session = fake_run_session |
| return stop |
|
|
|
|
| async def _cancel_runtime_tasks(manager: SessionManager) -> None: |
| tasks = [ |
| agent_session.task |
| for agent_session in manager.sessions.values() |
| if agent_session.task and not agent_session.task.done() |
| ] |
| for task in tasks: |
| task.cancel() |
| if tasks: |
| await asyncio.gather(*tasks, return_exceptions=True) |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_close_cancels_preload_and_deletes_owned_sandbox(monkeypatch): |
| deleted: list[str] = [] |
|
|
| async def fake_record_sandbox_destroy(*args, **kwargs): |
| pass |
|
|
| monkeypatch.setattr( |
| "agent.core.telemetry.record_sandbox_destroy", |
| fake_record_sandbox_destroy, |
| ) |
|
|
| store = NoopSessionStore() |
| manager = _manager_with_store(store) |
| gateway = CloseableResource() |
| persistence = CloseableResource() |
| manager.messaging_gateway = gateway |
| manager.persistence_store = persistence |
|
|
| cancel_event = asyncio.Event() |
| preload_cancel_event = threading.Event() |
|
|
| async def preload(): |
| while not preload_cancel_event.is_set(): |
| await asyncio.sleep(0) |
| cancel_event.set() |
|
|
| session = FakeRuntimeSession(hf_token="token") |
| session.session_id = "s1" |
| session.persistence_store = NoopSessionStore() |
| session.sandbox = SimpleNamespace( |
| space_id="owner/sandbox-12345678", |
| _owns_space=True, |
| delete=lambda: deleted.append("owner/sandbox-12345678"), |
| ) |
| session.sandbox_hardware = "cpu-basic" |
| session.sandbox_preload_cancel_event = preload_cancel_event |
| session.sandbox_preload_task = asyncio.create_task(preload()) |
| manager.sessions["s1"] = AgentSession( |
| session_id="s1", |
| session=session, |
| tool_router=object(), |
| submission_queue=asyncio.Queue(), |
| user_id="owner", |
| hf_token="token", |
| ) |
|
|
| await manager.close() |
|
|
| assert preload_cancel_event.is_set() |
| assert cancel_event.is_set() |
| assert deleted == ["owner/sandbox-12345678"] |
| assert gateway.closed is True |
| assert persistence.closed is True |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_close_closes_resources_when_sandbox_cleanup_fails(): |
| manager = _manager_with_store(NoopSessionStore()) |
| gateway = CloseableResource() |
| persistence = CloseableResource() |
| manager.messaging_gateway = gateway |
| manager.persistence_store = persistence |
| manager.sessions["s1"] = _runtime_agent_session("s1") |
| manager.sessions["s2"] = _runtime_agent_session("s2") |
| cleaned: list[str] = [] |
|
|
| async def fake_cleanup(session): |
| cleaned.append(session.hf_token) |
| if session.hf_token == "owner-token": |
| raise RuntimeError("boom") |
|
|
| manager._cleanup_sandbox = fake_cleanup |
|
|
| await manager.close() |
|
|
| assert cleaned == ["owner-token", "owner-token"] |
| assert gateway.closed is True |
| assert persistence.closed is True |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_existing_session_rejects_cross_user_token_overwrite(): |
| manager = _manager_with_store(NoopSessionStore()) |
| existing = _runtime_agent_session("s1", user_id="victim", hf_token="victim-token") |
| manager.sessions["s1"] = existing |
|
|
| result = await manager.ensure_session_loaded( |
| "s1", user_id="attacker", hf_token="attacker-token" |
| ) |
|
|
| assert result is None |
| assert existing.hf_token == "victim-token" |
| assert existing.session.hf_token == "victim-token" |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_existing_session_updates_token_after_access_check(): |
| manager = _manager_with_store(NoopSessionStore()) |
| existing = _runtime_agent_session("s1", user_id="owner", hf_token="old-token") |
| manager.sessions["s1"] = existing |
|
|
| result = await manager.ensure_session_loaded( |
| "s1", user_id="owner", hf_token="new-token" |
| ) |
|
|
| assert result is existing |
| assert existing.hf_token == "new-token" |
| assert existing.session.hf_token == "new-token" |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_existing_session_retries_preload_after_token_recovered(): |
| manager = _manager_with_store(NoopSessionStore()) |
| existing = _runtime_agent_session("s1", user_id="owner", hf_token=None) |
| done_task = asyncio.get_running_loop().create_future() |
| done_task.set_result(None) |
| existing.session.sandbox_preload_task = done_task |
| existing.session.sandbox_preload_error = ( |
| "No HF token available. Cannot create sandbox." |
| ) |
| manager.sessions["s1"] = existing |
| started: list[str] = [] |
|
|
| def fake_start_cpu_sandbox_preload(agent_session): |
| started.append(agent_session.session_id) |
|
|
| manager._start_cpu_sandbox_preload = fake_start_cpu_sandbox_preload |
|
|
| result = await manager.ensure_session_loaded( |
| "s1", |
| user_id="owner", |
| hf_token="new-token", |
| ) |
|
|
| assert result is existing |
| assert existing.hf_token == "new-token" |
| assert existing.session.hf_token == "new-token" |
| assert existing.session.sandbox_preload_error is None |
| assert existing.session.sandbox_preload_task is None |
| assert started == ["s1"] |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_existing_session_does_not_retry_preload_when_disabled(): |
| manager = _manager_with_store(NoopSessionStore()) |
| existing = _runtime_agent_session("s1", user_id="owner", hf_token=None) |
| done_task = asyncio.get_running_loop().create_future() |
| done_task.set_result(None) |
| existing.session.sandbox_preload_task = done_task |
| existing.session.sandbox_preload_error = ( |
| "No HF token available. Cannot create sandbox." |
| ) |
| manager.sessions["s1"] = existing |
| started: list[str] = [] |
|
|
| def fake_start_cpu_sandbox_preload(agent_session): |
| started.append(agent_session.session_id) |
|
|
| manager._start_cpu_sandbox_preload = fake_start_cpu_sandbox_preload |
|
|
| result = await manager.ensure_session_loaded( |
| "s1", |
| user_id="owner", |
| hf_token="new-token", |
| preload_sandbox=False, |
| ) |
|
|
| assert result is existing |
| assert existing.hf_token == "new-token" |
| assert existing.session.hf_token == "new-token" |
| assert existing.session.sandbox_preload_error == ( |
| "No HF token available. Cannot create sandbox." |
| ) |
| assert started == [] |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_existing_session_does_not_restart_preload_after_teardown(): |
| manager = _manager_with_store(NoopSessionStore()) |
| existing = _runtime_agent_session("s1", user_id="owner", hf_token="token") |
| done_task = asyncio.get_running_loop().create_future() |
| done_task.set_result(None) |
| existing.session.sandbox = None |
| existing.session.sandbox_preload_task = done_task |
| existing.session.sandbox_preload_error = None |
| manager.sessions["s1"] = existing |
| started: list[str] = [] |
|
|
| def fake_start_cpu_sandbox_preload(agent_session): |
| started.append(agent_session.session_id) |
|
|
| manager._start_cpu_sandbox_preload = fake_start_cpu_sandbox_preload |
|
|
| result = await manager.ensure_session_loaded( |
| "s1", |
| user_id="owner", |
| hf_token="token", |
| ) |
|
|
| assert result is existing |
| assert existing.session.sandbox_preload_task is done_task |
| assert existing.session.sandbox_preload_error is None |
| assert started == [] |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_concurrent_lazy_restore_starts_only_one_agent_task(): |
| store = RestoreStore(delay=0.01) |
| manager = _manager_with_store(store) |
| stop = _install_fake_runtime(manager) |
| scheduled: list[str] = [] |
|
|
| def fake_start_cpu_sandbox_preload(agent_session: AgentSession) -> None: |
| scheduled.append(agent_session.session_id) |
|
|
| manager._start_cpu_sandbox_preload = fake_start_cpu_sandbox_preload |
|
|
| try: |
| first, second = await asyncio.gather( |
| manager.ensure_session_loaded("persisted-session", user_id="owner"), |
| manager.ensure_session_loaded("persisted-session", user_id="owner"), |
| ) |
| await asyncio.sleep(0) |
|
|
| assert first is second |
| assert list(manager.sessions) == ["persisted-session"] |
| assert manager.run_calls == 1 |
| assert scheduled == ["persisted-session"] |
| assert not stop.is_set() |
| finally: |
| stop.set() |
| await _cancel_runtime_tasks(manager) |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_create_session_schedules_cpu_sandbox_preload(): |
| manager = _manager_with_store(NoopSessionStore()) |
| stop = _install_fake_runtime(manager) |
| scheduled: list[str] = [] |
|
|
| def fake_start_cpu_sandbox_preload(agent_session: AgentSession) -> None: |
| scheduled.append(agent_session.session_id) |
|
|
| manager._start_cpu_sandbox_preload = fake_start_cpu_sandbox_preload |
|
|
| try: |
| session_id = await manager.create_session(user_id="owner", hf_token="token") |
|
|
| assert scheduled == [session_id] |
| assert session_id in manager.sessions |
| runtime_session = manager.sessions[session_id].session |
| assert not hasattr(runtime_session, "_ml_intern_artifact_collection_task") |
| assert not hasattr(runtime_session, "_ml_intern_artifact_collection_slug") |
| finally: |
| stop.set() |
| await _cancel_runtime_tasks(manager) |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_lazy_restore_schedules_cpu_sandbox_preload(): |
| manager = _manager_with_store(RestoreStore()) |
| stop = _install_fake_runtime(manager) |
| scheduled: list[str] = [] |
|
|
| def fake_start_cpu_sandbox_preload(agent_session: AgentSession) -> None: |
| scheduled.append(agent_session.session_id) |
|
|
| manager._start_cpu_sandbox_preload = fake_start_cpu_sandbox_preload |
|
|
| try: |
| restored = await manager.ensure_session_loaded( |
| "persisted-session", user_id="owner" |
| ) |
|
|
| assert restored is not None |
| assert scheduled == ["persisted-session"] |
| assert "persisted-session" in manager.sessions |
| assert not hasattr(restored.session, "_ml_intern_artifact_collection_task") |
| assert not hasattr(restored.session, "_ml_intern_artifact_collection_slug") |
| finally: |
| stop.set() |
| await _cancel_runtime_tasks(manager) |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_lazy_restore_deletes_persisted_sandbox_before_preload(monkeypatch): |
| deleted: list[tuple[str, str, str]] = [] |
|
|
| class FakeApi: |
| def __init__(self, token=None): |
| self.token = token |
|
|
| def delete_repo(self, repo_id, repo_type): |
| deleted.append((self.token, repo_id, repo_type)) |
|
|
| monkeypatch.setattr("huggingface_hub.HfApi", FakeApi) |
|
|
| store = RestoreStore( |
| metadata={ |
| "session_id": "persisted-session", |
| "user_id": "owner", |
| "model": "test-model", |
| "created_at": datetime.now(UTC), |
| "sandbox_space_id": "owner/sandbox-12345678", |
| "sandbox_hardware": "cpu-basic", |
| "sandbox_owner": "owner", |
| "sandbox_created_at": datetime.now(UTC), |
| "sandbox_status": "active", |
| } |
| ) |
| manager = _manager_with_store(store) |
| stop = _install_fake_runtime(manager) |
| scheduled: list[str] = [] |
|
|
| def fake_start_cpu_sandbox_preload(agent_session: AgentSession) -> None: |
| scheduled.append(agent_session.session_id) |
|
|
| manager._start_cpu_sandbox_preload = fake_start_cpu_sandbox_preload |
|
|
| try: |
| restored = await manager.ensure_session_loaded( |
| "persisted-session", |
| user_id="owner", |
| hf_token="user-token", |
| ) |
|
|
| assert restored is not None |
| assert deleted == [("user-token", "owner/sandbox-12345678", "space")] |
| assert scheduled == ["persisted-session"] |
| assert store.metadata["sandbox_space_id"] is None |
| assert store.metadata["sandbox_status"] == "destroyed" |
| finally: |
| stop.set() |
| await _cancel_runtime_tasks(manager) |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_lazy_restore_can_skip_cpu_sandbox_preload_after_cleanup(monkeypatch): |
| deleted: list[str] = [] |
|
|
| class FakeApi: |
| def __init__(self, token=None): |
| self.token = token |
|
|
| def delete_repo(self, repo_id, repo_type): |
| deleted.append(repo_id) |
|
|
| monkeypatch.setattr("huggingface_hub.HfApi", FakeApi) |
|
|
| store = RestoreStore( |
| metadata={ |
| "session_id": "persisted-session", |
| "user_id": "owner", |
| "model": "test-model", |
| "created_at": datetime.now(UTC), |
| "sandbox_space_id": "owner/sandbox-87654321", |
| "sandbox_status": "active", |
| } |
| ) |
| manager = _manager_with_store(store) |
| stop = _install_fake_runtime(manager) |
| scheduled: list[str] = [] |
|
|
| def fake_start_cpu_sandbox_preload(agent_session: AgentSession) -> None: |
| scheduled.append(agent_session.session_id) |
|
|
| manager._start_cpu_sandbox_preload = fake_start_cpu_sandbox_preload |
|
|
| try: |
| restored = await manager.ensure_session_loaded( |
| "persisted-session", |
| user_id="owner", |
| hf_token="user-token", |
| preload_sandbox=False, |
| ) |
|
|
| assert restored is not None |
| assert deleted == ["owner/sandbox-87654321"] |
| assert scheduled == [] |
| assert store.metadata["sandbox_space_id"] is None |
| finally: |
| stop.set() |
| await _cancel_runtime_tasks(manager) |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_lazy_restore_preserves_pending_approval_tool_calls(): |
| store = RestoreStore( |
| metadata={ |
| "session_id": "approval-session", |
| "user_id": "owner", |
| "model": "test-model", |
| "pending_approval": [ |
| { |
| "id": "call_123", |
| "type": "function", |
| "function": { |
| "name": "create_file", |
| "arguments": '{"path":"app.py"}', |
| }, |
| } |
| ], |
| } |
| ) |
| manager = _manager_with_store(store) |
| stop = _install_fake_runtime(manager) |
|
|
| try: |
| restored = await manager.ensure_session_loaded( |
| "approval-session", user_id="owner" |
| ) |
|
|
| assert restored is not None |
| tool_calls = restored.session.pending_approval["tool_calls"] |
| assert len(tool_calls) == 1 |
| assert tool_calls[0].id == "call_123" |
| assert tool_calls[0].function.name == "create_file" |
| assert tool_calls[0].function.arguments == '{"path":"app.py"}' |
| finally: |
| stop.set() |
| await _cancel_runtime_tasks(manager) |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_lazy_restore_preserves_auto_approval_policy(): |
| store = RestoreStore( |
| metadata={ |
| "session_id": "yolo-session", |
| "user_id": "owner", |
| "model": "test-model", |
| "auto_approval_enabled": True, |
| "auto_approval_cost_cap_usd": 5.0, |
| "auto_approval_estimated_spend_usd": 1.25, |
| } |
| ) |
| manager = _manager_with_store(store) |
| stop = _install_fake_runtime(manager) |
|
|
| try: |
| restored = await manager.ensure_session_loaded("yolo-session", user_id="owner") |
|
|
| assert restored is not None |
| assert restored.session.auto_approval_enabled is True |
| assert restored.session.auto_approval_cost_cap_usd == 5.0 |
| assert restored.session.auto_approval_estimated_spend_usd == 1.25 |
| assert restored.session.auto_approval_policy_summary()["remaining_usd"] == 3.75 |
| finally: |
| stop.set() |
| await _cancel_runtime_tasks(manager) |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_list_sessions_dev_uses_store_dev_visibility(): |
| class ListStore(NoopSessionStore): |
| enabled = True |
|
|
| def __init__(self) -> None: |
| self.seen_user_id: str | None = None |
|
|
| async def list_sessions(self, user_id: str, **_: Any) -> list[dict[str, Any]]: |
| self.seen_user_id = user_id |
| if user_id == "dev": |
| return [ |
| { |
| "session_id": "s1", |
| "user_id": "alice", |
| "model": "m", |
| "created_at": datetime.now(UTC), |
| "auto_approval_enabled": True, |
| "auto_approval_cost_cap_usd": 5.0, |
| "auto_approval_estimated_spend_usd": 2.0, |
| }, |
| { |
| "session_id": "s2", |
| "user_id": "bob", |
| "model": "m", |
| "created_at": datetime.now(UTC), |
| }, |
| ] |
| return [] |
|
|
| store = ListStore() |
| manager = _manager_with_store(store) |
|
|
| sessions = await manager.list_sessions(user_id="dev") |
|
|
| assert store.seen_user_id == "dev" |
| assert {session["session_id"] for session in sessions} == {"s1", "s2"} |
| yolo = next(session for session in sessions if session["session_id"] == "s1") |
| assert yolo["auto_approval"] == { |
| "enabled": True, |
| "cost_cap_usd": 5.0, |
| "estimated_spend_usd": 2.0, |
| "remaining_usd": 3.0, |
| } |
|
|