""" Baseline agents for the Ask Answer Env environment (v1). Tests different ask-vs-act strategies under budget constraints (MAX_STEPS=3). With only 3 steps, agents can ask at most 2 slots before being forced to answer, creating a non-trivial tradeoff between information gathering and guessing. Baselines: - A: city+date (ask city, ask date, guess budget) - B: city+budget (ask city, ask budget, guess date) - C: style+city (trap: wastes a question on distractor) - Random: random actions - Oracle: knows hidden state, answers immediately (upper bound) """ import random from dataclasses import dataclass from typing import Callable, List, Optional, Tuple from ask_answer_env import AskAnswerEnv, AskAnswerAction, KnownSlots # Type aliases HiddenTuple = Tuple[str, str, str, str] # (city, date, budget, style) StrategyFn = Callable[[KnownSlots, int, Optional[HiddenTuple]], AskAnswerAction] # Default guesses when slot is unknown DEFAULT_CITY = "Paris" DEFAULT_DATE = "mid_feb" DEFAULT_BUDGET = "mid" DEFAULT_STYLE = "relax" # Valid slot values (for random baseline) CITIES = ["Paris", "Rome", "Tokyo", "Goa"] DATES = ["next_weekend", "mid_feb", "march"] BUDGETS = ["low", "mid", "high"] STYLES = ["relax", "adventure", "food"] @dataclass class EpisodeResult: """Result of a single episode.""" total_reward: float revealed: HiddenTuple steps_taken: int core_correct_count: int # 0-3: how many core slots were correct core_all_correct: bool # True if all 3 core slots correct # ============================================================================= # Strategy Functions # ============================================================================= def strategy_city_date(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction: """ Strategy A: Ask city, ask date, then answer (guess budget). Expected behavior with MAX_STEPS=3: - Step 1: ASK city - Step 2: ASK date - Step 3: ANSWER with known city+date, guess budget """ if known.city is None: return AskAnswerAction(type="ask", slot="city") elif known.date is None: return AskAnswerAction(type="ask", slot="date") else: return AskAnswerAction( type="answer", city=known.city, date=known.date, budget=known.budget if known.budget else DEFAULT_BUDGET, style=known.style, # None if not asked ) def strategy_city_budget(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction: """ Strategy B: Ask city, ask budget, then answer (guess date). Expected behavior with MAX_STEPS=3: - Step 1: ASK city - Step 2: ASK budget - Step 3: ANSWER with known city+budget, guess date """ if known.city is None: return AskAnswerAction(type="ask", slot="city") elif known.budget is None: return AskAnswerAction(type="ask", slot="budget") else: return AskAnswerAction( type="answer", city=known.city, date=known.date if known.date else DEFAULT_DATE, budget=known.budget, style=known.style, ) def strategy_style_city(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction: """ Strategy C (TRAP): Ask style first, then city, guess date+budget. This is a BAD strategy because: - Style only gives +0.1 bonus (vs +0.4 for core slots) - Wastes a question on a low-value distractor - Must guess 2 core slots instead of 1 Expected behavior with MAX_STEPS=3: - Step 1: ASK style (bad choice!) - Step 2: ASK city - Step 3: ANSWER with known style+city, guess date+budget """ if known.style is None: return AskAnswerAction(type="ask", slot="style") elif known.city is None: return AskAnswerAction(type="ask", slot="city") else: return AskAnswerAction( type="answer", city=known.city, date=known.date if known.date else DEFAULT_DATE, budget=known.budget if known.budget else DEFAULT_BUDGET, style=known.style, ) def strategy_random(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction: """ Random baseline: randomly ask or answer with random values. 50% chance to ask a random unknown slot, 50% chance to answer. If no unknown slots, always answer. """ unknown_slots = [] if known.city is None: unknown_slots.append("city") if known.date is None: unknown_slots.append("date") if known.budget is None: unknown_slots.append("budget") if known.style is None: unknown_slots.append("style") # If we have unknown slots and coin flip says ask if unknown_slots and random.random() < 0.5: slot = random.choice(unknown_slots) return AskAnswerAction(type="ask", slot=slot) # Otherwise answer with random guesses for unknown slots return AskAnswerAction( type="answer", city=known.city if known.city else random.choice(CITIES), date=known.date if known.date else random.choice(DATES), budget=known.budget if known.budget else random.choice(BUDGETS), style=known.style if known.style else random.choice(STYLES), ) def strategy_oracle(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction: """ Oracle baseline: knows hidden state, answers perfectly in 1 step. This is the THEORETICAL UPPER BOUND. In practice, this strategy function is NOT used because the server doesn't expose hidden state to the client. Instead, we hardcode the oracle's reward as 1.45 in run_baseline_test(). Reward breakdown: -0.05 (step) + 0.4×3 (core) + 0.1 (style) + 0.2 (bonus) = +1.45 """ if hidden is None: raise ValueError("Oracle strategy requires hidden state") city, date, budget, style = hidden return AskAnswerAction( type="answer", city=city, date=date, budget=budget, style=style, ) # ============================================================================= # Episode Runner # ============================================================================= def run_episode( client: AskAnswerEnv, strategy: StrategyFn, seed: int = 42, hidden: Optional[HiddenTuple] = None, verbose: bool = False, ) -> EpisodeResult: """ Run a single episode with the given strategy. Args: client: AskAnswerEnv client instance strategy: Function that takes (known, steps_left, hidden) and returns action seed: Random seed for reproducibility hidden: Hidden state tuple (required for oracle strategy) verbose: Whether to print step-by-step info Returns: EpisodeResult with total_reward, revealed slots, and steps taken """ result = client.reset(seed=seed) total_reward = 0.0 steps = 0 if verbose: print(f"=== Episode Start (seed={seed}) ===") print(f"Steps left: {result.observation.steps_left}") while not result.done: obs = result.observation action = strategy(obs.known, obs.steps_left, hidden) result = client.step(action) total_reward += result.reward steps += 1 if verbose: if action.type == "ask": slot_val = getattr(result.observation.known, action.slot) print(f" Step {steps}: ASK {action.slot} -> {slot_val}, reward={result.reward:+.2f}") else: print(f" Step {steps}: ANSWER city={action.city}, date={action.date}, " f"budget={action.budget}, style={action.style}, reward={result.reward:+.2f}") final = result.observation.known revealed = (final.city, final.date, final.budget, final.style) # Extract correctness info (available when done=True after ANSWER) core_correct_count = result.observation.core_correct_count or 0 core_all_correct = core_correct_count == 3 if verbose: print(f" Total reward: {total_reward:+.2f}") print(f" Core correct: {core_correct_count}/3") print() return EpisodeResult( total_reward=total_reward, revealed=revealed, steps_taken=steps, core_correct_count=core_correct_count, core_all_correct=core_all_correct, ) # ============================================================================= # Acceptance Tests # ============================================================================= @dataclass class BaselineStats: """Statistics for a baseline over multiple episodes.""" name: str mean_reward: float std_reward: float positive_return_rate: float # % of episodes with reward > 0 core_success_rate: float # % of episodes with all 3 core slots correct avg_core_correct: float # average number of core slots correct (0-3) def run_baseline_test( client: AskAnswerEnv, name: str, strategy: StrategyFn, num_episodes: int = 200, needs_oracle: bool = False, ) -> BaselineStats: """ Run multiple episodes with a strategy and compute statistics. Args: client: AskAnswerEnv client instance name: Name of the baseline for logging strategy: Strategy function num_episodes: Number of episodes to run needs_oracle: If True, use theoretical oracle values Returns: BaselineStats with all metrics """ if needs_oracle: # Oracle is a THEORETICAL upper bound - knows hidden state, # answers perfectly in 1 step. # # Reward: -0.05 + 0.4×3 + 0.1 + 0.2 = +1.45 # Core correct: 3/3 always return BaselineStats( name=name, mean_reward=1.45, std_reward=0.0, positive_return_rate=1.0, core_success_rate=1.0, avg_core_correct=3.0, ) results: List[EpisodeResult] = [] for seed in range(num_episodes): result = run_episode(client, strategy, seed=seed) results.append(result) rewards = [r.total_reward for r in results] mean_reward = sum(rewards) / len(rewards) variance = sum((r - mean_reward) ** 2 for r in rewards) / len(rewards) std_reward = variance ** 0.5 positive_return_rate = sum(1 for r in rewards if r > 0) / len(rewards) core_success_rate = sum(1 for r in results if r.core_all_correct) / len(results) avg_core_correct = sum(r.core_correct_count for r in results) / len(results) return BaselineStats( name=name, mean_reward=mean_reward, std_reward=std_reward, positive_return_rate=positive_return_rate, core_success_rate=core_success_rate, avg_core_correct=avg_core_correct, ) def run_acceptance_tests(client: AskAnswerEnv, num_episodes: int = 200) -> bool: """ Run all baseline tests and print results table. Expected ordering: Oracle > A ≈ B >> C > Random """ print(f"\nRunning {num_episodes} episodes per baseline...\n") baselines = [ ("Oracle (theoretical)", None, True), ("A: city+date", strategy_city_date, False), ("B: city+budget", strategy_city_budget, False), ("C: style+city (trap)", strategy_style_city, False), ("Random", strategy_random, False), ] all_stats: List[BaselineStats] = [] for name, strategy, is_oracle in baselines: stats = run_baseline_test(client, name, strategy, num_episodes, needs_oracle=is_oracle) all_stats.append(stats) print(f" {name}: mean={stats.mean_reward:+.3f}, core_success={stats.core_success_rate:.1%}") # Print results table print("\n" + "=" * 90) print("RESULTS SUMMARY") print("=" * 90) header = f"{'Baseline':<22} {'Mean':>8} {'Std':>7} {'Pos%':>7} {'Core%':>7} {'AvgCore':>8}" print(header) print("-" * 90) for s in sorted(all_stats, key=lambda x: -x.mean_reward): print(f"{s.name:<22} {s.mean_reward:>+8.3f} {s.std_reward:>7.3f} " f"{s.positive_return_rate:>6.0%} {s.core_success_rate:>6.0%} " f"{s.avg_core_correct:>7.2f}/3") print("-" * 90) print("\nColumn legend:") print(" Mean = mean total reward") print(" Std = standard deviation of reward") print(" Pos% = positive_return_rate (% episodes with reward > 0)") print(" Core% = core_success_rate (% episodes with all 3 core slots correct)") print(" AvgCore = avg_core_correct (mean # of core slots correct, out of 3)") # Verify expected ordering result_dict = {s.name: s.mean_reward for s in all_stats} checks = [ ("Oracle > A", result_dict["Oracle (theoretical)"] > result_dict["A: city+date"]), ("A ≈ B", abs(result_dict["A: city+date"] - result_dict["B: city+budget"]) < 0.1), ("A > C", result_dict["A: city+date"] > result_dict["C: style+city (trap)"]), ("C > Random", result_dict["C: style+city (trap)"] > result_dict["Random"]), ] print("\nExpected ordering checks:") all_passed = True for check_name, passed in checks: status = "PASS" if passed else "FAIL" print(f" {check_name}: {status}") if not passed: all_passed = False return all_passed # ============================================================================= # Determinism Tests (kept from v0) # ============================================================================= def test_determinism(client: AskAnswerEnv, seed: int = 42, runs: int = 3) -> bool: """Test that the same seed produces identical trajectories.""" trajectories = [] for _ in range(runs): result = run_episode(client, strategy_city_date, seed=seed) trajectories.append((result.total_reward, result.revealed)) rewards = [t[0] for t in trajectories] revealed = [t[1] for t in trajectories] identical = len(set(revealed)) == 1 and len(set(rewards)) == 1 print(f"Determinism (seed={seed}): {revealed[0]} x{runs}, identical={identical}") return identical def test_seed_sensitivity(client: AskAnswerEnv, num_seeds: int = 20) -> bool: """Verify different seeds produce different hidden states.""" unique = set() for seed in range(num_seeds): result = run_episode(client, strategy_city_date, seed=seed) unique.add(result.revealed) # Max possible: 4 * 3 * 3 * 3 = 108 (with style) print(f"Seed sensitivity: {len(unique)} unique tuples from {num_seeds} seeds") return len(unique) > 1 # ============================================================================= # Main # ============================================================================= if __name__ == "__main__": client = AskAnswerEnv.from_docker_image("ask_answer_env-env:latest") try: print("=" * 60) print("ASK-ANSWER ENV v1 ACCEPTANCE TESTS") print("=" * 60) # Quick determinism check print("\n1. DETERMINISM TESTS") print("-" * 40) test_determinism(client, seed=42) test_determinism(client, seed=123) test_seed_sensitivity(client) # Run a single verbose episode to show behavior print("\n2. EXAMPLE EPISODE (Strategy A: city+date)") print("-" * 40) run_episode(client, strategy_city_date, seed=42, verbose=True) print("\n3. EXAMPLE EPISODE (Strategy C: style+city - TRAP)") print("-" * 40) run_episode(client, strategy_style_city, seed=42, verbose=True) # Full acceptance tests print("\n4. BASELINE COMPARISON") print("-" * 40) passed = run_acceptance_tests(client, num_episodes=200) print("\n" + "=" * 60) print(f"ALL TESTS: {'PASSED' if passed else 'FAILED'}") print("=" * 60) finally: client.close()