Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """Ask Answer Env Environment Client.""" | |
| from typing import Dict | |
| from openenv.core.client_types import StepResult | |
| from openenv.core.env_server.types import State | |
| from openenv.core import EnvClient | |
| from .models import AskAnswerAction, AskAnswerObservation, KnownSlots | |
| class AskAnswerEnv( | |
| EnvClient[AskAnswerAction, AskAnswerObservation, State] | |
| ): | |
| """ | |
| Client for the Ask Answer Env Environment. | |
| A slot-filling environment where agents must decide between asking | |
| clarifying questions or answering early to maximize reward. | |
| Example: | |
| >>> with AskAnswerEnv(base_url="http://localhost:8000") as client: | |
| ... result = client.reset(seed=42) | |
| ... print(result.observation.prompt) | |
| ... print(result.observation.known) | |
| ... | |
| ... # Ask about city | |
| ... result = client.step(AskAnswerAction(type="ask", slot="city")) | |
| ... print(f"City: {result.observation.known.city}") | |
| ... | |
| ... # Answer with known values | |
| ... result = client.step(AskAnswerAction( | |
| ... type="answer", | |
| ... city=result.observation.known.city, | |
| ... date="mid_feb", | |
| ... budget="high" | |
| ... )) | |
| ... print(f"Reward: {result.reward}, Done: {result.done}") | |
| Example with Docker: | |
| >>> client = AskAnswerEnv.from_docker_image("ask_answer_env-env:latest") | |
| >>> try: | |
| ... result = client.reset(seed=42) | |
| ... # ... interact with environment | |
| ... finally: | |
| ... client.close() | |
| """ | |
| def _step_payload(self, action: AskAnswerAction) -> Dict: | |
| """ | |
| Convert AskAnswerAction to JSON payload for step message. | |
| Args: | |
| action: AskAnswerAction instance | |
| Returns: | |
| Dictionary representation suitable for JSON encoding | |
| """ | |
| payload = {"type": action.type} | |
| if action.type == "ask": | |
| payload["slot"] = action.slot | |
| else: # answer | |
| payload["city"] = action.city | |
| payload["date"] = action.date | |
| payload["budget"] = action.budget | |
| payload["style"] = action.style | |
| return payload | |
| def _parse_result(self, payload: Dict) -> StepResult[AskAnswerObservation]: | |
| """ | |
| Parse server response into StepResult[AskAnswerObservation]. | |
| Args: | |
| payload: JSON response data from server | |
| Returns: | |
| StepResult with AskAnswerObservation | |
| """ | |
| obs_data = payload.get("observation", {}) | |
| known_data = obs_data.get("known", {}) | |
| known = KnownSlots( | |
| city=known_data.get("city"), | |
| date=known_data.get("date"), | |
| budget=known_data.get("budget"), | |
| style=known_data.get("style"), | |
| ) | |
| observation = AskAnswerObservation( | |
| prompt=obs_data.get("prompt", "Plan a short trip for me."), | |
| known=known, | |
| steps_left=obs_data.get("steps_left", 0), | |
| done=payload.get("done", False), | |
| reward=payload.get("reward"), | |
| core_correct_count=obs_data.get("core_correct_count"), | |
| ) | |
| return StepResult( | |
| observation=observation, | |
| reward=payload.get("reward"), | |
| done=payload.get("done", False), | |
| ) | |
| def _parse_state(self, payload: Dict) -> State: | |
| """ | |
| Parse server response into State object. | |
| Args: | |
| payload: JSON response from state request | |
| Returns: | |
| State object with episode_id and step_count | |
| """ | |
| return State( | |
| episode_id=payload.get("episode_id"), | |
| step_count=payload.get("step_count", 0), | |
| ) | |