Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Ask Answer Env Environment Implementation. | |
| A deterministic slot-filling environment where agents must decide between | |
| asking clarifying questions or answering early to maximize reward. | |
| """ | |
| import random | |
| from typing import Optional | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| from models import AskAnswerAction, AskAnswerObservation, KnownSlots | |
| # Constants | |
| CITIES = ["Paris", "Rome", "Tokyo", "Goa"] | |
| DATES = ["next_weekend", "mid_feb", "march"] | |
| BUDGETS = ["low", "mid", "high"] | |
| STYLES = ["relax", "adventure", "food"] # Distractor slot | |
| MAX_STEPS = 3 # Forces agent to guess at least 1 core slot | |
| PROMPT = "Plan a short trip for me." | |
| # Rewards (unchanged from v0) | |
| STEP_PENALTY = -0.05 | |
| ASK_UNKNOWN_REWARD = 0.1 | |
| ASK_KNOWN_PENALTY = -0.2 | |
| AUTO_FAIL_PENALTY = -1.0 | |
| # Graded answer rewards (v1) | |
| ANSWER_CITY_CORRECT = 0.4 | |
| ANSWER_DATE_CORRECT = 0.4 | |
| ANSWER_BUDGET_CORRECT = 0.4 | |
| ANSWER_STYLE_CORRECT_BONUS = 0.1 # Optional nice-to-have | |
| ANSWER_CORE_ALL_CORRECT_BONUS = 0.2 | |
| ANSWER_CORE_ANY_WRONG_PENALTY = -0.6 | |
| class AskAnswerEnvironment(Environment): | |
| """ | |
| A slot-filling environment for training RL agents. | |
| The agent must decide between: | |
| - Asking clarifying questions (ASK) to reveal hidden slot values | |
| - Answering early (ANSWER) to end the episode | |
| Hidden state (city, date, budget, style) is sampled at reset with a seeded RNG. | |
| The agent can ask about slots to reveal their values before answering. | |
| With MAX_STEPS=3, the agent can only ask 2 slots before being forced to answer, | |
| creating a non-trivial ask-vs-act tradeoff. The "style" slot is a distractor | |
| that provides less reward than core slots (city, date, budget). | |
| Rewards: | |
| - Step penalty: -0.05 per step | |
| - ASK unknown slot: +0.1 | |
| - ASK known slot: -0.2 | |
| - ANSWER: graded per-slot (+0.4 each core, +0.1 style) | |
| - Core all correct bonus: +0.2 | |
| - Core any wrong penalty: -0.6 | |
| - Auto-fail (steps exhausted): -1.0 | |
| """ | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__(self): | |
| """Initialize the ask_answer_env environment.""" | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._rng: random.Random = random.Random() | |
| # Hidden truth (sampled at reset) | |
| self._hidden_city: str = "" | |
| self._hidden_date: str = "" | |
| self._hidden_budget: str = "" | |
| self._hidden_style: str = "" | |
| # Known slots (revealed through ASK actions) | |
| self._known: KnownSlots = KnownSlots() | |
| self._steps_left: int = MAX_STEPS | |
| self._done: bool = False | |
| def reset(self, seed: Optional[int] = None) -> AskAnswerObservation: | |
| """ | |
| Reset the environment with optional seed for determinism. | |
| Args: | |
| seed: Random seed for reproducibility | |
| Returns: | |
| AskAnswerObservation with initial state | |
| """ | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| # Initialize RNG with seed | |
| if seed is not None: | |
| self._rng = random.Random(seed) | |
| else: | |
| self._rng = random.Random() | |
| # Sample hidden truth | |
| self._hidden_city = self._rng.choice(CITIES) | |
| self._hidden_date = self._rng.choice(DATES) | |
| self._hidden_budget = self._rng.choice(BUDGETS) | |
| self._hidden_style = self._rng.choice(STYLES) | |
| # Reset known slots and step counter | |
| self._known = KnownSlots() | |
| self._steps_left = MAX_STEPS | |
| self._done = False | |
| return AskAnswerObservation( | |
| prompt=PROMPT, | |
| known=self._known, | |
| steps_left=self._steps_left, | |
| done=False, | |
| reward=0.0, | |
| ) | |
| def step(self, action: AskAnswerAction) -> AskAnswerObservation: # type: ignore[override] | |
| """ | |
| Execute a step in the environment. | |
| Args: | |
| action: AskAnswerAction with type 'ask' or 'answer' | |
| Returns: | |
| AskAnswerObservation with updated state and reward | |
| """ | |
| if self._done: | |
| return AskAnswerObservation( | |
| prompt=PROMPT, | |
| known=self._known, | |
| steps_left=self._steps_left, | |
| done=True, | |
| reward=0.0, | |
| ) | |
| self._state.step_count += 1 | |
| # Always apply step penalty | |
| reward = STEP_PENALTY | |
| done = False | |
| if action.type == "ask": | |
| reward += self._handle_ask(action.slot) | |
| self._steps_left -= 1 | |
| # Check for auto-fail | |
| if self._steps_left == 0: | |
| reward = AUTO_FAIL_PENALTY | |
| done = True | |
| elif action.type == "answer": | |
| reward += self._handle_answer(action) | |
| done = True | |
| self._done = done | |
| # Calculate core_correct_count when episode ends via ANSWER | |
| core_correct_count = None | |
| if done and action.type == "answer": | |
| core_correct_count = sum([ | |
| action.city == self._hidden_city, | |
| action.date == self._hidden_date, | |
| action.budget == self._hidden_budget, | |
| ]) | |
| return AskAnswerObservation( | |
| prompt=PROMPT, | |
| known=self._known, | |
| steps_left=self._steps_left, | |
| done=done, | |
| reward=reward, | |
| core_correct_count=core_correct_count, | |
| ) | |
| def _handle_ask(self, slot: Optional[str]) -> float: | |
| """ | |
| Handle ASK action - reveal a slot if unknown. | |
| Args: | |
| slot: The slot to ask about ('city', 'date', 'budget', or 'style') | |
| Returns: | |
| Reward for the ASK action | |
| """ | |
| if slot == "city": | |
| if self._known.city is not None: | |
| return ASK_KNOWN_PENALTY | |
| self._known = KnownSlots( | |
| city=self._hidden_city, | |
| date=self._known.date, | |
| budget=self._known.budget, | |
| style=self._known.style, | |
| ) | |
| return ASK_UNKNOWN_REWARD | |
| elif slot == "date": | |
| if self._known.date is not None: | |
| return ASK_KNOWN_PENALTY | |
| self._known = KnownSlots( | |
| city=self._known.city, | |
| date=self._hidden_date, | |
| budget=self._known.budget, | |
| style=self._known.style, | |
| ) | |
| return ASK_UNKNOWN_REWARD | |
| elif slot == "budget": | |
| if self._known.budget is not None: | |
| return ASK_KNOWN_PENALTY | |
| self._known = KnownSlots( | |
| city=self._known.city, | |
| date=self._known.date, | |
| budget=self._hidden_budget, | |
| style=self._known.style, | |
| ) | |
| return ASK_UNKNOWN_REWARD | |
| elif slot == "style": | |
| if self._known.style is not None: | |
| return ASK_KNOWN_PENALTY | |
| self._known = KnownSlots( | |
| city=self._known.city, | |
| date=self._known.date, | |
| budget=self._known.budget, | |
| style=self._hidden_style, | |
| ) | |
| return ASK_UNKNOWN_REWARD | |
| # Invalid slot | |
| return ASK_KNOWN_PENALTY | |
| def _handle_answer(self, action: AskAnswerAction) -> float: | |
| """ | |
| Handle ANSWER action with graded rewards. | |
| Reward structure: | |
| - Per-slot rewards: +0.4 for each correct core slot (city, date, budget) | |
| - Style bonus: +0.1 if style provided and correct (ignored if None) | |
| - Core bonus: +0.2 if all core slots correct | |
| - Core penalty: -0.6 if any core slot wrong | |
| Args: | |
| action: The answer action with city, date, budget, style values | |
| Returns: | |
| Reward for the ANSWER action | |
| """ | |
| reward = 0.0 | |
| # Check core slots | |
| city_correct = action.city == self._hidden_city | |
| date_correct = action.date == self._hidden_date | |
| budget_correct = action.budget == self._hidden_budget | |
| # Per-slot rewards for core slots | |
| if city_correct: | |
| reward += ANSWER_CITY_CORRECT | |
| if date_correct: | |
| reward += ANSWER_DATE_CORRECT | |
| if budget_correct: | |
| reward += ANSWER_BUDGET_CORRECT | |
| # Style bonus (only if provided and correct, ignored if None) | |
| if action.style is not None and action.style == self._hidden_style: | |
| reward += ANSWER_STYLE_CORRECT_BONUS | |
| # Core bonus/penalty | |
| core_all_correct = city_correct and date_correct and budget_correct | |
| if core_all_correct: | |
| reward += ANSWER_CORE_ALL_CORRECT_BONUS | |
| else: | |
| reward += ANSWER_CORE_ANY_WRONG_PENALTY | |
| return reward | |
| def state(self) -> State: | |
| """ | |
| Get the current environment state. | |
| Returns: | |
| Current State with episode_id and step_count | |
| """ | |
| return self._state | |