ujjwalsg commited on
Commit
371cfc1
·
verified ·
1 Parent(s): 5bf9713

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Multi-stage build using openenv-base
8
+ # This Dockerfile is flexible and works for both:
9
+ # - In-repo environments (with local OpenEnv sources)
10
+ # - Standalone environments (with openenv from PyPI/Git)
11
+ # The build script (openenv build) handles context detection and sets appropriate build args.
12
+
13
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
14
+ FROM ${BASE_IMAGE} AS builder
15
+
16
+ WORKDIR /app
17
+
18
+ # Ensure git is available (required for installing dependencies from VCS)
19
+ RUN apt-get update && \
20
+ apt-get install -y --no-install-recommends git && \
21
+ rm -rf /var/lib/apt/lists/*
22
+
23
+ # Build argument to control whether we're building standalone or in-repo
24
+ ARG BUILD_MODE=in-repo
25
+ ARG ENV_NAME=ask_answer_env
26
+
27
+ # Copy environment code (always at root of build context)
28
+ COPY . /app/env
29
+
30
+ # For in-repo builds, openenv is already vendored in the build context
31
+ # For standalone builds, openenv will be installed via pyproject.toml
32
+ WORKDIR /app/env
33
+
34
+ # Ensure uv is available (for local builds where base image lacks it)
35
+ RUN if ! command -v uv >/dev/null 2>&1; then \
36
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
37
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
38
+ mv /root/.local/bin/uvx /usr/local/bin/uvx; \
39
+ fi
40
+
41
+ # Install dependencies using uv sync
42
+ # If uv.lock exists, use it; otherwise resolve on the fly
43
+ RUN --mount=type=cache,target=/root/.cache/uv \
44
+ if [ -f uv.lock ]; then \
45
+ uv sync --frozen --no-install-project --no-editable; \
46
+ else \
47
+ uv sync --no-install-project --no-editable; \
48
+ fi
49
+
50
+ RUN --mount=type=cache,target=/root/.cache/uv \
51
+ if [ -f uv.lock ]; then \
52
+ uv sync --frozen --no-editable; \
53
+ else \
54
+ uv sync --no-editable; \
55
+ fi
56
+
57
+ # Final runtime stage
58
+ FROM ${BASE_IMAGE}
59
+
60
+ WORKDIR /app
61
+
62
+ # Copy the virtual environment from builder
63
+ COPY --from=builder /app/env/.venv /app/.venv
64
+
65
+ # Copy the environment code
66
+ COPY --from=builder /app/env /app/env
67
+
68
+ # Set PATH to use the virtual environment
69
+ ENV PATH="/app/.venv/bin:$PATH"
70
+
71
+ # Set PYTHONPATH so imports work correctly
72
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
73
+
74
+ # Health check
75
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
76
+ CMD curl -f http://localhost:8000/health || exit 1
77
+
78
+ # Run the FastAPI server
79
+ # The module path is constructed to work with the /app/env structure
80
+ ENV ENABLE_WEB_INTERFACE=true
81
+ CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
README.md CHANGED
@@ -1,12 +1,204 @@
1
  ---
2
  title: Ask Answer Env
3
- emoji:
4
- colorFrom: green
5
- colorTo: purple
6
  sdk: docker
7
- pinned: false
8
- license: mit
9
- short_description: OpenEnv for training to decide between asking & answering
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Ask Answer Env
3
+ emoji: 🎯
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: docker
7
+ app_port: 8000
8
+ base_path: /web
9
+ tags:
10
+ - openenv
11
+ - rl
12
  ---
13
 
14
+ # Ask Answer Env (v1)
15
+
16
+ A deterministic OpenEnv environment for training RL agents to decide between **asking clarifying questions** or **answering early** under budget constraints.
17
+
18
+ ## Overview
19
+
20
+ The agent receives a user prompt ("Plan a short trip for me.") and must discover hidden slot values by asking questions before providing a final answer. With only **3 steps** and **4 slots** (3 core + 1 distractor), the agent must prioritize which questions to ask.
21
+
22
+ **Key design goals:**
23
+ - No ML, no NLP — just structured interaction + delayed reward
24
+ - Deterministic given a seed
25
+ - Budget constraints force non-trivial tradeoffs (can only ask 2 of 4 slots)
26
+ - Graded reward structure (partial credit for correct slots)
27
+
28
+ ## Hidden State
29
+
30
+ At each episode reset, the environment samples (with seeded RNG):
31
+ - `city` ∈ `["Paris", "Rome", "Tokyo", "Goa"]` (core)
32
+ - `date` ∈ `["next_weekend", "mid_feb", "march"]` (core)
33
+ - `budget` ∈ `["low", "mid", "high"]` (core)
34
+ - `style` ∈ `["relax", "adventure", "food"]` (distractor)
35
+
36
+ The agent cannot see hidden values unless it asks.
37
+
38
+ ## Action Space
39
+
40
+ **ASK** — reveal a slot:
41
+ ```python
42
+ AskAnswerAction(type="ask", slot="city") # or "date", "budget", "style"
43
+ ```
44
+
45
+ **ANSWER** — end episode with guesses:
46
+ ```python
47
+ AskAnswerAction(type="answer", city="Paris", date="mid_feb", budget="high", style="relax")
48
+ ```
49
+
50
+ ## Observation
51
+
52
+ ```python
53
+ {
54
+ "prompt": "Plan a short trip for me.",
55
+ "known": {
56
+ "city": None | str,
57
+ "date": None | str,
58
+ "budget": None | str,
59
+ "style": None | str
60
+ },
61
+ "steps_left": int, # starts at 3
62
+ "core_correct_count": int | None # populated after ANSWER (0-3)
63
+ }
64
+ ```
65
+
66
+ ## Rewards (v1 - Graded Scoring)
67
+
68
+ | Event | Reward |
69
+ |-------|--------|
70
+ | Step penalty (always) | -0.05 |
71
+ | ASK unknown slot | +0.10 |
72
+ | ASK already-known slot | -0.20 |
73
+ | City correct | +0.40 |
74
+ | Date correct | +0.40 |
75
+ | Budget correct | +0.40 |
76
+ | Style correct (bonus) | +0.10 |
77
+ | All 3 core slots correct (bonus) | +0.20 |
78
+ | Any core slot wrong (penalty) | -0.60 |
79
+
80
+ **Oracle reward (theoretical max):** +1.45 (knows everything, answers perfectly in 1 step)
81
+
82
+ ## Baseline Results
83
+
84
+ ```
85
+ ==========================================================================================
86
+ RESULTS SUMMARY (200 episodes each)
87
+ ==========================================================================================
88
+ Baseline Mean Std Pos% Core% AvgCore
89
+ ------------------------------------------------------------------------------------------
90
+ Oracle (theoretical) +1.450 0.000 100% 100% 3.00/3
91
+ B: city+budget +0.634 0.560 100% 32% 2.32/3
92
+ A: city+date +0.604 0.547 100% 30% 2.29/3
93
+ C: style+city (trap) +0.284 0.483 50% 11% 1.61/3
94
+ Random -0.134 0.530 30% 6% 1.08/3
95
+ ------------------------------------------------------------------------------------------
96
+
97
+ Column legend:
98
+ Mean = mean total reward
99
+ Pos% = positive_return_rate (% episodes with reward > 0)
100
+ Core% = core_success_rate (% episodes with all 3 core slots correct)
101
+ AvgCore = avg_core_correct (mean # of core slots correct, out of 3)
102
+ ```
103
+
104
+ **Key insights:**
105
+ - A/B strategies (ask 2 core slots) achieve ~100% positive return
106
+ - C strategy (wastes a question on style distractor) drops to ~50%
107
+ - Random baseline performs poorly (~30% positive return)
108
+ - Core success rate ~30% for A/B matches expected 1/3 (guessing 1 slot)
109
+
110
+ ## Quick Start
111
+
112
+ ### Build Docker Image
113
+
114
+ ```bash
115
+ # For local use (root Dockerfile used by HF Spaces)
116
+ docker build -t ask_answer_env-env:latest .
117
+
118
+ # Or use server/Dockerfile (equivalent)
119
+ docker build -t ask_answer_env-env:latest -f server/Dockerfile .
120
+ ```
121
+
122
+ ### Run Baseline Tests
123
+
124
+ ```bash
125
+ uv run python exp.py
126
+ ```
127
+
128
+ ### Example Usage
129
+
130
+ ```python
131
+ from ask_answer_env import AskAnswerEnv, AskAnswerAction
132
+
133
+ client = AskAnswerEnv.from_docker_image("ask_answer_env-env:latest")
134
+ try:
135
+ result = client.reset(seed=42)
136
+ print(f"Steps left: {result.observation.steps_left}") # 3
137
+
138
+ # Ask about city (step 1)
139
+ result = client.step(AskAnswerAction(type="ask", slot="city"))
140
+ print(f"City: {result.observation.known.city}")
141
+
142
+ # Ask about date (step 2)
143
+ result = client.step(AskAnswerAction(type="ask", slot="date"))
144
+ print(f"Date: {result.observation.known.date}")
145
+
146
+ # Must answer now (step 3) - guess budget
147
+ known = result.observation.known
148
+ result = client.step(AskAnswerAction(
149
+ type="answer",
150
+ city=known.city,
151
+ date=known.date,
152
+ budget="mid", # guess
153
+ ))
154
+ print(f"Final reward: {result.reward}")
155
+ print(f"Core correct: {result.observation.core_correct_count}/3")
156
+ finally:
157
+ client.close()
158
+ ```
159
+
160
+ ## Testing (`exp.py`)
161
+
162
+ The `exp.py` script contains:
163
+
164
+ ### 1. Determinism Tests
165
+ Verifies same seed → identical trajectories and rewards.
166
+
167
+ ### 2. Seed Sensitivity Test
168
+ Confirms different seeds produce different hidden states.
169
+
170
+ ### 3. Baseline Comparison
171
+ Runs 5 strategies over 200 episodes each:
172
+ - **Oracle**: Theoretical upper bound (knows hidden state)
173
+ - **A: city+date**: Ask city, ask date, guess budget
174
+ - **B: city+budget**: Ask city, ask budget, guess date
175
+ - **C: style+city (trap)**: Wastes a question on distractor
176
+ - **Random**: Random ask/answer decisions
177
+
178
+ ### 4. Ordering Verification
179
+ Confirms: Oracle > A ≈ B >> C > Random
180
+
181
+ ## Project Structure
182
+
183
+ ```
184
+ ask_answer_env/
185
+ ├── __init__.py # Module exports
186
+ ├── models.py # AskAnswerAction, AskAnswerObservation, KnownSlots
187
+ ├── client.py # AskAnswerEnv client (WebSocket)
188
+ ├── exp.py # Baseline strategies + acceptance tests
189
+ ├── Dockerfile # Root Dockerfile (for HF Spaces)
190
+ ├── server/
191
+ │ ├── ask_answer_env_environment.py # Core environment logic
192
+ │ ├── app.py # FastAPI server
193
+ │ └── Dockerfile
194
+ ├── openenv.yaml # OpenEnv manifest
195
+ ├── pyproject.toml # Dependencies
196
+ └── uv.lock # Locked deps
197
+ ```
198
+
199
+ ## Episode Rules
200
+
201
+ - `max_steps = 3`
202
+ - Episode ends when agent sends ANSWER or steps run out
203
+ - Auto-fail (steps exhausted) gives -1.0 reward
204
+ - With 3 steps, agent can ask at most 2 slots before forced to answer/fail
__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Ask Answer Env Environment."""
8
+
9
+ from .client import AskAnswerEnv
10
+ from .models import AskAnswerAction, AskAnswerObservation, KnownSlots
11
+
12
+ __all__ = [
13
+ "AskAnswerAction",
14
+ "AskAnswerObservation",
15
+ "KnownSlots",
16
+ "AskAnswerEnv",
17
+ ]
client.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Ask Answer Env Environment Client."""
8
+
9
+ from typing import Dict
10
+
11
+ from openenv.core.client_types import StepResult
12
+ from openenv.core.env_server.types import State
13
+ from openenv.core import EnvClient
14
+
15
+ from .models import AskAnswerAction, AskAnswerObservation, KnownSlots
16
+
17
+
18
+ class AskAnswerEnv(
19
+ EnvClient[AskAnswerAction, AskAnswerObservation, State]
20
+ ):
21
+ """
22
+ Client for the Ask Answer Env Environment.
23
+
24
+ A slot-filling environment where agents must decide between asking
25
+ clarifying questions or answering early to maximize reward.
26
+
27
+ Example:
28
+ >>> with AskAnswerEnv(base_url="http://localhost:8000") as client:
29
+ ... result = client.reset(seed=42)
30
+ ... print(result.observation.prompt)
31
+ ... print(result.observation.known)
32
+ ...
33
+ ... # Ask about city
34
+ ... result = client.step(AskAnswerAction(type="ask", slot="city"))
35
+ ... print(f"City: {result.observation.known.city}")
36
+ ...
37
+ ... # Answer with known values
38
+ ... result = client.step(AskAnswerAction(
39
+ ... type="answer",
40
+ ... city=result.observation.known.city,
41
+ ... date="mid_feb",
42
+ ... budget="high"
43
+ ... ))
44
+ ... print(f"Reward: {result.reward}, Done: {result.done}")
45
+
46
+ Example with Docker:
47
+ >>> client = AskAnswerEnv.from_docker_image("ask_answer_env-env:latest")
48
+ >>> try:
49
+ ... result = client.reset(seed=42)
50
+ ... # ... interact with environment
51
+ ... finally:
52
+ ... client.close()
53
+ """
54
+
55
+ def _step_payload(self, action: AskAnswerAction) -> Dict:
56
+ """
57
+ Convert AskAnswerAction to JSON payload for step message.
58
+
59
+ Args:
60
+ action: AskAnswerAction instance
61
+
62
+ Returns:
63
+ Dictionary representation suitable for JSON encoding
64
+ """
65
+ payload = {"type": action.type}
66
+ if action.type == "ask":
67
+ payload["slot"] = action.slot
68
+ else: # answer
69
+ payload["city"] = action.city
70
+ payload["date"] = action.date
71
+ payload["budget"] = action.budget
72
+ payload["style"] = action.style
73
+ return payload
74
+
75
+ def _parse_result(self, payload: Dict) -> StepResult[AskAnswerObservation]:
76
+ """
77
+ Parse server response into StepResult[AskAnswerObservation].
78
+
79
+ Args:
80
+ payload: JSON response data from server
81
+
82
+ Returns:
83
+ StepResult with AskAnswerObservation
84
+ """
85
+ obs_data = payload.get("observation", {})
86
+ known_data = obs_data.get("known", {})
87
+
88
+ known = KnownSlots(
89
+ city=known_data.get("city"),
90
+ date=known_data.get("date"),
91
+ budget=known_data.get("budget"),
92
+ style=known_data.get("style"),
93
+ )
94
+
95
+ observation = AskAnswerObservation(
96
+ prompt=obs_data.get("prompt", "Plan a short trip for me."),
97
+ known=known,
98
+ steps_left=obs_data.get("steps_left", 0),
99
+ done=payload.get("done", False),
100
+ reward=payload.get("reward"),
101
+ core_correct_count=obs_data.get("core_correct_count"),
102
+ )
103
+
104
+ return StepResult(
105
+ observation=observation,
106
+ reward=payload.get("reward"),
107
+ done=payload.get("done", False),
108
+ )
109
+
110
+ def _parse_state(self, payload: Dict) -> State:
111
+ """
112
+ Parse server response into State object.
113
+
114
+ Args:
115
+ payload: JSON response from state request
116
+
117
+ Returns:
118
+ State object with episode_id and step_count
119
+ """
120
+ return State(
121
+ episode_id=payload.get("episode_id"),
122
+ step_count=payload.get("step_count", 0),
123
+ )
exp.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Baseline agents for the Ask Answer Env environment (v1).
3
+
4
+ Tests different ask-vs-act strategies under budget constraints (MAX_STEPS=3).
5
+ With only 3 steps, agents can ask at most 2 slots before being forced to answer,
6
+ creating a non-trivial tradeoff between information gathering and guessing.
7
+
8
+ Baselines:
9
+ - A: city+date (ask city, ask date, guess budget)
10
+ - B: city+budget (ask city, ask budget, guess date)
11
+ - C: style+city (trap: wastes a question on distractor)
12
+ - Random: random actions
13
+ - Oracle: knows hidden state, answers immediately (upper bound)
14
+ """
15
+
16
+ import random
17
+ from dataclasses import dataclass
18
+ from typing import Callable, List, Optional, Tuple
19
+
20
+ from ask_answer_env import AskAnswerEnv, AskAnswerAction, KnownSlots
21
+
22
+
23
+ # Type aliases
24
+ HiddenTuple = Tuple[str, str, str, str] # (city, date, budget, style)
25
+ StrategyFn = Callable[[KnownSlots, int, Optional[HiddenTuple]], AskAnswerAction]
26
+
27
+ # Default guesses when slot is unknown
28
+ DEFAULT_CITY = "Paris"
29
+ DEFAULT_DATE = "mid_feb"
30
+ DEFAULT_BUDGET = "mid"
31
+ DEFAULT_STYLE = "relax"
32
+
33
+ # Valid slot values (for random baseline)
34
+ CITIES = ["Paris", "Rome", "Tokyo", "Goa"]
35
+ DATES = ["next_weekend", "mid_feb", "march"]
36
+ BUDGETS = ["low", "mid", "high"]
37
+ STYLES = ["relax", "adventure", "food"]
38
+
39
+
40
+ @dataclass
41
+ class EpisodeResult:
42
+ """Result of a single episode."""
43
+ total_reward: float
44
+ revealed: HiddenTuple
45
+ steps_taken: int
46
+ core_correct_count: int # 0-3: how many core slots were correct
47
+ core_all_correct: bool # True if all 3 core slots correct
48
+
49
+
50
+ # =============================================================================
51
+ # Strategy Functions
52
+ # =============================================================================
53
+
54
+ def strategy_city_date(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction:
55
+ """
56
+ Strategy A: Ask city, ask date, then answer (guess budget).
57
+
58
+ Expected behavior with MAX_STEPS=3:
59
+ - Step 1: ASK city
60
+ - Step 2: ASK date
61
+ - Step 3: ANSWER with known city+date, guess budget
62
+ """
63
+ if known.city is None:
64
+ return AskAnswerAction(type="ask", slot="city")
65
+ elif known.date is None:
66
+ return AskAnswerAction(type="ask", slot="date")
67
+ else:
68
+ return AskAnswerAction(
69
+ type="answer",
70
+ city=known.city,
71
+ date=known.date,
72
+ budget=known.budget if known.budget else DEFAULT_BUDGET,
73
+ style=known.style, # None if not asked
74
+ )
75
+
76
+
77
+ def strategy_city_budget(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction:
78
+ """
79
+ Strategy B: Ask city, ask budget, then answer (guess date).
80
+
81
+ Expected behavior with MAX_STEPS=3:
82
+ - Step 1: ASK city
83
+ - Step 2: ASK budget
84
+ - Step 3: ANSWER with known city+budget, guess date
85
+ """
86
+ if known.city is None:
87
+ return AskAnswerAction(type="ask", slot="city")
88
+ elif known.budget is None:
89
+ return AskAnswerAction(type="ask", slot="budget")
90
+ else:
91
+ return AskAnswerAction(
92
+ type="answer",
93
+ city=known.city,
94
+ date=known.date if known.date else DEFAULT_DATE,
95
+ budget=known.budget,
96
+ style=known.style,
97
+ )
98
+
99
+
100
+ def strategy_style_city(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction:
101
+ """
102
+ Strategy C (TRAP): Ask style first, then city, guess date+budget.
103
+
104
+ This is a BAD strategy because:
105
+ - Style only gives +0.1 bonus (vs +0.4 for core slots)
106
+ - Wastes a question on a low-value distractor
107
+ - Must guess 2 core slots instead of 1
108
+
109
+ Expected behavior with MAX_STEPS=3:
110
+ - Step 1: ASK style (bad choice!)
111
+ - Step 2: ASK city
112
+ - Step 3: ANSWER with known style+city, guess date+budget
113
+ """
114
+ if known.style is None:
115
+ return AskAnswerAction(type="ask", slot="style")
116
+ elif known.city is None:
117
+ return AskAnswerAction(type="ask", slot="city")
118
+ else:
119
+ return AskAnswerAction(
120
+ type="answer",
121
+ city=known.city,
122
+ date=known.date if known.date else DEFAULT_DATE,
123
+ budget=known.budget if known.budget else DEFAULT_BUDGET,
124
+ style=known.style,
125
+ )
126
+
127
+
128
+ def strategy_random(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction:
129
+ """
130
+ Random baseline: randomly ask or answer with random values.
131
+
132
+ 50% chance to ask a random unknown slot, 50% chance to answer.
133
+ If no unknown slots, always answer.
134
+ """
135
+ unknown_slots = []
136
+ if known.city is None:
137
+ unknown_slots.append("city")
138
+ if known.date is None:
139
+ unknown_slots.append("date")
140
+ if known.budget is None:
141
+ unknown_slots.append("budget")
142
+ if known.style is None:
143
+ unknown_slots.append("style")
144
+
145
+ # If we have unknown slots and coin flip says ask
146
+ if unknown_slots and random.random() < 0.5:
147
+ slot = random.choice(unknown_slots)
148
+ return AskAnswerAction(type="ask", slot=slot)
149
+
150
+ # Otherwise answer with random guesses for unknown slots
151
+ return AskAnswerAction(
152
+ type="answer",
153
+ city=known.city if known.city else random.choice(CITIES),
154
+ date=known.date if known.date else random.choice(DATES),
155
+ budget=known.budget if known.budget else random.choice(BUDGETS),
156
+ style=known.style if known.style else random.choice(STYLES),
157
+ )
158
+
159
+
160
+ def strategy_oracle(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction:
161
+ """
162
+ Oracle baseline: knows hidden state, answers perfectly in 1 step.
163
+
164
+ This is the THEORETICAL UPPER BOUND.
165
+
166
+ In practice, this strategy function is NOT used because the server
167
+ doesn't expose hidden state to the client. Instead, we hardcode
168
+ the oracle's reward as 1.45 in run_baseline_test().
169
+
170
+ Reward breakdown:
171
+ -0.05 (step) + 0.4×3 (core) + 0.1 (style) + 0.2 (bonus) = +1.45
172
+ """
173
+ if hidden is None:
174
+ raise ValueError("Oracle strategy requires hidden state")
175
+
176
+ city, date, budget, style = hidden
177
+ return AskAnswerAction(
178
+ type="answer",
179
+ city=city,
180
+ date=date,
181
+ budget=budget,
182
+ style=style,
183
+ )
184
+
185
+
186
+ # =============================================================================
187
+ # Episode Runner
188
+ # =============================================================================
189
+
190
+ def run_episode(
191
+ client: AskAnswerEnv,
192
+ strategy: StrategyFn,
193
+ seed: int = 42,
194
+ hidden: Optional[HiddenTuple] = None,
195
+ verbose: bool = False,
196
+ ) -> EpisodeResult:
197
+ """
198
+ Run a single episode with the given strategy.
199
+
200
+ Args:
201
+ client: AskAnswerEnv client instance
202
+ strategy: Function that takes (known, steps_left, hidden) and returns action
203
+ seed: Random seed for reproducibility
204
+ hidden: Hidden state tuple (required for oracle strategy)
205
+ verbose: Whether to print step-by-step info
206
+
207
+ Returns:
208
+ EpisodeResult with total_reward, revealed slots, and steps taken
209
+ """
210
+ result = client.reset(seed=seed)
211
+ total_reward = 0.0
212
+ steps = 0
213
+
214
+ if verbose:
215
+ print(f"=== Episode Start (seed={seed}) ===")
216
+ print(f"Steps left: {result.observation.steps_left}")
217
+
218
+ while not result.done:
219
+ obs = result.observation
220
+ action = strategy(obs.known, obs.steps_left, hidden)
221
+
222
+ result = client.step(action)
223
+ total_reward += result.reward
224
+ steps += 1
225
+
226
+ if verbose:
227
+ if action.type == "ask":
228
+ slot_val = getattr(result.observation.known, action.slot)
229
+ print(f" Step {steps}: ASK {action.slot} -> {slot_val}, reward={result.reward:+.2f}")
230
+ else:
231
+ print(f" Step {steps}: ANSWER city={action.city}, date={action.date}, "
232
+ f"budget={action.budget}, style={action.style}, reward={result.reward:+.2f}")
233
+
234
+ final = result.observation.known
235
+ revealed = (final.city, final.date, final.budget, final.style)
236
+
237
+ # Extract correctness info (available when done=True after ANSWER)
238
+ core_correct_count = result.observation.core_correct_count or 0
239
+ core_all_correct = core_correct_count == 3
240
+
241
+ if verbose:
242
+ print(f" Total reward: {total_reward:+.2f}")
243
+ print(f" Core correct: {core_correct_count}/3")
244
+ print()
245
+
246
+ return EpisodeResult(
247
+ total_reward=total_reward,
248
+ revealed=revealed,
249
+ steps_taken=steps,
250
+ core_correct_count=core_correct_count,
251
+ core_all_correct=core_all_correct,
252
+ )
253
+
254
+
255
+ # =============================================================================
256
+ # Acceptance Tests
257
+ # =============================================================================
258
+
259
+ @dataclass
260
+ class BaselineStats:
261
+ """Statistics for a baseline over multiple episodes."""
262
+ name: str
263
+ mean_reward: float
264
+ std_reward: float
265
+ positive_return_rate: float # % of episodes with reward > 0
266
+ core_success_rate: float # % of episodes with all 3 core slots correct
267
+ avg_core_correct: float # average number of core slots correct (0-3)
268
+
269
+
270
+ def run_baseline_test(
271
+ client: AskAnswerEnv,
272
+ name: str,
273
+ strategy: StrategyFn,
274
+ num_episodes: int = 200,
275
+ needs_oracle: bool = False,
276
+ ) -> BaselineStats:
277
+ """
278
+ Run multiple episodes with a strategy and compute statistics.
279
+
280
+ Args:
281
+ client: AskAnswerEnv client instance
282
+ name: Name of the baseline for logging
283
+ strategy: Strategy function
284
+ num_episodes: Number of episodes to run
285
+ needs_oracle: If True, use theoretical oracle values
286
+
287
+ Returns:
288
+ BaselineStats with all metrics
289
+ """
290
+ if needs_oracle:
291
+ # Oracle is a THEORETICAL upper bound - knows hidden state,
292
+ # answers perfectly in 1 step.
293
+ #
294
+ # Reward: -0.05 + 0.4×3 + 0.1 + 0.2 = +1.45
295
+ # Core correct: 3/3 always
296
+ return BaselineStats(
297
+ name=name,
298
+ mean_reward=1.45,
299
+ std_reward=0.0,
300
+ positive_return_rate=1.0,
301
+ core_success_rate=1.0,
302
+ avg_core_correct=3.0,
303
+ )
304
+
305
+ results: List[EpisodeResult] = []
306
+
307
+ for seed in range(num_episodes):
308
+ result = run_episode(client, strategy, seed=seed)
309
+ results.append(result)
310
+
311
+ rewards = [r.total_reward for r in results]
312
+ mean_reward = sum(rewards) / len(rewards)
313
+ variance = sum((r - mean_reward) ** 2 for r in rewards) / len(rewards)
314
+ std_reward = variance ** 0.5
315
+
316
+ positive_return_rate = sum(1 for r in rewards if r > 0) / len(rewards)
317
+ core_success_rate = sum(1 for r in results if r.core_all_correct) / len(results)
318
+ avg_core_correct = sum(r.core_correct_count for r in results) / len(results)
319
+
320
+ return BaselineStats(
321
+ name=name,
322
+ mean_reward=mean_reward,
323
+ std_reward=std_reward,
324
+ positive_return_rate=positive_return_rate,
325
+ core_success_rate=core_success_rate,
326
+ avg_core_correct=avg_core_correct,
327
+ )
328
+
329
+
330
+ def run_acceptance_tests(client: AskAnswerEnv, num_episodes: int = 200) -> bool:
331
+ """
332
+ Run all baseline tests and print results table.
333
+
334
+ Expected ordering:
335
+ Oracle > A ≈ B >> C > Random
336
+ """
337
+ print(f"\nRunning {num_episodes} episodes per baseline...\n")
338
+
339
+ baselines = [
340
+ ("Oracle (theoretical)", None, True),
341
+ ("A: city+date", strategy_city_date, False),
342
+ ("B: city+budget", strategy_city_budget, False),
343
+ ("C: style+city (trap)", strategy_style_city, False),
344
+ ("Random", strategy_random, False),
345
+ ]
346
+
347
+ all_stats: List[BaselineStats] = []
348
+ for name, strategy, is_oracle in baselines:
349
+ stats = run_baseline_test(client, name, strategy, num_episodes, needs_oracle=is_oracle)
350
+ all_stats.append(stats)
351
+ print(f" {name}: mean={stats.mean_reward:+.3f}, core_success={stats.core_success_rate:.1%}")
352
+
353
+ # Print results table
354
+ print("\n" + "=" * 90)
355
+ print("RESULTS SUMMARY")
356
+ print("=" * 90)
357
+ header = f"{'Baseline':<22} {'Mean':>8} {'Std':>7} {'Pos%':>7} {'Core%':>7} {'AvgCore':>8}"
358
+ print(header)
359
+ print("-" * 90)
360
+
361
+ for s in sorted(all_stats, key=lambda x: -x.mean_reward):
362
+ print(f"{s.name:<22} {s.mean_reward:>+8.3f} {s.std_reward:>7.3f} "
363
+ f"{s.positive_return_rate:>6.0%} {s.core_success_rate:>6.0%} "
364
+ f"{s.avg_core_correct:>7.2f}/3")
365
+ print("-" * 90)
366
+
367
+ print("\nColumn legend:")
368
+ print(" Mean = mean total reward")
369
+ print(" Std = standard deviation of reward")
370
+ print(" Pos% = positive_return_rate (% episodes with reward > 0)")
371
+ print(" Core% = core_success_rate (% episodes with all 3 core slots correct)")
372
+ print(" AvgCore = avg_core_correct (mean # of core slots correct, out of 3)")
373
+
374
+ # Verify expected ordering
375
+ result_dict = {s.name: s.mean_reward for s in all_stats}
376
+
377
+ checks = [
378
+ ("Oracle > A", result_dict["Oracle (theoretical)"] > result_dict["A: city+date"]),
379
+ ("A ≈ B", abs(result_dict["A: city+date"] - result_dict["B: city+budget"]) < 0.1),
380
+ ("A > C", result_dict["A: city+date"] > result_dict["C: style+city (trap)"]),
381
+ ("C > Random", result_dict["C: style+city (trap)"] > result_dict["Random"]),
382
+ ]
383
+
384
+ print("\nExpected ordering checks:")
385
+ all_passed = True
386
+ for check_name, passed in checks:
387
+ status = "PASS" if passed else "FAIL"
388
+ print(f" {check_name}: {status}")
389
+ if not passed:
390
+ all_passed = False
391
+
392
+ return all_passed
393
+
394
+
395
+ # =============================================================================
396
+ # Determinism Tests (kept from v0)
397
+ # =============================================================================
398
+
399
+ def test_determinism(client: AskAnswerEnv, seed: int = 42, runs: int = 3) -> bool:
400
+ """Test that the same seed produces identical trajectories."""
401
+ trajectories = []
402
+ for _ in range(runs):
403
+ result = run_episode(client, strategy_city_date, seed=seed)
404
+ trajectories.append((result.total_reward, result.revealed))
405
+
406
+ rewards = [t[0] for t in trajectories]
407
+ revealed = [t[1] for t in trajectories]
408
+
409
+ identical = len(set(revealed)) == 1 and len(set(rewards)) == 1
410
+ print(f"Determinism (seed={seed}): {revealed[0]} x{runs}, identical={identical}")
411
+ return identical
412
+
413
+
414
+ def test_seed_sensitivity(client: AskAnswerEnv, num_seeds: int = 20) -> bool:
415
+ """Verify different seeds produce different hidden states."""
416
+ unique = set()
417
+ for seed in range(num_seeds):
418
+ result = run_episode(client, strategy_city_date, seed=seed)
419
+ unique.add(result.revealed)
420
+
421
+ # Max possible: 4 * 3 * 3 * 3 = 108 (with style)
422
+ print(f"Seed sensitivity: {len(unique)} unique tuples from {num_seeds} seeds")
423
+ return len(unique) > 1
424
+
425
+
426
+ # =============================================================================
427
+ # Main
428
+ # =============================================================================
429
+
430
+ if __name__ == "__main__":
431
+ client = AskAnswerEnv.from_docker_image("ask_answer_env-env:latest")
432
+ try:
433
+ print("=" * 60)
434
+ print("ASK-ANSWER ENV v1 ACCEPTANCE TESTS")
435
+ print("=" * 60)
436
+
437
+ # Quick determinism check
438
+ print("\n1. DETERMINISM TESTS")
439
+ print("-" * 40)
440
+ test_determinism(client, seed=42)
441
+ test_determinism(client, seed=123)
442
+ test_seed_sensitivity(client)
443
+
444
+ # Run a single verbose episode to show behavior
445
+ print("\n2. EXAMPLE EPISODE (Strategy A: city+date)")
446
+ print("-" * 40)
447
+ run_episode(client, strategy_city_date, seed=42, verbose=True)
448
+
449
+ print("\n3. EXAMPLE EPISODE (Strategy C: style+city - TRAP)")
450
+ print("-" * 40)
451
+ run_episode(client, strategy_style_city, seed=42, verbose=True)
452
+
453
+ # Full acceptance tests
454
+ print("\n4. BASELINE COMPARISON")
455
+ print("-" * 40)
456
+ passed = run_acceptance_tests(client, num_episodes=200)
457
+
458
+ print("\n" + "=" * 60)
459
+ print(f"ALL TESTS: {'PASSED' if passed else 'FAILED'}")
460
+ print("=" * 60)
461
+
462
+ finally:
463
+ client.close()
models.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Data models for the Ask Answer Env Environment.
9
+
10
+ A slot-filling environment where agents must decide between asking
11
+ clarifying questions or answering early.
12
+ """
13
+
14
+ from typing import Literal, Optional
15
+
16
+ from pydantic import BaseModel, Field
17
+
18
+ from openenv.core.env_server.types import Action, Observation
19
+
20
+
21
+ class AskAnswerAction(Action):
22
+ """
23
+ Action for the Ask Answer Env environment.
24
+
25
+ Two action types:
26
+ - "ask": Ask about a specific slot (city, date, budget, or style)
27
+ - "answer": Provide final answer with slot values
28
+ """
29
+
30
+ type: Literal["ask", "answer"] = Field(..., description="Action type: 'ask' or 'answer'")
31
+ slot: Optional[Literal["city", "date", "budget", "style"]] = Field(
32
+ default=None, description="Slot to ask about (only for 'ask' type)"
33
+ )
34
+ city: Optional[str] = Field(default=None, description="City answer (only for 'answer' type)")
35
+ date: Optional[str] = Field(default=None, description="Date answer (only for 'answer' type)")
36
+ budget: Optional[str] = Field(default=None, description="Budget answer (only for 'answer' type)")
37
+ style: Optional[str] = Field(default=None, description="Style answer (only for 'answer' type)")
38
+
39
+
40
+ class KnownSlots(BaseModel):
41
+ """Slots that have been revealed to the agent."""
42
+
43
+ city: Optional[str] = Field(default=None, description="Known city value")
44
+ date: Optional[str] = Field(default=None, description="Known date value")
45
+ budget: Optional[str] = Field(default=None, description="Known budget value")
46
+ style: Optional[str] = Field(default=None, description="Known style value")
47
+
48
+
49
+ class AskAnswerObservation(Observation):
50
+ """Observation from the Ask Answer Env environment."""
51
+
52
+ prompt: str = Field(default="Plan a short trip for me.", description="The user prompt")
53
+ known: KnownSlots = Field(default_factory=KnownSlots, description="Currently known slot values")
54
+ steps_left: int = Field(default=3, description="Number of steps remaining")
55
+
56
+ # Correctness info (populated when done=True after ANSWER action)
57
+ core_correct_count: Optional[int] = Field(
58
+ default=None, description="Number of core slots correct (0-3), None if not answered yet"
59
+ )
openenv.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: ask_answer_env
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
7
+
openenv_ask_answer_env.egg-info/PKG-INFO ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: openenv-ask_answer_env
3
+ Version: 0.1.0
4
+ Summary: Ask Answer Env environment for OpenEnv
5
+ Requires-Python: >=3.10
6
+ Requires-Dist: openenv-core[core]>=0.2.1
7
+ Provides-Extra: dev
8
+ Requires-Dist: pytest>=8.0.0; extra == "dev"
9
+ Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
openenv_ask_answer_env.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ __init__.py
3
+ client.py
4
+ models.py
5
+ pyproject.toml
6
+ ./__init__.py
7
+ ./client.py
8
+ ./exp.py
9
+ ./models.py
10
+ openenv_ask_answer_env.egg-info/PKG-INFO
11
+ openenv_ask_answer_env.egg-info/SOURCES.txt
12
+ openenv_ask_answer_env.egg-info/dependency_links.txt
13
+ openenv_ask_answer_env.egg-info/entry_points.txt
14
+ openenv_ask_answer_env.egg-info/requires.txt
15
+ openenv_ask_answer_env.egg-info/top_level.txt
16
+ server/__init__.py
17
+ server/app.py
18
+ server/ask_answer_env_environment.py
openenv_ask_answer_env.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
openenv_ask_answer_env.egg-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ server = ask_answer_env.server.app:main
openenv_ask_answer_env.egg-info/requires.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ openenv-core[core]>=0.2.1
2
+
3
+ [dev]
4
+ pytest>=8.0.0
5
+ pytest-cov>=4.0.0
openenv_ask_answer_env.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ask_answer_env
pyproject.toml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ [build-system]
8
+ requires = ["setuptools>=45", "wheel"]
9
+ build-backend = "setuptools.build_meta"
10
+
11
+ [project]
12
+ name = "openenv-ask_answer_env"
13
+ version = "0.1.0"
14
+ description = "Ask Answer Env environment for OpenEnv"
15
+ requires-python = ">=3.10"
16
+ dependencies = [
17
+ # Core OpenEnv runtime (provides FastAPI server + HTTP client types)
18
+ # install from github
19
+ # "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
20
+ "openenv-core[core]>=0.2.1",
21
+ # Environment-specific dependencies
22
+ # Add all dependencies needed for your environment here
23
+ # Examples:
24
+ # "numpy>=1.19.0",
25
+ # "torch>=2.0.0",
26
+ # "gymnasium>=0.29.0",
27
+ # "openspiel>=1.0.0",
28
+ # "smolagents>=1.22.0,<2",
29
+ ]
30
+
31
+ [project.optional-dependencies]
32
+ dev = [
33
+ "pytest>=8.0.0",
34
+ "pytest-cov>=4.0.0",
35
+ ]
36
+
37
+ [project.scripts]
38
+ # Server entry point - enables running via: uv run --project . server
39
+ # or: python -m ask_answer_env.server.app
40
+ server = "ask_answer_env.server.app:main"
41
+
42
+ [tool.setuptools]
43
+ include-package-data = true
44
+ packages = ["ask_answer_env", "ask_answer_env.server"]
45
+ package-dir = { "ask_answer_env" = ".", "ask_answer_env.server" = "server" }
server/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Ask Answer Env environment server components."""
8
+
9
+ from .ask_answer_env_environment import AskAnswerEnvironment
10
+
11
+ __all__ = ["AskAnswerEnvironment"]
server/app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ FastAPI application for the Ask Answer Env Environment.
9
+
10
+ This module creates an HTTP server that exposes the AskAnswerEnvironment
11
+ over HTTP and WebSocket endpoints, compatible with EnvClient.
12
+
13
+ Endpoints:
14
+ - POST /reset: Reset the environment
15
+ - POST /step: Execute an action
16
+ - GET /state: Get current environment state
17
+ - GET /schema: Get action/observation schemas
18
+ - WS /ws: WebSocket endpoint for persistent sessions
19
+
20
+ Usage:
21
+ # Development (with auto-reload):
22
+ uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
23
+
24
+ # Production:
25
+ uvicorn server.app:app --host 0.0.0.0 --port 8000 --workers 4
26
+
27
+ # Or run directly:
28
+ python -m server.app
29
+ """
30
+
31
+ try:
32
+ from openenv.core.env_server.http_server import create_app
33
+ except Exception as e: # pragma: no cover
34
+ raise ImportError(
35
+ "openenv is required for the web interface. Install dependencies with '\n uv sync\n'"
36
+ ) from e
37
+
38
+ # Import from local models.py (PYTHONPATH includes /app/env in Docker)
39
+ from models import AskAnswerAction, AskAnswerObservation
40
+ from .ask_answer_env_environment import AskAnswerEnvironment
41
+
42
+
43
+ # Create the app with web interface and README integration
44
+ app = create_app(
45
+ AskAnswerEnvironment,
46
+ AskAnswerAction,
47
+ AskAnswerObservation,
48
+ env_name="ask_answer_env",
49
+ max_concurrent_envs=1, # increase this number to allow more concurrent WebSocket sessions
50
+ )
51
+
52
+
53
+ def main(host: str = "0.0.0.0", port: int = 8000):
54
+ """
55
+ Entry point for direct execution via uv run or python -m.
56
+
57
+ This function enables running the server without Docker:
58
+ uv run --project . server
59
+ uv run --project . server --port 8001
60
+ python -m ask_answer_env.server.app
61
+
62
+ Args:
63
+ host: Host address to bind to (default: "0.0.0.0")
64
+ port: Port number to listen on (default: 8000)
65
+
66
+ For production deployments, consider using uvicorn directly with
67
+ multiple workers:
68
+ uvicorn ask_answer_env.server.app:app --workers 4
69
+ """
70
+ import uvicorn
71
+
72
+ uvicorn.run(app, host=host, port=port)
73
+
74
+
75
+ if __name__ == "__main__":
76
+ import argparse
77
+
78
+ parser = argparse.ArgumentParser()
79
+ parser.add_argument("--port", type=int, default=8000)
80
+ args = parser.parse_args()
81
+ main(port=args.port)
server/ask_answer_env_environment.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Ask Answer Env Environment Implementation.
9
+
10
+ A deterministic slot-filling environment where agents must decide between
11
+ asking clarifying questions or answering early to maximize reward.
12
+ """
13
+
14
+ import random
15
+ from typing import Optional
16
+ from uuid import uuid4
17
+
18
+ from openenv.core.env_server.interfaces import Environment
19
+ from openenv.core.env_server.types import State
20
+
21
+ from models import AskAnswerAction, AskAnswerObservation, KnownSlots
22
+
23
+
24
+ # Constants
25
+ CITIES = ["Paris", "Rome", "Tokyo", "Goa"]
26
+ DATES = ["next_weekend", "mid_feb", "march"]
27
+ BUDGETS = ["low", "mid", "high"]
28
+ STYLES = ["relax", "adventure", "food"] # Distractor slot
29
+ MAX_STEPS = 3 # Forces agent to guess at least 1 core slot
30
+ PROMPT = "Plan a short trip for me."
31
+
32
+ # Rewards (unchanged from v0)
33
+ STEP_PENALTY = -0.05
34
+ ASK_UNKNOWN_REWARD = 0.1
35
+ ASK_KNOWN_PENALTY = -0.2
36
+ AUTO_FAIL_PENALTY = -1.0
37
+
38
+ # Graded answer rewards (v1)
39
+ ANSWER_CITY_CORRECT = 0.4
40
+ ANSWER_DATE_CORRECT = 0.4
41
+ ANSWER_BUDGET_CORRECT = 0.4
42
+ ANSWER_STYLE_CORRECT_BONUS = 0.1 # Optional nice-to-have
43
+ ANSWER_CORE_ALL_CORRECT_BONUS = 0.2
44
+ ANSWER_CORE_ANY_WRONG_PENALTY = -0.6
45
+
46
+
47
+ class AskAnswerEnvironment(Environment):
48
+ """
49
+ A slot-filling environment for training RL agents.
50
+
51
+ The agent must decide between:
52
+ - Asking clarifying questions (ASK) to reveal hidden slot values
53
+ - Answering early (ANSWER) to end the episode
54
+
55
+ Hidden state (city, date, budget, style) is sampled at reset with a seeded RNG.
56
+ The agent can ask about slots to reveal their values before answering.
57
+
58
+ With MAX_STEPS=3, the agent can only ask 2 slots before being forced to answer,
59
+ creating a non-trivial ask-vs-act tradeoff. The "style" slot is a distractor
60
+ that provides less reward than core slots (city, date, budget).
61
+
62
+ Rewards:
63
+ - Step penalty: -0.05 per step
64
+ - ASK unknown slot: +0.1
65
+ - ASK known slot: -0.2
66
+ - ANSWER: graded per-slot (+0.4 each core, +0.1 style)
67
+ - Core all correct bonus: +0.2
68
+ - Core any wrong penalty: -0.6
69
+ - Auto-fail (steps exhausted): -1.0
70
+ """
71
+
72
+ SUPPORTS_CONCURRENT_SESSIONS: bool = True
73
+
74
+ def __init__(self):
75
+ """Initialize the ask_answer_env environment."""
76
+ self._state = State(episode_id=str(uuid4()), step_count=0)
77
+ self._rng: random.Random = random.Random()
78
+
79
+ # Hidden truth (sampled at reset)
80
+ self._hidden_city: str = ""
81
+ self._hidden_date: str = ""
82
+ self._hidden_budget: str = ""
83
+ self._hidden_style: str = ""
84
+
85
+ # Known slots (revealed through ASK actions)
86
+ self._known: KnownSlots = KnownSlots()
87
+ self._steps_left: int = MAX_STEPS
88
+ self._done: bool = False
89
+
90
+ def reset(self, seed: Optional[int] = None) -> AskAnswerObservation:
91
+ """
92
+ Reset the environment with optional seed for determinism.
93
+
94
+ Args:
95
+ seed: Random seed for reproducibility
96
+
97
+ Returns:
98
+ AskAnswerObservation with initial state
99
+ """
100
+ self._state = State(episode_id=str(uuid4()), step_count=0)
101
+
102
+ # Initialize RNG with seed
103
+ if seed is not None:
104
+ self._rng = random.Random(seed)
105
+ else:
106
+ self._rng = random.Random()
107
+
108
+ # Sample hidden truth
109
+ self._hidden_city = self._rng.choice(CITIES)
110
+ self._hidden_date = self._rng.choice(DATES)
111
+ self._hidden_budget = self._rng.choice(BUDGETS)
112
+ self._hidden_style = self._rng.choice(STYLES)
113
+
114
+ # Reset known slots and step counter
115
+ self._known = KnownSlots()
116
+ self._steps_left = MAX_STEPS
117
+ self._done = False
118
+
119
+ return AskAnswerObservation(
120
+ prompt=PROMPT,
121
+ known=self._known,
122
+ steps_left=self._steps_left,
123
+ done=False,
124
+ reward=0.0,
125
+ )
126
+
127
+ def step(self, action: AskAnswerAction) -> AskAnswerObservation: # type: ignore[override]
128
+ """
129
+ Execute a step in the environment.
130
+
131
+ Args:
132
+ action: AskAnswerAction with type 'ask' or 'answer'
133
+
134
+ Returns:
135
+ AskAnswerObservation with updated state and reward
136
+ """
137
+ if self._done:
138
+ return AskAnswerObservation(
139
+ prompt=PROMPT,
140
+ known=self._known,
141
+ steps_left=self._steps_left,
142
+ done=True,
143
+ reward=0.0,
144
+ )
145
+
146
+ self._state.step_count += 1
147
+
148
+ # Always apply step penalty
149
+ reward = STEP_PENALTY
150
+ done = False
151
+
152
+ if action.type == "ask":
153
+ reward += self._handle_ask(action.slot)
154
+ self._steps_left -= 1
155
+
156
+ # Check for auto-fail
157
+ if self._steps_left == 0:
158
+ reward = AUTO_FAIL_PENALTY
159
+ done = True
160
+
161
+ elif action.type == "answer":
162
+ reward += self._handle_answer(action)
163
+ done = True
164
+
165
+ self._done = done
166
+
167
+ # Calculate core_correct_count when episode ends via ANSWER
168
+ core_correct_count = None
169
+ if done and action.type == "answer":
170
+ core_correct_count = sum([
171
+ action.city == self._hidden_city,
172
+ action.date == self._hidden_date,
173
+ action.budget == self._hidden_budget,
174
+ ])
175
+
176
+ return AskAnswerObservation(
177
+ prompt=PROMPT,
178
+ known=self._known,
179
+ steps_left=self._steps_left,
180
+ done=done,
181
+ reward=reward,
182
+ core_correct_count=core_correct_count,
183
+ )
184
+
185
+ def _handle_ask(self, slot: Optional[str]) -> float:
186
+ """
187
+ Handle ASK action - reveal a slot if unknown.
188
+
189
+ Args:
190
+ slot: The slot to ask about ('city', 'date', 'budget', or 'style')
191
+
192
+ Returns:
193
+ Reward for the ASK action
194
+ """
195
+ if slot == "city":
196
+ if self._known.city is not None:
197
+ return ASK_KNOWN_PENALTY
198
+ self._known = KnownSlots(
199
+ city=self._hidden_city,
200
+ date=self._known.date,
201
+ budget=self._known.budget,
202
+ style=self._known.style,
203
+ )
204
+ return ASK_UNKNOWN_REWARD
205
+
206
+ elif slot == "date":
207
+ if self._known.date is not None:
208
+ return ASK_KNOWN_PENALTY
209
+ self._known = KnownSlots(
210
+ city=self._known.city,
211
+ date=self._hidden_date,
212
+ budget=self._known.budget,
213
+ style=self._known.style,
214
+ )
215
+ return ASK_UNKNOWN_REWARD
216
+
217
+ elif slot == "budget":
218
+ if self._known.budget is not None:
219
+ return ASK_KNOWN_PENALTY
220
+ self._known = KnownSlots(
221
+ city=self._known.city,
222
+ date=self._known.date,
223
+ budget=self._hidden_budget,
224
+ style=self._known.style,
225
+ )
226
+ return ASK_UNKNOWN_REWARD
227
+
228
+ elif slot == "style":
229
+ if self._known.style is not None:
230
+ return ASK_KNOWN_PENALTY
231
+ self._known = KnownSlots(
232
+ city=self._known.city,
233
+ date=self._known.date,
234
+ budget=self._known.budget,
235
+ style=self._hidden_style,
236
+ )
237
+ return ASK_UNKNOWN_REWARD
238
+
239
+ # Invalid slot
240
+ return ASK_KNOWN_PENALTY
241
+
242
+ def _handle_answer(self, action: AskAnswerAction) -> float:
243
+ """
244
+ Handle ANSWER action with graded rewards.
245
+
246
+ Reward structure:
247
+ - Per-slot rewards: +0.4 for each correct core slot (city, date, budget)
248
+ - Style bonus: +0.1 if style provided and correct (ignored if None)
249
+ - Core bonus: +0.2 if all core slots correct
250
+ - Core penalty: -0.6 if any core slot wrong
251
+
252
+ Args:
253
+ action: The answer action with city, date, budget, style values
254
+
255
+ Returns:
256
+ Reward for the ANSWER action
257
+ """
258
+ reward = 0.0
259
+
260
+ # Check core slots
261
+ city_correct = action.city == self._hidden_city
262
+ date_correct = action.date == self._hidden_date
263
+ budget_correct = action.budget == self._hidden_budget
264
+
265
+ # Per-slot rewards for core slots
266
+ if city_correct:
267
+ reward += ANSWER_CITY_CORRECT
268
+ if date_correct:
269
+ reward += ANSWER_DATE_CORRECT
270
+ if budget_correct:
271
+ reward += ANSWER_BUDGET_CORRECT
272
+
273
+ # Style bonus (only if provided and correct, ignored if None)
274
+ if action.style is not None and action.style == self._hidden_style:
275
+ reward += ANSWER_STYLE_CORRECT_BONUS
276
+
277
+ # Core bonus/penalty
278
+ core_all_correct = city_correct and date_correct and budget_correct
279
+ if core_all_correct:
280
+ reward += ANSWER_CORE_ALL_CORRECT_BONUS
281
+ else:
282
+ reward += ANSWER_CORE_ANY_WRONG_PENALTY
283
+
284
+ return reward
285
+
286
+ @property
287
+ def state(self) -> State:
288
+ """
289
+ Get the current environment state.
290
+
291
+ Returns:
292
+ Current State with episode_id and step_count
293
+ """
294
+ return self._state
server/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ openenv[core]>=0.2.0
2
+ fastapi>=0.115.0
3
+ uvicorn>=0.24.0
4
+
5
+
6
+
uv.lock ADDED
The diff for this file is too large to render. See raw diff