| import json |
| import logging |
| from abc import ABC, abstractmethod |
| from enum import Enum |
| from functools import cached_property |
| from typing import Any, Literal, Optional, Self |
|
|
| from pydantic import BaseModel |
|
|
| from proxy_lite.history import ToolCall |
| from proxy_lite.tools import Tool, ToolExecutionResponse |
|
|
|
|
| class EventType(str, Enum): |
| OBSERVATION = "observation" |
| ACTION = "action" |
| MESSAGE = "message" |
|
|
|
|
| class Event(BaseModel): |
| type: EventType |
|
|
|
|
| class State(BaseModel): |
| text: Optional[str] = None |
| image: Optional[str] = None |
| html: Optional[str] = None |
| tool_responses: Optional[list[ToolExecutionResponse]] = None |
|
|
|
|
| class Observation(Event): |
| type: Literal[EventType.OBSERVATION] = EventType.OBSERVATION |
| state: State |
| terminated: bool |
| reward: Optional[float] = None |
| info: Optional[dict[str, Any]] = None |
|
|
|
|
| class Action(Event): |
| type: Literal[EventType.ACTION] = EventType.ACTION |
| text: Optional[str] = None |
| tool_calls: Optional[list[ToolCall]] = None |
| info: Optional[dict[str, Any]] = None |
|
|
|
|
| class BaseEnvironmentConfig(BaseModel): ... |
|
|
|
|
| class BaseEnvironment(BaseModel, ABC): |
| config: BaseEnvironmentConfig |
| logger: logging.Logger | None = None |
|
|
| class Config: |
| arbitrary_types_allowed = True |
|
|
| async def __aenter__(self) -> Self: |
| return self |
|
|
| async def __aexit__(self, exc_type, exc_value, traceback): |
| pass |
|
|
| @property |
| @abstractmethod |
| def info_for_user(self) -> str: ... |
|
|
| @cached_property |
| @abstractmethod |
| def tools(self) -> list[Tool]: ... |
|
|
| @abstractmethod |
| async def initialise(self) -> Observation: ... |
|
|
| @abstractmethod |
| async def execute_action(self, action: Action) -> Observation: ... |
|
|
| @abstractmethod |
| async def observe(self) -> Observation: ... |
|
|
| @abstractmethod |
| async def evaluate(self, **kwargs: dict[str, Any]) -> dict[str, Any]: ... |
|
|
| async def execute_tool(self, tool_call: ToolCall) -> None: |
| function = tool_call.function |
| for tool in self.tools: |
| if hasattr(tool, function["name"]): |
| arguments = json.loads(function["arguments"]) |
| if isinstance(arguments, str): |
| arguments = json.loads(arguments) |
| return await getattr(tool, function["name"])( |
| **arguments, |
| ) |
| msg = f'No tool function with name "{function["name"]}"' |
| raise ValueError(msg) |
|
|
| async def get_info(self) -> dict[str, Any]: |
| return {} |
|
|
|
|
| class Environments: |
| _environment_registry: dict[str, type[BaseEnvironment]] = {} |
| _environment_config_registry: dict[str, type[BaseEnvironmentConfig]] = {} |
|
|
| @classmethod |
| def register_environment(cls, name: str): |
| """ |
| Decorator to register an Environment class under a given name. |
| |
| Example: |
| @Environments.register_environment("my_environment") |
| class MyEnvironment(BaseEnvironment): |
| ... |
| """ |
|
|
| def decorator(env_cls: type[BaseEnvironment]) -> type[BaseEnvironment]: |
| cls._environment_registry[name] = env_cls |
| return env_cls |
|
|
| return decorator |
|
|
| @classmethod |
| def register_environment_config(cls, name: str): |
| """ |
| Decorator to register an Environment configuration class under a given name. |
| |
| Example: |
| @Environments.register_environment_config("my_environment") |
| class MyEnvironmentConfig(BaseEnvironmentConfig): |
| ... |
| """ |
|
|
| def decorator(config_cls: type[BaseEnvironmentConfig]) -> type[BaseEnvironmentConfig]: |
| cls._environment_config_registry[name] = config_cls |
| return config_cls |
|
|
| return decorator |
|
|
| @classmethod |
| def get(cls, name: str) -> type[BaseEnvironment]: |
| """ |
| Retrieve a registered Environment class by its name. |
| |
| Raises: |
| ValueError: If no such environment is found. |
| """ |
| try: |
| return cls._environment_registry[name] |
| except KeyError: |
| raise ValueError(f"Environment '{name}' not found.") |
|
|
| @classmethod |
| def get_config(cls, name: str) -> type[BaseEnvironmentConfig]: |
| """ |
| Retrieve a registered Environment configuration class by its name. |
| |
| Raises: |
| ValueError: If no such configuration is found. |
| """ |
| try: |
| return cls._environment_config_registry[name] |
| except KeyError: |
| raise ValueError(f"Environment config for '{name}' not found.") |
|
|