| from typing import NotRequired, override |
|
|
| from langchain.agents import AgentState |
| from langchain.agents.middleware import AgentMiddleware |
| from langgraph.runtime import Runtime |
|
|
| from src.agents.thread_state import SandboxState, ThreadDataState |
| from src.sandbox import get_sandbox_provider |
| from src.utils.runtime import get_runtime_thread_id |
|
|
|
|
| class SandboxMiddlewareState(AgentState): |
| """Compatible with the `ThreadState` schema.""" |
|
|
| sandbox: NotRequired[SandboxState | None] |
| thread_data: NotRequired[ThreadDataState | None] |
|
|
|
|
| class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]): |
| """Create a sandbox environment and assign it to an agent. |
| |
| Lifecycle Management: |
| - With lazy_init=True (default): Sandbox is acquired on first tool call |
| - With lazy_init=False: Sandbox is acquired on first agent invocation (before_agent) |
| - Sandbox is reused across multiple turns within the same thread |
| - Sandbox is NOT released after each agent call to avoid wasteful recreation |
| - Cleanup happens at application shutdown via SandboxProvider.shutdown() |
| """ |
|
|
| state_schema = SandboxMiddlewareState |
|
|
| def __init__(self, lazy_init: bool = True): |
| """Initialize sandbox middleware. |
| |
| Args: |
| lazy_init: If True, defer sandbox acquisition until first tool call. |
| If False, acquire sandbox eagerly in before_agent(). |
| Default is True for optimal performance. |
| """ |
| super().__init__() |
| self._lazy_init = lazy_init |
|
|
| def _acquire_sandbox(self, thread_id: str) -> str: |
| provider = get_sandbox_provider() |
| sandbox_id = provider.acquire(thread_id) |
| print(f"Acquiring sandbox {sandbox_id}") |
| return sandbox_id |
|
|
| @override |
| def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None: |
| |
| if self._lazy_init: |
| return super().before_agent(state, runtime) |
|
|
| |
| if "sandbox" not in state or state["sandbox"] is None: |
| thread_id = get_runtime_thread_id(runtime) |
| if thread_id is None: |
| raise ValueError("Thread ID is required in runtime context or metadata") |
| print(f"Thread ID: {thread_id}") |
| sandbox_id = self._acquire_sandbox(thread_id) |
| return {"sandbox": {"sandbox_id": sandbox_id}} |
| return super().before_agent(state, runtime) |
|
|