Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- Dockerfile +81 -0
- README.md +199 -7
- __init__.py +17 -0
- client.py +123 -0
- exp.py +463 -0
- models.py +59 -0
- openenv.yaml +7 -0
- openenv_ask_answer_env.egg-info/PKG-INFO +9 -0
- openenv_ask_answer_env.egg-info/SOURCES.txt +18 -0
- openenv_ask_answer_env.egg-info/dependency_links.txt +1 -0
- openenv_ask_answer_env.egg-info/entry_points.txt +2 -0
- openenv_ask_answer_env.egg-info/requires.txt +5 -0
- openenv_ask_answer_env.egg-info/top_level.txt +1 -0
- pyproject.toml +45 -0
- server/__init__.py +11 -0
- server/app.py +81 -0
- server/ask_answer_env_environment.py +294 -0
- server/requirements.txt +6 -0
- uv.lock +0 -0
Dockerfile
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Multi-stage build using openenv-base
|
| 8 |
+
# This Dockerfile is flexible and works for both:
|
| 9 |
+
# - In-repo environments (with local OpenEnv sources)
|
| 10 |
+
# - Standalone environments (with openenv from PyPI/Git)
|
| 11 |
+
# The build script (openenv build) handles context detection and sets appropriate build args.
|
| 12 |
+
|
| 13 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 14 |
+
FROM ${BASE_IMAGE} AS builder
|
| 15 |
+
|
| 16 |
+
WORKDIR /app
|
| 17 |
+
|
| 18 |
+
# Ensure git is available (required for installing dependencies from VCS)
|
| 19 |
+
RUN apt-get update && \
|
| 20 |
+
apt-get install -y --no-install-recommends git && \
|
| 21 |
+
rm -rf /var/lib/apt/lists/*
|
| 22 |
+
|
| 23 |
+
# Build argument to control whether we're building standalone or in-repo
|
| 24 |
+
ARG BUILD_MODE=in-repo
|
| 25 |
+
ARG ENV_NAME=ask_answer_env
|
| 26 |
+
|
| 27 |
+
# Copy environment code (always at root of build context)
|
| 28 |
+
COPY . /app/env
|
| 29 |
+
|
| 30 |
+
# For in-repo builds, openenv is already vendored in the build context
|
| 31 |
+
# For standalone builds, openenv will be installed via pyproject.toml
|
| 32 |
+
WORKDIR /app/env
|
| 33 |
+
|
| 34 |
+
# Ensure uv is available (for local builds where base image lacks it)
|
| 35 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 36 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 37 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 38 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 39 |
+
fi
|
| 40 |
+
|
| 41 |
+
# Install dependencies using uv sync
|
| 42 |
+
# If uv.lock exists, use it; otherwise resolve on the fly
|
| 43 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 44 |
+
if [ -f uv.lock ]; then \
|
| 45 |
+
uv sync --frozen --no-install-project --no-editable; \
|
| 46 |
+
else \
|
| 47 |
+
uv sync --no-install-project --no-editable; \
|
| 48 |
+
fi
|
| 49 |
+
|
| 50 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 51 |
+
if [ -f uv.lock ]; then \
|
| 52 |
+
uv sync --frozen --no-editable; \
|
| 53 |
+
else \
|
| 54 |
+
uv sync --no-editable; \
|
| 55 |
+
fi
|
| 56 |
+
|
| 57 |
+
# Final runtime stage
|
| 58 |
+
FROM ${BASE_IMAGE}
|
| 59 |
+
|
| 60 |
+
WORKDIR /app
|
| 61 |
+
|
| 62 |
+
# Copy the virtual environment from builder
|
| 63 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 64 |
+
|
| 65 |
+
# Copy the environment code
|
| 66 |
+
COPY --from=builder /app/env /app/env
|
| 67 |
+
|
| 68 |
+
# Set PATH to use the virtual environment
|
| 69 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 70 |
+
|
| 71 |
+
# Set PYTHONPATH so imports work correctly
|
| 72 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 73 |
+
|
| 74 |
+
# Health check
|
| 75 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 76 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 77 |
+
|
| 78 |
+
# Run the FastAPI server
|
| 79 |
+
# The module path is constructed to work with the /app/env structure
|
| 80 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 81 |
+
CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
|
README.md
CHANGED
|
@@ -1,12 +1,204 @@
|
|
| 1 |
---
|
| 2 |
title: Ask Answer Env
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: Ask Answer Env
|
| 3 |
+
emoji: 🎯
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: docker
|
| 7 |
+
app_port: 8000
|
| 8 |
+
base_path: /web
|
| 9 |
+
tags:
|
| 10 |
+
- openenv
|
| 11 |
+
- rl
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# Ask Answer Env (v1)
|
| 15 |
+
|
| 16 |
+
A deterministic OpenEnv environment for training RL agents to decide between **asking clarifying questions** or **answering early** under budget constraints.
|
| 17 |
+
|
| 18 |
+
## Overview
|
| 19 |
+
|
| 20 |
+
The agent receives a user prompt ("Plan a short trip for me.") and must discover hidden slot values by asking questions before providing a final answer. With only **3 steps** and **4 slots** (3 core + 1 distractor), the agent must prioritize which questions to ask.
|
| 21 |
+
|
| 22 |
+
**Key design goals:**
|
| 23 |
+
- No ML, no NLP — just structured interaction + delayed reward
|
| 24 |
+
- Deterministic given a seed
|
| 25 |
+
- Budget constraints force non-trivial tradeoffs (can only ask 2 of 4 slots)
|
| 26 |
+
- Graded reward structure (partial credit for correct slots)
|
| 27 |
+
|
| 28 |
+
## Hidden State
|
| 29 |
+
|
| 30 |
+
At each episode reset, the environment samples (with seeded RNG):
|
| 31 |
+
- `city` ∈ `["Paris", "Rome", "Tokyo", "Goa"]` (core)
|
| 32 |
+
- `date` ∈ `["next_weekend", "mid_feb", "march"]` (core)
|
| 33 |
+
- `budget` ∈ `["low", "mid", "high"]` (core)
|
| 34 |
+
- `style` ∈ `["relax", "adventure", "food"]` (distractor)
|
| 35 |
+
|
| 36 |
+
The agent cannot see hidden values unless it asks.
|
| 37 |
+
|
| 38 |
+
## Action Space
|
| 39 |
+
|
| 40 |
+
**ASK** — reveal a slot:
|
| 41 |
+
```python
|
| 42 |
+
AskAnswerAction(type="ask", slot="city") # or "date", "budget", "style"
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
**ANSWER** — end episode with guesses:
|
| 46 |
+
```python
|
| 47 |
+
AskAnswerAction(type="answer", city="Paris", date="mid_feb", budget="high", style="relax")
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
## Observation
|
| 51 |
+
|
| 52 |
+
```python
|
| 53 |
+
{
|
| 54 |
+
"prompt": "Plan a short trip for me.",
|
| 55 |
+
"known": {
|
| 56 |
+
"city": None | str,
|
| 57 |
+
"date": None | str,
|
| 58 |
+
"budget": None | str,
|
| 59 |
+
"style": None | str
|
| 60 |
+
},
|
| 61 |
+
"steps_left": int, # starts at 3
|
| 62 |
+
"core_correct_count": int | None # populated after ANSWER (0-3)
|
| 63 |
+
}
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
## Rewards (v1 - Graded Scoring)
|
| 67 |
+
|
| 68 |
+
| Event | Reward |
|
| 69 |
+
|-------|--------|
|
| 70 |
+
| Step penalty (always) | -0.05 |
|
| 71 |
+
| ASK unknown slot | +0.10 |
|
| 72 |
+
| ASK already-known slot | -0.20 |
|
| 73 |
+
| City correct | +0.40 |
|
| 74 |
+
| Date correct | +0.40 |
|
| 75 |
+
| Budget correct | +0.40 |
|
| 76 |
+
| Style correct (bonus) | +0.10 |
|
| 77 |
+
| All 3 core slots correct (bonus) | +0.20 |
|
| 78 |
+
| Any core slot wrong (penalty) | -0.60 |
|
| 79 |
+
|
| 80 |
+
**Oracle reward (theoretical max):** +1.45 (knows everything, answers perfectly in 1 step)
|
| 81 |
+
|
| 82 |
+
## Baseline Results
|
| 83 |
+
|
| 84 |
+
```
|
| 85 |
+
==========================================================================================
|
| 86 |
+
RESULTS SUMMARY (200 episodes each)
|
| 87 |
+
==========================================================================================
|
| 88 |
+
Baseline Mean Std Pos% Core% AvgCore
|
| 89 |
+
------------------------------------------------------------------------------------------
|
| 90 |
+
Oracle (theoretical) +1.450 0.000 100% 100% 3.00/3
|
| 91 |
+
B: city+budget +0.634 0.560 100% 32% 2.32/3
|
| 92 |
+
A: city+date +0.604 0.547 100% 30% 2.29/3
|
| 93 |
+
C: style+city (trap) +0.284 0.483 50% 11% 1.61/3
|
| 94 |
+
Random -0.134 0.530 30% 6% 1.08/3
|
| 95 |
+
------------------------------------------------------------------------------------------
|
| 96 |
+
|
| 97 |
+
Column legend:
|
| 98 |
+
Mean = mean total reward
|
| 99 |
+
Pos% = positive_return_rate (% episodes with reward > 0)
|
| 100 |
+
Core% = core_success_rate (% episodes with all 3 core slots correct)
|
| 101 |
+
AvgCore = avg_core_correct (mean # of core slots correct, out of 3)
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
**Key insights:**
|
| 105 |
+
- A/B strategies (ask 2 core slots) achieve ~100% positive return
|
| 106 |
+
- C strategy (wastes a question on style distractor) drops to ~50%
|
| 107 |
+
- Random baseline performs poorly (~30% positive return)
|
| 108 |
+
- Core success rate ~30% for A/B matches expected 1/3 (guessing 1 slot)
|
| 109 |
+
|
| 110 |
+
## Quick Start
|
| 111 |
+
|
| 112 |
+
### Build Docker Image
|
| 113 |
+
|
| 114 |
+
```bash
|
| 115 |
+
# For local use (root Dockerfile used by HF Spaces)
|
| 116 |
+
docker build -t ask_answer_env-env:latest .
|
| 117 |
+
|
| 118 |
+
# Or use server/Dockerfile (equivalent)
|
| 119 |
+
docker build -t ask_answer_env-env:latest -f server/Dockerfile .
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
### Run Baseline Tests
|
| 123 |
+
|
| 124 |
+
```bash
|
| 125 |
+
uv run python exp.py
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
### Example Usage
|
| 129 |
+
|
| 130 |
+
```python
|
| 131 |
+
from ask_answer_env import AskAnswerEnv, AskAnswerAction
|
| 132 |
+
|
| 133 |
+
client = AskAnswerEnv.from_docker_image("ask_answer_env-env:latest")
|
| 134 |
+
try:
|
| 135 |
+
result = client.reset(seed=42)
|
| 136 |
+
print(f"Steps left: {result.observation.steps_left}") # 3
|
| 137 |
+
|
| 138 |
+
# Ask about city (step 1)
|
| 139 |
+
result = client.step(AskAnswerAction(type="ask", slot="city"))
|
| 140 |
+
print(f"City: {result.observation.known.city}")
|
| 141 |
+
|
| 142 |
+
# Ask about date (step 2)
|
| 143 |
+
result = client.step(AskAnswerAction(type="ask", slot="date"))
|
| 144 |
+
print(f"Date: {result.observation.known.date}")
|
| 145 |
+
|
| 146 |
+
# Must answer now (step 3) - guess budget
|
| 147 |
+
known = result.observation.known
|
| 148 |
+
result = client.step(AskAnswerAction(
|
| 149 |
+
type="answer",
|
| 150 |
+
city=known.city,
|
| 151 |
+
date=known.date,
|
| 152 |
+
budget="mid", # guess
|
| 153 |
+
))
|
| 154 |
+
print(f"Final reward: {result.reward}")
|
| 155 |
+
print(f"Core correct: {result.observation.core_correct_count}/3")
|
| 156 |
+
finally:
|
| 157 |
+
client.close()
|
| 158 |
+
```
|
| 159 |
+
|
| 160 |
+
## Testing (`exp.py`)
|
| 161 |
+
|
| 162 |
+
The `exp.py` script contains:
|
| 163 |
+
|
| 164 |
+
### 1. Determinism Tests
|
| 165 |
+
Verifies same seed → identical trajectories and rewards.
|
| 166 |
+
|
| 167 |
+
### 2. Seed Sensitivity Test
|
| 168 |
+
Confirms different seeds produce different hidden states.
|
| 169 |
+
|
| 170 |
+
### 3. Baseline Comparison
|
| 171 |
+
Runs 5 strategies over 200 episodes each:
|
| 172 |
+
- **Oracle**: Theoretical upper bound (knows hidden state)
|
| 173 |
+
- **A: city+date**: Ask city, ask date, guess budget
|
| 174 |
+
- **B: city+budget**: Ask city, ask budget, guess date
|
| 175 |
+
- **C: style+city (trap)**: Wastes a question on distractor
|
| 176 |
+
- **Random**: Random ask/answer decisions
|
| 177 |
+
|
| 178 |
+
### 4. Ordering Verification
|
| 179 |
+
Confirms: Oracle > A ≈ B >> C > Random
|
| 180 |
+
|
| 181 |
+
## Project Structure
|
| 182 |
+
|
| 183 |
+
```
|
| 184 |
+
ask_answer_env/
|
| 185 |
+
├── __init__.py # Module exports
|
| 186 |
+
├── models.py # AskAnswerAction, AskAnswerObservation, KnownSlots
|
| 187 |
+
├── client.py # AskAnswerEnv client (WebSocket)
|
| 188 |
+
├── exp.py # Baseline strategies + acceptance tests
|
| 189 |
+
├── Dockerfile # Root Dockerfile (for HF Spaces)
|
| 190 |
+
├── server/
|
| 191 |
+
│ ├── ask_answer_env_environment.py # Core environment logic
|
| 192 |
+
│ ├── app.py # FastAPI server
|
| 193 |
+
│ └── Dockerfile
|
| 194 |
+
├── openenv.yaml # OpenEnv manifest
|
| 195 |
+
├── pyproject.toml # Dependencies
|
| 196 |
+
└── uv.lock # Locked deps
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
## Episode Rules
|
| 200 |
+
|
| 201 |
+
- `max_steps = 3`
|
| 202 |
+
- Episode ends when agent sends ANSWER or steps run out
|
| 203 |
+
- Auto-fail (steps exhausted) gives -1.0 reward
|
| 204 |
+
- With 3 steps, agent can ask at most 2 slots before forced to answer/fail
|
__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Ask Answer Env Environment."""
|
| 8 |
+
|
| 9 |
+
from .client import AskAnswerEnv
|
| 10 |
+
from .models import AskAnswerAction, AskAnswerObservation, KnownSlots
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"AskAnswerAction",
|
| 14 |
+
"AskAnswerObservation",
|
| 15 |
+
"KnownSlots",
|
| 16 |
+
"AskAnswerEnv",
|
| 17 |
+
]
|
client.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Ask Answer Env Environment Client."""
|
| 8 |
+
|
| 9 |
+
from typing import Dict
|
| 10 |
+
|
| 11 |
+
from openenv.core.client_types import StepResult
|
| 12 |
+
from openenv.core.env_server.types import State
|
| 13 |
+
from openenv.core import EnvClient
|
| 14 |
+
|
| 15 |
+
from .models import AskAnswerAction, AskAnswerObservation, KnownSlots
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class AskAnswerEnv(
|
| 19 |
+
EnvClient[AskAnswerAction, AskAnswerObservation, State]
|
| 20 |
+
):
|
| 21 |
+
"""
|
| 22 |
+
Client for the Ask Answer Env Environment.
|
| 23 |
+
|
| 24 |
+
A slot-filling environment where agents must decide between asking
|
| 25 |
+
clarifying questions or answering early to maximize reward.
|
| 26 |
+
|
| 27 |
+
Example:
|
| 28 |
+
>>> with AskAnswerEnv(base_url="http://localhost:8000") as client:
|
| 29 |
+
... result = client.reset(seed=42)
|
| 30 |
+
... print(result.observation.prompt)
|
| 31 |
+
... print(result.observation.known)
|
| 32 |
+
...
|
| 33 |
+
... # Ask about city
|
| 34 |
+
... result = client.step(AskAnswerAction(type="ask", slot="city"))
|
| 35 |
+
... print(f"City: {result.observation.known.city}")
|
| 36 |
+
...
|
| 37 |
+
... # Answer with known values
|
| 38 |
+
... result = client.step(AskAnswerAction(
|
| 39 |
+
... type="answer",
|
| 40 |
+
... city=result.observation.known.city,
|
| 41 |
+
... date="mid_feb",
|
| 42 |
+
... budget="high"
|
| 43 |
+
... ))
|
| 44 |
+
... print(f"Reward: {result.reward}, Done: {result.done}")
|
| 45 |
+
|
| 46 |
+
Example with Docker:
|
| 47 |
+
>>> client = AskAnswerEnv.from_docker_image("ask_answer_env-env:latest")
|
| 48 |
+
>>> try:
|
| 49 |
+
... result = client.reset(seed=42)
|
| 50 |
+
... # ... interact with environment
|
| 51 |
+
... finally:
|
| 52 |
+
... client.close()
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def _step_payload(self, action: AskAnswerAction) -> Dict:
|
| 56 |
+
"""
|
| 57 |
+
Convert AskAnswerAction to JSON payload for step message.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
action: AskAnswerAction instance
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
Dictionary representation suitable for JSON encoding
|
| 64 |
+
"""
|
| 65 |
+
payload = {"type": action.type}
|
| 66 |
+
if action.type == "ask":
|
| 67 |
+
payload["slot"] = action.slot
|
| 68 |
+
else: # answer
|
| 69 |
+
payload["city"] = action.city
|
| 70 |
+
payload["date"] = action.date
|
| 71 |
+
payload["budget"] = action.budget
|
| 72 |
+
payload["style"] = action.style
|
| 73 |
+
return payload
|
| 74 |
+
|
| 75 |
+
def _parse_result(self, payload: Dict) -> StepResult[AskAnswerObservation]:
|
| 76 |
+
"""
|
| 77 |
+
Parse server response into StepResult[AskAnswerObservation].
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
payload: JSON response data from server
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
StepResult with AskAnswerObservation
|
| 84 |
+
"""
|
| 85 |
+
obs_data = payload.get("observation", {})
|
| 86 |
+
known_data = obs_data.get("known", {})
|
| 87 |
+
|
| 88 |
+
known = KnownSlots(
|
| 89 |
+
city=known_data.get("city"),
|
| 90 |
+
date=known_data.get("date"),
|
| 91 |
+
budget=known_data.get("budget"),
|
| 92 |
+
style=known_data.get("style"),
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
observation = AskAnswerObservation(
|
| 96 |
+
prompt=obs_data.get("prompt", "Plan a short trip for me."),
|
| 97 |
+
known=known,
|
| 98 |
+
steps_left=obs_data.get("steps_left", 0),
|
| 99 |
+
done=payload.get("done", False),
|
| 100 |
+
reward=payload.get("reward"),
|
| 101 |
+
core_correct_count=obs_data.get("core_correct_count"),
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
return StepResult(
|
| 105 |
+
observation=observation,
|
| 106 |
+
reward=payload.get("reward"),
|
| 107 |
+
done=payload.get("done", False),
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
def _parse_state(self, payload: Dict) -> State:
|
| 111 |
+
"""
|
| 112 |
+
Parse server response into State object.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
payload: JSON response from state request
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
State object with episode_id and step_count
|
| 119 |
+
"""
|
| 120 |
+
return State(
|
| 121 |
+
episode_id=payload.get("episode_id"),
|
| 122 |
+
step_count=payload.get("step_count", 0),
|
| 123 |
+
)
|
exp.py
ADDED
|
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Baseline agents for the Ask Answer Env environment (v1).
|
| 3 |
+
|
| 4 |
+
Tests different ask-vs-act strategies under budget constraints (MAX_STEPS=3).
|
| 5 |
+
With only 3 steps, agents can ask at most 2 slots before being forced to answer,
|
| 6 |
+
creating a non-trivial tradeoff between information gathering and guessing.
|
| 7 |
+
|
| 8 |
+
Baselines:
|
| 9 |
+
- A: city+date (ask city, ask date, guess budget)
|
| 10 |
+
- B: city+budget (ask city, ask budget, guess date)
|
| 11 |
+
- C: style+city (trap: wastes a question on distractor)
|
| 12 |
+
- Random: random actions
|
| 13 |
+
- Oracle: knows hidden state, answers immediately (upper bound)
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import random
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Callable, List, Optional, Tuple
|
| 19 |
+
|
| 20 |
+
from ask_answer_env import AskAnswerEnv, AskAnswerAction, KnownSlots
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Type aliases
|
| 24 |
+
HiddenTuple = Tuple[str, str, str, str] # (city, date, budget, style)
|
| 25 |
+
StrategyFn = Callable[[KnownSlots, int, Optional[HiddenTuple]], AskAnswerAction]
|
| 26 |
+
|
| 27 |
+
# Default guesses when slot is unknown
|
| 28 |
+
DEFAULT_CITY = "Paris"
|
| 29 |
+
DEFAULT_DATE = "mid_feb"
|
| 30 |
+
DEFAULT_BUDGET = "mid"
|
| 31 |
+
DEFAULT_STYLE = "relax"
|
| 32 |
+
|
| 33 |
+
# Valid slot values (for random baseline)
|
| 34 |
+
CITIES = ["Paris", "Rome", "Tokyo", "Goa"]
|
| 35 |
+
DATES = ["next_weekend", "mid_feb", "march"]
|
| 36 |
+
BUDGETS = ["low", "mid", "high"]
|
| 37 |
+
STYLES = ["relax", "adventure", "food"]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class EpisodeResult:
|
| 42 |
+
"""Result of a single episode."""
|
| 43 |
+
total_reward: float
|
| 44 |
+
revealed: HiddenTuple
|
| 45 |
+
steps_taken: int
|
| 46 |
+
core_correct_count: int # 0-3: how many core slots were correct
|
| 47 |
+
core_all_correct: bool # True if all 3 core slots correct
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# =============================================================================
|
| 51 |
+
# Strategy Functions
|
| 52 |
+
# =============================================================================
|
| 53 |
+
|
| 54 |
+
def strategy_city_date(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction:
|
| 55 |
+
"""
|
| 56 |
+
Strategy A: Ask city, ask date, then answer (guess budget).
|
| 57 |
+
|
| 58 |
+
Expected behavior with MAX_STEPS=3:
|
| 59 |
+
- Step 1: ASK city
|
| 60 |
+
- Step 2: ASK date
|
| 61 |
+
- Step 3: ANSWER with known city+date, guess budget
|
| 62 |
+
"""
|
| 63 |
+
if known.city is None:
|
| 64 |
+
return AskAnswerAction(type="ask", slot="city")
|
| 65 |
+
elif known.date is None:
|
| 66 |
+
return AskAnswerAction(type="ask", slot="date")
|
| 67 |
+
else:
|
| 68 |
+
return AskAnswerAction(
|
| 69 |
+
type="answer",
|
| 70 |
+
city=known.city,
|
| 71 |
+
date=known.date,
|
| 72 |
+
budget=known.budget if known.budget else DEFAULT_BUDGET,
|
| 73 |
+
style=known.style, # None if not asked
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def strategy_city_budget(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction:
|
| 78 |
+
"""
|
| 79 |
+
Strategy B: Ask city, ask budget, then answer (guess date).
|
| 80 |
+
|
| 81 |
+
Expected behavior with MAX_STEPS=3:
|
| 82 |
+
- Step 1: ASK city
|
| 83 |
+
- Step 2: ASK budget
|
| 84 |
+
- Step 3: ANSWER with known city+budget, guess date
|
| 85 |
+
"""
|
| 86 |
+
if known.city is None:
|
| 87 |
+
return AskAnswerAction(type="ask", slot="city")
|
| 88 |
+
elif known.budget is None:
|
| 89 |
+
return AskAnswerAction(type="ask", slot="budget")
|
| 90 |
+
else:
|
| 91 |
+
return AskAnswerAction(
|
| 92 |
+
type="answer",
|
| 93 |
+
city=known.city,
|
| 94 |
+
date=known.date if known.date else DEFAULT_DATE,
|
| 95 |
+
budget=known.budget,
|
| 96 |
+
style=known.style,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def strategy_style_city(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction:
|
| 101 |
+
"""
|
| 102 |
+
Strategy C (TRAP): Ask style first, then city, guess date+budget.
|
| 103 |
+
|
| 104 |
+
This is a BAD strategy because:
|
| 105 |
+
- Style only gives +0.1 bonus (vs +0.4 for core slots)
|
| 106 |
+
- Wastes a question on a low-value distractor
|
| 107 |
+
- Must guess 2 core slots instead of 1
|
| 108 |
+
|
| 109 |
+
Expected behavior with MAX_STEPS=3:
|
| 110 |
+
- Step 1: ASK style (bad choice!)
|
| 111 |
+
- Step 2: ASK city
|
| 112 |
+
- Step 3: ANSWER with known style+city, guess date+budget
|
| 113 |
+
"""
|
| 114 |
+
if known.style is None:
|
| 115 |
+
return AskAnswerAction(type="ask", slot="style")
|
| 116 |
+
elif known.city is None:
|
| 117 |
+
return AskAnswerAction(type="ask", slot="city")
|
| 118 |
+
else:
|
| 119 |
+
return AskAnswerAction(
|
| 120 |
+
type="answer",
|
| 121 |
+
city=known.city,
|
| 122 |
+
date=known.date if known.date else DEFAULT_DATE,
|
| 123 |
+
budget=known.budget if known.budget else DEFAULT_BUDGET,
|
| 124 |
+
style=known.style,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def strategy_random(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction:
|
| 129 |
+
"""
|
| 130 |
+
Random baseline: randomly ask or answer with random values.
|
| 131 |
+
|
| 132 |
+
50% chance to ask a random unknown slot, 50% chance to answer.
|
| 133 |
+
If no unknown slots, always answer.
|
| 134 |
+
"""
|
| 135 |
+
unknown_slots = []
|
| 136 |
+
if known.city is None:
|
| 137 |
+
unknown_slots.append("city")
|
| 138 |
+
if known.date is None:
|
| 139 |
+
unknown_slots.append("date")
|
| 140 |
+
if known.budget is None:
|
| 141 |
+
unknown_slots.append("budget")
|
| 142 |
+
if known.style is None:
|
| 143 |
+
unknown_slots.append("style")
|
| 144 |
+
|
| 145 |
+
# If we have unknown slots and coin flip says ask
|
| 146 |
+
if unknown_slots and random.random() < 0.5:
|
| 147 |
+
slot = random.choice(unknown_slots)
|
| 148 |
+
return AskAnswerAction(type="ask", slot=slot)
|
| 149 |
+
|
| 150 |
+
# Otherwise answer with random guesses for unknown slots
|
| 151 |
+
return AskAnswerAction(
|
| 152 |
+
type="answer",
|
| 153 |
+
city=known.city if known.city else random.choice(CITIES),
|
| 154 |
+
date=known.date if known.date else random.choice(DATES),
|
| 155 |
+
budget=known.budget if known.budget else random.choice(BUDGETS),
|
| 156 |
+
style=known.style if known.style else random.choice(STYLES),
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def strategy_oracle(known: KnownSlots, steps_left: int, hidden: Optional[HiddenTuple] = None) -> AskAnswerAction:
|
| 161 |
+
"""
|
| 162 |
+
Oracle baseline: knows hidden state, answers perfectly in 1 step.
|
| 163 |
+
|
| 164 |
+
This is the THEORETICAL UPPER BOUND.
|
| 165 |
+
|
| 166 |
+
In practice, this strategy function is NOT used because the server
|
| 167 |
+
doesn't expose hidden state to the client. Instead, we hardcode
|
| 168 |
+
the oracle's reward as 1.45 in run_baseline_test().
|
| 169 |
+
|
| 170 |
+
Reward breakdown:
|
| 171 |
+
-0.05 (step) + 0.4×3 (core) + 0.1 (style) + 0.2 (bonus) = +1.45
|
| 172 |
+
"""
|
| 173 |
+
if hidden is None:
|
| 174 |
+
raise ValueError("Oracle strategy requires hidden state")
|
| 175 |
+
|
| 176 |
+
city, date, budget, style = hidden
|
| 177 |
+
return AskAnswerAction(
|
| 178 |
+
type="answer",
|
| 179 |
+
city=city,
|
| 180 |
+
date=date,
|
| 181 |
+
budget=budget,
|
| 182 |
+
style=style,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# =============================================================================
|
| 187 |
+
# Episode Runner
|
| 188 |
+
# =============================================================================
|
| 189 |
+
|
| 190 |
+
def run_episode(
|
| 191 |
+
client: AskAnswerEnv,
|
| 192 |
+
strategy: StrategyFn,
|
| 193 |
+
seed: int = 42,
|
| 194 |
+
hidden: Optional[HiddenTuple] = None,
|
| 195 |
+
verbose: bool = False,
|
| 196 |
+
) -> EpisodeResult:
|
| 197 |
+
"""
|
| 198 |
+
Run a single episode with the given strategy.
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
client: AskAnswerEnv client instance
|
| 202 |
+
strategy: Function that takes (known, steps_left, hidden) and returns action
|
| 203 |
+
seed: Random seed for reproducibility
|
| 204 |
+
hidden: Hidden state tuple (required for oracle strategy)
|
| 205 |
+
verbose: Whether to print step-by-step info
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
EpisodeResult with total_reward, revealed slots, and steps taken
|
| 209 |
+
"""
|
| 210 |
+
result = client.reset(seed=seed)
|
| 211 |
+
total_reward = 0.0
|
| 212 |
+
steps = 0
|
| 213 |
+
|
| 214 |
+
if verbose:
|
| 215 |
+
print(f"=== Episode Start (seed={seed}) ===")
|
| 216 |
+
print(f"Steps left: {result.observation.steps_left}")
|
| 217 |
+
|
| 218 |
+
while not result.done:
|
| 219 |
+
obs = result.observation
|
| 220 |
+
action = strategy(obs.known, obs.steps_left, hidden)
|
| 221 |
+
|
| 222 |
+
result = client.step(action)
|
| 223 |
+
total_reward += result.reward
|
| 224 |
+
steps += 1
|
| 225 |
+
|
| 226 |
+
if verbose:
|
| 227 |
+
if action.type == "ask":
|
| 228 |
+
slot_val = getattr(result.observation.known, action.slot)
|
| 229 |
+
print(f" Step {steps}: ASK {action.slot} -> {slot_val}, reward={result.reward:+.2f}")
|
| 230 |
+
else:
|
| 231 |
+
print(f" Step {steps}: ANSWER city={action.city}, date={action.date}, "
|
| 232 |
+
f"budget={action.budget}, style={action.style}, reward={result.reward:+.2f}")
|
| 233 |
+
|
| 234 |
+
final = result.observation.known
|
| 235 |
+
revealed = (final.city, final.date, final.budget, final.style)
|
| 236 |
+
|
| 237 |
+
# Extract correctness info (available when done=True after ANSWER)
|
| 238 |
+
core_correct_count = result.observation.core_correct_count or 0
|
| 239 |
+
core_all_correct = core_correct_count == 3
|
| 240 |
+
|
| 241 |
+
if verbose:
|
| 242 |
+
print(f" Total reward: {total_reward:+.2f}")
|
| 243 |
+
print(f" Core correct: {core_correct_count}/3")
|
| 244 |
+
print()
|
| 245 |
+
|
| 246 |
+
return EpisodeResult(
|
| 247 |
+
total_reward=total_reward,
|
| 248 |
+
revealed=revealed,
|
| 249 |
+
steps_taken=steps,
|
| 250 |
+
core_correct_count=core_correct_count,
|
| 251 |
+
core_all_correct=core_all_correct,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
# =============================================================================
|
| 256 |
+
# Acceptance Tests
|
| 257 |
+
# =============================================================================
|
| 258 |
+
|
| 259 |
+
@dataclass
|
| 260 |
+
class BaselineStats:
|
| 261 |
+
"""Statistics for a baseline over multiple episodes."""
|
| 262 |
+
name: str
|
| 263 |
+
mean_reward: float
|
| 264 |
+
std_reward: float
|
| 265 |
+
positive_return_rate: float # % of episodes with reward > 0
|
| 266 |
+
core_success_rate: float # % of episodes with all 3 core slots correct
|
| 267 |
+
avg_core_correct: float # average number of core slots correct (0-3)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def run_baseline_test(
|
| 271 |
+
client: AskAnswerEnv,
|
| 272 |
+
name: str,
|
| 273 |
+
strategy: StrategyFn,
|
| 274 |
+
num_episodes: int = 200,
|
| 275 |
+
needs_oracle: bool = False,
|
| 276 |
+
) -> BaselineStats:
|
| 277 |
+
"""
|
| 278 |
+
Run multiple episodes with a strategy and compute statistics.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
client: AskAnswerEnv client instance
|
| 282 |
+
name: Name of the baseline for logging
|
| 283 |
+
strategy: Strategy function
|
| 284 |
+
num_episodes: Number of episodes to run
|
| 285 |
+
needs_oracle: If True, use theoretical oracle values
|
| 286 |
+
|
| 287 |
+
Returns:
|
| 288 |
+
BaselineStats with all metrics
|
| 289 |
+
"""
|
| 290 |
+
if needs_oracle:
|
| 291 |
+
# Oracle is a THEORETICAL upper bound - knows hidden state,
|
| 292 |
+
# answers perfectly in 1 step.
|
| 293 |
+
#
|
| 294 |
+
# Reward: -0.05 + 0.4×3 + 0.1 + 0.2 = +1.45
|
| 295 |
+
# Core correct: 3/3 always
|
| 296 |
+
return BaselineStats(
|
| 297 |
+
name=name,
|
| 298 |
+
mean_reward=1.45,
|
| 299 |
+
std_reward=0.0,
|
| 300 |
+
positive_return_rate=1.0,
|
| 301 |
+
core_success_rate=1.0,
|
| 302 |
+
avg_core_correct=3.0,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
results: List[EpisodeResult] = []
|
| 306 |
+
|
| 307 |
+
for seed in range(num_episodes):
|
| 308 |
+
result = run_episode(client, strategy, seed=seed)
|
| 309 |
+
results.append(result)
|
| 310 |
+
|
| 311 |
+
rewards = [r.total_reward for r in results]
|
| 312 |
+
mean_reward = sum(rewards) / len(rewards)
|
| 313 |
+
variance = sum((r - mean_reward) ** 2 for r in rewards) / len(rewards)
|
| 314 |
+
std_reward = variance ** 0.5
|
| 315 |
+
|
| 316 |
+
positive_return_rate = sum(1 for r in rewards if r > 0) / len(rewards)
|
| 317 |
+
core_success_rate = sum(1 for r in results if r.core_all_correct) / len(results)
|
| 318 |
+
avg_core_correct = sum(r.core_correct_count for r in results) / len(results)
|
| 319 |
+
|
| 320 |
+
return BaselineStats(
|
| 321 |
+
name=name,
|
| 322 |
+
mean_reward=mean_reward,
|
| 323 |
+
std_reward=std_reward,
|
| 324 |
+
positive_return_rate=positive_return_rate,
|
| 325 |
+
core_success_rate=core_success_rate,
|
| 326 |
+
avg_core_correct=avg_core_correct,
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def run_acceptance_tests(client: AskAnswerEnv, num_episodes: int = 200) -> bool:
|
| 331 |
+
"""
|
| 332 |
+
Run all baseline tests and print results table.
|
| 333 |
+
|
| 334 |
+
Expected ordering:
|
| 335 |
+
Oracle > A ≈ B >> C > Random
|
| 336 |
+
"""
|
| 337 |
+
print(f"\nRunning {num_episodes} episodes per baseline...\n")
|
| 338 |
+
|
| 339 |
+
baselines = [
|
| 340 |
+
("Oracle (theoretical)", None, True),
|
| 341 |
+
("A: city+date", strategy_city_date, False),
|
| 342 |
+
("B: city+budget", strategy_city_budget, False),
|
| 343 |
+
("C: style+city (trap)", strategy_style_city, False),
|
| 344 |
+
("Random", strategy_random, False),
|
| 345 |
+
]
|
| 346 |
+
|
| 347 |
+
all_stats: List[BaselineStats] = []
|
| 348 |
+
for name, strategy, is_oracle in baselines:
|
| 349 |
+
stats = run_baseline_test(client, name, strategy, num_episodes, needs_oracle=is_oracle)
|
| 350 |
+
all_stats.append(stats)
|
| 351 |
+
print(f" {name}: mean={stats.mean_reward:+.3f}, core_success={stats.core_success_rate:.1%}")
|
| 352 |
+
|
| 353 |
+
# Print results table
|
| 354 |
+
print("\n" + "=" * 90)
|
| 355 |
+
print("RESULTS SUMMARY")
|
| 356 |
+
print("=" * 90)
|
| 357 |
+
header = f"{'Baseline':<22} {'Mean':>8} {'Std':>7} {'Pos%':>7} {'Core%':>7} {'AvgCore':>8}"
|
| 358 |
+
print(header)
|
| 359 |
+
print("-" * 90)
|
| 360 |
+
|
| 361 |
+
for s in sorted(all_stats, key=lambda x: -x.mean_reward):
|
| 362 |
+
print(f"{s.name:<22} {s.mean_reward:>+8.3f} {s.std_reward:>7.3f} "
|
| 363 |
+
f"{s.positive_return_rate:>6.0%} {s.core_success_rate:>6.0%} "
|
| 364 |
+
f"{s.avg_core_correct:>7.2f}/3")
|
| 365 |
+
print("-" * 90)
|
| 366 |
+
|
| 367 |
+
print("\nColumn legend:")
|
| 368 |
+
print(" Mean = mean total reward")
|
| 369 |
+
print(" Std = standard deviation of reward")
|
| 370 |
+
print(" Pos% = positive_return_rate (% episodes with reward > 0)")
|
| 371 |
+
print(" Core% = core_success_rate (% episodes with all 3 core slots correct)")
|
| 372 |
+
print(" AvgCore = avg_core_correct (mean # of core slots correct, out of 3)")
|
| 373 |
+
|
| 374 |
+
# Verify expected ordering
|
| 375 |
+
result_dict = {s.name: s.mean_reward for s in all_stats}
|
| 376 |
+
|
| 377 |
+
checks = [
|
| 378 |
+
("Oracle > A", result_dict["Oracle (theoretical)"] > result_dict["A: city+date"]),
|
| 379 |
+
("A ≈ B", abs(result_dict["A: city+date"] - result_dict["B: city+budget"]) < 0.1),
|
| 380 |
+
("A > C", result_dict["A: city+date"] > result_dict["C: style+city (trap)"]),
|
| 381 |
+
("C > Random", result_dict["C: style+city (trap)"] > result_dict["Random"]),
|
| 382 |
+
]
|
| 383 |
+
|
| 384 |
+
print("\nExpected ordering checks:")
|
| 385 |
+
all_passed = True
|
| 386 |
+
for check_name, passed in checks:
|
| 387 |
+
status = "PASS" if passed else "FAIL"
|
| 388 |
+
print(f" {check_name}: {status}")
|
| 389 |
+
if not passed:
|
| 390 |
+
all_passed = False
|
| 391 |
+
|
| 392 |
+
return all_passed
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
# =============================================================================
|
| 396 |
+
# Determinism Tests (kept from v0)
|
| 397 |
+
# =============================================================================
|
| 398 |
+
|
| 399 |
+
def test_determinism(client: AskAnswerEnv, seed: int = 42, runs: int = 3) -> bool:
|
| 400 |
+
"""Test that the same seed produces identical trajectories."""
|
| 401 |
+
trajectories = []
|
| 402 |
+
for _ in range(runs):
|
| 403 |
+
result = run_episode(client, strategy_city_date, seed=seed)
|
| 404 |
+
trajectories.append((result.total_reward, result.revealed))
|
| 405 |
+
|
| 406 |
+
rewards = [t[0] for t in trajectories]
|
| 407 |
+
revealed = [t[1] for t in trajectories]
|
| 408 |
+
|
| 409 |
+
identical = len(set(revealed)) == 1 and len(set(rewards)) == 1
|
| 410 |
+
print(f"Determinism (seed={seed}): {revealed[0]} x{runs}, identical={identical}")
|
| 411 |
+
return identical
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def test_seed_sensitivity(client: AskAnswerEnv, num_seeds: int = 20) -> bool:
|
| 415 |
+
"""Verify different seeds produce different hidden states."""
|
| 416 |
+
unique = set()
|
| 417 |
+
for seed in range(num_seeds):
|
| 418 |
+
result = run_episode(client, strategy_city_date, seed=seed)
|
| 419 |
+
unique.add(result.revealed)
|
| 420 |
+
|
| 421 |
+
# Max possible: 4 * 3 * 3 * 3 = 108 (with style)
|
| 422 |
+
print(f"Seed sensitivity: {len(unique)} unique tuples from {num_seeds} seeds")
|
| 423 |
+
return len(unique) > 1
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
# =============================================================================
|
| 427 |
+
# Main
|
| 428 |
+
# =============================================================================
|
| 429 |
+
|
| 430 |
+
if __name__ == "__main__":
|
| 431 |
+
client = AskAnswerEnv.from_docker_image("ask_answer_env-env:latest")
|
| 432 |
+
try:
|
| 433 |
+
print("=" * 60)
|
| 434 |
+
print("ASK-ANSWER ENV v1 ACCEPTANCE TESTS")
|
| 435 |
+
print("=" * 60)
|
| 436 |
+
|
| 437 |
+
# Quick determinism check
|
| 438 |
+
print("\n1. DETERMINISM TESTS")
|
| 439 |
+
print("-" * 40)
|
| 440 |
+
test_determinism(client, seed=42)
|
| 441 |
+
test_determinism(client, seed=123)
|
| 442 |
+
test_seed_sensitivity(client)
|
| 443 |
+
|
| 444 |
+
# Run a single verbose episode to show behavior
|
| 445 |
+
print("\n2. EXAMPLE EPISODE (Strategy A: city+date)")
|
| 446 |
+
print("-" * 40)
|
| 447 |
+
run_episode(client, strategy_city_date, seed=42, verbose=True)
|
| 448 |
+
|
| 449 |
+
print("\n3. EXAMPLE EPISODE (Strategy C: style+city - TRAP)")
|
| 450 |
+
print("-" * 40)
|
| 451 |
+
run_episode(client, strategy_style_city, seed=42, verbose=True)
|
| 452 |
+
|
| 453 |
+
# Full acceptance tests
|
| 454 |
+
print("\n4. BASELINE COMPARISON")
|
| 455 |
+
print("-" * 40)
|
| 456 |
+
passed = run_acceptance_tests(client, num_episodes=200)
|
| 457 |
+
|
| 458 |
+
print("\n" + "=" * 60)
|
| 459 |
+
print(f"ALL TESTS: {'PASSED' if passed else 'FAILED'}")
|
| 460 |
+
print("=" * 60)
|
| 461 |
+
|
| 462 |
+
finally:
|
| 463 |
+
client.close()
|
models.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Data models for the Ask Answer Env Environment.
|
| 9 |
+
|
| 10 |
+
A slot-filling environment where agents must decide between asking
|
| 11 |
+
clarifying questions or answering early.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from typing import Literal, Optional
|
| 15 |
+
|
| 16 |
+
from pydantic import BaseModel, Field
|
| 17 |
+
|
| 18 |
+
from openenv.core.env_server.types import Action, Observation
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class AskAnswerAction(Action):
|
| 22 |
+
"""
|
| 23 |
+
Action for the Ask Answer Env environment.
|
| 24 |
+
|
| 25 |
+
Two action types:
|
| 26 |
+
- "ask": Ask about a specific slot (city, date, budget, or style)
|
| 27 |
+
- "answer": Provide final answer with slot values
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
type: Literal["ask", "answer"] = Field(..., description="Action type: 'ask' or 'answer'")
|
| 31 |
+
slot: Optional[Literal["city", "date", "budget", "style"]] = Field(
|
| 32 |
+
default=None, description="Slot to ask about (only for 'ask' type)"
|
| 33 |
+
)
|
| 34 |
+
city: Optional[str] = Field(default=None, description="City answer (only for 'answer' type)")
|
| 35 |
+
date: Optional[str] = Field(default=None, description="Date answer (only for 'answer' type)")
|
| 36 |
+
budget: Optional[str] = Field(default=None, description="Budget answer (only for 'answer' type)")
|
| 37 |
+
style: Optional[str] = Field(default=None, description="Style answer (only for 'answer' type)")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class KnownSlots(BaseModel):
|
| 41 |
+
"""Slots that have been revealed to the agent."""
|
| 42 |
+
|
| 43 |
+
city: Optional[str] = Field(default=None, description="Known city value")
|
| 44 |
+
date: Optional[str] = Field(default=None, description="Known date value")
|
| 45 |
+
budget: Optional[str] = Field(default=None, description="Known budget value")
|
| 46 |
+
style: Optional[str] = Field(default=None, description="Known style value")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class AskAnswerObservation(Observation):
|
| 50 |
+
"""Observation from the Ask Answer Env environment."""
|
| 51 |
+
|
| 52 |
+
prompt: str = Field(default="Plan a short trip for me.", description="The user prompt")
|
| 53 |
+
known: KnownSlots = Field(default_factory=KnownSlots, description="Currently known slot values")
|
| 54 |
+
steps_left: int = Field(default=3, description="Number of steps remaining")
|
| 55 |
+
|
| 56 |
+
# Correctness info (populated when done=True after ANSWER action)
|
| 57 |
+
core_correct_count: Optional[int] = Field(
|
| 58 |
+
default=None, description="Number of core slots correct (0-3), None if not answered yet"
|
| 59 |
+
)
|
openenv.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: ask_answer_env
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8000
|
| 7 |
+
|
openenv_ask_answer_env.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: openenv-ask_answer_env
|
| 3 |
+
Version: 0.1.0
|
| 4 |
+
Summary: Ask Answer Env environment for OpenEnv
|
| 5 |
+
Requires-Python: >=3.10
|
| 6 |
+
Requires-Dist: openenv-core[core]>=0.2.1
|
| 7 |
+
Provides-Extra: dev
|
| 8 |
+
Requires-Dist: pytest>=8.0.0; extra == "dev"
|
| 9 |
+
Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
|
openenv_ask_answer_env.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
README.md
|
| 2 |
+
__init__.py
|
| 3 |
+
client.py
|
| 4 |
+
models.py
|
| 5 |
+
pyproject.toml
|
| 6 |
+
./__init__.py
|
| 7 |
+
./client.py
|
| 8 |
+
./exp.py
|
| 9 |
+
./models.py
|
| 10 |
+
openenv_ask_answer_env.egg-info/PKG-INFO
|
| 11 |
+
openenv_ask_answer_env.egg-info/SOURCES.txt
|
| 12 |
+
openenv_ask_answer_env.egg-info/dependency_links.txt
|
| 13 |
+
openenv_ask_answer_env.egg-info/entry_points.txt
|
| 14 |
+
openenv_ask_answer_env.egg-info/requires.txt
|
| 15 |
+
openenv_ask_answer_env.egg-info/top_level.txt
|
| 16 |
+
server/__init__.py
|
| 17 |
+
server/app.py
|
| 18 |
+
server/ask_answer_env_environment.py
|
openenv_ask_answer_env.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
openenv_ask_answer_env.egg-info/entry_points.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[console_scripts]
|
| 2 |
+
server = ask_answer_env.server.app:main
|
openenv_ask_answer_env.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core[core]>=0.2.1
|
| 2 |
+
|
| 3 |
+
[dev]
|
| 4 |
+
pytest>=8.0.0
|
| 5 |
+
pytest-cov>=4.0.0
|
openenv_ask_answer_env.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
ask_answer_env
|
pyproject.toml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
[build-system]
|
| 8 |
+
requires = ["setuptools>=45", "wheel"]
|
| 9 |
+
build-backend = "setuptools.build_meta"
|
| 10 |
+
|
| 11 |
+
[project]
|
| 12 |
+
name = "openenv-ask_answer_env"
|
| 13 |
+
version = "0.1.0"
|
| 14 |
+
description = "Ask Answer Env environment for OpenEnv"
|
| 15 |
+
requires-python = ">=3.10"
|
| 16 |
+
dependencies = [
|
| 17 |
+
# Core OpenEnv runtime (provides FastAPI server + HTTP client types)
|
| 18 |
+
# install from github
|
| 19 |
+
# "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
|
| 20 |
+
"openenv-core[core]>=0.2.1",
|
| 21 |
+
# Environment-specific dependencies
|
| 22 |
+
# Add all dependencies needed for your environment here
|
| 23 |
+
# Examples:
|
| 24 |
+
# "numpy>=1.19.0",
|
| 25 |
+
# "torch>=2.0.0",
|
| 26 |
+
# "gymnasium>=0.29.0",
|
| 27 |
+
# "openspiel>=1.0.0",
|
| 28 |
+
# "smolagents>=1.22.0,<2",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
[project.optional-dependencies]
|
| 32 |
+
dev = [
|
| 33 |
+
"pytest>=8.0.0",
|
| 34 |
+
"pytest-cov>=4.0.0",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
[project.scripts]
|
| 38 |
+
# Server entry point - enables running via: uv run --project . server
|
| 39 |
+
# or: python -m ask_answer_env.server.app
|
| 40 |
+
server = "ask_answer_env.server.app:main"
|
| 41 |
+
|
| 42 |
+
[tool.setuptools]
|
| 43 |
+
include-package-data = true
|
| 44 |
+
packages = ["ask_answer_env", "ask_answer_env.server"]
|
| 45 |
+
package-dir = { "ask_answer_env" = ".", "ask_answer_env.server" = "server" }
|
server/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Ask Answer Env environment server components."""
|
| 8 |
+
|
| 9 |
+
from .ask_answer_env_environment import AskAnswerEnvironment
|
| 10 |
+
|
| 11 |
+
__all__ = ["AskAnswerEnvironment"]
|
server/app.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
FastAPI application for the Ask Answer Env Environment.
|
| 9 |
+
|
| 10 |
+
This module creates an HTTP server that exposes the AskAnswerEnvironment
|
| 11 |
+
over HTTP and WebSocket endpoints, compatible with EnvClient.
|
| 12 |
+
|
| 13 |
+
Endpoints:
|
| 14 |
+
- POST /reset: Reset the environment
|
| 15 |
+
- POST /step: Execute an action
|
| 16 |
+
- GET /state: Get current environment state
|
| 17 |
+
- GET /schema: Get action/observation schemas
|
| 18 |
+
- WS /ws: WebSocket endpoint for persistent sessions
|
| 19 |
+
|
| 20 |
+
Usage:
|
| 21 |
+
# Development (with auto-reload):
|
| 22 |
+
uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
|
| 23 |
+
|
| 24 |
+
# Production:
|
| 25 |
+
uvicorn server.app:app --host 0.0.0.0 --port 8000 --workers 4
|
| 26 |
+
|
| 27 |
+
# Or run directly:
|
| 28 |
+
python -m server.app
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
from openenv.core.env_server.http_server import create_app
|
| 33 |
+
except Exception as e: # pragma: no cover
|
| 34 |
+
raise ImportError(
|
| 35 |
+
"openenv is required for the web interface. Install dependencies with '\n uv sync\n'"
|
| 36 |
+
) from e
|
| 37 |
+
|
| 38 |
+
# Import from local models.py (PYTHONPATH includes /app/env in Docker)
|
| 39 |
+
from models import AskAnswerAction, AskAnswerObservation
|
| 40 |
+
from .ask_answer_env_environment import AskAnswerEnvironment
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# Create the app with web interface and README integration
|
| 44 |
+
app = create_app(
|
| 45 |
+
AskAnswerEnvironment,
|
| 46 |
+
AskAnswerAction,
|
| 47 |
+
AskAnswerObservation,
|
| 48 |
+
env_name="ask_answer_env",
|
| 49 |
+
max_concurrent_envs=1, # increase this number to allow more concurrent WebSocket sessions
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def main(host: str = "0.0.0.0", port: int = 8000):
|
| 54 |
+
"""
|
| 55 |
+
Entry point for direct execution via uv run or python -m.
|
| 56 |
+
|
| 57 |
+
This function enables running the server without Docker:
|
| 58 |
+
uv run --project . server
|
| 59 |
+
uv run --project . server --port 8001
|
| 60 |
+
python -m ask_answer_env.server.app
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
host: Host address to bind to (default: "0.0.0.0")
|
| 64 |
+
port: Port number to listen on (default: 8000)
|
| 65 |
+
|
| 66 |
+
For production deployments, consider using uvicorn directly with
|
| 67 |
+
multiple workers:
|
| 68 |
+
uvicorn ask_answer_env.server.app:app --workers 4
|
| 69 |
+
"""
|
| 70 |
+
import uvicorn
|
| 71 |
+
|
| 72 |
+
uvicorn.run(app, host=host, port=port)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
if __name__ == "__main__":
|
| 76 |
+
import argparse
|
| 77 |
+
|
| 78 |
+
parser = argparse.ArgumentParser()
|
| 79 |
+
parser.add_argument("--port", type=int, default=8000)
|
| 80 |
+
args = parser.parse_args()
|
| 81 |
+
main(port=args.port)
|
server/ask_answer_env_environment.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Ask Answer Env Environment Implementation.
|
| 9 |
+
|
| 10 |
+
A deterministic slot-filling environment where agents must decide between
|
| 11 |
+
asking clarifying questions or answering early to maximize reward.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import random
|
| 15 |
+
from typing import Optional
|
| 16 |
+
from uuid import uuid4
|
| 17 |
+
|
| 18 |
+
from openenv.core.env_server.interfaces import Environment
|
| 19 |
+
from openenv.core.env_server.types import State
|
| 20 |
+
|
| 21 |
+
from models import AskAnswerAction, AskAnswerObservation, KnownSlots
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Constants
|
| 25 |
+
CITIES = ["Paris", "Rome", "Tokyo", "Goa"]
|
| 26 |
+
DATES = ["next_weekend", "mid_feb", "march"]
|
| 27 |
+
BUDGETS = ["low", "mid", "high"]
|
| 28 |
+
STYLES = ["relax", "adventure", "food"] # Distractor slot
|
| 29 |
+
MAX_STEPS = 3 # Forces agent to guess at least 1 core slot
|
| 30 |
+
PROMPT = "Plan a short trip for me."
|
| 31 |
+
|
| 32 |
+
# Rewards (unchanged from v0)
|
| 33 |
+
STEP_PENALTY = -0.05
|
| 34 |
+
ASK_UNKNOWN_REWARD = 0.1
|
| 35 |
+
ASK_KNOWN_PENALTY = -0.2
|
| 36 |
+
AUTO_FAIL_PENALTY = -1.0
|
| 37 |
+
|
| 38 |
+
# Graded answer rewards (v1)
|
| 39 |
+
ANSWER_CITY_CORRECT = 0.4
|
| 40 |
+
ANSWER_DATE_CORRECT = 0.4
|
| 41 |
+
ANSWER_BUDGET_CORRECT = 0.4
|
| 42 |
+
ANSWER_STYLE_CORRECT_BONUS = 0.1 # Optional nice-to-have
|
| 43 |
+
ANSWER_CORE_ALL_CORRECT_BONUS = 0.2
|
| 44 |
+
ANSWER_CORE_ANY_WRONG_PENALTY = -0.6
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class AskAnswerEnvironment(Environment):
|
| 48 |
+
"""
|
| 49 |
+
A slot-filling environment for training RL agents.
|
| 50 |
+
|
| 51 |
+
The agent must decide between:
|
| 52 |
+
- Asking clarifying questions (ASK) to reveal hidden slot values
|
| 53 |
+
- Answering early (ANSWER) to end the episode
|
| 54 |
+
|
| 55 |
+
Hidden state (city, date, budget, style) is sampled at reset with a seeded RNG.
|
| 56 |
+
The agent can ask about slots to reveal their values before answering.
|
| 57 |
+
|
| 58 |
+
With MAX_STEPS=3, the agent can only ask 2 slots before being forced to answer,
|
| 59 |
+
creating a non-trivial ask-vs-act tradeoff. The "style" slot is a distractor
|
| 60 |
+
that provides less reward than core slots (city, date, budget).
|
| 61 |
+
|
| 62 |
+
Rewards:
|
| 63 |
+
- Step penalty: -0.05 per step
|
| 64 |
+
- ASK unknown slot: +0.1
|
| 65 |
+
- ASK known slot: -0.2
|
| 66 |
+
- ANSWER: graded per-slot (+0.4 each core, +0.1 style)
|
| 67 |
+
- Core all correct bonus: +0.2
|
| 68 |
+
- Core any wrong penalty: -0.6
|
| 69 |
+
- Auto-fail (steps exhausted): -1.0
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 73 |
+
|
| 74 |
+
def __init__(self):
|
| 75 |
+
"""Initialize the ask_answer_env environment."""
|
| 76 |
+
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 77 |
+
self._rng: random.Random = random.Random()
|
| 78 |
+
|
| 79 |
+
# Hidden truth (sampled at reset)
|
| 80 |
+
self._hidden_city: str = ""
|
| 81 |
+
self._hidden_date: str = ""
|
| 82 |
+
self._hidden_budget: str = ""
|
| 83 |
+
self._hidden_style: str = ""
|
| 84 |
+
|
| 85 |
+
# Known slots (revealed through ASK actions)
|
| 86 |
+
self._known: KnownSlots = KnownSlots()
|
| 87 |
+
self._steps_left: int = MAX_STEPS
|
| 88 |
+
self._done: bool = False
|
| 89 |
+
|
| 90 |
+
def reset(self, seed: Optional[int] = None) -> AskAnswerObservation:
|
| 91 |
+
"""
|
| 92 |
+
Reset the environment with optional seed for determinism.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
seed: Random seed for reproducibility
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
AskAnswerObservation with initial state
|
| 99 |
+
"""
|
| 100 |
+
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 101 |
+
|
| 102 |
+
# Initialize RNG with seed
|
| 103 |
+
if seed is not None:
|
| 104 |
+
self._rng = random.Random(seed)
|
| 105 |
+
else:
|
| 106 |
+
self._rng = random.Random()
|
| 107 |
+
|
| 108 |
+
# Sample hidden truth
|
| 109 |
+
self._hidden_city = self._rng.choice(CITIES)
|
| 110 |
+
self._hidden_date = self._rng.choice(DATES)
|
| 111 |
+
self._hidden_budget = self._rng.choice(BUDGETS)
|
| 112 |
+
self._hidden_style = self._rng.choice(STYLES)
|
| 113 |
+
|
| 114 |
+
# Reset known slots and step counter
|
| 115 |
+
self._known = KnownSlots()
|
| 116 |
+
self._steps_left = MAX_STEPS
|
| 117 |
+
self._done = False
|
| 118 |
+
|
| 119 |
+
return AskAnswerObservation(
|
| 120 |
+
prompt=PROMPT,
|
| 121 |
+
known=self._known,
|
| 122 |
+
steps_left=self._steps_left,
|
| 123 |
+
done=False,
|
| 124 |
+
reward=0.0,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
def step(self, action: AskAnswerAction) -> AskAnswerObservation: # type: ignore[override]
|
| 128 |
+
"""
|
| 129 |
+
Execute a step in the environment.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
action: AskAnswerAction with type 'ask' or 'answer'
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
AskAnswerObservation with updated state and reward
|
| 136 |
+
"""
|
| 137 |
+
if self._done:
|
| 138 |
+
return AskAnswerObservation(
|
| 139 |
+
prompt=PROMPT,
|
| 140 |
+
known=self._known,
|
| 141 |
+
steps_left=self._steps_left,
|
| 142 |
+
done=True,
|
| 143 |
+
reward=0.0,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
self._state.step_count += 1
|
| 147 |
+
|
| 148 |
+
# Always apply step penalty
|
| 149 |
+
reward = STEP_PENALTY
|
| 150 |
+
done = False
|
| 151 |
+
|
| 152 |
+
if action.type == "ask":
|
| 153 |
+
reward += self._handle_ask(action.slot)
|
| 154 |
+
self._steps_left -= 1
|
| 155 |
+
|
| 156 |
+
# Check for auto-fail
|
| 157 |
+
if self._steps_left == 0:
|
| 158 |
+
reward = AUTO_FAIL_PENALTY
|
| 159 |
+
done = True
|
| 160 |
+
|
| 161 |
+
elif action.type == "answer":
|
| 162 |
+
reward += self._handle_answer(action)
|
| 163 |
+
done = True
|
| 164 |
+
|
| 165 |
+
self._done = done
|
| 166 |
+
|
| 167 |
+
# Calculate core_correct_count when episode ends via ANSWER
|
| 168 |
+
core_correct_count = None
|
| 169 |
+
if done and action.type == "answer":
|
| 170 |
+
core_correct_count = sum([
|
| 171 |
+
action.city == self._hidden_city,
|
| 172 |
+
action.date == self._hidden_date,
|
| 173 |
+
action.budget == self._hidden_budget,
|
| 174 |
+
])
|
| 175 |
+
|
| 176 |
+
return AskAnswerObservation(
|
| 177 |
+
prompt=PROMPT,
|
| 178 |
+
known=self._known,
|
| 179 |
+
steps_left=self._steps_left,
|
| 180 |
+
done=done,
|
| 181 |
+
reward=reward,
|
| 182 |
+
core_correct_count=core_correct_count,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
def _handle_ask(self, slot: Optional[str]) -> float:
|
| 186 |
+
"""
|
| 187 |
+
Handle ASK action - reveal a slot if unknown.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
slot: The slot to ask about ('city', 'date', 'budget', or 'style')
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
Reward for the ASK action
|
| 194 |
+
"""
|
| 195 |
+
if slot == "city":
|
| 196 |
+
if self._known.city is not None:
|
| 197 |
+
return ASK_KNOWN_PENALTY
|
| 198 |
+
self._known = KnownSlots(
|
| 199 |
+
city=self._hidden_city,
|
| 200 |
+
date=self._known.date,
|
| 201 |
+
budget=self._known.budget,
|
| 202 |
+
style=self._known.style,
|
| 203 |
+
)
|
| 204 |
+
return ASK_UNKNOWN_REWARD
|
| 205 |
+
|
| 206 |
+
elif slot == "date":
|
| 207 |
+
if self._known.date is not None:
|
| 208 |
+
return ASK_KNOWN_PENALTY
|
| 209 |
+
self._known = KnownSlots(
|
| 210 |
+
city=self._known.city,
|
| 211 |
+
date=self._hidden_date,
|
| 212 |
+
budget=self._known.budget,
|
| 213 |
+
style=self._known.style,
|
| 214 |
+
)
|
| 215 |
+
return ASK_UNKNOWN_REWARD
|
| 216 |
+
|
| 217 |
+
elif slot == "budget":
|
| 218 |
+
if self._known.budget is not None:
|
| 219 |
+
return ASK_KNOWN_PENALTY
|
| 220 |
+
self._known = KnownSlots(
|
| 221 |
+
city=self._known.city,
|
| 222 |
+
date=self._known.date,
|
| 223 |
+
budget=self._hidden_budget,
|
| 224 |
+
style=self._known.style,
|
| 225 |
+
)
|
| 226 |
+
return ASK_UNKNOWN_REWARD
|
| 227 |
+
|
| 228 |
+
elif slot == "style":
|
| 229 |
+
if self._known.style is not None:
|
| 230 |
+
return ASK_KNOWN_PENALTY
|
| 231 |
+
self._known = KnownSlots(
|
| 232 |
+
city=self._known.city,
|
| 233 |
+
date=self._known.date,
|
| 234 |
+
budget=self._known.budget,
|
| 235 |
+
style=self._hidden_style,
|
| 236 |
+
)
|
| 237 |
+
return ASK_UNKNOWN_REWARD
|
| 238 |
+
|
| 239 |
+
# Invalid slot
|
| 240 |
+
return ASK_KNOWN_PENALTY
|
| 241 |
+
|
| 242 |
+
def _handle_answer(self, action: AskAnswerAction) -> float:
|
| 243 |
+
"""
|
| 244 |
+
Handle ANSWER action with graded rewards.
|
| 245 |
+
|
| 246 |
+
Reward structure:
|
| 247 |
+
- Per-slot rewards: +0.4 for each correct core slot (city, date, budget)
|
| 248 |
+
- Style bonus: +0.1 if style provided and correct (ignored if None)
|
| 249 |
+
- Core bonus: +0.2 if all core slots correct
|
| 250 |
+
- Core penalty: -0.6 if any core slot wrong
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
action: The answer action with city, date, budget, style values
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
Reward for the ANSWER action
|
| 257 |
+
"""
|
| 258 |
+
reward = 0.0
|
| 259 |
+
|
| 260 |
+
# Check core slots
|
| 261 |
+
city_correct = action.city == self._hidden_city
|
| 262 |
+
date_correct = action.date == self._hidden_date
|
| 263 |
+
budget_correct = action.budget == self._hidden_budget
|
| 264 |
+
|
| 265 |
+
# Per-slot rewards for core slots
|
| 266 |
+
if city_correct:
|
| 267 |
+
reward += ANSWER_CITY_CORRECT
|
| 268 |
+
if date_correct:
|
| 269 |
+
reward += ANSWER_DATE_CORRECT
|
| 270 |
+
if budget_correct:
|
| 271 |
+
reward += ANSWER_BUDGET_CORRECT
|
| 272 |
+
|
| 273 |
+
# Style bonus (only if provided and correct, ignored if None)
|
| 274 |
+
if action.style is not None and action.style == self._hidden_style:
|
| 275 |
+
reward += ANSWER_STYLE_CORRECT_BONUS
|
| 276 |
+
|
| 277 |
+
# Core bonus/penalty
|
| 278 |
+
core_all_correct = city_correct and date_correct and budget_correct
|
| 279 |
+
if core_all_correct:
|
| 280 |
+
reward += ANSWER_CORE_ALL_CORRECT_BONUS
|
| 281 |
+
else:
|
| 282 |
+
reward += ANSWER_CORE_ANY_WRONG_PENALTY
|
| 283 |
+
|
| 284 |
+
return reward
|
| 285 |
+
|
| 286 |
+
@property
|
| 287 |
+
def state(self) -> State:
|
| 288 |
+
"""
|
| 289 |
+
Get the current environment state.
|
| 290 |
+
|
| 291 |
+
Returns:
|
| 292 |
+
Current State with episode_id and step_count
|
| 293 |
+
"""
|
| 294 |
+
return self._state
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv[core]>=0.2.0
|
| 2 |
+
fastapi>=0.115.0
|
| 3 |
+
uvicorn>=0.24.0
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|