diff --git a/Dockerfile b/Dockerfile index 878e0dc0eaa7cbf9a6ebda5cb36fa3dae12d0e88..d69d4ebf3cf2e9bc87bb4d688bfc4d047c5c83bf 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,39 +4,66 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# Use the standard openenv base image -# Built from: docker build -t openenv-base:latest -f src/openenv/core/containers/images/Dockerfile . -# In GitHub Actions, this is overridden to use the GHCR base image ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest +FROM ghcr.io/meta-pytorch/openenv-base:latest AS builder + +WORKDIR /app + + +COPY . /app/env + +WORKDIR /app/env + +RUN if ! command -v uv >/dev/null 2>&1; then \ + curl -LsSf https://astral.sh/uv/install.sh | sh && \ + mv /root/.local/bin/uv /usr/local/bin/uv && \ + mv /root/.local/bin/uvx /usr/local/bin/uvx; \ + fi + +RUN apt-get update && apt-get install -y --no-install-recommends \ + git \ + && rm -rf /var/lib/apt/lists/* + +RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \ + install -m 0755 /root/.local/bin/uv /usr/local/bin/uv && \ + install -m 0755 /root/.local/bin/uvx /usr/local/bin/uvx + +RUN --mount=type=cache,target=/root/.cache/uv \ + if [ -f uv.lock ]; then \ + uv sync --no-install-project --no-editable; \ + else \ + uv sync --no-install-project --no-editable; \ + fi + +RUN --mount=type=cache,target=/root/.cache/uv \ + if [ -f uv.lock ]; then \ + uv sync --no-editable; \ + else \ + uv sync --no-editable; \ + fi + +# Pre-download GPT-2 tokenizer to avoid permission issues at runtime. +# HF_HOME must match the runtime value so the cache is copied to the right place. +RUN HF_HOME=/.cache /app/env/.venv/bin/python -c "from transformers import GPT2Tokenizer; GPT2Tokenizer.from_pretrained('gpt2')" + +# Final runtime stage FROM ghcr.io/meta-pytorch/openenv-base:latest -# Install dependencies and run setup -COPY envs/chat_env/server/requirements.txt /tmp/requirements.txt -COPY envs/chat_env/server/install_deps.sh /tmp/install_deps.sh -RUN chmod +x /tmp/install_deps.sh && \ - /tmp/install_deps.sh && \ - rm /tmp/install_deps.sh /tmp/requirements.txt +WORKDIR /app + +COPY --from=builder /app/env/.venv /app/.venv +COPY --from=builder /app/env /app/env +COPY --from=builder /.cache /.cache -# Set environment variables +ENV PATH="/app/.venv/bin:$PATH" +ENV PYTHONPATH="/app/env:$PYTHONPATH" ENV HF_HOME=/.cache ENV TRANSFORMERS_CACHE=/.cache - -# Environment variables that can be overridden at runtime ENV TOKENIZER_NAME=gpt2 ENV SYSTEM_PROMPT="You are a helpful AI assistant." -ENV ENABLE_WEB_INTERFACE=false - -# Copy only what's needed for this environment -COPY src/core/ /app/src/core/ -COPY envs/chat_env/ /app/envs/chat_env/ - -# Copy README for web interface documentation -COPY envs/chat_env/README.md /app/README.md +ENV ENABLE_WEB_INTERFACE=true -# Health check HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ - CMD curl -f http://localhost:8000/health || exit 1 + CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1 -# Run the FastAPI server -CMD ["uvicorn", "envs.chat_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"] -ENV PYTHONPATH=/app/src/core:/app/src:${PYTHONPATH} +CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"] diff --git a/README.md b/README.md index 67450c4c6ca94d2907ebe81bf290ee13eb29b9c9..741e94228203b277e880aa07b741c0ef655d452d 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ pinned: false app_port: 8000 base_path: /web tags: - - openenv-main + - openenv-0.2.3 - openenv --- @@ -17,7 +17,7 @@ tags: This Space is built from OpenEnv environment `chat_env`. - Space URL: `https://huggingface.co/spaces/openenv/chat_env` -- OpenEnv pinned ref: `main` +- OpenEnv pinned ref: `0.2.3` - Hub tag: `openenv` ### Connecting from Code diff --git a/client.py b/client.py index 67c3b1195ddfcea680920638c34b908517aec2c3..eb4a3de9d88522d3e00ec041aec5a20c109d6730 100644 --- a/client.py +++ b/client.py @@ -13,7 +13,6 @@ via WebSocket for persistent sessions. from typing import Any, Dict -import torch from openenv.core.client_types import StepResult from openenv.core.env_client import EnvClient from openenv.core.env_server.interfaces import Message @@ -28,8 +27,7 @@ class ChatEnv(EnvClient[ChatAction, ChatObservation, ChatState]): This client maintains a persistent WebSocket connection to the environment server, enabling efficient multi-step interactions with lower latency. - Note: Since ChatEnvironment works with PyTorch tensors, the client - serializes tokens as lists for transport and deserializes them back to tensors. + The client transports token ids as plain JSON lists. Example: >>> # Connect to a running server @@ -38,8 +36,7 @@ class ChatEnv(EnvClient[ChatAction, ChatObservation, ChatState]): ... print(result.observation.messages) ... ... # Send an action with tokens - ... import torch - ... tokens = torch.tensor([[1, 2, 3, 4, 5]]) + ... tokens = [1, 2, 3, 4, 5] ... result = client.step(ChatAction(tokens=tokens)) ... print(result.observation.messages) ... print(result.reward) @@ -49,7 +46,7 @@ class ChatEnv(EnvClient[ChatAction, ChatObservation, ChatState]): >>> client = ChatEnv.from_docker_image("chat-env:latest") >>> try: ... result = client.reset() - ... result = client.step(ChatAction(tokens=torch.tensor([[1, 2, 3]]))) + ... result = client.step(ChatAction(tokens=[1, 2, 3])) ... finally: ... client.close() """ @@ -58,23 +55,14 @@ class ChatEnv(EnvClient[ChatAction, ChatObservation, ChatState]): """ Convert ChatAction to JSON payload for step request. - Since PyTorch tensors can't be directly serialized to JSON, - we convert them to nested lists. - Args: action: ChatAction instance with tokens Returns: Dictionary representation suitable for JSON encoding """ - # Convert tensor to list for JSON serialization - if isinstance(action.tokens, torch.Tensor): - tokens_list = action.tokens.tolist() - else: - tokens_list = action.tokens - return { - "tokens": tokens_list, + "tokens": action.tokens, "metadata": action.metadata, } @@ -90,15 +78,8 @@ class ChatEnv(EnvClient[ChatAction, ChatObservation, ChatState]): """ obs_data = payload.get("observation", {}) - # Convert tokens list back to tensor tokens_data = obs_data.get("tokens", []) - if isinstance(tokens_data, list): - if tokens_data: - tokens = torch.tensor(tokens_data) - else: - tokens = torch.tensor([]) - else: - tokens = torch.tensor([]) + tokens = tokens_data if isinstance(tokens_data, list) else [] # Parse messages messages = obs_data.get("messages", []) @@ -130,14 +111,11 @@ class ChatEnv(EnvClient[ChatAction, ChatObservation, ChatState]): # Parse history messages history_messages = payload.get("history_messages", []) - # Parse history tokens - convert lists back to tensors + # Parse history tokens history_tokens_data = payload.get("history_tokens", []) history_tokens = [] for token_list in history_tokens_data: - if token_list: - history_tokens.append(torch.tensor(token_list)) - else: - history_tokens.append(torch.tensor([])) + history_tokens.append(token_list if isinstance(token_list, list) else []) return ChatState( episode_id=payload.get("episode_id"), @@ -176,8 +154,6 @@ class ChatEnv(EnvClient[ChatAction, ChatObservation, ChatState]): raise ValueError("Message content cannot be None") # Tokenize the message - tokens = tokenizer.apply_chat_template( - conversation=[message], tokenize=True, return_tensors="pt" - ) + tokens = tokenizer.apply_chat_template(conversation=[message], tokenize=True) return ChatAction(tokens=tokens) diff --git a/envs/chat_env/client.py b/envs/chat_env/client.py index 67c3b1195ddfcea680920638c34b908517aec2c3..eb4a3de9d88522d3e00ec041aec5a20c109d6730 100644 --- a/envs/chat_env/client.py +++ b/envs/chat_env/client.py @@ -13,7 +13,6 @@ via WebSocket for persistent sessions. from typing import Any, Dict -import torch from openenv.core.client_types import StepResult from openenv.core.env_client import EnvClient from openenv.core.env_server.interfaces import Message @@ -28,8 +27,7 @@ class ChatEnv(EnvClient[ChatAction, ChatObservation, ChatState]): This client maintains a persistent WebSocket connection to the environment server, enabling efficient multi-step interactions with lower latency. - Note: Since ChatEnvironment works with PyTorch tensors, the client - serializes tokens as lists for transport and deserializes them back to tensors. + The client transports token ids as plain JSON lists. Example: >>> # Connect to a running server @@ -38,8 +36,7 @@ class ChatEnv(EnvClient[ChatAction, ChatObservation, ChatState]): ... print(result.observation.messages) ... ... # Send an action with tokens - ... import torch - ... tokens = torch.tensor([[1, 2, 3, 4, 5]]) + ... tokens = [1, 2, 3, 4, 5] ... result = client.step(ChatAction(tokens=tokens)) ... print(result.observation.messages) ... print(result.reward) @@ -49,7 +46,7 @@ class ChatEnv(EnvClient[ChatAction, ChatObservation, ChatState]): >>> client = ChatEnv.from_docker_image("chat-env:latest") >>> try: ... result = client.reset() - ... result = client.step(ChatAction(tokens=torch.tensor([[1, 2, 3]]))) + ... result = client.step(ChatAction(tokens=[1, 2, 3])) ... finally: ... client.close() """ @@ -58,23 +55,14 @@ class ChatEnv(EnvClient[ChatAction, ChatObservation, ChatState]): """ Convert ChatAction to JSON payload for step request. - Since PyTorch tensors can't be directly serialized to JSON, - we convert them to nested lists. - Args: action: ChatAction instance with tokens Returns: Dictionary representation suitable for JSON encoding """ - # Convert tensor to list for JSON serialization - if isinstance(action.tokens, torch.Tensor): - tokens_list = action.tokens.tolist() - else: - tokens_list = action.tokens - return { - "tokens": tokens_list, + "tokens": action.tokens, "metadata": action.metadata, } @@ -90,15 +78,8 @@ class ChatEnv(EnvClient[ChatAction, ChatObservation, ChatState]): """ obs_data = payload.get("observation", {}) - # Convert tokens list back to tensor tokens_data = obs_data.get("tokens", []) - if isinstance(tokens_data, list): - if tokens_data: - tokens = torch.tensor(tokens_data) - else: - tokens = torch.tensor([]) - else: - tokens = torch.tensor([]) + tokens = tokens_data if isinstance(tokens_data, list) else [] # Parse messages messages = obs_data.get("messages", []) @@ -130,14 +111,11 @@ class ChatEnv(EnvClient[ChatAction, ChatObservation, ChatState]): # Parse history messages history_messages = payload.get("history_messages", []) - # Parse history tokens - convert lists back to tensors + # Parse history tokens history_tokens_data = payload.get("history_tokens", []) history_tokens = [] for token_list in history_tokens_data: - if token_list: - history_tokens.append(torch.tensor(token_list)) - else: - history_tokens.append(torch.tensor([])) + history_tokens.append(token_list if isinstance(token_list, list) else []) return ChatState( episode_id=payload.get("episode_id"), @@ -176,8 +154,6 @@ class ChatEnv(EnvClient[ChatAction, ChatObservation, ChatState]): raise ValueError("Message content cannot be None") # Tokenize the message - tokens = tokenizer.apply_chat_template( - conversation=[message], tokenize=True, return_tensors="pt" - ) + tokens = tokenizer.apply_chat_template(conversation=[message], tokenize=True) return ChatAction(tokens=tokens) diff --git a/envs/chat_env/models.py b/envs/chat_env/models.py index b34ad30006e8e90cf34616693285c06bea9e754e..0a72b6ffd4c6b710c9d8782b48e04b987ebfd790 100644 --- a/envs/chat_env/models.py +++ b/envs/chat_env/models.py @@ -11,10 +11,25 @@ The Chat environment provides a chat-based interface for LLMs with support for tokenization and message history management. """ -import torch -from openenv.core.env_server.interfaces import Message from openenv.core.env_server.types import Action, Observation, State -from pydantic import Field +from pydantic import Field, field_validator + + +def _flatten_tokens(value) -> list[int]: + """Coerce nested tensor-like or sequence inputs into a flat token list.""" + if hasattr(value, "tolist") and callable(value.tolist): + value = value.tolist() + + if isinstance(value, tuple): + value = list(value) + + if isinstance(value, list): + flattened: list[int] = [] + for item in value: + flattened.extend(_flatten_tokens(item)) + return flattened + + return [int(value)] class ChatAction(Action): @@ -24,21 +39,24 @@ class ChatAction(Action): This interfaces directly with models. """ - tokens: torch.Tensor = Field(default_factory=lambda: torch.tensor([])) + tokens: list[int] = Field(..., min_length=1) - def __post_init__(self): - """Validate required Fields after initialization.""" - if self.tokens.numel() == 0: - raise ValueError("tokens is required and cannot be empty") + @field_validator("tokens", mode="before") + @classmethod + def _coerce_tokens(cls, value): + """Accept either tensors or JSON arrays on the public HTTP surface.""" + if isinstance(value, (list, tuple)) or hasattr(value, "tolist"): + return _flatten_tokens(value) + raise TypeError("tokens must be provided as a sequence of token ids") class ChatState(State): """State of the ChatEnvironment containing message history.""" - history_messages: list[Message] = Field(default_factory=list) - history_tokens: list[torch.Tensor] = Field( - default_factory=list - ) # Same len as messages + # TODO: revert to list[Message] once openenv-core ships typing_extensions.TypedDict + # in interfaces.py and chat_env/pyproject.toml pins to that release. + history_messages: list[dict[str, str]] = Field(default_factory=list) + history_tokens: list[list[int]] = Field(default_factory=list) # Same len as messages class ChatObservation(Observation): @@ -57,6 +75,7 @@ class ChatObservation(Observation): tokens = tensor([1, 2, 3, 4, 5, ...]) # tokenized entire conversation """ - messages: list[Message] = Field(default_factory=list) - tokens: torch.Tensor = Field(default_factory=lambda: torch.tensor([])) + # TODO: revert to list[Message] (same as above) + messages: list[dict[str, str]] = Field(default_factory=list) + tokens: list[int] = Field(default_factory=list) # Inherited Fields from Observation ABC: reward, done, metadata diff --git a/envs/chat_env/openenv_chat_env.egg-info/PKG-INFO b/envs/chat_env/openenv_chat_env.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..bff6523fc7095e520b6e3620b26e66984213ff6e --- /dev/null +++ b/envs/chat_env/openenv_chat_env.egg-info/PKG-INFO @@ -0,0 +1,14 @@ +Metadata-Version: 2.4 +Name: openenv-chat-env +Version: 0.1.0 +Summary: Chat Environment for OpenEnv - LLM-powered conversational agent +Requires-Python: >=3.10 +Requires-Dist: openenv-core[core]>=0.2.3 +Requires-Dist: fastapi>=0.115.0 +Requires-Dist: pydantic>=2.0.0 +Requires-Dist: uvicorn>=0.24.0 +Requires-Dist: requests>=2.31.0 +Requires-Dist: transformers +Provides-Extra: dev +Requires-Dist: pytest>=8.0.0; extra == "dev" +Requires-Dist: pytest-cov>=4.0.0; extra == "dev" diff --git a/envs/chat_env/openenv_chat_env.egg-info/SOURCES.txt b/envs/chat_env/openenv_chat_env.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..67ba456cb65d0ebdb5b27397cac3a35b7c087811 --- /dev/null +++ b/envs/chat_env/openenv_chat_env.egg-info/SOURCES.txt @@ -0,0 +1,18 @@ +README.md +__init__.py +client.py +models.py +pyproject.toml +./__init__.py +./client.py +./models.py +openenv_chat_env.egg-info/PKG-INFO +openenv_chat_env.egg-info/SOURCES.txt +openenv_chat_env.egg-info/dependency_links.txt +openenv_chat_env.egg-info/entry_points.txt +openenv_chat_env.egg-info/requires.txt +openenv_chat_env.egg-info/top_level.txt +server/__init__.py +server/app.py +server/chat_environment.py +server/test_chat_env.py \ No newline at end of file diff --git a/envs/chat_env/openenv_chat_env.egg-info/dependency_links.txt b/envs/chat_env/openenv_chat_env.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/envs/chat_env/openenv_chat_env.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/envs/chat_env/openenv_chat_env.egg-info/entry_points.txt b/envs/chat_env/openenv_chat_env.egg-info/entry_points.txt new file mode 100644 index 0000000000000000000000000000000000000000..893068401fa986707ad57f3c68e84f9620c53df6 --- /dev/null +++ b/envs/chat_env/openenv_chat_env.egg-info/entry_points.txt @@ -0,0 +1,2 @@ +[console_scripts] +server = chat_env.server.app:main diff --git a/envs/chat_env/openenv_chat_env.egg-info/requires.txt b/envs/chat_env/openenv_chat_env.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..c2e059d51585ba03dd3ae18172a1f3d5e0aeafc2 --- /dev/null +++ b/envs/chat_env/openenv_chat_env.egg-info/requires.txt @@ -0,0 +1,10 @@ +openenv-core[core]>=0.2.3 +fastapi>=0.115.0 +pydantic>=2.0.0 +uvicorn>=0.24.0 +requests>=2.31.0 +transformers + +[dev] +pytest>=8.0.0 +pytest-cov>=4.0.0 diff --git a/envs/chat_env/openenv_chat_env.egg-info/top_level.txt b/envs/chat_env/openenv_chat_env.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..b5331ca79966793dd53ca1fb41d9de845358acc8 --- /dev/null +++ b/envs/chat_env/openenv_chat_env.egg-info/top_level.txt @@ -0,0 +1 @@ +chat_env diff --git a/envs/chat_env/pyproject.toml b/envs/chat_env/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..49b376aae710c417e0199b984619751668e34267 --- /dev/null +++ b/envs/chat_env/pyproject.toml @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +[build-system] +requires = ["setuptools>=45", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "openenv-chat-env" +version = "0.1.0" +description = "Chat Environment for OpenEnv - LLM-powered conversational agent" +requires-python = ">=3.10" +dependencies = [ + "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git@v0.2.3", + "fastapi>=0.115.0", + "pydantic>=2.0.0", + "uvicorn>=0.24.0", + "requests>=2.31.0", + "transformers", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0.0", + "pytest-cov>=4.0.0", +] + +[project.scripts] +server = "chat_env.server.app:main" + +[tool.setuptools] +include-package-data = true +packages = ["chat_env", "chat_env.server"] +package-dir = { "chat_env" = ".", "chat_env.server" = "server" } diff --git a/envs/chat_env/server/Dockerfile b/envs/chat_env/server/Dockerfile index edf92b546625bb88fcc3c2ff0df3e818f5f92910..5ecee93a8a54bfb2faac5ce8869a0e05e16f233b 100644 --- a/envs/chat_env/server/Dockerfile +++ b/envs/chat_env/server/Dockerfile @@ -4,38 +4,62 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# Use the standard openenv base image -# Built from: docker build -t openenv-base:latest -f src/openenv/core/containers/images/Dockerfile . -# In GitHub Actions, this is overridden to use the GHCR base image -ARG BASE_IMAGE=openenv-base:latest +ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest +FROM ${BASE_IMAGE} AS builder + +WORKDIR /app + + +COPY . /app/env + +WORKDIR /app/env + +RUN if ! command -v uv >/dev/null 2>&1; then \ + curl -LsSf https://astral.sh/uv/install.sh | sh && \ + mv /root/.local/bin/uv /usr/local/bin/uv && \ + mv /root/.local/bin/uvx /usr/local/bin/uvx; \ + fi + +RUN apt-get update && apt-get install -y --no-install-recommends \ + git \ + && rm -rf /var/lib/apt/lists/* + +RUN --mount=type=cache,target=/root/.cache/uv \ + if [ -f uv.lock ]; then \ + uv sync --frozen --no-install-project --no-editable; \ + else \ + uv sync --no-install-project --no-editable; \ + fi + +RUN --mount=type=cache,target=/root/.cache/uv \ + if [ -f uv.lock ]; then \ + uv sync --frozen --no-editable; \ + else \ + uv sync --no-editable; \ + fi + +# Pre-download GPT-2 tokenizer to avoid permission issues at runtime. +# HF_HOME must match the runtime value so the cache is copied to the right place. +RUN HF_HOME=/.cache /app/env/.venv/bin/python -c "from transformers import GPT2Tokenizer; GPT2Tokenizer.from_pretrained('gpt2')" + +# Final runtime stage FROM ${BASE_IMAGE} -# Install dependencies and run setup -COPY envs/chat_env/server/requirements.txt /tmp/requirements.txt -COPY envs/chat_env/server/install_deps.sh /tmp/install_deps.sh -RUN chmod +x /tmp/install_deps.sh && \ - /tmp/install_deps.sh && \ - rm /tmp/install_deps.sh /tmp/requirements.txt +WORKDIR /app + +COPY --from=builder /app/env/.venv /app/.venv +COPY --from=builder /app/env /app/env +COPY --from=builder /.cache /.cache -# Set environment variables +ENV PATH="/app/.venv/bin:$PATH" +ENV PYTHONPATH="/app/env:$PYTHONPATH" ENV HF_HOME=/.cache ENV TRANSFORMERS_CACHE=/.cache - -# Environment variables that can be overridden at runtime ENV TOKENIZER_NAME=gpt2 ENV SYSTEM_PROMPT="You are a helpful AI assistant." -ENV ENABLE_WEB_INTERFACE=false - -# Copy only what's needed for this environment -COPY src/core/ /app/src/core/ -COPY envs/chat_env/ /app/envs/chat_env/ - -# Copy README for web interface documentation -COPY envs/chat_env/README.md /app/README.md +ENV ENABLE_WEB_INTERFACE=true -# Health check HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ - CMD curl -f http://localhost:8000/health || exit 1 + CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1 -# Run the FastAPI server -CMD ["uvicorn", "envs.chat_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"] +CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"] diff --git a/envs/chat_env/server/app.py b/envs/chat_env/server/app.py index 7345911ac40c4dbd50992003f7800c30e54e0ecc..c66bf8ee0d6a9bc9461f59fb18f8d0a84da2589f 100644 --- a/envs/chat_env/server/app.py +++ b/envs/chat_env/server/app.py @@ -28,8 +28,17 @@ import os from openenv.core.env_server import create_app -from ..models import ChatAction, ChatObservation -from .chat_environment import ChatEnvironment +# Support both in-repo and standalone imports +try: + # In-repo imports (when running from OpenEnv repository) + from ..models import ChatAction, ChatObservation + from .chat_environment import ChatEnvironment +except ImportError as e: + if "relative import" not in str(e) and "no known parent package" not in str(e): + raise + # Standalone imports (when running via uvicorn server.app:app) + from models import ChatAction, ChatObservation + from server.chat_environment import ChatEnvironment # Initialize tokenizer based on environment variable @@ -78,7 +87,11 @@ app = create_app( ) -if __name__ == "__main__": +def main(): import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000) + + +if __name__ == "__main__": + main() diff --git a/envs/chat_env/server/chat_environment.py b/envs/chat_env/server/chat_environment.py index b8b3f6b225b6881f81ad05a0162fedf1606e2a4f..90b2d01f07364486d2e90727d2b4531025336b54 100644 --- a/envs/chat_env/server/chat_environment.py +++ b/envs/chat_env/server/chat_environment.py @@ -10,7 +10,6 @@ Chat Environment Implementation. A chat-based environment for LLMs, designed as a blank canvas for conversation and RL. """ -import torch from openenv.core.env_server.interfaces import ( Environment, Message, @@ -18,7 +17,15 @@ from openenv.core.env_server.interfaces import ( Transform, ) -from ..models import ChatAction, ChatObservation, ChatState +# Support both in-repo and standalone imports +try: + # In-repo imports (when running from OpenEnv repository) + from ..models import ChatAction, ChatObservation, ChatState +except ImportError as e: + if "relative import" not in str(e) and "no known parent package" not in str(e): + raise + # Standalone imports (when running via uvicorn server.app:app) + from models import ChatAction, ChatObservation, ChatState class ChatEnvironment(Environment): @@ -64,34 +71,37 @@ class ChatEnvironment(Environment): system_tokens = self._tokenize_conversation([system_message]) self._state.history_tokens.append(system_tokens) - def _tokenize_conversation(self, conversation: list[Message]) -> torch.Tensor: + def _coerce_tokens(self, tokens) -> list[int]: + """Normalize tokenizer outputs into a flat list of ints.""" + if hasattr(tokens, "tolist") and callable(tokens.tolist): + tokens = tokens.tolist() + + if isinstance(tokens, tuple): + tokens = list(tokens) + + if isinstance(tokens, list): + flattened: list[int] = [] + for token in tokens: + flattened.extend(self._coerce_tokens(token)) + return flattened + + return [int(tokens)] + + def _tokenize_conversation(self, conversation: list[Message]) -> list[int]: """Tokenize a conversation with a chat-template fallback for base tokenizers.""" try: - tokens = self.tokenizer.apply_chat_template( - conversation=conversation, - tokenize=True, - return_tensors="pt", # type: ignore[arg-type] - ) + tokens = self.tokenizer.apply_chat_template(conversation=conversation, tokenize=True) except Exception: # Some tokenizers (e.g. gpt2) do not define `chat_template`. fallback_text = "".join( f"{m['role']}: {m['content']}\n" for m in conversation ) if hasattr(self.tokenizer, "encode"): - try: - tokens = self.tokenizer.encode( # type: ignore[attr-defined] - fallback_text, - return_tensors="pt", - ) - except TypeError: - token_ids = self.tokenizer.encode(fallback_text) # type: ignore[attr-defined] - tokens = torch.tensor([token_ids], dtype=torch.long) + tokens = self.tokenizer.encode(fallback_text) # type: ignore[attr-defined] else: raise ValueError("Tokenizer must support apply_chat_template or encode") - if isinstance(tokens, torch.Tensor): - return tokens - return torch.tensor(tokens, dtype=torch.long) + return self._coerce_tokens(tokens) def reset(self) -> ChatObservation: """Reset the environment to initial state. @@ -121,13 +131,13 @@ class ChatEnvironment(Environment): Returns: ChatObservation: The updated observation with the new tokens added. """ + action_tokens = [int(token) for token in action.tokens] + # Store the tokens directly from the action - self._state.history_tokens.append(action.tokens) + self._state.history_tokens.append(action_tokens) # Decode tokens to text and add as a message to history - decoded_text = self.tokenizer.decode( - action.tokens.squeeze(), skip_special_tokens=True - ) + decoded_text = self.tokenizer.decode(action_tokens, skip_special_tokens=True) assistant_message: Message = {"role": "assistant", "content": decoded_text} self._state.history_messages.append(assistant_message) @@ -143,12 +153,13 @@ class ChatEnvironment(Environment): ChatObservation: Observation with messages and flattened tokens """ if self._state.history_tokens: - # Flatten all tokens into a single 1D tensor - flattened_tokens = torch.cat( - (t.flatten() for t in self._state.history_tokens), dim=0 - ) + flattened_tokens = [ + token + for token_list in self._state.history_tokens + for token in token_list + ] else: - flattened_tokens = torch.tensor([]) + flattened_tokens = [] observation = ChatObservation( messages=self._state.history_messages.copy(), # Copy to prevent external mutation @@ -162,7 +173,7 @@ class ChatEnvironment(Environment): # If transform returns base Observation, convert back to ChatObservation return ChatObservation( messages=getattr(transformed, "messages", []), - tokens=getattr(transformed, "tokens", torch.tensor([])), + tokens=self._coerce_tokens(getattr(transformed, "tokens", [])), done=transformed.done, reward=transformed.reward, ) diff --git a/envs/chat_env/server/requirements.txt b/envs/chat_env/server/requirements.txt index 4f492ddc93de3952474c43b325c02b1b0fcdec9c..976a2b1f3998279c10c413279a095be86bf69167 100644 --- a/envs/chat_env/server/requirements.txt +++ b/envs/chat_env/server/requirements.txt @@ -1,2 +1 @@ -torch transformers diff --git a/envs/chat_env/server/test_chat_env.py b/envs/chat_env/server/test_chat_env.py index e58988c13bcae105e325d386efb4fa4d7c7e2c56..efe17b987c5ff0d6a69a34a9b023f4472ca85592 100644 --- a/envs/chat_env/server/test_chat_env.py +++ b/envs/chat_env/server/test_chat_env.py @@ -10,7 +10,6 @@ Test suite for ChatEnvironment. Proper unit tests with assertions to verify correct behavior. """ -import torch from openenv.core.env_server.interfaces import Message from ..models import ChatAction @@ -27,21 +26,21 @@ class MockTokenizer: return_tensors: str | None = None, **kwargs, ): - """Mock implementation that creates deterministic token tensors from text.""" + """Mock implementation that creates deterministic tokens from text.""" # Concatenate all message content + del tokenize, return_tensors, kwargs text = " ".join([msg["content"] for msg in conversation]) # Create deterministic tokens based on text content # Use character codes modulo 256 to get valid token IDs tokens = [ord(c) % 256 for c in text] - if return_tensors == "pt": - return torch.tensor([tokens]) return tokens def decode(self, token_ids, skip_special_tokens: bool = False, **kwargs) -> str: """Mock decode that reverses the encoding process.""" - if isinstance(token_ids, torch.Tensor): + del skip_special_tokens, kwargs + if hasattr(token_ids, "tolist") and callable(token_ids.tolist): token_ids = token_ids.tolist() # Reverse the encoding: convert tokens back to characters @@ -63,12 +62,12 @@ def test_tokenization_consistency(): action2 = env.message_to_action(message2) # Verify tokens are identical - assert torch.equal(action1.tokens, action2.tokens), ( + assert action1.tokens == action2.tokens, ( "Same message should produce identical tokens" ) # Verify tokens are not empty - assert action1.tokens.numel() > 0, "Tokens should not be empty" + assert len(action1.tokens) > 0, "Tokens should not be empty" print("✓ test_tokenization_consistency passed") @@ -151,13 +150,13 @@ def test_token_history_accumulation(): env = ChatEnvironment(tokenizer=tokenizer) obs = env.reset() - initial_token_count = obs.tokens.numel() + initial_token_count = len(obs.tokens) # Step with first message message1 = {"role": "user", "content": "Hi"} action1 = env.message_to_action(message1) obs1 = env.step(action1) - token_count_1 = obs1.tokens.numel() + token_count_1 = len(obs1.tokens) # Tokens should increase assert token_count_1 > initial_token_count, "Token count should increase after step" @@ -166,7 +165,7 @@ def test_token_history_accumulation(): message2 = {"role": "assistant", "content": "Hello there"} action2 = env.message_to_action(message2) obs2 = env.step(action2) - token_count_2 = obs2.tokens.numel() + token_count_2 = len(obs2.tokens) # Tokens should continue to accumulate assert token_count_2 > token_count_1, ( @@ -174,8 +173,8 @@ def test_token_history_accumulation(): ) # Verify tokens are the concatenation of both messages - expected_tokens = torch.cat([action1.tokens.flatten(), action2.tokens.flatten()]) - assert torch.equal(obs2.tokens, expected_tokens), ( + expected_tokens = action1.tokens + action2.tokens + assert obs2.tokens == expected_tokens, ( "Tokens should be concatenation of all actions" ) @@ -190,7 +189,7 @@ def test_direct_token_action(): env.reset() # Create raw tokens - raw_tokens = torch.tensor([[72, 101, 108, 108, 111]]) # ASCII for "Hello" + raw_tokens = [[72, 101, 108, 108, 111]] # ASCII for "Hello" action = ChatAction(tokens=raw_tokens) # Step with raw tokens @@ -201,7 +200,7 @@ def test_direct_token_action(): assert obs.messages[0]["role"] == "assistant", "Should default to assistant role" # Verify tokens match what we sent (flattened) - assert torch.equal(obs.tokens, raw_tokens.flatten()), ( + assert obs.tokens == [72, 101, 108, 108, 111], ( "Observation tokens should match input tokens" ) @@ -211,10 +210,12 @@ def test_direct_token_action(): def test_empty_tokens_validation(): """Test that empty tokens raise a ValueError.""" try: - action = ChatAction(tokens=torch.tensor([])) + action = ChatAction(tokens=[]) assert False, "Should have raised ValueError for empty tokens" except ValueError as e: - assert "empty" in str(e).lower(), "Error message should mention empty tokens" + assert "at least 1 item" in str(e).lower(), ( + "Error message should mention the minimum token count" + ) print("✓ test_empty_tokens_validation passed") diff --git a/models.py b/models.py index b34ad30006e8e90cf34616693285c06bea9e754e..0a72b6ffd4c6b710c9d8782b48e04b987ebfd790 100644 --- a/models.py +++ b/models.py @@ -11,10 +11,25 @@ The Chat environment provides a chat-based interface for LLMs with support for tokenization and message history management. """ -import torch -from openenv.core.env_server.interfaces import Message from openenv.core.env_server.types import Action, Observation, State -from pydantic import Field +from pydantic import Field, field_validator + + +def _flatten_tokens(value) -> list[int]: + """Coerce nested tensor-like or sequence inputs into a flat token list.""" + if hasattr(value, "tolist") and callable(value.tolist): + value = value.tolist() + + if isinstance(value, tuple): + value = list(value) + + if isinstance(value, list): + flattened: list[int] = [] + for item in value: + flattened.extend(_flatten_tokens(item)) + return flattened + + return [int(value)] class ChatAction(Action): @@ -24,21 +39,24 @@ class ChatAction(Action): This interfaces directly with models. """ - tokens: torch.Tensor = Field(default_factory=lambda: torch.tensor([])) + tokens: list[int] = Field(..., min_length=1) - def __post_init__(self): - """Validate required Fields after initialization.""" - if self.tokens.numel() == 0: - raise ValueError("tokens is required and cannot be empty") + @field_validator("tokens", mode="before") + @classmethod + def _coerce_tokens(cls, value): + """Accept either tensors or JSON arrays on the public HTTP surface.""" + if isinstance(value, (list, tuple)) or hasattr(value, "tolist"): + return _flatten_tokens(value) + raise TypeError("tokens must be provided as a sequence of token ids") class ChatState(State): """State of the ChatEnvironment containing message history.""" - history_messages: list[Message] = Field(default_factory=list) - history_tokens: list[torch.Tensor] = Field( - default_factory=list - ) # Same len as messages + # TODO: revert to list[Message] once openenv-core ships typing_extensions.TypedDict + # in interfaces.py and chat_env/pyproject.toml pins to that release. + history_messages: list[dict[str, str]] = Field(default_factory=list) + history_tokens: list[list[int]] = Field(default_factory=list) # Same len as messages class ChatObservation(Observation): @@ -57,6 +75,7 @@ class ChatObservation(Observation): tokens = tensor([1, 2, 3, 4, 5, ...]) # tokenized entire conversation """ - messages: list[Message] = Field(default_factory=list) - tokens: torch.Tensor = Field(default_factory=lambda: torch.tensor([])) + # TODO: revert to list[Message] (same as above) + messages: list[dict[str, str]] = Field(default_factory=list) + tokens: list[int] = Field(default_factory=list) # Inherited Fields from Observation ABC: reward, done, metadata diff --git a/openenv_chat_env.egg-info/PKG-INFO b/openenv_chat_env.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..bff6523fc7095e520b6e3620b26e66984213ff6e --- /dev/null +++ b/openenv_chat_env.egg-info/PKG-INFO @@ -0,0 +1,14 @@ +Metadata-Version: 2.4 +Name: openenv-chat-env +Version: 0.1.0 +Summary: Chat Environment for OpenEnv - LLM-powered conversational agent +Requires-Python: >=3.10 +Requires-Dist: openenv-core[core]>=0.2.3 +Requires-Dist: fastapi>=0.115.0 +Requires-Dist: pydantic>=2.0.0 +Requires-Dist: uvicorn>=0.24.0 +Requires-Dist: requests>=2.31.0 +Requires-Dist: transformers +Provides-Extra: dev +Requires-Dist: pytest>=8.0.0; extra == "dev" +Requires-Dist: pytest-cov>=4.0.0; extra == "dev" diff --git a/openenv_chat_env.egg-info/SOURCES.txt b/openenv_chat_env.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..67ba456cb65d0ebdb5b27397cac3a35b7c087811 --- /dev/null +++ b/openenv_chat_env.egg-info/SOURCES.txt @@ -0,0 +1,18 @@ +README.md +__init__.py +client.py +models.py +pyproject.toml +./__init__.py +./client.py +./models.py +openenv_chat_env.egg-info/PKG-INFO +openenv_chat_env.egg-info/SOURCES.txt +openenv_chat_env.egg-info/dependency_links.txt +openenv_chat_env.egg-info/entry_points.txt +openenv_chat_env.egg-info/requires.txt +openenv_chat_env.egg-info/top_level.txt +server/__init__.py +server/app.py +server/chat_environment.py +server/test_chat_env.py \ No newline at end of file diff --git a/openenv_chat_env.egg-info/dependency_links.txt b/openenv_chat_env.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/openenv_chat_env.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/openenv_chat_env.egg-info/entry_points.txt b/openenv_chat_env.egg-info/entry_points.txt new file mode 100644 index 0000000000000000000000000000000000000000..893068401fa986707ad57f3c68e84f9620c53df6 --- /dev/null +++ b/openenv_chat_env.egg-info/entry_points.txt @@ -0,0 +1,2 @@ +[console_scripts] +server = chat_env.server.app:main diff --git a/openenv_chat_env.egg-info/requires.txt b/openenv_chat_env.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..c2e059d51585ba03dd3ae18172a1f3d5e0aeafc2 --- /dev/null +++ b/openenv_chat_env.egg-info/requires.txt @@ -0,0 +1,10 @@ +openenv-core[core]>=0.2.3 +fastapi>=0.115.0 +pydantic>=2.0.0 +uvicorn>=0.24.0 +requests>=2.31.0 +transformers + +[dev] +pytest>=8.0.0 +pytest-cov>=4.0.0 diff --git a/openenv_chat_env.egg-info/top_level.txt b/openenv_chat_env.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..b5331ca79966793dd53ca1fb41d9de845358acc8 --- /dev/null +++ b/openenv_chat_env.egg-info/top_level.txt @@ -0,0 +1 @@ +chat_env diff --git a/pyproject.toml b/pyproject.toml index 082cfa73560b0a96a8e3660bfb4448a632c310b2..49b376aae710c417e0199b984619751668e34267 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,147 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + [build-system] requires = ["setuptools>=45", "wheel"] build-backend = "setuptools.build_meta" [project] -name = "openenv-core" -version = "0.2.2.dev0" -description = "A unified framework for reinforcement learning environments" -readme = "README.md" +name = "openenv-chat-env" +version = "0.1.0" +description = "Chat Environment for OpenEnv - LLM-powered conversational agent" requires-python = ">=3.10" dependencies = [ - # Core shared dependencies - minimal set required for all environments - # Heavy dependencies (torch, numpy, smolagents, etc.) should be in - # individual environment pyproject.toml files - "fastapi>=0.104.0", + "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git@v0.2.3", + "fastapi>=0.115.0", "pydantic>=2.0.0", "uvicorn>=0.24.0", - "requests>=2.25.0", - # CLI dependencies - "typer>=0.9.0", - "rich>=13.0.0", - "pyyaml>=6.0", - "huggingface_hub>=0.20.0", - "openai>=2.7.2", - "tomli>=2.3.0", - "tomli-w>=1.2.0", - "websockets>=15.0.1", - # MCP support - "fastmcp>=3.0.0", - # Web UI dependencies - "gradio>=4.0.0", + "requests>=2.31.0", + "transformers", ] [project.optional-dependencies] -core = [ - "fastapi>=0.104.0", - "pydantic>=2.0.0", - "uvicorn>=0.24.0", - "requests>=2.25.0", - "websockets>=15.0.1", -] -cli = [ - "typer>=0.9.0", - "rich>=13.0.0", - "pyyaml>=6.0", - "huggingface_hub>=0.20.0", - "openai>=2.7.2", - "tomli>=2.3.0", - "tomli-w>=1.2.0", -] -docs = [ - "sphinx==7.2.6", - "pytorch-sphinx-theme2", - "sphinxcontrib.katex==0.9.10", - "docutils>=0.18.1,<0.21", - "sphinx-design==0.6.1", - "sphinxcontrib-mermaid==1.0.0", - "myst-parser", - "sphinxext-opengraph", - "sphinx-sitemap==2.7.1", - "sphinx-gallery>=0.14.0", - "matplotlib", - "nest-asyncio", - "smolagents", -] -all = [ - "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git@main", - "openenv-core[cli]", -] -daytona = [ - "daytona>=0.136.0", - "pyyaml>=6.0", -] -inspect = [ - "inspect-ai>=0.3.0", +dev = [ + "pytest>=8.0.0", + "pytest-cov>=4.0.0", ] [project.scripts] -openenv = "openenv.cli.__main__:main" +server = "chat_env.server.app:main" [tool.setuptools] -package-dir = {"" = "src"} include-package-data = true - -[tool.setuptools.package-data] -"openenv.cli" = ["templates/**/*"] - -[tool.setuptools.packages.find] -where = ["src"] - -[tool.coverage.run] -omit = [ - "openenv/cli/templates/**", - "**/templates/**", - "openenv/cli/__main__.py", -] - -[tool.coverage.report] -exclude_lines = [ - "pragma: no cover", - "def __repr__", - "raise AssertionError", - "raise NotImplementedError", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] - -[tool.pytest.ini_options] -asyncio_mode = "auto" -asyncio_default_fixture_loop_scope = "function" -markers = [ - "docker: Tests that require Docker to be running", - "network: Tests that require network access (HuggingFace, etc.)", - "integration: Integration tests with external resources", -] - -[dependency-groups] -dev = [ - "ruff>=0.14.0", - "usort>=1.1.0", - "pytest>=7.0", - "pytest-asyncio>=0.21", -] - -[tool.usort] -# Disable first_party auto-detection so all non-stdlib imports land in -# the same "third_party" bucket (the default_category). This matches -# pyfmt's usort behavior inside arc f, which groups openenv.* and env -# package imports together without blank-line separators. -first_party_detection = false - -[tool.ruff] -line-length = 88 - -[tool.ruff.lint] -select = ["E", "F", "W"] -ignore = [ - "E402", # Module level import not at top of file (needed for pytest.importorskip patterns) - "E501", # Line too long (not enforced previously, would require large refactor) -] - -[tool.ruff.lint.per-file-ignores] -# Context manager variables that are intentionally unused -"tests/envs/test_websockets.py" = ["F841"] -"tests/test_cli/test_push.py" = ["F841"] -# Compatibility shim module -"src/openenv_core/__init__.py" = ["F401"] +packages = ["chat_env", "chat_env.server"] +package-dir = { "chat_env" = ".", "chat_env.server" = "server" } diff --git a/server/Dockerfile b/server/Dockerfile index edf92b546625bb88fcc3c2ff0df3e818f5f92910..5ecee93a8a54bfb2faac5ce8869a0e05e16f233b 100644 --- a/server/Dockerfile +++ b/server/Dockerfile @@ -4,38 +4,62 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# Use the standard openenv base image -# Built from: docker build -t openenv-base:latest -f src/openenv/core/containers/images/Dockerfile . -# In GitHub Actions, this is overridden to use the GHCR base image -ARG BASE_IMAGE=openenv-base:latest +ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest +FROM ${BASE_IMAGE} AS builder + +WORKDIR /app + + +COPY . /app/env + +WORKDIR /app/env + +RUN if ! command -v uv >/dev/null 2>&1; then \ + curl -LsSf https://astral.sh/uv/install.sh | sh && \ + mv /root/.local/bin/uv /usr/local/bin/uv && \ + mv /root/.local/bin/uvx /usr/local/bin/uvx; \ + fi + +RUN apt-get update && apt-get install -y --no-install-recommends \ + git \ + && rm -rf /var/lib/apt/lists/* + +RUN --mount=type=cache,target=/root/.cache/uv \ + if [ -f uv.lock ]; then \ + uv sync --frozen --no-install-project --no-editable; \ + else \ + uv sync --no-install-project --no-editable; \ + fi + +RUN --mount=type=cache,target=/root/.cache/uv \ + if [ -f uv.lock ]; then \ + uv sync --frozen --no-editable; \ + else \ + uv sync --no-editable; \ + fi + +# Pre-download GPT-2 tokenizer to avoid permission issues at runtime. +# HF_HOME must match the runtime value so the cache is copied to the right place. +RUN HF_HOME=/.cache /app/env/.venv/bin/python -c "from transformers import GPT2Tokenizer; GPT2Tokenizer.from_pretrained('gpt2')" + +# Final runtime stage FROM ${BASE_IMAGE} -# Install dependencies and run setup -COPY envs/chat_env/server/requirements.txt /tmp/requirements.txt -COPY envs/chat_env/server/install_deps.sh /tmp/install_deps.sh -RUN chmod +x /tmp/install_deps.sh && \ - /tmp/install_deps.sh && \ - rm /tmp/install_deps.sh /tmp/requirements.txt +WORKDIR /app + +COPY --from=builder /app/env/.venv /app/.venv +COPY --from=builder /app/env /app/env +COPY --from=builder /.cache /.cache -# Set environment variables +ENV PATH="/app/.venv/bin:$PATH" +ENV PYTHONPATH="/app/env:$PYTHONPATH" ENV HF_HOME=/.cache ENV TRANSFORMERS_CACHE=/.cache - -# Environment variables that can be overridden at runtime ENV TOKENIZER_NAME=gpt2 ENV SYSTEM_PROMPT="You are a helpful AI assistant." -ENV ENABLE_WEB_INTERFACE=false - -# Copy only what's needed for this environment -COPY src/core/ /app/src/core/ -COPY envs/chat_env/ /app/envs/chat_env/ - -# Copy README for web interface documentation -COPY envs/chat_env/README.md /app/README.md +ENV ENABLE_WEB_INTERFACE=true -# Health check HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ - CMD curl -f http://localhost:8000/health || exit 1 + CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1 -# Run the FastAPI server -CMD ["uvicorn", "envs.chat_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"] +CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"] diff --git a/server/app.py b/server/app.py index 7345911ac40c4dbd50992003f7800c30e54e0ecc..c66bf8ee0d6a9bc9461f59fb18f8d0a84da2589f 100644 --- a/server/app.py +++ b/server/app.py @@ -28,8 +28,17 @@ import os from openenv.core.env_server import create_app -from ..models import ChatAction, ChatObservation -from .chat_environment import ChatEnvironment +# Support both in-repo and standalone imports +try: + # In-repo imports (when running from OpenEnv repository) + from ..models import ChatAction, ChatObservation + from .chat_environment import ChatEnvironment +except ImportError as e: + if "relative import" not in str(e) and "no known parent package" not in str(e): + raise + # Standalone imports (when running via uvicorn server.app:app) + from models import ChatAction, ChatObservation + from server.chat_environment import ChatEnvironment # Initialize tokenizer based on environment variable @@ -78,7 +87,11 @@ app = create_app( ) -if __name__ == "__main__": +def main(): import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000) + + +if __name__ == "__main__": + main() diff --git a/server/chat_environment.py b/server/chat_environment.py index b8b3f6b225b6881f81ad05a0162fedf1606e2a4f..90b2d01f07364486d2e90727d2b4531025336b54 100644 --- a/server/chat_environment.py +++ b/server/chat_environment.py @@ -10,7 +10,6 @@ Chat Environment Implementation. A chat-based environment for LLMs, designed as a blank canvas for conversation and RL. """ -import torch from openenv.core.env_server.interfaces import ( Environment, Message, @@ -18,7 +17,15 @@ from openenv.core.env_server.interfaces import ( Transform, ) -from ..models import ChatAction, ChatObservation, ChatState +# Support both in-repo and standalone imports +try: + # In-repo imports (when running from OpenEnv repository) + from ..models import ChatAction, ChatObservation, ChatState +except ImportError as e: + if "relative import" not in str(e) and "no known parent package" not in str(e): + raise + # Standalone imports (when running via uvicorn server.app:app) + from models import ChatAction, ChatObservation, ChatState class ChatEnvironment(Environment): @@ -64,34 +71,37 @@ class ChatEnvironment(Environment): system_tokens = self._tokenize_conversation([system_message]) self._state.history_tokens.append(system_tokens) - def _tokenize_conversation(self, conversation: list[Message]) -> torch.Tensor: + def _coerce_tokens(self, tokens) -> list[int]: + """Normalize tokenizer outputs into a flat list of ints.""" + if hasattr(tokens, "tolist") and callable(tokens.tolist): + tokens = tokens.tolist() + + if isinstance(tokens, tuple): + tokens = list(tokens) + + if isinstance(tokens, list): + flattened: list[int] = [] + for token in tokens: + flattened.extend(self._coerce_tokens(token)) + return flattened + + return [int(tokens)] + + def _tokenize_conversation(self, conversation: list[Message]) -> list[int]: """Tokenize a conversation with a chat-template fallback for base tokenizers.""" try: - tokens = self.tokenizer.apply_chat_template( - conversation=conversation, - tokenize=True, - return_tensors="pt", # type: ignore[arg-type] - ) + tokens = self.tokenizer.apply_chat_template(conversation=conversation, tokenize=True) except Exception: # Some tokenizers (e.g. gpt2) do not define `chat_template`. fallback_text = "".join( f"{m['role']}: {m['content']}\n" for m in conversation ) if hasattr(self.tokenizer, "encode"): - try: - tokens = self.tokenizer.encode( # type: ignore[attr-defined] - fallback_text, - return_tensors="pt", - ) - except TypeError: - token_ids = self.tokenizer.encode(fallback_text) # type: ignore[attr-defined] - tokens = torch.tensor([token_ids], dtype=torch.long) + tokens = self.tokenizer.encode(fallback_text) # type: ignore[attr-defined] else: raise ValueError("Tokenizer must support apply_chat_template or encode") - if isinstance(tokens, torch.Tensor): - return tokens - return torch.tensor(tokens, dtype=torch.long) + return self._coerce_tokens(tokens) def reset(self) -> ChatObservation: """Reset the environment to initial state. @@ -121,13 +131,13 @@ class ChatEnvironment(Environment): Returns: ChatObservation: The updated observation with the new tokens added. """ + action_tokens = [int(token) for token in action.tokens] + # Store the tokens directly from the action - self._state.history_tokens.append(action.tokens) + self._state.history_tokens.append(action_tokens) # Decode tokens to text and add as a message to history - decoded_text = self.tokenizer.decode( - action.tokens.squeeze(), skip_special_tokens=True - ) + decoded_text = self.tokenizer.decode(action_tokens, skip_special_tokens=True) assistant_message: Message = {"role": "assistant", "content": decoded_text} self._state.history_messages.append(assistant_message) @@ -143,12 +153,13 @@ class ChatEnvironment(Environment): ChatObservation: Observation with messages and flattened tokens """ if self._state.history_tokens: - # Flatten all tokens into a single 1D tensor - flattened_tokens = torch.cat( - (t.flatten() for t in self._state.history_tokens), dim=0 - ) + flattened_tokens = [ + token + for token_list in self._state.history_tokens + for token in token_list + ] else: - flattened_tokens = torch.tensor([]) + flattened_tokens = [] observation = ChatObservation( messages=self._state.history_messages.copy(), # Copy to prevent external mutation @@ -162,7 +173,7 @@ class ChatEnvironment(Environment): # If transform returns base Observation, convert back to ChatObservation return ChatObservation( messages=getattr(transformed, "messages", []), - tokens=getattr(transformed, "tokens", torch.tensor([])), + tokens=self._coerce_tokens(getattr(transformed, "tokens", [])), done=transformed.done, reward=transformed.reward, ) diff --git a/server/requirements.txt b/server/requirements.txt index 4f492ddc93de3952474c43b325c02b1b0fcdec9c..976a2b1f3998279c10c413279a095be86bf69167 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -1,2 +1 @@ -torch transformers diff --git a/server/test_chat_env.py b/server/test_chat_env.py index e58988c13bcae105e325d386efb4fa4d7c7e2c56..efe17b987c5ff0d6a69a34a9b023f4472ca85592 100644 --- a/server/test_chat_env.py +++ b/server/test_chat_env.py @@ -10,7 +10,6 @@ Test suite for ChatEnvironment. Proper unit tests with assertions to verify correct behavior. """ -import torch from openenv.core.env_server.interfaces import Message from ..models import ChatAction @@ -27,21 +26,21 @@ class MockTokenizer: return_tensors: str | None = None, **kwargs, ): - """Mock implementation that creates deterministic token tensors from text.""" + """Mock implementation that creates deterministic tokens from text.""" # Concatenate all message content + del tokenize, return_tensors, kwargs text = " ".join([msg["content"] for msg in conversation]) # Create deterministic tokens based on text content # Use character codes modulo 256 to get valid token IDs tokens = [ord(c) % 256 for c in text] - if return_tensors == "pt": - return torch.tensor([tokens]) return tokens def decode(self, token_ids, skip_special_tokens: bool = False, **kwargs) -> str: """Mock decode that reverses the encoding process.""" - if isinstance(token_ids, torch.Tensor): + del skip_special_tokens, kwargs + if hasattr(token_ids, "tolist") and callable(token_ids.tolist): token_ids = token_ids.tolist() # Reverse the encoding: convert tokens back to characters @@ -63,12 +62,12 @@ def test_tokenization_consistency(): action2 = env.message_to_action(message2) # Verify tokens are identical - assert torch.equal(action1.tokens, action2.tokens), ( + assert action1.tokens == action2.tokens, ( "Same message should produce identical tokens" ) # Verify tokens are not empty - assert action1.tokens.numel() > 0, "Tokens should not be empty" + assert len(action1.tokens) > 0, "Tokens should not be empty" print("✓ test_tokenization_consistency passed") @@ -151,13 +150,13 @@ def test_token_history_accumulation(): env = ChatEnvironment(tokenizer=tokenizer) obs = env.reset() - initial_token_count = obs.tokens.numel() + initial_token_count = len(obs.tokens) # Step with first message message1 = {"role": "user", "content": "Hi"} action1 = env.message_to_action(message1) obs1 = env.step(action1) - token_count_1 = obs1.tokens.numel() + token_count_1 = len(obs1.tokens) # Tokens should increase assert token_count_1 > initial_token_count, "Token count should increase after step" @@ -166,7 +165,7 @@ def test_token_history_accumulation(): message2 = {"role": "assistant", "content": "Hello there"} action2 = env.message_to_action(message2) obs2 = env.step(action2) - token_count_2 = obs2.tokens.numel() + token_count_2 = len(obs2.tokens) # Tokens should continue to accumulate assert token_count_2 > token_count_1, ( @@ -174,8 +173,8 @@ def test_token_history_accumulation(): ) # Verify tokens are the concatenation of both messages - expected_tokens = torch.cat([action1.tokens.flatten(), action2.tokens.flatten()]) - assert torch.equal(obs2.tokens, expected_tokens), ( + expected_tokens = action1.tokens + action2.tokens + assert obs2.tokens == expected_tokens, ( "Tokens should be concatenation of all actions" ) @@ -190,7 +189,7 @@ def test_direct_token_action(): env.reset() # Create raw tokens - raw_tokens = torch.tensor([[72, 101, 108, 108, 111]]) # ASCII for "Hello" + raw_tokens = [[72, 101, 108, 108, 111]] # ASCII for "Hello" action = ChatAction(tokens=raw_tokens) # Step with raw tokens @@ -201,7 +200,7 @@ def test_direct_token_action(): assert obs.messages[0]["role"] == "assistant", "Should default to assistant role" # Verify tokens match what we sent (flattened) - assert torch.equal(obs.tokens, raw_tokens.flatten()), ( + assert obs.tokens == [72, 101, 108, 108, 111], ( "Observation tokens should match input tokens" ) @@ -211,10 +210,12 @@ def test_direct_token_action(): def test_empty_tokens_validation(): """Test that empty tokens raise a ValueError.""" try: - action = ChatAction(tokens=torch.tensor([])) + action = ChatAction(tokens=[]) assert False, "Should have raised ValueError for empty tokens" except ValueError as e: - assert "empty" in str(e).lower(), "Error message should mention empty tokens" + assert "at least 1 item" in str(e).lower(), ( + "Error message should mention the minimum token count" + ) print("✓ test_empty_tokens_validation passed") diff --git a/src/core/env_server/http_server.py b/src/core/env_server/http_server.py index 658f63ef98bf78d278b8926271c217da23c79a37..f59012b60d335e596fc25866db4c64cbeafaa5a3 100644 --- a/src/core/env_server/http_server.py +++ b/src/core/env_server/http_server.py @@ -16,11 +16,15 @@ from __future__ import annotations import asyncio import inspect import json +import logging import os import time import uuid from concurrent.futures import ThreadPoolExecutor -from typing import Any, Callable, Dict, Optional, Type +from contextlib import AsyncExitStack +from typing import Any, AsyncContextManager, Callable, cast, Dict, Optional, Type + +_MISSING = object() from fastapi import ( Body, @@ -204,8 +208,9 @@ class HTTPEnvServer: self.observation_cls = observation_cls # Session management for WebSocket connections - self._sessions: Dict[str, Environment] = {} + self._sessions: Dict[str, Optional[Environment]] = {} self._session_executors: Dict[str, ThreadPoolExecutor] = {} + self._session_stacks: Dict[str, AsyncExitStack] = {} self._session_info: Dict[str, SessionInfo] = {} self._session_lock = asyncio.Lock() @@ -213,6 +218,14 @@ class HTTPEnvServer: # This is needed for environments using sync libraries (e.g., Playwright) self._executor = ThreadPoolExecutor(max_workers=32) + # Idle session reaper configuration. + # Timeout is taken from ConcurrencyConfig.session_timeout; + # None means no timeout (default — reaper is a no-op). + self._session_idle_timeout_s: Optional[float] = ( + self._concurrency_config.session_timeout + ) + self._reaper_task: Optional[asyncio.Task[None]] = None + def _validate_concurrency_safety(self) -> None: """ Validate that the environment supports the configured concurrency level. @@ -321,12 +334,37 @@ class HTTPEnvServer: ) raise EnvironmentFactoryError(factory_name) from e + # Hold the MCP session open for the lifetime of this session, + # matching the WebSocket path's AsyncExitStack pattern. This + # prevents per-request MCP transport teardown/reconnection and + # preserves FastMCP session state (ctx.set_state / ctx.get_state) + # across HTTP calls within the same OpenEnv session. + stack = AsyncExitStack() + try: + mcp_session_factory = getattr(env, "mcp_session", None) + if callable(mcp_session_factory): + mcp_session_cm = cast(AsyncContextManager[Any], mcp_session_factory()) + await stack.enter_async_context(mcp_session_cm) + except Exception: + # MCP transport failed to start — clean up the reserved slot, + # the env, and the executor so they don't leak permanently + # against _max_concurrent_envs. + await stack.aclose() # best-effort + async with self._session_lock: + self._sessions.pop(session_id, None) + self._session_executors.pop(session_id, None) + self._session_info.pop(session_id, None) + await self._cleanup_session_resources(env, executor) + raise + async with self._session_lock: self._sessions[session_id] = env + self._session_stacks[session_id] = stack + now = time.time() self._session_info[session_id] = SessionInfo( session_id=session_id, created_at=current_time, - last_activity_at=current_time, + last_activity_at=now, step_count=0, environment_type=type(env).__name__, ) @@ -343,8 +381,27 @@ class HTTPEnvServer: async with self._session_lock: env = self._sessions.pop(session_id, None) executor = self._session_executors.pop(session_id, None) + stack = self._session_stacks.pop(session_id, None) self._session_info.pop(session_id, None) + await self._cleanup_session_resources(env, executor, stack) + + async def _cleanup_session_resources( + self, + env: Optional[Environment], + executor: Optional[ThreadPoolExecutor], + stack: Optional[AsyncExitStack] = None, + ) -> None: + """Close an environment and shut down its executor (best-effort).""" + # Close the MCP session stack first — this gracefully exits the + # mcp_session() context (and the underlying FastMCP Client session) + # before we tear down the environment references. + if stack is not None: + try: + await stack.aclose() + except Exception: + pass # Best effort cleanup + # Run close() in the same executor where the env was created # This is required for thread-sensitive libraries like Playwright/greenlet if env is not None: @@ -383,6 +440,51 @@ class HTTPEnvServer: if increment_step: self._session_info[session_id].step_count += 1 + async def _reap_idle_sessions(self) -> None: + """Background task that periodically destroys sessions idle beyond the timeout.""" + timeout = self._session_idle_timeout_s + if timeout is None: + return # no timeout configured — noop + interval = max(timeout / 4, 5.0) # check frequently enough + while True: + try: + await asyncio.sleep(interval) + now = time.time() + stale_ids: list[str] = [] + async with self._session_lock: + for sid, info in self._session_info.items(): + if now - info.last_activity_at > timeout: + stale_ids.append(sid) + for sid in stale_ids: + # Re-check under lock: activity may have arrived since + # the snapshot was taken, making this session active again. + # Refresh `now` so slow _destroy_session calls don't cause + # subsequent entries to be validated against a stale clock. + now = time.time() + async with self._session_lock: + info = self._session_info.get(sid) + if info is None or (now - info.last_activity_at) <= timeout: + continue + await self._destroy_session(sid) + except asyncio.CancelledError: + break + except Exception as exc: + logging.getLogger(__name__).warning( + "Idle-session reaper encountered an error (will retry): %s", + exc, + ) + + def _start_reaper(self) -> None: + """Start the idle-session reaper if a timeout is configured.""" + if self._session_idle_timeout_s is not None and self._reaper_task is None: + self._reaper_task = asyncio.create_task(self._reap_idle_sessions()) + + def _stop_reaper(self) -> None: + """Cancel the reaper background task.""" + if self._reaper_task is not None: + self._reaper_task.cancel() + self._reaper_task = None + def get_session_info(self, session_id: str) -> Optional[SessionInfo]: """ Get information about a specific session. @@ -458,6 +560,20 @@ class HTTPEnvServer: f"Invalid mode: '{mode}'. Must be one of: {valid_modes}" ) + # Wire up idle-session reaper lifecycle via app events + server_ref = self + + async def _start_session_reaper() -> None: + server_ref._start_reaper() + + async def _stop_session_reaper() -> None: + server_ref._stop_reaper() + + if not getattr(app.router, "_openenv_reaper_registered", False): + app.router.on_startup.append(_start_session_reaper) + app.router.on_shutdown.append(_stop_session_reaper) + app.router._openenv_reaper_registered = True # type: ignore[attr-defined] + # Helper function to handle reset endpoint async def reset_handler( request: ResetRequest = Body(default_factory=ResetRequest), @@ -526,53 +642,214 @@ class HTTPEnvServer: # Helper function to handle MCP endpoint async def mcp_handler( - request: JsonRpcRequest, session_env: Optional[Environment] = None + request: JsonRpcRequest, + session_env: Optional[Environment] = None, + session_id: Optional[str] = None, ) -> JsonRpcResponse: """ Handle MCP JSON-RPC requests. - Supports tools/list and tools/call methods in JSON-RPC 2.0 format. + Supports tools/list and tools/call methods in JSON-RPC 2.0 format, + plus OpenEnv session lifecycle methods for HTTP MCP: + - openenv/session/create + - openenv/session/close """ method = request.method request_id = request.id + params = request.params + if not isinstance(params, dict): + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_PARAMS, + "Params must be an object", + request_id=request_id, + ) + + # OpenEnv extension methods for explicit MCP session management. + # This enables persistent MCP lifecycles over HTTP /mcp, matching WebSocket semantics. + if method == "openenv/session/create": + if session_env is not None and session_id is not None: + return JsonRpcResponse.success( + result={"session_id": session_id}, + request_id=request_id, + ) + try: + created_session_id, _ = await self._create_session() + except SessionCapacityError as e: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.SERVER_ERROR, + str(e), + request_id=request_id, + data={ + "active_sessions": e.active_sessions, + "max_sessions": e.max_sessions, + }, + ) + except EnvironmentFactoryError as e: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.SERVER_ERROR, + str(e), + request_id=request_id, + data={"factory_name": e.factory_name}, + ) + return JsonRpcResponse.success( + result={"session_id": created_session_id}, + request_id=request_id, + ) + + if method == "openenv/session/close": + target_session_id = params.get("session_id") + if not target_session_id: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_PARAMS, + "Invalid params - 'session_id' is required", + request_id=request_id, + ) + + if session_id is not None and target_session_id == session_id: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_REQUEST, + "Cannot close active WebSocket-managed session via MCP method", + request_id=request_id, + ) + + async with self._session_lock: + env = self._sessions.pop(target_session_id, _MISSING) + if env is not _MISSING: + executor = self._session_executors.pop(target_session_id, None) + stack = self._session_stacks.pop(target_session_id, None) + self._session_info.pop(target_session_id, None) + else: + executor = None + stack = None + + if env is _MISSING: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_PARAMS, + f"Unknown session_id: {target_session_id}", + request_id=request_id, + ) + + if env is None: + # Session slot reserved but env factory still running; + # re-insert the placeholder AND the executor so + # _create_session can finish and the executor remains + # tracked for eventual shutdown. + async with self._session_lock: + self._sessions[target_session_id] = None + if executor is not None: + self._session_executors[target_session_id] = executor + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_REQUEST, + f"Session {target_session_id} is still initializing; retry shortly", + request_id=request_id, + ) + + # env/executor/stack cleanup outside the lock + await self._cleanup_session_resources(env, executor, stack) + return JsonRpcResponse.success( + result={"session_id": target_session_id, "closed": True}, + request_id=request_id, + ) + + requested_session_id = params.get("session_id") + managed_session_id = session_id # Use provided session environment or create temporary one if session_env is not None: _env = session_env should_close = False + elif requested_session_id: + async with self._session_lock: + _env = self._sessions.get(requested_session_id, _MISSING) + + if _env is _MISSING: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_PARAMS, + f"Unknown session_id: {requested_session_id}", + request_id=request_id, + ) + + if _env is None: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_REQUEST, + f"Session {requested_session_id} is still initializing; retry shortly", + request_id=request_id, + ) + + should_close = False + managed_session_id = requested_session_id else: _env = self._env_factory() should_close = True try: + mcp_client = getattr(_env, "mcp_client", None) + mcp_server = getattr(_env, "mcp_server", None) + mcp_session_factory = getattr(_env, "mcp_session", None) + if method == McpMethod.TOOLS_LIST: # Check if environment is MCP-enabled - if not hasattr(_env, "mcp_client"): + if mcp_client is None and mcp_server is None: return JsonRpcResponse.error_response( JsonRpcErrorCode.INTERNAL_ERROR, "Environment does not support MCP", request_id=request_id, ) - # Use async context manager for MCP client - async with _env.mcp_client: - tools = await _env.mcp_client.list_tools() + if mcp_client: + if managed_session_id and mcp_client.is_connected(): + # Session-managed with live transport — call + # directly, no redundant re-entry. + tools = await mcp_client.list_tools() + elif callable(mcp_session_factory): + # Stateless request, or session-managed but the + # background transport was lost: (re-)open. + mcp_session_cm = cast( + AsyncContextManager[Any], mcp_session_factory() + ) + async with mcp_session_cm: + tools = await mcp_client.list_tools() + else: + async with mcp_client: + tools = await mcp_client.list_tools() + + return JsonRpcResponse.success( + result={ + "tools": [ + t.model_dump() + if hasattr(t, "model_dump") + else dict(t) + for t in tools + ] + }, + request_id=request_id, + ) - return JsonRpcResponse.success( - result={ - "tools": [ - t.model_dump() if hasattr(t, "model_dump") else dict(t) - for t in tools - ] - }, + if mcp_server: + tools = [] + for _tool_name, tool in get_server_tools(mcp_server).items(): + tools.append( + { + "name": tool.name, + "description": tool.description or "", + "inputSchema": tool.parameters or {}, + } + ) + return JsonRpcResponse.success( + result={"tools": tools}, + request_id=request_id, + ) + + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + "MCP server not available", request_id=request_id, ) elif method == McpMethod.TOOLS_CALL: - params = request.params tool_name = params.get("name") arguments = params.get("arguments", {}) - if not hasattr(_env, "mcp_client"): + if mcp_client is None and mcp_server is None: return JsonRpcResponse.error_response( JsonRpcErrorCode.INTERNAL_ERROR, "Environment does not support MCP", @@ -581,15 +858,51 @@ class HTTPEnvServer: if not tool_name: return JsonRpcResponse.error_response( - JsonRpcErrorCode.INVALID_REQUEST, + JsonRpcErrorCode.INVALID_PARAMS, "Missing 'name' in params", request_id=request_id, ) - # Use async context manager for MCP client - async with _env.mcp_client: - result = await _env.mcp_client.call_tool( - name=tool_name, arguments=arguments + if mcp_client: + if managed_session_id and mcp_client.is_connected(): + # Session-managed with live transport. + result = await mcp_client.call_tool( + name=tool_name, arguments=arguments + ) + elif callable(mcp_session_factory): + # Stateless request, or session-managed but the + # background transport was lost: (re-)open. + mcp_session_cm = cast( + AsyncContextManager[Any], mcp_session_factory() + ) + async with mcp_session_cm: + result = await mcp_client.call_tool( + name=tool_name, arguments=arguments + ) + else: + async with mcp_client: + result = await mcp_client.call_tool( + name=tool_name, arguments=arguments + ) + elif mcp_server: + server_tools = get_server_tools(mcp_server) + if tool_name in server_tools: + tool = server_tools[tool_name] + if inspect.iscoroutinefunction(tool.fn): + result = await tool.fn(**arguments) + else: + result = tool.fn(**arguments) + else: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_PARAMS, + f"Tool not found: {tool_name}", + request_id=request_id, + ) + else: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + "MCP server not available", + request_id=request_id, ) # Ensure result is JSON serializable @@ -614,6 +927,11 @@ class HTTPEnvServer: request_id=request_id, ) finally: + if managed_session_id: + self._update_session_activity( + managed_session_id, + increment_step=(method == McpMethod.TOOLS_CALL), + ) if should_close: _env.close() @@ -637,42 +955,59 @@ class HTTPEnvServer: try: # Create session with dedicated environment session_id, session_env = await self._create_session() + if session_env is None: + raise RuntimeError( + "Session environment not initialized for MCP websocket" + ) - while True: - # Receive message from client - raw_message = await websocket.receive_text() - - try: - jsonrpc_dict = json.loads(raw_message) - jsonrpc_request = JsonRpcRequest(**jsonrpc_dict) - except json.JSONDecodeError as e: - error_resp = JsonRpcResponse.error_response( - JsonRpcErrorCode.PARSE_ERROR, - f"Parse error: {e}", - ) - await websocket.send_text(error_resp.model_dump_json()) - continue - except ValidationError as e: - error_resp = JsonRpcResponse.error_response( - JsonRpcErrorCode.INVALID_REQUEST, - f"Invalid request: {e}", - ) - await websocket.send_text(error_resp.model_dump_json()) - continue + # If environment has an mcp_session context manager, hold it open + # for the lifetime of the websocket connection - try: - # Call mcp_handler with session environment - response = await mcp_handler( - jsonrpc_request, session_env=session_env + async with AsyncExitStack() as stack: + mcp_session_factory = getattr(session_env, "mcp_session", None) + if callable(mcp_session_factory): + mcp_session_cm = cast( + AsyncContextManager[Any], mcp_session_factory() ) - await websocket.send_text(response.model_dump_json()) - except Exception as e: - error_resp = JsonRpcResponse.error_response( - JsonRpcErrorCode.INTERNAL_ERROR, - str(e), - request_id=jsonrpc_request.id, - ) - await websocket.send_text(error_resp.model_dump_json()) + await stack.enter_async_context(mcp_session_cm) + + while True: + # Receive message from client + raw_message = await websocket.receive_text() + + try: + jsonrpc_dict = json.loads(raw_message) + jsonrpc_request = JsonRpcRequest(**jsonrpc_dict) + except json.JSONDecodeError as e: + error_resp = JsonRpcResponse.error_response( + JsonRpcErrorCode.PARSE_ERROR, + f"Parse error: {e}", + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + except ValidationError as e: + error_resp = JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_REQUEST, + f"Invalid request: {e}", + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + + try: + # Call mcp_handler with session environment + response = await mcp_handler( + jsonrpc_request, + session_env=session_env, + session_id=session_id, + ) + await websocket.send_text(response.model_dump_json()) + except Exception as e: + error_resp = JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + str(e), + request_id=jsonrpc_request.id, + ) + await websocket.send_text(error_resp.model_dump_json()) except WebSocketDisconnect: pass @@ -931,120 +1266,8 @@ all schema information needed to interact with the environment. JsonRpcErrorCode.PARSE_ERROR ).model_dump() - method = request.method - params = request.params - request_id = request.id - - # Create a temporary environment for MCP access - _env = self._env_factory() - - try: - # Check if environment supports MCP - if not hasattr(_env, "mcp_client") and not hasattr(_env, "mcp_server"): - return JsonRpcResponse.error_response( - JsonRpcErrorCode.INTERNAL_ERROR, - "Environment does not support MCP", - request_id=request_id, - ).model_dump() - - if method == McpMethod.TOOLS_LIST: - # List tools from MCP server - if hasattr(_env, "mcp_client") and _env.mcp_client: - async with _env.mcp_client: - tools = await _env.mcp_client.list_tools() - return JsonRpcResponse.success( - result={ - "tools": [ - t.model_dump() - if hasattr(t, "model_dump") - else dict(t) - for t in tools - ] - }, - request_id=request_id, - ).model_dump() - elif hasattr(_env, "mcp_server") and _env.mcp_server: - # Use server directly - tools = [] - for tool_name, tool in get_server_tools( - _env.mcp_server - ).items(): - tool_dict = { - "name": tool.name, - "description": tool.description or "", - "inputSchema": tool.parameters or {}, - } - tools.append(tool_dict) - return JsonRpcResponse.success( - result={"tools": tools}, - request_id=request_id, - ).model_dump() - else: - return JsonRpcResponse.error_response( - JsonRpcErrorCode.INTERNAL_ERROR, - "MCP server not available", - request_id=request_id, - ).model_dump() - - elif method == McpMethod.TOOLS_CALL: - tool_name = params.get("name") - arguments = params.get("arguments", {}) - - if not tool_name: - return JsonRpcResponse.error_response( - JsonRpcErrorCode.INVALID_PARAMS, - "Invalid params - 'name' is required", - request_id=request_id, - ).model_dump() - - # Call tool via MCP - if hasattr(_env, "mcp_client") and _env.mcp_client: - async with _env.mcp_client: - result = await _env.mcp_client.call_tool( - name=tool_name, arguments=arguments - ) - elif hasattr(_env, "mcp_server") and _env.mcp_server: - # Call tool directly on FastMCP server - server_tools = get_server_tools(_env.mcp_server) - if tool_name in server_tools: - tool = server_tools[tool_name] - result = tool.fn(**arguments) - else: - return JsonRpcResponse.error_response( - JsonRpcErrorCode.INVALID_PARAMS, - f"Tool not found: {tool_name}", - request_id=request_id, - ).model_dump() - else: - return JsonRpcResponse.error_response( - JsonRpcErrorCode.INTERNAL_ERROR, - "MCP server not available", - request_id=request_id, - ).model_dump() - - # Make result JSON serializable - serializable_result = _make_json_serializable(result) - - return JsonRpcResponse.success( - result=serializable_result, - request_id=request_id, - ).model_dump() - - else: - return JsonRpcResponse.error_response( - JsonRpcErrorCode.METHOD_NOT_FOUND, - f"Method not found: {method}", - request_id=request_id, - ).model_dump() - - except Exception as e: - return JsonRpcResponse.error_response( - JsonRpcErrorCode.INTERNAL_ERROR, - str(e), - request_id=request_id, - ).model_dump() - finally: - _env.close() + response = await mcp_handler(request) + return response.model_dump() # Register WebSocket endpoint for persistent sessions @app.websocket("/ws") @@ -1066,135 +1289,167 @@ all schema information needed to interact with the environment. try: # Create session with dedicated environment session_id, session_env = await self._create_session() + if session_env is None: + raise RuntimeError( + "Session environment not initialized for websocket" + ) - while True: - # Receive message from client - raw_message = await websocket.receive_text() + # Keep MCP session open for entire websocket lifetime + # (avoids reconnect overhead on every message) - try: - message_dict = json.loads(raw_message) - except json.JSONDecodeError as e: - error_resp = WSErrorResponse( - data={ - "message": f"Invalid JSON: {e}", - "code": WSErrorCode.INVALID_JSON, - } + async with AsyncExitStack() as stack: + mcp_session_factory = getattr(session_env, "mcp_session", None) + if callable(mcp_session_factory): + mcp_session_cm = cast( + AsyncContextManager[Any], mcp_session_factory() ) - await websocket.send_text(error_resp.model_dump_json()) - continue - - msg_type = message_dict.get("type", "") + await stack.enter_async_context(mcp_session_cm) + + while True: + # Receive message from client + raw_message = await websocket.receive_text() + + try: + message_dict = json.loads(raw_message) + except json.JSONDecodeError as e: + error_resp = WSErrorResponse( + data={ + "message": f"Invalid JSON: {e}", + "code": WSErrorCode.INVALID_JSON, + } + ) + await websocket.send_text(error_resp.model_dump_json()) + continue - try: - match msg_type: - case "reset": - msg = WSResetMessage(**message_dict) + msg_type = message_dict.get("type", "") - is_async = ( - session_env.reset_async.__func__ - is not Environment.reset_async - ) + try: + match msg_type: + case "reset": + msg = WSResetMessage(**message_dict) - if is_async: - sig = inspect.signature(session_env.reset_async) - valid_kwargs = self._get_valid_kwargs(sig, msg.data) - observation = await session_env.reset_async( - **valid_kwargs + is_async = ( + session_env.reset_async.__func__ + is not Environment.reset_async ) - else: - sig = inspect.signature(session_env.reset) - valid_kwargs = self._get_valid_kwargs(sig, msg.data) - observation = await self._run_in_session_executor( - session_id, session_env.reset, **valid_kwargs - ) - - self._update_session_activity(session_id) - - response = WSObservationResponse( - data=serialize_observation(observation), - ) - case "step": - msg = WSStepMessage(**message_dict) - action = deserialize_action(msg.data, self.action_cls) + if is_async: + sig = inspect.signature(session_env.reset_async) + valid_kwargs = self._get_valid_kwargs( + sig, msg.data + ) + observation = await session_env.reset_async( + **valid_kwargs + ) + else: + sig = inspect.signature(session_env.reset) + valid_kwargs = self._get_valid_kwargs( + sig, msg.data + ) + observation = ( + await self._run_in_session_executor( + session_id, + session_env.reset, + **valid_kwargs, + ) + ) + + self._update_session_activity(session_id) + + response = WSObservationResponse( + data=serialize_observation(observation), + ) - is_async = ( - session_env.step_async.__func__ - is not Environment.step_async - ) + case "step": + msg = WSStepMessage(**message_dict) + action = deserialize_action( + msg.data, self.action_cls + ) - if is_async: - observation = await session_env.step_async(action) - else: - observation = await self._run_in_session_executor( - session_id, session_env.step, action + is_async = ( + session_env.step_async.__func__ + is not Environment.step_async ) - self._update_session_activity( - session_id, increment_step=True - ) + if is_async: + observation = await session_env.step_async( + action + ) + else: + observation = ( + await self._run_in_session_executor( + session_id, session_env.step, action + ) + ) + + self._update_session_activity( + session_id, increment_step=True + ) - response = WSObservationResponse( - data=serialize_observation(observation) - ) + response = WSObservationResponse( + data=serialize_observation(observation) + ) - case "state": - msg = WSStateMessage(**message_dict) - state = session_env.state - if hasattr(state, "model_dump"): - state_data = state.model_dump() - else: - state_data = dict(state) if state else {} - - response = WSStateResponse(data=state_data) - - case "close": - msg = WSCloseMessage(**message_dict) - break - - case "mcp": - msg = WSMCPMessage(**message_dict) - try: - rpc_request = JsonRpcRequest(**msg.data) - except (ValidationError, Exception) as e: - rpc_response = JsonRpcResponse.error_response( - JsonRpcErrorCode.INVALID_REQUEST, - f"Invalid request: {e}", + case "state": + msg = WSStateMessage(**message_dict) + state = session_env.state + if hasattr(state, "model_dump"): + state_data = state.model_dump() + else: + state_data = dict(state) if state else {} + + response = WSStateResponse(data=state_data) + + case "close": + msg = WSCloseMessage(**message_dict) + break + + case "mcp": + msg = WSMCPMessage(**message_dict) + try: + rpc_request = JsonRpcRequest(**msg.data) + except (ValidationError, Exception) as e: + rpc_response = JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_REQUEST, + f"Invalid request: {e}", + ) + else: + rpc_response = await mcp_handler( + rpc_request, + session_env=session_env, + session_id=session_id, + ) + response = WSMCPResponse( + data=rpc_response.model_dump() ) - else: - rpc_response = await mcp_handler( - rpc_request, - session_env=session_env, + + case _: + response = WSErrorResponse( + data={ + "message": f"Unknown message type: {msg_type}", + "code": WSErrorCode.UNKNOWN_TYPE, + } ) - response = WSMCPResponse(data=rpc_response.model_dump()) - - case _: - response = WSErrorResponse( - data={ - "message": f"Unknown message type: {msg_type}", - "code": WSErrorCode.UNKNOWN_TYPE, - } - ) - await websocket.send_text(response.model_dump_json()) + await websocket.send_text(response.model_dump_json()) - except ValidationError as e: - error_resp = WSErrorResponse( - data={ - "message": "Invalid message", - "code": WSErrorCode.VALIDATION_ERROR, - "errors": e.errors(), - } - ) - await websocket.send_text(error_resp.model_dump_json()) - except Exception as e: - error_resp = WSErrorResponse( - data={ - "message": str(e), - "code": WSErrorCode.EXECUTION_ERROR, - } - ) - await websocket.send_text(error_resp.model_dump_json()) + except ValidationError as e: + error_resp = WSErrorResponse( + data={ + "message": "Invalid message", + "code": WSErrorCode.VALIDATION_ERROR, + "errors": e.errors(), + } + ) + await websocket.send_text(error_resp.model_dump_json()) + except Exception as e: + error_resp = WSErrorResponse( + data={ + "message": str(e), + "code": WSErrorCode.EXECUTION_ERROR, + } + ) + await websocket.send_text(error_resp.model_dump_json()) except WebSocketDisconnect: pass @@ -1276,7 +1531,7 @@ def create_app( from .web_interface import create_web_interface_app return create_web_interface_app( - env, + cast(Any, env), action_cls, observation_cls, env_name, diff --git a/src/core/env_server/mcp_environment.py b/src/core/env_server/mcp_environment.py index 03f66e37897ec81796d468f3d0590d465deddea1..50ddec98d3e769bf241076a6deb3e4ff6cb229e6 100644 --- a/src/core/env_server/mcp_environment.py +++ b/src/core/env_server/mcp_environment.py @@ -56,6 +56,7 @@ import asyncio import inspect from abc import abstractmethod from collections import defaultdict +from contextlib import asynccontextmanager from typing import Any, Callable, Dict, Optional from fastmcp import Client @@ -164,6 +165,52 @@ class MCPEnvironment(Environment): # Track tool schemas for list_tools: {tool_name: {mode: schema}} self._mode_tool_schemas = defaultdict(dict) + def _require_mcp_client(self) -> Any: + """Return MCP client or raise if environment has been closed.""" + if self.mcp_client is None: + raise RuntimeError("MCP client is not available; environment is closed") + return self.mcp_client + + def _require_mcp_server(self) -> Any: + """Return MCP server or raise if environment has been closed.""" + if self.mcp_server is None: + raise RuntimeError("MCP server is not available; environment is closed") + return self.mcp_server + + @asynccontextmanager + async def mcp_session(self): + """ + Context manager for MCP client sessions. + + This wrapper serves two purposes: + + 1. **Null guard** — raises a clear error if ``close()`` has already + been called (``mcp_client`` is ``None``). + + 2. **AsyncExitStack adapter** — FastMCP's ``Client.__aenter__`` + creates a background ``asyncio.Task`` for session management. + When entered directly via ``AsyncExitStack`` in the HTTP session + path (``_create_session``), this task can be cancelled by ASGI + harnesses (e.g. Starlette ``TestClient``) between requests, + corrupting session state. Wrapping in an ``asynccontextmanager`` + generator isolates the task lifecycle: the generator frame keeps + ``async with client:`` suspended at ``yield``, so cleanup only + runs when the stack explicitly closes the generator — not when + the event loop cancels orphaned tasks. + + Delegates to FastMCP's ``Client`` context manager which is + reentrant: the first entry opens the transport and subsequent + (nested) entries simply increment an internal reference counter. + The transport is closed only when the outermost context exits. + + No external lock is needed because ``Client._connect`` / + ``Client._disconnect`` already serialise connection state changes + through their own ``anyio.Lock``. + """ + client = self._require_mcp_client() + async with client: + yield client + @property def supports_code_mode(self) -> bool: """Check if this environment supports code mode (execute_code).""" @@ -292,7 +339,8 @@ class MCPEnvironment(Environment): # If mode is None, register with FastMCP as usual if mode is None: - decorated_func = self.mcp_server.tool()(func) + mcp_server = self._require_mcp_server() + decorated_func = mcp_server.tool()(func) self._mode_tools[tool_name][None] = func return decorated_func @@ -372,24 +420,49 @@ class MCPEnvironment(Environment): return self._step_impl(action, timeout_s=timeout_s, **kwargs) def _handle_list_tools(self) -> ListToolsObservation: + """Sync wrapper — delegates to the canonical async implementation.""" + return run_async_safely(self._async_handle_list_tools()) + + async def _async_list_tools(self) -> list: """ - Handle a ListToolsAction by querying the MCP server. + Async helper to list tools from the MCP client. Returns: - ListToolsObservation containing all available tools with their - names, descriptions, and input schemas, filtered by current mode. + List of tool objects from the MCP server. """ - try: - # Get current mode - current_mode = getattr(self, "_mode", None) + async with self.mcp_session() as client: + return await client.list_tools() - # Start with tools from FastMCP server (mode=None tools) - tools_result = run_async_safely(self._async_list_tools()) + def _handle_call_tool( + self, + action: CallToolAction, + timeout_s: Optional[float] = None, + ) -> CallToolObservation: + """Sync wrapper — delegates to the canonical async implementation.""" + return run_async_safely( + self._async_handle_call_tool(action, timeout_s=timeout_s) + ) - # Build list of Tool objects - tools = [] + async def _async_call_tool(self, tool_name: str, arguments: dict) -> Any: + """ + Async helper to call a tool on the MCP server. - # Add FastMCP tools that are not mode-specific + Args: + tool_name: Name of the tool to invoke. + arguments: Dictionary of arguments to pass to the tool. + + Returns: + The result from the tool execution. + """ + async with self.mcp_session() as client: + return await client.call_tool(tool_name, arguments) + + async def _async_handle_list_tools(self) -> ListToolsObservation: + """Async version of _handle_list_tools — avoids run_async_safely.""" + try: + current_mode = getattr(self, "_mode", None) + tools_result = await self._async_list_tools() + tools = [] for tool in tools_result: if tool.name not in self._mode_tool_schemas: tools.append( @@ -401,11 +474,8 @@ class MCPEnvironment(Environment): else {}, ) ) - - # Add mode-specific tools available in current mode for tool_name, mode_schemas in self._mode_tool_schemas.items(): if None in mode_schemas: - # Tool available in all modes schema = mode_schemas[None] tools.append( Tool( @@ -415,7 +485,6 @@ class MCPEnvironment(Environment): ) ) elif current_mode in mode_schemas: - # Tool available in current mode schema = mode_schemas[current_mode] tools.append( Tool( @@ -424,65 +493,30 @@ class MCPEnvironment(Environment): input_schema=schema["input_schema"], ) ) - return ListToolsObservation(tools=tools) - except Exception as e: - # Return an observation with error in metadata return ListToolsObservation( tools=[], - metadata={ - "error": str(e), - "error_type": "list_tools_failed", - }, + metadata={"error": str(e), "error_type": "list_tools_failed"}, ) - async def _async_list_tools(self) -> list: - """ - Async helper to list tools from the MCP client. - - Returns: - List of tool objects from the MCP server. - """ - async with self.mcp_client: - return await self.mcp_client.list_tools() - - def _handle_call_tool( + async def _async_handle_call_tool( self, action: CallToolAction, timeout_s: Optional[float] = None, ) -> CallToolObservation: - """ - Handle a CallToolAction by invoking the specified tool. - - Args: - action: The CallToolAction containing tool_name and arguments. - timeout_s: Timeout in seconds. Defaults to MCP_TOOL_CALL_TIMEOUT (30s). - - Returns: - CallToolObservation with the tool's result or an error. - """ + """Async version of _handle_call_tool — avoids run_async_safely.""" timeout = timeout_s if timeout_s is not None else MCP_TOOL_CALL_TIMEOUT - - # Check if this is a mode-specific tool tool_name = action.tool_name current_mode = getattr(self, "_mode", None) if tool_name in self._mode_tools: mode_info = self._mode_tools[tool_name] - - # Check if tool is available in current mode - # Tool is available if: - # 1. It has a None mode (available in all modes), OR - # 2. It has an implementation for the current mode if None in mode_info: - # Use the mode-agnostic version func = mode_info[None] elif current_mode in mode_info: - # Use the mode-specific version func = mode_info[current_mode] else: - # Tool not available in current mode return CallToolObservation( tool_name=tool_name, result=None, @@ -491,16 +525,11 @@ class MCPEnvironment(Environment): message=f"Tool '{tool_name}' not available in {current_mode} mode", ), ) - - # Call the mode-specific function directly try: - # Check if function is async and await if necessary if inspect.iscoroutinefunction(func): - result = run_async_safely(func(**action.arguments)) + result = await func(**action.arguments) else: result = func(**action.arguments) - - # Wrap result in CallToolResult format to match FastMCP behavior return CallToolObservation( tool_name=tool_name, result=CallToolResult( @@ -521,22 +550,12 @@ class MCPEnvironment(Environment): ), ) - # Not a mode-specific tool, use FastMCP try: - # Run the async call_tool with timeout - # Use run_async_safely to handle both sync and async contexts - result = run_async_safely( - asyncio.wait_for( - self._async_call_tool(action.tool_name, action.arguments), - timeout=timeout, - ) - ) - - return CallToolObservation( - tool_name=action.tool_name, - result=result, + result = await asyncio.wait_for( + self._async_call_tool(action.tool_name, action.arguments), + timeout=timeout, ) - + return CallToolObservation(tool_name=action.tool_name, result=result) except asyncio.TimeoutError: return CallToolObservation( tool_name=action.tool_name, @@ -546,11 +565,8 @@ class MCPEnvironment(Environment): message=f"Tool '{action.tool_name}' timed out after {timeout} seconds", ), ) - except Exception as e: error_message = str(e) - - # Determine error type based on the exception if ( "not found" in error_message.lower() or "unknown tool" in error_message.lower() @@ -563,29 +579,34 @@ class MCPEnvironment(Environment): error_type = ToolErrorType.INVALID_ARGS else: error_type = ToolErrorType.EXECUTION_ERROR - return CallToolObservation( tool_name=action.tool_name, result=None, - error=ToolError( - error_type=error_type, - message=error_message, - ), + error=ToolError(error_type=error_type, message=error_message), ) - async def _async_call_tool(self, tool_name: str, arguments: dict) -> Any: + async def step_async( + self, + action: Action, + timeout_s: Optional[float] = None, + **kwargs: Any, + ) -> Observation: """ - Async helper to call a tool on the MCP server. + Async step that routes MCP actions without going through run_async_safely. - Args: - tool_name: Name of the tool to invoke. - arguments: Dictionary of arguments to pass to the tool. - - Returns: - The result from the tool execution. + The WebSocket handler calls this directly on the outer event loop, where + the MCP session is already open, avoiding the thread/event-loop deadlock + that occurs when the sync step() path is used via run_in_executor. """ - async with self.mcp_client: - return await self.mcp_client.call_tool(tool_name, arguments) + if isinstance(action, ListToolsAction): + return await self._async_handle_list_tools() + elif isinstance(action, CallToolAction): + return await self._async_handle_call_tool(action, timeout_s=timeout_s) + else: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + None, lambda: self._step_impl(action, timeout_s=timeout_s, **kwargs) + ) @abstractmethod def _step_impl( diff --git a/src/core/env_server/serialization.py b/src/core/env_server/serialization.py index a9b50d9aeb873794044e77ee398a7f2b5fca8093..fd5fb588c739c3dc2bfdc1a24e55d3a95cf54543 100644 --- a/src/core/env_server/serialization.py +++ b/src/core/env_server/serialization.py @@ -14,14 +14,28 @@ HTTP server and web interface implementations. from typing import Any, Dict, Type +from .mcp_types import CallToolAction, ListToolsAction from .types import Action, Observation +# MCP action types keyed by their "type" discriminator value. +# These are checked before the environment's own action_cls so that +# ListToolsAction / CallToolAction payloads are never rejected by an +# unrelated Pydantic model. +_MCP_ACTION_TYPES: Dict[str, Type[Action]] = { + "list_tools": ListToolsAction, + "call_tool": CallToolAction, +} + def deserialize_action(action_data: Dict[str, Any], action_cls: Type[Action]) -> Action: """ Convert JSON dict to Action instance using Pydantic validation. - This is a basic deserialization that works for most environments. + MCP action types (``list_tools``, ``call_tool``) are recognised + automatically via the ``"type"`` discriminator field, regardless of + the environment's configured ``action_cls``. All other payloads + fall through to ``action_cls.model_validate()``. + For special cases (e.g., tensor fields, custom type conversions), use deserialize_action_with_preprocessing(). @@ -38,6 +52,17 @@ def deserialize_action(action_data: Dict[str, Any], action_cls: Type[Action]) -> Note: This uses Pydantic's model_validate() for automatic validation. """ + # Route MCP action types before falling through to the env action_cls. + # Only intercept when action_cls is the generic Action base or itself an + # MCP type (i.e. the server hosts an MCP environment). This avoids + # silently bypassing env-specific validation for non-MCP environments + # that happen to use "call_tool" / "list_tools" as a type discriminator. + action_type = action_data.get("type") + if action_type in _MCP_ACTION_TYPES: + mcp_cls = _MCP_ACTION_TYPES[action_type] + if action_cls is Action or action_cls in _MCP_ACTION_TYPES.values(): + return mcp_cls.model_validate(action_data) + return action_cls.model_validate(action_data) @@ -62,6 +87,15 @@ def deserialize_action_with_preprocessing( Raises: ValidationError: If action_data is invalid for the action class """ + # Route MCP action types before preprocessing (they don't need it). + # Same guard as deserialize_action: only intercept when action_cls is + # the generic Action base or itself an MCP type. + action_type = action_data.get("type") + if action_type in _MCP_ACTION_TYPES: + mcp_cls = _MCP_ACTION_TYPES[action_type] + if action_cls is Action or action_cls in _MCP_ACTION_TYPES.values(): + return mcp_cls.model_validate(action_data) + processed_data = {} for key, value in action_data.items(): diff --git a/src/core/env_server/web_interface.py b/src/core/env_server/web_interface.py index 284740eb408b8e2b798037918967b7a50abee72d..026093887cbb43e995df64881c849ca6ed4ac5de 100644 --- a/src/core/env_server/web_interface.py +++ b/src/core/env_server/web_interface.py @@ -15,13 +15,15 @@ option (e.g. openenv push --enable-interface) or ENABLE_WEB_INTERFACE env var. from __future__ import annotations import asyncio +import inspect import json from concurrent.futures import ThreadPoolExecutor from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Type import gradio as gr -from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from fastapi import Body, FastAPI, HTTPException, status, WebSocket, WebSocketDisconnect +from fastapi.responses import RedirectResponse from pydantic import BaseModel, ConfigDict, Field from .gradio_theme import OPENENV_GRADIO_CSS, OPENENV_GRADIO_THEME @@ -269,6 +271,28 @@ class WebInterfaceManager: # Thread pool for running sync code (e.g., Playwright sync API) in async context self._executor = ThreadPoolExecutor(max_workers=1) + @staticmethod + def _get_valid_kwargs( + sig: inspect.Signature, + kwargs: Dict[str, Any], + skip_params: Optional[set[str]] = None, + ) -> Dict[str, Any]: + """Filter kwargs to only those accepted by the target function.""" + skip_params = skip_params or set() + valid_kwargs: Dict[str, Any] = {} + has_var_kwargs = any( + param.kind == inspect.Parameter.VAR_KEYWORD + for param in sig.parameters.values() + ) + + for key, value in kwargs.items(): + if key in skip_params: + continue + if key in sig.parameters or has_var_kwargs: + valid_kwargs[key] = value + + return valid_kwargs + async def _run_sync_in_thread_pool(self, func, *args, **kwargs): """Run a synchronous function in the thread pool executor. @@ -317,11 +341,24 @@ class WebInterfaceManager: for client in disconnected_clients: self.connected_clients.remove(client) - async def reset_environment(self) -> Dict[str, Any]: + async def reset_environment( + self, reset_kwargs: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: """Reset the environment and update state.""" - # Run sync reset in thread pool to avoid blocking event loop - # and to support environments using sync libraries (e.g., Playwright) - observation: Observation = await self._run_sync_in_thread_pool(self.env.reset) + reset_kwargs = reset_kwargs or {} + + is_async = self.env.reset_async.__func__ is not Environment.reset_async + sig = inspect.signature(self.env.reset_async if is_async else self.env.reset) + valid_kwargs = self._get_valid_kwargs(sig, reset_kwargs) + + if is_async: + observation = await self.env.reset_async(**valid_kwargs) + else: + # Run sync reset in thread pool to avoid blocking event loop + # and to support environments using sync libraries (e.g., Playwright) + observation = await self._run_sync_in_thread_pool( + self.env.reset, **valid_kwargs + ) state: State = self.env.state # Serialize observation once using shared utility @@ -428,6 +465,16 @@ def create_web_interface_app( web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata) # Web API routes first (so they take precedence over Gradio mount at /web) + @app.get("/", include_in_schema=False) + async def web_root(): + """Redirect the app root to the Gradio interface.""" + return RedirectResponse(url="/web/") + + @app.get("/web", include_in_schema=False) + async def web_root_no_slash(): + """Redirect /web to /web/ for mounted Gradio deployments behind proxies.""" + return RedirectResponse(url="/web/") + @app.get("/web/metadata") async def web_metadata(): """Get environment metadata.""" @@ -449,9 +496,9 @@ def create_web_interface_app( await web_manager.disconnect_websocket(websocket) @app.post("/web/reset") - async def web_reset(): + async def web_reset(request: Optional[Dict[str, Any]] = Body(default=None)): """Reset endpoint for web interface.""" - return await web_manager.reset_environment() + return await web_manager.reset_environment(request) @app.post("/web/step") async def web_step(request: Dict[str, Any]): @@ -475,7 +522,13 @@ def create_web_interface_app( @app.get("/web/state") async def web_state(): """State endpoint for web interface.""" - return web_manager.get_state() + try: + return web_manager.get_state() + except RuntimeError as exc: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=str(exc), + ) from exc action_fields = _extract_action_fields(action_cls) is_chat_env = _is_chat_env(action_cls) @@ -505,7 +558,7 @@ def create_web_interface_app( ) gradio_blocks = gr.TabbedInterface( [default_blocks, custom_blocks], - tab_names=["Playground", "Visualization"], + tab_names=["Playground", "Custom"], title=get_gradio_display_title(metadata), ) else: diff --git a/src/core/mcp_client.py b/src/core/mcp_client.py index edac3529d3a34e798781d86cf4d2495dc9611713..1d8bd38efd3595526fc25915c8fdbbe7aaeca5d5 100644 --- a/src/core/mcp_client.py +++ b/src/core/mcp_client.py @@ -52,6 +52,7 @@ Example (sync wrapper): ... result = env.call_tool("echo_message", message="Hello!") """ +import asyncio from typing import Any, Dict, List, Optional from .client_types import StepResult @@ -118,6 +119,66 @@ class MCPClientBase(EnvClient[Any, Observation, State]): ) self._tools_cache: Optional[List[Tool]] = None self.use_production_mode = False + self._production_session_id: Optional[str] = None + self._production_session_lock = asyncio.Lock() + self._jsonrpc_request_id = 0 + self._http_client: Optional[Any] = None # lazily-created httpx.AsyncClient + + def _next_request_id(self) -> int: + """Generate a monotonically increasing JSON-RPC request id.""" + self._jsonrpc_request_id += 1 + return self._jsonrpc_request_id + + def _production_mcp_url(self) -> str: + """Build HTTP MCP endpoint URL from the client's websocket URL.""" + url = self._ws_url.replace("ws://", "http://").replace("wss://", "https://") + if url.endswith("/ws"): + url = url[: -len("/ws")] + return url.rstrip("/") + "/mcp" + + async def _get_http_client(self) -> Any: + """Return a shared httpx.AsyncClient, creating one lazily.""" + if self._http_client is None: + import httpx + + self._http_client = httpx.AsyncClient() + return self._http_client + + async def _production_mcp_request( + self, method: str, params: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """Send a JSON-RPC request to HTTP /mcp and return parsed JSON response.""" + client = await self._get_http_client() + response = await client.post( + self._production_mcp_url(), + json={ + "jsonrpc": "2.0", + "method": method, + "params": params or {}, + "id": self._next_request_id(), + }, + timeout=self._message_timeout, + ) + response.raise_for_status() + return response.json() + + async def _ensure_production_session(self) -> str: + """Create and cache a persistent HTTP MCP session id if needed.""" + async with self._production_session_lock: + if self._production_session_id is not None: + return self._production_session_id + + data = await self._production_mcp_request("openenv/session/create") + if "error" in data: + message = data.get("error", {}).get("message", "unknown error") + raise RuntimeError(f"Failed to create MCP session: {message}") + + session_id = data.get("result", {}).get("session_id") + if not session_id: + raise RuntimeError("Failed to create MCP session: missing session_id") + + self._production_session_id = session_id + return session_id async def list_tools(self, use_cache: bool = True) -> List[Tool]: """ @@ -138,26 +199,18 @@ class MCPClientBase(EnvClient[Any, Observation, State]): if use_cache and self._tools_cache is not None: return self._tools_cache - # Use production mode HTTP endpoint if enabled - if self.use_production_mode: - import requests - - # Convert ws:// URL to http:// URL - url = self._ws_url.replace("ws://", "http://").replace("wss://", "https://") - # Remove /ws suffix if present and add /mcp - url = url.rstrip("/ws").rstrip("/") + "/mcp" - + # Use production mode HTTP endpoint if enabled. + # Some tests instantiate with __new__ and skip __init__, so default missing flag to False. + if getattr(self, "use_production_mode", False): try: - response = requests.post( - url, - json={ - "jsonrpc": "2.0", - "method": "tools/list", - "params": {}, - "id": 1, - }, + session_id = await self._ensure_production_session() + data = await self._production_mcp_request( + "tools/list", + {"session_id": session_id}, ) - data = response.json() + if "error" in data: + message = data.get("error", {}).get("message", "unknown error") + raise RuntimeError(f"list_tools failed: {message}") if "result" in data and "tools" in data["result"]: tools = [ Tool( @@ -177,7 +230,12 @@ class MCPClientBase(EnvClient[Any, Observation, State]): return [] result = await self.step(ListToolsAction()) - self._tools_cache = result.observation.tools + if isinstance(result.observation, ListToolsObservation): + self._tools_cache = result.observation.tools + return self._tools_cache + + # Unexpected observation type; keep API stable with an empty tool list. + self._tools_cache = [] return self._tools_cache def _step_payload(self, action: Any) -> Dict[str, Any]: @@ -251,6 +309,35 @@ class MCPClientBase(EnvClient[Any, Observation, State]): step_count=payload.get("step_count", 0), ) + async def close(self) -> None: + """ + Close client resources. + + In production MCP mode, this also closes the server-side persistent + MCP session (best effort) before closing websocket/provider resources. + """ + if self._production_session_id is not None: + try: + await self._production_mcp_request( + "openenv/session/close", + {"session_id": self._production_session_id}, + ) + except Exception: + # Best effort cleanup - do not mask normal close behavior + pass + finally: + self._production_session_id = None + + if self._http_client is not None: + try: + await self._http_client.aclose() + except Exception: + pass + finally: + self._http_client = None + + await super().close() + class MCPToolClient(MCPClientBase): """ @@ -316,6 +403,26 @@ class MCPToolClient(MCPClientBase): >>> result = await env.call_tool("greet", name="Claude") >>> print(result) # "Hello, Claude!" """ + if getattr(self, "use_production_mode", False): + session_id = await self._ensure_production_session() + data = await self._production_mcp_request( + "tools/call", + { + "name": name, + "arguments": kwargs, + "session_id": session_id, + }, + ) + + if "error" in data: + message = data.get("error", {}).get("message", "unknown error") + raise RuntimeError(f"Tool '{name}' failed: {message}") + + result = data.get("result") + if isinstance(result, dict) and "data" in result: + return result["data"] + return result + action = CallToolAction(tool_name=name, arguments=kwargs) result = await self.step(action) obs = result.observation diff --git a/src/core/openenv/__init__.py b/src/core/openenv/__init__.py index cabe2abc6a70dacafe04f0583b27b2552bab1e47..ef29784ad031f4601adcbefc8bf9d3b9137c353f 100644 --- a/src/core/openenv/__init__.py +++ b/src/core/openenv/__init__.py @@ -14,10 +14,18 @@ __all__ = [ "SyncEnvClient", ] -try: - __version__ = metadata.version("openenv") # type: ignore[arg-type] -except metadata.PackageNotFoundError: # pragma: no cover - local dev - __version__ = "0.0.0" + +def _load_package_version() -> str: + """Resolve the installed distribution version for the OpenEnv package.""" + for distribution_name in ("openenv-core", "openenv"): + try: + return metadata.version(distribution_name) + except metadata.PackageNotFoundError: + continue + return "0.0.0" + + +__version__ = _load_package_version() _LAZY_MODULES = { diff --git a/src/core/openenv/cli/templates/openenv_env/pyproject.toml b/src/core/openenv/cli/templates/openenv_env/pyproject.toml index a8e59fbfa3dbc8a0df7c84d479e79cef062d8e61..b63103db9111f91be99328cad38b351e89810eb8 100644 --- a/src/core/openenv/cli/templates/openenv_env/pyproject.toml +++ b/src/core/openenv/cli/templates/openenv_env/pyproject.toml @@ -17,7 +17,7 @@ dependencies = [ # Core OpenEnv runtime (provides FastAPI server + HTTP client types) # install from github # "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git", - "openenv-core[core]>=0.2.1", + "openenv-core[core]>=0.2.2", # Environment-specific dependencies # Add all dependencies needed for your environment here # Examples: diff --git a/src/core/openenv/core/env_server/http_server.py b/src/core/openenv/core/env_server/http_server.py index 658f63ef98bf78d278b8926271c217da23c79a37..f59012b60d335e596fc25866db4c64cbeafaa5a3 100644 --- a/src/core/openenv/core/env_server/http_server.py +++ b/src/core/openenv/core/env_server/http_server.py @@ -16,11 +16,15 @@ from __future__ import annotations import asyncio import inspect import json +import logging import os import time import uuid from concurrent.futures import ThreadPoolExecutor -from typing import Any, Callable, Dict, Optional, Type +from contextlib import AsyncExitStack +from typing import Any, AsyncContextManager, Callable, cast, Dict, Optional, Type + +_MISSING = object() from fastapi import ( Body, @@ -204,8 +208,9 @@ class HTTPEnvServer: self.observation_cls = observation_cls # Session management for WebSocket connections - self._sessions: Dict[str, Environment] = {} + self._sessions: Dict[str, Optional[Environment]] = {} self._session_executors: Dict[str, ThreadPoolExecutor] = {} + self._session_stacks: Dict[str, AsyncExitStack] = {} self._session_info: Dict[str, SessionInfo] = {} self._session_lock = asyncio.Lock() @@ -213,6 +218,14 @@ class HTTPEnvServer: # This is needed for environments using sync libraries (e.g., Playwright) self._executor = ThreadPoolExecutor(max_workers=32) + # Idle session reaper configuration. + # Timeout is taken from ConcurrencyConfig.session_timeout; + # None means no timeout (default — reaper is a no-op). + self._session_idle_timeout_s: Optional[float] = ( + self._concurrency_config.session_timeout + ) + self._reaper_task: Optional[asyncio.Task[None]] = None + def _validate_concurrency_safety(self) -> None: """ Validate that the environment supports the configured concurrency level. @@ -321,12 +334,37 @@ class HTTPEnvServer: ) raise EnvironmentFactoryError(factory_name) from e + # Hold the MCP session open for the lifetime of this session, + # matching the WebSocket path's AsyncExitStack pattern. This + # prevents per-request MCP transport teardown/reconnection and + # preserves FastMCP session state (ctx.set_state / ctx.get_state) + # across HTTP calls within the same OpenEnv session. + stack = AsyncExitStack() + try: + mcp_session_factory = getattr(env, "mcp_session", None) + if callable(mcp_session_factory): + mcp_session_cm = cast(AsyncContextManager[Any], mcp_session_factory()) + await stack.enter_async_context(mcp_session_cm) + except Exception: + # MCP transport failed to start — clean up the reserved slot, + # the env, and the executor so they don't leak permanently + # against _max_concurrent_envs. + await stack.aclose() # best-effort + async with self._session_lock: + self._sessions.pop(session_id, None) + self._session_executors.pop(session_id, None) + self._session_info.pop(session_id, None) + await self._cleanup_session_resources(env, executor) + raise + async with self._session_lock: self._sessions[session_id] = env + self._session_stacks[session_id] = stack + now = time.time() self._session_info[session_id] = SessionInfo( session_id=session_id, created_at=current_time, - last_activity_at=current_time, + last_activity_at=now, step_count=0, environment_type=type(env).__name__, ) @@ -343,8 +381,27 @@ class HTTPEnvServer: async with self._session_lock: env = self._sessions.pop(session_id, None) executor = self._session_executors.pop(session_id, None) + stack = self._session_stacks.pop(session_id, None) self._session_info.pop(session_id, None) + await self._cleanup_session_resources(env, executor, stack) + + async def _cleanup_session_resources( + self, + env: Optional[Environment], + executor: Optional[ThreadPoolExecutor], + stack: Optional[AsyncExitStack] = None, + ) -> None: + """Close an environment and shut down its executor (best-effort).""" + # Close the MCP session stack first — this gracefully exits the + # mcp_session() context (and the underlying FastMCP Client session) + # before we tear down the environment references. + if stack is not None: + try: + await stack.aclose() + except Exception: + pass # Best effort cleanup + # Run close() in the same executor where the env was created # This is required for thread-sensitive libraries like Playwright/greenlet if env is not None: @@ -383,6 +440,51 @@ class HTTPEnvServer: if increment_step: self._session_info[session_id].step_count += 1 + async def _reap_idle_sessions(self) -> None: + """Background task that periodically destroys sessions idle beyond the timeout.""" + timeout = self._session_idle_timeout_s + if timeout is None: + return # no timeout configured — noop + interval = max(timeout / 4, 5.0) # check frequently enough + while True: + try: + await asyncio.sleep(interval) + now = time.time() + stale_ids: list[str] = [] + async with self._session_lock: + for sid, info in self._session_info.items(): + if now - info.last_activity_at > timeout: + stale_ids.append(sid) + for sid in stale_ids: + # Re-check under lock: activity may have arrived since + # the snapshot was taken, making this session active again. + # Refresh `now` so slow _destroy_session calls don't cause + # subsequent entries to be validated against a stale clock. + now = time.time() + async with self._session_lock: + info = self._session_info.get(sid) + if info is None or (now - info.last_activity_at) <= timeout: + continue + await self._destroy_session(sid) + except asyncio.CancelledError: + break + except Exception as exc: + logging.getLogger(__name__).warning( + "Idle-session reaper encountered an error (will retry): %s", + exc, + ) + + def _start_reaper(self) -> None: + """Start the idle-session reaper if a timeout is configured.""" + if self._session_idle_timeout_s is not None and self._reaper_task is None: + self._reaper_task = asyncio.create_task(self._reap_idle_sessions()) + + def _stop_reaper(self) -> None: + """Cancel the reaper background task.""" + if self._reaper_task is not None: + self._reaper_task.cancel() + self._reaper_task = None + def get_session_info(self, session_id: str) -> Optional[SessionInfo]: """ Get information about a specific session. @@ -458,6 +560,20 @@ class HTTPEnvServer: f"Invalid mode: '{mode}'. Must be one of: {valid_modes}" ) + # Wire up idle-session reaper lifecycle via app events + server_ref = self + + async def _start_session_reaper() -> None: + server_ref._start_reaper() + + async def _stop_session_reaper() -> None: + server_ref._stop_reaper() + + if not getattr(app.router, "_openenv_reaper_registered", False): + app.router.on_startup.append(_start_session_reaper) + app.router.on_shutdown.append(_stop_session_reaper) + app.router._openenv_reaper_registered = True # type: ignore[attr-defined] + # Helper function to handle reset endpoint async def reset_handler( request: ResetRequest = Body(default_factory=ResetRequest), @@ -526,53 +642,214 @@ class HTTPEnvServer: # Helper function to handle MCP endpoint async def mcp_handler( - request: JsonRpcRequest, session_env: Optional[Environment] = None + request: JsonRpcRequest, + session_env: Optional[Environment] = None, + session_id: Optional[str] = None, ) -> JsonRpcResponse: """ Handle MCP JSON-RPC requests. - Supports tools/list and tools/call methods in JSON-RPC 2.0 format. + Supports tools/list and tools/call methods in JSON-RPC 2.0 format, + plus OpenEnv session lifecycle methods for HTTP MCP: + - openenv/session/create + - openenv/session/close """ method = request.method request_id = request.id + params = request.params + if not isinstance(params, dict): + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_PARAMS, + "Params must be an object", + request_id=request_id, + ) + + # OpenEnv extension methods for explicit MCP session management. + # This enables persistent MCP lifecycles over HTTP /mcp, matching WebSocket semantics. + if method == "openenv/session/create": + if session_env is not None and session_id is not None: + return JsonRpcResponse.success( + result={"session_id": session_id}, + request_id=request_id, + ) + try: + created_session_id, _ = await self._create_session() + except SessionCapacityError as e: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.SERVER_ERROR, + str(e), + request_id=request_id, + data={ + "active_sessions": e.active_sessions, + "max_sessions": e.max_sessions, + }, + ) + except EnvironmentFactoryError as e: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.SERVER_ERROR, + str(e), + request_id=request_id, + data={"factory_name": e.factory_name}, + ) + return JsonRpcResponse.success( + result={"session_id": created_session_id}, + request_id=request_id, + ) + + if method == "openenv/session/close": + target_session_id = params.get("session_id") + if not target_session_id: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_PARAMS, + "Invalid params - 'session_id' is required", + request_id=request_id, + ) + + if session_id is not None and target_session_id == session_id: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_REQUEST, + "Cannot close active WebSocket-managed session via MCP method", + request_id=request_id, + ) + + async with self._session_lock: + env = self._sessions.pop(target_session_id, _MISSING) + if env is not _MISSING: + executor = self._session_executors.pop(target_session_id, None) + stack = self._session_stacks.pop(target_session_id, None) + self._session_info.pop(target_session_id, None) + else: + executor = None + stack = None + + if env is _MISSING: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_PARAMS, + f"Unknown session_id: {target_session_id}", + request_id=request_id, + ) + + if env is None: + # Session slot reserved but env factory still running; + # re-insert the placeholder AND the executor so + # _create_session can finish and the executor remains + # tracked for eventual shutdown. + async with self._session_lock: + self._sessions[target_session_id] = None + if executor is not None: + self._session_executors[target_session_id] = executor + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_REQUEST, + f"Session {target_session_id} is still initializing; retry shortly", + request_id=request_id, + ) + + # env/executor/stack cleanup outside the lock + await self._cleanup_session_resources(env, executor, stack) + return JsonRpcResponse.success( + result={"session_id": target_session_id, "closed": True}, + request_id=request_id, + ) + + requested_session_id = params.get("session_id") + managed_session_id = session_id # Use provided session environment or create temporary one if session_env is not None: _env = session_env should_close = False + elif requested_session_id: + async with self._session_lock: + _env = self._sessions.get(requested_session_id, _MISSING) + + if _env is _MISSING: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_PARAMS, + f"Unknown session_id: {requested_session_id}", + request_id=request_id, + ) + + if _env is None: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_REQUEST, + f"Session {requested_session_id} is still initializing; retry shortly", + request_id=request_id, + ) + + should_close = False + managed_session_id = requested_session_id else: _env = self._env_factory() should_close = True try: + mcp_client = getattr(_env, "mcp_client", None) + mcp_server = getattr(_env, "mcp_server", None) + mcp_session_factory = getattr(_env, "mcp_session", None) + if method == McpMethod.TOOLS_LIST: # Check if environment is MCP-enabled - if not hasattr(_env, "mcp_client"): + if mcp_client is None and mcp_server is None: return JsonRpcResponse.error_response( JsonRpcErrorCode.INTERNAL_ERROR, "Environment does not support MCP", request_id=request_id, ) - # Use async context manager for MCP client - async with _env.mcp_client: - tools = await _env.mcp_client.list_tools() + if mcp_client: + if managed_session_id and mcp_client.is_connected(): + # Session-managed with live transport — call + # directly, no redundant re-entry. + tools = await mcp_client.list_tools() + elif callable(mcp_session_factory): + # Stateless request, or session-managed but the + # background transport was lost: (re-)open. + mcp_session_cm = cast( + AsyncContextManager[Any], mcp_session_factory() + ) + async with mcp_session_cm: + tools = await mcp_client.list_tools() + else: + async with mcp_client: + tools = await mcp_client.list_tools() + + return JsonRpcResponse.success( + result={ + "tools": [ + t.model_dump() + if hasattr(t, "model_dump") + else dict(t) + for t in tools + ] + }, + request_id=request_id, + ) - return JsonRpcResponse.success( - result={ - "tools": [ - t.model_dump() if hasattr(t, "model_dump") else dict(t) - for t in tools - ] - }, + if mcp_server: + tools = [] + for _tool_name, tool in get_server_tools(mcp_server).items(): + tools.append( + { + "name": tool.name, + "description": tool.description or "", + "inputSchema": tool.parameters or {}, + } + ) + return JsonRpcResponse.success( + result={"tools": tools}, + request_id=request_id, + ) + + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + "MCP server not available", request_id=request_id, ) elif method == McpMethod.TOOLS_CALL: - params = request.params tool_name = params.get("name") arguments = params.get("arguments", {}) - if not hasattr(_env, "mcp_client"): + if mcp_client is None and mcp_server is None: return JsonRpcResponse.error_response( JsonRpcErrorCode.INTERNAL_ERROR, "Environment does not support MCP", @@ -581,15 +858,51 @@ class HTTPEnvServer: if not tool_name: return JsonRpcResponse.error_response( - JsonRpcErrorCode.INVALID_REQUEST, + JsonRpcErrorCode.INVALID_PARAMS, "Missing 'name' in params", request_id=request_id, ) - # Use async context manager for MCP client - async with _env.mcp_client: - result = await _env.mcp_client.call_tool( - name=tool_name, arguments=arguments + if mcp_client: + if managed_session_id and mcp_client.is_connected(): + # Session-managed with live transport. + result = await mcp_client.call_tool( + name=tool_name, arguments=arguments + ) + elif callable(mcp_session_factory): + # Stateless request, or session-managed but the + # background transport was lost: (re-)open. + mcp_session_cm = cast( + AsyncContextManager[Any], mcp_session_factory() + ) + async with mcp_session_cm: + result = await mcp_client.call_tool( + name=tool_name, arguments=arguments + ) + else: + async with mcp_client: + result = await mcp_client.call_tool( + name=tool_name, arguments=arguments + ) + elif mcp_server: + server_tools = get_server_tools(mcp_server) + if tool_name in server_tools: + tool = server_tools[tool_name] + if inspect.iscoroutinefunction(tool.fn): + result = await tool.fn(**arguments) + else: + result = tool.fn(**arguments) + else: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_PARAMS, + f"Tool not found: {tool_name}", + request_id=request_id, + ) + else: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + "MCP server not available", + request_id=request_id, ) # Ensure result is JSON serializable @@ -614,6 +927,11 @@ class HTTPEnvServer: request_id=request_id, ) finally: + if managed_session_id: + self._update_session_activity( + managed_session_id, + increment_step=(method == McpMethod.TOOLS_CALL), + ) if should_close: _env.close() @@ -637,42 +955,59 @@ class HTTPEnvServer: try: # Create session with dedicated environment session_id, session_env = await self._create_session() + if session_env is None: + raise RuntimeError( + "Session environment not initialized for MCP websocket" + ) - while True: - # Receive message from client - raw_message = await websocket.receive_text() - - try: - jsonrpc_dict = json.loads(raw_message) - jsonrpc_request = JsonRpcRequest(**jsonrpc_dict) - except json.JSONDecodeError as e: - error_resp = JsonRpcResponse.error_response( - JsonRpcErrorCode.PARSE_ERROR, - f"Parse error: {e}", - ) - await websocket.send_text(error_resp.model_dump_json()) - continue - except ValidationError as e: - error_resp = JsonRpcResponse.error_response( - JsonRpcErrorCode.INVALID_REQUEST, - f"Invalid request: {e}", - ) - await websocket.send_text(error_resp.model_dump_json()) - continue + # If environment has an mcp_session context manager, hold it open + # for the lifetime of the websocket connection - try: - # Call mcp_handler with session environment - response = await mcp_handler( - jsonrpc_request, session_env=session_env + async with AsyncExitStack() as stack: + mcp_session_factory = getattr(session_env, "mcp_session", None) + if callable(mcp_session_factory): + mcp_session_cm = cast( + AsyncContextManager[Any], mcp_session_factory() ) - await websocket.send_text(response.model_dump_json()) - except Exception as e: - error_resp = JsonRpcResponse.error_response( - JsonRpcErrorCode.INTERNAL_ERROR, - str(e), - request_id=jsonrpc_request.id, - ) - await websocket.send_text(error_resp.model_dump_json()) + await stack.enter_async_context(mcp_session_cm) + + while True: + # Receive message from client + raw_message = await websocket.receive_text() + + try: + jsonrpc_dict = json.loads(raw_message) + jsonrpc_request = JsonRpcRequest(**jsonrpc_dict) + except json.JSONDecodeError as e: + error_resp = JsonRpcResponse.error_response( + JsonRpcErrorCode.PARSE_ERROR, + f"Parse error: {e}", + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + except ValidationError as e: + error_resp = JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_REQUEST, + f"Invalid request: {e}", + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + + try: + # Call mcp_handler with session environment + response = await mcp_handler( + jsonrpc_request, + session_env=session_env, + session_id=session_id, + ) + await websocket.send_text(response.model_dump_json()) + except Exception as e: + error_resp = JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + str(e), + request_id=jsonrpc_request.id, + ) + await websocket.send_text(error_resp.model_dump_json()) except WebSocketDisconnect: pass @@ -931,120 +1266,8 @@ all schema information needed to interact with the environment. JsonRpcErrorCode.PARSE_ERROR ).model_dump() - method = request.method - params = request.params - request_id = request.id - - # Create a temporary environment for MCP access - _env = self._env_factory() - - try: - # Check if environment supports MCP - if not hasattr(_env, "mcp_client") and not hasattr(_env, "mcp_server"): - return JsonRpcResponse.error_response( - JsonRpcErrorCode.INTERNAL_ERROR, - "Environment does not support MCP", - request_id=request_id, - ).model_dump() - - if method == McpMethod.TOOLS_LIST: - # List tools from MCP server - if hasattr(_env, "mcp_client") and _env.mcp_client: - async with _env.mcp_client: - tools = await _env.mcp_client.list_tools() - return JsonRpcResponse.success( - result={ - "tools": [ - t.model_dump() - if hasattr(t, "model_dump") - else dict(t) - for t in tools - ] - }, - request_id=request_id, - ).model_dump() - elif hasattr(_env, "mcp_server") and _env.mcp_server: - # Use server directly - tools = [] - for tool_name, tool in get_server_tools( - _env.mcp_server - ).items(): - tool_dict = { - "name": tool.name, - "description": tool.description or "", - "inputSchema": tool.parameters or {}, - } - tools.append(tool_dict) - return JsonRpcResponse.success( - result={"tools": tools}, - request_id=request_id, - ).model_dump() - else: - return JsonRpcResponse.error_response( - JsonRpcErrorCode.INTERNAL_ERROR, - "MCP server not available", - request_id=request_id, - ).model_dump() - - elif method == McpMethod.TOOLS_CALL: - tool_name = params.get("name") - arguments = params.get("arguments", {}) - - if not tool_name: - return JsonRpcResponse.error_response( - JsonRpcErrorCode.INVALID_PARAMS, - "Invalid params - 'name' is required", - request_id=request_id, - ).model_dump() - - # Call tool via MCP - if hasattr(_env, "mcp_client") and _env.mcp_client: - async with _env.mcp_client: - result = await _env.mcp_client.call_tool( - name=tool_name, arguments=arguments - ) - elif hasattr(_env, "mcp_server") and _env.mcp_server: - # Call tool directly on FastMCP server - server_tools = get_server_tools(_env.mcp_server) - if tool_name in server_tools: - tool = server_tools[tool_name] - result = tool.fn(**arguments) - else: - return JsonRpcResponse.error_response( - JsonRpcErrorCode.INVALID_PARAMS, - f"Tool not found: {tool_name}", - request_id=request_id, - ).model_dump() - else: - return JsonRpcResponse.error_response( - JsonRpcErrorCode.INTERNAL_ERROR, - "MCP server not available", - request_id=request_id, - ).model_dump() - - # Make result JSON serializable - serializable_result = _make_json_serializable(result) - - return JsonRpcResponse.success( - result=serializable_result, - request_id=request_id, - ).model_dump() - - else: - return JsonRpcResponse.error_response( - JsonRpcErrorCode.METHOD_NOT_FOUND, - f"Method not found: {method}", - request_id=request_id, - ).model_dump() - - except Exception as e: - return JsonRpcResponse.error_response( - JsonRpcErrorCode.INTERNAL_ERROR, - str(e), - request_id=request_id, - ).model_dump() - finally: - _env.close() + response = await mcp_handler(request) + return response.model_dump() # Register WebSocket endpoint for persistent sessions @app.websocket("/ws") @@ -1066,135 +1289,167 @@ all schema information needed to interact with the environment. try: # Create session with dedicated environment session_id, session_env = await self._create_session() + if session_env is None: + raise RuntimeError( + "Session environment not initialized for websocket" + ) - while True: - # Receive message from client - raw_message = await websocket.receive_text() + # Keep MCP session open for entire websocket lifetime + # (avoids reconnect overhead on every message) - try: - message_dict = json.loads(raw_message) - except json.JSONDecodeError as e: - error_resp = WSErrorResponse( - data={ - "message": f"Invalid JSON: {e}", - "code": WSErrorCode.INVALID_JSON, - } + async with AsyncExitStack() as stack: + mcp_session_factory = getattr(session_env, "mcp_session", None) + if callable(mcp_session_factory): + mcp_session_cm = cast( + AsyncContextManager[Any], mcp_session_factory() ) - await websocket.send_text(error_resp.model_dump_json()) - continue - - msg_type = message_dict.get("type", "") + await stack.enter_async_context(mcp_session_cm) + + while True: + # Receive message from client + raw_message = await websocket.receive_text() + + try: + message_dict = json.loads(raw_message) + except json.JSONDecodeError as e: + error_resp = WSErrorResponse( + data={ + "message": f"Invalid JSON: {e}", + "code": WSErrorCode.INVALID_JSON, + } + ) + await websocket.send_text(error_resp.model_dump_json()) + continue - try: - match msg_type: - case "reset": - msg = WSResetMessage(**message_dict) + msg_type = message_dict.get("type", "") - is_async = ( - session_env.reset_async.__func__ - is not Environment.reset_async - ) + try: + match msg_type: + case "reset": + msg = WSResetMessage(**message_dict) - if is_async: - sig = inspect.signature(session_env.reset_async) - valid_kwargs = self._get_valid_kwargs(sig, msg.data) - observation = await session_env.reset_async( - **valid_kwargs + is_async = ( + session_env.reset_async.__func__ + is not Environment.reset_async ) - else: - sig = inspect.signature(session_env.reset) - valid_kwargs = self._get_valid_kwargs(sig, msg.data) - observation = await self._run_in_session_executor( - session_id, session_env.reset, **valid_kwargs - ) - - self._update_session_activity(session_id) - - response = WSObservationResponse( - data=serialize_observation(observation), - ) - case "step": - msg = WSStepMessage(**message_dict) - action = deserialize_action(msg.data, self.action_cls) + if is_async: + sig = inspect.signature(session_env.reset_async) + valid_kwargs = self._get_valid_kwargs( + sig, msg.data + ) + observation = await session_env.reset_async( + **valid_kwargs + ) + else: + sig = inspect.signature(session_env.reset) + valid_kwargs = self._get_valid_kwargs( + sig, msg.data + ) + observation = ( + await self._run_in_session_executor( + session_id, + session_env.reset, + **valid_kwargs, + ) + ) + + self._update_session_activity(session_id) + + response = WSObservationResponse( + data=serialize_observation(observation), + ) - is_async = ( - session_env.step_async.__func__ - is not Environment.step_async - ) + case "step": + msg = WSStepMessage(**message_dict) + action = deserialize_action( + msg.data, self.action_cls + ) - if is_async: - observation = await session_env.step_async(action) - else: - observation = await self._run_in_session_executor( - session_id, session_env.step, action + is_async = ( + session_env.step_async.__func__ + is not Environment.step_async ) - self._update_session_activity( - session_id, increment_step=True - ) + if is_async: + observation = await session_env.step_async( + action + ) + else: + observation = ( + await self._run_in_session_executor( + session_id, session_env.step, action + ) + ) + + self._update_session_activity( + session_id, increment_step=True + ) - response = WSObservationResponse( - data=serialize_observation(observation) - ) + response = WSObservationResponse( + data=serialize_observation(observation) + ) - case "state": - msg = WSStateMessage(**message_dict) - state = session_env.state - if hasattr(state, "model_dump"): - state_data = state.model_dump() - else: - state_data = dict(state) if state else {} - - response = WSStateResponse(data=state_data) - - case "close": - msg = WSCloseMessage(**message_dict) - break - - case "mcp": - msg = WSMCPMessage(**message_dict) - try: - rpc_request = JsonRpcRequest(**msg.data) - except (ValidationError, Exception) as e: - rpc_response = JsonRpcResponse.error_response( - JsonRpcErrorCode.INVALID_REQUEST, - f"Invalid request: {e}", + case "state": + msg = WSStateMessage(**message_dict) + state = session_env.state + if hasattr(state, "model_dump"): + state_data = state.model_dump() + else: + state_data = dict(state) if state else {} + + response = WSStateResponse(data=state_data) + + case "close": + msg = WSCloseMessage(**message_dict) + break + + case "mcp": + msg = WSMCPMessage(**message_dict) + try: + rpc_request = JsonRpcRequest(**msg.data) + except (ValidationError, Exception) as e: + rpc_response = JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_REQUEST, + f"Invalid request: {e}", + ) + else: + rpc_response = await mcp_handler( + rpc_request, + session_env=session_env, + session_id=session_id, + ) + response = WSMCPResponse( + data=rpc_response.model_dump() ) - else: - rpc_response = await mcp_handler( - rpc_request, - session_env=session_env, + + case _: + response = WSErrorResponse( + data={ + "message": f"Unknown message type: {msg_type}", + "code": WSErrorCode.UNKNOWN_TYPE, + } ) - response = WSMCPResponse(data=rpc_response.model_dump()) - - case _: - response = WSErrorResponse( - data={ - "message": f"Unknown message type: {msg_type}", - "code": WSErrorCode.UNKNOWN_TYPE, - } - ) - await websocket.send_text(response.model_dump_json()) + await websocket.send_text(response.model_dump_json()) - except ValidationError as e: - error_resp = WSErrorResponse( - data={ - "message": "Invalid message", - "code": WSErrorCode.VALIDATION_ERROR, - "errors": e.errors(), - } - ) - await websocket.send_text(error_resp.model_dump_json()) - except Exception as e: - error_resp = WSErrorResponse( - data={ - "message": str(e), - "code": WSErrorCode.EXECUTION_ERROR, - } - ) - await websocket.send_text(error_resp.model_dump_json()) + except ValidationError as e: + error_resp = WSErrorResponse( + data={ + "message": "Invalid message", + "code": WSErrorCode.VALIDATION_ERROR, + "errors": e.errors(), + } + ) + await websocket.send_text(error_resp.model_dump_json()) + except Exception as e: + error_resp = WSErrorResponse( + data={ + "message": str(e), + "code": WSErrorCode.EXECUTION_ERROR, + } + ) + await websocket.send_text(error_resp.model_dump_json()) except WebSocketDisconnect: pass @@ -1276,7 +1531,7 @@ def create_app( from .web_interface import create_web_interface_app return create_web_interface_app( - env, + cast(Any, env), action_cls, observation_cls, env_name, diff --git a/src/core/openenv/core/env_server/mcp_environment.py b/src/core/openenv/core/env_server/mcp_environment.py index 03f66e37897ec81796d468f3d0590d465deddea1..50ddec98d3e769bf241076a6deb3e4ff6cb229e6 100644 --- a/src/core/openenv/core/env_server/mcp_environment.py +++ b/src/core/openenv/core/env_server/mcp_environment.py @@ -56,6 +56,7 @@ import asyncio import inspect from abc import abstractmethod from collections import defaultdict +from contextlib import asynccontextmanager from typing import Any, Callable, Dict, Optional from fastmcp import Client @@ -164,6 +165,52 @@ class MCPEnvironment(Environment): # Track tool schemas for list_tools: {tool_name: {mode: schema}} self._mode_tool_schemas = defaultdict(dict) + def _require_mcp_client(self) -> Any: + """Return MCP client or raise if environment has been closed.""" + if self.mcp_client is None: + raise RuntimeError("MCP client is not available; environment is closed") + return self.mcp_client + + def _require_mcp_server(self) -> Any: + """Return MCP server or raise if environment has been closed.""" + if self.mcp_server is None: + raise RuntimeError("MCP server is not available; environment is closed") + return self.mcp_server + + @asynccontextmanager + async def mcp_session(self): + """ + Context manager for MCP client sessions. + + This wrapper serves two purposes: + + 1. **Null guard** — raises a clear error if ``close()`` has already + been called (``mcp_client`` is ``None``). + + 2. **AsyncExitStack adapter** — FastMCP's ``Client.__aenter__`` + creates a background ``asyncio.Task`` for session management. + When entered directly via ``AsyncExitStack`` in the HTTP session + path (``_create_session``), this task can be cancelled by ASGI + harnesses (e.g. Starlette ``TestClient``) between requests, + corrupting session state. Wrapping in an ``asynccontextmanager`` + generator isolates the task lifecycle: the generator frame keeps + ``async with client:`` suspended at ``yield``, so cleanup only + runs when the stack explicitly closes the generator — not when + the event loop cancels orphaned tasks. + + Delegates to FastMCP's ``Client`` context manager which is + reentrant: the first entry opens the transport and subsequent + (nested) entries simply increment an internal reference counter. + The transport is closed only when the outermost context exits. + + No external lock is needed because ``Client._connect`` / + ``Client._disconnect`` already serialise connection state changes + through their own ``anyio.Lock``. + """ + client = self._require_mcp_client() + async with client: + yield client + @property def supports_code_mode(self) -> bool: """Check if this environment supports code mode (execute_code).""" @@ -292,7 +339,8 @@ class MCPEnvironment(Environment): # If mode is None, register with FastMCP as usual if mode is None: - decorated_func = self.mcp_server.tool()(func) + mcp_server = self._require_mcp_server() + decorated_func = mcp_server.tool()(func) self._mode_tools[tool_name][None] = func return decorated_func @@ -372,24 +420,49 @@ class MCPEnvironment(Environment): return self._step_impl(action, timeout_s=timeout_s, **kwargs) def _handle_list_tools(self) -> ListToolsObservation: + """Sync wrapper — delegates to the canonical async implementation.""" + return run_async_safely(self._async_handle_list_tools()) + + async def _async_list_tools(self) -> list: """ - Handle a ListToolsAction by querying the MCP server. + Async helper to list tools from the MCP client. Returns: - ListToolsObservation containing all available tools with their - names, descriptions, and input schemas, filtered by current mode. + List of tool objects from the MCP server. """ - try: - # Get current mode - current_mode = getattr(self, "_mode", None) + async with self.mcp_session() as client: + return await client.list_tools() - # Start with tools from FastMCP server (mode=None tools) - tools_result = run_async_safely(self._async_list_tools()) + def _handle_call_tool( + self, + action: CallToolAction, + timeout_s: Optional[float] = None, + ) -> CallToolObservation: + """Sync wrapper — delegates to the canonical async implementation.""" + return run_async_safely( + self._async_handle_call_tool(action, timeout_s=timeout_s) + ) - # Build list of Tool objects - tools = [] + async def _async_call_tool(self, tool_name: str, arguments: dict) -> Any: + """ + Async helper to call a tool on the MCP server. - # Add FastMCP tools that are not mode-specific + Args: + tool_name: Name of the tool to invoke. + arguments: Dictionary of arguments to pass to the tool. + + Returns: + The result from the tool execution. + """ + async with self.mcp_session() as client: + return await client.call_tool(tool_name, arguments) + + async def _async_handle_list_tools(self) -> ListToolsObservation: + """Async version of _handle_list_tools — avoids run_async_safely.""" + try: + current_mode = getattr(self, "_mode", None) + tools_result = await self._async_list_tools() + tools = [] for tool in tools_result: if tool.name not in self._mode_tool_schemas: tools.append( @@ -401,11 +474,8 @@ class MCPEnvironment(Environment): else {}, ) ) - - # Add mode-specific tools available in current mode for tool_name, mode_schemas in self._mode_tool_schemas.items(): if None in mode_schemas: - # Tool available in all modes schema = mode_schemas[None] tools.append( Tool( @@ -415,7 +485,6 @@ class MCPEnvironment(Environment): ) ) elif current_mode in mode_schemas: - # Tool available in current mode schema = mode_schemas[current_mode] tools.append( Tool( @@ -424,65 +493,30 @@ class MCPEnvironment(Environment): input_schema=schema["input_schema"], ) ) - return ListToolsObservation(tools=tools) - except Exception as e: - # Return an observation with error in metadata return ListToolsObservation( tools=[], - metadata={ - "error": str(e), - "error_type": "list_tools_failed", - }, + metadata={"error": str(e), "error_type": "list_tools_failed"}, ) - async def _async_list_tools(self) -> list: - """ - Async helper to list tools from the MCP client. - - Returns: - List of tool objects from the MCP server. - """ - async with self.mcp_client: - return await self.mcp_client.list_tools() - - def _handle_call_tool( + async def _async_handle_call_tool( self, action: CallToolAction, timeout_s: Optional[float] = None, ) -> CallToolObservation: - """ - Handle a CallToolAction by invoking the specified tool. - - Args: - action: The CallToolAction containing tool_name and arguments. - timeout_s: Timeout in seconds. Defaults to MCP_TOOL_CALL_TIMEOUT (30s). - - Returns: - CallToolObservation with the tool's result or an error. - """ + """Async version of _handle_call_tool — avoids run_async_safely.""" timeout = timeout_s if timeout_s is not None else MCP_TOOL_CALL_TIMEOUT - - # Check if this is a mode-specific tool tool_name = action.tool_name current_mode = getattr(self, "_mode", None) if tool_name in self._mode_tools: mode_info = self._mode_tools[tool_name] - - # Check if tool is available in current mode - # Tool is available if: - # 1. It has a None mode (available in all modes), OR - # 2. It has an implementation for the current mode if None in mode_info: - # Use the mode-agnostic version func = mode_info[None] elif current_mode in mode_info: - # Use the mode-specific version func = mode_info[current_mode] else: - # Tool not available in current mode return CallToolObservation( tool_name=tool_name, result=None, @@ -491,16 +525,11 @@ class MCPEnvironment(Environment): message=f"Tool '{tool_name}' not available in {current_mode} mode", ), ) - - # Call the mode-specific function directly try: - # Check if function is async and await if necessary if inspect.iscoroutinefunction(func): - result = run_async_safely(func(**action.arguments)) + result = await func(**action.arguments) else: result = func(**action.arguments) - - # Wrap result in CallToolResult format to match FastMCP behavior return CallToolObservation( tool_name=tool_name, result=CallToolResult( @@ -521,22 +550,12 @@ class MCPEnvironment(Environment): ), ) - # Not a mode-specific tool, use FastMCP try: - # Run the async call_tool with timeout - # Use run_async_safely to handle both sync and async contexts - result = run_async_safely( - asyncio.wait_for( - self._async_call_tool(action.tool_name, action.arguments), - timeout=timeout, - ) - ) - - return CallToolObservation( - tool_name=action.tool_name, - result=result, + result = await asyncio.wait_for( + self._async_call_tool(action.tool_name, action.arguments), + timeout=timeout, ) - + return CallToolObservation(tool_name=action.tool_name, result=result) except asyncio.TimeoutError: return CallToolObservation( tool_name=action.tool_name, @@ -546,11 +565,8 @@ class MCPEnvironment(Environment): message=f"Tool '{action.tool_name}' timed out after {timeout} seconds", ), ) - except Exception as e: error_message = str(e) - - # Determine error type based on the exception if ( "not found" in error_message.lower() or "unknown tool" in error_message.lower() @@ -563,29 +579,34 @@ class MCPEnvironment(Environment): error_type = ToolErrorType.INVALID_ARGS else: error_type = ToolErrorType.EXECUTION_ERROR - return CallToolObservation( tool_name=action.tool_name, result=None, - error=ToolError( - error_type=error_type, - message=error_message, - ), + error=ToolError(error_type=error_type, message=error_message), ) - async def _async_call_tool(self, tool_name: str, arguments: dict) -> Any: + async def step_async( + self, + action: Action, + timeout_s: Optional[float] = None, + **kwargs: Any, + ) -> Observation: """ - Async helper to call a tool on the MCP server. + Async step that routes MCP actions without going through run_async_safely. - Args: - tool_name: Name of the tool to invoke. - arguments: Dictionary of arguments to pass to the tool. - - Returns: - The result from the tool execution. + The WebSocket handler calls this directly on the outer event loop, where + the MCP session is already open, avoiding the thread/event-loop deadlock + that occurs when the sync step() path is used via run_in_executor. """ - async with self.mcp_client: - return await self.mcp_client.call_tool(tool_name, arguments) + if isinstance(action, ListToolsAction): + return await self._async_handle_list_tools() + elif isinstance(action, CallToolAction): + return await self._async_handle_call_tool(action, timeout_s=timeout_s) + else: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + None, lambda: self._step_impl(action, timeout_s=timeout_s, **kwargs) + ) @abstractmethod def _step_impl( diff --git a/src/core/openenv/core/env_server/serialization.py b/src/core/openenv/core/env_server/serialization.py index a9b50d9aeb873794044e77ee398a7f2b5fca8093..fd5fb588c739c3dc2bfdc1a24e55d3a95cf54543 100644 --- a/src/core/openenv/core/env_server/serialization.py +++ b/src/core/openenv/core/env_server/serialization.py @@ -14,14 +14,28 @@ HTTP server and web interface implementations. from typing import Any, Dict, Type +from .mcp_types import CallToolAction, ListToolsAction from .types import Action, Observation +# MCP action types keyed by their "type" discriminator value. +# These are checked before the environment's own action_cls so that +# ListToolsAction / CallToolAction payloads are never rejected by an +# unrelated Pydantic model. +_MCP_ACTION_TYPES: Dict[str, Type[Action]] = { + "list_tools": ListToolsAction, + "call_tool": CallToolAction, +} + def deserialize_action(action_data: Dict[str, Any], action_cls: Type[Action]) -> Action: """ Convert JSON dict to Action instance using Pydantic validation. - This is a basic deserialization that works for most environments. + MCP action types (``list_tools``, ``call_tool``) are recognised + automatically via the ``"type"`` discriminator field, regardless of + the environment's configured ``action_cls``. All other payloads + fall through to ``action_cls.model_validate()``. + For special cases (e.g., tensor fields, custom type conversions), use deserialize_action_with_preprocessing(). @@ -38,6 +52,17 @@ def deserialize_action(action_data: Dict[str, Any], action_cls: Type[Action]) -> Note: This uses Pydantic's model_validate() for automatic validation. """ + # Route MCP action types before falling through to the env action_cls. + # Only intercept when action_cls is the generic Action base or itself an + # MCP type (i.e. the server hosts an MCP environment). This avoids + # silently bypassing env-specific validation for non-MCP environments + # that happen to use "call_tool" / "list_tools" as a type discriminator. + action_type = action_data.get("type") + if action_type in _MCP_ACTION_TYPES: + mcp_cls = _MCP_ACTION_TYPES[action_type] + if action_cls is Action or action_cls in _MCP_ACTION_TYPES.values(): + return mcp_cls.model_validate(action_data) + return action_cls.model_validate(action_data) @@ -62,6 +87,15 @@ def deserialize_action_with_preprocessing( Raises: ValidationError: If action_data is invalid for the action class """ + # Route MCP action types before preprocessing (they don't need it). + # Same guard as deserialize_action: only intercept when action_cls is + # the generic Action base or itself an MCP type. + action_type = action_data.get("type") + if action_type in _MCP_ACTION_TYPES: + mcp_cls = _MCP_ACTION_TYPES[action_type] + if action_cls is Action or action_cls in _MCP_ACTION_TYPES.values(): + return mcp_cls.model_validate(action_data) + processed_data = {} for key, value in action_data.items(): diff --git a/src/core/openenv/core/env_server/web_interface.py b/src/core/openenv/core/env_server/web_interface.py index 284740eb408b8e2b798037918967b7a50abee72d..026093887cbb43e995df64881c849ca6ed4ac5de 100644 --- a/src/core/openenv/core/env_server/web_interface.py +++ b/src/core/openenv/core/env_server/web_interface.py @@ -15,13 +15,15 @@ option (e.g. openenv push --enable-interface) or ENABLE_WEB_INTERFACE env var. from __future__ import annotations import asyncio +import inspect import json from concurrent.futures import ThreadPoolExecutor from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Type import gradio as gr -from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from fastapi import Body, FastAPI, HTTPException, status, WebSocket, WebSocketDisconnect +from fastapi.responses import RedirectResponse from pydantic import BaseModel, ConfigDict, Field from .gradio_theme import OPENENV_GRADIO_CSS, OPENENV_GRADIO_THEME @@ -269,6 +271,28 @@ class WebInterfaceManager: # Thread pool for running sync code (e.g., Playwright sync API) in async context self._executor = ThreadPoolExecutor(max_workers=1) + @staticmethod + def _get_valid_kwargs( + sig: inspect.Signature, + kwargs: Dict[str, Any], + skip_params: Optional[set[str]] = None, + ) -> Dict[str, Any]: + """Filter kwargs to only those accepted by the target function.""" + skip_params = skip_params or set() + valid_kwargs: Dict[str, Any] = {} + has_var_kwargs = any( + param.kind == inspect.Parameter.VAR_KEYWORD + for param in sig.parameters.values() + ) + + for key, value in kwargs.items(): + if key in skip_params: + continue + if key in sig.parameters or has_var_kwargs: + valid_kwargs[key] = value + + return valid_kwargs + async def _run_sync_in_thread_pool(self, func, *args, **kwargs): """Run a synchronous function in the thread pool executor. @@ -317,11 +341,24 @@ class WebInterfaceManager: for client in disconnected_clients: self.connected_clients.remove(client) - async def reset_environment(self) -> Dict[str, Any]: + async def reset_environment( + self, reset_kwargs: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: """Reset the environment and update state.""" - # Run sync reset in thread pool to avoid blocking event loop - # and to support environments using sync libraries (e.g., Playwright) - observation: Observation = await self._run_sync_in_thread_pool(self.env.reset) + reset_kwargs = reset_kwargs or {} + + is_async = self.env.reset_async.__func__ is not Environment.reset_async + sig = inspect.signature(self.env.reset_async if is_async else self.env.reset) + valid_kwargs = self._get_valid_kwargs(sig, reset_kwargs) + + if is_async: + observation = await self.env.reset_async(**valid_kwargs) + else: + # Run sync reset in thread pool to avoid blocking event loop + # and to support environments using sync libraries (e.g., Playwright) + observation = await self._run_sync_in_thread_pool( + self.env.reset, **valid_kwargs + ) state: State = self.env.state # Serialize observation once using shared utility @@ -428,6 +465,16 @@ def create_web_interface_app( web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata) # Web API routes first (so they take precedence over Gradio mount at /web) + @app.get("/", include_in_schema=False) + async def web_root(): + """Redirect the app root to the Gradio interface.""" + return RedirectResponse(url="/web/") + + @app.get("/web", include_in_schema=False) + async def web_root_no_slash(): + """Redirect /web to /web/ for mounted Gradio deployments behind proxies.""" + return RedirectResponse(url="/web/") + @app.get("/web/metadata") async def web_metadata(): """Get environment metadata.""" @@ -449,9 +496,9 @@ def create_web_interface_app( await web_manager.disconnect_websocket(websocket) @app.post("/web/reset") - async def web_reset(): + async def web_reset(request: Optional[Dict[str, Any]] = Body(default=None)): """Reset endpoint for web interface.""" - return await web_manager.reset_environment() + return await web_manager.reset_environment(request) @app.post("/web/step") async def web_step(request: Dict[str, Any]): @@ -475,7 +522,13 @@ def create_web_interface_app( @app.get("/web/state") async def web_state(): """State endpoint for web interface.""" - return web_manager.get_state() + try: + return web_manager.get_state() + except RuntimeError as exc: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=str(exc), + ) from exc action_fields = _extract_action_fields(action_cls) is_chat_env = _is_chat_env(action_cls) @@ -505,7 +558,7 @@ def create_web_interface_app( ) gradio_blocks = gr.TabbedInterface( [default_blocks, custom_blocks], - tab_names=["Playground", "Visualization"], + tab_names=["Playground", "Custom"], title=get_gradio_display_title(metadata), ) else: diff --git a/src/core/openenv/core/mcp_client.py b/src/core/openenv/core/mcp_client.py index edac3529d3a34e798781d86cf4d2495dc9611713..1d8bd38efd3595526fc25915c8fdbbe7aaeca5d5 100644 --- a/src/core/openenv/core/mcp_client.py +++ b/src/core/openenv/core/mcp_client.py @@ -52,6 +52,7 @@ Example (sync wrapper): ... result = env.call_tool("echo_message", message="Hello!") """ +import asyncio from typing import Any, Dict, List, Optional from .client_types import StepResult @@ -118,6 +119,66 @@ class MCPClientBase(EnvClient[Any, Observation, State]): ) self._tools_cache: Optional[List[Tool]] = None self.use_production_mode = False + self._production_session_id: Optional[str] = None + self._production_session_lock = asyncio.Lock() + self._jsonrpc_request_id = 0 + self._http_client: Optional[Any] = None # lazily-created httpx.AsyncClient + + def _next_request_id(self) -> int: + """Generate a monotonically increasing JSON-RPC request id.""" + self._jsonrpc_request_id += 1 + return self._jsonrpc_request_id + + def _production_mcp_url(self) -> str: + """Build HTTP MCP endpoint URL from the client's websocket URL.""" + url = self._ws_url.replace("ws://", "http://").replace("wss://", "https://") + if url.endswith("/ws"): + url = url[: -len("/ws")] + return url.rstrip("/") + "/mcp" + + async def _get_http_client(self) -> Any: + """Return a shared httpx.AsyncClient, creating one lazily.""" + if self._http_client is None: + import httpx + + self._http_client = httpx.AsyncClient() + return self._http_client + + async def _production_mcp_request( + self, method: str, params: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """Send a JSON-RPC request to HTTP /mcp and return parsed JSON response.""" + client = await self._get_http_client() + response = await client.post( + self._production_mcp_url(), + json={ + "jsonrpc": "2.0", + "method": method, + "params": params or {}, + "id": self._next_request_id(), + }, + timeout=self._message_timeout, + ) + response.raise_for_status() + return response.json() + + async def _ensure_production_session(self) -> str: + """Create and cache a persistent HTTP MCP session id if needed.""" + async with self._production_session_lock: + if self._production_session_id is not None: + return self._production_session_id + + data = await self._production_mcp_request("openenv/session/create") + if "error" in data: + message = data.get("error", {}).get("message", "unknown error") + raise RuntimeError(f"Failed to create MCP session: {message}") + + session_id = data.get("result", {}).get("session_id") + if not session_id: + raise RuntimeError("Failed to create MCP session: missing session_id") + + self._production_session_id = session_id + return session_id async def list_tools(self, use_cache: bool = True) -> List[Tool]: """ @@ -138,26 +199,18 @@ class MCPClientBase(EnvClient[Any, Observation, State]): if use_cache and self._tools_cache is not None: return self._tools_cache - # Use production mode HTTP endpoint if enabled - if self.use_production_mode: - import requests - - # Convert ws:// URL to http:// URL - url = self._ws_url.replace("ws://", "http://").replace("wss://", "https://") - # Remove /ws suffix if present and add /mcp - url = url.rstrip("/ws").rstrip("/") + "/mcp" - + # Use production mode HTTP endpoint if enabled. + # Some tests instantiate with __new__ and skip __init__, so default missing flag to False. + if getattr(self, "use_production_mode", False): try: - response = requests.post( - url, - json={ - "jsonrpc": "2.0", - "method": "tools/list", - "params": {}, - "id": 1, - }, + session_id = await self._ensure_production_session() + data = await self._production_mcp_request( + "tools/list", + {"session_id": session_id}, ) - data = response.json() + if "error" in data: + message = data.get("error", {}).get("message", "unknown error") + raise RuntimeError(f"list_tools failed: {message}") if "result" in data and "tools" in data["result"]: tools = [ Tool( @@ -177,7 +230,12 @@ class MCPClientBase(EnvClient[Any, Observation, State]): return [] result = await self.step(ListToolsAction()) - self._tools_cache = result.observation.tools + if isinstance(result.observation, ListToolsObservation): + self._tools_cache = result.observation.tools + return self._tools_cache + + # Unexpected observation type; keep API stable with an empty tool list. + self._tools_cache = [] return self._tools_cache def _step_payload(self, action: Any) -> Dict[str, Any]: @@ -251,6 +309,35 @@ class MCPClientBase(EnvClient[Any, Observation, State]): step_count=payload.get("step_count", 0), ) + async def close(self) -> None: + """ + Close client resources. + + In production MCP mode, this also closes the server-side persistent + MCP session (best effort) before closing websocket/provider resources. + """ + if self._production_session_id is not None: + try: + await self._production_mcp_request( + "openenv/session/close", + {"session_id": self._production_session_id}, + ) + except Exception: + # Best effort cleanup - do not mask normal close behavior + pass + finally: + self._production_session_id = None + + if self._http_client is not None: + try: + await self._http_client.aclose() + except Exception: + pass + finally: + self._http_client = None + + await super().close() + class MCPToolClient(MCPClientBase): """ @@ -316,6 +403,26 @@ class MCPToolClient(MCPClientBase): >>> result = await env.call_tool("greet", name="Claude") >>> print(result) # "Hello, Claude!" """ + if getattr(self, "use_production_mode", False): + session_id = await self._ensure_production_session() + data = await self._production_mcp_request( + "tools/call", + { + "name": name, + "arguments": kwargs, + "session_id": session_id, + }, + ) + + if "error" in data: + message = data.get("error", {}).get("message", "unknown error") + raise RuntimeError(f"Tool '{name}' failed: {message}") + + result = data.get("result") + if isinstance(result, dict) and "data" in result: + return result["data"] + return result + action = CallToolAction(tool_name=name, arguments=kwargs) result = await self.step(action) obs = result.observation diff --git a/src/openenv/__init__.py b/src/openenv/__init__.py index cabe2abc6a70dacafe04f0583b27b2552bab1e47..ef29784ad031f4601adcbefc8bf9d3b9137c353f 100644 --- a/src/openenv/__init__.py +++ b/src/openenv/__init__.py @@ -14,10 +14,18 @@ __all__ = [ "SyncEnvClient", ] -try: - __version__ = metadata.version("openenv") # type: ignore[arg-type] -except metadata.PackageNotFoundError: # pragma: no cover - local dev - __version__ = "0.0.0" + +def _load_package_version() -> str: + """Resolve the installed distribution version for the OpenEnv package.""" + for distribution_name in ("openenv-core", "openenv"): + try: + return metadata.version(distribution_name) + except metadata.PackageNotFoundError: + continue + return "0.0.0" + + +__version__ = _load_package_version() _LAZY_MODULES = { diff --git a/src/openenv/cli/templates/openenv_env/pyproject.toml b/src/openenv/cli/templates/openenv_env/pyproject.toml index a8e59fbfa3dbc8a0df7c84d479e79cef062d8e61..b63103db9111f91be99328cad38b351e89810eb8 100644 --- a/src/openenv/cli/templates/openenv_env/pyproject.toml +++ b/src/openenv/cli/templates/openenv_env/pyproject.toml @@ -17,7 +17,7 @@ dependencies = [ # Core OpenEnv runtime (provides FastAPI server + HTTP client types) # install from github # "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git", - "openenv-core[core]>=0.2.1", + "openenv-core[core]>=0.2.2", # Environment-specific dependencies # Add all dependencies needed for your environment here # Examples: diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py index 658f63ef98bf78d278b8926271c217da23c79a37..f59012b60d335e596fc25866db4c64cbeafaa5a3 100644 --- a/src/openenv/core/env_server/http_server.py +++ b/src/openenv/core/env_server/http_server.py @@ -16,11 +16,15 @@ from __future__ import annotations import asyncio import inspect import json +import logging import os import time import uuid from concurrent.futures import ThreadPoolExecutor -from typing import Any, Callable, Dict, Optional, Type +from contextlib import AsyncExitStack +from typing import Any, AsyncContextManager, Callable, cast, Dict, Optional, Type + +_MISSING = object() from fastapi import ( Body, @@ -204,8 +208,9 @@ class HTTPEnvServer: self.observation_cls = observation_cls # Session management for WebSocket connections - self._sessions: Dict[str, Environment] = {} + self._sessions: Dict[str, Optional[Environment]] = {} self._session_executors: Dict[str, ThreadPoolExecutor] = {} + self._session_stacks: Dict[str, AsyncExitStack] = {} self._session_info: Dict[str, SessionInfo] = {} self._session_lock = asyncio.Lock() @@ -213,6 +218,14 @@ class HTTPEnvServer: # This is needed for environments using sync libraries (e.g., Playwright) self._executor = ThreadPoolExecutor(max_workers=32) + # Idle session reaper configuration. + # Timeout is taken from ConcurrencyConfig.session_timeout; + # None means no timeout (default — reaper is a no-op). + self._session_idle_timeout_s: Optional[float] = ( + self._concurrency_config.session_timeout + ) + self._reaper_task: Optional[asyncio.Task[None]] = None + def _validate_concurrency_safety(self) -> None: """ Validate that the environment supports the configured concurrency level. @@ -321,12 +334,37 @@ class HTTPEnvServer: ) raise EnvironmentFactoryError(factory_name) from e + # Hold the MCP session open for the lifetime of this session, + # matching the WebSocket path's AsyncExitStack pattern. This + # prevents per-request MCP transport teardown/reconnection and + # preserves FastMCP session state (ctx.set_state / ctx.get_state) + # across HTTP calls within the same OpenEnv session. + stack = AsyncExitStack() + try: + mcp_session_factory = getattr(env, "mcp_session", None) + if callable(mcp_session_factory): + mcp_session_cm = cast(AsyncContextManager[Any], mcp_session_factory()) + await stack.enter_async_context(mcp_session_cm) + except Exception: + # MCP transport failed to start — clean up the reserved slot, + # the env, and the executor so they don't leak permanently + # against _max_concurrent_envs. + await stack.aclose() # best-effort + async with self._session_lock: + self._sessions.pop(session_id, None) + self._session_executors.pop(session_id, None) + self._session_info.pop(session_id, None) + await self._cleanup_session_resources(env, executor) + raise + async with self._session_lock: self._sessions[session_id] = env + self._session_stacks[session_id] = stack + now = time.time() self._session_info[session_id] = SessionInfo( session_id=session_id, created_at=current_time, - last_activity_at=current_time, + last_activity_at=now, step_count=0, environment_type=type(env).__name__, ) @@ -343,8 +381,27 @@ class HTTPEnvServer: async with self._session_lock: env = self._sessions.pop(session_id, None) executor = self._session_executors.pop(session_id, None) + stack = self._session_stacks.pop(session_id, None) self._session_info.pop(session_id, None) + await self._cleanup_session_resources(env, executor, stack) + + async def _cleanup_session_resources( + self, + env: Optional[Environment], + executor: Optional[ThreadPoolExecutor], + stack: Optional[AsyncExitStack] = None, + ) -> None: + """Close an environment and shut down its executor (best-effort).""" + # Close the MCP session stack first — this gracefully exits the + # mcp_session() context (and the underlying FastMCP Client session) + # before we tear down the environment references. + if stack is not None: + try: + await stack.aclose() + except Exception: + pass # Best effort cleanup + # Run close() in the same executor where the env was created # This is required for thread-sensitive libraries like Playwright/greenlet if env is not None: @@ -383,6 +440,51 @@ class HTTPEnvServer: if increment_step: self._session_info[session_id].step_count += 1 + async def _reap_idle_sessions(self) -> None: + """Background task that periodically destroys sessions idle beyond the timeout.""" + timeout = self._session_idle_timeout_s + if timeout is None: + return # no timeout configured — noop + interval = max(timeout / 4, 5.0) # check frequently enough + while True: + try: + await asyncio.sleep(interval) + now = time.time() + stale_ids: list[str] = [] + async with self._session_lock: + for sid, info in self._session_info.items(): + if now - info.last_activity_at > timeout: + stale_ids.append(sid) + for sid in stale_ids: + # Re-check under lock: activity may have arrived since + # the snapshot was taken, making this session active again. + # Refresh `now` so slow _destroy_session calls don't cause + # subsequent entries to be validated against a stale clock. + now = time.time() + async with self._session_lock: + info = self._session_info.get(sid) + if info is None or (now - info.last_activity_at) <= timeout: + continue + await self._destroy_session(sid) + except asyncio.CancelledError: + break + except Exception as exc: + logging.getLogger(__name__).warning( + "Idle-session reaper encountered an error (will retry): %s", + exc, + ) + + def _start_reaper(self) -> None: + """Start the idle-session reaper if a timeout is configured.""" + if self._session_idle_timeout_s is not None and self._reaper_task is None: + self._reaper_task = asyncio.create_task(self._reap_idle_sessions()) + + def _stop_reaper(self) -> None: + """Cancel the reaper background task.""" + if self._reaper_task is not None: + self._reaper_task.cancel() + self._reaper_task = None + def get_session_info(self, session_id: str) -> Optional[SessionInfo]: """ Get information about a specific session. @@ -458,6 +560,20 @@ class HTTPEnvServer: f"Invalid mode: '{mode}'. Must be one of: {valid_modes}" ) + # Wire up idle-session reaper lifecycle via app events + server_ref = self + + async def _start_session_reaper() -> None: + server_ref._start_reaper() + + async def _stop_session_reaper() -> None: + server_ref._stop_reaper() + + if not getattr(app.router, "_openenv_reaper_registered", False): + app.router.on_startup.append(_start_session_reaper) + app.router.on_shutdown.append(_stop_session_reaper) + app.router._openenv_reaper_registered = True # type: ignore[attr-defined] + # Helper function to handle reset endpoint async def reset_handler( request: ResetRequest = Body(default_factory=ResetRequest), @@ -526,53 +642,214 @@ class HTTPEnvServer: # Helper function to handle MCP endpoint async def mcp_handler( - request: JsonRpcRequest, session_env: Optional[Environment] = None + request: JsonRpcRequest, + session_env: Optional[Environment] = None, + session_id: Optional[str] = None, ) -> JsonRpcResponse: """ Handle MCP JSON-RPC requests. - Supports tools/list and tools/call methods in JSON-RPC 2.0 format. + Supports tools/list and tools/call methods in JSON-RPC 2.0 format, + plus OpenEnv session lifecycle methods for HTTP MCP: + - openenv/session/create + - openenv/session/close """ method = request.method request_id = request.id + params = request.params + if not isinstance(params, dict): + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_PARAMS, + "Params must be an object", + request_id=request_id, + ) + + # OpenEnv extension methods for explicit MCP session management. + # This enables persistent MCP lifecycles over HTTP /mcp, matching WebSocket semantics. + if method == "openenv/session/create": + if session_env is not None and session_id is not None: + return JsonRpcResponse.success( + result={"session_id": session_id}, + request_id=request_id, + ) + try: + created_session_id, _ = await self._create_session() + except SessionCapacityError as e: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.SERVER_ERROR, + str(e), + request_id=request_id, + data={ + "active_sessions": e.active_sessions, + "max_sessions": e.max_sessions, + }, + ) + except EnvironmentFactoryError as e: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.SERVER_ERROR, + str(e), + request_id=request_id, + data={"factory_name": e.factory_name}, + ) + return JsonRpcResponse.success( + result={"session_id": created_session_id}, + request_id=request_id, + ) + + if method == "openenv/session/close": + target_session_id = params.get("session_id") + if not target_session_id: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_PARAMS, + "Invalid params - 'session_id' is required", + request_id=request_id, + ) + + if session_id is not None and target_session_id == session_id: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_REQUEST, + "Cannot close active WebSocket-managed session via MCP method", + request_id=request_id, + ) + + async with self._session_lock: + env = self._sessions.pop(target_session_id, _MISSING) + if env is not _MISSING: + executor = self._session_executors.pop(target_session_id, None) + stack = self._session_stacks.pop(target_session_id, None) + self._session_info.pop(target_session_id, None) + else: + executor = None + stack = None + + if env is _MISSING: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_PARAMS, + f"Unknown session_id: {target_session_id}", + request_id=request_id, + ) + + if env is None: + # Session slot reserved but env factory still running; + # re-insert the placeholder AND the executor so + # _create_session can finish and the executor remains + # tracked for eventual shutdown. + async with self._session_lock: + self._sessions[target_session_id] = None + if executor is not None: + self._session_executors[target_session_id] = executor + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_REQUEST, + f"Session {target_session_id} is still initializing; retry shortly", + request_id=request_id, + ) + + # env/executor/stack cleanup outside the lock + await self._cleanup_session_resources(env, executor, stack) + return JsonRpcResponse.success( + result={"session_id": target_session_id, "closed": True}, + request_id=request_id, + ) + + requested_session_id = params.get("session_id") + managed_session_id = session_id # Use provided session environment or create temporary one if session_env is not None: _env = session_env should_close = False + elif requested_session_id: + async with self._session_lock: + _env = self._sessions.get(requested_session_id, _MISSING) + + if _env is _MISSING: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_PARAMS, + f"Unknown session_id: {requested_session_id}", + request_id=request_id, + ) + + if _env is None: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_REQUEST, + f"Session {requested_session_id} is still initializing; retry shortly", + request_id=request_id, + ) + + should_close = False + managed_session_id = requested_session_id else: _env = self._env_factory() should_close = True try: + mcp_client = getattr(_env, "mcp_client", None) + mcp_server = getattr(_env, "mcp_server", None) + mcp_session_factory = getattr(_env, "mcp_session", None) + if method == McpMethod.TOOLS_LIST: # Check if environment is MCP-enabled - if not hasattr(_env, "mcp_client"): + if mcp_client is None and mcp_server is None: return JsonRpcResponse.error_response( JsonRpcErrorCode.INTERNAL_ERROR, "Environment does not support MCP", request_id=request_id, ) - # Use async context manager for MCP client - async with _env.mcp_client: - tools = await _env.mcp_client.list_tools() + if mcp_client: + if managed_session_id and mcp_client.is_connected(): + # Session-managed with live transport — call + # directly, no redundant re-entry. + tools = await mcp_client.list_tools() + elif callable(mcp_session_factory): + # Stateless request, or session-managed but the + # background transport was lost: (re-)open. + mcp_session_cm = cast( + AsyncContextManager[Any], mcp_session_factory() + ) + async with mcp_session_cm: + tools = await mcp_client.list_tools() + else: + async with mcp_client: + tools = await mcp_client.list_tools() + + return JsonRpcResponse.success( + result={ + "tools": [ + t.model_dump() + if hasattr(t, "model_dump") + else dict(t) + for t in tools + ] + }, + request_id=request_id, + ) - return JsonRpcResponse.success( - result={ - "tools": [ - t.model_dump() if hasattr(t, "model_dump") else dict(t) - for t in tools - ] - }, + if mcp_server: + tools = [] + for _tool_name, tool in get_server_tools(mcp_server).items(): + tools.append( + { + "name": tool.name, + "description": tool.description or "", + "inputSchema": tool.parameters or {}, + } + ) + return JsonRpcResponse.success( + result={"tools": tools}, + request_id=request_id, + ) + + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + "MCP server not available", request_id=request_id, ) elif method == McpMethod.TOOLS_CALL: - params = request.params tool_name = params.get("name") arguments = params.get("arguments", {}) - if not hasattr(_env, "mcp_client"): + if mcp_client is None and mcp_server is None: return JsonRpcResponse.error_response( JsonRpcErrorCode.INTERNAL_ERROR, "Environment does not support MCP", @@ -581,15 +858,51 @@ class HTTPEnvServer: if not tool_name: return JsonRpcResponse.error_response( - JsonRpcErrorCode.INVALID_REQUEST, + JsonRpcErrorCode.INVALID_PARAMS, "Missing 'name' in params", request_id=request_id, ) - # Use async context manager for MCP client - async with _env.mcp_client: - result = await _env.mcp_client.call_tool( - name=tool_name, arguments=arguments + if mcp_client: + if managed_session_id and mcp_client.is_connected(): + # Session-managed with live transport. + result = await mcp_client.call_tool( + name=tool_name, arguments=arguments + ) + elif callable(mcp_session_factory): + # Stateless request, or session-managed but the + # background transport was lost: (re-)open. + mcp_session_cm = cast( + AsyncContextManager[Any], mcp_session_factory() + ) + async with mcp_session_cm: + result = await mcp_client.call_tool( + name=tool_name, arguments=arguments + ) + else: + async with mcp_client: + result = await mcp_client.call_tool( + name=tool_name, arguments=arguments + ) + elif mcp_server: + server_tools = get_server_tools(mcp_server) + if tool_name in server_tools: + tool = server_tools[tool_name] + if inspect.iscoroutinefunction(tool.fn): + result = await tool.fn(**arguments) + else: + result = tool.fn(**arguments) + else: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_PARAMS, + f"Tool not found: {tool_name}", + request_id=request_id, + ) + else: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + "MCP server not available", + request_id=request_id, ) # Ensure result is JSON serializable @@ -614,6 +927,11 @@ class HTTPEnvServer: request_id=request_id, ) finally: + if managed_session_id: + self._update_session_activity( + managed_session_id, + increment_step=(method == McpMethod.TOOLS_CALL), + ) if should_close: _env.close() @@ -637,42 +955,59 @@ class HTTPEnvServer: try: # Create session with dedicated environment session_id, session_env = await self._create_session() + if session_env is None: + raise RuntimeError( + "Session environment not initialized for MCP websocket" + ) - while True: - # Receive message from client - raw_message = await websocket.receive_text() - - try: - jsonrpc_dict = json.loads(raw_message) - jsonrpc_request = JsonRpcRequest(**jsonrpc_dict) - except json.JSONDecodeError as e: - error_resp = JsonRpcResponse.error_response( - JsonRpcErrorCode.PARSE_ERROR, - f"Parse error: {e}", - ) - await websocket.send_text(error_resp.model_dump_json()) - continue - except ValidationError as e: - error_resp = JsonRpcResponse.error_response( - JsonRpcErrorCode.INVALID_REQUEST, - f"Invalid request: {e}", - ) - await websocket.send_text(error_resp.model_dump_json()) - continue + # If environment has an mcp_session context manager, hold it open + # for the lifetime of the websocket connection - try: - # Call mcp_handler with session environment - response = await mcp_handler( - jsonrpc_request, session_env=session_env + async with AsyncExitStack() as stack: + mcp_session_factory = getattr(session_env, "mcp_session", None) + if callable(mcp_session_factory): + mcp_session_cm = cast( + AsyncContextManager[Any], mcp_session_factory() ) - await websocket.send_text(response.model_dump_json()) - except Exception as e: - error_resp = JsonRpcResponse.error_response( - JsonRpcErrorCode.INTERNAL_ERROR, - str(e), - request_id=jsonrpc_request.id, - ) - await websocket.send_text(error_resp.model_dump_json()) + await stack.enter_async_context(mcp_session_cm) + + while True: + # Receive message from client + raw_message = await websocket.receive_text() + + try: + jsonrpc_dict = json.loads(raw_message) + jsonrpc_request = JsonRpcRequest(**jsonrpc_dict) + except json.JSONDecodeError as e: + error_resp = JsonRpcResponse.error_response( + JsonRpcErrorCode.PARSE_ERROR, + f"Parse error: {e}", + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + except ValidationError as e: + error_resp = JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_REQUEST, + f"Invalid request: {e}", + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + + try: + # Call mcp_handler with session environment + response = await mcp_handler( + jsonrpc_request, + session_env=session_env, + session_id=session_id, + ) + await websocket.send_text(response.model_dump_json()) + except Exception as e: + error_resp = JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + str(e), + request_id=jsonrpc_request.id, + ) + await websocket.send_text(error_resp.model_dump_json()) except WebSocketDisconnect: pass @@ -931,120 +1266,8 @@ all schema information needed to interact with the environment. JsonRpcErrorCode.PARSE_ERROR ).model_dump() - method = request.method - params = request.params - request_id = request.id - - # Create a temporary environment for MCP access - _env = self._env_factory() - - try: - # Check if environment supports MCP - if not hasattr(_env, "mcp_client") and not hasattr(_env, "mcp_server"): - return JsonRpcResponse.error_response( - JsonRpcErrorCode.INTERNAL_ERROR, - "Environment does not support MCP", - request_id=request_id, - ).model_dump() - - if method == McpMethod.TOOLS_LIST: - # List tools from MCP server - if hasattr(_env, "mcp_client") and _env.mcp_client: - async with _env.mcp_client: - tools = await _env.mcp_client.list_tools() - return JsonRpcResponse.success( - result={ - "tools": [ - t.model_dump() - if hasattr(t, "model_dump") - else dict(t) - for t in tools - ] - }, - request_id=request_id, - ).model_dump() - elif hasattr(_env, "mcp_server") and _env.mcp_server: - # Use server directly - tools = [] - for tool_name, tool in get_server_tools( - _env.mcp_server - ).items(): - tool_dict = { - "name": tool.name, - "description": tool.description or "", - "inputSchema": tool.parameters or {}, - } - tools.append(tool_dict) - return JsonRpcResponse.success( - result={"tools": tools}, - request_id=request_id, - ).model_dump() - else: - return JsonRpcResponse.error_response( - JsonRpcErrorCode.INTERNAL_ERROR, - "MCP server not available", - request_id=request_id, - ).model_dump() - - elif method == McpMethod.TOOLS_CALL: - tool_name = params.get("name") - arguments = params.get("arguments", {}) - - if not tool_name: - return JsonRpcResponse.error_response( - JsonRpcErrorCode.INVALID_PARAMS, - "Invalid params - 'name' is required", - request_id=request_id, - ).model_dump() - - # Call tool via MCP - if hasattr(_env, "mcp_client") and _env.mcp_client: - async with _env.mcp_client: - result = await _env.mcp_client.call_tool( - name=tool_name, arguments=arguments - ) - elif hasattr(_env, "mcp_server") and _env.mcp_server: - # Call tool directly on FastMCP server - server_tools = get_server_tools(_env.mcp_server) - if tool_name in server_tools: - tool = server_tools[tool_name] - result = tool.fn(**arguments) - else: - return JsonRpcResponse.error_response( - JsonRpcErrorCode.INVALID_PARAMS, - f"Tool not found: {tool_name}", - request_id=request_id, - ).model_dump() - else: - return JsonRpcResponse.error_response( - JsonRpcErrorCode.INTERNAL_ERROR, - "MCP server not available", - request_id=request_id, - ).model_dump() - - # Make result JSON serializable - serializable_result = _make_json_serializable(result) - - return JsonRpcResponse.success( - result=serializable_result, - request_id=request_id, - ).model_dump() - - else: - return JsonRpcResponse.error_response( - JsonRpcErrorCode.METHOD_NOT_FOUND, - f"Method not found: {method}", - request_id=request_id, - ).model_dump() - - except Exception as e: - return JsonRpcResponse.error_response( - JsonRpcErrorCode.INTERNAL_ERROR, - str(e), - request_id=request_id, - ).model_dump() - finally: - _env.close() + response = await mcp_handler(request) + return response.model_dump() # Register WebSocket endpoint for persistent sessions @app.websocket("/ws") @@ -1066,135 +1289,167 @@ all schema information needed to interact with the environment. try: # Create session with dedicated environment session_id, session_env = await self._create_session() + if session_env is None: + raise RuntimeError( + "Session environment not initialized for websocket" + ) - while True: - # Receive message from client - raw_message = await websocket.receive_text() + # Keep MCP session open for entire websocket lifetime + # (avoids reconnect overhead on every message) - try: - message_dict = json.loads(raw_message) - except json.JSONDecodeError as e: - error_resp = WSErrorResponse( - data={ - "message": f"Invalid JSON: {e}", - "code": WSErrorCode.INVALID_JSON, - } + async with AsyncExitStack() as stack: + mcp_session_factory = getattr(session_env, "mcp_session", None) + if callable(mcp_session_factory): + mcp_session_cm = cast( + AsyncContextManager[Any], mcp_session_factory() ) - await websocket.send_text(error_resp.model_dump_json()) - continue - - msg_type = message_dict.get("type", "") + await stack.enter_async_context(mcp_session_cm) + + while True: + # Receive message from client + raw_message = await websocket.receive_text() + + try: + message_dict = json.loads(raw_message) + except json.JSONDecodeError as e: + error_resp = WSErrorResponse( + data={ + "message": f"Invalid JSON: {e}", + "code": WSErrorCode.INVALID_JSON, + } + ) + await websocket.send_text(error_resp.model_dump_json()) + continue - try: - match msg_type: - case "reset": - msg = WSResetMessage(**message_dict) + msg_type = message_dict.get("type", "") - is_async = ( - session_env.reset_async.__func__ - is not Environment.reset_async - ) + try: + match msg_type: + case "reset": + msg = WSResetMessage(**message_dict) - if is_async: - sig = inspect.signature(session_env.reset_async) - valid_kwargs = self._get_valid_kwargs(sig, msg.data) - observation = await session_env.reset_async( - **valid_kwargs + is_async = ( + session_env.reset_async.__func__ + is not Environment.reset_async ) - else: - sig = inspect.signature(session_env.reset) - valid_kwargs = self._get_valid_kwargs(sig, msg.data) - observation = await self._run_in_session_executor( - session_id, session_env.reset, **valid_kwargs - ) - - self._update_session_activity(session_id) - - response = WSObservationResponse( - data=serialize_observation(observation), - ) - case "step": - msg = WSStepMessage(**message_dict) - action = deserialize_action(msg.data, self.action_cls) + if is_async: + sig = inspect.signature(session_env.reset_async) + valid_kwargs = self._get_valid_kwargs( + sig, msg.data + ) + observation = await session_env.reset_async( + **valid_kwargs + ) + else: + sig = inspect.signature(session_env.reset) + valid_kwargs = self._get_valid_kwargs( + sig, msg.data + ) + observation = ( + await self._run_in_session_executor( + session_id, + session_env.reset, + **valid_kwargs, + ) + ) + + self._update_session_activity(session_id) + + response = WSObservationResponse( + data=serialize_observation(observation), + ) - is_async = ( - session_env.step_async.__func__ - is not Environment.step_async - ) + case "step": + msg = WSStepMessage(**message_dict) + action = deserialize_action( + msg.data, self.action_cls + ) - if is_async: - observation = await session_env.step_async(action) - else: - observation = await self._run_in_session_executor( - session_id, session_env.step, action + is_async = ( + session_env.step_async.__func__ + is not Environment.step_async ) - self._update_session_activity( - session_id, increment_step=True - ) + if is_async: + observation = await session_env.step_async( + action + ) + else: + observation = ( + await self._run_in_session_executor( + session_id, session_env.step, action + ) + ) + + self._update_session_activity( + session_id, increment_step=True + ) - response = WSObservationResponse( - data=serialize_observation(observation) - ) + response = WSObservationResponse( + data=serialize_observation(observation) + ) - case "state": - msg = WSStateMessage(**message_dict) - state = session_env.state - if hasattr(state, "model_dump"): - state_data = state.model_dump() - else: - state_data = dict(state) if state else {} - - response = WSStateResponse(data=state_data) - - case "close": - msg = WSCloseMessage(**message_dict) - break - - case "mcp": - msg = WSMCPMessage(**message_dict) - try: - rpc_request = JsonRpcRequest(**msg.data) - except (ValidationError, Exception) as e: - rpc_response = JsonRpcResponse.error_response( - JsonRpcErrorCode.INVALID_REQUEST, - f"Invalid request: {e}", + case "state": + msg = WSStateMessage(**message_dict) + state = session_env.state + if hasattr(state, "model_dump"): + state_data = state.model_dump() + else: + state_data = dict(state) if state else {} + + response = WSStateResponse(data=state_data) + + case "close": + msg = WSCloseMessage(**message_dict) + break + + case "mcp": + msg = WSMCPMessage(**message_dict) + try: + rpc_request = JsonRpcRequest(**msg.data) + except (ValidationError, Exception) as e: + rpc_response = JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_REQUEST, + f"Invalid request: {e}", + ) + else: + rpc_response = await mcp_handler( + rpc_request, + session_env=session_env, + session_id=session_id, + ) + response = WSMCPResponse( + data=rpc_response.model_dump() ) - else: - rpc_response = await mcp_handler( - rpc_request, - session_env=session_env, + + case _: + response = WSErrorResponse( + data={ + "message": f"Unknown message type: {msg_type}", + "code": WSErrorCode.UNKNOWN_TYPE, + } ) - response = WSMCPResponse(data=rpc_response.model_dump()) - - case _: - response = WSErrorResponse( - data={ - "message": f"Unknown message type: {msg_type}", - "code": WSErrorCode.UNKNOWN_TYPE, - } - ) - await websocket.send_text(response.model_dump_json()) + await websocket.send_text(response.model_dump_json()) - except ValidationError as e: - error_resp = WSErrorResponse( - data={ - "message": "Invalid message", - "code": WSErrorCode.VALIDATION_ERROR, - "errors": e.errors(), - } - ) - await websocket.send_text(error_resp.model_dump_json()) - except Exception as e: - error_resp = WSErrorResponse( - data={ - "message": str(e), - "code": WSErrorCode.EXECUTION_ERROR, - } - ) - await websocket.send_text(error_resp.model_dump_json()) + except ValidationError as e: + error_resp = WSErrorResponse( + data={ + "message": "Invalid message", + "code": WSErrorCode.VALIDATION_ERROR, + "errors": e.errors(), + } + ) + await websocket.send_text(error_resp.model_dump_json()) + except Exception as e: + error_resp = WSErrorResponse( + data={ + "message": str(e), + "code": WSErrorCode.EXECUTION_ERROR, + } + ) + await websocket.send_text(error_resp.model_dump_json()) except WebSocketDisconnect: pass @@ -1276,7 +1531,7 @@ def create_app( from .web_interface import create_web_interface_app return create_web_interface_app( - env, + cast(Any, env), action_cls, observation_cls, env_name, diff --git a/src/openenv/core/env_server/mcp_environment.py b/src/openenv/core/env_server/mcp_environment.py index 03f66e37897ec81796d468f3d0590d465deddea1..50ddec98d3e769bf241076a6deb3e4ff6cb229e6 100644 --- a/src/openenv/core/env_server/mcp_environment.py +++ b/src/openenv/core/env_server/mcp_environment.py @@ -56,6 +56,7 @@ import asyncio import inspect from abc import abstractmethod from collections import defaultdict +from contextlib import asynccontextmanager from typing import Any, Callable, Dict, Optional from fastmcp import Client @@ -164,6 +165,52 @@ class MCPEnvironment(Environment): # Track tool schemas for list_tools: {tool_name: {mode: schema}} self._mode_tool_schemas = defaultdict(dict) + def _require_mcp_client(self) -> Any: + """Return MCP client or raise if environment has been closed.""" + if self.mcp_client is None: + raise RuntimeError("MCP client is not available; environment is closed") + return self.mcp_client + + def _require_mcp_server(self) -> Any: + """Return MCP server or raise if environment has been closed.""" + if self.mcp_server is None: + raise RuntimeError("MCP server is not available; environment is closed") + return self.mcp_server + + @asynccontextmanager + async def mcp_session(self): + """ + Context manager for MCP client sessions. + + This wrapper serves two purposes: + + 1. **Null guard** — raises a clear error if ``close()`` has already + been called (``mcp_client`` is ``None``). + + 2. **AsyncExitStack adapter** — FastMCP's ``Client.__aenter__`` + creates a background ``asyncio.Task`` for session management. + When entered directly via ``AsyncExitStack`` in the HTTP session + path (``_create_session``), this task can be cancelled by ASGI + harnesses (e.g. Starlette ``TestClient``) between requests, + corrupting session state. Wrapping in an ``asynccontextmanager`` + generator isolates the task lifecycle: the generator frame keeps + ``async with client:`` suspended at ``yield``, so cleanup only + runs when the stack explicitly closes the generator — not when + the event loop cancels orphaned tasks. + + Delegates to FastMCP's ``Client`` context manager which is + reentrant: the first entry opens the transport and subsequent + (nested) entries simply increment an internal reference counter. + The transport is closed only when the outermost context exits. + + No external lock is needed because ``Client._connect`` / + ``Client._disconnect`` already serialise connection state changes + through their own ``anyio.Lock``. + """ + client = self._require_mcp_client() + async with client: + yield client + @property def supports_code_mode(self) -> bool: """Check if this environment supports code mode (execute_code).""" @@ -292,7 +339,8 @@ class MCPEnvironment(Environment): # If mode is None, register with FastMCP as usual if mode is None: - decorated_func = self.mcp_server.tool()(func) + mcp_server = self._require_mcp_server() + decorated_func = mcp_server.tool()(func) self._mode_tools[tool_name][None] = func return decorated_func @@ -372,24 +420,49 @@ class MCPEnvironment(Environment): return self._step_impl(action, timeout_s=timeout_s, **kwargs) def _handle_list_tools(self) -> ListToolsObservation: + """Sync wrapper — delegates to the canonical async implementation.""" + return run_async_safely(self._async_handle_list_tools()) + + async def _async_list_tools(self) -> list: """ - Handle a ListToolsAction by querying the MCP server. + Async helper to list tools from the MCP client. Returns: - ListToolsObservation containing all available tools with their - names, descriptions, and input schemas, filtered by current mode. + List of tool objects from the MCP server. """ - try: - # Get current mode - current_mode = getattr(self, "_mode", None) + async with self.mcp_session() as client: + return await client.list_tools() - # Start with tools from FastMCP server (mode=None tools) - tools_result = run_async_safely(self._async_list_tools()) + def _handle_call_tool( + self, + action: CallToolAction, + timeout_s: Optional[float] = None, + ) -> CallToolObservation: + """Sync wrapper — delegates to the canonical async implementation.""" + return run_async_safely( + self._async_handle_call_tool(action, timeout_s=timeout_s) + ) - # Build list of Tool objects - tools = [] + async def _async_call_tool(self, tool_name: str, arguments: dict) -> Any: + """ + Async helper to call a tool on the MCP server. - # Add FastMCP tools that are not mode-specific + Args: + tool_name: Name of the tool to invoke. + arguments: Dictionary of arguments to pass to the tool. + + Returns: + The result from the tool execution. + """ + async with self.mcp_session() as client: + return await client.call_tool(tool_name, arguments) + + async def _async_handle_list_tools(self) -> ListToolsObservation: + """Async version of _handle_list_tools — avoids run_async_safely.""" + try: + current_mode = getattr(self, "_mode", None) + tools_result = await self._async_list_tools() + tools = [] for tool in tools_result: if tool.name not in self._mode_tool_schemas: tools.append( @@ -401,11 +474,8 @@ class MCPEnvironment(Environment): else {}, ) ) - - # Add mode-specific tools available in current mode for tool_name, mode_schemas in self._mode_tool_schemas.items(): if None in mode_schemas: - # Tool available in all modes schema = mode_schemas[None] tools.append( Tool( @@ -415,7 +485,6 @@ class MCPEnvironment(Environment): ) ) elif current_mode in mode_schemas: - # Tool available in current mode schema = mode_schemas[current_mode] tools.append( Tool( @@ -424,65 +493,30 @@ class MCPEnvironment(Environment): input_schema=schema["input_schema"], ) ) - return ListToolsObservation(tools=tools) - except Exception as e: - # Return an observation with error in metadata return ListToolsObservation( tools=[], - metadata={ - "error": str(e), - "error_type": "list_tools_failed", - }, + metadata={"error": str(e), "error_type": "list_tools_failed"}, ) - async def _async_list_tools(self) -> list: - """ - Async helper to list tools from the MCP client. - - Returns: - List of tool objects from the MCP server. - """ - async with self.mcp_client: - return await self.mcp_client.list_tools() - - def _handle_call_tool( + async def _async_handle_call_tool( self, action: CallToolAction, timeout_s: Optional[float] = None, ) -> CallToolObservation: - """ - Handle a CallToolAction by invoking the specified tool. - - Args: - action: The CallToolAction containing tool_name and arguments. - timeout_s: Timeout in seconds. Defaults to MCP_TOOL_CALL_TIMEOUT (30s). - - Returns: - CallToolObservation with the tool's result or an error. - """ + """Async version of _handle_call_tool — avoids run_async_safely.""" timeout = timeout_s if timeout_s is not None else MCP_TOOL_CALL_TIMEOUT - - # Check if this is a mode-specific tool tool_name = action.tool_name current_mode = getattr(self, "_mode", None) if tool_name in self._mode_tools: mode_info = self._mode_tools[tool_name] - - # Check if tool is available in current mode - # Tool is available if: - # 1. It has a None mode (available in all modes), OR - # 2. It has an implementation for the current mode if None in mode_info: - # Use the mode-agnostic version func = mode_info[None] elif current_mode in mode_info: - # Use the mode-specific version func = mode_info[current_mode] else: - # Tool not available in current mode return CallToolObservation( tool_name=tool_name, result=None, @@ -491,16 +525,11 @@ class MCPEnvironment(Environment): message=f"Tool '{tool_name}' not available in {current_mode} mode", ), ) - - # Call the mode-specific function directly try: - # Check if function is async and await if necessary if inspect.iscoroutinefunction(func): - result = run_async_safely(func(**action.arguments)) + result = await func(**action.arguments) else: result = func(**action.arguments) - - # Wrap result in CallToolResult format to match FastMCP behavior return CallToolObservation( tool_name=tool_name, result=CallToolResult( @@ -521,22 +550,12 @@ class MCPEnvironment(Environment): ), ) - # Not a mode-specific tool, use FastMCP try: - # Run the async call_tool with timeout - # Use run_async_safely to handle both sync and async contexts - result = run_async_safely( - asyncio.wait_for( - self._async_call_tool(action.tool_name, action.arguments), - timeout=timeout, - ) - ) - - return CallToolObservation( - tool_name=action.tool_name, - result=result, + result = await asyncio.wait_for( + self._async_call_tool(action.tool_name, action.arguments), + timeout=timeout, ) - + return CallToolObservation(tool_name=action.tool_name, result=result) except asyncio.TimeoutError: return CallToolObservation( tool_name=action.tool_name, @@ -546,11 +565,8 @@ class MCPEnvironment(Environment): message=f"Tool '{action.tool_name}' timed out after {timeout} seconds", ), ) - except Exception as e: error_message = str(e) - - # Determine error type based on the exception if ( "not found" in error_message.lower() or "unknown tool" in error_message.lower() @@ -563,29 +579,34 @@ class MCPEnvironment(Environment): error_type = ToolErrorType.INVALID_ARGS else: error_type = ToolErrorType.EXECUTION_ERROR - return CallToolObservation( tool_name=action.tool_name, result=None, - error=ToolError( - error_type=error_type, - message=error_message, - ), + error=ToolError(error_type=error_type, message=error_message), ) - async def _async_call_tool(self, tool_name: str, arguments: dict) -> Any: + async def step_async( + self, + action: Action, + timeout_s: Optional[float] = None, + **kwargs: Any, + ) -> Observation: """ - Async helper to call a tool on the MCP server. + Async step that routes MCP actions without going through run_async_safely. - Args: - tool_name: Name of the tool to invoke. - arguments: Dictionary of arguments to pass to the tool. - - Returns: - The result from the tool execution. + The WebSocket handler calls this directly on the outer event loop, where + the MCP session is already open, avoiding the thread/event-loop deadlock + that occurs when the sync step() path is used via run_in_executor. """ - async with self.mcp_client: - return await self.mcp_client.call_tool(tool_name, arguments) + if isinstance(action, ListToolsAction): + return await self._async_handle_list_tools() + elif isinstance(action, CallToolAction): + return await self._async_handle_call_tool(action, timeout_s=timeout_s) + else: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + None, lambda: self._step_impl(action, timeout_s=timeout_s, **kwargs) + ) @abstractmethod def _step_impl( diff --git a/src/openenv/core/env_server/serialization.py b/src/openenv/core/env_server/serialization.py index a9b50d9aeb873794044e77ee398a7f2b5fca8093..fd5fb588c739c3dc2bfdc1a24e55d3a95cf54543 100644 --- a/src/openenv/core/env_server/serialization.py +++ b/src/openenv/core/env_server/serialization.py @@ -14,14 +14,28 @@ HTTP server and web interface implementations. from typing import Any, Dict, Type +from .mcp_types import CallToolAction, ListToolsAction from .types import Action, Observation +# MCP action types keyed by their "type" discriminator value. +# These are checked before the environment's own action_cls so that +# ListToolsAction / CallToolAction payloads are never rejected by an +# unrelated Pydantic model. +_MCP_ACTION_TYPES: Dict[str, Type[Action]] = { + "list_tools": ListToolsAction, + "call_tool": CallToolAction, +} + def deserialize_action(action_data: Dict[str, Any], action_cls: Type[Action]) -> Action: """ Convert JSON dict to Action instance using Pydantic validation. - This is a basic deserialization that works for most environments. + MCP action types (``list_tools``, ``call_tool``) are recognised + automatically via the ``"type"`` discriminator field, regardless of + the environment's configured ``action_cls``. All other payloads + fall through to ``action_cls.model_validate()``. + For special cases (e.g., tensor fields, custom type conversions), use deserialize_action_with_preprocessing(). @@ -38,6 +52,17 @@ def deserialize_action(action_data: Dict[str, Any], action_cls: Type[Action]) -> Note: This uses Pydantic's model_validate() for automatic validation. """ + # Route MCP action types before falling through to the env action_cls. + # Only intercept when action_cls is the generic Action base or itself an + # MCP type (i.e. the server hosts an MCP environment). This avoids + # silently bypassing env-specific validation for non-MCP environments + # that happen to use "call_tool" / "list_tools" as a type discriminator. + action_type = action_data.get("type") + if action_type in _MCP_ACTION_TYPES: + mcp_cls = _MCP_ACTION_TYPES[action_type] + if action_cls is Action or action_cls in _MCP_ACTION_TYPES.values(): + return mcp_cls.model_validate(action_data) + return action_cls.model_validate(action_data) @@ -62,6 +87,15 @@ def deserialize_action_with_preprocessing( Raises: ValidationError: If action_data is invalid for the action class """ + # Route MCP action types before preprocessing (they don't need it). + # Same guard as deserialize_action: only intercept when action_cls is + # the generic Action base or itself an MCP type. + action_type = action_data.get("type") + if action_type in _MCP_ACTION_TYPES: + mcp_cls = _MCP_ACTION_TYPES[action_type] + if action_cls is Action or action_cls in _MCP_ACTION_TYPES.values(): + return mcp_cls.model_validate(action_data) + processed_data = {} for key, value in action_data.items(): diff --git a/src/openenv/core/env_server/web_interface.py b/src/openenv/core/env_server/web_interface.py index 284740eb408b8e2b798037918967b7a50abee72d..026093887cbb43e995df64881c849ca6ed4ac5de 100644 --- a/src/openenv/core/env_server/web_interface.py +++ b/src/openenv/core/env_server/web_interface.py @@ -15,13 +15,15 @@ option (e.g. openenv push --enable-interface) or ENABLE_WEB_INTERFACE env var. from __future__ import annotations import asyncio +import inspect import json from concurrent.futures import ThreadPoolExecutor from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Type import gradio as gr -from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from fastapi import Body, FastAPI, HTTPException, status, WebSocket, WebSocketDisconnect +from fastapi.responses import RedirectResponse from pydantic import BaseModel, ConfigDict, Field from .gradio_theme import OPENENV_GRADIO_CSS, OPENENV_GRADIO_THEME @@ -269,6 +271,28 @@ class WebInterfaceManager: # Thread pool for running sync code (e.g., Playwright sync API) in async context self._executor = ThreadPoolExecutor(max_workers=1) + @staticmethod + def _get_valid_kwargs( + sig: inspect.Signature, + kwargs: Dict[str, Any], + skip_params: Optional[set[str]] = None, + ) -> Dict[str, Any]: + """Filter kwargs to only those accepted by the target function.""" + skip_params = skip_params or set() + valid_kwargs: Dict[str, Any] = {} + has_var_kwargs = any( + param.kind == inspect.Parameter.VAR_KEYWORD + for param in sig.parameters.values() + ) + + for key, value in kwargs.items(): + if key in skip_params: + continue + if key in sig.parameters or has_var_kwargs: + valid_kwargs[key] = value + + return valid_kwargs + async def _run_sync_in_thread_pool(self, func, *args, **kwargs): """Run a synchronous function in the thread pool executor. @@ -317,11 +341,24 @@ class WebInterfaceManager: for client in disconnected_clients: self.connected_clients.remove(client) - async def reset_environment(self) -> Dict[str, Any]: + async def reset_environment( + self, reset_kwargs: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: """Reset the environment and update state.""" - # Run sync reset in thread pool to avoid blocking event loop - # and to support environments using sync libraries (e.g., Playwright) - observation: Observation = await self._run_sync_in_thread_pool(self.env.reset) + reset_kwargs = reset_kwargs or {} + + is_async = self.env.reset_async.__func__ is not Environment.reset_async + sig = inspect.signature(self.env.reset_async if is_async else self.env.reset) + valid_kwargs = self._get_valid_kwargs(sig, reset_kwargs) + + if is_async: + observation = await self.env.reset_async(**valid_kwargs) + else: + # Run sync reset in thread pool to avoid blocking event loop + # and to support environments using sync libraries (e.g., Playwright) + observation = await self._run_sync_in_thread_pool( + self.env.reset, **valid_kwargs + ) state: State = self.env.state # Serialize observation once using shared utility @@ -428,6 +465,16 @@ def create_web_interface_app( web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata) # Web API routes first (so they take precedence over Gradio mount at /web) + @app.get("/", include_in_schema=False) + async def web_root(): + """Redirect the app root to the Gradio interface.""" + return RedirectResponse(url="/web/") + + @app.get("/web", include_in_schema=False) + async def web_root_no_slash(): + """Redirect /web to /web/ for mounted Gradio deployments behind proxies.""" + return RedirectResponse(url="/web/") + @app.get("/web/metadata") async def web_metadata(): """Get environment metadata.""" @@ -449,9 +496,9 @@ def create_web_interface_app( await web_manager.disconnect_websocket(websocket) @app.post("/web/reset") - async def web_reset(): + async def web_reset(request: Optional[Dict[str, Any]] = Body(default=None)): """Reset endpoint for web interface.""" - return await web_manager.reset_environment() + return await web_manager.reset_environment(request) @app.post("/web/step") async def web_step(request: Dict[str, Any]): @@ -475,7 +522,13 @@ def create_web_interface_app( @app.get("/web/state") async def web_state(): """State endpoint for web interface.""" - return web_manager.get_state() + try: + return web_manager.get_state() + except RuntimeError as exc: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=str(exc), + ) from exc action_fields = _extract_action_fields(action_cls) is_chat_env = _is_chat_env(action_cls) @@ -505,7 +558,7 @@ def create_web_interface_app( ) gradio_blocks = gr.TabbedInterface( [default_blocks, custom_blocks], - tab_names=["Playground", "Visualization"], + tab_names=["Playground", "Custom"], title=get_gradio_display_title(metadata), ) else: diff --git a/src/openenv/core/mcp_client.py b/src/openenv/core/mcp_client.py index edac3529d3a34e798781d86cf4d2495dc9611713..1d8bd38efd3595526fc25915c8fdbbe7aaeca5d5 100644 --- a/src/openenv/core/mcp_client.py +++ b/src/openenv/core/mcp_client.py @@ -52,6 +52,7 @@ Example (sync wrapper): ... result = env.call_tool("echo_message", message="Hello!") """ +import asyncio from typing import Any, Dict, List, Optional from .client_types import StepResult @@ -118,6 +119,66 @@ class MCPClientBase(EnvClient[Any, Observation, State]): ) self._tools_cache: Optional[List[Tool]] = None self.use_production_mode = False + self._production_session_id: Optional[str] = None + self._production_session_lock = asyncio.Lock() + self._jsonrpc_request_id = 0 + self._http_client: Optional[Any] = None # lazily-created httpx.AsyncClient + + def _next_request_id(self) -> int: + """Generate a monotonically increasing JSON-RPC request id.""" + self._jsonrpc_request_id += 1 + return self._jsonrpc_request_id + + def _production_mcp_url(self) -> str: + """Build HTTP MCP endpoint URL from the client's websocket URL.""" + url = self._ws_url.replace("ws://", "http://").replace("wss://", "https://") + if url.endswith("/ws"): + url = url[: -len("/ws")] + return url.rstrip("/") + "/mcp" + + async def _get_http_client(self) -> Any: + """Return a shared httpx.AsyncClient, creating one lazily.""" + if self._http_client is None: + import httpx + + self._http_client = httpx.AsyncClient() + return self._http_client + + async def _production_mcp_request( + self, method: str, params: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """Send a JSON-RPC request to HTTP /mcp and return parsed JSON response.""" + client = await self._get_http_client() + response = await client.post( + self._production_mcp_url(), + json={ + "jsonrpc": "2.0", + "method": method, + "params": params or {}, + "id": self._next_request_id(), + }, + timeout=self._message_timeout, + ) + response.raise_for_status() + return response.json() + + async def _ensure_production_session(self) -> str: + """Create and cache a persistent HTTP MCP session id if needed.""" + async with self._production_session_lock: + if self._production_session_id is not None: + return self._production_session_id + + data = await self._production_mcp_request("openenv/session/create") + if "error" in data: + message = data.get("error", {}).get("message", "unknown error") + raise RuntimeError(f"Failed to create MCP session: {message}") + + session_id = data.get("result", {}).get("session_id") + if not session_id: + raise RuntimeError("Failed to create MCP session: missing session_id") + + self._production_session_id = session_id + return session_id async def list_tools(self, use_cache: bool = True) -> List[Tool]: """ @@ -138,26 +199,18 @@ class MCPClientBase(EnvClient[Any, Observation, State]): if use_cache and self._tools_cache is not None: return self._tools_cache - # Use production mode HTTP endpoint if enabled - if self.use_production_mode: - import requests - - # Convert ws:// URL to http:// URL - url = self._ws_url.replace("ws://", "http://").replace("wss://", "https://") - # Remove /ws suffix if present and add /mcp - url = url.rstrip("/ws").rstrip("/") + "/mcp" - + # Use production mode HTTP endpoint if enabled. + # Some tests instantiate with __new__ and skip __init__, so default missing flag to False. + if getattr(self, "use_production_mode", False): try: - response = requests.post( - url, - json={ - "jsonrpc": "2.0", - "method": "tools/list", - "params": {}, - "id": 1, - }, + session_id = await self._ensure_production_session() + data = await self._production_mcp_request( + "tools/list", + {"session_id": session_id}, ) - data = response.json() + if "error" in data: + message = data.get("error", {}).get("message", "unknown error") + raise RuntimeError(f"list_tools failed: {message}") if "result" in data and "tools" in data["result"]: tools = [ Tool( @@ -177,7 +230,12 @@ class MCPClientBase(EnvClient[Any, Observation, State]): return [] result = await self.step(ListToolsAction()) - self._tools_cache = result.observation.tools + if isinstance(result.observation, ListToolsObservation): + self._tools_cache = result.observation.tools + return self._tools_cache + + # Unexpected observation type; keep API stable with an empty tool list. + self._tools_cache = [] return self._tools_cache def _step_payload(self, action: Any) -> Dict[str, Any]: @@ -251,6 +309,35 @@ class MCPClientBase(EnvClient[Any, Observation, State]): step_count=payload.get("step_count", 0), ) + async def close(self) -> None: + """ + Close client resources. + + In production MCP mode, this also closes the server-side persistent + MCP session (best effort) before closing websocket/provider resources. + """ + if self._production_session_id is not None: + try: + await self._production_mcp_request( + "openenv/session/close", + {"session_id": self._production_session_id}, + ) + except Exception: + # Best effort cleanup - do not mask normal close behavior + pass + finally: + self._production_session_id = None + + if self._http_client is not None: + try: + await self._http_client.aclose() + except Exception: + pass + finally: + self._http_client = None + + await super().close() + class MCPToolClient(MCPClientBase): """ @@ -316,6 +403,26 @@ class MCPToolClient(MCPClientBase): >>> result = await env.call_tool("greet", name="Claude") >>> print(result) # "Hello, Claude!" """ + if getattr(self, "use_production_mode", False): + session_id = await self._ensure_production_session() + data = await self._production_mcp_request( + "tools/call", + { + "name": name, + "arguments": kwargs, + "session_id": session_id, + }, + ) + + if "error" in data: + message = data.get("error", {}).get("message", "unknown error") + raise RuntimeError(f"Tool '{name}' failed: {message}") + + result = data.get("result") + if isinstance(result, dict) and "data" in result: + return result["data"] + return result + action = CallToolAction(tool_name=name, arguments=kwargs) result = await self.step(action) obs = result.observation diff --git a/src/openenv_core.egg-info/PKG-INFO b/src/openenv_core.egg-info/PKG-INFO index 654a6035265f03dfe71cddb57b22a9385d428012..c4aa57a590f67ef81dd51bd2f893d0eb66e31dd3 100644 --- a/src/openenv_core.egg-info/PKG-INFO +++ b/src/openenv_core.egg-info/PKG-INFO @@ -1,6 +1,6 @@ Metadata-Version: 2.4 Name: openenv-core -Version: 0.2.2.dev0 +Version: 0.2.3 Summary: A unified framework for reinforcement learning environments Requires-Python: >=3.10 Description-Content-Type: text/markdown @@ -19,6 +19,7 @@ Requires-Dist: tomli-w>=1.2.0 Requires-Dist: websockets>=15.0.1 Requires-Dist: fastmcp>=3.0.0 Requires-Dist: gradio>=4.0.0 +Requires-Dist: httpx>=0.28.1 Provides-Extra: core Requires-Dist: fastapi>=0.104.0; extra == "core" Requires-Dist: pydantic>=2.0.0; extra == "core" @@ -61,7 +62,7 @@ Dynamic: license-file An e2e framework for creating, deploying and using isolated execution environments for agentic RL training, built using Gymnasium style simple APIs. -[![PyPI](https://img.shields.io/pypi/v/openenv?color=blue)](https://pypi.org/project/openenv/) +[![PyPI](https://img.shields.io/pypi/v/openenv-core?color=blue)](https://pypi.org/project/openenv-core/) [![Discord](https://img.shields.io/badge/Discord-OpenEnv-7289da?style=flat&logo=discord&logoColor=white)](https://discord.gg/YsTYBh6PD9) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/meta-pytorch/OpenEnv/blob/main/examples/OpenEnv_Tutorial.ipynb) [![Docs](https://img.shields.io/badge/Docs-Explore-blue?logo=readthedocs&logoColor=white)](https://meta-pytorch.org/OpenEnv/) diff --git a/src/openenv_core.egg-info/SOURCES.txt b/src/openenv_core.egg-info/SOURCES.txt index f06cb83904e0d986b1d0010ab0eb4c8ca541e85e..3fa1bfb9249b08e6d379bcf153c784c94ca4a634 100644 --- a/src/openenv_core.egg-info/SOURCES.txt +++ b/src/openenv_core.egg-info/SOURCES.txt @@ -1,4 +1,5 @@ LICENSE +MANIFEST.in README.md pyproject.toml src/openenv/__init__.py @@ -19,8 +20,6 @@ src/openenv/cli/commands/serve.py src/openenv/cli/commands/skills.py src/openenv/cli/commands/validate.py src/openenv/cli/templates/__init__.py -src/openenv/cli/templates/__pycache__/__init__.cpython-311.pyc -src/openenv/cli/templates/__pycache__/__init__.cpython-313.pyc src/openenv/cli/templates/openenv_env/README.md src/openenv/cli/templates/openenv_env/__init__.py src/openenv/cli/templates/openenv_env/client.py diff --git a/src/openenv_core.egg-info/requires.txt b/src/openenv_core.egg-info/requires.txt index caf0c5f6727bc58526e2a96a3f14cc28713c32c6..c1738e33af1ddae7a155997144133b9f743b165e 100644 --- a/src/openenv_core.egg-info/requires.txt +++ b/src/openenv_core.egg-info/requires.txt @@ -12,6 +12,7 @@ tomli-w>=1.2.0 websockets>=15.0.1 fastmcp>=3.0.0 gradio>=4.0.0 +httpx>=0.28.1 [all] openenv-core[core]