Spaces:
Sleeping
Sleeping
Andrew Lara commited on
Commit ·
ee91164
0
Parent(s):
Deploy landing page update to Space
Browse files- .dockerignore +7 -0
- .env.example +5 -0
- .gitattributes +2 -0
- .github/workflows/ci.yml +18 -0
- .gitignore +6 -0
- CODEX_CONTEXT.md +110 -0
- Dockerfile +23 -0
- README.md +150 -0
- docs/agent_comparison.png +3 -0
- docs/budget_pacing.png +3 -0
- eval_results.json +0 -0
- openenv.yaml +12 -0
- pyproject.toml +37 -0
- reasonbudget_gym/__init__.py +3 -0
- reasonbudget_gym/baselines/__init__.py +11 -0
- reasonbudget_gym/baselines/bandit.py +82 -0
- reasonbudget_gym/baselines/greedy_max.py +18 -0
- reasonbudget_gym/baselines/oracle.py +46 -0
- reasonbudget_gym/baselines/uniform.py +22 -0
- reasonbudget_gym/client.py +51 -0
- reasonbudget_gym/data/__init__.py +0 -0
- reasonbudget_gym/data/embeddings.npy +3 -0
- reasonbudget_gym/data/generate_synthetic_cache.py +156 -0
- reasonbudget_gym/data/response_cache.json +0 -0
- reasonbudget_gym/env/__init__.py +5 -0
- reasonbudget_gym/env/config.py +37 -0
- reasonbudget_gym/env/episode_sampler.py +210 -0
- reasonbudget_gym/env/models.py +43 -0
- reasonbudget_gym/env/reason_budget_env.py +167 -0
- reasonbudget_gym/env/reward.py +44 -0
- reasonbudget_gym/eval/__init__.py +0 -0
- reasonbudget_gym/eval/evaluate.py +121 -0
- reasonbudget_gym/eval/plots.py +89 -0
- reasonbudget_gym/policy/__init__.py +3 -0
- reasonbudget_gym/policy/allocation_policy.py +127 -0
- reasonbudget_gym/server/__init__.py +0 -0
- reasonbudget_gym/server/app.py +233 -0
- reasonbudget_gym/solver/__init__.py +4 -0
- reasonbudget_gym/solver/base.py +26 -0
- reasonbudget_gym/solver/cached_solver.py +98 -0
- reasonbudget_gym/solver/live_solver.py +62 -0
- reasonbudget_gym/tests/__init__.py +0 -0
- reasonbudget_gym/tests/test_config.py +27 -0
- reasonbudget_gym/tests/test_integration.py +43 -0
- reasonbudget_gym/tests/test_reward.py +23 -0
- reasonbudget_gym/training/__init__.py +0 -0
- reasonbudget_gym/training/ppo_train.py +252 -0
- requirements.txt +8 -0
.dockerignore
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.git
|
| 2 |
+
.venv
|
| 3 |
+
.pytest_cache
|
| 4 |
+
__pycache__
|
| 5 |
+
*.py[cod]
|
| 6 |
+
.DS_Store
|
| 7 |
+
runs
|
.env.example
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Optional: only needed for live solver (GPU inference)
|
| 2 |
+
TOGETHER_API_KEY=your_key_here
|
| 3 |
+
|
| 4 |
+
# Optional: for HF Spaces deployment
|
| 5 |
+
HF_TOKEN=your_token_here
|
.gitattributes
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
docs/*.png filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
reasonbudget_gym/data/*.npy filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/ci.yml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: CI
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
pull_request:
|
| 6 |
+
|
| 7 |
+
jobs:
|
| 8 |
+
test:
|
| 9 |
+
runs-on: ubuntu-latest
|
| 10 |
+
steps:
|
| 11 |
+
- uses: actions/checkout@v4
|
| 12 |
+
- uses: actions/setup-python@v5
|
| 13 |
+
with:
|
| 14 |
+
python-version: "3.11"
|
| 15 |
+
- name: Install dependencies
|
| 16 |
+
run: pip install -e ".[dev]"
|
| 17 |
+
- name: Run tests
|
| 18 |
+
run: python -m pytest reasonbudget_gym/tests/ -v
|
.gitignore
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.DS_Store
|
| 2 |
+
.venv/
|
| 3 |
+
.pytest_cache/
|
| 4 |
+
__pycache__/
|
| 5 |
+
*.py[cod]
|
| 6 |
+
runs/
|
CODEX_CONTEXT.md
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Codex Context — ReasoningEconomicsEnv
|
| 2 |
+
|
| 3 |
+
## Project
|
| 4 |
+
|
| 5 |
+
- Repo root: `/Users/andrew/Mac/RL Research`
|
| 6 |
+
- GitHub repo: `git@github.com:laraandrew/reasoningeconomicsenv.git`
|
| 7 |
+
- Active branch: `polish-and-deploy`
|
| 8 |
+
- Hugging Face Space: `landrew9/CollabReasoning`
|
| 9 |
+
- Package: `reasonbudget_gym`
|
| 10 |
+
- Goal: RL environment for token-budget allocation, competition submission, Docker-based HF Space deployment
|
| 11 |
+
|
| 12 |
+
## Remotes
|
| 13 |
+
|
| 14 |
+
- `origin`: `git@github.com:laraandrew/reasoningeconomicsenv.git`
|
| 15 |
+
- `hf`: `https://huggingface.co/spaces/landrew9/CollabReasoning`
|
| 16 |
+
|
| 17 |
+
## Current State
|
| 18 |
+
|
| 19 |
+
- `main` and `polish-and-deploy` originally pointed to the same base commit.
|
| 20 |
+
- Work on `polish-and-deploy` is pushed to GitHub through commit `efdc42b`.
|
| 21 |
+
- The shipped cache works:
|
| 22 |
+
- `CachedSolver(EnvConfig())._cache` loads 500 entries.
|
| 23 |
+
- The environment now defaults to an offline-safe path for cached runs:
|
| 24 |
+
- `EpisodeSampler` uses deterministic bundled questions when the cached solver is active.
|
| 25 |
+
- Real question embeddings are enabled and cached at:
|
| 26 |
+
- `reasonbudget_gym/data/embeddings.npy`
|
| 27 |
+
- README now contains measured evaluation metrics and embedded plot assets.
|
| 28 |
+
- CI exists at `.github/workflows/ci.yml`.
|
| 29 |
+
- Dockerfile was slimmed to a runtime-only serving image suitable for HF Spaces.
|
| 30 |
+
- The Hugging Face Space repo was force-updated from a clean temporary clone because
|
| 31 |
+
Hugging Face rejected the branch's historical raw binary blobs.
|
| 32 |
+
- The live Space is currently:
|
| 33 |
+
- Hub page: `https://huggingface.co/spaces/landrew9/CollabReasoning`
|
| 34 |
+
- Host: `https://landrew9-collabreasoning.hf.space`
|
| 35 |
+
- Runtime stage: `RUNNING`
|
| 36 |
+
- Health endpoint: `/health`
|
| 37 |
+
- Root path originally returned `404`; a landing page at `/` was then added in `server/app.py`
|
| 38 |
+
|
| 39 |
+
## Local Tooling
|
| 40 |
+
|
| 41 |
+
- Hugging Face CLI installed globally via the official installer.
|
| 42 |
+
- Binary path: `/Users/andrew/.local/bin/hf`
|
| 43 |
+
- Reported version at install time: `1.8.0`
|
| 44 |
+
- Installer added `/Users/andrew/.local/bin` to `/Users/andrew/.zshrc`
|
| 45 |
+
- `git-lfs` and `git-xet` are installed and initialized globally.
|
| 46 |
+
- `.gitattributes` now tracks:
|
| 47 |
+
- `docs/*.png`
|
| 48 |
+
- `reasonbudget_gym/data/*.npy`
|
| 49 |
+
|
| 50 |
+
## Verified Commands
|
| 51 |
+
|
| 52 |
+
- Tests:
|
| 53 |
+
- `.venv/bin/python -m pytest reasonbudget_gym/tests/ -v`
|
| 54 |
+
- Result: `8 passed`
|
| 55 |
+
- Eval:
|
| 56 |
+
- `.venv/bin/python -m reasonbudget_gym.eval.evaluate --n_episodes 50 --seed 42 --output eval_results.json`
|
| 57 |
+
- Plot generation:
|
| 58 |
+
- `.venv/bin/python -c "from reasonbudget_gym.eval.plots import agent_comparison, budget_pacing; agent_comparison('eval_results.json', 'docs/agent_comparison.png'); budget_pacing('eval_results.json', 'docs/budget_pacing.png')"`
|
| 59 |
+
- PPO smoke test:
|
| 60 |
+
- `.venv/bin/python -m reasonbudget_gym.training.ppo_train --n_episodes 100 --output_dir runs/smoke`
|
| 61 |
+
- Completed successfully and wrote checkpoints.
|
| 62 |
+
- Docker:
|
| 63 |
+
- `docker build -t reasoning-economic-env .`
|
| 64 |
+
- `docker run -d -p 8000:8000 --name reasoning-economic-env-test reasoning-economic-env`
|
| 65 |
+
- `curl http://127.0.0.1:8000/health`
|
| 66 |
+
- Result: `{"status":"ok","env":"ReasonBudgetEnv","version":"0.1.0"}`
|
| 67 |
+
|
| 68 |
+
## Current Eval Numbers
|
| 69 |
+
|
| 70 |
+
From `eval_results.json` with `--n_episodes 50 --seed 42`:
|
| 71 |
+
|
| 72 |
+
| Agent | Mean Accuracy | Mean Reward | Budget Used |
|
| 73 |
+
|---|---:|---:|---:|
|
| 74 |
+
| `uniform` | 0.780 | 7.620 | 100.0% |
|
| 75 |
+
| `greedy_max` | 0.840 | 4.163 | 100.0% |
|
| 76 |
+
| `oracle` | 0.728 | 6.933 | 98.3% |
|
| 77 |
+
| `bandit` | 0.744 | 6.526 | 98.8% |
|
| 78 |
+
|
| 79 |
+
## Important Files
|
| 80 |
+
|
| 81 |
+
- `reasonbudget_gym/env/episode_sampler.py`
|
| 82 |
+
- `reasonbudget_gym/env/config.py`
|
| 83 |
+
- `reasonbudget_gym/solver/cached_solver.py`
|
| 84 |
+
- `reasonbudget_gym/eval/evaluate.py`
|
| 85 |
+
- `reasonbudget_gym/server/app.py`
|
| 86 |
+
- `Dockerfile`
|
| 87 |
+
- `README.md`
|
| 88 |
+
- `.github/workflows/ci.yml`
|
| 89 |
+
- `eval_results.json`
|
| 90 |
+
- `docs/agent_comparison.png`
|
| 91 |
+
- `docs/budget_pacing.png`
|
| 92 |
+
|
| 93 |
+
## Git History Added On This Branch
|
| 94 |
+
|
| 95 |
+
- `29b6ad0` Add gitignore for local dev artifacts
|
| 96 |
+
- `ecd0ab1` Use bundled questions for cached offline runs
|
| 97 |
+
- `9e122a2` Cache MiniLM question embeddings
|
| 98 |
+
- `c4d6234` Add GitHub Actions test workflow
|
| 99 |
+
- `fc6c606` Add baseline eval results and README plots
|
| 100 |
+
- `280a6de` Slim Docker image for HF deployment
|
| 101 |
+
- `fc4c73c` Add living Codex context file
|
| 102 |
+
- `efdc42b` Track Space binaries with Xet
|
| 103 |
+
|
| 104 |
+
## Notes For Next Codex
|
| 105 |
+
|
| 106 |
+
- Keep `HANDOFF.md` deleted; update this file instead.
|
| 107 |
+
- Do not remove `reasonbudget_gym/data/response_cache.json` or `reasonbudget_gym/data/embeddings.npy`; they are part of the current offline/demo story.
|
| 108 |
+
- The Docker image should stay lean; avoid reintroducing `sentence-transformers`, `datasets`, or training dependencies into the serving image unless truly needed.
|
| 109 |
+
- If enabling the live solver later, configure secrets in Hugging Face Space settings rather than hard-coding them.
|
| 110 |
+
- The local repo may also have an `hf` remote pointing at the Space repo; if so, pushes there will trigger Space rebuilds.
|
Dockerfile
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 5 |
+
PYTHONUNBUFFERED=1
|
| 6 |
+
|
| 7 |
+
# Copy only the files needed to serve the packaged environment.
|
| 8 |
+
COPY pyproject.toml README.md openenv.yaml ./
|
| 9 |
+
COPY reasonbudget_gym ./reasonbudget_gym
|
| 10 |
+
|
| 11 |
+
# The Space serves the bundled cached environment, so it only needs the
|
| 12 |
+
# lightweight runtime deps plus an editable install of this package.
|
| 13 |
+
RUN pip install --no-cache-dir \
|
| 14 |
+
"fastapi>=0.110.0" \
|
| 15 |
+
"uvicorn[standard]>=0.29.0" \
|
| 16 |
+
"pydantic>=2.0" \
|
| 17 |
+
"numpy>=1.24" \
|
| 18 |
+
"hatchling" \
|
| 19 |
+
&& pip install --no-cache-dir --no-deps -e .
|
| 20 |
+
|
| 21 |
+
EXPOSE 8000
|
| 22 |
+
|
| 23 |
+
CMD ["uvicorn", "reasonbudget_gym.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
README.md
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: ReasoningEconomicsEnv
|
| 3 |
+
sdk: docker
|
| 4 |
+
app_port: 8000
|
| 5 |
+
tags:
|
| 6 |
+
- openenv
|
| 7 |
+
- reasoning-economic-env
|
| 8 |
+
- rl
|
| 9 |
+
- math
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# ReasoningEconomicsEnv
|
| 13 |
+
|
| 14 |
+
**An RL environment for learning to allocate reasoning compute under budget constraints.**
|
| 15 |
+
|
| 16 |
+
> Modern reasoning models like DeepSeek-R1 "think" by generating internal tokens before
|
| 17 |
+
> answering. More tokens = deeper reasoning = better answers — but tokens cost compute and
|
| 18 |
+
> money. How should an agent decide how much to think on each problem?
|
| 19 |
+
|
| 20 |
+
ReasoningEconomicsEnv frames this as a sequential decision problem: an agent faces a series
|
| 21 |
+
of math questions with a fixed total token budget and must learn to **allocate tokens wisely**
|
| 22 |
+
— spending less on easy questions, more on hard ones.
|
| 23 |
+
|
| 24 |
+
Built on [Meta's OpenEnv framework](https://github.com/meta-pytorch/OpenEnv) for the
|
| 25 |
+
[AgentX–AgentBeats Competition](https://rdi.berkeley.edu/agentx-agentbeats) hosted by
|
| 26 |
+
Berkeley RDI.
|
| 27 |
+
|
| 28 |
+
---
|
| 29 |
+
|
| 30 |
+
## How It Works
|
| 31 |
+
|
| 32 |
+
```
|
| 33 |
+
Episode (10 questions, 4000 token budget)
|
| 34 |
+
┌─────────────────────────────────────────────────────────┐
|
| 35 |
+
│ 1. Agent observes: question embedding, remaining budget │
|
| 36 |
+
│ 2. Agent decides: token allocation (50–800) │
|
| 37 |
+
│ 3. Solver attempts question with that token limit │
|
| 38 |
+
│ 4. Reward = correctness − β·cost + γ·efficiency_bonus │
|
| 39 |
+
│ 5. Repeat until all questions answered or budget gone │
|
| 40 |
+
└─────────────────────────────────────────────────────────┘
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
**Reward formula:** `R = correctness(±1/−0.1) − β·(tokens_used/budget) + γ·(savings/budget)`
|
| 44 |
+
|
| 45 |
+
---
|
| 46 |
+
|
| 47 |
+
## Quick Start
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
pip install -e .
|
| 51 |
+
|
| 52 |
+
# Run the OpenEnv server
|
| 53 |
+
uvicorn reasonbudget_gym.server.app:app --port 8000
|
| 54 |
+
|
| 55 |
+
# In another terminal — use the Python client
|
| 56 |
+
python -c "
|
| 57 |
+
from reasonbudget_gym.client import ReasonBudgetClient
|
| 58 |
+
client = ReasonBudgetClient()
|
| 59 |
+
obs = client.reset()
|
| 60 |
+
result = client.step(200)
|
| 61 |
+
print(result.reward, result.done)
|
| 62 |
+
"
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
**Or run baseline evaluation locally:**
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
python -m reasonbudget_gym.eval.evaluate --n_episodes 50 --seed 42 --output eval_results.json
|
| 69 |
+
python -m reasonbudget_gym.eval.plots eval_results.json
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
---
|
| 73 |
+
|
| 74 |
+
## Baselines
|
| 75 |
+
|
| 76 |
+
| Agent | Mean Accuracy | Mean Reward | Budget Used |
|
| 77 |
+
|-------|---------------|-------------|-------------|
|
| 78 |
+
| `uniform` | 0.780 | 7.620 | 100.0% |
|
| 79 |
+
| `greedy_max` | 0.840 | 4.163 | 100.0% |
|
| 80 |
+
| `oracle` | 0.728 | 6.933 | 98.3% |
|
| 81 |
+
| `bandit` | 0.744 | 6.526 | 98.8% |
|
| 82 |
+
|
| 83 |
+
Evaluation command:
|
| 84 |
+
|
| 85 |
+
```bash
|
| 86 |
+
python -m reasonbudget_gym.eval.evaluate --n_episodes 50 --seed 42 --output eval_results.json
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+

|
| 90 |
+
|
| 91 |
+

|
| 92 |
+
|
| 93 |
+
---
|
| 94 |
+
|
| 95 |
+
## Observation Space
|
| 96 |
+
|
| 97 |
+
| Field | Shape | Description |
|
| 98 |
+
|-------|-------|-------------|
|
| 99 |
+
| `question_embedding` | 384-dim | Sentence-transformer encoding |
|
| 100 |
+
| `remaining_budget` | int | Tokens left in episode |
|
| 101 |
+
| `questions_remaining` | int | Questions left |
|
| 102 |
+
| `budget_per_remaining` | float | remaining / questions_left |
|
| 103 |
+
| `accuracy_so_far` | float | Running accuracy [0, 1] |
|
| 104 |
+
| `history` | list | Past (allocated, used, correct) tuples |
|
| 105 |
+
|
| 106 |
+
**Action:** integer token allocation, clamped to `[min_tokens, max_tokens]` and remaining budget.
|
| 107 |
+
|
| 108 |
+
---
|
| 109 |
+
|
| 110 |
+
## Data
|
| 111 |
+
|
| 112 |
+
The repo ships with a deterministic offline question bundle and response cache under
|
| 113 |
+
`reasonbudget_gym/data/`, so demos and tests work without external services.
|
| 114 |
+
|
| 115 |
+
A **synthetic cache** (`reasonbudget_gym/data/response_cache.json`) simulates realistic
|
| 116 |
+
DeepSeek-R1 accuracy curves across 4 difficulty tiers: `gsm8k`, `math_l1_l2`, `math_l3`,
|
| 117 |
+
`math_l4_l5`. The sampler also caches MiniLM embeddings to
|
| 118 |
+
`reasonbudget_gym/data/embeddings.npy` after the first run.
|
| 119 |
+
|
| 120 |
+
Regenerate the synthetic cache with:
|
| 121 |
+
|
| 122 |
+
```bash
|
| 123 |
+
python reasonbudget_gym/data/generate_synthetic_cache.py
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
---
|
| 127 |
+
|
| 128 |
+
## Deployment (Docker / HF Spaces)
|
| 129 |
+
|
| 130 |
+
```bash
|
| 131 |
+
docker build -t reasoning-economic-env .
|
| 132 |
+
docker run -p 8000:8000 reasoning-economic-env
|
| 133 |
+
curl http://localhost:8000/health
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
---
|
| 137 |
+
|
| 138 |
+
## Related Work
|
| 139 |
+
|
| 140 |
+
- **[MAS-TTS](https://github.com/jincan333/MAS-TTS):** Allocates reasoning across *agents* on
|
| 141 |
+
one problem vs. our approach of allocating across *questions* for a single agent.
|
| 142 |
+
- **[AgentTTS](https://arxiv.org/abs/2508.00890):** Test-time compute-optimal scaling across
|
| 143 |
+
multi-stage complex tasks.
|
| 144 |
+
|
| 145 |
+
---
|
| 146 |
+
|
| 147 |
+
## Citation
|
| 148 |
+
|
| 149 |
+
Part of the AgentX–AgentBeats Competition (Berkeley RDI, 2026).
|
| 150 |
+
Built on [OpenEnv](https://github.com/meta-pytorch/OpenEnv) by Meta/PyTorch.
|
docs/agent_comparison.png
ADDED
|
Git LFS Details
|
docs/budget_pacing.png
ADDED
|
Git LFS Details
|
eval_results.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
openenv.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: reasoning-economic-env
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: reasonbudget_gym.server.app:app
|
| 6 |
+
port: 8000
|
| 7 |
+
tags:
|
| 8 |
+
- openenv
|
| 9 |
+
- reasoning-economic-env
|
| 10 |
+
- rl
|
| 11 |
+
- math
|
| 12 |
+
- token-budget
|
pyproject.toml
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["hatchling"]
|
| 3 |
+
build-backend = "hatchling.build"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "reasonbudget-gym"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "RL environment for learning to allocate reasoning compute under budget constraints"
|
| 9 |
+
requires-python = ">=3.10"
|
| 10 |
+
dependencies = [
|
| 11 |
+
"fastapi>=0.110.0",
|
| 12 |
+
"uvicorn[standard]>=0.29.0",
|
| 13 |
+
"pydantic>=2.0",
|
| 14 |
+
"numpy>=1.24",
|
| 15 |
+
"datasets>=2.18.0",
|
| 16 |
+
"sentence-transformers>=2.7.0",
|
| 17 |
+
"matplotlib>=3.8",
|
| 18 |
+
"seaborn>=0.13",
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
[project.optional-dependencies]
|
| 22 |
+
dev = [
|
| 23 |
+
"pytest>=8.0",
|
| 24 |
+
"httpx>=0.27",
|
| 25 |
+
]
|
| 26 |
+
train = [
|
| 27 |
+
"torch>=2.2",
|
| 28 |
+
]
|
| 29 |
+
live = [
|
| 30 |
+
"together>=1.2",
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
[tool.hatch.build.targets.wheel]
|
| 34 |
+
packages = ["reasonbudget_gym"]
|
| 35 |
+
|
| 36 |
+
[tool.pytest.ini_options]
|
| 37 |
+
testpaths = ["reasonbudget_gym/tests"]
|
reasonbudget_gym/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ReasoningEconomicsEnv — token budget allocation RL environment."""
|
| 2 |
+
|
| 3 |
+
__version__ = "0.1.0"
|
reasonbudget_gym/baselines/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .uniform import UniformBaseline
|
| 2 |
+
from .greedy_max import GreedyMaxBaseline
|
| 3 |
+
from .oracle import DifficultyOracleBaseline
|
| 4 |
+
from .bandit import LinUCBBaseline
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"UniformBaseline",
|
| 8 |
+
"GreedyMaxBaseline",
|
| 9 |
+
"DifficultyOracleBaseline",
|
| 10 |
+
"LinUCBBaseline",
|
| 11 |
+
]
|
reasonbudget_gym/baselines/bandit.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LinUCB bandit baseline: learns token allocation from question embeddings.
|
| 3 |
+
|
| 4 |
+
Uses a simplified contextual bandit (LinUCB) where:
|
| 5 |
+
- Context = question embedding (384-dim, but we project to 16-dim for speed)
|
| 6 |
+
- Arms = discrete budget tiers [50, 100, 200, 400, 800]
|
| 7 |
+
- Reward = observed correctness signal
|
| 8 |
+
"""
|
| 9 |
+
import math
|
| 10 |
+
import numpy as np
|
| 11 |
+
from ..env.models import Observation
|
| 12 |
+
from ..env.config import EnvConfig
|
| 13 |
+
|
| 14 |
+
PROJ_DIM = 16 # Projection dimension for efficiency
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class LinUCBBaseline:
|
| 18 |
+
"""
|
| 19 |
+
Linear UCB contextual bandit for token allocation.
|
| 20 |
+
|
| 21 |
+
Projects 384-dim embeddings to PROJ_DIM via random projection,
|
| 22 |
+
then maintains a separate LinUCB arm for each budget tier.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
name = "bandit"
|
| 26 |
+
|
| 27 |
+
def __init__(self, config: EnvConfig, alpha: float = 1.0, seed: int = 42):
|
| 28 |
+
self.config = config
|
| 29 |
+
self.alpha = alpha
|
| 30 |
+
self.tiers = config.budget_tiers
|
| 31 |
+
np.random.seed(seed)
|
| 32 |
+
# Random projection matrix: 384 -> PROJ_DIM
|
| 33 |
+
self._proj = np.random.randn(config.embedding_dim, PROJ_DIM) / math.sqrt(PROJ_DIM)
|
| 34 |
+
d = PROJ_DIM + 4 # projected emb + 4 scalars
|
| 35 |
+
# Per-arm LinUCB parameters
|
| 36 |
+
self._A = {t: np.eye(d) for t in self.tiers}
|
| 37 |
+
self._b = {t: np.zeros(d) for t in self.tiers}
|
| 38 |
+
self._last_context = None
|
| 39 |
+
self._last_arm = None
|
| 40 |
+
|
| 41 |
+
def _context(self, obs: Observation) -> np.ndarray:
|
| 42 |
+
emb = np.array(obs.question_embedding, dtype=float)
|
| 43 |
+
proj = emb @ self._proj
|
| 44 |
+
scalars = np.array([
|
| 45 |
+
obs.remaining_budget / self.config.total_budget,
|
| 46 |
+
obs.questions_remaining / self.config.questions_per_episode,
|
| 47 |
+
obs.budget_per_remaining / self.config.max_tokens,
|
| 48 |
+
obs.accuracy_so_far,
|
| 49 |
+
])
|
| 50 |
+
return np.concatenate([proj, scalars])
|
| 51 |
+
|
| 52 |
+
def get_action(self, obs: Observation) -> int:
|
| 53 |
+
ctx = self._context(obs)
|
| 54 |
+
self._last_context = ctx
|
| 55 |
+
|
| 56 |
+
# Only consider tiers we can afford
|
| 57 |
+
affordable = [t for t in self.tiers if t <= obs.remaining_budget]
|
| 58 |
+
if not affordable:
|
| 59 |
+
affordable = [self.tiers[0]]
|
| 60 |
+
|
| 61 |
+
ucb_scores = {}
|
| 62 |
+
for arm in affordable:
|
| 63 |
+
A_inv = np.linalg.inv(self._A[arm])
|
| 64 |
+
theta = A_inv @ self._b[arm]
|
| 65 |
+
ucb = theta @ ctx + self.alpha * math.sqrt(ctx @ A_inv @ ctx)
|
| 66 |
+
ucb_scores[arm] = ucb
|
| 67 |
+
|
| 68 |
+
best_arm = max(ucb_scores, key=ucb_scores.__getitem__)
|
| 69 |
+
self._last_arm = best_arm
|
| 70 |
+
return best_arm
|
| 71 |
+
|
| 72 |
+
def update(self, reward: float):
|
| 73 |
+
"""Update LinUCB parameters after observing a reward."""
|
| 74 |
+
if self._last_context is None or self._last_arm is None:
|
| 75 |
+
return
|
| 76 |
+
ctx = self._last_context
|
| 77 |
+
arm = self._last_arm
|
| 78 |
+
self._A[arm] += np.outer(ctx, ctx)
|
| 79 |
+
self._b[arm] += reward * ctx
|
| 80 |
+
|
| 81 |
+
def reset(self):
|
| 82 |
+
pass
|
reasonbudget_gym/baselines/greedy_max.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GreedyMax baseline: always allocate the maximum fair share."""
|
| 2 |
+
from ..env.models import Observation
|
| 3 |
+
from ..env.config import EnvConfig
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class GreedyMaxBaseline:
|
| 7 |
+
"""Allocates max_tokens each step regardless of budget state."""
|
| 8 |
+
|
| 9 |
+
name = "greedy_max"
|
| 10 |
+
|
| 11 |
+
def __init__(self, config: EnvConfig):
|
| 12 |
+
self.config = config
|
| 13 |
+
|
| 14 |
+
def get_action(self, obs: Observation) -> int:
|
| 15 |
+
return min(self.config.max_tokens, obs.remaining_budget)
|
| 16 |
+
|
| 17 |
+
def reset(self):
|
| 18 |
+
pass
|
reasonbudget_gym/baselines/oracle.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DifficultyOracle baseline: knows question difficulty, allocates proportionally.
|
| 3 |
+
|
| 4 |
+
This is an upper bound — in practice the agent doesn't see difficulty labels.
|
| 5 |
+
"""
|
| 6 |
+
from ..env.models import Observation
|
| 7 |
+
from ..env.config import EnvConfig
|
| 8 |
+
|
| 9 |
+
# Token multipliers by difficulty (relative units)
|
| 10 |
+
DIFFICULTY_MULTIPLIERS = {
|
| 11 |
+
"gsm8k": 0.5,
|
| 12 |
+
"math_l1_l2": 0.75,
|
| 13 |
+
"math_l3": 1.25,
|
| 14 |
+
"math_l4_l5": 2.0,
|
| 15 |
+
}
|
| 16 |
+
DEFAULT_MULTIPLIER = 1.0
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class DifficultyOracleBaseline:
|
| 20 |
+
"""
|
| 21 |
+
Allocates tokens proportional to question difficulty.
|
| 22 |
+
|
| 23 |
+
Requires `info['difficulty']` in the observation (injected by the env).
|
| 24 |
+
Falls back to uniform if difficulty is unknown.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
name = "oracle"
|
| 28 |
+
|
| 29 |
+
def __init__(self, config: EnvConfig):
|
| 30 |
+
self.config = config
|
| 31 |
+
self._current_difficulty = "gsm8k"
|
| 32 |
+
|
| 33 |
+
def set_difficulty(self, difficulty: str):
|
| 34 |
+
"""Called by evaluation harness after env.step() returns info."""
|
| 35 |
+
self._current_difficulty = difficulty
|
| 36 |
+
|
| 37 |
+
def get_action(self, obs: Observation) -> int:
|
| 38 |
+
mult = DIFFICULTY_MULTIPLIERS.get(self._current_difficulty, DEFAULT_MULTIPLIER)
|
| 39 |
+
base = obs.remaining_budget / max(1, obs.questions_remaining)
|
| 40 |
+
allocation = int(base * mult)
|
| 41 |
+
allocation = max(self.config.min_tokens, min(allocation, self.config.max_tokens))
|
| 42 |
+
allocation = min(allocation, obs.remaining_budget)
|
| 43 |
+
return max(self.config.min_tokens, allocation)
|
| 44 |
+
|
| 45 |
+
def reset(self):
|
| 46 |
+
self._current_difficulty = "gsm8k"
|
reasonbudget_gym/baselines/uniform.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Uniform allocation baseline: split budget equally across all questions."""
|
| 2 |
+
from ..env.models import Observation
|
| 3 |
+
from ..env.config import EnvConfig
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class UniformBaseline:
|
| 7 |
+
"""Allocates remaining_budget / questions_remaining tokens each step."""
|
| 8 |
+
|
| 9 |
+
name = "uniform"
|
| 10 |
+
|
| 11 |
+
def __init__(self, config: EnvConfig):
|
| 12 |
+
self.config = config
|
| 13 |
+
|
| 14 |
+
def get_action(self, obs: Observation) -> int:
|
| 15 |
+
if obs.questions_remaining == 0:
|
| 16 |
+
return self.config.min_tokens
|
| 17 |
+
allocation = int(obs.remaining_budget / obs.questions_remaining)
|
| 18 |
+
allocation = max(self.config.min_tokens, min(allocation, self.config.max_tokens))
|
| 19 |
+
return allocation
|
| 20 |
+
|
| 21 |
+
def reset(self):
|
| 22 |
+
pass
|
reasonbudget_gym/client.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Remote HTTP client for the ReasonBudgetEnv server."""
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Any, Dict, List, Optional
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
import requests
|
| 7 |
+
_OK = True
|
| 8 |
+
except ImportError:
|
| 9 |
+
_OK = False
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class RemoteObs:
|
| 14 |
+
question_embedding: List[float]
|
| 15 |
+
remaining_budget: int
|
| 16 |
+
questions_remaining: int
|
| 17 |
+
budget_per_remaining: float
|
| 18 |
+
accuracy_so_far: float
|
| 19 |
+
history: List[dict]
|
| 20 |
+
done: bool
|
| 21 |
+
episode_id: Optional[str] = None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class RemoteResult:
|
| 26 |
+
observation: RemoteObs
|
| 27 |
+
reward: float
|
| 28 |
+
done: bool
|
| 29 |
+
info: Dict[str, Any]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ReasonBudgetClient:
|
| 33 |
+
def __init__(self, base_url: str = "http://localhost:8000"):
|
| 34 |
+
if not _OK:
|
| 35 |
+
raise ImportError("pip install requests")
|
| 36 |
+
self.url = base_url.rstrip("/")
|
| 37 |
+
|
| 38 |
+
def health(self): return requests.get(f"{self.url}/health").json()
|
| 39 |
+
def info(self): return requests.get(f"{self.url}/info").json()
|
| 40 |
+
|
| 41 |
+
def reset(self, seed=None) -> RemoteObs:
|
| 42 |
+
r = requests.post(f"{self.url}/reset", json={"seed": seed})
|
| 43 |
+
r.raise_for_status()
|
| 44 |
+
return RemoteObs(**r.json())
|
| 45 |
+
|
| 46 |
+
def step(self, token_allocation: int) -> RemoteResult:
|
| 47 |
+
r = requests.post(f"{self.url}/step", json={"token_allocation": token_allocation})
|
| 48 |
+
r.raise_for_status()
|
| 49 |
+
d = r.json()
|
| 50 |
+
return RemoteResult(observation=RemoteObs(**d["observation"]),
|
| 51 |
+
reward=d["reward"], done=d["done"], info=d["info"])
|
reasonbudget_gym/data/__init__.py
ADDED
|
File without changes
|
reasonbudget_gym/data/embeddings.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a2f5b49d773c0f605b508c0c2d2c736305cd346654b9b63af8d3a3856cd71f12
|
| 3 |
+
size 768128
|
reasonbudget_gym/data/generate_synthetic_cache.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Generate a synthetic response cache for offline/demo use.
|
| 3 |
+
|
| 4 |
+
Simulates realistic DeepSeek-R1 accuracy curves across difficulty tiers
|
| 5 |
+
WITHOUT any API calls. Uses MetaMathQA dataset if available, otherwise
|
| 6 |
+
falls back to synthetic arithmetic questions.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python reasonbudget_gym/data/generate_synthetic_cache.py
|
| 10 |
+
# or from package root:
|
| 11 |
+
python -m reasonbudget_gym.data.generate_synthetic_cache
|
| 12 |
+
|
| 13 |
+
Output: data/response_cache.json (~500 questions × 5 budget tiers)
|
| 14 |
+
"""
|
| 15 |
+
import hashlib
|
| 16 |
+
import json
|
| 17 |
+
import random
|
| 18 |
+
import sys
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
# ---------------------------------------------------------------------------
|
| 22 |
+
# Accuracy model (from task spec — based on real DeepSeek-R1 behaviour)
|
| 23 |
+
# ---------------------------------------------------------------------------
|
| 24 |
+
ACCURACY_TABLE = {
|
| 25 |
+
# 50 100 200 400 800
|
| 26 |
+
"gsm8k": [0.85, 0.92, 0.95, 0.96, 0.97],
|
| 27 |
+
"math_l1_l2":[0.55, 0.70, 0.80, 0.88, 0.92],
|
| 28 |
+
"math_l3": [0.25, 0.40, 0.55, 0.70, 0.80],
|
| 29 |
+
"math_l4_l5":[0.08, 0.15, 0.30, 0.50, 0.65],
|
| 30 |
+
}
|
| 31 |
+
BUDGET_TIERS = [50, 100, 200, 400, 800]
|
| 32 |
+
N_QUESTIONS = 500
|
| 33 |
+
EVAL_FRACTION = 0.1
|
| 34 |
+
SEED = 42
|
| 35 |
+
|
| 36 |
+
# ---------------------------------------------------------------------------
|
| 37 |
+
# Helpers
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
|
| 40 |
+
def _make_qid(text: str, idx: int) -> str:
|
| 41 |
+
h = hashlib.md5(f"{idx}:{text}".encode()).hexdigest()[:12]
|
| 42 |
+
return f"q_{h}"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _infer_difficulty(item: dict) -> str:
|
| 46 |
+
source = item.get("type", "").lower()
|
| 47 |
+
if "gsm" in source or "grade" in source:
|
| 48 |
+
return "gsm8k"
|
| 49 |
+
resp = item.get("response", "")
|
| 50 |
+
n = len(resp.split())
|
| 51 |
+
if n < 80: return "gsm8k"
|
| 52 |
+
if n < 150: return "math_l1_l2"
|
| 53 |
+
if n < 250: return "math_l3"
|
| 54 |
+
return "math_l4_l5"
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _extract_answer(response: str) -> str:
|
| 58 |
+
import re
|
| 59 |
+
for pat in [r"[Tt]he answer is[:\s]+([^\n.]+)", r"####\s*(.+)", r"=\s*([^\n]+)$"]:
|
| 60 |
+
m = re.search(pat, response)
|
| 61 |
+
if m:
|
| 62 |
+
return m.group(1).strip()
|
| 63 |
+
lines = [l.strip() for l in response.split("\n") if l.strip()]
|
| 64 |
+
return lines[-1] if lines else "unknown"
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _load_questions(n: int):
|
| 68 |
+
try:
|
| 69 |
+
from datasets import load_dataset
|
| 70 |
+
ds = load_dataset("meta-math/MetaMathQA", split="train", trust_remote_code=True)
|
| 71 |
+
items = list(ds.select(range(min(n, len(ds)))))
|
| 72 |
+
questions = []
|
| 73 |
+
for idx, item in enumerate(items):
|
| 74 |
+
qtext = item.get("query", f"Question {idx}")
|
| 75 |
+
answer = _extract_answer(item.get("response", ""))
|
| 76 |
+
diff = _infer_difficulty(item)
|
| 77 |
+
questions.append((_make_qid(qtext, idx), qtext, answer, diff))
|
| 78 |
+
print(f" Loaded {len(questions)} questions from MetaMathQA")
|
| 79 |
+
return questions
|
| 80 |
+
except Exception as e:
|
| 81 |
+
print(f" Dataset unavailable ({e}), using synthetic questions")
|
| 82 |
+
return _synthetic_questions(n)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _synthetic_questions(n: int):
|
| 86 |
+
"""Pure-Python fallback — no external deps."""
|
| 87 |
+
rng = random.Random(SEED)
|
| 88 |
+
templates = [
|
| 89 |
+
("If {a} people each have {b} apples, how many apples total?", lambda a,b,c: a*b),
|
| 90 |
+
("What is {a} + {b} * {c}?", lambda a,b,c: a+b*c),
|
| 91 |
+
("A car travels at {a} km/h for {b} hours. Distance?", lambda a,b,c: a*b),
|
| 92 |
+
("Solve: {a}x = {b}. What is x?", lambda a,b,c: b//a if a else 0),
|
| 93 |
+
]
|
| 94 |
+
difficulties = ["gsm8k", "math_l1_l2", "math_l3", "math_l4_l5"]
|
| 95 |
+
out = []
|
| 96 |
+
for i in range(n):
|
| 97 |
+
a, b, c = rng.randint(2, 20), rng.randint(1, 15), rng.randint(1, 10)
|
| 98 |
+
tmpl, fn = templates[i % len(templates)]
|
| 99 |
+
q = tmpl.format(a=a, b=b, c=c)
|
| 100 |
+
ans = str(fn(a, b, c))
|
| 101 |
+
diff = difficulties[i % len(difficulties)]
|
| 102 |
+
out.append((_make_qid(q, i), q, ans, diff))
|
| 103 |
+
return out
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# ---------------------------------------------------------------------------
|
| 107 |
+
# Cache generation
|
| 108 |
+
# ---------------------------------------------------------------------------
|
| 109 |
+
|
| 110 |
+
def generate_cache(output_path: str = None):
|
| 111 |
+
if output_path is None:
|
| 112 |
+
# Resolve relative to this file's parent directory
|
| 113 |
+
output_path = str(Path(__file__).parent / "response_cache.json")
|
| 114 |
+
|
| 115 |
+
rng = random.Random(SEED)
|
| 116 |
+
print(f"Generating synthetic cache ({N_QUESTIONS} questions × {len(BUDGET_TIERS)} tiers)...")
|
| 117 |
+
questions = _load_questions(N_QUESTIONS)
|
| 118 |
+
|
| 119 |
+
entries = {}
|
| 120 |
+
for qid, qtext, answer, difficulty in questions:
|
| 121 |
+
entries[qid] = {}
|
| 122 |
+
acc_curve = ACCURACY_TABLE[difficulty]
|
| 123 |
+
for tier_idx, tier in enumerate(BUDGET_TIERS):
|
| 124 |
+
p_correct = acc_curve[tier_idx]
|
| 125 |
+
correct = rng.random() < p_correct
|
| 126 |
+
# tokens_used: 70-95% of budget with some noise
|
| 127 |
+
used = int(tier * rng.uniform(0.70, 0.95))
|
| 128 |
+
entries[qid][str(tier)] = {
|
| 129 |
+
"answer": answer if correct else "unknown",
|
| 130 |
+
"was_correct": correct,
|
| 131 |
+
"tokens_used": used,
|
| 132 |
+
"response_text": "[synthetic cache entry]",
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
n_eval = int(len(questions) * EVAL_FRACTION)
|
| 136 |
+
eval_ids = [qid for qid, *_ in questions[-n_eval:]]
|
| 137 |
+
|
| 138 |
+
cache = {"entries": entries, "eval_ids": eval_ids}
|
| 139 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 140 |
+
with open(output_path, "w") as f:
|
| 141 |
+
json.dump(cache, f)
|
| 142 |
+
|
| 143 |
+
n_correct = sum(
|
| 144 |
+
1 for qdata in entries.values()
|
| 145 |
+
for entry in qdata.values() if entry["was_correct"]
|
| 146 |
+
)
|
| 147 |
+
total = len(entries) * len(BUDGET_TIERS)
|
| 148 |
+
print(f" Written: {output_path}")
|
| 149 |
+
print(f" Questions: {len(entries)} | Overall accuracy: {n_correct/total:.1%}")
|
| 150 |
+
print(f" Eval holdout: {len(eval_ids)} questions")
|
| 151 |
+
return output_path
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
if __name__ == "__main__":
|
| 155 |
+
path = sys.argv[1] if len(sys.argv) > 1 else None
|
| 156 |
+
generate_cache(path)
|
reasonbudget_gym/data/response_cache.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
reasonbudget_gym/env/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .reason_budget_env import ReasonBudgetEnv
|
| 2 |
+
from .config import EnvConfig
|
| 3 |
+
from .models import Observation, StepResult, QuestionMeta
|
| 4 |
+
|
| 5 |
+
__all__ = ["ReasonBudgetEnv", "EnvConfig", "Observation", "StepResult", "QuestionMeta"]
|
reasonbudget_gym/env/config.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Environment configuration dataclass."""
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class EnvConfig:
|
| 8 |
+
"""Configuration for the ReasonBudgetEnv."""
|
| 9 |
+
|
| 10 |
+
# Token budget
|
| 11 |
+
total_budget: int = 4000
|
| 12 |
+
min_tokens: int = 50
|
| 13 |
+
max_tokens: int = 800
|
| 14 |
+
budget_tiers: List[int] = field(default_factory=lambda: [50, 100, 200, 400, 800])
|
| 15 |
+
|
| 16 |
+
# Episode structure
|
| 17 |
+
questions_per_episode: int = 10
|
| 18 |
+
seed: int = 42
|
| 19 |
+
|
| 20 |
+
# Reward weights
|
| 21 |
+
correct_reward: float = 1.0
|
| 22 |
+
wrong_penalty: float = -0.1
|
| 23 |
+
cost_penalty_weight: float = 0.0002 # β
|
| 24 |
+
efficiency_bonus_weight: float = 0.3 # γ
|
| 25 |
+
|
| 26 |
+
# Solver
|
| 27 |
+
solver_type: str = "cached" # "cached" | "live"
|
| 28 |
+
cache_path: str = "data/response_cache.json"
|
| 29 |
+
|
| 30 |
+
# Dataset
|
| 31 |
+
dataset_name: str = "meta-math/MetaMathQA"
|
| 32 |
+
max_cache_questions: int = 500
|
| 33 |
+
|
| 34 |
+
# Embedding
|
| 35 |
+
embedding_dim: int = 384
|
| 36 |
+
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"
|
| 37 |
+
embedding_cache_path: str = "data/embeddings.npy"
|
reasonbudget_gym/env/episode_sampler.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Sample episodes (sequences of questions) from the dataset."""
|
| 2 |
+
import hashlib
|
| 3 |
+
import random
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import List
|
| 6 |
+
|
| 7 |
+
from .config import EnvConfig
|
| 8 |
+
from .models import QuestionMeta
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# Difficulty label heuristics based on dataset origin
|
| 12 |
+
_DIFFICULTY_LEVELS = ["gsm8k", "math_l1_l2", "math_l3", "math_l4_l5"]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _infer_difficulty(item: dict) -> str:
|
| 16 |
+
"""Heuristically assign difficulty from MetaMathQA metadata."""
|
| 17 |
+
source = item.get("type", "").lower()
|
| 18 |
+
query = item.get("query", "")
|
| 19 |
+
|
| 20 |
+
if "gsm" in source or "grade" in source:
|
| 21 |
+
return "gsm8k"
|
| 22 |
+
|
| 23 |
+
# Use a simple proxy: length of the chain of thought as difficulty signal
|
| 24 |
+
response = item.get("response", "")
|
| 25 |
+
cot_len = len(response.split())
|
| 26 |
+
if cot_len < 80:
|
| 27 |
+
return "gsm8k"
|
| 28 |
+
elif cot_len < 150:
|
| 29 |
+
return "math_l1_l2"
|
| 30 |
+
elif cot_len < 250:
|
| 31 |
+
return "math_l3"
|
| 32 |
+
else:
|
| 33 |
+
return "math_l4_l5"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _extract_answer(response: str) -> str:
|
| 37 |
+
"""Extract the final answer from a MetaMathQA response."""
|
| 38 |
+
# MetaMathQA answers follow "The answer is X" or "#### X"
|
| 39 |
+
import re
|
| 40 |
+
patterns = [
|
| 41 |
+
r"[Tt]he answer is[:\s]+([^\n.]+)",
|
| 42 |
+
r"####\s*(.+)",
|
| 43 |
+
r"=\s*([^\n]+)$",
|
| 44 |
+
]
|
| 45 |
+
for pat in patterns:
|
| 46 |
+
m = re.search(pat, response)
|
| 47 |
+
if m:
|
| 48 |
+
return m.group(1).strip()
|
| 49 |
+
# Fallback: last non-empty line
|
| 50 |
+
lines = [l.strip() for l in response.split("\n") if l.strip()]
|
| 51 |
+
return lines[-1] if lines else "unknown"
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _make_question_id(text: str, idx: int) -> str:
|
| 55 |
+
h = hashlib.md5(f"{idx}:{text}".encode()).hexdigest()[:12]
|
| 56 |
+
return f"q_{h}"
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class EpisodeSampler:
|
| 60 |
+
"""Loads questions from MetaMathQA and provides episode batches."""
|
| 61 |
+
|
| 62 |
+
def __init__(self, config: EnvConfig, split: str = "train"):
|
| 63 |
+
self.config = config
|
| 64 |
+
self.rng = random.Random(config.seed)
|
| 65 |
+
self._questions: List[QuestionMeta] = []
|
| 66 |
+
self._loaded = False
|
| 67 |
+
self._split = split
|
| 68 |
+
|
| 69 |
+
def _resolve_cache_path(self) -> Path:
|
| 70 |
+
return self._resolve_project_path(self.config.cache_path)
|
| 71 |
+
|
| 72 |
+
def _resolve_project_path(self, relative_path: str) -> Path:
|
| 73 |
+
path = Path(relative_path)
|
| 74 |
+
if path.is_absolute():
|
| 75 |
+
return path
|
| 76 |
+
|
| 77 |
+
pkg_root = Path(__file__).resolve().parent.parent
|
| 78 |
+
package_relative = pkg_root / relative_path
|
| 79 |
+
cwd_relative = Path.cwd() / relative_path
|
| 80 |
+
if package_relative.exists() or not cwd_relative.exists():
|
| 81 |
+
return package_relative
|
| 82 |
+
return cwd_relative
|
| 83 |
+
|
| 84 |
+
def _load_bundled_questions(self) -> List[dict]:
|
| 85 |
+
"""Return deterministic synthetic questions that align with the shipped cache."""
|
| 86 |
+
from ..data.generate_synthetic_cache import _synthetic_questions
|
| 87 |
+
|
| 88 |
+
questions = []
|
| 89 |
+
for qid, qtext, answer, difficulty in _synthetic_questions(self.config.max_cache_questions):
|
| 90 |
+
questions.append({
|
| 91 |
+
"question_id": qid,
|
| 92 |
+
"query": qtext,
|
| 93 |
+
"answer": answer,
|
| 94 |
+
"difficulty": difficulty,
|
| 95 |
+
})
|
| 96 |
+
return questions
|
| 97 |
+
|
| 98 |
+
def _apply_embeddings(self, embeddings) -> None:
|
| 99 |
+
for question, embedding in zip(self._questions, embeddings):
|
| 100 |
+
question.embedding = [float(x) for x in embedding]
|
| 101 |
+
|
| 102 |
+
def _load_embeddings(self) -> None:
|
| 103 |
+
import numpy as np
|
| 104 |
+
|
| 105 |
+
if not self._questions:
|
| 106 |
+
return
|
| 107 |
+
|
| 108 |
+
expected_shape = (len(self._questions), self.config.embedding_dim)
|
| 109 |
+
embeddings_path = self._resolve_project_path(self.config.embedding_cache_path)
|
| 110 |
+
|
| 111 |
+
if embeddings_path.exists():
|
| 112 |
+
try:
|
| 113 |
+
cached = np.load(embeddings_path)
|
| 114 |
+
if cached.shape == expected_shape:
|
| 115 |
+
self._apply_embeddings(cached)
|
| 116 |
+
return
|
| 117 |
+
except Exception:
|
| 118 |
+
pass
|
| 119 |
+
|
| 120 |
+
try:
|
| 121 |
+
from sentence_transformers import SentenceTransformer
|
| 122 |
+
|
| 123 |
+
model = SentenceTransformer(self.config.embedding_model)
|
| 124 |
+
texts = [question.question_text for question in self._questions]
|
| 125 |
+
embeddings = np.asarray(
|
| 126 |
+
model.encode(texts, show_progress_bar=False),
|
| 127 |
+
dtype=np.float32,
|
| 128 |
+
)
|
| 129 |
+
if embeddings.shape != expected_shape:
|
| 130 |
+
return
|
| 131 |
+
embeddings_path.parent.mkdir(parents=True, exist_ok=True)
|
| 132 |
+
np.save(embeddings_path, embeddings)
|
| 133 |
+
self._apply_embeddings(embeddings)
|
| 134 |
+
except Exception:
|
| 135 |
+
# Fall back to zero embeddings when the model is unavailable.
|
| 136 |
+
return
|
| 137 |
+
|
| 138 |
+
def _load(self):
|
| 139 |
+
if self._loaded:
|
| 140 |
+
return
|
| 141 |
+
|
| 142 |
+
cache_path = self._resolve_cache_path()
|
| 143 |
+
if self.config.solver_type == "cached" and cache_path.exists():
|
| 144 |
+
items = self._load_bundled_questions()
|
| 145 |
+
else:
|
| 146 |
+
try:
|
| 147 |
+
from datasets import load_dataset
|
| 148 |
+
|
| 149 |
+
ds = load_dataset(
|
| 150 |
+
self.config.dataset_name,
|
| 151 |
+
split=self._split,
|
| 152 |
+
trust_remote_code=True,
|
| 153 |
+
)
|
| 154 |
+
# Limit to first max_cache_questions for speed
|
| 155 |
+
items = list(ds.select(range(min(self.config.max_cache_questions, len(ds)))))
|
| 156 |
+
except Exception:
|
| 157 |
+
# Offline fallback: generate synthetic questions
|
| 158 |
+
items = self._synthetic_questions(self.config.max_cache_questions)
|
| 159 |
+
|
| 160 |
+
self._questions = []
|
| 161 |
+
for idx, item in enumerate(items):
|
| 162 |
+
qtext = item.get("query", item.get("question", f"Question {idx}"))
|
| 163 |
+
response = item.get("response", item.get("answer", ""))
|
| 164 |
+
answer = _extract_answer(response)
|
| 165 |
+
difficulty = item.get("difficulty", _infer_difficulty(item))
|
| 166 |
+
qid = item.get("question_id", _make_question_id(qtext, idx))
|
| 167 |
+
# Embedding is a zero-vector placeholder; real embedding done lazily
|
| 168 |
+
self._questions.append(QuestionMeta(
|
| 169 |
+
question_id=qid,
|
| 170 |
+
question_text=qtext,
|
| 171 |
+
ground_truth=answer,
|
| 172 |
+
difficulty=difficulty,
|
| 173 |
+
embedding=[0.0] * self.config.embedding_dim,
|
| 174 |
+
))
|
| 175 |
+
self._load_embeddings()
|
| 176 |
+
self._loaded = True
|
| 177 |
+
|
| 178 |
+
def _synthetic_questions(self, n: int) -> List[dict]:
|
| 179 |
+
"""Generate synthetic questions when dataset is unavailable."""
|
| 180 |
+
templates = [
|
| 181 |
+
("If a train travels at {v} km/h for {t} hours, how far does it go?", "{r}"),
|
| 182 |
+
("What is {a} + {b} * {c}?", "{r}"),
|
| 183 |
+
("Solve for x: {a}x + {b} = {c}", "{r}"),
|
| 184 |
+
("A store sells apples for ${p} each. How much do {n} apples cost?", "${r}"),
|
| 185 |
+
]
|
| 186 |
+
items = []
|
| 187 |
+
rng = random.Random(42)
|
| 188 |
+
difficulties = ["gsm8k", "math_l1_l2", "math_l3", "math_l4_l5"]
|
| 189 |
+
for i in range(n):
|
| 190 |
+
tmpl, ans_tmpl = templates[i % len(templates)]
|
| 191 |
+
a, b, c = rng.randint(1, 20), rng.randint(1, 10), rng.randint(1, 15)
|
| 192 |
+
v, t, p, nn = rng.randint(60, 200), rng.randint(1, 10), rng.randint(1, 5), rng.randint(2, 20)
|
| 193 |
+
r = a + b * c
|
| 194 |
+
query = tmpl.format(v=v, t=t, a=a, b=b, c=c, p=p, n=nn)
|
| 195 |
+
answer = str(r)
|
| 196 |
+
diff = difficulties[i % len(difficulties)]
|
| 197 |
+
# Fake CoT length based on difficulty
|
| 198 |
+
cot_lengths = {"gsm8k": 50, "math_l1_l2": 120, "math_l3": 200, "math_l4_l5": 300}
|
| 199 |
+
fake_cot = " word" * cot_lengths[diff]
|
| 200 |
+
items.append({"query": query, "response": f"{fake_cot} The answer is {answer}", "type": diff})
|
| 201 |
+
return items
|
| 202 |
+
|
| 203 |
+
def sample_episode(self) -> List[QuestionMeta]:
|
| 204 |
+
"""Sample a sequence of questions_per_episode questions."""
|
| 205 |
+
self._load()
|
| 206 |
+
return self.rng.choices(self._questions, k=self.config.questions_per_episode)
|
| 207 |
+
|
| 208 |
+
def get_all_questions(self) -> List[QuestionMeta]:
|
| 209 |
+
self._load()
|
| 210 |
+
return list(self._questions)
|
reasonbudget_gym/env/models.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data models shared across the environment."""
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class QuestionMeta:
|
| 8 |
+
"""Metadata for a single question in an episode."""
|
| 9 |
+
question_id: str
|
| 10 |
+
question_text: str
|
| 11 |
+
ground_truth: str
|
| 12 |
+
difficulty: str # gsm8k | math_l1_l2 | math_l3 | math_l4_l5
|
| 13 |
+
embedding: List[float] # 384-dim sentence-transformer encoding
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class StepInfo:
|
| 18 |
+
"""Record of one completed step."""
|
| 19 |
+
tokens_allocated: int
|
| 20 |
+
tokens_used: int
|
| 21 |
+
was_correct: bool
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class Observation:
|
| 26 |
+
"""Full observation returned to the agent at each step."""
|
| 27 |
+
question_embedding: List[float] # 384-dim
|
| 28 |
+
remaining_budget: int
|
| 29 |
+
questions_remaining: int
|
| 30 |
+
budget_per_remaining: float
|
| 31 |
+
accuracy_so_far: float
|
| 32 |
+
history: List[StepInfo]
|
| 33 |
+
done: bool = False
|
| 34 |
+
episode_id: Optional[str] = None
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class StepResult:
|
| 39 |
+
"""Result of env.step()."""
|
| 40 |
+
observation: Observation
|
| 41 |
+
reward: float
|
| 42 |
+
done: bool
|
| 43 |
+
info: dict = field(default_factory=dict)
|
reasonbudget_gym/env/reason_budget_env.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ReasonBudgetEnv — core RL environment for token budget allocation.
|
| 3 |
+
|
| 4 |
+
Compatible with OpenEnv v0.2.x (step/reset interface).
|
| 5 |
+
"""
|
| 6 |
+
import uuid
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
from .config import EnvConfig
|
| 10 |
+
from .episode_sampler import EpisodeSampler
|
| 11 |
+
from .models import Observation, StepInfo, StepResult
|
| 12 |
+
from .reward import compute_reward
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ReasonBudgetEnv:
|
| 16 |
+
"""
|
| 17 |
+
Sequential token budget allocation environment.
|
| 18 |
+
|
| 19 |
+
At each step the agent observes the current question embedding plus
|
| 20 |
+
budget state and chooses how many tokens to allocate. The solver
|
| 21 |
+
attempts the question with that limit and returns a correctness signal.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, config: Optional[EnvConfig] = None):
|
| 25 |
+
self.config = config or EnvConfig()
|
| 26 |
+
self.sampler = EpisodeSampler(self.config)
|
| 27 |
+
self._solver = None
|
| 28 |
+
self._reset_state()
|
| 29 |
+
|
| 30 |
+
# ------------------------------------------------------------------
|
| 31 |
+
# Solver lazy init
|
| 32 |
+
# ------------------------------------------------------------------
|
| 33 |
+
@property
|
| 34 |
+
def solver(self):
|
| 35 |
+
if self._solver is None:
|
| 36 |
+
if self.config.solver_type == "cached":
|
| 37 |
+
from ..solver.cached_solver import CachedSolver
|
| 38 |
+
self._solver = CachedSolver(self.config)
|
| 39 |
+
else:
|
| 40 |
+
from ..solver.live_solver import LiveSolver
|
| 41 |
+
self._solver = LiveSolver(self.config)
|
| 42 |
+
return self._solver
|
| 43 |
+
|
| 44 |
+
# ------------------------------------------------------------------
|
| 45 |
+
# Core interface
|
| 46 |
+
# ------------------------------------------------------------------
|
| 47 |
+
def reset(self) -> Observation:
|
| 48 |
+
"""Start a new episode. Returns first observation."""
|
| 49 |
+
self._reset_state()
|
| 50 |
+
self._episode = self.sampler.sample_episode()
|
| 51 |
+
self._episode_id = str(uuid.uuid4())[:8]
|
| 52 |
+
return self._make_observation()
|
| 53 |
+
|
| 54 |
+
def step(self, token_allocation: int) -> StepResult:
|
| 55 |
+
"""Take a step: allocate tokens to the current question."""
|
| 56 |
+
if self._done:
|
| 57 |
+
raise RuntimeError("Episode is done. Call reset() first.")
|
| 58 |
+
|
| 59 |
+
# Clamp allocation to valid range and remaining budget
|
| 60 |
+
token_allocation = max(self.config.min_tokens,
|
| 61 |
+
min(token_allocation, self.config.max_tokens))
|
| 62 |
+
token_allocation = min(token_allocation, self._remaining_budget)
|
| 63 |
+
if token_allocation < self.config.min_tokens:
|
| 64 |
+
token_allocation = self.config.min_tokens # allow overspend on last Q
|
| 65 |
+
|
| 66 |
+
question = self._episode[self._step_idx]
|
| 67 |
+
budget_before = self._remaining_budget
|
| 68 |
+
|
| 69 |
+
# Solve
|
| 70 |
+
result = self.solver.solve(
|
| 71 |
+
question_id=question.question_id,
|
| 72 |
+
question_text=question.question_text,
|
| 73 |
+
ground_truth=question.ground_truth,
|
| 74 |
+
token_budget=token_allocation,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Accounting
|
| 78 |
+
tokens_used = min(result.tokens_used, token_allocation)
|
| 79 |
+
self._remaining_budget -= token_allocation
|
| 80 |
+
self._step_idx += 1
|
| 81 |
+
self._history.append(StepInfo(
|
| 82 |
+
tokens_allocated=token_allocation,
|
| 83 |
+
tokens_used=tokens_used,
|
| 84 |
+
was_correct=result.was_correct,
|
| 85 |
+
))
|
| 86 |
+
if result.was_correct:
|
| 87 |
+
self._correct_count += 1
|
| 88 |
+
|
| 89 |
+
# Reward
|
| 90 |
+
reward = compute_reward(
|
| 91 |
+
was_correct=result.was_correct,
|
| 92 |
+
tokens_allocated=token_allocation,
|
| 93 |
+
tokens_used=tokens_used,
|
| 94 |
+
remaining_budget_before=budget_before,
|
| 95 |
+
questions_remaining=len(self._episode) - self._step_idx,
|
| 96 |
+
config=self.config,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Done?
|
| 100 |
+
self._done = (
|
| 101 |
+
self._step_idx >= len(self._episode)
|
| 102 |
+
or self._remaining_budget < self.config.min_tokens
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
obs = self._make_observation()
|
| 106 |
+
return StepResult(
|
| 107 |
+
observation=obs,
|
| 108 |
+
reward=reward,
|
| 109 |
+
done=self._done,
|
| 110 |
+
info={
|
| 111 |
+
"was_correct": result.was_correct,
|
| 112 |
+
"tokens_allocated": token_allocation,
|
| 113 |
+
"tokens_used": tokens_used,
|
| 114 |
+
"difficulty": question.difficulty,
|
| 115 |
+
"remaining_budget": self._remaining_budget,
|
| 116 |
+
"step": self._step_idx,
|
| 117 |
+
},
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# ------------------------------------------------------------------
|
| 121 |
+
# Helpers
|
| 122 |
+
# ------------------------------------------------------------------
|
| 123 |
+
def _reset_state(self):
|
| 124 |
+
self._episode = []
|
| 125 |
+
self._episode_id = None
|
| 126 |
+
self._step_idx = 0
|
| 127 |
+
self._remaining_budget = self.config.total_budget
|
| 128 |
+
self._history = []
|
| 129 |
+
self._correct_count = 0
|
| 130 |
+
self._done = False
|
| 131 |
+
|
| 132 |
+
def _make_observation(self) -> Observation:
|
| 133 |
+
if self._step_idx < len(self._episode):
|
| 134 |
+
q = self._episode[self._step_idx]
|
| 135 |
+
emb = q.embedding
|
| 136 |
+
else:
|
| 137 |
+
emb = [0.0] * self.config.embedding_dim
|
| 138 |
+
|
| 139 |
+
n_remaining = max(0, len(self._episode) - self._step_idx)
|
| 140 |
+
bpr = self._remaining_budget / n_remaining if n_remaining > 0 else 0.0
|
| 141 |
+
acc = self._correct_count / self._step_idx if self._step_idx > 0 else 0.0
|
| 142 |
+
|
| 143 |
+
return Observation(
|
| 144 |
+
question_embedding=emb,
|
| 145 |
+
remaining_budget=self._remaining_budget,
|
| 146 |
+
questions_remaining=n_remaining,
|
| 147 |
+
budget_per_remaining=bpr,
|
| 148 |
+
accuracy_so_far=acc,
|
| 149 |
+
history=list(self._history),
|
| 150 |
+
done=self._done,
|
| 151 |
+
episode_id=self._episode_id,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
@property
|
| 155 |
+
def observation_dim(self) -> int:
|
| 156 |
+
"""Flat observation dimension (embedding + 4 scalars)."""
|
| 157 |
+
return self.config.embedding_dim + 4
|
| 158 |
+
|
| 159 |
+
def flat_observation(self, obs: Observation) -> list:
|
| 160 |
+
"""Flatten observation to a list of floats for the policy."""
|
| 161 |
+
scalars = [
|
| 162 |
+
obs.remaining_budget / self.config.total_budget,
|
| 163 |
+
obs.questions_remaining / self.config.questions_per_episode,
|
| 164 |
+
obs.budget_per_remaining / self.config.max_tokens,
|
| 165 |
+
obs.accuracy_so_far,
|
| 166 |
+
]
|
| 167 |
+
return list(obs.question_embedding) + scalars
|
reasonbudget_gym/env/reward.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Reward function for the ReasonBudgetEnv."""
|
| 2 |
+
from .config import EnvConfig
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def compute_reward(
|
| 6 |
+
was_correct: bool,
|
| 7 |
+
tokens_allocated: int,
|
| 8 |
+
tokens_used: int,
|
| 9 |
+
remaining_budget_before: int,
|
| 10 |
+
questions_remaining: int,
|
| 11 |
+
config: EnvConfig,
|
| 12 |
+
) -> float:
|
| 13 |
+
"""
|
| 14 |
+
Compute the per-step reward.
|
| 15 |
+
|
| 16 |
+
R = correctness_term − β * cost_penalty + γ * efficiency_bonus
|
| 17 |
+
|
| 18 |
+
correctness_term: +correct_reward if correct, else wrong_penalty
|
| 19 |
+
cost_penalty: tokens_used / total_budget (fractional spend)
|
| 20 |
+
efficiency_bonus: savings as fraction of budget if correct
|
| 21 |
+
0 if wrong (no credit for being cheap and wrong)
|
| 22 |
+
"""
|
| 23 |
+
# Correctness signal
|
| 24 |
+
if was_correct:
|
| 25 |
+
correctness_term = config.correct_reward
|
| 26 |
+
else:
|
| 27 |
+
correctness_term = config.wrong_penalty
|
| 28 |
+
|
| 29 |
+
# Cost penalty — proportional to tokens burned this step
|
| 30 |
+
cost_penalty = tokens_used / config.total_budget
|
| 31 |
+
|
| 32 |
+
# Efficiency bonus — reward for underspending when correct
|
| 33 |
+
if was_correct and tokens_allocated > 0:
|
| 34 |
+
savings = max(0, tokens_allocated - tokens_used)
|
| 35 |
+
efficiency_bonus = savings / config.total_budget
|
| 36 |
+
else:
|
| 37 |
+
efficiency_bonus = 0.0
|
| 38 |
+
|
| 39 |
+
reward = (
|
| 40 |
+
correctness_term
|
| 41 |
+
- config.cost_penalty_weight * cost_penalty
|
| 42 |
+
+ config.efficiency_bonus_weight * efficiency_bonus
|
| 43 |
+
)
|
| 44 |
+
return float(reward)
|
reasonbudget_gym/eval/__init__.py
ADDED
|
File without changes
|
reasonbudget_gym/eval/evaluate.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluation harness: run all 4 baselines for N episodes and collect metrics.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python -m reasonbudget_gym.eval.evaluate --n_episodes 50 --seed 42 --output eval_results.json
|
| 6 |
+
"""
|
| 7 |
+
import argparse
|
| 8 |
+
import json
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
from ..env.config import EnvConfig
|
| 12 |
+
from ..env.reason_budget_env import ReasonBudgetEnv
|
| 13 |
+
from ..baselines import UniformBaseline, GreedyMaxBaseline, DifficultyOracleBaseline, LinUCBBaseline
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def run_episode(env, agent, config):
|
| 17 |
+
obs = env.reset()
|
| 18 |
+
if hasattr(agent, "set_difficulty") and env._episode:
|
| 19 |
+
agent.set_difficulty(env._episode[0].difficulty)
|
| 20 |
+
done = False
|
| 21 |
+
total_reward = 0.0
|
| 22 |
+
correct = 0
|
| 23 |
+
tokens = 0
|
| 24 |
+
steps = 0
|
| 25 |
+
per_step = []
|
| 26 |
+
|
| 27 |
+
while not done:
|
| 28 |
+
action = agent.get_action(obs)
|
| 29 |
+
result = env.step(action)
|
| 30 |
+
if hasattr(agent, "update"):
|
| 31 |
+
agent.update(result.reward)
|
| 32 |
+
if hasattr(agent, "set_difficulty"):
|
| 33 |
+
ep = env._episode
|
| 34 |
+
if env._step_idx < len(ep):
|
| 35 |
+
agent.set_difficulty(ep[env._step_idx].difficulty)
|
| 36 |
+
|
| 37 |
+
total_reward += result.reward
|
| 38 |
+
if result.info.get("was_correct"):
|
| 39 |
+
correct += 1
|
| 40 |
+
tokens += result.info.get("tokens_allocated", 0)
|
| 41 |
+
steps += 1
|
| 42 |
+
per_step.append({
|
| 43 |
+
"step": steps,
|
| 44 |
+
"tokens_allocated": result.info.get("tokens_allocated", 0),
|
| 45 |
+
"was_correct": result.info.get("was_correct", False),
|
| 46 |
+
"difficulty": result.info.get("difficulty", "unknown"),
|
| 47 |
+
"remaining_budget": result.info.get("remaining_budget", 0),
|
| 48 |
+
"reward": result.reward,
|
| 49 |
+
})
|
| 50 |
+
done = result.done
|
| 51 |
+
obs = result.observation
|
| 52 |
+
|
| 53 |
+
return {
|
| 54 |
+
"total_reward": total_reward,
|
| 55 |
+
"accuracy": correct / max(1, steps),
|
| 56 |
+
"total_tokens_used": tokens,
|
| 57 |
+
"budget_utilization": tokens / config.total_budget,
|
| 58 |
+
"steps": steps,
|
| 59 |
+
"per_step": per_step,
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def evaluate_agent(name, agent, config, n_episodes, seed):
|
| 64 |
+
env = ReasonBudgetEnv(config)
|
| 65 |
+
agent.reset()
|
| 66 |
+
episodes = []
|
| 67 |
+
for i in range(n_episodes):
|
| 68 |
+
env.config.seed = seed + i
|
| 69 |
+
env.sampler.rng.seed(seed + i)
|
| 70 |
+
episodes.append(run_episode(env, agent, config))
|
| 71 |
+
|
| 72 |
+
rewards = [e["total_reward"] for e in episodes]
|
| 73 |
+
accs = [e["accuracy"] for e in episodes]
|
| 74 |
+
utils = [e["budget_utilization"] for e in episodes]
|
| 75 |
+
return {
|
| 76 |
+
"agent": name,
|
| 77 |
+
"n_episodes": n_episodes,
|
| 78 |
+
"mean_reward": float(np.mean(rewards)),
|
| 79 |
+
"std_reward": float(np.std(rewards)),
|
| 80 |
+
"mean_accuracy": float(np.mean(accs)),
|
| 81 |
+
"std_accuracy": float(np.std(accs)),
|
| 82 |
+
"mean_budget_utilization": float(np.mean(utils)),
|
| 83 |
+
"episodes": episodes,
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def main():
|
| 88 |
+
parser = argparse.ArgumentParser()
|
| 89 |
+
parser.add_argument("--n_episodes", type=int, default=50)
|
| 90 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 91 |
+
parser.add_argument("--output", type=str, default="eval_results.json")
|
| 92 |
+
args = parser.parse_args()
|
| 93 |
+
|
| 94 |
+
config = EnvConfig(seed=args.seed)
|
| 95 |
+
agents = {
|
| 96 |
+
"uniform": UniformBaseline(config),
|
| 97 |
+
"greedy_max": GreedyMaxBaseline(config),
|
| 98 |
+
"oracle": DifficultyOracleBaseline(config),
|
| 99 |
+
"bandit": LinUCBBaseline(config, seed=args.seed),
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
results = {}
|
| 103 |
+
print(f"Evaluating {len(agents)} agents × {args.n_episodes} episodes")
|
| 104 |
+
for name, agent in agents.items():
|
| 105 |
+
print(f" {name}...", end=" ", flush=True)
|
| 106 |
+
res = evaluate_agent(name, agent, config, args.n_episodes, args.seed)
|
| 107 |
+
results[name] = res
|
| 108 |
+
print(f"acc={res['mean_accuracy']:.3f} reward={res['mean_reward']:.3f}")
|
| 109 |
+
|
| 110 |
+
print(f"\n{'Agent':<15} {'Accuracy':>10} {'Reward':>10} {'Budget%':>10}")
|
| 111 |
+
print("-" * 50)
|
| 112 |
+
for name, res in results.items():
|
| 113 |
+
print(f"{name:<15} {res['mean_accuracy']:>10.3f} {res['mean_reward']:>10.3f} {res['mean_budget_utilization']:>9.1%}")
|
| 114 |
+
|
| 115 |
+
with open(args.output, "w") as f:
|
| 116 |
+
json.dump(results, f, indent=2)
|
| 117 |
+
print(f"\nSaved to {args.output}")
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
if __name__ == "__main__":
|
| 121 |
+
main()
|
reasonbudget_gym/eval/plots.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generate comparison charts from eval_results.json."""
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def agent_comparison(results_path: str, output_path: str = "docs/agent_comparison.png"):
|
| 7 |
+
import numpy as np
|
| 8 |
+
import matplotlib
|
| 9 |
+
matplotlib.use("Agg")
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
|
| 12 |
+
with open(results_path) as f:
|
| 13 |
+
results = json.load(f)
|
| 14 |
+
|
| 15 |
+
agents = list(results.keys())
|
| 16 |
+
accs = [results[a]["mean_accuracy"] for a in agents]
|
| 17 |
+
rewards = [results[a]["mean_reward"] for a in agents]
|
| 18 |
+
acc_stds = [results[a]["std_accuracy"] for a in agents]
|
| 19 |
+
x = np.arange(len(agents))
|
| 20 |
+
colors = ["#4C72B0", "#DD8452", "#55A868", "#C44E52"]
|
| 21 |
+
|
| 22 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
|
| 23 |
+
fig.suptitle("Baseline Agent Comparison — ReasoningEconomicsEnv", fontsize=13, fontweight="bold")
|
| 24 |
+
|
| 25 |
+
bars1 = ax1.bar(x, accs, 0.5, yerr=acc_stds, capsize=4, color=colors, alpha=0.85)
|
| 26 |
+
ax1.set_title("Mean Accuracy")
|
| 27 |
+
ax1.set_xticks(x); ax1.set_xticklabels(agents, rotation=15)
|
| 28 |
+
ax1.set_ylim(0, 1.1)
|
| 29 |
+
for b, v in zip(bars1, accs):
|
| 30 |
+
ax1.text(b.get_x() + b.get_width()/2, v + 0.03, f"{v:.3f}", ha="center", fontsize=9)
|
| 31 |
+
|
| 32 |
+
bars2 = ax2.bar(x, rewards, 0.5, color=colors, alpha=0.85)
|
| 33 |
+
ax2.set_title("Mean Episode Reward")
|
| 34 |
+
ax2.set_xticks(x); ax2.set_xticklabels(agents, rotation=15)
|
| 35 |
+
for b, v in zip(bars2, rewards):
|
| 36 |
+
ax2.text(b.get_x() + b.get_width()/2, v + abs(v)*0.03 + 0.01, f"{v:.3f}", ha="center", fontsize=9)
|
| 37 |
+
|
| 38 |
+
plt.tight_layout()
|
| 39 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 40 |
+
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
| 41 |
+
plt.close()
|
| 42 |
+
print(f"Saved: {output_path}")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def budget_pacing(results_path: str, output_path: str = "docs/budget_pacing.png"):
|
| 46 |
+
import numpy as np
|
| 47 |
+
import matplotlib
|
| 48 |
+
matplotlib.use("Agg")
|
| 49 |
+
import matplotlib.pyplot as plt
|
| 50 |
+
|
| 51 |
+
with open(results_path) as f:
|
| 52 |
+
results = json.load(f)
|
| 53 |
+
|
| 54 |
+
fig, ax = plt.subplots(figsize=(10, 5))
|
| 55 |
+
colors = {"uniform": "#4C72B0", "greedy_max": "#DD8452", "oracle": "#55A868", "bandit": "#C44E52"}
|
| 56 |
+
|
| 57 |
+
for name, res in results.items():
|
| 58 |
+
episodes = res["episodes"]
|
| 59 |
+
max_steps = max(len(ep["per_step"]) for ep in episodes)
|
| 60 |
+
mat = np.zeros((len(episodes), max_steps))
|
| 61 |
+
for i, ep in enumerate(episodes):
|
| 62 |
+
cumsum = 0
|
| 63 |
+
for j, s in enumerate(ep["per_step"]):
|
| 64 |
+
cumsum += s["tokens_allocated"]
|
| 65 |
+
mat[i, j] = cumsum
|
| 66 |
+
mat[i, len(ep["per_step"]):] = cumsum
|
| 67 |
+
|
| 68 |
+
mean = mat.mean(0); std = mat.std(0)
|
| 69 |
+
steps = np.arange(1, max_steps + 1)
|
| 70 |
+
c = colors.get(name, "gray")
|
| 71 |
+
ax.plot(steps, mean, label=name, color=c, linewidth=2)
|
| 72 |
+
ax.fill_between(steps, mean - std, mean + std, color=c, alpha=0.15)
|
| 73 |
+
|
| 74 |
+
ax.axhline(y=4000, color="black", linestyle="--", linewidth=1.5, label="Budget (4000)")
|
| 75 |
+
ax.set_xlabel("Question #"); ax.set_ylabel("Cumulative Tokens")
|
| 76 |
+
ax.set_title("Budget Pacing by Agent — Mean ± 1 SD", fontweight="bold")
|
| 77 |
+
ax.legend(); ax.grid(True, alpha=0.3)
|
| 78 |
+
plt.tight_layout()
|
| 79 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 80 |
+
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
| 81 |
+
plt.close()
|
| 82 |
+
print(f"Saved: {output_path}")
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
if __name__ == "__main__":
|
| 86 |
+
import sys
|
| 87 |
+
f = sys.argv[1] if len(sys.argv) > 1 else "eval_results.json"
|
| 88 |
+
agent_comparison(f)
|
| 89 |
+
budget_pacing(f)
|
reasonbudget_gym/policy/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .allocation_policy import AllocationPolicy
|
| 2 |
+
|
| 3 |
+
__all__ = ["AllocationPolicy"]
|
reasonbudget_gym/policy/allocation_policy.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MLP policy for token budget allocation.
|
| 3 |
+
|
| 4 |
+
Architecture:
|
| 5 |
+
- Shared trunk: FC(obs_dim → 256) → ReLU → FC(256 → 128) → ReLU
|
| 6 |
+
- Actor head: FC(128 → 1) producing mean, with learned log_std
|
| 7 |
+
Output: Gaussian(mean, std) over normalised token allocation
|
| 8 |
+
- Value head: FC(128 → 1)
|
| 9 |
+
|
| 10 |
+
The action is a continuous value in [0, 1] representing fraction of
|
| 11 |
+
max_tokens to allocate. It is scaled to an integer allocation at step time.
|
| 12 |
+
"""
|
| 13 |
+
import math
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from torch.distributions import Normal
|
| 21 |
+
_TORCH_AVAILABLE = True
|
| 22 |
+
except ImportError:
|
| 23 |
+
_TORCH_AVAILABLE = False
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _require_torch():
|
| 27 |
+
if not _TORCH_AVAILABLE:
|
| 28 |
+
raise ImportError(
|
| 29 |
+
"PyTorch is required for AllocationPolicy. "
|
| 30 |
+
"Install it with: pip install torch"
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class _PolicyNet(nn.Module if _TORCH_AVAILABLE else object):
|
| 35 |
+
"""Neural network backbone for the allocation policy."""
|
| 36 |
+
|
| 37 |
+
def __init__(self, obs_dim: int, hidden: int = 256):
|
| 38 |
+
if not _TORCH_AVAILABLE:
|
| 39 |
+
raise ImportError("PyTorch required")
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.trunk = nn.Sequential(
|
| 42 |
+
nn.Linear(obs_dim, hidden),
|
| 43 |
+
nn.ReLU(),
|
| 44 |
+
nn.Linear(hidden, hidden // 2),
|
| 45 |
+
nn.ReLU(),
|
| 46 |
+
)
|
| 47 |
+
self.actor_mean = nn.Linear(hidden // 2, 1)
|
| 48 |
+
self.log_std = nn.Parameter(torch.zeros(1))
|
| 49 |
+
self.value_head = nn.Linear(hidden // 2, 1)
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
h = self.trunk(x)
|
| 53 |
+
mean = torch.sigmoid(self.actor_mean(h)) # in (0, 1)
|
| 54 |
+
std = torch.exp(self.log_std).clamp(0.01, 1.0)
|
| 55 |
+
value = self.value_head(h)
|
| 56 |
+
return mean, std, value
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class AllocationPolicy:
|
| 60 |
+
"""
|
| 61 |
+
High-level wrapper around the policy network.
|
| 62 |
+
|
| 63 |
+
Provides get_action() compatible with the baseline interface,
|
| 64 |
+
plus PPO-specific methods (evaluate_actions, value).
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(self, obs_dim: int, max_tokens: int, min_tokens: int = 50):
|
| 68 |
+
_require_torch()
|
| 69 |
+
self.obs_dim = obs_dim
|
| 70 |
+
self.max_tokens = max_tokens
|
| 71 |
+
self.min_tokens = min_tokens
|
| 72 |
+
self.net = _PolicyNet(obs_dim)
|
| 73 |
+
self.optimizer = torch.optim.Adam(self.net.parameters(), lr=3e-4)
|
| 74 |
+
|
| 75 |
+
def _obs_to_tensor(self, obs_flat: list) -> "torch.Tensor":
|
| 76 |
+
return torch.tensor(obs_flat, dtype=torch.float32).unsqueeze(0)
|
| 77 |
+
|
| 78 |
+
def get_action(self, obs_flat: list) -> tuple:
|
| 79 |
+
"""
|
| 80 |
+
Returns (action_int, log_prob, value_estimate).
|
| 81 |
+
|
| 82 |
+
action_int: token allocation as integer
|
| 83 |
+
"""
|
| 84 |
+
self.net.eval()
|
| 85 |
+
with torch.no_grad():
|
| 86 |
+
x = self._obs_to_tensor(obs_flat)
|
| 87 |
+
mean, std, value = self.net(x)
|
| 88 |
+
dist = Normal(mean, std)
|
| 89 |
+
sample = dist.sample().clamp(0.0, 1.0)
|
| 90 |
+
log_prob = dist.log_prob(sample).squeeze()
|
| 91 |
+
|
| 92 |
+
frac = sample.item()
|
| 93 |
+
action_int = int(self.min_tokens + frac * (self.max_tokens - self.min_tokens))
|
| 94 |
+
return action_int, log_prob.item(), value.squeeze().item()
|
| 95 |
+
|
| 96 |
+
def evaluate_actions(self, obs_batch, action_fracs):
|
| 97 |
+
"""
|
| 98 |
+
Compute log_probs, entropy, values for a batch.
|
| 99 |
+
|
| 100 |
+
obs_batch: Tensor [B, obs_dim]
|
| 101 |
+
action_fracs: Tensor [B, 1] (normalised actions in [0,1])
|
| 102 |
+
"""
|
| 103 |
+
self.net.train()
|
| 104 |
+
mean, std, values = self.net(obs_batch)
|
| 105 |
+
dist = Normal(mean, std)
|
| 106 |
+
log_probs = dist.log_prob(action_fracs)
|
| 107 |
+
entropy = dist.entropy()
|
| 108 |
+
return log_probs, entropy, values
|
| 109 |
+
|
| 110 |
+
def save(self, path: str):
|
| 111 |
+
_require_torch()
|
| 112 |
+
torch.save({
|
| 113 |
+
"net_state": self.net.state_dict(),
|
| 114 |
+
"optimizer_state": self.optimizer.state_dict(),
|
| 115 |
+
"obs_dim": self.obs_dim,
|
| 116 |
+
"max_tokens": self.max_tokens,
|
| 117 |
+
"min_tokens": self.min_tokens,
|
| 118 |
+
}, path)
|
| 119 |
+
|
| 120 |
+
def load(self, path: str):
|
| 121 |
+
_require_torch()
|
| 122 |
+
ckpt = torch.load(path, map_location="cpu")
|
| 123 |
+
self.net.load_state_dict(ckpt["net_state"])
|
| 124 |
+
self.optimizer.load_state_dict(ckpt["optimizer_state"])
|
| 125 |
+
|
| 126 |
+
def reset(self):
|
| 127 |
+
pass
|
reasonbudget_gym/server/__init__.py
ADDED
|
File without changes
|
reasonbudget_gym/server/app.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenEnv-compatible FastAPI server for ReasonBudgetEnv.
|
| 3 |
+
Entry point: reasonbudget_gym.server.app:app
|
| 4 |
+
"""
|
| 5 |
+
from typing import Any, Dict, List, Optional
|
| 6 |
+
|
| 7 |
+
from fastapi import FastAPI, HTTPException
|
| 8 |
+
from fastapi.responses import HTMLResponse
|
| 9 |
+
from pydantic import BaseModel
|
| 10 |
+
|
| 11 |
+
from ..env.config import EnvConfig
|
| 12 |
+
from ..env.models import Observation
|
| 13 |
+
from ..env.reason_budget_env import ReasonBudgetEnv
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ResetRequest(BaseModel):
|
| 17 |
+
seed: Optional[int] = None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class StepRequest(BaseModel):
|
| 21 |
+
token_allocation: int
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ObsResponse(BaseModel):
|
| 25 |
+
question_embedding: List[float]
|
| 26 |
+
remaining_budget: int
|
| 27 |
+
questions_remaining: int
|
| 28 |
+
budget_per_remaining: float
|
| 29 |
+
accuracy_so_far: float
|
| 30 |
+
history: List[dict]
|
| 31 |
+
done: bool
|
| 32 |
+
episode_id: Optional[str] = None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class StepResponse(BaseModel):
|
| 36 |
+
observation: ObsResponse
|
| 37 |
+
reward: float
|
| 38 |
+
done: bool
|
| 39 |
+
info: Dict[str, Any]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _to_obs_response(obs: Observation) -> ObsResponse:
|
| 43 |
+
return ObsResponse(
|
| 44 |
+
question_embedding=obs.question_embedding,
|
| 45 |
+
remaining_budget=obs.remaining_budget,
|
| 46 |
+
questions_remaining=obs.questions_remaining,
|
| 47 |
+
budget_per_remaining=obs.budget_per_remaining,
|
| 48 |
+
accuracy_so_far=obs.accuracy_so_far,
|
| 49 |
+
history=[
|
| 50 |
+
{"tokens_allocated": s.tokens_allocated,
|
| 51 |
+
"tokens_used": s.tokens_used,
|
| 52 |
+
"was_correct": s.was_correct}
|
| 53 |
+
for s in obs.history
|
| 54 |
+
],
|
| 55 |
+
done=obs.done,
|
| 56 |
+
episode_id=obs.episode_id,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _landing_page() -> str:
|
| 61 |
+
return """<!DOCTYPE html>
|
| 62 |
+
<html lang="en">
|
| 63 |
+
<head>
|
| 64 |
+
<meta charset="utf-8" />
|
| 65 |
+
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
| 66 |
+
<title>ReasoningEconomicsEnv</title>
|
| 67 |
+
<style>
|
| 68 |
+
:root {
|
| 69 |
+
color-scheme: light;
|
| 70 |
+
--bg: #f6f4ec;
|
| 71 |
+
--card: #fffdf6;
|
| 72 |
+
--ink: #182022;
|
| 73 |
+
--muted: #536067;
|
| 74 |
+
--accent: #0e7c66;
|
| 75 |
+
--border: #d8d1c0;
|
| 76 |
+
}
|
| 77 |
+
body {
|
| 78 |
+
margin: 0;
|
| 79 |
+
font-family: "Iowan Old Style", "Palatino Linotype", serif;
|
| 80 |
+
background:
|
| 81 |
+
radial-gradient(circle at top left, rgba(14, 124, 102, 0.16), transparent 30%),
|
| 82 |
+
linear-gradient(180deg, #f8f6ef 0%, var(--bg) 100%);
|
| 83 |
+
color: var(--ink);
|
| 84 |
+
}
|
| 85 |
+
main {
|
| 86 |
+
max-width: 840px;
|
| 87 |
+
margin: 0 auto;
|
| 88 |
+
padding: 48px 24px 64px;
|
| 89 |
+
}
|
| 90 |
+
.eyebrow {
|
| 91 |
+
text-transform: uppercase;
|
| 92 |
+
letter-spacing: 0.12em;
|
| 93 |
+
font-size: 0.78rem;
|
| 94 |
+
color: var(--accent);
|
| 95 |
+
margin-bottom: 12px;
|
| 96 |
+
}
|
| 97 |
+
h1 {
|
| 98 |
+
font-size: clamp(2.2rem, 6vw, 4.2rem);
|
| 99 |
+
line-height: 0.95;
|
| 100 |
+
margin: 0 0 18px;
|
| 101 |
+
}
|
| 102 |
+
p {
|
| 103 |
+
font-size: 1.08rem;
|
| 104 |
+
line-height: 1.65;
|
| 105 |
+
color: var(--muted);
|
| 106 |
+
margin: 0 0 18px;
|
| 107 |
+
}
|
| 108 |
+
.card {
|
| 109 |
+
margin-top: 32px;
|
| 110 |
+
background: var(--card);
|
| 111 |
+
border: 1px solid var(--border);
|
| 112 |
+
border-radius: 18px;
|
| 113 |
+
padding: 24px;
|
| 114 |
+
box-shadow: 0 14px 50px rgba(24, 32, 34, 0.08);
|
| 115 |
+
}
|
| 116 |
+
.links {
|
| 117 |
+
display: flex;
|
| 118 |
+
flex-wrap: wrap;
|
| 119 |
+
gap: 12px;
|
| 120 |
+
margin: 24px 0 12px;
|
| 121 |
+
}
|
| 122 |
+
a {
|
| 123 |
+
color: inherit;
|
| 124 |
+
}
|
| 125 |
+
.button {
|
| 126 |
+
display: inline-block;
|
| 127 |
+
text-decoration: none;
|
| 128 |
+
border-radius: 999px;
|
| 129 |
+
padding: 12px 18px;
|
| 130 |
+
border: 1px solid var(--border);
|
| 131 |
+
background: white;
|
| 132 |
+
font-weight: 600;
|
| 133 |
+
}
|
| 134 |
+
.button.primary {
|
| 135 |
+
background: var(--accent);
|
| 136 |
+
border-color: var(--accent);
|
| 137 |
+
color: white;
|
| 138 |
+
}
|
| 139 |
+
ul {
|
| 140 |
+
margin: 18px 0 0;
|
| 141 |
+
padding-left: 20px;
|
| 142 |
+
color: var(--muted);
|
| 143 |
+
}
|
| 144 |
+
code {
|
| 145 |
+
font-family: "SFMono-Regular", "Menlo", monospace;
|
| 146 |
+
font-size: 0.95em;
|
| 147 |
+
}
|
| 148 |
+
</style>
|
| 149 |
+
</head>
|
| 150 |
+
<body>
|
| 151 |
+
<main>
|
| 152 |
+
<div class="eyebrow">AgentX–AgentBeats / OpenEnv</div>
|
| 153 |
+
<h1>ReasoningEconomicsEnv</h1>
|
| 154 |
+
<p>
|
| 155 |
+
A reinforcement learning environment for allocating reasoning tokens across
|
| 156 |
+
multi-question episodes under a fixed compute budget.
|
| 157 |
+
</p>
|
| 158 |
+
<p>
|
| 159 |
+
This Space serves the environment as an API-first FastAPI app. Use the links
|
| 160 |
+
below to inspect the schema, check runtime health, or query environment metadata.
|
| 161 |
+
</p>
|
| 162 |
+
|
| 163 |
+
<div class="links">
|
| 164 |
+
<a class="button primary" href="/docs">Open API Docs</a>
|
| 165 |
+
<a class="button" href="/health">Health Check</a>
|
| 166 |
+
<a class="button" href="/info">Environment Info</a>
|
| 167 |
+
</div>
|
| 168 |
+
|
| 169 |
+
<section class="card">
|
| 170 |
+
<p><strong>Core endpoints</strong></p>
|
| 171 |
+
<ul>
|
| 172 |
+
<li><code>GET /health</code> returns service status.</li>
|
| 173 |
+
<li><code>GET /info</code> returns observation/action-space metadata.</li>
|
| 174 |
+
<li><code>POST /reset</code> starts a fresh episode.</li>
|
| 175 |
+
<li><code>POST /step</code> advances the environment by one allocation decision.</li>
|
| 176 |
+
</ul>
|
| 177 |
+
</section>
|
| 178 |
+
</main>
|
| 179 |
+
</body>
|
| 180 |
+
</html>
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def create_fastapi_app() -> FastAPI:
|
| 185 |
+
app = FastAPI(title="ReasoningEconomicsEnv", version="0.1.0")
|
| 186 |
+
config = EnvConfig()
|
| 187 |
+
env = ReasonBudgetEnv(config)
|
| 188 |
+
|
| 189 |
+
@app.get("/", response_class=HTMLResponse)
|
| 190 |
+
def index():
|
| 191 |
+
return _landing_page()
|
| 192 |
+
|
| 193 |
+
@app.get("/health")
|
| 194 |
+
def health():
|
| 195 |
+
return {"status": "ok", "env": "ReasonBudgetEnv", "version": "0.1.0"}
|
| 196 |
+
|
| 197 |
+
@app.get("/info")
|
| 198 |
+
def info():
|
| 199 |
+
return {
|
| 200 |
+
"name": "ReasoningEconomicsEnv",
|
| 201 |
+
"observation_dim": env.observation_dim,
|
| 202 |
+
"action_space": {"type": "integer", "min": config.min_tokens, "max": config.max_tokens},
|
| 203 |
+
"total_budget": config.total_budget,
|
| 204 |
+
"questions_per_episode": config.questions_per_episode,
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
@app.post("/reset", response_model=ObsResponse)
|
| 208 |
+
def reset(req: ResetRequest):
|
| 209 |
+
if req.seed is not None:
|
| 210 |
+
env.config.seed = req.seed
|
| 211 |
+
env.sampler.rng.seed(req.seed)
|
| 212 |
+
return _to_obs_response(env.reset())
|
| 213 |
+
|
| 214 |
+
@app.post("/step", response_model=StepResponse)
|
| 215 |
+
def step(req: StepRequest):
|
| 216 |
+
try:
|
| 217 |
+
result = env.step(req.token_allocation)
|
| 218 |
+
except RuntimeError as e:
|
| 219 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 220 |
+
return StepResponse(
|
| 221 |
+
observation=_to_obs_response(result.observation),
|
| 222 |
+
reward=result.reward,
|
| 223 |
+
done=result.done,
|
| 224 |
+
info=result.info,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
return app
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# OpenEnv looks for `app` at module level
|
| 231 |
+
app = create_fastapi_app()
|
| 232 |
+
# backwards-compat alias
|
| 233 |
+
create_app = create_fastapi_app
|
reasonbudget_gym/solver/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base import BaseSolver, SolverResult
|
| 2 |
+
from .cached_solver import CachedSolver
|
| 3 |
+
|
| 4 |
+
__all__ = ["BaseSolver", "SolverResult", "CachedSolver"]
|
reasonbudget_gym/solver/base.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Abstract base class for solvers."""
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class SolverResult:
|
| 8 |
+
"""Result of a solver attempt."""
|
| 9 |
+
answer: str
|
| 10 |
+
was_correct: bool
|
| 11 |
+
tokens_used: int
|
| 12 |
+
response_text: str = ""
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class BaseSolver(ABC):
|
| 16 |
+
"""Attempt to answer a question within a token budget."""
|
| 17 |
+
|
| 18 |
+
@abstractmethod
|
| 19 |
+
def solve(
|
| 20 |
+
self,
|
| 21 |
+
question_id: str,
|
| 22 |
+
question_text: str,
|
| 23 |
+
ground_truth: str,
|
| 24 |
+
token_budget: int,
|
| 25 |
+
) -> SolverResult:
|
| 26 |
+
...
|
reasonbudget_gym/solver/cached_solver.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CachedSolver — serves pre-computed solver results from a JSON cache.
|
| 3 |
+
|
| 4 |
+
Cache schema:
|
| 5 |
+
{
|
| 6 |
+
"entries": {
|
| 7 |
+
"<question_id>": {
|
| 8 |
+
"<budget_tier>": {
|
| 9 |
+
"answer": str,
|
| 10 |
+
"was_correct": bool,
|
| 11 |
+
"tokens_used": int,
|
| 12 |
+
"response_text": str
|
| 13 |
+
}
|
| 14 |
+
}
|
| 15 |
+
},
|
| 16 |
+
"eval_ids": [list of question_ids for holdout]
|
| 17 |
+
}
|
| 18 |
+
"""
|
| 19 |
+
import json
|
| 20 |
+
import os
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
|
| 23 |
+
from .base import BaseSolver, SolverResult
|
| 24 |
+
from ..env.config import EnvConfig
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class CachedSolver(BaseSolver):
|
| 28 |
+
"""Look up pre-computed results from a JSON cache file."""
|
| 29 |
+
|
| 30 |
+
def __init__(self, config: EnvConfig):
|
| 31 |
+
self.config = config
|
| 32 |
+
self._cache: dict = {}
|
| 33 |
+
self._eval_ids: list = []
|
| 34 |
+
self._load_cache()
|
| 35 |
+
|
| 36 |
+
def _load_cache(self):
|
| 37 |
+
# Resolve relative path from project root
|
| 38 |
+
cache_path = Path(self.config.cache_path)
|
| 39 |
+
if not cache_path.is_absolute():
|
| 40 |
+
# Try relative to package root, then cwd
|
| 41 |
+
pkg_root = Path(__file__).parent.parent
|
| 42 |
+
candidate = pkg_root / self.config.cache_path
|
| 43 |
+
if candidate.exists():
|
| 44 |
+
cache_path = candidate
|
| 45 |
+
if not cache_path.exists():
|
| 46 |
+
# Cache missing — will return fallback results
|
| 47 |
+
return
|
| 48 |
+
with open(cache_path) as f:
|
| 49 |
+
data = json.load(f)
|
| 50 |
+
self._cache = data.get("entries", {})
|
| 51 |
+
self._eval_ids = data.get("eval_ids", [])
|
| 52 |
+
|
| 53 |
+
def _nearest_tier(self, budget: int) -> str:
|
| 54 |
+
"""Return the nearest budget tier <= budget (or smallest if budget < min)."""
|
| 55 |
+
tiers = sorted(int(t) for t in (list(self._cache.values())[0].keys()
|
| 56 |
+
if self._cache else [50, 100, 200, 400, 800]))
|
| 57 |
+
best = tiers[0]
|
| 58 |
+
for t in tiers:
|
| 59 |
+
if t <= budget:
|
| 60 |
+
best = t
|
| 61 |
+
return str(best)
|
| 62 |
+
|
| 63 |
+
def solve(
|
| 64 |
+
self,
|
| 65 |
+
question_id: str,
|
| 66 |
+
question_text: str,
|
| 67 |
+
ground_truth: str,
|
| 68 |
+
token_budget: int,
|
| 69 |
+
) -> SolverResult:
|
| 70 |
+
if question_id not in self._cache:
|
| 71 |
+
# Fallback: simulate based on budget tier ratio
|
| 72 |
+
return self._fallback(ground_truth, token_budget)
|
| 73 |
+
|
| 74 |
+
tier = self._nearest_tier(token_budget)
|
| 75 |
+
entry = self._cache[question_id].get(tier, self._cache[question_id].get(
|
| 76 |
+
str(sorted(int(k) for k in self._cache[question_id])[0])
|
| 77 |
+
))
|
| 78 |
+
return SolverResult(
|
| 79 |
+
answer=entry["answer"],
|
| 80 |
+
was_correct=entry["was_correct"],
|
| 81 |
+
tokens_used=entry["tokens_used"],
|
| 82 |
+
response_text=entry.get("response_text", ""),
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def _fallback(self, ground_truth: str, token_budget: int) -> SolverResult:
|
| 86 |
+
"""When question not in cache, use a deterministic heuristic."""
|
| 87 |
+
import hashlib
|
| 88 |
+
h = int(hashlib.md5(ground_truth.encode()).hexdigest(), 16)
|
| 89 |
+
# Probability of correctness grows with budget
|
| 90 |
+
p = min(0.9, 0.3 + token_budget / 1600)
|
| 91 |
+
correct = (h % 100) < int(p * 100)
|
| 92 |
+
used = int(token_budget * 0.85)
|
| 93 |
+
return SolverResult(
|
| 94 |
+
answer=ground_truth if correct else "unknown",
|
| 95 |
+
was_correct=correct,
|
| 96 |
+
tokens_used=used,
|
| 97 |
+
response_text="[fallback]",
|
| 98 |
+
)
|
reasonbudget_gym/solver/live_solver.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LiveSolver — calls DeepSeek-R1 via Together AI for real inference.
|
| 3 |
+
|
| 4 |
+
Requires TOGETHER_API_KEY environment variable.
|
| 5 |
+
Not used by default; switch solver_type to 'live' to enable.
|
| 6 |
+
"""
|
| 7 |
+
import os
|
| 8 |
+
import re
|
| 9 |
+
from .base import BaseSolver, SolverResult
|
| 10 |
+
from ..env.config import EnvConfig
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class LiveSolver(BaseSolver):
|
| 14 |
+
"""Call DeepSeek-R1 API for live reasoning."""
|
| 15 |
+
|
| 16 |
+
MODEL = "deepseek-ai/DeepSeek-R1"
|
| 17 |
+
|
| 18 |
+
def __init__(self, config: EnvConfig):
|
| 19 |
+
self.config = config
|
| 20 |
+
try:
|
| 21 |
+
from together import Together
|
| 22 |
+
self._client = Together(api_key=os.environ["TOGETHER_API_KEY"])
|
| 23 |
+
except ImportError:
|
| 24 |
+
raise ImportError("Install `together` package: pip install together")
|
| 25 |
+
except KeyError:
|
| 26 |
+
raise EnvironmentError("TOGETHER_API_KEY not set")
|
| 27 |
+
|
| 28 |
+
def solve(
|
| 29 |
+
self,
|
| 30 |
+
question_id: str,
|
| 31 |
+
question_text: str,
|
| 32 |
+
ground_truth: str,
|
| 33 |
+
token_budget: int,
|
| 34 |
+
) -> SolverResult:
|
| 35 |
+
prompt = (
|
| 36 |
+
f"Solve the following math problem step by step.\n\n"
|
| 37 |
+
f"Problem: {question_text}\n\n"
|
| 38 |
+
f"Provide your final answer as: The answer is <answer>"
|
| 39 |
+
)
|
| 40 |
+
response = self._client.chat.completions.create(
|
| 41 |
+
model=self.MODEL,
|
| 42 |
+
messages=[{"role": "user", "content": prompt}],
|
| 43 |
+
max_tokens=token_budget,
|
| 44 |
+
)
|
| 45 |
+
text = response.choices[0].message.content or ""
|
| 46 |
+
tokens_used = response.usage.completion_tokens if response.usage else token_budget
|
| 47 |
+
|
| 48 |
+
# Extract answer
|
| 49 |
+
m = re.search(r"[Tt]he answer is[:\s]+([^\n.]+)", text)
|
| 50 |
+
answer = m.group(1).strip() if m else text.strip().split("\n")[-1]
|
| 51 |
+
|
| 52 |
+
# Simple correctness check
|
| 53 |
+
was_correct = (
|
| 54 |
+
ground_truth.strip().lower() in answer.lower()
|
| 55 |
+
or answer.lower() in ground_truth.strip().lower()
|
| 56 |
+
)
|
| 57 |
+
return SolverResult(
|
| 58 |
+
answer=answer,
|
| 59 |
+
was_correct=was_correct,
|
| 60 |
+
tokens_used=tokens_used,
|
| 61 |
+
response_text=text[:500],
|
| 62 |
+
)
|
reasonbudget_gym/tests/__init__.py
ADDED
|
File without changes
|
reasonbudget_gym/tests/test_config.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
| 4 |
+
|
| 5 |
+
from reasonbudget_gym.env.config import EnvConfig
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def test_default_config():
|
| 9 |
+
c = EnvConfig()
|
| 10 |
+
assert c.total_budget == 4000
|
| 11 |
+
assert c.min_tokens == 50
|
| 12 |
+
assert c.max_tokens == 800
|
| 13 |
+
assert len(c.budget_tiers) == 5
|
| 14 |
+
assert c.questions_per_episode == 10
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def test_custom_config():
|
| 18 |
+
c = EnvConfig(total_budget=2000, questions_per_episode=5)
|
| 19 |
+
assert c.total_budget == 2000
|
| 20 |
+
assert c.questions_per_episode == 5
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def test_reward_weights():
|
| 24 |
+
c = EnvConfig()
|
| 25 |
+
assert c.correct_reward > 0
|
| 26 |
+
assert c.cost_penalty_weight >= 0
|
| 27 |
+
assert c.efficiency_bonus_weight >= 0
|
reasonbudget_gym/tests/test_integration.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
| 4 |
+
|
| 5 |
+
from reasonbudget_gym.env.config import EnvConfig
|
| 6 |
+
from reasonbudget_gym.env.reason_budget_env import ReasonBudgetEnv
|
| 7 |
+
from reasonbudget_gym.baselines.uniform import UniformBaseline
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def test_full_episode():
|
| 11 |
+
config = EnvConfig(questions_per_episode=5, total_budget=2000)
|
| 12 |
+
env = ReasonBudgetEnv(config)
|
| 13 |
+
agent = UniformBaseline(config)
|
| 14 |
+
|
| 15 |
+
obs = env.reset()
|
| 16 |
+
assert obs.questions_remaining == 5
|
| 17 |
+
assert not obs.done
|
| 18 |
+
|
| 19 |
+
done = False
|
| 20 |
+
steps = 0
|
| 21 |
+
total_reward = 0.0
|
| 22 |
+
while not done:
|
| 23 |
+
result = env.step(agent.get_action(obs))
|
| 24 |
+
total_reward += result.reward
|
| 25 |
+
done = result.done
|
| 26 |
+
obs = result.observation
|
| 27 |
+
steps += 1
|
| 28 |
+
assert steps <= 20
|
| 29 |
+
|
| 30 |
+
assert obs.done
|
| 31 |
+
assert steps == 5
|
| 32 |
+
assert -20 < total_reward < 20
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def test_reset_clears_state():
|
| 36 |
+
config = EnvConfig(questions_per_episode=3, total_budget=1000)
|
| 37 |
+
env = ReasonBudgetEnv(config)
|
| 38 |
+
env.reset()
|
| 39 |
+
env.step(100); env.step(100)
|
| 40 |
+
obs = env.reset()
|
| 41 |
+
assert obs.questions_remaining == 3
|
| 42 |
+
assert obs.remaining_budget == 1000
|
| 43 |
+
assert len(obs.history) == 0
|
reasonbudget_gym/tests/test_reward.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
| 4 |
+
|
| 5 |
+
from reasonbudget_gym.env.config import EnvConfig
|
| 6 |
+
from reasonbudget_gym.env.reward import compute_reward
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def test_correct_positive():
|
| 10 |
+
r = compute_reward(True, 200, 180, 2000, 5, EnvConfig())
|
| 11 |
+
assert r > 0
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def test_wrong_expensive_negative():
|
| 15 |
+
r = compute_reward(False, 800, 800, 4000, 1, EnvConfig(cost_penalty_weight=0.5))
|
| 16 |
+
assert r < 0
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def test_efficiency_bonus():
|
| 20 |
+
cfg = EnvConfig(efficiency_bonus_weight=0.5)
|
| 21 |
+
r_cheap = compute_reward(True, 400, 50, 2000, 5, cfg)
|
| 22 |
+
r_full = compute_reward(True, 400, 400, 2000, 5, cfg)
|
| 23 |
+
assert r_cheap > r_full
|
reasonbudget_gym/training/__init__.py
ADDED
|
File without changes
|
reasonbudget_gym/training/ppo_train.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PPO training loop for the token budget allocation policy.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python -m reasonbudget_gym.training.ppo_train \
|
| 6 |
+
--n_episodes 500 \
|
| 7 |
+
--ppo_epochs 4 \
|
| 8 |
+
--clip_eps 0.2 \
|
| 9 |
+
--value_coef 0.5 \
|
| 10 |
+
--entropy_coef 0.01 \
|
| 11 |
+
--output_dir runs/ppo_run1
|
| 12 |
+
|
| 13 |
+
Training procedure:
|
| 14 |
+
1. Roll out N episodes using the current policy (stochastic)
|
| 15 |
+
2. Compute returns and GAE advantages
|
| 16 |
+
3. Run PPO update for K epochs over the rollout buffer
|
| 17 |
+
4. Log metrics and save checkpoint every 50 episodes
|
| 18 |
+
"""
|
| 19 |
+
import argparse
|
| 20 |
+
import json
|
| 21 |
+
import math
|
| 22 |
+
import os
|
| 23 |
+
import random
|
| 24 |
+
import sys
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
import torch
|
| 29 |
+
import torch.nn.functional as F
|
| 30 |
+
_TORCH_AVAILABLE = True
|
| 31 |
+
except ImportError:
|
| 32 |
+
_TORCH_AVAILABLE = False
|
| 33 |
+
|
| 34 |
+
from ..env.config import EnvConfig
|
| 35 |
+
from ..env.reason_budget_env import ReasonBudgetEnv
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def collect_rollout(env: ReasonBudgetEnv, policy, config: EnvConfig) -> dict:
|
| 39 |
+
"""Run one episode and collect (obs, action_frac, log_prob, reward, value, done)."""
|
| 40 |
+
obs = env.reset()
|
| 41 |
+
flat_obs = env.flat_observation(obs)
|
| 42 |
+
|
| 43 |
+
observations, action_fracs, log_probs, rewards, values, dones = [], [], [], [], [], []
|
| 44 |
+
done = False
|
| 45 |
+
total_reward = 0.0
|
| 46 |
+
correct_count = 0
|
| 47 |
+
steps = 0
|
| 48 |
+
|
| 49 |
+
while not done:
|
| 50 |
+
action_int, log_prob, value = policy.get_action(flat_obs)
|
| 51 |
+
result = env.step(action_int)
|
| 52 |
+
|
| 53 |
+
# Normalise action to [0, 1]
|
| 54 |
+
frac = (action_int - config.min_tokens) / max(1, config.max_tokens - config.min_tokens)
|
| 55 |
+
frac = max(0.0, min(1.0, frac))
|
| 56 |
+
|
| 57 |
+
observations.append(flat_obs)
|
| 58 |
+
action_fracs.append(frac)
|
| 59 |
+
log_probs.append(log_prob)
|
| 60 |
+
rewards.append(result.reward)
|
| 61 |
+
values.append(value)
|
| 62 |
+
dones.append(result.done)
|
| 63 |
+
|
| 64 |
+
if result.info.get("was_correct"):
|
| 65 |
+
correct_count += 1
|
| 66 |
+
total_reward += result.reward
|
| 67 |
+
steps += 1
|
| 68 |
+
done = result.done
|
| 69 |
+
flat_obs = env.flat_observation(result.observation)
|
| 70 |
+
|
| 71 |
+
return {
|
| 72 |
+
"observations": observations,
|
| 73 |
+
"action_fracs": action_fracs,
|
| 74 |
+
"log_probs": log_probs,
|
| 75 |
+
"rewards": rewards,
|
| 76 |
+
"values": values,
|
| 77 |
+
"dones": dones,
|
| 78 |
+
"total_reward": total_reward,
|
| 79 |
+
"accuracy": correct_count / max(1, steps),
|
| 80 |
+
"steps": steps,
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def compute_returns_and_advantages(
|
| 85 |
+
rewards, values, dones, gamma=0.99, lam=0.95
|
| 86 |
+
):
|
| 87 |
+
"""Compute GAE advantages and discounted returns."""
|
| 88 |
+
n = len(rewards)
|
| 89 |
+
advantages = [0.0] * n
|
| 90 |
+
returns = [0.0] * n
|
| 91 |
+
gae = 0.0
|
| 92 |
+
next_value = 0.0
|
| 93 |
+
|
| 94 |
+
for t in reversed(range(n)):
|
| 95 |
+
mask = 0.0 if dones[t] else 1.0
|
| 96 |
+
delta = rewards[t] + gamma * next_value * mask - values[t]
|
| 97 |
+
gae = delta + gamma * lam * mask * gae
|
| 98 |
+
advantages[t] = gae
|
| 99 |
+
next_value = values[t]
|
| 100 |
+
|
| 101 |
+
for t in range(n):
|
| 102 |
+
returns[t] = advantages[t] + values[t]
|
| 103 |
+
|
| 104 |
+
return advantages, returns
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def ppo_update(policy, rollouts: list, clip_eps: float, value_coef: float,
|
| 108 |
+
entropy_coef: float, ppo_epochs: int, batch_size: int = 64):
|
| 109 |
+
"""Run PPO updates over the collected rollouts."""
|
| 110 |
+
import torch
|
| 111 |
+
|
| 112 |
+
# Flatten all rollouts
|
| 113 |
+
all_obs, all_fracs, all_lp, all_adv, all_ret = [], [], [], [], []
|
| 114 |
+
for r in rollouts:
|
| 115 |
+
all_obs.extend(r["observations"])
|
| 116 |
+
all_fracs.extend(r["action_fracs"])
|
| 117 |
+
all_lp.extend(r["log_probs"])
|
| 118 |
+
all_adv.extend(r["advantages"])
|
| 119 |
+
all_ret.extend(r["returns"])
|
| 120 |
+
|
| 121 |
+
obs_t = torch.tensor(all_obs, dtype=torch.float32)
|
| 122 |
+
fracs_t = torch.tensor(all_fracs, dtype=torch.float32).unsqueeze(1)
|
| 123 |
+
old_lp_t = torch.tensor(all_lp, dtype=torch.float32).unsqueeze(1)
|
| 124 |
+
adv_t = torch.tensor(all_adv, dtype=torch.float32).unsqueeze(1)
|
| 125 |
+
ret_t = torch.tensor(all_ret, dtype=torch.float32).unsqueeze(1)
|
| 126 |
+
|
| 127 |
+
# Normalise advantages
|
| 128 |
+
adv_t = (adv_t - adv_t.mean()) / (adv_t.std() + 1e-8)
|
| 129 |
+
|
| 130 |
+
n = obs_t.shape[0]
|
| 131 |
+
total_loss = 0.0
|
| 132 |
+
n_updates = 0
|
| 133 |
+
|
| 134 |
+
for _ in range(ppo_epochs):
|
| 135 |
+
idx = torch.randperm(n)
|
| 136 |
+
for start in range(0, n, batch_size):
|
| 137 |
+
b_idx = idx[start:start + batch_size]
|
| 138 |
+
b_obs = obs_t[b_idx]
|
| 139 |
+
b_fracs = fracs_t[b_idx]
|
| 140 |
+
b_old_lp = old_lp_t[b_idx]
|
| 141 |
+
b_adv = adv_t[b_idx]
|
| 142 |
+
b_ret = ret_t[b_idx]
|
| 143 |
+
|
| 144 |
+
new_lp, entropy, values = policy.evaluate_actions(b_obs, b_fracs)
|
| 145 |
+
|
| 146 |
+
ratio = torch.exp(new_lp - b_old_lp)
|
| 147 |
+
surr1 = ratio * b_adv
|
| 148 |
+
surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * b_adv
|
| 149 |
+
actor_loss = -torch.min(surr1, surr2).mean()
|
| 150 |
+
value_loss = F.mse_loss(values, b_ret)
|
| 151 |
+
entropy_loss = -entropy.mean()
|
| 152 |
+
|
| 153 |
+
loss = actor_loss + value_coef * value_loss + entropy_coef * entropy_loss
|
| 154 |
+
|
| 155 |
+
policy.optimizer.zero_grad()
|
| 156 |
+
loss.backward()
|
| 157 |
+
torch.nn.utils.clip_grad_norm_(policy.net.parameters(), 0.5)
|
| 158 |
+
policy.optimizer.step()
|
| 159 |
+
|
| 160 |
+
total_loss += loss.item()
|
| 161 |
+
n_updates += 1
|
| 162 |
+
|
| 163 |
+
return total_loss / max(1, n_updates)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def train(args):
|
| 167 |
+
if not _TORCH_AVAILABLE:
|
| 168 |
+
print("ERROR: PyTorch not installed. Run: pip install torch", file=sys.stderr)
|
| 169 |
+
sys.exit(1)
|
| 170 |
+
|
| 171 |
+
config = EnvConfig(seed=args.seed)
|
| 172 |
+
env = ReasonBudgetEnv(config)
|
| 173 |
+
obs_dim = env.observation_dim
|
| 174 |
+
|
| 175 |
+
from ..policy.allocation_policy import AllocationPolicy
|
| 176 |
+
policy = AllocationPolicy(
|
| 177 |
+
obs_dim=obs_dim,
|
| 178 |
+
max_tokens=config.max_tokens,
|
| 179 |
+
min_tokens=config.min_tokens,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
output_dir = Path(args.output_dir)
|
| 183 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 184 |
+
|
| 185 |
+
history = []
|
| 186 |
+
print(f"Starting PPO training: {args.n_episodes} episodes, obs_dim={obs_dim}")
|
| 187 |
+
print(f"Output dir: {output_dir}")
|
| 188 |
+
|
| 189 |
+
for episode in range(1, args.n_episodes + 1):
|
| 190 |
+
# Collect rollout
|
| 191 |
+
rollout = collect_rollout(env, policy, config)
|
| 192 |
+
advantages, returns = compute_returns_and_advantages(
|
| 193 |
+
rollout["rewards"], rollout["values"], rollout["dones"]
|
| 194 |
+
)
|
| 195 |
+
rollout["advantages"] = advantages
|
| 196 |
+
rollout["returns"] = returns
|
| 197 |
+
|
| 198 |
+
# PPO update
|
| 199 |
+
loss = ppo_update(
|
| 200 |
+
policy, [rollout],
|
| 201 |
+
clip_eps=args.clip_eps,
|
| 202 |
+
value_coef=args.value_coef,
|
| 203 |
+
entropy_coef=args.entropy_coef,
|
| 204 |
+
ppo_epochs=args.ppo_epochs,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
record = {
|
| 208 |
+
"episode": episode,
|
| 209 |
+
"reward": rollout["total_reward"],
|
| 210 |
+
"accuracy": rollout["accuracy"],
|
| 211 |
+
"loss": loss,
|
| 212 |
+
"steps": rollout["steps"],
|
| 213 |
+
}
|
| 214 |
+
history.append(record)
|
| 215 |
+
|
| 216 |
+
if episode % 10 == 0:
|
| 217 |
+
recent = history[-10:]
|
| 218 |
+
avg_r = sum(r["reward"] for r in recent) / len(recent)
|
| 219 |
+
avg_a = sum(r["accuracy"] for r in recent) / len(recent)
|
| 220 |
+
print(f" Ep {episode:4d} | avg_reward={avg_r:.3f} | avg_acc={avg_a:.3f} | loss={loss:.4f}")
|
| 221 |
+
|
| 222 |
+
if episode % 50 == 0:
|
| 223 |
+
ckpt_path = output_dir / f"policy_ep{episode}.pt"
|
| 224 |
+
policy.save(str(ckpt_path))
|
| 225 |
+
print(f" Saved checkpoint: {ckpt_path}")
|
| 226 |
+
|
| 227 |
+
# Final save
|
| 228 |
+
policy.save(str(output_dir / "policy_final.pt"))
|
| 229 |
+
with open(output_dir / "training_history.json", "w") as f:
|
| 230 |
+
json.dump(history, f, indent=2)
|
| 231 |
+
|
| 232 |
+
print(f"\nTraining complete. Final checkpoint: {output_dir / 'policy_final.pt'}")
|
| 233 |
+
final_10 = history[-10:]
|
| 234 |
+
print(f"Last 10 episodes — avg reward: {sum(r['reward'] for r in final_10)/10:.3f}, "
|
| 235 |
+
f"avg accuracy: {sum(r['accuracy'] for r in final_10)/10:.3f}")
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def main():
|
| 239 |
+
parser = argparse.ArgumentParser(description="PPO training for ReasonBudgetEnv")
|
| 240 |
+
parser.add_argument("--n_episodes", type=int, default=500)
|
| 241 |
+
parser.add_argument("--ppo_epochs", type=int, default=4)
|
| 242 |
+
parser.add_argument("--clip_eps", type=float, default=0.2)
|
| 243 |
+
parser.add_argument("--value_coef", type=float, default=0.5)
|
| 244 |
+
parser.add_argument("--entropy_coef", type=float, default=0.01)
|
| 245 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 246 |
+
parser.add_argument("--output_dir", type=str, default="runs/ppo_run1")
|
| 247 |
+
args = parser.parse_args()
|
| 248 |
+
train(args)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
if __name__ == "__main__":
|
| 252 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.110.0
|
| 2 |
+
uvicorn[standard]>=0.29.0
|
| 3 |
+
pydantic>=2.0
|
| 4 |
+
numpy>=1.24
|
| 5 |
+
datasets>=2.18.0
|
| 6 |
+
sentence-transformers>=2.7.0
|
| 7 |
+
matplotlib>=3.8
|
| 8 |
+
seaborn>=0.13
|