ask_answer_env / exp.py
ujjwalsg's picture
Upload folder using huggingface_hub
371cfc1 verified
"""
Baseline agents for the Ask Answer Env environment (v1).
Tests different ask-vs-act strategies under budget constraints (MAX_STEPS=3).
With only 3 steps, agents can ask at most 2 slots before being forced to answer,
creating a non-trivial tradeoff between information gathering and guessing.
Baselines:
- A: city+date (ask city, ask date, guess budget)
- B: city+budget (ask city, ask budget, guess date)
- C: style+city (trap: wastes a question on distractor)
- Random: random actions
- Oracle: knows hidden state, answers immediately (upper bound)
"""
import random
from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple
from ask_answer_env import AskAnswerEnv, AskAnswerAction, KnownSlots
# Type aliases
HiddenTuple = Tuple[str, str, str, str] # (city, date, budget, style)
StrategyFn = Callable[[KnownSlots, int, Optional[HiddenTuple]], AskAnswerAction]
# Default guesses when slot is unknown
DEFAULT_CITY = "Paris"
DEFAULT_DATE = "mid_feb"
DEFAULT_BUDGET = "mid"
DEFAULT_STYLE = "relax"
# Valid slot values (for random baseline)
CITIES = ["Paris", "Rome", "Tokyo", "Goa"]
DATES = ["next_weekend", "mid_feb", "march"]
BUDGETS = ["low", "mid", "high"]
STYLES = ["relax", "adventure", "food"]
@dataclass
class EpisodeResult:
"""Result of a single episode."""
total_reward: float
revealed: HiddenTuple
steps_taken: int
core_correct_count: int # 0-3: how many core slots were correct
core_all_correct: bool # True if all 3 core slots correct
# =============================================================================
# Strategy Functions
# =============================================================================
def strategy_city_date(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction:
"""
Strategy A: Ask city, ask date, then answer (guess budget).
Expected behavior with MAX_STEPS=3:
- Step 1: ASK city
- Step 2: ASK date
- Step 3: ANSWER with known city+date, guess budget
"""
if known.city is None:
return AskAnswerAction(type="ask", slot="city")
elif known.date is None:
return AskAnswerAction(type="ask", slot="date")
else:
return AskAnswerAction(
type="answer",
city=known.city,
date=known.date,
budget=known.budget if known.budget else DEFAULT_BUDGET,
style=known.style, # None if not asked
)
def strategy_city_budget(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction:
"""
Strategy B: Ask city, ask budget, then answer (guess date).
Expected behavior with MAX_STEPS=3:
- Step 1: ASK city
- Step 2: ASK budget
- Step 3: ANSWER with known city+budget, guess date
"""
if known.city is None:
return AskAnswerAction(type="ask", slot="city")
elif known.budget is None:
return AskAnswerAction(type="ask", slot="budget")
else:
return AskAnswerAction(
type="answer",
city=known.city,
date=known.date if known.date else DEFAULT_DATE,
budget=known.budget,
style=known.style,
)
def strategy_style_city(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction:
"""
Strategy C (TRAP): Ask style first, then city, guess date+budget.
This is a BAD strategy because:
- Style only gives +0.1 bonus (vs +0.4 for core slots)
- Wastes a question on a low-value distractor
- Must guess 2 core slots instead of 1
Expected behavior with MAX_STEPS=3:
- Step 1: ASK style (bad choice!)
- Step 2: ASK city
- Step 3: ANSWER with known style+city, guess date+budget
"""
if known.style is None:
return AskAnswerAction(type="ask", slot="style")
elif known.city is None:
return AskAnswerAction(type="ask", slot="city")
else:
return AskAnswerAction(
type="answer",
city=known.city,
date=known.date if known.date else DEFAULT_DATE,
budget=known.budget if known.budget else DEFAULT_BUDGET,
style=known.style,
)
def strategy_random(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction:
"""
Random baseline: randomly ask or answer with random values.
50% chance to ask a random unknown slot, 50% chance to answer.
If no unknown slots, always answer.
"""
unknown_slots = []
if known.city is None:
unknown_slots.append("city")
if known.date is None:
unknown_slots.append("date")
if known.budget is None:
unknown_slots.append("budget")
if known.style is None:
unknown_slots.append("style")
# If we have unknown slots and coin flip says ask
if unknown_slots and random.random() < 0.5:
slot = random.choice(unknown_slots)
return AskAnswerAction(type="ask", slot=slot)
# Otherwise answer with random guesses for unknown slots
return AskAnswerAction(
type="answer",
city=known.city if known.city else random.choice(CITIES),
date=known.date if known.date else random.choice(DATES),
budget=known.budget if known.budget else random.choice(BUDGETS),
style=known.style if known.style else random.choice(STYLES),
)
def strategy_oracle(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction:
"""
Oracle baseline: knows hidden state, answers perfectly in 1 step.
This is the THEORETICAL UPPER BOUND.
In practice, this strategy function is NOT used because the server
doesn't expose hidden state to the client. Instead, we hardcode
the oracle's reward as 1.45 in run_baseline_test().
Reward breakdown:
-0.05 (step) + 0.4×3 (core) + 0.1 (style) + 0.2 (bonus) = +1.45
"""
if hidden is None:
raise ValueError("Oracle strategy requires hidden state")
city, date, budget, style = hidden
return AskAnswerAction(
type="answer",
city=city,
date=date,
budget=budget,
style=style,
)
# =============================================================================
# Episode Runner
# =============================================================================
def run_episode(
client: AskAnswerEnv,
strategy: StrategyFn,
seed: int = 42,
hidden: Optional[HiddenTuple] = None,
verbose: bool = False,
) -> EpisodeResult:
"""
Run a single episode with the given strategy.
Args:
client: AskAnswerEnv client instance
strategy: Function that takes (known, steps_left, hidden) and returns action
seed: Random seed for reproducibility
hidden: Hidden state tuple (required for oracle strategy)
verbose: Whether to print step-by-step info
Returns:
EpisodeResult with total_reward, revealed slots, and steps taken
"""
result = client.reset(seed=seed)
total_reward = 0.0
steps = 0
if verbose:
print(f"=== Episode Start (seed={seed}) ===")
print(f"Steps left: {result.observation.steps_left}")
while not result.done:
obs = result.observation
action = strategy(obs.known, obs.steps_left, hidden)
result = client.step(action)
total_reward += result.reward
steps += 1
if verbose:
if action.type == "ask":
slot_val = getattr(result.observation.known, action.slot)
print(f" Step {steps}: ASK {action.slot} -> {slot_val}, reward={result.reward:+.2f}")
else:
print(f" Step {steps}: ANSWER city={action.city}, date={action.date}, "
f"budget={action.budget}, style={action.style}, reward={result.reward:+.2f}")
final = result.observation.known
revealed = (final.city, final.date, final.budget, final.style)
# Extract correctness info (available when done=True after ANSWER)
core_correct_count = result.observation.core_correct_count or 0
core_all_correct = core_correct_count == 3
if verbose:
print(f" Total reward: {total_reward:+.2f}")
print(f" Core correct: {core_correct_count}/3")
print()
return EpisodeResult(
total_reward=total_reward,
revealed=revealed,
steps_taken=steps,
core_correct_count=core_correct_count,
core_all_correct=core_all_correct,
)
# =============================================================================
# Acceptance Tests
# =============================================================================
@dataclass
class BaselineStats:
"""Statistics for a baseline over multiple episodes."""
name: str
mean_reward: float
std_reward: float
positive_return_rate: float # % of episodes with reward > 0
core_success_rate: float # % of episodes with all 3 core slots correct
avg_core_correct: float # average number of core slots correct (0-3)
def run_baseline_test(
client: AskAnswerEnv,
name: str,
strategy: StrategyFn,
num_episodes: int = 200,
needs_oracle: bool = False,
) -> BaselineStats:
"""
Run multiple episodes with a strategy and compute statistics.
Args:
client: AskAnswerEnv client instance
name: Name of the baseline for logging
strategy: Strategy function
num_episodes: Number of episodes to run
needs_oracle: If True, use theoretical oracle values
Returns:
BaselineStats with all metrics
"""
if needs_oracle:
# Oracle is a THEORETICAL upper bound - knows hidden state,
# answers perfectly in 1 step.
#
# Reward: -0.05 + 0.4×3 + 0.1 + 0.2 = +1.45
# Core correct: 3/3 always
return BaselineStats(
name=name,
mean_reward=1.45,
std_reward=0.0,
positive_return_rate=1.0,
core_success_rate=1.0,
avg_core_correct=3.0,
)
results: List[EpisodeResult] = []
for seed in range(num_episodes):
result = run_episode(client, strategy, seed=seed)
results.append(result)
rewards = [r.total_reward for r in results]
mean_reward = sum(rewards) / len(rewards)
variance = sum((r - mean_reward) ** 2 for r in rewards) / len(rewards)
std_reward = variance ** 0.5
positive_return_rate = sum(1 for r in rewards if r > 0) / len(rewards)
core_success_rate = sum(1 for r in results if r.core_all_correct) / len(results)
avg_core_correct = sum(r.core_correct_count for r in results) / len(results)
return BaselineStats(
name=name,
mean_reward=mean_reward,
std_reward=std_reward,
positive_return_rate=positive_return_rate,
core_success_rate=core_success_rate,
avg_core_correct=avg_core_correct,
)
def run_acceptance_tests(client: AskAnswerEnv, num_episodes: int = 200) -> bool:
"""
Run all baseline tests and print results table.
Expected ordering:
Oracle > A ≈ B >> C > Random
"""
print(f"\nRunning {num_episodes} episodes per baseline...\n")
baselines = [
("Oracle (theoretical)", None, True),
("A: city+date", strategy_city_date, False),
("B: city+budget", strategy_city_budget, False),
("C: style+city (trap)", strategy_style_city, False),
("Random", strategy_random, False),
]
all_stats: List[BaselineStats] = []
for name, strategy, is_oracle in baselines:
stats = run_baseline_test(client, name, strategy, num_episodes, needs_oracle=is_oracle)
all_stats.append(stats)
print(f" {name}: mean={stats.mean_reward:+.3f}, core_success={stats.core_success_rate:.1%}")
# Print results table
print("\n" + "=" * 90)
print("RESULTS SUMMARY")
print("=" * 90)
header = f"{'Baseline':<22} {'Mean':>8} {'Std':>7} {'Pos%':>7} {'Core%':>7} {'AvgCore':>8}"
print(header)
print("-" * 90)
for s in sorted(all_stats, key=lambda x: -x.mean_reward):
print(f"{s.name:<22} {s.mean_reward:>+8.3f} {s.std_reward:>7.3f} "
f"{s.positive_return_rate:>6.0%} {s.core_success_rate:>6.0%} "
f"{s.avg_core_correct:>7.2f}/3")
print("-" * 90)
print("\nColumn legend:")
print(" Mean = mean total reward")
print(" Std = standard deviation of reward")
print(" Pos% = positive_return_rate (% episodes with reward > 0)")
print(" Core% = core_success_rate (% episodes with all 3 core slots correct)")
print(" AvgCore = avg_core_correct (mean # of core slots correct, out of 3)")
# Verify expected ordering
result_dict = {s.name: s.mean_reward for s in all_stats}
checks = [
("Oracle > A", result_dict["Oracle (theoretical)"] > result_dict["A: city+date"]),
("A ≈ B", abs(result_dict["A: city+date"] - result_dict["B: city+budget"]) < 0.1),
("A > C", result_dict["A: city+date"] > result_dict["C: style+city (trap)"]),
("C > Random", result_dict["C: style+city (trap)"] > result_dict["Random"]),
]
print("\nExpected ordering checks:")
all_passed = True
for check_name, passed in checks:
status = "PASS" if passed else "FAIL"
print(f" {check_name}: {status}")
if not passed:
all_passed = False
return all_passed
# =============================================================================
# Determinism Tests (kept from v0)
# =============================================================================
def test_determinism(client: AskAnswerEnv, seed: int = 42, runs: int = 3) -> bool:
"""Test that the same seed produces identical trajectories."""
trajectories = []
for _ in range(runs):
result = run_episode(client, strategy_city_date, seed=seed)
trajectories.append((result.total_reward, result.revealed))
rewards = [t[0] for t in trajectories]
revealed = [t[1] for t in trajectories]
identical = len(set(revealed)) == 1 and len(set(rewards)) == 1
print(f"Determinism (seed={seed}): {revealed[0]} x{runs}, identical={identical}")
return identical
def test_seed_sensitivity(client: AskAnswerEnv, num_seeds: int = 20) -> bool:
"""Verify different seeds produce different hidden states."""
unique = set()
for seed in range(num_seeds):
result = run_episode(client, strategy_city_date, seed=seed)
unique.add(result.revealed)
# Max possible: 4 * 3 * 3 * 3 = 108 (with style)
print(f"Seed sensitivity: {len(unique)} unique tuples from {num_seeds} seeds")
return len(unique) > 1
# =============================================================================
# Main
# =============================================================================
if __name__ == "__main__":
client = AskAnswerEnv.from_docker_image("ask_answer_env-env:latest")
try:
print("=" * 60)
print("ASK-ANSWER ENV v1 ACCEPTANCE TESTS")
print("=" * 60)
# Quick determinism check
print("\n1. DETERMINISM TESTS")
print("-" * 40)
test_determinism(client, seed=42)
test_determinism(client, seed=123)
test_seed_sensitivity(client)
# Run a single verbose episode to show behavior
print("\n2. EXAMPLE EPISODE (Strategy A: city+date)")
print("-" * 40)
run_episode(client, strategy_city_date, seed=42, verbose=True)
print("\n3. EXAMPLE EPISODE (Strategy C: style+city - TRAP)")
print("-" * 40)
run_episode(client, strategy_style_city, seed=42, verbose=True)
# Full acceptance tests
print("\n4. BASELINE COMPARISON")
print("-" * 40)
passed = run_acceptance_tests(client, num_episodes=200)
print("\n" + "=" * 60)
print(f"ALL TESTS: {'PASSED' if passed else 'FAILED'}")
print("=" * 60)
finally:
client.close()