Spaces:
Sleeping
Sleeping
| """ | |
| Baseline agents for the Ask Answer Env environment (v1). | |
| Tests different ask-vs-act strategies under budget constraints (MAX_STEPS=3). | |
| With only 3 steps, agents can ask at most 2 slots before being forced to answer, | |
| creating a non-trivial tradeoff between information gathering and guessing. | |
| Baselines: | |
| - A: city+date (ask city, ask date, guess budget) | |
| - B: city+budget (ask city, ask budget, guess date) | |
| - C: style+city (trap: wastes a question on distractor) | |
| - Random: random actions | |
| - Oracle: knows hidden state, answers immediately (upper bound) | |
| """ | |
| import random | |
| from dataclasses import dataclass | |
| from typing import Callable, List, Optional, Tuple | |
| from ask_answer_env import AskAnswerEnv, AskAnswerAction, KnownSlots | |
| # Type aliases | |
| HiddenTuple = Tuple[str, str, str, str] # (city, date, budget, style) | |
| StrategyFn = Callable[[KnownSlots, int, Optional[HiddenTuple]], AskAnswerAction] | |
| # Default guesses when slot is unknown | |
| DEFAULT_CITY = "Paris" | |
| DEFAULT_DATE = "mid_feb" | |
| DEFAULT_BUDGET = "mid" | |
| DEFAULT_STYLE = "relax" | |
| # Valid slot values (for random baseline) | |
| CITIES = ["Paris", "Rome", "Tokyo", "Goa"] | |
| DATES = ["next_weekend", "mid_feb", "march"] | |
| BUDGETS = ["low", "mid", "high"] | |
| STYLES = ["relax", "adventure", "food"] | |
| class EpisodeResult: | |
| """Result of a single episode.""" | |
| total_reward: float | |
| revealed: HiddenTuple | |
| steps_taken: int | |
| core_correct_count: int # 0-3: how many core slots were correct | |
| core_all_correct: bool # True if all 3 core slots correct | |
| # ============================================================================= | |
| # Strategy Functions | |
| # ============================================================================= | |
| def strategy_city_date(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction: | |
| """ | |
| Strategy A: Ask city, ask date, then answer (guess budget). | |
| Expected behavior with MAX_STEPS=3: | |
| - Step 1: ASK city | |
| - Step 2: ASK date | |
| - Step 3: ANSWER with known city+date, guess budget | |
| """ | |
| if known.city is None: | |
| return AskAnswerAction(type="ask", slot="city") | |
| elif known.date is None: | |
| return AskAnswerAction(type="ask", slot="date") | |
| else: | |
| return AskAnswerAction( | |
| type="answer", | |
| city=known.city, | |
| date=known.date, | |
| budget=known.budget if known.budget else DEFAULT_BUDGET, | |
| style=known.style, # None if not asked | |
| ) | |
| def strategy_city_budget(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction: | |
| """ | |
| Strategy B: Ask city, ask budget, then answer (guess date). | |
| Expected behavior with MAX_STEPS=3: | |
| - Step 1: ASK city | |
| - Step 2: ASK budget | |
| - Step 3: ANSWER with known city+budget, guess date | |
| """ | |
| if known.city is None: | |
| return AskAnswerAction(type="ask", slot="city") | |
| elif known.budget is None: | |
| return AskAnswerAction(type="ask", slot="budget") | |
| else: | |
| return AskAnswerAction( | |
| type="answer", | |
| city=known.city, | |
| date=known.date if known.date else DEFAULT_DATE, | |
| budget=known.budget, | |
| style=known.style, | |
| ) | |
| def strategy_style_city(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction: | |
| """ | |
| Strategy C (TRAP): Ask style first, then city, guess date+budget. | |
| This is a BAD strategy because: | |
| - Style only gives +0.1 bonus (vs +0.4 for core slots) | |
| - Wastes a question on a low-value distractor | |
| - Must guess 2 core slots instead of 1 | |
| Expected behavior with MAX_STEPS=3: | |
| - Step 1: ASK style (bad choice!) | |
| - Step 2: ASK city | |
| - Step 3: ANSWER with known style+city, guess date+budget | |
| """ | |
| if known.style is None: | |
| return AskAnswerAction(type="ask", slot="style") | |
| elif known.city is None: | |
| return AskAnswerAction(type="ask", slot="city") | |
| else: | |
| return AskAnswerAction( | |
| type="answer", | |
| city=known.city, | |
| date=known.date if known.date else DEFAULT_DATE, | |
| budget=known.budget if known.budget else DEFAULT_BUDGET, | |
| style=known.style, | |
| ) | |
| def strategy_random(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction: | |
| """ | |
| Random baseline: randomly ask or answer with random values. | |
| 50% chance to ask a random unknown slot, 50% chance to answer. | |
| If no unknown slots, always answer. | |
| """ | |
| unknown_slots = [] | |
| if known.city is None: | |
| unknown_slots.append("city") | |
| if known.date is None: | |
| unknown_slots.append("date") | |
| if known.budget is None: | |
| unknown_slots.append("budget") | |
| if known.style is None: | |
| unknown_slots.append("style") | |
| # If we have unknown slots and coin flip says ask | |
| if unknown_slots and random.random() < 0.5: | |
| slot = random.choice(unknown_slots) | |
| return AskAnswerAction(type="ask", slot=slot) | |
| # Otherwise answer with random guesses for unknown slots | |
| return AskAnswerAction( | |
| type="answer", | |
| city=known.city if known.city else random.choice(CITIES), | |
| date=known.date if known.date else random.choice(DATES), | |
| budget=known.budget if known.budget else random.choice(BUDGETS), | |
| style=known.style if known.style else random.choice(STYLES), | |
| ) | |
| def strategy_oracle(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction: | |
| """ | |
| Oracle baseline: knows hidden state, answers perfectly in 1 step. | |
| This is the THEORETICAL UPPER BOUND. | |
| In practice, this strategy function is NOT used because the server | |
| doesn't expose hidden state to the client. Instead, we hardcode | |
| the oracle's reward as 1.45 in run_baseline_test(). | |
| Reward breakdown: | |
| -0.05 (step) + 0.4×3 (core) + 0.1 (style) + 0.2 (bonus) = +1.45 | |
| """ | |
| if hidden is None: | |
| raise ValueError("Oracle strategy requires hidden state") | |
| city, date, budget, style = hidden | |
| return AskAnswerAction( | |
| type="answer", | |
| city=city, | |
| date=date, | |
| budget=budget, | |
| style=style, | |
| ) | |
| # ============================================================================= | |
| # Episode Runner | |
| # ============================================================================= | |
| def run_episode( | |
| client: AskAnswerEnv, | |
| strategy: StrategyFn, | |
| seed: int = 42, | |
| hidden: Optional[HiddenTuple] = None, | |
| verbose: bool = False, | |
| ) -> EpisodeResult: | |
| """ | |
| Run a single episode with the given strategy. | |
| Args: | |
| client: AskAnswerEnv client instance | |
| strategy: Function that takes (known, steps_left, hidden) and returns action | |
| seed: Random seed for reproducibility | |
| hidden: Hidden state tuple (required for oracle strategy) | |
| verbose: Whether to print step-by-step info | |
| Returns: | |
| EpisodeResult with total_reward, revealed slots, and steps taken | |
| """ | |
| result = client.reset(seed=seed) | |
| total_reward = 0.0 | |
| steps = 0 | |
| if verbose: | |
| print(f"=== Episode Start (seed={seed}) ===") | |
| print(f"Steps left: {result.observation.steps_left}") | |
| while not result.done: | |
| obs = result.observation | |
| action = strategy(obs.known, obs.steps_left, hidden) | |
| result = client.step(action) | |
| total_reward += result.reward | |
| steps += 1 | |
| if verbose: | |
| if action.type == "ask": | |
| slot_val = getattr(result.observation.known, action.slot) | |
| print(f" Step {steps}: ASK {action.slot} -> {slot_val}, reward={result.reward:+.2f}") | |
| else: | |
| print(f" Step {steps}: ANSWER city={action.city}, date={action.date}, " | |
| f"budget={action.budget}, style={action.style}, reward={result.reward:+.2f}") | |
| final = result.observation.known | |
| revealed = (final.city, final.date, final.budget, final.style) | |
| # Extract correctness info (available when done=True after ANSWER) | |
| core_correct_count = result.observation.core_correct_count or 0 | |
| core_all_correct = core_correct_count == 3 | |
| if verbose: | |
| print(f" Total reward: {total_reward:+.2f}") | |
| print(f" Core correct: {core_correct_count}/3") | |
| print() | |
| return EpisodeResult( | |
| total_reward=total_reward, | |
| revealed=revealed, | |
| steps_taken=steps, | |
| core_correct_count=core_correct_count, | |
| core_all_correct=core_all_correct, | |
| ) | |
| # ============================================================================= | |
| # Acceptance Tests | |
| # ============================================================================= | |
| class BaselineStats: | |
| """Statistics for a baseline over multiple episodes.""" | |
| name: str | |
| mean_reward: float | |
| std_reward: float | |
| positive_return_rate: float # % of episodes with reward > 0 | |
| core_success_rate: float # % of episodes with all 3 core slots correct | |
| avg_core_correct: float # average number of core slots correct (0-3) | |
| def run_baseline_test( | |
| client: AskAnswerEnv, | |
| name: str, | |
| strategy: StrategyFn, | |
| num_episodes: int = 200, | |
| needs_oracle: bool = False, | |
| ) -> BaselineStats: | |
| """ | |
| Run multiple episodes with a strategy and compute statistics. | |
| Args: | |
| client: AskAnswerEnv client instance | |
| name: Name of the baseline for logging | |
| strategy: Strategy function | |
| num_episodes: Number of episodes to run | |
| needs_oracle: If True, use theoretical oracle values | |
| Returns: | |
| BaselineStats with all metrics | |
| """ | |
| if needs_oracle: | |
| # Oracle is a THEORETICAL upper bound - knows hidden state, | |
| # answers perfectly in 1 step. | |
| # | |
| # Reward: -0.05 + 0.4×3 + 0.1 + 0.2 = +1.45 | |
| # Core correct: 3/3 always | |
| return BaselineStats( | |
| name=name, | |
| mean_reward=1.45, | |
| std_reward=0.0, | |
| positive_return_rate=1.0, | |
| core_success_rate=1.0, | |
| avg_core_correct=3.0, | |
| ) | |
| results: List[EpisodeResult] = [] | |
| for seed in range(num_episodes): | |
| result = run_episode(client, strategy, seed=seed) | |
| results.append(result) | |
| rewards = [r.total_reward for r in results] | |
| mean_reward = sum(rewards) / len(rewards) | |
| variance = sum((r - mean_reward) ** 2 for r in rewards) / len(rewards) | |
| std_reward = variance ** 0.5 | |
| positive_return_rate = sum(1 for r in rewards if r > 0) / len(rewards) | |
| core_success_rate = sum(1 for r in results if r.core_all_correct) / len(results) | |
| avg_core_correct = sum(r.core_correct_count for r in results) / len(results) | |
| return BaselineStats( | |
| name=name, | |
| mean_reward=mean_reward, | |
| std_reward=std_reward, | |
| positive_return_rate=positive_return_rate, | |
| core_success_rate=core_success_rate, | |
| avg_core_correct=avg_core_correct, | |
| ) | |
| def run_acceptance_tests(client: AskAnswerEnv, num_episodes: int = 200) -> bool: | |
| """ | |
| Run all baseline tests and print results table. | |
| Expected ordering: | |
| Oracle > A ≈ B >> C > Random | |
| """ | |
| print(f"\nRunning {num_episodes} episodes per baseline...\n") | |
| baselines = [ | |
| ("Oracle (theoretical)", None, True), | |
| ("A: city+date", strategy_city_date, False), | |
| ("B: city+budget", strategy_city_budget, False), | |
| ("C: style+city (trap)", strategy_style_city, False), | |
| ("Random", strategy_random, False), | |
| ] | |
| all_stats: List[BaselineStats] = [] | |
| for name, strategy, is_oracle in baselines: | |
| stats = run_baseline_test(client, name, strategy, num_episodes, needs_oracle=is_oracle) | |
| all_stats.append(stats) | |
| print(f" {name}: mean={stats.mean_reward:+.3f}, core_success={stats.core_success_rate:.1%}") | |
| # Print results table | |
| print("\n" + "=" * 90) | |
| print("RESULTS SUMMARY") | |
| print("=" * 90) | |
| header = f"{'Baseline':<22} {'Mean':>8} {'Std':>7} {'Pos%':>7} {'Core%':>7} {'AvgCore':>8}" | |
| print(header) | |
| print("-" * 90) | |
| for s in sorted(all_stats, key=lambda x: -x.mean_reward): | |
| print(f"{s.name:<22} {s.mean_reward:>+8.3f} {s.std_reward:>7.3f} " | |
| f"{s.positive_return_rate:>6.0%} {s.core_success_rate:>6.0%} " | |
| f"{s.avg_core_correct:>7.2f}/3") | |
| print("-" * 90) | |
| print("\nColumn legend:") | |
| print(" Mean = mean total reward") | |
| print(" Std = standard deviation of reward") | |
| print(" Pos% = positive_return_rate (% episodes with reward > 0)") | |
| print(" Core% = core_success_rate (% episodes with all 3 core slots correct)") | |
| print(" AvgCore = avg_core_correct (mean # of core slots correct, out of 3)") | |
| # Verify expected ordering | |
| result_dict = {s.name: s.mean_reward for s in all_stats} | |
| checks = [ | |
| ("Oracle > A", result_dict["Oracle (theoretical)"] > result_dict["A: city+date"]), | |
| ("A ≈ B", abs(result_dict["A: city+date"] - result_dict["B: city+budget"]) < 0.1), | |
| ("A > C", result_dict["A: city+date"] > result_dict["C: style+city (trap)"]), | |
| ("C > Random", result_dict["C: style+city (trap)"] > result_dict["Random"]), | |
| ] | |
| print("\nExpected ordering checks:") | |
| all_passed = True | |
| for check_name, passed in checks: | |
| status = "PASS" if passed else "FAIL" | |
| print(f" {check_name}: {status}") | |
| if not passed: | |
| all_passed = False | |
| return all_passed | |
| # ============================================================================= | |
| # Determinism Tests (kept from v0) | |
| # ============================================================================= | |
| def test_determinism(client: AskAnswerEnv, seed: int = 42, runs: int = 3) -> bool: | |
| """Test that the same seed produces identical trajectories.""" | |
| trajectories = [] | |
| for _ in range(runs): | |
| result = run_episode(client, strategy_city_date, seed=seed) | |
| trajectories.append((result.total_reward, result.revealed)) | |
| rewards = [t[0] for t in trajectories] | |
| revealed = [t[1] for t in trajectories] | |
| identical = len(set(revealed)) == 1 and len(set(rewards)) == 1 | |
| print(f"Determinism (seed={seed}): {revealed[0]} x{runs}, identical={identical}") | |
| return identical | |
| def test_seed_sensitivity(client: AskAnswerEnv, num_seeds: int = 20) -> bool: | |
| """Verify different seeds produce different hidden states.""" | |
| unique = set() | |
| for seed in range(num_seeds): | |
| result = run_episode(client, strategy_city_date, seed=seed) | |
| unique.add(result.revealed) | |
| # Max possible: 4 * 3 * 3 * 3 = 108 (with style) | |
| print(f"Seed sensitivity: {len(unique)} unique tuples from {num_seeds} seeds") | |
| return len(unique) > 1 | |
| # ============================================================================= | |
| # Main | |
| # ============================================================================= | |
| if __name__ == "__main__": | |
| client = AskAnswerEnv.from_docker_image("ask_answer_env-env:latest") | |
| try: | |
| print("=" * 60) | |
| print("ASK-ANSWER ENV v1 ACCEPTANCE TESTS") | |
| print("=" * 60) | |
| # Quick determinism check | |
| print("\n1. DETERMINISM TESTS") | |
| print("-" * 40) | |
| test_determinism(client, seed=42) | |
| test_determinism(client, seed=123) | |
| test_seed_sensitivity(client) | |
| # Run a single verbose episode to show behavior | |
| print("\n2. EXAMPLE EPISODE (Strategy A: city+date)") | |
| print("-" * 40) | |
| run_episode(client, strategy_city_date, seed=42, verbose=True) | |
| print("\n3. EXAMPLE EPISODE (Strategy C: style+city - TRAP)") | |
| print("-" * 40) | |
| run_episode(client, strategy_style_city, seed=42, verbose=True) | |
| # Full acceptance tests | |
| print("\n4. BASELINE COMPARISON") | |
| print("-" * 40) | |
| passed = run_acceptance_tests(client, num_episodes=200) | |
| print("\n" + "=" * 60) | |
| print(f"ALL TESTS: {'PASSED' if passed else 'FAILED'}") | |
| print("=" * 60) | |
| finally: | |
| client.close() | |