Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| """Regression tests for server-side session persistence restore/access.""" | |
| from __future__ import annotations | |
| import asyncio | |
| import sys | |
| import threading | |
| from datetime import datetime, UTC | |
| from pathlib import Path | |
| from types import SimpleNamespace | |
| from typing import Any | |
| import pytest | |
| _BACKEND_DIR = Path(__file__).resolve().parent.parent.parent / "backend" | |
| if str(_BACKEND_DIR) not in sys.path: | |
| sys.path.insert(0, str(_BACKEND_DIR)) | |
| from agent.core.session_persistence import NoopSessionStore # noqa: E402 | |
| from session_manager import AgentSession, SessionManager # noqa: E402 | |
| class FakeRuntimeSession: | |
| def __init__(self, *, hf_token: str | None = None, model: str = "test-model"): | |
| self.hf_token = hf_token | |
| self.context_manager = SimpleNamespace(items=[]) | |
| self.pending_approval = None | |
| self.turn_count = 0 | |
| self.config = SimpleNamespace(model_name=model) | |
| self.notification_destinations = [] | |
| self.auto_approval_enabled = False | |
| self.auto_approval_cost_cap_usd = None | |
| self.auto_approval_estimated_spend_usd = 0.0 | |
| self.sandbox = None | |
| self.sandbox_hardware = None | |
| self.sandbox_preload_task = None | |
| self.sandbox_preload_cancel_event = None | |
| def auto_approval_policy_summary(self): | |
| cap = self.auto_approval_cost_cap_usd | |
| remaining = ( | |
| None | |
| if cap is None | |
| else max(0, cap - self.auto_approval_estimated_spend_usd) | |
| ) | |
| return { | |
| "enabled": self.auto_approval_enabled, | |
| "cost_cap_usd": cap, | |
| "estimated_spend_usd": self.auto_approval_estimated_spend_usd, | |
| "remaining_usd": remaining, | |
| } | |
| def set_auto_approval_policy(self, *, enabled, cost_cap_usd): | |
| self.auto_approval_enabled = enabled | |
| self.auto_approval_cost_cap_usd = cost_cap_usd | |
| class RestoreStore(NoopSessionStore): | |
| enabled = True | |
| def __init__( | |
| self, | |
| *, | |
| metadata: dict[str, Any] | None = None, | |
| messages: list[dict[str, Any]] | None = None, | |
| delay: float = 0, | |
| ) -> None: | |
| self.metadata = metadata or { | |
| "session_id": "persisted-session", | |
| "user_id": "owner", | |
| "model": "test-model", | |
| "created_at": datetime.now(UTC), | |
| } | |
| self.messages = messages or [] | |
| self.delay = delay | |
| self.load_calls = 0 | |
| self.updated_fields: list[tuple[str, dict[str, Any]]] = [] | |
| async def load_session(self, session_id: str, **_: Any) -> dict[str, Any] | None: | |
| self.load_calls += 1 | |
| if self.delay: | |
| await asyncio.sleep(self.delay) | |
| metadata = dict(self.metadata) | |
| metadata.setdefault("session_id", session_id) | |
| metadata.setdefault("_id", session_id) | |
| return {"metadata": metadata, "messages": self.messages} | |
| async def update_session_fields(self, session_id: str, **fields: Any) -> None: | |
| self.updated_fields.append((session_id, fields)) | |
| self.metadata.update(fields) | |
| class CloseableResource: | |
| def __init__(self) -> None: | |
| self.closed = False | |
| async def close(self) -> None: | |
| self.closed = True | |
| def _manager_with_store(store: NoopSessionStore) -> SessionManager: | |
| manager = object.__new__(SessionManager) | |
| manager.config = SimpleNamespace(model_name="test-model") | |
| manager.sessions = {} | |
| manager._lock = asyncio.Lock() | |
| manager.persistence_store = store | |
| manager.messaging_gateway = CloseableResource() | |
| return manager | |
| def _runtime_agent_session( | |
| session_id: str, | |
| *, | |
| user_id: str = "owner", | |
| hf_token: str | None = "owner-token", | |
| ) -> AgentSession: | |
| runtime_session = FakeRuntimeSession(hf_token=hf_token) | |
| return AgentSession( | |
| session_id=session_id, | |
| session=runtime_session, # type: ignore[arg-type] | |
| tool_router=object(), # type: ignore[arg-type] | |
| submission_queue=asyncio.Queue(), | |
| user_id=user_id, | |
| hf_token=hf_token, | |
| ) | |
| async def test_update_session_auto_approval_defaults_to_five_dollars(): | |
| manager = _manager_with_store(NoopSessionStore()) | |
| existing = _runtime_agent_session("s1", user_id="owner") | |
| manager.sessions["s1"] = existing | |
| summary = await manager.update_session_auto_approval( | |
| "s1", | |
| enabled=True, | |
| cost_cap_usd=None, | |
| cap_provided=False, | |
| ) | |
| assert summary["enabled"] is True | |
| assert summary["cost_cap_usd"] == 5.0 | |
| assert summary["remaining_usd"] == 5.0 | |
| def _install_fake_runtime(manager: SessionManager) -> asyncio.Event: | |
| stop = asyncio.Event() | |
| manager.run_calls = 0 # type: ignore[attr-defined] | |
| def fake_create_session_sync(**kwargs: Any): | |
| return object(), FakeRuntimeSession( | |
| hf_token=kwargs.get("hf_token"), | |
| model=kwargs.get("model") or "test-model", | |
| ) | |
| async def fake_run_session(*_: Any) -> None: | |
| manager.run_calls += 1 # type: ignore[attr-defined] | |
| await stop.wait() | |
| manager._create_session_sync = fake_create_session_sync # type: ignore[method-assign] | |
| manager._run_session = fake_run_session # type: ignore[method-assign] | |
| return stop | |
| async def _cancel_runtime_tasks(manager: SessionManager) -> None: | |
| tasks = [ | |
| agent_session.task | |
| for agent_session in manager.sessions.values() | |
| if agent_session.task and not agent_session.task.done() | |
| ] | |
| for task in tasks: | |
| task.cancel() | |
| if tasks: | |
| await asyncio.gather(*tasks, return_exceptions=True) | |
| async def test_close_cancels_preload_and_deletes_owned_sandbox(monkeypatch): | |
| deleted: list[str] = [] | |
| async def fake_record_sandbox_destroy(*args, **kwargs): | |
| pass | |
| monkeypatch.setattr( | |
| "agent.core.telemetry.record_sandbox_destroy", | |
| fake_record_sandbox_destroy, | |
| ) | |
| store = NoopSessionStore() | |
| manager = _manager_with_store(store) | |
| gateway = CloseableResource() | |
| persistence = CloseableResource() | |
| manager.messaging_gateway = gateway # type: ignore[assignment] | |
| manager.persistence_store = persistence # type: ignore[assignment] | |
| cancel_event = asyncio.Event() | |
| preload_cancel_event = threading.Event() | |
| async def preload(): | |
| while not preload_cancel_event.is_set(): | |
| await asyncio.sleep(0) | |
| cancel_event.set() | |
| session = FakeRuntimeSession(hf_token="token") | |
| session.session_id = "s1" | |
| session.persistence_store = NoopSessionStore() | |
| session.sandbox = SimpleNamespace( | |
| space_id="owner/sandbox-12345678", | |
| _owns_space=True, | |
| delete=lambda log=None: deleted.append("owner/sandbox-12345678"), | |
| ) | |
| session.sandbox_hardware = "cpu-basic" | |
| session.sandbox_preload_cancel_event = preload_cancel_event | |
| session.sandbox_preload_task = asyncio.create_task(preload()) | |
| manager.sessions["s1"] = AgentSession( | |
| session_id="s1", | |
| session=session, # type: ignore[arg-type] | |
| tool_router=object(), # type: ignore[arg-type] | |
| submission_queue=asyncio.Queue(), | |
| user_id="owner", | |
| hf_token="token", | |
| ) | |
| await manager.close() | |
| assert preload_cancel_event.is_set() | |
| assert cancel_event.is_set() | |
| assert deleted == ["owner/sandbox-12345678"] | |
| assert gateway.closed is True | |
| assert persistence.closed is True | |
| async def test_close_closes_resources_when_sandbox_cleanup_fails(): | |
| manager = _manager_with_store(NoopSessionStore()) | |
| gateway = CloseableResource() | |
| persistence = CloseableResource() | |
| manager.messaging_gateway = gateway # type: ignore[assignment] | |
| manager.persistence_store = persistence # type: ignore[assignment] | |
| manager.sessions["s1"] = _runtime_agent_session("s1") | |
| manager.sessions["s2"] = _runtime_agent_session("s2") | |
| cleaned: list[str] = [] | |
| async def fake_cleanup(session): | |
| cleaned.append(session.hf_token) | |
| if session.hf_token == "owner-token": | |
| raise RuntimeError("boom") | |
| manager._cleanup_sandbox = fake_cleanup # type: ignore[method-assign] | |
| await manager.close() | |
| assert cleaned == ["owner-token", "owner-token"] | |
| assert gateway.closed is True | |
| assert persistence.closed is True | |
| async def test_existing_session_rejects_cross_user_token_overwrite(): | |
| manager = _manager_with_store(NoopSessionStore()) | |
| existing = _runtime_agent_session("s1", user_id="victim", hf_token="victim-token") | |
| manager.sessions["s1"] = existing | |
| result = await manager.ensure_session_loaded( | |
| "s1", user_id="attacker", hf_token="attacker-token" | |
| ) | |
| assert result is None | |
| assert existing.hf_token == "victim-token" | |
| assert existing.session.hf_token == "victim-token" | |
| async def test_existing_session_updates_token_after_access_check(): | |
| manager = _manager_with_store(NoopSessionStore()) | |
| existing = _runtime_agent_session("s1", user_id="owner", hf_token="old-token") | |
| manager.sessions["s1"] = existing | |
| result = await manager.ensure_session_loaded( | |
| "s1", user_id="owner", hf_token="new-token" | |
| ) | |
| assert result is existing | |
| assert existing.hf_token == "new-token" | |
| assert existing.session.hf_token == "new-token" | |
| async def test_existing_session_retries_preload_after_token_recovered(): | |
| manager = _manager_with_store(NoopSessionStore()) | |
| existing = _runtime_agent_session("s1", user_id="owner", hf_token=None) | |
| done_task = asyncio.get_running_loop().create_future() | |
| done_task.set_result(None) | |
| existing.session.sandbox_preload_task = done_task | |
| existing.session.sandbox_preload_error = ( | |
| "No HF token available. Cannot create sandbox." | |
| ) | |
| manager.sessions["s1"] = existing | |
| started: list[str] = [] | |
| def fake_start_cpu_sandbox_preload(agent_session): | |
| started.append(agent_session.session_id) | |
| manager._start_cpu_sandbox_preload = fake_start_cpu_sandbox_preload # type: ignore[method-assign] | |
| result = await manager.ensure_session_loaded( | |
| "s1", | |
| user_id="owner", | |
| hf_token="new-token", | |
| ) | |
| assert result is existing | |
| assert existing.hf_token == "new-token" | |
| assert existing.session.hf_token == "new-token" | |
| assert existing.session.sandbox_preload_error is None | |
| assert existing.session.sandbox_preload_task is None | |
| assert started == ["s1"] | |
| async def test_existing_session_does_not_retry_preload_when_disabled(): | |
| manager = _manager_with_store(NoopSessionStore()) | |
| existing = _runtime_agent_session("s1", user_id="owner", hf_token=None) | |
| done_task = asyncio.get_running_loop().create_future() | |
| done_task.set_result(None) | |
| existing.session.sandbox_preload_task = done_task | |
| existing.session.sandbox_preload_error = ( | |
| "No HF token available. Cannot create sandbox." | |
| ) | |
| manager.sessions["s1"] = existing | |
| started: list[str] = [] | |
| def fake_start_cpu_sandbox_preload(agent_session): | |
| started.append(agent_session.session_id) | |
| manager._start_cpu_sandbox_preload = fake_start_cpu_sandbox_preload # type: ignore[method-assign] | |
| result = await manager.ensure_session_loaded( | |
| "s1", | |
| user_id="owner", | |
| hf_token="new-token", | |
| preload_sandbox=False, | |
| ) | |
| assert result is existing | |
| assert existing.hf_token == "new-token" | |
| assert existing.session.hf_token == "new-token" | |
| assert existing.session.sandbox_preload_error == ( | |
| "No HF token available. Cannot create sandbox." | |
| ) | |
| assert started == [] | |
| async def test_existing_session_does_not_restart_preload_after_teardown(): | |
| manager = _manager_with_store(NoopSessionStore()) | |
| existing = _runtime_agent_session("s1", user_id="owner", hf_token="token") | |
| done_task = asyncio.get_running_loop().create_future() | |
| done_task.set_result(None) | |
| existing.session.sandbox = None | |
| existing.session.sandbox_preload_task = done_task | |
| existing.session.sandbox_preload_error = None | |
| manager.sessions["s1"] = existing | |
| started: list[str] = [] | |
| def fake_start_cpu_sandbox_preload(agent_session): | |
| started.append(agent_session.session_id) | |
| manager._start_cpu_sandbox_preload = fake_start_cpu_sandbox_preload # type: ignore[method-assign] | |
| result = await manager.ensure_session_loaded( | |
| "s1", | |
| user_id="owner", | |
| hf_token="token", | |
| ) | |
| assert result is existing | |
| assert existing.session.sandbox_preload_task is done_task | |
| assert existing.session.sandbox_preload_error is None | |
| assert started == [] | |
| async def test_concurrent_lazy_restore_starts_only_one_agent_task(): | |
| store = RestoreStore(delay=0.01) | |
| manager = _manager_with_store(store) | |
| stop = _install_fake_runtime(manager) | |
| scheduled: list[str] = [] | |
| def fake_start_cpu_sandbox_preload(agent_session: AgentSession) -> None: | |
| scheduled.append(agent_session.session_id) | |
| manager._start_cpu_sandbox_preload = fake_start_cpu_sandbox_preload # type: ignore[method-assign] | |
| try: | |
| first, second = await asyncio.gather( | |
| manager.ensure_session_loaded("persisted-session", user_id="owner"), | |
| manager.ensure_session_loaded("persisted-session", user_id="owner"), | |
| ) | |
| await asyncio.sleep(0) | |
| assert first is second | |
| assert list(manager.sessions) == ["persisted-session"] | |
| assert manager.run_calls == 1 # type: ignore[attr-defined] | |
| assert scheduled == ["persisted-session"] | |
| assert not stop.is_set() | |
| finally: | |
| stop.set() | |
| await _cancel_runtime_tasks(manager) | |
| async def test_create_session_schedules_cpu_sandbox_preload(): | |
| manager = _manager_with_store(NoopSessionStore()) | |
| stop = _install_fake_runtime(manager) | |
| scheduled: list[str] = [] | |
| def fake_start_cpu_sandbox_preload(agent_session: AgentSession) -> None: | |
| scheduled.append(agent_session.session_id) | |
| manager._start_cpu_sandbox_preload = fake_start_cpu_sandbox_preload # type: ignore[method-assign] | |
| try: | |
| session_id = await manager.create_session(user_id="owner", hf_token="token") | |
| assert scheduled == [session_id] | |
| assert session_id in manager.sessions | |
| runtime_session = manager.sessions[session_id].session | |
| assert not hasattr(runtime_session, "_ml_intern_artifact_collection_task") | |
| assert not hasattr(runtime_session, "_ml_intern_artifact_collection_slug") | |
| finally: | |
| stop.set() | |
| await _cancel_runtime_tasks(manager) | |
| async def test_lazy_restore_schedules_cpu_sandbox_preload(): | |
| manager = _manager_with_store(RestoreStore()) | |
| stop = _install_fake_runtime(manager) | |
| scheduled: list[str] = [] | |
| def fake_start_cpu_sandbox_preload(agent_session: AgentSession) -> None: | |
| scheduled.append(agent_session.session_id) | |
| manager._start_cpu_sandbox_preload = fake_start_cpu_sandbox_preload # type: ignore[method-assign] | |
| try: | |
| restored = await manager.ensure_session_loaded( | |
| "persisted-session", user_id="owner" | |
| ) | |
| assert restored is not None | |
| assert scheduled == ["persisted-session"] | |
| assert "persisted-session" in manager.sessions | |
| assert not hasattr(restored.session, "_ml_intern_artifact_collection_task") | |
| assert not hasattr(restored.session, "_ml_intern_artifact_collection_slug") | |
| finally: | |
| stop.set() | |
| await _cancel_runtime_tasks(manager) | |
| async def test_lazy_restore_deletes_persisted_sandbox_before_preload(monkeypatch): | |
| deleted: list[tuple[str, str, str]] = [] | |
| class FakeApi: | |
| def __init__(self, token=None): | |
| self.token = token | |
| def delete_repo(self, repo_id, repo_type): | |
| deleted.append((self.token, repo_id, repo_type)) | |
| monkeypatch.setattr("huggingface_hub.HfApi", FakeApi) | |
| store = RestoreStore( | |
| metadata={ | |
| "session_id": "persisted-session", | |
| "user_id": "owner", | |
| "model": "test-model", | |
| "created_at": datetime.now(UTC), | |
| "sandbox_space_id": "owner/sandbox-12345678", | |
| "sandbox_hardware": "cpu-basic", | |
| "sandbox_owner": "owner", | |
| "sandbox_created_at": datetime.now(UTC), | |
| "sandbox_status": "active", | |
| } | |
| ) | |
| manager = _manager_with_store(store) | |
| stop = _install_fake_runtime(manager) | |
| scheduled: list[str] = [] | |
| def fake_start_cpu_sandbox_preload(agent_session: AgentSession) -> None: | |
| scheduled.append(agent_session.session_id) | |
| manager._start_cpu_sandbox_preload = fake_start_cpu_sandbox_preload # type: ignore[method-assign] | |
| try: | |
| restored = await manager.ensure_session_loaded( | |
| "persisted-session", | |
| user_id="owner", | |
| hf_token="user-token", | |
| ) | |
| assert restored is not None | |
| assert deleted == [("user-token", "owner/sandbox-12345678", "space")] | |
| assert scheduled == ["persisted-session"] | |
| assert store.metadata["sandbox_space_id"] is None | |
| assert store.metadata["sandbox_status"] == "destroyed" | |
| finally: | |
| stop.set() | |
| await _cancel_runtime_tasks(manager) | |
| async def test_lazy_restore_can_skip_cpu_sandbox_preload_after_cleanup(monkeypatch): | |
| deleted: list[str] = [] | |
| class FakeApi: | |
| def __init__(self, token=None): | |
| self.token = token | |
| def delete_repo(self, repo_id, repo_type): | |
| deleted.append(repo_id) | |
| monkeypatch.setattr("huggingface_hub.HfApi", FakeApi) | |
| store = RestoreStore( | |
| metadata={ | |
| "session_id": "persisted-session", | |
| "user_id": "owner", | |
| "model": "test-model", | |
| "created_at": datetime.now(UTC), | |
| "sandbox_space_id": "owner/sandbox-87654321", | |
| "sandbox_status": "active", | |
| } | |
| ) | |
| manager = _manager_with_store(store) | |
| stop = _install_fake_runtime(manager) | |
| scheduled: list[str] = [] | |
| def fake_start_cpu_sandbox_preload(agent_session: AgentSession) -> None: | |
| scheduled.append(agent_session.session_id) | |
| manager._start_cpu_sandbox_preload = fake_start_cpu_sandbox_preload # type: ignore[method-assign] | |
| try: | |
| restored = await manager.ensure_session_loaded( | |
| "persisted-session", | |
| user_id="owner", | |
| hf_token="user-token", | |
| preload_sandbox=False, | |
| ) | |
| assert restored is not None | |
| assert deleted == ["owner/sandbox-87654321"] | |
| assert scheduled == [] | |
| assert store.metadata["sandbox_space_id"] is None | |
| finally: | |
| stop.set() | |
| await _cancel_runtime_tasks(manager) | |
| async def test_lazy_restore_preserves_pending_approval_tool_calls(): | |
| store = RestoreStore( | |
| metadata={ | |
| "session_id": "approval-session", | |
| "user_id": "owner", | |
| "model": "test-model", | |
| "pending_approval": [ | |
| { | |
| "id": "call_123", | |
| "type": "function", | |
| "function": { | |
| "name": "create_file", | |
| "arguments": '{"path":"app.py"}', | |
| }, | |
| } | |
| ], | |
| } | |
| ) | |
| manager = _manager_with_store(store) | |
| stop = _install_fake_runtime(manager) | |
| try: | |
| restored = await manager.ensure_session_loaded( | |
| "approval-session", user_id="owner" | |
| ) | |
| assert restored is not None | |
| tool_calls = restored.session.pending_approval["tool_calls"] | |
| assert len(tool_calls) == 1 | |
| assert tool_calls[0].id == "call_123" | |
| assert tool_calls[0].function.name == "create_file" | |
| assert tool_calls[0].function.arguments == '{"path":"app.py"}' | |
| finally: | |
| stop.set() | |
| await _cancel_runtime_tasks(manager) | |
| async def test_lazy_restore_preserves_auto_approval_policy(): | |
| store = RestoreStore( | |
| metadata={ | |
| "session_id": "yolo-session", | |
| "user_id": "owner", | |
| "model": "test-model", | |
| "auto_approval_enabled": True, | |
| "auto_approval_cost_cap_usd": 5.0, | |
| "auto_approval_estimated_spend_usd": 1.25, | |
| } | |
| ) | |
| manager = _manager_with_store(store) | |
| stop = _install_fake_runtime(manager) | |
| try: | |
| restored = await manager.ensure_session_loaded("yolo-session", user_id="owner") | |
| assert restored is not None | |
| assert restored.session.auto_approval_enabled is True | |
| assert restored.session.auto_approval_cost_cap_usd == 5.0 | |
| assert restored.session.auto_approval_estimated_spend_usd == 1.25 | |
| assert restored.session.auto_approval_policy_summary()["remaining_usd"] == 3.75 | |
| finally: | |
| stop.set() | |
| await _cancel_runtime_tasks(manager) | |
| async def test_list_sessions_dev_uses_store_dev_visibility(): | |
| class ListStore(NoopSessionStore): | |
| enabled = True | |
| def __init__(self) -> None: | |
| self.seen_user_id: str | None = None | |
| async def list_sessions(self, user_id: str, **_: Any) -> list[dict[str, Any]]: | |
| self.seen_user_id = user_id | |
| if user_id == "dev": | |
| return [ | |
| { | |
| "session_id": "s1", | |
| "user_id": "alice", | |
| "model": "m", | |
| "created_at": datetime.now(UTC), | |
| "auto_approval_enabled": True, | |
| "auto_approval_cost_cap_usd": 5.0, | |
| "auto_approval_estimated_spend_usd": 2.0, | |
| }, | |
| { | |
| "session_id": "s2", | |
| "user_id": "bob", | |
| "model": "m", | |
| "created_at": datetime.now(UTC), | |
| }, | |
| ] | |
| return [] | |
| store = ListStore() | |
| manager = _manager_with_store(store) | |
| sessions = await manager.list_sessions(user_id="dev") | |
| assert store.seen_user_id == "dev" | |
| assert {session["session_id"] for session in sessions} == {"s1", "s2"} | |
| yolo = next(session for session in sessions if session["session_id"] == "s1") | |
| assert yolo["auto_approval"] == { | |
| "enabled": True, | |
| "cost_cap_usd": 5.0, | |
| "estimated_spend_usd": 2.0, | |
| "remaining_usd": 3.0, | |
| } | |