ask_answer_env / server /ask_answer_env_environment.py
ujjwalsg's picture
Upload folder using huggingface_hub
371cfc1 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Ask Answer Env Environment Implementation.
A deterministic slot-filling environment where agents must decide between
asking clarifying questions or answering early to maximize reward.
"""
import random
from typing import Optional
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
from models import AskAnswerAction, AskAnswerObservation, KnownSlots
# Constants
CITIES = ["Paris", "Rome", "Tokyo", "Goa"]
DATES = ["next_weekend", "mid_feb", "march"]
BUDGETS = ["low", "mid", "high"]
STYLES = ["relax", "adventure", "food"] # Distractor slot
MAX_STEPS = 3 # Forces agent to guess at least 1 core slot
PROMPT = "Plan a short trip for me."
# Rewards (unchanged from v0)
STEP_PENALTY = -0.05
ASK_UNKNOWN_REWARD = 0.1
ASK_KNOWN_PENALTY = -0.2
AUTO_FAIL_PENALTY = -1.0
# Graded answer rewards (v1)
ANSWER_CITY_CORRECT = 0.4
ANSWER_DATE_CORRECT = 0.4
ANSWER_BUDGET_CORRECT = 0.4
ANSWER_STYLE_CORRECT_BONUS = 0.1 # Optional nice-to-have
ANSWER_CORE_ALL_CORRECT_BONUS = 0.2
ANSWER_CORE_ANY_WRONG_PENALTY = -0.6
class AskAnswerEnvironment(Environment):
"""
A slot-filling environment for training RL agents.
The agent must decide between:
- Asking clarifying questions (ASK) to reveal hidden slot values
- Answering early (ANSWER) to end the episode
Hidden state (city, date, budget, style) is sampled at reset with a seeded RNG.
The agent can ask about slots to reveal their values before answering.
With MAX_STEPS=3, the agent can only ask 2 slots before being forced to answer,
creating a non-trivial ask-vs-act tradeoff. The "style" slot is a distractor
that provides less reward than core slots (city, date, budget).
Rewards:
- Step penalty: -0.05 per step
- ASK unknown slot: +0.1
- ASK known slot: -0.2
- ANSWER: graded per-slot (+0.4 each core, +0.1 style)
- Core all correct bonus: +0.2
- Core any wrong penalty: -0.6
- Auto-fail (steps exhausted): -1.0
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self):
"""Initialize the ask_answer_env environment."""
self._state = State(episode_id=str(uuid4()), step_count=0)
self._rng: random.Random = random.Random()
# Hidden truth (sampled at reset)
self._hidden_city: str = ""
self._hidden_date: str = ""
self._hidden_budget: str = ""
self._hidden_style: str = ""
# Known slots (revealed through ASK actions)
self._known: KnownSlots = KnownSlots()
self._steps_left: int = MAX_STEPS
self._done: bool = False
def reset(self, seed: Optional[int] = None) -> AskAnswerObservation:
"""
Reset the environment with optional seed for determinism.
Args:
seed: Random seed for reproducibility
Returns:
AskAnswerObservation with initial state
"""
self._state = State(episode_id=str(uuid4()), step_count=0)
# Initialize RNG with seed
if seed is not None:
self._rng = random.Random(seed)
else:
self._rng = random.Random()
# Sample hidden truth
self._hidden_city = self._rng.choice(CITIES)
self._hidden_date = self._rng.choice(DATES)
self._hidden_budget = self._rng.choice(BUDGETS)
self._hidden_style = self._rng.choice(STYLES)
# Reset known slots and step counter
self._known = KnownSlots()
self._steps_left = MAX_STEPS
self._done = False
return AskAnswerObservation(
prompt=PROMPT,
known=self._known,
steps_left=self._steps_left,
done=False,
reward=0.0,
)
def step(self, action: AskAnswerAction) -> AskAnswerObservation: # type: ignore[override]
"""
Execute a step in the environment.
Args:
action: AskAnswerAction with type 'ask' or 'answer'
Returns:
AskAnswerObservation with updated state and reward
"""
if self._done:
return AskAnswerObservation(
prompt=PROMPT,
known=self._known,
steps_left=self._steps_left,
done=True,
reward=0.0,
)
self._state.step_count += 1
# Always apply step penalty
reward = STEP_PENALTY
done = False
if action.type == "ask":
reward += self._handle_ask(action.slot)
self._steps_left -= 1
# Check for auto-fail
if self._steps_left == 0:
reward = AUTO_FAIL_PENALTY
done = True
elif action.type == "answer":
reward += self._handle_answer(action)
done = True
self._done = done
# Calculate core_correct_count when episode ends via ANSWER
core_correct_count = None
if done and action.type == "answer":
core_correct_count = sum([
action.city == self._hidden_city,
action.date == self._hidden_date,
action.budget == self._hidden_budget,
])
return AskAnswerObservation(
prompt=PROMPT,
known=self._known,
steps_left=self._steps_left,
done=done,
reward=reward,
core_correct_count=core_correct_count,
)
def _handle_ask(self, slot: Optional[str]) -> float:
"""
Handle ASK action - reveal a slot if unknown.
Args:
slot: The slot to ask about ('city', 'date', 'budget', or 'style')
Returns:
Reward for the ASK action
"""
if slot == "city":
if self._known.city is not None:
return ASK_KNOWN_PENALTY
self._known = KnownSlots(
city=self._hidden_city,
date=self._known.date,
budget=self._known.budget,
style=self._known.style,
)
return ASK_UNKNOWN_REWARD
elif slot == "date":
if self._known.date is not None:
return ASK_KNOWN_PENALTY
self._known = KnownSlots(
city=self._known.city,
date=self._hidden_date,
budget=self._known.budget,
style=self._known.style,
)
return ASK_UNKNOWN_REWARD
elif slot == "budget":
if self._known.budget is not None:
return ASK_KNOWN_PENALTY
self._known = KnownSlots(
city=self._known.city,
date=self._known.date,
budget=self._hidden_budget,
style=self._known.style,
)
return ASK_UNKNOWN_REWARD
elif slot == "style":
if self._known.style is not None:
return ASK_KNOWN_PENALTY
self._known = KnownSlots(
city=self._known.city,
date=self._known.date,
budget=self._known.budget,
style=self._hidden_style,
)
return ASK_UNKNOWN_REWARD
# Invalid slot
return ASK_KNOWN_PENALTY
def _handle_answer(self, action: AskAnswerAction) -> float:
"""
Handle ANSWER action with graded rewards.
Reward structure:
- Per-slot rewards: +0.4 for each correct core slot (city, date, budget)
- Style bonus: +0.1 if style provided and correct (ignored if None)
- Core bonus: +0.2 if all core slots correct
- Core penalty: -0.6 if any core slot wrong
Args:
action: The answer action with city, date, budget, style values
Returns:
Reward for the ANSWER action
"""
reward = 0.0
# Check core slots
city_correct = action.city == self._hidden_city
date_correct = action.date == self._hidden_date
budget_correct = action.budget == self._hidden_budget
# Per-slot rewards for core slots
if city_correct:
reward += ANSWER_CITY_CORRECT
if date_correct:
reward += ANSWER_DATE_CORRECT
if budget_correct:
reward += ANSWER_BUDGET_CORRECT
# Style bonus (only if provided and correct, ignored if None)
if action.style is not None and action.style == self._hidden_style:
reward += ANSWER_STYLE_CORRECT_BONUS
# Core bonus/penalty
core_all_correct = city_correct and date_correct and budget_correct
if core_all_correct:
reward += ANSWER_CORE_ALL_CORRECT_BONUS
else:
reward += ANSWER_CORE_ANY_WRONG_PENALTY
return reward
@property
def state(self) -> State:
"""
Get the current environment state.
Returns:
Current State with episode_id and step_count
"""
return self._state