File size: 16,255 Bytes
371cfc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
"""
Baseline agents for the Ask Answer Env environment (v1).

Tests different ask-vs-act strategies under budget constraints (MAX_STEPS=3).
With only 3 steps, agents can ask at most 2 slots before being forced to answer,
creating a non-trivial tradeoff between information gathering and guessing.

Baselines:
- A: city+date (ask city, ask date, guess budget)
- B: city+budget (ask city, ask budget, guess date)
- C: style+city (trap: wastes a question on distractor)
- Random: random actions
- Oracle: knows hidden state, answers immediately (upper bound)
"""

import random
from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple

from ask_answer_env import AskAnswerEnv, AskAnswerAction, KnownSlots


# Type aliases
HiddenTuple = Tuple[str, str, str, str]  # (city, date, budget, style)
StrategyFn = Callable[[KnownSlots, int, Optional[HiddenTuple]], AskAnswerAction]

# Default guesses when slot is unknown
DEFAULT_CITY = "Paris"
DEFAULT_DATE = "mid_feb"
DEFAULT_BUDGET = "mid"
DEFAULT_STYLE = "relax"

# Valid slot values (for random baseline)
CITIES = ["Paris", "Rome", "Tokyo", "Goa"]
DATES = ["next_weekend", "mid_feb", "march"]
BUDGETS = ["low", "mid", "high"]
STYLES = ["relax", "adventure", "food"]


@dataclass
class EpisodeResult:
    """Result of a single episode."""
    total_reward: float
    revealed: HiddenTuple
    steps_taken: int
    core_correct_count: int  # 0-3: how many core slots were correct
    core_all_correct: bool   # True if all 3 core slots correct


# =============================================================================
# Strategy Functions
# =============================================================================

