# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ Ask Answer Env Environment Implementation. A deterministic slot-filling environment where agents must decide between asking clarifying questions or answering early to maximize reward. """ import random from typing import Optional from uuid import uuid4 from openenv.core.env_server.interfaces import Environment from openenv.core.env_server.types import State from models import AskAnswerAction, AskAnswerObservation, KnownSlots # Constants CITIES = ["Paris", "Rome", "Tokyo", "Goa"] DATES = ["next_weekend", "mid_feb", "march"] BUDGETS = ["low", "mid", "high"] STYLES = ["relax", "adventure", "food"] # Distractor slot MAX_STEPS = 3 # Forces agent to guess at least 1 core slot PROMPT = "Plan a short trip for me." # Rewards (unchanged from v0) STEP_PENALTY = -0.05 ASK_UNKNOWN_REWARD = 0.1 ASK_KNOWN_PENALTY = -0.2 AUTO_FAIL_PENALTY = -1.0 # Graded answer rewards (v1) ANSWER_CITY_CORRECT = 0.4 ANSWER_DATE_CORRECT = 0.4 ANSWER_BUDGET_CORRECT = 0.4 ANSWER_STYLE_CORRECT_BONUS = 0.1 # Optional nice-to-have ANSWER_CORE_ALL_CORRECT_BONUS = 0.2 ANSWER_CORE_ANY_WRONG_PENALTY = -0.6 class AskAnswerEnvironment(Environment): """ A slot-filling environment for training RL agents. The agent must decide between: - Asking clarifying questions (ASK) to reveal hidden slot values - Answering early (ANSWER) to end the episode Hidden state (city, date, budget, style) is sampled at reset with a seeded RNG. The agent can ask about slots to reveal their values before answering. With MAX_STEPS=3, the agent can only ask 2 slots before being forced to answer, creating a non-trivial ask-vs-act tradeoff. The "style" slot is a distractor that provides less reward than core slots (city, date, budget). Rewards: - Step penalty: -0.05 per step - ASK unknown slot: +0.1 - ASK known slot: -0.2 - ANSWER: graded per-slot (+0.4 each core, +0.1 style) - Core all correct bonus: +0.2 - Core any wrong penalty: -0.6 - Auto-fail (steps exhausted): -1.0 """ SUPPORTS_CONCURRENT_SESSIONS: bool = True def __init__(self): """Initialize the ask_answer_env environment.""" self._state = State(episode_id=str(uuid4()), step_count=0) self._rng: random.Random = random.Random() # Hidden truth (sampled at reset) self._hidden_city: str = "" self._hidden_date: str = "" self._hidden_budget: str = "" self._hidden_style: str = "" # Known slots (revealed through ASK actions) self._known: KnownSlots = KnownSlots() self._steps_left: int = MAX_STEPS self._done: bool = False def reset(self, seed: Optional[int] = None) -> AskAnswerObservation: """ Reset the environment with optional seed for determinism. Args: seed: Random seed for reproducibility Returns: AskAnswerObservation with initial state """ self._state = State(episode_id=str(uuid4()), step_count=0) # Initialize RNG with seed if seed is not None: self._rng = random.Random(seed) else: self._rng = random.Random() # Sample hidden truth self._hidden_city = self._rng.choice(CITIES) self._hidden_date = self._rng.choice(DATES) self._hidden_budget = self._rng.choice(BUDGETS) self._hidden_style = self._rng.choice(STYLES) # Reset known slots and step counter self._known = KnownSlots() self._steps_left = MAX_STEPS self._done = False return AskAnswerObservation( prompt=PROMPT, known=self._known, steps_left=self._steps_left, done=False, reward=0.0, ) def step(self, action: AskAnswerAction) -> AskAnswerObservation: # type: ignore[override] """ Execute a step in the environment. Args: action: AskAnswerAction with type 'ask' or 'answer' Returns: AskAnswerObservation with updated state and reward """ if self._done: return AskAnswerObservation( prompt=PROMPT, known=self._known, steps_left=self._steps_left, done=True, reward=0.0, ) self._state.step_count += 1 # Always apply step penalty reward = STEP_PENALTY done = False if action.type == "ask": reward += self._handle_ask(action.slot) self._steps_left -= 1 # Check for auto-fail if self._steps_left == 0: reward = AUTO_FAIL_PENALTY done = True elif action.type == "answer": reward += self._handle_answer(action) done = True self._done = done # Calculate core_correct_count when episode ends via ANSWER core_correct_count = None if done and action.type == "answer": core_correct_count = sum([ action.city == self._hidden_city, action.date == self._hidden_date, action.budget == self._hidden_budget, ]) return AskAnswerObservation( prompt=PROMPT, known=self._known, steps_left=self._steps_left, done=done, reward=reward, core_correct_count=core_correct_count, ) def _handle_ask(self, slot: Optional[str]) -> float: """ Handle ASK action - reveal a slot if unknown. Args: slot: The slot to ask about ('city', 'date', 'budget', or 'style') Returns: Reward for the ASK action """ if slot == "city": if self._known.city is not None: return ASK_KNOWN_PENALTY self._known = KnownSlots( city=self._hidden_city, date=self._known.date, budget=self._known.budget, style=self._known.style, ) return ASK_UNKNOWN_REWARD elif slot == "date": if self._known.date is not None: return ASK_KNOWN_PENALTY self._known = KnownSlots( city=self._known.city, date=self._hidden_date, budget=self._known.budget, style=self._known.style, ) return ASK_UNKNOWN_REWARD elif slot == "budget": if self._known.budget is not None: return ASK_KNOWN_PENALTY self._known = KnownSlots( city=self._known.city, date=self._known.date, budget=self._hidden_budget, style=self._known.style, ) return ASK_UNKNOWN_REWARD elif slot == "style": if self._known.style is not None: return ASK_KNOWN_PENALTY self._known = KnownSlots( city=self._known.city, date=self._known.date, budget=self._known.budget, style=self._hidden_style, ) return ASK_UNKNOWN_REWARD # Invalid slot return ASK_KNOWN_PENALTY def _handle_answer(self, action: AskAnswerAction) -> float: """ Handle ANSWER action with graded rewards. Reward structure: - Per-slot rewards: +0.4 for each correct core slot (city, date, budget) - Style bonus: +0.1 if style provided and correct (ignored if None) - Core bonus: +0.2 if all core slots correct - Core penalty: -0.6 if any core slot wrong Args: action: The answer action with city, date, budget, style values Returns: Reward for the ANSWER action """ reward = 0.0 # Check core slots city_correct = action.city == self._hidden_city date_correct = action.date == self._hidden_date budget_correct = action.budget == self._hidden_budget # Per-slot rewards for core slots if city_correct: reward += ANSWER_CITY_CORRECT if date_correct: reward += ANSWER_DATE_CORRECT if budget_correct: reward += ANSWER_BUDGET_CORRECT # Style bonus (only if provided and correct, ignored if None) if action.style is not None and action.style == self._hidden_style: reward += ANSWER_STYLE_CORRECT_BONUS # Core bonus/penalty core_all_correct = city_correct and date_correct and budget_correct if core_all_correct: reward += ANSWER_CORE_ALL_CORRECT_BONUS else: reward += ANSWER_CORE_ANY_WRONG_PENALTY return reward @property def state(self) -> State: """ Get the current environment state. Returns: Current State with episode_id and step_count """ return self._state