def strategy_city_date(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction:
    """
    Strategy A: Ask city, ask date, then answer (guess budget).
    
    Expected behavior with MAX_STEPS=3:
    - Step 1: ASK city
    - Step 2: ASK date
    - Step 3: ANSWER with known city+date, guess budget
    """
    if known.city is None:
        return AskAnswerAction(type="ask", slot="city")
    elif known.date is None:
        return AskAnswerAction(type="ask", slot="date")
    else:
        return AskAnswerAction(
            type="answer",
            city=known.city,
            date=known.date,
            budget=known.budget if known.budget else DEFAULT_BUDGET,
            style=known.style,  # None if not asked
        )


def strategy_city_budget(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction:
    """
    Strategy B: Ask city, ask budget, then answer (guess date).
    
    Expected behavior with MAX_STEPS=3:
    - Step 1: ASK city
    - Step 2: ASK budget
    - Step 3: ANSWER with known city+budget, guess date
    """
    if known.city is None:
        return AskAnswerAction(type="ask", slot="city")
    elif known.budget is None:
        return AskAnswerAction(type="ask", slot="budget")
    else:
        return AskAnswerAction(
            type="answer",
            city=known.city,
            date=known.date if known.date else DEFAULT_DATE,
            budget=known.budget,
            style=known.style,
        )


def strategy_style_city(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction:
    """
    Strategy C (TRAP): Ask style first, then city, guess date+budget.
    
    This is a BAD strategy because:
    - Style only gives +0.1 bonus (vs +0.4 for core slots)
    - Wastes a question on a low-value distractor
    - Must guess 2 core slots instead of 1
    
    Expected behavior with MAX_STEPS=3:
    - Step 1: ASK style (bad choice!)
    - Step 2: ASK city
    - Step 3: ANSWER with known style+city, guess date+budget
    """
    if known.style is None:
        return AskAnswerAction(type="ask", slot="style")
    elif known.city is None:
        return AskAnswerAction(type="ask", slot="city")
    else:
        return AskAnswerAction(
            type="answer",
            city=known.city,
            date=known.date if known.date else DEFAULT_DATE,
            budget=known.budget if known.budget else DEFAULT_BUDGET,
            style=known.style,
        )


def strategy_random(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction:
    """
    Random baseline: randomly ask or answer with random values.
    
    50% chance to ask a random unknown slot, 50% chance to answer.
    If no unknown slots, always answer.
    """
    unknown_slots = []
    if known.city is None:
        unknown_slots.append("city")
    if known.date is None:
        unknown_slots.append("date")
    if known.budget is None:
        unknown_slots.append("budget")
    if known.style is None:
        unknown_slots.append("style")
    
    # If we have unknown slots and coin flip says ask
    if unknown_slots and random.random() < 0.5:
        slot = random.choice(unknown_slots)
        return AskAnswerAction(type="ask", slot=slot)
    
    # Otherwise answer with random guesses for unknown slots
    return AskAnswerAction(
        type="answer",
        city=known.city if known.city else random.choice(CITIES),
        date=known.date if known.date else random.choice(DATES),
        budget=known.budget if known.budget else random.choice(BUDGETS),
        style=known.style if known.style else random.choice(STYLES),
    )


def strategy_oracle(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction:
    """
    Oracle baseline: knows hidden state, answers perfectly in 1 step.
    
    This is the THEORETICAL UPPER BOUND.
    
    In practice, this strategy function is NOT used because the server
    doesn't expose hidden state to the client. Instead, we hardcode
    the oracle's reward as 1.45 in run_baseline_test().
    
    Reward breakdown:
        -0.05 (step) + 0.4×3 (core) + 0.1 (style) + 0.2 (bonus) = +1.45
    """
    if hidden is None:
        raise ValueError("Oracle strategy requires hidden state")
    
    city, date, budget, style = hidden
    return AskAnswerAction(
        type="answer",
        city=city,
        date=date,
        budget=budget,
        style=style,
    )


# =============================================================================
# Episode Runner
# =============================================================================

def run_episode(
    client: AskAnswerEnv,
    strategy: StrategyFn,
    seed: int = 42,
    hidden: Optional[HiddenTuple] = None,
    verbose: bool = False,
) -> EpisodeResult:
    """
    Run a single episode with the given strategy.
    
    Args:
        client: AskAnswerEnv client instance
        strategy: Function that takes (known, steps_left, hidden) and returns action
        seed: Random seed for reproducibility
        hidden: Hidden state tuple (required for oracle strategy)
        verbose: Whether to print step-by-step info
    
    Returns:
        EpisodeResult with total_reward, revealed slots, and steps taken
    """
    result = client.reset(seed=seed)
    total_reward = 0.0
    steps = 0
    
    if verbose:
        print(f"=== Episode Start (seed={seed}) ===")
        print(f"Steps left: {result.observation.steps_left}")
    
    while not result.done:
        obs = result.observation
        action = strategy(obs.known, obs.steps_left, hidden)
        
        result = client.step(action)
        total_reward += result.reward
        steps += 1
        
        if verbose:
            if action.type == "ask":
                slot_val = getattr(result.observation.known, action.slot)
                print(f"  Step {steps}: ASK {action.slot} -> {slot_val}, reward={result.reward:+.2f}")
            else:
                print(f"  Step {steps}: ANSWER city={action.city}, date={action.date}, "
                      f"budget={action.budget}, style={action.style}, reward={result.reward:+.2f}")
    
    final = result.observation.known
    revealed = (final.city, final.date, final.budget, final.style)
    
    # Extract correctness info (available when done=True after ANSWER)
    core_correct_count = result.observation.core_correct_count or 0
    core_all_correct = core_correct_count == 3
    
    if verbose:
        print(f"  Total reward: {total_reward:+.2f}")
        print(f"  Core correct: {core_correct_count}/3")
        print()
    
    return EpisodeResult(
        total_reward=total_reward,
        revealed=revealed,
        steps_taken=steps,
        core_correct_count=core_correct_count,
        core_all_correct=core_all_correct,
    )


# =============================================================================
# Acceptance Tests
# =============================================================================

@dataclass
class BaselineStats:
    """Statistics for a baseline over multiple episodes."""
    name: str
    mean_reward: float
    std_reward: float
    positive_return_rate: float  # % of episodes with reward > 0
    core_success_rate: float     # % of episodes with all 3 core slots correct
    avg_core_correct: float      # average number of core slots correct (0-3)


def run_baseline_test(
    client: AskAnswerEnv,
    name: str,
    strategy: StrategyFn,
    num_episodes: int = 200,
    needs_oracle: bool = False,
) -> BaselineStats:
    """
    Run multiple episodes with a strategy and compute statistics.
    
    Args:
        client: AskAnswerEnv client instance
        name: Name of the baseline for logging
        strategy: Strategy function
        num_episodes: Number of episodes to run
        needs_oracle: If True, use theoretical oracle values
    
    Returns:
        BaselineStats with all metrics
    """
    if needs_oracle:
        # Oracle is a THEORETICAL upper bound - knows hidden state,
        # answers perfectly in 1 step.
        #
        # Reward: -0.05 + 0.4×3 + 0.1 + 0.2 = +1.45
        # Core correct: 3/3 always
        return BaselineStats(
            name=name,
            mean_reward=1.45,
            std_reward=0.0,
            positive_return_rate=1.0,
            core_success_rate=1.0,
            avg_core_correct=3.0,
        )
    
    results: List[EpisodeResult] = []
    
    for seed in range(num_episodes):
        result = run_episode(client, strategy, seed=seed)
        results.append(result)
    
    rewards = [r.total_reward for r in results]
    mean_reward = sum(rewards) / len(rewards)
    variance = sum((r - mean_reward) ** 2 for r in rewards) / len(rewards)
    std_reward = variance ** 0.5
    
    positive_return_rate = sum(1 for r in rewards if r > 0) / len(rewards)
    core_success_rate = sum(1 for r in results if r.core_all_correct) / len(results)
    avg_core_correct = sum(r.core_correct_count for r in results) / len(results)
    
    return BaselineStats(
        name=name,
        mean_reward=mean_reward,
        std_reward=std_reward,
        positive_return_rate=positive_return_rate,
        core_success_rate=core_success_rate,
        avg_core_correct=avg_core_correct,
    )


def run_acceptance_tests(client: AskAnswerEnv, num_episodes: int = 200) -> bool:
    """
    Run all baseline tests and print results table.
    
    Expected ordering:
    Oracle > A ≈ B >> C > Random
    """
    print(f"\nRunning {num_episodes} episodes per baseline...\n")
    
    baselines = [
        ("Oracle (theoretical)", None, True),
        ("A: city+date", strategy_city_date, False),
        ("B: city+budget", strategy_city_budget, False),
        ("C: style+city (trap)", strategy_style_city, False),
        ("Random", strategy_random, False),
    ]
    
    all_stats: List[BaselineStats] = []
    for name, strategy, is_oracle in baselines:
        stats = run_baseline_test(client, name, strategy, num_episodes, needs_oracle=is_oracle)
        all_stats.append(stats)
        print(f"  {name}: mean={stats.mean_reward:+.3f}, core_success={stats.core_success_rate:.1%}")
    
    # Print results table
    print("\n" + "=" * 90)
    print("RESULTS SUMMARY")
    print("=" * 90)
    header = f"{'Baseline':<22} {'Mean':>8} {'Std':>7} {'Pos%':>7} {'Core%':>7} {'AvgCore':>8}"
    print(header)
    print("-" * 90)
    
    for s in sorted(all_stats, key=lambda x: -x.mean_reward):
        print(f"{s.name:<22} {s.mean_reward:>+8.3f} {s.std_reward:>7.3f} "
              f"{s.positive_return_rate:>6.0%} {s.core_success_rate:>6.0%} "
              f"{s.avg_core_correct:>7.2f}/3")
    print("-" * 90)
    
    print("\nColumn legend:")
    print("  Mean    = mean total reward")
    print("  Std     = standard deviation of reward")
    print("  Pos%    = positive_return_rate (% episodes with reward > 0)")
    print("  Core%   = core_success_rate (% episodes with all 3 core slots correct)")
    print("  AvgCore = avg_core_correct (mean # of core slots correct, out of 3)")
    
    # Verify expected ordering
    result_dict = {s.name: s.mean_reward for s in all_stats}
    
    checks = [
        ("Oracle > A", result_dict["Oracle (theoretical)"] > result_dict["A: city+date"]),
        ("A ≈ B", abs(result_dict["A: city+date"] - result_dict["B: city+budget"]) < 0.1),
        ("A > C", result_dict["A: city+date"] > result_dict["C: style+city (trap)"]),
        ("C > Random", result_dict["C: style+city (trap)"] > result_dict["Random"]),
    ]
    
    print("\nExpected ordering checks:")
    all_passed = True
    for check_name, passed in checks:
        status = "PASS" if passed else "FAIL"
        print(f"  {check_name}: {status}")
        if not passed:
            all_passed = False
    
    return all_passed


# =============================================================================
# Determinism Tests (kept from v0)
# =============================================================================

def test_determinism(client: AskAnswerEnv, seed: int = 42, runs: int = 3) -> bool:
    """Test that the same seed produces identical trajectories."""
    trajectories = []
    for _ in range(runs):
        result = run_episode(client, strategy_city_date, seed=seed)
        trajectories.append((result.total_reward, result.revealed))
    
    rewards = [t[0] for t in trajectories]
    revealed = [t[1] for t in trajectories]
    
    identical = len(set(revealed)) == 1 and len(set(rewards)) == 1
    print(f"Determinism (seed={seed}): {revealed[0]} x{runs}, identical={identical}")
    return identical


def test_seed_sensitivity(client: AskAnswerEnv, num_seeds: int = 20) -> bool:
    """Verify different seeds produce different hidden states."""
    unique = set()
    for seed in range(num_seeds):
        result = run_episode(client, strategy_city_date, seed=seed)
        unique.add(result.revealed)
    
    # Max possible: 4 * 3 * 3 * 3 = 108 (with style)
    print(f"Seed sensitivity: {len(unique)} unique tuples from {num_seeds} seeds")
    return len(unique) > 1


# =============================================================================
# Main
# =============================================================================

if __name__ == "__main__":
    client = AskAnswerEnv.from_docker_image("ask_answer_env-env:latest")
    try:
        print("=" * 60)
        print("ASK-ANSWER ENV v1 ACCEPTANCE TESTS")
        print("=" * 60)
        
        # Quick determinism check
        print("\n1. DETERMINISM TESTS")
        print("-" * 40)
        test_determinism(client, seed=42)
        test_determinism(client, seed=123)
        test_seed_sensitivity(client)
        
        # Run a single verbose episode to show behavior
        print("\n2. EXAMPLE EPISODE (Strategy A: city+date)")
        print("-" * 40)
        run_episode(client, strategy_city_date, seed=42, verbose=True)
        
        print("\n3. EXAMPLE EPISODE (Strategy C: style+city - TRAP)")
        print("-" * 40)
        run_episode(client, strategy_style_city, seed=42, verbose=True)
        
        # Full acceptance tests
        print("\n4. BASELINE COMPARISON")
        print("-" * 40)
        passed = run_acceptance_tests(client, num_episodes=200)
        
        print("\n" + "=" * 60)
        print(f"ALL TESTS: {'PASSED' if passed else 'FAILED'}")
        print("=" * 60)
        
    finally:
        client.close()