Andrew Lara commited on
Commit
ee91164
·
0 Parent(s):

Deploy landing page update to Space

Browse files
Files changed (48) hide show
  1. .dockerignore +7 -0
  2. .env.example +5 -0
  3. .gitattributes +2 -0
  4. .github/workflows/ci.yml +18 -0
  5. .gitignore +6 -0
  6. CODEX_CONTEXT.md +110 -0
  7. Dockerfile +23 -0
  8. README.md +150 -0
  9. docs/agent_comparison.png +3 -0
  10. docs/budget_pacing.png +3 -0
  11. eval_results.json +0 -0
  12. openenv.yaml +12 -0
  13. pyproject.toml +37 -0
  14. reasonbudget_gym/__init__.py +3 -0
  15. reasonbudget_gym/baselines/__init__.py +11 -0
  16. reasonbudget_gym/baselines/bandit.py +82 -0
  17. reasonbudget_gym/baselines/greedy_max.py +18 -0
  18. reasonbudget_gym/baselines/oracle.py +46 -0
  19. reasonbudget_gym/baselines/uniform.py +22 -0
  20. reasonbudget_gym/client.py +51 -0
  21. reasonbudget_gym/data/__init__.py +0 -0
  22. reasonbudget_gym/data/embeddings.npy +3 -0
  23. reasonbudget_gym/data/generate_synthetic_cache.py +156 -0
  24. reasonbudget_gym/data/response_cache.json +0 -0
  25. reasonbudget_gym/env/__init__.py +5 -0
  26. reasonbudget_gym/env/config.py +37 -0
  27. reasonbudget_gym/env/episode_sampler.py +210 -0
  28. reasonbudget_gym/env/models.py +43 -0
  29. reasonbudget_gym/env/reason_budget_env.py +167 -0
  30. reasonbudget_gym/env/reward.py +44 -0
  31. reasonbudget_gym/eval/__init__.py +0 -0
  32. reasonbudget_gym/eval/evaluate.py +121 -0
  33. reasonbudget_gym/eval/plots.py +89 -0
  34. reasonbudget_gym/policy/__init__.py +3 -0
  35. reasonbudget_gym/policy/allocation_policy.py +127 -0
  36. reasonbudget_gym/server/__init__.py +0 -0
  37. reasonbudget_gym/server/app.py +233 -0
  38. reasonbudget_gym/solver/__init__.py +4 -0
  39. reasonbudget_gym/solver/base.py +26 -0
  40. reasonbudget_gym/solver/cached_solver.py +98 -0
  41. reasonbudget_gym/solver/live_solver.py +62 -0
  42. reasonbudget_gym/tests/__init__.py +0 -0
  43. reasonbudget_gym/tests/test_config.py +27 -0
  44. reasonbudget_gym/tests/test_integration.py +43 -0
  45. reasonbudget_gym/tests/test_reward.py +23 -0
  46. reasonbudget_gym/training/__init__.py +0 -0
  47. reasonbudget_gym/training/ppo_train.py +252 -0
  48. requirements.txt +8 -0
.dockerignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ .git
2
+ .venv
3
+ .pytest_cache
4
+ __pycache__
5
+ *.py[cod]
6
+ .DS_Store
7
+ runs
.env.example ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Optional: only needed for live solver (GPU inference)
2
+ TOGETHER_API_KEY=your_key_here
3
+
4
+ # Optional: for HF Spaces deployment
5
+ HF_TOKEN=your_token_here
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ docs/*.png filter=lfs diff=lfs merge=lfs -text
2
+ reasonbudget_gym/data/*.npy filter=lfs diff=lfs merge=lfs -text
.github/workflows/ci.yml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: CI
2
+
3
+ on:
4
+ push:
5
+ pull_request:
6
+
7
+ jobs:
8
+ test:
9
+ runs-on: ubuntu-latest
10
+ steps:
11
+ - uses: actions/checkout@v4
12
+ - uses: actions/setup-python@v5
13
+ with:
14
+ python-version: "3.11"
15
+ - name: Install dependencies
16
+ run: pip install -e ".[dev]"
17
+ - name: Run tests
18
+ run: python -m pytest reasonbudget_gym/tests/ -v
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ .venv/
3
+ .pytest_cache/
4
+ __pycache__/
5
+ *.py[cod]
6
+ runs/
CODEX_CONTEXT.md ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Codex Context — ReasoningEconomicsEnv
2
+
3
+ ## Project
4
+
5
+ - Repo root: `/Users/andrew/Mac/RL Research`
6
+ - GitHub repo: `git@github.com:laraandrew/reasoningeconomicsenv.git`
7
+ - Active branch: `polish-and-deploy`
8
+ - Hugging Face Space: `landrew9/CollabReasoning`
9
+ - Package: `reasonbudget_gym`
10
+ - Goal: RL environment for token-budget allocation, competition submission, Docker-based HF Space deployment
11
+
12
+ ## Remotes
13
+
14
+ - `origin`: `git@github.com:laraandrew/reasoningeconomicsenv.git`
15
+ - `hf`: `https://huggingface.co/spaces/landrew9/CollabReasoning`
16
+
17
+ ## Current State
18
+
19
+ - `main` and `polish-and-deploy` originally pointed to the same base commit.
20
+ - Work on `polish-and-deploy` is pushed to GitHub through commit `efdc42b`.
21
+ - The shipped cache works:
22
+ - `CachedSolver(EnvConfig())._cache` loads 500 entries.
23
+ - The environment now defaults to an offline-safe path for cached runs:
24
+ - `EpisodeSampler` uses deterministic bundled questions when the cached solver is active.
25
+ - Real question embeddings are enabled and cached at:
26
+ - `reasonbudget_gym/data/embeddings.npy`
27
+ - README now contains measured evaluation metrics and embedded plot assets.
28
+ - CI exists at `.github/workflows/ci.yml`.
29
+ - Dockerfile was slimmed to a runtime-only serving image suitable for HF Spaces.
30
+ - The Hugging Face Space repo was force-updated from a clean temporary clone because
31
+ Hugging Face rejected the branch's historical raw binary blobs.
32
+ - The live Space is currently:
33
+ - Hub page: `https://huggingface.co/spaces/landrew9/CollabReasoning`
34
+ - Host: `https://landrew9-collabreasoning.hf.space`
35
+ - Runtime stage: `RUNNING`
36
+ - Health endpoint: `/health`
37
+ - Root path originally returned `404`; a landing page at `/` was then added in `server/app.py`
38
+
39
+ ## Local Tooling
40
+
41
+ - Hugging Face CLI installed globally via the official installer.
42
+ - Binary path: `/Users/andrew/.local/bin/hf`
43
+ - Reported version at install time: `1.8.0`
44
+ - Installer added `/Users/andrew/.local/bin` to `/Users/andrew/.zshrc`
45
+ - `git-lfs` and `git-xet` are installed and initialized globally.
46
+ - `.gitattributes` now tracks:
47
+ - `docs/*.png`
48
+ - `reasonbudget_gym/data/*.npy`
49
+
50
+ ## Verified Commands
51
+
52
+ - Tests:
53
+ - `.venv/bin/python -m pytest reasonbudget_gym/tests/ -v`
54
+ - Result: `8 passed`
55
+ - Eval:
56
+ - `.venv/bin/python -m reasonbudget_gym.eval.evaluate --n_episodes 50 --seed 42 --output eval_results.json`
57
+ - Plot generation:
58
+ - `.venv/bin/python -c "from reasonbudget_gym.eval.plots import agent_comparison, budget_pacing; agent_comparison('eval_results.json', 'docs/agent_comparison.png'); budget_pacing('eval_results.json', 'docs/budget_pacing.png')"`
59
+ - PPO smoke test:
60
+ - `.venv/bin/python -m reasonbudget_gym.training.ppo_train --n_episodes 100 --output_dir runs/smoke`
61
+ - Completed successfully and wrote checkpoints.
62
+ - Docker:
63
+ - `docker build -t reasoning-economic-env .`
64
+ - `docker run -d -p 8000:8000 --name reasoning-economic-env-test reasoning-economic-env`
65
+ - `curl http://127.0.0.1:8000/health`
66
+ - Result: `{"status":"ok","env":"ReasonBudgetEnv","version":"0.1.0"}`
67
+
68
+ ## Current Eval Numbers
69
+
70
+ From `eval_results.json` with `--n_episodes 50 --seed 42`:
71
+
72
+ | Agent | Mean Accuracy | Mean Reward | Budget Used |
73
+ |---|---:|---:|---:|
74
+ | `uniform` | 0.780 | 7.620 | 100.0% |
75
+ | `greedy_max` | 0.840 | 4.163 | 100.0% |
76
+ | `oracle` | 0.728 | 6.933 | 98.3% |
77
+ | `bandit` | 0.744 | 6.526 | 98.8% |
78
+
79
+ ## Important Files
80
+
81
+ - `reasonbudget_gym/env/episode_sampler.py`
82
+ - `reasonbudget_gym/env/config.py`
83
+ - `reasonbudget_gym/solver/cached_solver.py`
84
+ - `reasonbudget_gym/eval/evaluate.py`
85
+ - `reasonbudget_gym/server/app.py`
86
+ - `Dockerfile`
87
+ - `README.md`
88
+ - `.github/workflows/ci.yml`
89
+ - `eval_results.json`
90
+ - `docs/agent_comparison.png`
91
+ - `docs/budget_pacing.png`
92
+
93
+ ## Git History Added On This Branch
94
+
95
+ - `29b6ad0` Add gitignore for local dev artifacts
96
+ - `ecd0ab1` Use bundled questions for cached offline runs
97
+ - `9e122a2` Cache MiniLM question embeddings
98
+ - `c4d6234` Add GitHub Actions test workflow
99
+ - `fc6c606` Add baseline eval results and README plots
100
+ - `280a6de` Slim Docker image for HF deployment
101
+ - `fc4c73c` Add living Codex context file
102
+ - `efdc42b` Track Space binaries with Xet
103
+
104
+ ## Notes For Next Codex
105
+
106
+ - Keep `HANDOFF.md` deleted; update this file instead.
107
+ - Do not remove `reasonbudget_gym/data/response_cache.json` or `reasonbudget_gym/data/embeddings.npy`; they are part of the current offline/demo story.
108
+ - The Docker image should stay lean; avoid reintroducing `sentence-transformers`, `datasets`, or training dependencies into the serving image unless truly needed.
109
+ - If enabling the live solver later, configure secrets in Hugging Face Space settings rather than hard-coding them.
110
+ - The local repo may also have an `hf` remote pointing at the Space repo; if so, pushes there will trigger Space rebuilds.
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+ ENV PYTHONDONTWRITEBYTECODE=1 \
5
+ PYTHONUNBUFFERED=1
6
+
7
+ # Copy only the files needed to serve the packaged environment.
8
+ COPY pyproject.toml README.md openenv.yaml ./
9
+ COPY reasonbudget_gym ./reasonbudget_gym
10
+
11
+ # The Space serves the bundled cached environment, so it only needs the
12
+ # lightweight runtime deps plus an editable install of this package.
13
+ RUN pip install --no-cache-dir \
14
+ "fastapi>=0.110.0" \
15
+ "uvicorn[standard]>=0.29.0" \
16
+ "pydantic>=2.0" \
17
+ "numpy>=1.24" \
18
+ "hatchling" \
19
+ && pip install --no-cache-dir --no-deps -e .
20
+
21
+ EXPOSE 8000
22
+
23
+ CMD ["uvicorn", "reasonbudget_gym.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
README.md ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ReasoningEconomicsEnv
3
+ sdk: docker
4
+ app_port: 8000
5
+ tags:
6
+ - openenv
7
+ - reasoning-economic-env
8
+ - rl
9
+ - math
10
+ ---
11
+
12
+ # ReasoningEconomicsEnv
13
+
14
+ **An RL environment for learning to allocate reasoning compute under budget constraints.**
15
+
16
+ > Modern reasoning models like DeepSeek-R1 "think" by generating internal tokens before
17
+ > answering. More tokens = deeper reasoning = better answers — but tokens cost compute and
18
+ > money. How should an agent decide how much to think on each problem?
19
+
20
+ ReasoningEconomicsEnv frames this as a sequential decision problem: an agent faces a series
21
+ of math questions with a fixed total token budget and must learn to **allocate tokens wisely**
22
+ — spending less on easy questions, more on hard ones.
23
+
24
+ Built on [Meta's OpenEnv framework](https://github.com/meta-pytorch/OpenEnv) for the
25
+ [AgentX–AgentBeats Competition](https://rdi.berkeley.edu/agentx-agentbeats) hosted by
26
+ Berkeley RDI.
27
+
28
+ ---
29
+
30
+ ## How It Works
31
+
32
+ ```
33
+ Episode (10 questions, 4000 token budget)
34
+ ┌─────────────────────────────────────────────────────────┐
35
+ │ 1. Agent observes: question embedding, remaining budget │
36
+ │ 2. Agent decides: token allocation (50–800) │
37
+ │ 3. Solver attempts question with that token limit │
38
+ │ 4. Reward = correctness − β·cost + γ·efficiency_bonus │
39
+ │ 5. Repeat until all questions answered or budget gone │
40
+ └─────────────────────────────────────────────────────────┘
41
+ ```
42
+
43
+ **Reward formula:** `R = correctness(±1/−0.1) − β·(tokens_used/budget) + γ·(savings/budget)`
44
+
45
+ ---
46
+
47
+ ## Quick Start
48
+
49
+ ```bash
50
+ pip install -e .
51
+
52
+ # Run the OpenEnv server
53
+ uvicorn reasonbudget_gym.server.app:app --port 8000
54
+
55
+ # In another terminal — use the Python client
56
+ python -c "
57
+ from reasonbudget_gym.client import ReasonBudgetClient
58
+ client = ReasonBudgetClient()
59
+ obs = client.reset()
60
+ result = client.step(200)
61
+ print(result.reward, result.done)
62
+ "
63
+ ```
64
+
65
+ **Or run baseline evaluation locally:**
66
+
67
+ ```bash
68
+ python -m reasonbudget_gym.eval.evaluate --n_episodes 50 --seed 42 --output eval_results.json
69
+ python -m reasonbudget_gym.eval.plots eval_results.json
70
+ ```
71
+
72
+ ---
73
+
74
+ ## Baselines
75
+
76
+ | Agent | Mean Accuracy | Mean Reward | Budget Used |
77
+ |-------|---------------|-------------|-------------|
78
+ | `uniform` | 0.780 | 7.620 | 100.0% |
79
+ | `greedy_max` | 0.840 | 4.163 | 100.0% |
80
+ | `oracle` | 0.728 | 6.933 | 98.3% |
81
+ | `bandit` | 0.744 | 6.526 | 98.8% |
82
+
83
+ Evaluation command:
84
+
85
+ ```bash
86
+ python -m reasonbudget_gym.eval.evaluate --n_episodes 50 --seed 42 --output eval_results.json
87
+ ```
88
+
89
+ ![Baseline comparison](docs/agent_comparison.png)
90
+
91
+ ![Budget pacing](docs/budget_pacing.png)
92
+
93
+ ---
94
+
95
+ ## Observation Space
96
+
97
+ | Field | Shape | Description |
98
+ |-------|-------|-------------|
99
+ | `question_embedding` | 384-dim | Sentence-transformer encoding |
100
+ | `remaining_budget` | int | Tokens left in episode |
101
+ | `questions_remaining` | int | Questions left |
102
+ | `budget_per_remaining` | float | remaining / questions_left |
103
+ | `accuracy_so_far` | float | Running accuracy [0, 1] |
104
+ | `history` | list | Past (allocated, used, correct) tuples |
105
+
106
+ **Action:** integer token allocation, clamped to `[min_tokens, max_tokens]` and remaining budget.
107
+
108
+ ---
109
+
110
+ ## Data
111
+
112
+ The repo ships with a deterministic offline question bundle and response cache under
113
+ `reasonbudget_gym/data/`, so demos and tests work without external services.
114
+
115
+ A **synthetic cache** (`reasonbudget_gym/data/response_cache.json`) simulates realistic
116
+ DeepSeek-R1 accuracy curves across 4 difficulty tiers: `gsm8k`, `math_l1_l2`, `math_l3`,
117
+ `math_l4_l5`. The sampler also caches MiniLM embeddings to
118
+ `reasonbudget_gym/data/embeddings.npy` after the first run.
119
+
120
+ Regenerate the synthetic cache with:
121
+
122
+ ```bash
123
+ python reasonbudget_gym/data/generate_synthetic_cache.py
124
+ ```
125
+
126
+ ---
127
+
128
+ ## Deployment (Docker / HF Spaces)
129
+
130
+ ```bash
131
+ docker build -t reasoning-economic-env .
132
+ docker run -p 8000:8000 reasoning-economic-env
133
+ curl http://localhost:8000/health
134
+ ```
135
+
136
+ ---
137
+
138
+ ## Related Work
139
+
140
+ - **[MAS-TTS](https://github.com/jincan333/MAS-TTS):** Allocates reasoning across *agents* on
141
+ one problem vs. our approach of allocating across *questions* for a single agent.
142
+ - **[AgentTTS](https://arxiv.org/abs/2508.00890):** Test-time compute-optimal scaling across
143
+ multi-stage complex tasks.
144
+
145
+ ---
146
+
147
+ ## Citation
148
+
149
+ Part of the AgentX–AgentBeats Competition (Berkeley RDI, 2026).
150
+ Built on [OpenEnv](https://github.com/meta-pytorch/OpenEnv) by Meta/PyTorch.
docs/agent_comparison.png ADDED

Git LFS Details

  • SHA256: 03e0666f33493659fdea9a45064a34556fa191a590544a2da2ed84bc65d21739
  • Pointer size: 130 Bytes
  • Size of remote file: 55.4 kB
docs/budget_pacing.png ADDED

Git LFS Details

  • SHA256: 2ab4d18ea1b3763880ac25fdd4ecd7f77b0289acf754dba9b7b85c3dfc54a525
  • Pointer size: 131 Bytes
  • Size of remote file: 142 kB
eval_results.json ADDED
The diff for this file is too large to render. See raw diff
 
openenv.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: reasoning-economic-env
3
+ type: space
4
+ runtime: fastapi
5
+ app: reasonbudget_gym.server.app:app
6
+ port: 8000
7
+ tags:
8
+ - openenv
9
+ - reasoning-economic-env
10
+ - rl
11
+ - math
12
+ - token-budget
pyproject.toml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "reasonbudget-gym"
7
+ version = "0.1.0"
8
+ description = "RL environment for learning to allocate reasoning compute under budget constraints"
9
+ requires-python = ">=3.10"
10
+ dependencies = [
11
+ "fastapi>=0.110.0",
12
+ "uvicorn[standard]>=0.29.0",
13
+ "pydantic>=2.0",
14
+ "numpy>=1.24",
15
+ "datasets>=2.18.0",
16
+ "sentence-transformers>=2.7.0",
17
+ "matplotlib>=3.8",
18
+ "seaborn>=0.13",
19
+ ]
20
+
21
+ [project.optional-dependencies]
22
+ dev = [
23
+ "pytest>=8.0",
24
+ "httpx>=0.27",
25
+ ]
26
+ train = [
27
+ "torch>=2.2",
28
+ ]
29
+ live = [
30
+ "together>=1.2",
31
+ ]
32
+
33
+ [tool.hatch.build.targets.wheel]
34
+ packages = ["reasonbudget_gym"]
35
+
36
+ [tool.pytest.ini_options]
37
+ testpaths = ["reasonbudget_gym/tests"]
reasonbudget_gym/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """ReasoningEconomicsEnv — token budget allocation RL environment."""
2
+
3
+ __version__ = "0.1.0"
reasonbudget_gym/baselines/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .uniform import UniformBaseline
2
+ from .greedy_max import GreedyMaxBaseline
3
+ from .oracle import DifficultyOracleBaseline
4
+ from .bandit import LinUCBBaseline
5
+
6
+ __all__ = [
7
+ "UniformBaseline",
8
+ "GreedyMaxBaseline",
9
+ "DifficultyOracleBaseline",
10
+ "LinUCBBaseline",
11
+ ]
reasonbudget_gym/baselines/bandit.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LinUCB bandit baseline: learns token allocation from question embeddings.
3
+
4
+ Uses a simplified contextual bandit (LinUCB) where:
5
+ - Context = question embedding (384-dim, but we project to 16-dim for speed)
6
+ - Arms = discrete budget tiers [50, 100, 200, 400, 800]
7
+ - Reward = observed correctness signal
8
+ """
9
+ import math
10
+ import numpy as np
11
+ from ..env.models import Observation
12
+ from ..env.config import EnvConfig
13
+
14
+ PROJ_DIM = 16 # Projection dimension for efficiency
15
+
16
+
17
+ class LinUCBBaseline:
18
+ """
19
+ Linear UCB contextual bandit for token allocation.
20
+
21
+ Projects 384-dim embeddings to PROJ_DIM via random projection,
22
+ then maintains a separate LinUCB arm for each budget tier.
23
+ """
24
+
25
+ name = "bandit"
26
+
27
+ def __init__(self, config: EnvConfig, alpha: float = 1.0, seed: int = 42):
28
+ self.config = config
29
+ self.alpha = alpha
30
+ self.tiers = config.budget_tiers
31
+ np.random.seed(seed)
32
+ # Random projection matrix: 384 -> PROJ_DIM
33
+ self._proj = np.random.randn(config.embedding_dim, PROJ_DIM) / math.sqrt(PROJ_DIM)
34
+ d = PROJ_DIM + 4 # projected emb + 4 scalars
35
+ # Per-arm LinUCB parameters
36
+ self._A = {t: np.eye(d) for t in self.tiers}
37
+ self._b = {t: np.zeros(d) for t in self.tiers}
38
+ self._last_context = None
39
+ self._last_arm = None
40
+
41
+ def _context(self, obs: Observation) -> np.ndarray:
42
+ emb = np.array(obs.question_embedding, dtype=float)
43
+ proj = emb @ self._proj
44
+ scalars = np.array([
45
+ obs.remaining_budget / self.config.total_budget,
46
+ obs.questions_remaining / self.config.questions_per_episode,
47
+ obs.budget_per_remaining / self.config.max_tokens,
48
+ obs.accuracy_so_far,
49
+ ])
50
+ return np.concatenate([proj, scalars])
51
+
52
+ def get_action(self, obs: Observation) -> int:
53
+ ctx = self._context(obs)
54
+ self._last_context = ctx
55
+
56
+ # Only consider tiers we can afford
57
+ affordable = [t for t in self.tiers if t <= obs.remaining_budget]
58
+ if not affordable:
59
+ affordable = [self.tiers[0]]
60
+
61
+ ucb_scores = {}
62
+ for arm in affordable:
63
+ A_inv = np.linalg.inv(self._A[arm])
64
+ theta = A_inv @ self._b[arm]
65
+ ucb = theta @ ctx + self.alpha * math.sqrt(ctx @ A_inv @ ctx)
66
+ ucb_scores[arm] = ucb
67
+
68
+ best_arm = max(ucb_scores, key=ucb_scores.__getitem__)
69
+ self._last_arm = best_arm
70
+ return best_arm
71
+
72
+ def update(self, reward: float):
73
+ """Update LinUCB parameters after observing a reward."""
74
+ if self._last_context is None or self._last_arm is None:
75
+ return
76
+ ctx = self._last_context
77
+ arm = self._last_arm
78
+ self._A[arm] += np.outer(ctx, ctx)
79
+ self._b[arm] += reward * ctx
80
+
81
+ def reset(self):
82
+ pass
reasonbudget_gym/baselines/greedy_max.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GreedyMax baseline: always allocate the maximum fair share."""
2
+ from ..env.models import Observation
3
+ from ..env.config import EnvConfig
4
+
5
+
6
+ class GreedyMaxBaseline:
7
+ """Allocates max_tokens each step regardless of budget state."""
8
+
9
+ name = "greedy_max"
10
+
11
+ def __init__(self, config: EnvConfig):
12
+ self.config = config
13
+
14
+ def get_action(self, obs: Observation) -> int:
15
+ return min(self.config.max_tokens, obs.remaining_budget)
16
+
17
+ def reset(self):
18
+ pass
reasonbudget_gym/baselines/oracle.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DifficultyOracle baseline: knows question difficulty, allocates proportionally.
3
+
4
+ This is an upper bound — in practice the agent doesn't see difficulty labels.
5
+ """
6
+ from ..env.models import Observation
7
+ from ..env.config import EnvConfig
8
+
9
+ # Token multipliers by difficulty (relative units)
10
+ DIFFICULTY_MULTIPLIERS = {
11
+ "gsm8k": 0.5,
12
+ "math_l1_l2": 0.75,
13
+ "math_l3": 1.25,
14
+ "math_l4_l5": 2.0,
15
+ }
16
+ DEFAULT_MULTIPLIER = 1.0
17
+
18
+
19
+ class DifficultyOracleBaseline:
20
+ """
21
+ Allocates tokens proportional to question difficulty.
22
+
23
+ Requires `info['difficulty']` in the observation (injected by the env).
24
+ Falls back to uniform if difficulty is unknown.
25
+ """
26
+
27
+ name = "oracle"
28
+
29
+ def __init__(self, config: EnvConfig):
30
+ self.config = config
31
+ self._current_difficulty = "gsm8k"
32
+
33
+ def set_difficulty(self, difficulty: str):
34
+ """Called by evaluation harness after env.step() returns info."""
35
+ self._current_difficulty = difficulty
36
+
37
+ def get_action(self, obs: Observation) -> int:
38
+ mult = DIFFICULTY_MULTIPLIERS.get(self._current_difficulty, DEFAULT_MULTIPLIER)
39
+ base = obs.remaining_budget / max(1, obs.questions_remaining)
40
+ allocation = int(base * mult)
41
+ allocation = max(self.config.min_tokens, min(allocation, self.config.max_tokens))
42
+ allocation = min(allocation, obs.remaining_budget)
43
+ return max(self.config.min_tokens, allocation)
44
+
45
+ def reset(self):
46
+ self._current_difficulty = "gsm8k"
reasonbudget_gym/baselines/uniform.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Uniform allocation baseline: split budget equally across all questions."""
2
+ from ..env.models import Observation
3
+ from ..env.config import EnvConfig
4
+
5
+
6
+ class UniformBaseline:
7
+ """Allocates remaining_budget / questions_remaining tokens each step."""
8
+
9
+ name = "uniform"
10
+
11
+ def __init__(self, config: EnvConfig):
12
+ self.config = config
13
+
14
+ def get_action(self, obs: Observation) -> int:
15
+ if obs.questions_remaining == 0:
16
+ return self.config.min_tokens
17
+ allocation = int(obs.remaining_budget / obs.questions_remaining)
18
+ allocation = max(self.config.min_tokens, min(allocation, self.config.max_tokens))
19
+ return allocation
20
+
21
+ def reset(self):
22
+ pass
reasonbudget_gym/client.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Remote HTTP client for the ReasonBudgetEnv server."""
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ try:
6
+ import requests
7
+ _OK = True
8
+ except ImportError:
9
+ _OK = False
10
+
11
+
12
+ @dataclass
13
+ class RemoteObs:
14
+ question_embedding: List[float]
15
+ remaining_budget: int
16
+ questions_remaining: int
17
+ budget_per_remaining: float
18
+ accuracy_so_far: float
19
+ history: List[dict]
20
+ done: bool
21
+ episode_id: Optional[str] = None
22
+
23
+
24
+ @dataclass
25
+ class RemoteResult:
26
+ observation: RemoteObs
27
+ reward: float
28
+ done: bool
29
+ info: Dict[str, Any]
30
+
31
+
32
+ class ReasonBudgetClient:
33
+ def __init__(self, base_url: str = "http://localhost:8000"):
34
+ if not _OK:
35
+ raise ImportError("pip install requests")
36
+ self.url = base_url.rstrip("/")
37
+
38
+ def health(self): return requests.get(f"{self.url}/health").json()
39
+ def info(self): return requests.get(f"{self.url}/info").json()
40
+
41
+ def reset(self, seed=None) -> RemoteObs:
42
+ r = requests.post(f"{self.url}/reset", json={"seed": seed})
43
+ r.raise_for_status()
44
+ return RemoteObs(**r.json())
45
+
46
+ def step(self, token_allocation: int) -> RemoteResult:
47
+ r = requests.post(f"{self.url}/step", json={"token_allocation": token_allocation})
48
+ r.raise_for_status()
49
+ d = r.json()
50
+ return RemoteResult(observation=RemoteObs(**d["observation"]),
51
+ reward=d["reward"], done=d["done"], info=d["info"])
reasonbudget_gym/data/__init__.py ADDED
File without changes
reasonbudget_gym/data/embeddings.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2f5b49d773c0f605b508c0c2d2c736305cd346654b9b63af8d3a3856cd71f12
3
+ size 768128
reasonbudget_gym/data/generate_synthetic_cache.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generate a synthetic response cache for offline/demo use.
3
+
4
+ Simulates realistic DeepSeek-R1 accuracy curves across difficulty tiers
5
+ WITHOUT any API calls. Uses MetaMathQA dataset if available, otherwise
6
+ falls back to synthetic arithmetic questions.
7
+
8
+ Usage:
9
+ python reasonbudget_gym/data/generate_synthetic_cache.py
10
+ # or from package root:
11
+ python -m reasonbudget_gym.data.generate_synthetic_cache
12
+
13
+ Output: data/response_cache.json (~500 questions × 5 budget tiers)
14
+ """
15
+ import hashlib
16
+ import json
17
+ import random
18
+ import sys
19
+ from pathlib import Path
20
+
21
+ # ---------------------------------------------------------------------------
22
+ # Accuracy model (from task spec — based on real DeepSeek-R1 behaviour)
23
+ # ---------------------------------------------------------------------------
24
+ ACCURACY_TABLE = {
25
+ # 50 100 200 400 800
26
+ "gsm8k": [0.85, 0.92, 0.95, 0.96, 0.97],
27
+ "math_l1_l2":[0.55, 0.70, 0.80, 0.88, 0.92],
28
+ "math_l3": [0.25, 0.40, 0.55, 0.70, 0.80],
29
+ "math_l4_l5":[0.08, 0.15, 0.30, 0.50, 0.65],
30
+ }
31
+ BUDGET_TIERS = [50, 100, 200, 400, 800]
32
+ N_QUESTIONS = 500
33
+ EVAL_FRACTION = 0.1
34
+ SEED = 42
35
+
36
+ # ---------------------------------------------------------------------------
37
+ # Helpers
38
+ # ---------------------------------------------------------------------------
39
+
40
+ def _make_qid(text: str, idx: int) -> str:
41
+ h = hashlib.md5(f"{idx}:{text}".encode()).hexdigest()[:12]
42
+ return f"q_{h}"
43
+
44
+
45
+ def _infer_difficulty(item: dict) -> str:
46
+ source = item.get("type", "").lower()
47
+ if "gsm" in source or "grade" in source:
48
+ return "gsm8k"
49
+ resp = item.get("response", "")
50
+ n = len(resp.split())
51
+ if n < 80: return "gsm8k"
52
+ if n < 150: return "math_l1_l2"
53
+ if n < 250: return "math_l3"
54
+ return "math_l4_l5"
55
+
56
+
57
+ def _extract_answer(response: str) -> str:
58
+ import re
59
+ for pat in [r"[Tt]he answer is[:\s]+([^\n.]+)", r"####\s*(.+)", r"=\s*([^\n]+)$"]:
60
+ m = re.search(pat, response)
61
+ if m:
62
+ return m.group(1).strip()
63
+ lines = [l.strip() for l in response.split("\n") if l.strip()]
64
+ return lines[-1] if lines else "unknown"
65
+
66
+
67
+ def _load_questions(n: int):
68
+ try:
69
+ from datasets import load_dataset
70
+ ds = load_dataset("meta-math/MetaMathQA", split="train", trust_remote_code=True)
71
+ items = list(ds.select(range(min(n, len(ds)))))
72
+ questions = []
73
+ for idx, item in enumerate(items):
74
+ qtext = item.get("query", f"Question {idx}")
75
+ answer = _extract_answer(item.get("response", ""))
76
+ diff = _infer_difficulty(item)
77
+ questions.append((_make_qid(qtext, idx), qtext, answer, diff))
78
+ print(f" Loaded {len(questions)} questions from MetaMathQA")
79
+ return questions
80
+ except Exception as e:
81
+ print(f" Dataset unavailable ({e}), using synthetic questions")
82
+ return _synthetic_questions(n)
83
+
84
+
85
+ def _synthetic_questions(n: int):
86
+ """Pure-Python fallback — no external deps."""
87
+ rng = random.Random(SEED)
88
+ templates = [
89
+ ("If {a} people each have {b} apples, how many apples total?", lambda a,b,c: a*b),
90
+ ("What is {a} + {b} * {c}?", lambda a,b,c: a+b*c),
91
+ ("A car travels at {a} km/h for {b} hours. Distance?", lambda a,b,c: a*b),
92
+ ("Solve: {a}x = {b}. What is x?", lambda a,b,c: b//a if a else 0),
93
+ ]
94
+ difficulties = ["gsm8k", "math_l1_l2", "math_l3", "math_l4_l5"]
95
+ out = []
96
+ for i in range(n):
97
+ a, b, c = rng.randint(2, 20), rng.randint(1, 15), rng.randint(1, 10)
98
+ tmpl, fn = templates[i % len(templates)]
99
+ q = tmpl.format(a=a, b=b, c=c)
100
+ ans = str(fn(a, b, c))
101
+ diff = difficulties[i % len(difficulties)]
102
+ out.append((_make_qid(q, i), q, ans, diff))
103
+ return out
104
+
105
+
106
+ # ---------------------------------------------------------------------------
107
+ # Cache generation
108
+ # ---------------------------------------------------------------------------
109
+
110
+ def generate_cache(output_path: str = None):
111
+ if output_path is None:
112
+ # Resolve relative to this file's parent directory
113
+ output_path = str(Path(__file__).parent / "response_cache.json")
114
+
115
+ rng = random.Random(SEED)
116
+ print(f"Generating synthetic cache ({N_QUESTIONS} questions × {len(BUDGET_TIERS)} tiers)...")
117
+ questions = _load_questions(N_QUESTIONS)
118
+
119
+ entries = {}
120
+ for qid, qtext, answer, difficulty in questions:
121
+ entries[qid] = {}
122
+ acc_curve = ACCURACY_TABLE[difficulty]
123
+ for tier_idx, tier in enumerate(BUDGET_TIERS):
124
+ p_correct = acc_curve[tier_idx]
125
+ correct = rng.random() < p_correct
126
+ # tokens_used: 70-95% of budget with some noise
127
+ used = int(tier * rng.uniform(0.70, 0.95))
128
+ entries[qid][str(tier)] = {
129
+ "answer": answer if correct else "unknown",
130
+ "was_correct": correct,
131
+ "tokens_used": used,
132
+ "response_text": "[synthetic cache entry]",
133
+ }
134
+
135
+ n_eval = int(len(questions) * EVAL_FRACTION)
136
+ eval_ids = [qid for qid, *_ in questions[-n_eval:]]
137
+
138
+ cache = {"entries": entries, "eval_ids": eval_ids}
139
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
140
+ with open(output_path, "w") as f:
141
+ json.dump(cache, f)
142
+
143
+ n_correct = sum(
144
+ 1 for qdata in entries.values()
145
+ for entry in qdata.values() if entry["was_correct"]
146
+ )
147
+ total = len(entries) * len(BUDGET_TIERS)
148
+ print(f" Written: {output_path}")
149
+ print(f" Questions: {len(entries)} | Overall accuracy: {n_correct/total:.1%}")
150
+ print(f" Eval holdout: {len(eval_ids)} questions")
151
+ return output_path
152
+
153
+
154
+ if __name__ == "__main__":
155
+ path = sys.argv[1] if len(sys.argv) > 1 else None
156
+ generate_cache(path)
reasonbudget_gym/data/response_cache.json ADDED
The diff for this file is too large to render. See raw diff
 
reasonbudget_gym/env/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .reason_budget_env import ReasonBudgetEnv
2
+ from .config import EnvConfig
3
+ from .models import Observation, StepResult, QuestionMeta
4
+
5
+ __all__ = ["ReasonBudgetEnv", "EnvConfig", "Observation", "StepResult", "QuestionMeta"]
reasonbudget_gym/env/config.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Environment configuration dataclass."""
2
+ from dataclasses import dataclass, field
3
+ from typing import List
4
+
5
+
6
+ @dataclass
7
+ class EnvConfig:
8
+ """Configuration for the ReasonBudgetEnv."""
9
+
10
+ # Token budget
11
+ total_budget: int = 4000
12
+ min_tokens: int = 50
13
+ max_tokens: int = 800
14
+ budget_tiers: List[int] = field(default_factory=lambda: [50, 100, 200, 400, 800])
15
+
16
+ # Episode structure
17
+ questions_per_episode: int = 10
18
+ seed: int = 42
19
+
20
+ # Reward weights
21
+ correct_reward: float = 1.0
22
+ wrong_penalty: float = -0.1
23
+ cost_penalty_weight: float = 0.0002 # β
24
+ efficiency_bonus_weight: float = 0.3 # γ
25
+
26
+ # Solver
27
+ solver_type: str = "cached" # "cached" | "live"
28
+ cache_path: str = "data/response_cache.json"
29
+
30
+ # Dataset
31
+ dataset_name: str = "meta-math/MetaMathQA"
32
+ max_cache_questions: int = 500
33
+
34
+ # Embedding
35
+ embedding_dim: int = 384
36
+ embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"
37
+ embedding_cache_path: str = "data/embeddings.npy"
reasonbudget_gym/env/episode_sampler.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sample episodes (sequences of questions) from the dataset."""
2
+ import hashlib
3
+ import random
4
+ from pathlib import Path
5
+ from typing import List
6
+
7
+ from .config import EnvConfig
8
+ from .models import QuestionMeta
9
+
10
+
11
+ # Difficulty label heuristics based on dataset origin
12
+ _DIFFICULTY_LEVELS = ["gsm8k", "math_l1_l2", "math_l3", "math_l4_l5"]
13
+
14
+
15
+ def _infer_difficulty(item: dict) -> str:
16
+ """Heuristically assign difficulty from MetaMathQA metadata."""
17
+ source = item.get("type", "").lower()
18
+ query = item.get("query", "")
19
+
20
+ if "gsm" in source or "grade" in source:
21
+ return "gsm8k"
22
+
23
+ # Use a simple proxy: length of the chain of thought as difficulty signal
24
+ response = item.get("response", "")
25
+ cot_len = len(response.split())
26
+ if cot_len < 80:
27
+ return "gsm8k"
28
+ elif cot_len < 150:
29
+ return "math_l1_l2"
30
+ elif cot_len < 250:
31
+ return "math_l3"
32
+ else:
33
+ return "math_l4_l5"
34
+
35
+
36
+ def _extract_answer(response: str) -> str:
37
+ """Extract the final answer from a MetaMathQA response."""
38
+ # MetaMathQA answers follow "The answer is X" or "#### X"
39
+ import re
40
+ patterns = [
41
+ r"[Tt]he answer is[:\s]+([^\n.]+)",
42
+ r"####\s*(.+)",
43
+ r"=\s*([^\n]+)$",
44
+ ]
45
+ for pat in patterns:
46
+ m = re.search(pat, response)
47
+ if m:
48
+ return m.group(1).strip()
49
+ # Fallback: last non-empty line
50
+ lines = [l.strip() for l in response.split("\n") if l.strip()]
51
+ return lines[-1] if lines else "unknown"
52
+
53
+
54
+ def _make_question_id(text: str, idx: int) -> str:
55
+ h = hashlib.md5(f"{idx}:{text}".encode()).hexdigest()[:12]
56
+ return f"q_{h}"
57
+
58
+
59
+ class EpisodeSampler:
60
+ """Loads questions from MetaMathQA and provides episode batches."""
61
+
62
+ def __init__(self, config: EnvConfig, split: str = "train"):
63
+ self.config = config
64
+ self.rng = random.Random(config.seed)
65
+ self._questions: List[QuestionMeta] = []
66
+ self._loaded = False
67
+ self._split = split
68
+
69
+ def _resolve_cache_path(self) -> Path:
70
+ return self._resolve_project_path(self.config.cache_path)
71
+
72
+ def _resolve_project_path(self, relative_path: str) -> Path:
73
+ path = Path(relative_path)
74
+ if path.is_absolute():
75
+ return path
76
+
77
+ pkg_root = Path(__file__).resolve().parent.parent
78
+ package_relative = pkg_root / relative_path
79
+ cwd_relative = Path.cwd() / relative_path
80
+ if package_relative.exists() or not cwd_relative.exists():
81
+ return package_relative
82
+ return cwd_relative
83
+
84
+ def _load_bundled_questions(self) -> List[dict]:
85
+ """Return deterministic synthetic questions that align with the shipped cache."""
86
+ from ..data.generate_synthetic_cache import _synthetic_questions
87
+
88
+ questions = []
89
+ for qid, qtext, answer, difficulty in _synthetic_questions(self.config.max_cache_questions):
90
+ questions.append({
91
+ "question_id": qid,
92
+ "query": qtext,
93
+ "answer": answer,
94
+ "difficulty": difficulty,
95
+ })
96
+ return questions
97
+
98
+ def _apply_embeddings(self, embeddings) -> None:
99
+ for question, embedding in zip(self._questions, embeddings):
100
+ question.embedding = [float(x) for x in embedding]
101
+
102
+ def _load_embeddings(self) -> None:
103
+ import numpy as np
104
+
105
+ if not self._questions:
106
+ return
107
+
108
+ expected_shape = (len(self._questions), self.config.embedding_dim)
109
+ embeddings_path = self._resolve_project_path(self.config.embedding_cache_path)
110
+
111
+ if embeddings_path.exists():
112
+ try:
113
+ cached = np.load(embeddings_path)
114
+ if cached.shape == expected_shape:
115
+ self._apply_embeddings(cached)
116
+ return
117
+ except Exception:
118
+ pass
119
+
120
+ try:
121
+ from sentence_transformers import SentenceTransformer
122
+
123
+ model = SentenceTransformer(self.config.embedding_model)
124
+ texts = [question.question_text for question in self._questions]
125
+ embeddings = np.asarray(
126
+ model.encode(texts, show_progress_bar=False),
127
+ dtype=np.float32,
128
+ )
129
+ if embeddings.shape != expected_shape:
130
+ return
131
+ embeddings_path.parent.mkdir(parents=True, exist_ok=True)
132
+ np.save(embeddings_path, embeddings)
133
+ self._apply_embeddings(embeddings)
134
+ except Exception:
135
+ # Fall back to zero embeddings when the model is unavailable.
136
+ return
137
+
138
+ def _load(self):
139
+ if self._loaded:
140
+ return
141
+
142
+ cache_path = self._resolve_cache_path()
143
+ if self.config.solver_type == "cached" and cache_path.exists():
144
+ items = self._load_bundled_questions()
145
+ else:
146
+ try:
147
+ from datasets import load_dataset
148
+
149
+ ds = load_dataset(
150
+ self.config.dataset_name,
151
+ split=self._split,
152
+ trust_remote_code=True,
153
+ )
154
+ # Limit to first max_cache_questions for speed
155
+ items = list(ds.select(range(min(self.config.max_cache_questions, len(ds)))))
156
+ except Exception:
157
+ # Offline fallback: generate synthetic questions
158
+ items = self._synthetic_questions(self.config.max_cache_questions)
159
+
160
+ self._questions = []
161
+ for idx, item in enumerate(items):
162
+ qtext = item.get("query", item.get("question", f"Question {idx}"))
163
+ response = item.get("response", item.get("answer", ""))
164
+ answer = _extract_answer(response)
165
+ difficulty = item.get("difficulty", _infer_difficulty(item))
166
+ qid = item.get("question_id", _make_question_id(qtext, idx))
167
+ # Embedding is a zero-vector placeholder; real embedding done lazily
168
+ self._questions.append(QuestionMeta(
169
+ question_id=qid,
170
+ question_text=qtext,
171
+ ground_truth=answer,
172
+ difficulty=difficulty,
173
+ embedding=[0.0] * self.config.embedding_dim,
174
+ ))
175
+ self._load_embeddings()
176
+ self._loaded = True
177
+
178
+ def _synthetic_questions(self, n: int) -> List[dict]:
179
+ """Generate synthetic questions when dataset is unavailable."""
180
+ templates = [
181
+ ("If a train travels at {v} km/h for {t} hours, how far does it go?", "{r}"),
182
+ ("What is {a} + {b} * {c}?", "{r}"),
183
+ ("Solve for x: {a}x + {b} = {c}", "{r}"),
184
+ ("A store sells apples for ${p} each. How much do {n} apples cost?", "${r}"),
185
+ ]
186
+ items = []
187
+ rng = random.Random(42)
188
+ difficulties = ["gsm8k", "math_l1_l2", "math_l3", "math_l4_l5"]
189
+ for i in range(n):
190
+ tmpl, ans_tmpl = templates[i % len(templates)]
191
+ a, b, c = rng.randint(1, 20), rng.randint(1, 10), rng.randint(1, 15)
192
+ v, t, p, nn = rng.randint(60, 200), rng.randint(1, 10), rng.randint(1, 5), rng.randint(2, 20)
193
+ r = a + b * c
194
+ query = tmpl.format(v=v, t=t, a=a, b=b, c=c, p=p, n=nn)
195
+ answer = str(r)
196
+ diff = difficulties[i % len(difficulties)]
197
+ # Fake CoT length based on difficulty
198
+ cot_lengths = {"gsm8k": 50, "math_l1_l2": 120, "math_l3": 200, "math_l4_l5": 300}
199
+ fake_cot = " word" * cot_lengths[diff]
200
+ items.append({"query": query, "response": f"{fake_cot} The answer is {answer}", "type": diff})
201
+ return items
202
+
203
+ def sample_episode(self) -> List[QuestionMeta]:
204
+ """Sample a sequence of questions_per_episode questions."""
205
+ self._load()
206
+ return self.rng.choices(self._questions, k=self.config.questions_per_episode)
207
+
208
+ def get_all_questions(self) -> List[QuestionMeta]:
209
+ self._load()
210
+ return list(self._questions)
reasonbudget_gym/env/models.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data models shared across the environment."""
2
+ from dataclasses import dataclass, field
3
+ from typing import List, Optional
4
+
5
+
6
+ @dataclass
7
+ class QuestionMeta:
8
+ """Metadata for a single question in an episode."""
9
+ question_id: str
10
+ question_text: str
11
+ ground_truth: str
12
+ difficulty: str # gsm8k | math_l1_l2 | math_l3 | math_l4_l5
13
+ embedding: List[float] # 384-dim sentence-transformer encoding
14
+
15
+
16
+ @dataclass
17
+ class StepInfo:
18
+ """Record of one completed step."""
19
+ tokens_allocated: int
20
+ tokens_used: int
21
+ was_correct: bool
22
+
23
+
24
+ @dataclass
25
+ class Observation:
26
+ """Full observation returned to the agent at each step."""
27
+ question_embedding: List[float] # 384-dim
28
+ remaining_budget: int
29
+ questions_remaining: int
30
+ budget_per_remaining: float
31
+ accuracy_so_far: float
32
+ history: List[StepInfo]
33
+ done: bool = False
34
+ episode_id: Optional[str] = None
35
+
36
+
37
+ @dataclass
38
+ class StepResult:
39
+ """Result of env.step()."""
40
+ observation: Observation
41
+ reward: float
42
+ done: bool
43
+ info: dict = field(default_factory=dict)
reasonbudget_gym/env/reason_budget_env.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ReasonBudgetEnv — core RL environment for token budget allocation.
3
+
4
+ Compatible with OpenEnv v0.2.x (step/reset interface).
5
+ """
6
+ import uuid
7
+ from typing import Optional
8
+
9
+ from .config import EnvConfig
10
+ from .episode_sampler import EpisodeSampler
11
+ from .models import Observation, StepInfo, StepResult
12
+ from .reward import compute_reward
13
+
14
+
15
+ class ReasonBudgetEnv:
16
+ """
17
+ Sequential token budget allocation environment.
18
+
19
+ At each step the agent observes the current question embedding plus
20
+ budget state and chooses how many tokens to allocate. The solver
21
+ attempts the question with that limit and returns a correctness signal.
22
+ """
23
+
24
+ def __init__(self, config: Optional[EnvConfig] = None):
25
+ self.config = config or EnvConfig()
26
+ self.sampler = EpisodeSampler(self.config)
27
+ self._solver = None
28
+ self._reset_state()
29
+
30
+ # ------------------------------------------------------------------
31
+ # Solver lazy init
32
+ # ------------------------------------------------------------------
33
+ @property
34
+ def solver(self):
35
+ if self._solver is None:
36
+ if self.config.solver_type == "cached":
37
+ from ..solver.cached_solver import CachedSolver
38
+ self._solver = CachedSolver(self.config)
39
+ else:
40
+ from ..solver.live_solver import LiveSolver
41
+ self._solver = LiveSolver(self.config)
42
+ return self._solver
43
+
44
+ # ------------------------------------------------------------------
45
+ # Core interface
46
+ # ------------------------------------------------------------------
47
+ def reset(self) -> Observation:
48
+ """Start a new episode. Returns first observation."""
49
+ self._reset_state()
50
+ self._episode = self.sampler.sample_episode()
51
+ self._episode_id = str(uuid.uuid4())[:8]
52
+ return self._make_observation()
53
+
54
+ def step(self, token_allocation: int) -> StepResult:
55
+ """Take a step: allocate tokens to the current question."""
56
+ if self._done:
57
+ raise RuntimeError("Episode is done. Call reset() first.")
58
+
59
+ # Clamp allocation to valid range and remaining budget
60
+ token_allocation = max(self.config.min_tokens,
61
+ min(token_allocation, self.config.max_tokens))
62
+ token_allocation = min(token_allocation, self._remaining_budget)
63
+ if token_allocation < self.config.min_tokens:
64
+ token_allocation = self.config.min_tokens # allow overspend on last Q
65
+
66
+ question = self._episode[self._step_idx]
67
+ budget_before = self._remaining_budget
68
+
69
+ # Solve
70
+ result = self.solver.solve(
71
+ question_id=question.question_id,
72
+ question_text=question.question_text,
73
+ ground_truth=question.ground_truth,
74
+ token_budget=token_allocation,
75
+ )
76
+
77
+ # Accounting
78
+ tokens_used = min(result.tokens_used, token_allocation)
79
+ self._remaining_budget -= token_allocation
80
+ self._step_idx += 1
81
+ self._history.append(StepInfo(
82
+ tokens_allocated=token_allocation,
83
+ tokens_used=tokens_used,
84
+ was_correct=result.was_correct,
85
+ ))
86
+ if result.was_correct:
87
+ self._correct_count += 1
88
+
89
+ # Reward
90
+ reward = compute_reward(
91
+ was_correct=result.was_correct,
92
+ tokens_allocated=token_allocation,
93
+ tokens_used=tokens_used,
94
+ remaining_budget_before=budget_before,
95
+ questions_remaining=len(self._episode) - self._step_idx,
96
+ config=self.config,
97
+ )
98
+
99
+ # Done?
100
+ self._done = (
101
+ self._step_idx >= len(self._episode)
102
+ or self._remaining_budget < self.config.min_tokens
103
+ )
104
+
105
+ obs = self._make_observation()
106
+ return StepResult(
107
+ observation=obs,
108
+ reward=reward,
109
+ done=self._done,
110
+ info={
111
+ "was_correct": result.was_correct,
112
+ "tokens_allocated": token_allocation,
113
+ "tokens_used": tokens_used,
114
+ "difficulty": question.difficulty,
115
+ "remaining_budget": self._remaining_budget,
116
+ "step": self._step_idx,
117
+ },
118
+ )
119
+
120
+ # ------------------------------------------------------------------
121
+ # Helpers
122
+ # ------------------------------------------------------------------
123
+ def _reset_state(self):
124
+ self._episode = []
125
+ self._episode_id = None
126
+ self._step_idx = 0
127
+ self._remaining_budget = self.config.total_budget
128
+ self._history = []
129
+ self._correct_count = 0
130
+ self._done = False
131
+
132
+ def _make_observation(self) -> Observation:
133
+ if self._step_idx < len(self._episode):
134
+ q = self._episode[self._step_idx]
135
+ emb = q.embedding
136
+ else:
137
+ emb = [0.0] * self.config.embedding_dim
138
+
139
+ n_remaining = max(0, len(self._episode) - self._step_idx)
140
+ bpr = self._remaining_budget / n_remaining if n_remaining > 0 else 0.0
141
+ acc = self._correct_count / self._step_idx if self._step_idx > 0 else 0.0
142
+
143
+ return Observation(
144
+ question_embedding=emb,
145
+ remaining_budget=self._remaining_budget,
146
+ questions_remaining=n_remaining,
147
+ budget_per_remaining=bpr,
148
+ accuracy_so_far=acc,
149
+ history=list(self._history),
150
+ done=self._done,
151
+ episode_id=self._episode_id,
152
+ )
153
+
154
+ @property
155
+ def observation_dim(self) -> int:
156
+ """Flat observation dimension (embedding + 4 scalars)."""
157
+ return self.config.embedding_dim + 4
158
+
159
+ def flat_observation(self, obs: Observation) -> list:
160
+ """Flatten observation to a list of floats for the policy."""
161
+ scalars = [
162
+ obs.remaining_budget / self.config.total_budget,
163
+ obs.questions_remaining / self.config.questions_per_episode,
164
+ obs.budget_per_remaining / self.config.max_tokens,
165
+ obs.accuracy_so_far,
166
+ ]
167
+ return list(obs.question_embedding) + scalars
reasonbudget_gym/env/reward.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reward function for the ReasonBudgetEnv."""
2
+ from .config import EnvConfig
3
+
4
+
5
+ def compute_reward(
6
+ was_correct: bool,
7
+ tokens_allocated: int,
8
+ tokens_used: int,
9
+ remaining_budget_before: int,
10
+ questions_remaining: int,
11
+ config: EnvConfig,
12
+ ) -> float:
13
+ """
14
+ Compute the per-step reward.
15
+
16
+ R = correctness_term − β * cost_penalty + γ * efficiency_bonus
17
+
18
+ correctness_term: +correct_reward if correct, else wrong_penalty
19
+ cost_penalty: tokens_used / total_budget (fractional spend)
20
+ efficiency_bonus: savings as fraction of budget if correct
21
+ 0 if wrong (no credit for being cheap and wrong)
22
+ """
23
+ # Correctness signal
24
+ if was_correct:
25
+ correctness_term = config.correct_reward
26
+ else:
27
+ correctness_term = config.wrong_penalty
28
+
29
+ # Cost penalty — proportional to tokens burned this step
30
+ cost_penalty = tokens_used / config.total_budget
31
+
32
+ # Efficiency bonus — reward for underspending when correct
33
+ if was_correct and tokens_allocated > 0:
34
+ savings = max(0, tokens_allocated - tokens_used)
35
+ efficiency_bonus = savings / config.total_budget
36
+ else:
37
+ efficiency_bonus = 0.0
38
+
39
+ reward = (
40
+ correctness_term
41
+ - config.cost_penalty_weight * cost_penalty
42
+ + config.efficiency_bonus_weight * efficiency_bonus
43
+ )
44
+ return float(reward)
reasonbudget_gym/eval/__init__.py ADDED
File without changes
reasonbudget_gym/eval/evaluate.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation harness: run all 4 baselines for N episodes and collect metrics.
3
+
4
+ Usage:
5
+ python -m reasonbudget_gym.eval.evaluate --n_episodes 50 --seed 42 --output eval_results.json
6
+ """
7
+ import argparse
8
+ import json
9
+ import numpy as np
10
+
11
+ from ..env.config import EnvConfig
12
+ from ..env.reason_budget_env import ReasonBudgetEnv
13
+ from ..baselines import UniformBaseline, GreedyMaxBaseline, DifficultyOracleBaseline, LinUCBBaseline
14
+
15
+
16
+ def run_episode(env, agent, config):
17
+ obs = env.reset()
18
+ if hasattr(agent, "set_difficulty") and env._episode:
19
+ agent.set_difficulty(env._episode[0].difficulty)
20
+ done = False
21
+ total_reward = 0.0
22
+ correct = 0
23
+ tokens = 0
24
+ steps = 0
25
+ per_step = []
26
+
27
+ while not done:
28
+ action = agent.get_action(obs)
29
+ result = env.step(action)
30
+ if hasattr(agent, "update"):
31
+ agent.update(result.reward)
32
+ if hasattr(agent, "set_difficulty"):
33
+ ep = env._episode
34
+ if env._step_idx < len(ep):
35
+ agent.set_difficulty(ep[env._step_idx].difficulty)
36
+
37
+ total_reward += result.reward
38
+ if result.info.get("was_correct"):
39
+ correct += 1
40
+ tokens += result.info.get("tokens_allocated", 0)
41
+ steps += 1
42
+ per_step.append({
43
+ "step": steps,
44
+ "tokens_allocated": result.info.get("tokens_allocated", 0),
45
+ "was_correct": result.info.get("was_correct", False),
46
+ "difficulty": result.info.get("difficulty", "unknown"),
47
+ "remaining_budget": result.info.get("remaining_budget", 0),
48
+ "reward": result.reward,
49
+ })
50
+ done = result.done
51
+ obs = result.observation
52
+
53
+ return {
54
+ "total_reward": total_reward,
55
+ "accuracy": correct / max(1, steps),
56
+ "total_tokens_used": tokens,
57
+ "budget_utilization": tokens / config.total_budget,
58
+ "steps": steps,
59
+ "per_step": per_step,
60
+ }
61
+
62
+
63
+ def evaluate_agent(name, agent, config, n_episodes, seed):
64
+ env = ReasonBudgetEnv(config)
65
+ agent.reset()
66
+ episodes = []
67
+ for i in range(n_episodes):
68
+ env.config.seed = seed + i
69
+ env.sampler.rng.seed(seed + i)
70
+ episodes.append(run_episode(env, agent, config))
71
+
72
+ rewards = [e["total_reward"] for e in episodes]
73
+ accs = [e["accuracy"] for e in episodes]
74
+ utils = [e["budget_utilization"] for e in episodes]
75
+ return {
76
+ "agent": name,
77
+ "n_episodes": n_episodes,
78
+ "mean_reward": float(np.mean(rewards)),
79
+ "std_reward": float(np.std(rewards)),
80
+ "mean_accuracy": float(np.mean(accs)),
81
+ "std_accuracy": float(np.std(accs)),
82
+ "mean_budget_utilization": float(np.mean(utils)),
83
+ "episodes": episodes,
84
+ }
85
+
86
+
87
+ def main():
88
+ parser = argparse.ArgumentParser()
89
+ parser.add_argument("--n_episodes", type=int, default=50)
90
+ parser.add_argument("--seed", type=int, default=42)
91
+ parser.add_argument("--output", type=str, default="eval_results.json")
92
+ args = parser.parse_args()
93
+
94
+ config = EnvConfig(seed=args.seed)
95
+ agents = {
96
+ "uniform": UniformBaseline(config),
97
+ "greedy_max": GreedyMaxBaseline(config),
98
+ "oracle": DifficultyOracleBaseline(config),
99
+ "bandit": LinUCBBaseline(config, seed=args.seed),
100
+ }
101
+
102
+ results = {}
103
+ print(f"Evaluating {len(agents)} agents × {args.n_episodes} episodes")
104
+ for name, agent in agents.items():
105
+ print(f" {name}...", end=" ", flush=True)
106
+ res = evaluate_agent(name, agent, config, args.n_episodes, args.seed)
107
+ results[name] = res
108
+ print(f"acc={res['mean_accuracy']:.3f} reward={res['mean_reward']:.3f}")
109
+
110
+ print(f"\n{'Agent':<15} {'Accuracy':>10} {'Reward':>10} {'Budget%':>10}")
111
+ print("-" * 50)
112
+ for name, res in results.items():
113
+ print(f"{name:<15} {res['mean_accuracy']:>10.3f} {res['mean_reward']:>10.3f} {res['mean_budget_utilization']:>9.1%}")
114
+
115
+ with open(args.output, "w") as f:
116
+ json.dump(results, f, indent=2)
117
+ print(f"\nSaved to {args.output}")
118
+
119
+
120
+ if __name__ == "__main__":
121
+ main()
reasonbudget_gym/eval/plots.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generate comparison charts from eval_results.json."""
2
+ import json
3
+ from pathlib import Path
4
+
5
+
6
+ def agent_comparison(results_path: str, output_path: str = "docs/agent_comparison.png"):
7
+ import numpy as np
8
+ import matplotlib
9
+ matplotlib.use("Agg")
10
+ import matplotlib.pyplot as plt
11
+
12
+ with open(results_path) as f:
13
+ results = json.load(f)
14
+
15
+ agents = list(results.keys())
16
+ accs = [results[a]["mean_accuracy"] for a in agents]
17
+ rewards = [results[a]["mean_reward"] for a in agents]
18
+ acc_stds = [results[a]["std_accuracy"] for a in agents]
19
+ x = np.arange(len(agents))
20
+ colors = ["#4C72B0", "#DD8452", "#55A868", "#C44E52"]
21
+
22
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
23
+ fig.suptitle("Baseline Agent Comparison — ReasoningEconomicsEnv", fontsize=13, fontweight="bold")
24
+
25
+ bars1 = ax1.bar(x, accs, 0.5, yerr=acc_stds, capsize=4, color=colors, alpha=0.85)
26
+ ax1.set_title("Mean Accuracy")
27
+ ax1.set_xticks(x); ax1.set_xticklabels(agents, rotation=15)
28
+ ax1.set_ylim(0, 1.1)
29
+ for b, v in zip(bars1, accs):
30
+ ax1.text(b.get_x() + b.get_width()/2, v + 0.03, f"{v:.3f}", ha="center", fontsize=9)
31
+
32
+ bars2 = ax2.bar(x, rewards, 0.5, color=colors, alpha=0.85)
33
+ ax2.set_title("Mean Episode Reward")
34
+ ax2.set_xticks(x); ax2.set_xticklabels(agents, rotation=15)
35
+ for b, v in zip(bars2, rewards):
36
+ ax2.text(b.get_x() + b.get_width()/2, v + abs(v)*0.03 + 0.01, f"{v:.3f}", ha="center", fontsize=9)
37
+
38
+ plt.tight_layout()
39
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
40
+ plt.savefig(output_path, dpi=150, bbox_inches="tight")
41
+ plt.close()
42
+ print(f"Saved: {output_path}")
43
+
44
+
45
+ def budget_pacing(results_path: str, output_path: str = "docs/budget_pacing.png"):
46
+ import numpy as np
47
+ import matplotlib
48
+ matplotlib.use("Agg")
49
+ import matplotlib.pyplot as plt
50
+
51
+ with open(results_path) as f:
52
+ results = json.load(f)
53
+
54
+ fig, ax = plt.subplots(figsize=(10, 5))
55
+ colors = {"uniform": "#4C72B0", "greedy_max": "#DD8452", "oracle": "#55A868", "bandit": "#C44E52"}
56
+
57
+ for name, res in results.items():
58
+ episodes = res["episodes"]
59
+ max_steps = max(len(ep["per_step"]) for ep in episodes)
60
+ mat = np.zeros((len(episodes), max_steps))
61
+ for i, ep in enumerate(episodes):
62
+ cumsum = 0
63
+ for j, s in enumerate(ep["per_step"]):
64
+ cumsum += s["tokens_allocated"]
65
+ mat[i, j] = cumsum
66
+ mat[i, len(ep["per_step"]):] = cumsum
67
+
68
+ mean = mat.mean(0); std = mat.std(0)
69
+ steps = np.arange(1, max_steps + 1)
70
+ c = colors.get(name, "gray")
71
+ ax.plot(steps, mean, label=name, color=c, linewidth=2)
72
+ ax.fill_between(steps, mean - std, mean + std, color=c, alpha=0.15)
73
+
74
+ ax.axhline(y=4000, color="black", linestyle="--", linewidth=1.5, label="Budget (4000)")
75
+ ax.set_xlabel("Question #"); ax.set_ylabel("Cumulative Tokens")
76
+ ax.set_title("Budget Pacing by Agent — Mean ± 1 SD", fontweight="bold")
77
+ ax.legend(); ax.grid(True, alpha=0.3)
78
+ plt.tight_layout()
79
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
80
+ plt.savefig(output_path, dpi=150, bbox_inches="tight")
81
+ plt.close()
82
+ print(f"Saved: {output_path}")
83
+
84
+
85
+ if __name__ == "__main__":
86
+ import sys
87
+ f = sys.argv[1] if len(sys.argv) > 1 else "eval_results.json"
88
+ agent_comparison(f)
89
+ budget_pacing(f)
reasonbudget_gym/policy/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .allocation_policy import AllocationPolicy
2
+
3
+ __all__ = ["AllocationPolicy"]
reasonbudget_gym/policy/allocation_policy.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MLP policy for token budget allocation.
3
+
4
+ Architecture:
5
+ - Shared trunk: FC(obs_dim → 256) → ReLU → FC(256 → 128) → ReLU
6
+ - Actor head: FC(128 → 1) producing mean, with learned log_std
7
+ Output: Gaussian(mean, std) over normalised token allocation
8
+ - Value head: FC(128 → 1)
9
+
10
+ The action is a continuous value in [0, 1] representing fraction of
11
+ max_tokens to allocate. It is scaled to an integer allocation at step time.
12
+ """
13
+ import math
14
+ import numpy as np
15
+
16
+ try:
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from torch.distributions import Normal
21
+ _TORCH_AVAILABLE = True
22
+ except ImportError:
23
+ _TORCH_AVAILABLE = False
24
+
25
+
26
+ def _require_torch():
27
+ if not _TORCH_AVAILABLE:
28
+ raise ImportError(
29
+ "PyTorch is required for AllocationPolicy. "
30
+ "Install it with: pip install torch"
31
+ )
32
+
33
+
34
+ class _PolicyNet(nn.Module if _TORCH_AVAILABLE else object):
35
+ """Neural network backbone for the allocation policy."""
36
+
37
+ def __init__(self, obs_dim: int, hidden: int = 256):
38
+ if not _TORCH_AVAILABLE:
39
+ raise ImportError("PyTorch required")
40
+ super().__init__()
41
+ self.trunk = nn.Sequential(
42
+ nn.Linear(obs_dim, hidden),
43
+ nn.ReLU(),
44
+ nn.Linear(hidden, hidden // 2),
45
+ nn.ReLU(),
46
+ )
47
+ self.actor_mean = nn.Linear(hidden // 2, 1)
48
+ self.log_std = nn.Parameter(torch.zeros(1))
49
+ self.value_head = nn.Linear(hidden // 2, 1)
50
+
51
+ def forward(self, x):
52
+ h = self.trunk(x)
53
+ mean = torch.sigmoid(self.actor_mean(h)) # in (0, 1)
54
+ std = torch.exp(self.log_std).clamp(0.01, 1.0)
55
+ value = self.value_head(h)
56
+ return mean, std, value
57
+
58
+
59
+ class AllocationPolicy:
60
+ """
61
+ High-level wrapper around the policy network.
62
+
63
+ Provides get_action() compatible with the baseline interface,
64
+ plus PPO-specific methods (evaluate_actions, value).
65
+ """
66
+
67
+ def __init__(self, obs_dim: int, max_tokens: int, min_tokens: int = 50):
68
+ _require_torch()
69
+ self.obs_dim = obs_dim
70
+ self.max_tokens = max_tokens
71
+ self.min_tokens = min_tokens
72
+ self.net = _PolicyNet(obs_dim)
73
+ self.optimizer = torch.optim.Adam(self.net.parameters(), lr=3e-4)
74
+
75
+ def _obs_to_tensor(self, obs_flat: list) -> "torch.Tensor":
76
+ return torch.tensor(obs_flat, dtype=torch.float32).unsqueeze(0)
77
+
78
+ def get_action(self, obs_flat: list) -> tuple:
79
+ """
80
+ Returns (action_int, log_prob, value_estimate).
81
+
82
+ action_int: token allocation as integer
83
+ """
84
+ self.net.eval()
85
+ with torch.no_grad():
86
+ x = self._obs_to_tensor(obs_flat)
87
+ mean, std, value = self.net(x)
88
+ dist = Normal(mean, std)
89
+ sample = dist.sample().clamp(0.0, 1.0)
90
+ log_prob = dist.log_prob(sample).squeeze()
91
+
92
+ frac = sample.item()
93
+ action_int = int(self.min_tokens + frac * (self.max_tokens - self.min_tokens))
94
+ return action_int, log_prob.item(), value.squeeze().item()
95
+
96
+ def evaluate_actions(self, obs_batch, action_fracs):
97
+ """
98
+ Compute log_probs, entropy, values for a batch.
99
+
100
+ obs_batch: Tensor [B, obs_dim]
101
+ action_fracs: Tensor [B, 1] (normalised actions in [0,1])
102
+ """
103
+ self.net.train()
104
+ mean, std, values = self.net(obs_batch)
105
+ dist = Normal(mean, std)
106
+ log_probs = dist.log_prob(action_fracs)
107
+ entropy = dist.entropy()
108
+ return log_probs, entropy, values
109
+
110
+ def save(self, path: str):
111
+ _require_torch()
112
+ torch.save({
113
+ "net_state": self.net.state_dict(),
114
+ "optimizer_state": self.optimizer.state_dict(),
115
+ "obs_dim": self.obs_dim,
116
+ "max_tokens": self.max_tokens,
117
+ "min_tokens": self.min_tokens,
118
+ }, path)
119
+
120
+ def load(self, path: str):
121
+ _require_torch()
122
+ ckpt = torch.load(path, map_location="cpu")
123
+ self.net.load_state_dict(ckpt["net_state"])
124
+ self.optimizer.load_state_dict(ckpt["optimizer_state"])
125
+
126
+ def reset(self):
127
+ pass
reasonbudget_gym/server/__init__.py ADDED
File without changes
reasonbudget_gym/server/app.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenEnv-compatible FastAPI server for ReasonBudgetEnv.
3
+ Entry point: reasonbudget_gym.server.app:app
4
+ """
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ from fastapi import FastAPI, HTTPException
8
+ from fastapi.responses import HTMLResponse
9
+ from pydantic import BaseModel
10
+
11
+ from ..env.config import EnvConfig
12
+ from ..env.models import Observation
13
+ from ..env.reason_budget_env import ReasonBudgetEnv
14
+
15
+
16
+ class ResetRequest(BaseModel):
17
+ seed: Optional[int] = None
18
+
19
+
20
+ class StepRequest(BaseModel):
21
+ token_allocation: int
22
+
23
+
24
+ class ObsResponse(BaseModel):
25
+ question_embedding: List[float]
26
+ remaining_budget: int
27
+ questions_remaining: int
28
+ budget_per_remaining: float
29
+ accuracy_so_far: float
30
+ history: List[dict]
31
+ done: bool
32
+ episode_id: Optional[str] = None
33
+
34
+
35
+ class StepResponse(BaseModel):
36
+ observation: ObsResponse
37
+ reward: float
38
+ done: bool
39
+ info: Dict[str, Any]
40
+
41
+
42
+ def _to_obs_response(obs: Observation) -> ObsResponse:
43
+ return ObsResponse(
44
+ question_embedding=obs.question_embedding,
45
+ remaining_budget=obs.remaining_budget,
46
+ questions_remaining=obs.questions_remaining,
47
+ budget_per_remaining=obs.budget_per_remaining,
48
+ accuracy_so_far=obs.accuracy_so_far,
49
+ history=[
50
+ {"tokens_allocated": s.tokens_allocated,
51
+ "tokens_used": s.tokens_used,
52
+ "was_correct": s.was_correct}
53
+ for s in obs.history
54
+ ],
55
+ done=obs.done,
56
+ episode_id=obs.episode_id,
57
+ )
58
+
59
+
60
+ def _landing_page() -> str:
61
+ return """<!DOCTYPE html>
62
+ <html lang="en">
63
+ <head>
64
+ <meta charset="utf-8" />
65
+ <meta name="viewport" content="width=device-width, initial-scale=1" />
66
+ <title>ReasoningEconomicsEnv</title>
67
+ <style>
68
+ :root {
69
+ color-scheme: light;
70
+ --bg: #f6f4ec;
71
+ --card: #fffdf6;
72
+ --ink: #182022;
73
+ --muted: #536067;
74
+ --accent: #0e7c66;
75
+ --border: #d8d1c0;
76
+ }
77
+ body {
78
+ margin: 0;
79
+ font-family: "Iowan Old Style", "Palatino Linotype", serif;
80
+ background:
81
+ radial-gradient(circle at top left, rgba(14, 124, 102, 0.16), transparent 30%),
82
+ linear-gradient(180deg, #f8f6ef 0%, var(--bg) 100%);
83
+ color: var(--ink);
84
+ }
85
+ main {
86
+ max-width: 840px;
87
+ margin: 0 auto;
88
+ padding: 48px 24px 64px;
89
+ }
90
+ .eyebrow {
91
+ text-transform: uppercase;
92
+ letter-spacing: 0.12em;
93
+ font-size: 0.78rem;
94
+ color: var(--accent);
95
+ margin-bottom: 12px;
96
+ }
97
+ h1 {
98
+ font-size: clamp(2.2rem, 6vw, 4.2rem);
99
+ line-height: 0.95;
100
+ margin: 0 0 18px;
101
+ }
102
+ p {
103
+ font-size: 1.08rem;
104
+ line-height: 1.65;
105
+ color: var(--muted);
106
+ margin: 0 0 18px;
107
+ }
108
+ .card {
109
+ margin-top: 32px;
110
+ background: var(--card);
111
+ border: 1px solid var(--border);
112
+ border-radius: 18px;
113
+ padding: 24px;
114
+ box-shadow: 0 14px 50px rgba(24, 32, 34, 0.08);
115
+ }
116
+ .links {
117
+ display: flex;
118
+ flex-wrap: wrap;
119
+ gap: 12px;
120
+ margin: 24px 0 12px;
121
+ }
122
+ a {
123
+ color: inherit;
124
+ }
125
+ .button {
126
+ display: inline-block;
127
+ text-decoration: none;
128
+ border-radius: 999px;
129
+ padding: 12px 18px;
130
+ border: 1px solid var(--border);
131
+ background: white;
132
+ font-weight: 600;
133
+ }
134
+ .button.primary {
135
+ background: var(--accent);
136
+ border-color: var(--accent);
137
+ color: white;
138
+ }
139
+ ul {
140
+ margin: 18px 0 0;
141
+ padding-left: 20px;
142
+ color: var(--muted);
143
+ }
144
+ code {
145
+ font-family: "SFMono-Regular", "Menlo", monospace;
146
+ font-size: 0.95em;
147
+ }
148
+ </style>
149
+ </head>
150
+ <body>
151
+ <main>
152
+ <div class="eyebrow">AgentX–AgentBeats / OpenEnv</div>
153
+ <h1>ReasoningEconomicsEnv</h1>
154
+ <p>
155
+ A reinforcement learning environment for allocating reasoning tokens across
156
+ multi-question episodes under a fixed compute budget.
157
+ </p>
158
+ <p>
159
+ This Space serves the environment as an API-first FastAPI app. Use the links
160
+ below to inspect the schema, check runtime health, or query environment metadata.
161
+ </p>
162
+
163
+ <div class="links">
164
+ <a class="button primary" href="/docs">Open API Docs</a>
165
+ <a class="button" href="/health">Health Check</a>
166
+ <a class="button" href="/info">Environment Info</a>
167
+ </div>
168
+
169
+ <section class="card">
170
+ <p><strong>Core endpoints</strong></p>
171
+ <ul>
172
+ <li><code>GET /health</code> returns service status.</li>
173
+ <li><code>GET /info</code> returns observation/action-space metadata.</li>
174
+ <li><code>POST /reset</code> starts a fresh episode.</li>
175
+ <li><code>POST /step</code> advances the environment by one allocation decision.</li>
176
+ </ul>
177
+ </section>
178
+ </main>
179
+ </body>
180
+ </html>
181
+ """
182
+
183
+
184
+ def create_fastapi_app() -> FastAPI:
185
+ app = FastAPI(title="ReasoningEconomicsEnv", version="0.1.0")
186
+ config = EnvConfig()
187
+ env = ReasonBudgetEnv(config)
188
+
189
+ @app.get("/", response_class=HTMLResponse)
190
+ def index():
191
+ return _landing_page()
192
+
193
+ @app.get("/health")
194
+ def health():
195
+ return {"status": "ok", "env": "ReasonBudgetEnv", "version": "0.1.0"}
196
+
197
+ @app.get("/info")
198
+ def info():
199
+ return {
200
+ "name": "ReasoningEconomicsEnv",
201
+ "observation_dim": env.observation_dim,
202
+ "action_space": {"type": "integer", "min": config.min_tokens, "max": config.max_tokens},
203
+ "total_budget": config.total_budget,
204
+ "questions_per_episode": config.questions_per_episode,
205
+ }
206
+
207
+ @app.post("/reset", response_model=ObsResponse)
208
+ def reset(req: ResetRequest):
209
+ if req.seed is not None:
210
+ env.config.seed = req.seed
211
+ env.sampler.rng.seed(req.seed)
212
+ return _to_obs_response(env.reset())
213
+
214
+ @app.post("/step", response_model=StepResponse)
215
+ def step(req: StepRequest):
216
+ try:
217
+ result = env.step(req.token_allocation)
218
+ except RuntimeError as e:
219
+ raise HTTPException(status_code=400, detail=str(e))
220
+ return StepResponse(
221
+ observation=_to_obs_response(result.observation),
222
+ reward=result.reward,
223
+ done=result.done,
224
+ info=result.info,
225
+ )
226
+
227
+ return app
228
+
229
+
230
+ # OpenEnv looks for `app` at module level
231
+ app = create_fastapi_app()
232
+ # backwards-compat alias
233
+ create_app = create_fastapi_app
reasonbudget_gym/solver/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .base import BaseSolver, SolverResult
2
+ from .cached_solver import CachedSolver
3
+
4
+ __all__ = ["BaseSolver", "SolverResult", "CachedSolver"]
reasonbudget_gym/solver/base.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Abstract base class for solvers."""
2
+ from abc import ABC, abstractmethod
3
+ from dataclasses import dataclass
4
+
5
+
6
+ @dataclass
7
+ class SolverResult:
8
+ """Result of a solver attempt."""
9
+ answer: str
10
+ was_correct: bool
11
+ tokens_used: int
12
+ response_text: str = ""
13
+
14
+
15
+ class BaseSolver(ABC):
16
+ """Attempt to answer a question within a token budget."""
17
+
18
+ @abstractmethod
19
+ def solve(
20
+ self,
21
+ question_id: str,
22
+ question_text: str,
23
+ ground_truth: str,
24
+ token_budget: int,
25
+ ) -> SolverResult:
26
+ ...
reasonbudget_gym/solver/cached_solver.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CachedSolver — serves pre-computed solver results from a JSON cache.
3
+
4
+ Cache schema:
5
+ {
6
+ "entries": {
7
+ "<question_id>": {
8
+ "<budget_tier>": {
9
+ "answer": str,
10
+ "was_correct": bool,
11
+ "tokens_used": int,
12
+ "response_text": str
13
+ }
14
+ }
15
+ },
16
+ "eval_ids": [list of question_ids for holdout]
17
+ }
18
+ """
19
+ import json
20
+ import os
21
+ from pathlib import Path
22
+
23
+ from .base import BaseSolver, SolverResult
24
+ from ..env.config import EnvConfig
25
+
26
+
27
+ class CachedSolver(BaseSolver):
28
+ """Look up pre-computed results from a JSON cache file."""
29
+
30
+ def __init__(self, config: EnvConfig):
31
+ self.config = config
32
+ self._cache: dict = {}
33
+ self._eval_ids: list = []
34
+ self._load_cache()
35
+
36
+ def _load_cache(self):
37
+ # Resolve relative path from project root
38
+ cache_path = Path(self.config.cache_path)
39
+ if not cache_path.is_absolute():
40
+ # Try relative to package root, then cwd
41
+ pkg_root = Path(__file__).parent.parent
42
+ candidate = pkg_root / self.config.cache_path
43
+ if candidate.exists():
44
+ cache_path = candidate
45
+ if not cache_path.exists():
46
+ # Cache missing — will return fallback results
47
+ return
48
+ with open(cache_path) as f:
49
+ data = json.load(f)
50
+ self._cache = data.get("entries", {})
51
+ self._eval_ids = data.get("eval_ids", [])
52
+
53
+ def _nearest_tier(self, budget: int) -> str:
54
+ """Return the nearest budget tier <= budget (or smallest if budget < min)."""
55
+ tiers = sorted(int(t) for t in (list(self._cache.values())[0].keys()
56
+ if self._cache else [50, 100, 200, 400, 800]))
57
+ best = tiers[0]
58
+ for t in tiers:
59
+ if t <= budget:
60
+ best = t
61
+ return str(best)
62
+
63
+ def solve(
64
+ self,
65
+ question_id: str,
66
+ question_text: str,
67
+ ground_truth: str,
68
+ token_budget: int,
69
+ ) -> SolverResult:
70
+ if question_id not in self._cache:
71
+ # Fallback: simulate based on budget tier ratio
72
+ return self._fallback(ground_truth, token_budget)
73
+
74
+ tier = self._nearest_tier(token_budget)
75
+ entry = self._cache[question_id].get(tier, self._cache[question_id].get(
76
+ str(sorted(int(k) for k in self._cache[question_id])[0])
77
+ ))
78
+ return SolverResult(
79
+ answer=entry["answer"],
80
+ was_correct=entry["was_correct"],
81
+ tokens_used=entry["tokens_used"],
82
+ response_text=entry.get("response_text", ""),
83
+ )
84
+
85
+ def _fallback(self, ground_truth: str, token_budget: int) -> SolverResult:
86
+ """When question not in cache, use a deterministic heuristic."""
87
+ import hashlib
88
+ h = int(hashlib.md5(ground_truth.encode()).hexdigest(), 16)
89
+ # Probability of correctness grows with budget
90
+ p = min(0.9, 0.3 + token_budget / 1600)
91
+ correct = (h % 100) < int(p * 100)
92
+ used = int(token_budget * 0.85)
93
+ return SolverResult(
94
+ answer=ground_truth if correct else "unknown",
95
+ was_correct=correct,
96
+ tokens_used=used,
97
+ response_text="[fallback]",
98
+ )
reasonbudget_gym/solver/live_solver.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LiveSolver — calls DeepSeek-R1 via Together AI for real inference.
3
+
4
+ Requires TOGETHER_API_KEY environment variable.
5
+ Not used by default; switch solver_type to 'live' to enable.
6
+ """
7
+ import os
8
+ import re
9
+ from .base import BaseSolver, SolverResult
10
+ from ..env.config import EnvConfig
11
+
12
+
13
+ class LiveSolver(BaseSolver):
14
+ """Call DeepSeek-R1 API for live reasoning."""
15
+
16
+ MODEL = "deepseek-ai/DeepSeek-R1"
17
+
18
+ def __init__(self, config: EnvConfig):
19
+ self.config = config
20
+ try:
21
+ from together import Together
22
+ self._client = Together(api_key=os.environ["TOGETHER_API_KEY"])
23
+ except ImportError:
24
+ raise ImportError("Install `together` package: pip install together")
25
+ except KeyError:
26
+ raise EnvironmentError("TOGETHER_API_KEY not set")
27
+
28
+ def solve(
29
+ self,
30
+ question_id: str,
31
+ question_text: str,
32
+ ground_truth: str,
33
+ token_budget: int,
34
+ ) -> SolverResult:
35
+ prompt = (
36
+ f"Solve the following math problem step by step.\n\n"
37
+ f"Problem: {question_text}\n\n"
38
+ f"Provide your final answer as: The answer is <answer>"
39
+ )
40
+ response = self._client.chat.completions.create(
41
+ model=self.MODEL,
42
+ messages=[{"role": "user", "content": prompt}],
43
+ max_tokens=token_budget,
44
+ )
45
+ text = response.choices[0].message.content or ""
46
+ tokens_used = response.usage.completion_tokens if response.usage else token_budget
47
+
48
+ # Extract answer
49
+ m = re.search(r"[Tt]he answer is[:\s]+([^\n.]+)", text)
50
+ answer = m.group(1).strip() if m else text.strip().split("\n")[-1]
51
+
52
+ # Simple correctness check
53
+ was_correct = (
54
+ ground_truth.strip().lower() in answer.lower()
55
+ or answer.lower() in ground_truth.strip().lower()
56
+ )
57
+ return SolverResult(
58
+ answer=answer,
59
+ was_correct=was_correct,
60
+ tokens_used=tokens_used,
61
+ response_text=text[:500],
62
+ )
reasonbudget_gym/tests/__init__.py ADDED
File without changes
reasonbudget_gym/tests/test_config.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
4
+
5
+ from reasonbudget_gym.env.config import EnvConfig
6
+
7
+
8
+ def test_default_config():
9
+ c = EnvConfig()
10
+ assert c.total_budget == 4000
11
+ assert c.min_tokens == 50
12
+ assert c.max_tokens == 800
13
+ assert len(c.budget_tiers) == 5
14
+ assert c.questions_per_episode == 10
15
+
16
+
17
+ def test_custom_config():
18
+ c = EnvConfig(total_budget=2000, questions_per_episode=5)
19
+ assert c.total_budget == 2000
20
+ assert c.questions_per_episode == 5
21
+
22
+
23
+ def test_reward_weights():
24
+ c = EnvConfig()
25
+ assert c.correct_reward > 0
26
+ assert c.cost_penalty_weight >= 0
27
+ assert c.efficiency_bonus_weight >= 0
reasonbudget_gym/tests/test_integration.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
4
+
5
+ from reasonbudget_gym.env.config import EnvConfig
6
+ from reasonbudget_gym.env.reason_budget_env import ReasonBudgetEnv
7
+ from reasonbudget_gym.baselines.uniform import UniformBaseline
8
+
9
+
10
+ def test_full_episode():
11
+ config = EnvConfig(questions_per_episode=5, total_budget=2000)
12
+ env = ReasonBudgetEnv(config)
13
+ agent = UniformBaseline(config)
14
+
15
+ obs = env.reset()
16
+ assert obs.questions_remaining == 5
17
+ assert not obs.done
18
+
19
+ done = False
20
+ steps = 0
21
+ total_reward = 0.0
22
+ while not done:
23
+ result = env.step(agent.get_action(obs))
24
+ total_reward += result.reward
25
+ done = result.done
26
+ obs = result.observation
27
+ steps += 1
28
+ assert steps <= 20
29
+
30
+ assert obs.done
31
+ assert steps == 5
32
+ assert -20 < total_reward < 20
33
+
34
+
35
+ def test_reset_clears_state():
36
+ config = EnvConfig(questions_per_episode=3, total_budget=1000)
37
+ env = ReasonBudgetEnv(config)
38
+ env.reset()
39
+ env.step(100); env.step(100)
40
+ obs = env.reset()
41
+ assert obs.questions_remaining == 3
42
+ assert obs.remaining_budget == 1000
43
+ assert len(obs.history) == 0
reasonbudget_gym/tests/test_reward.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
4
+
5
+ from reasonbudget_gym.env.config import EnvConfig
6
+ from reasonbudget_gym.env.reward import compute_reward
7
+
8
+
9
+ def test_correct_positive():
10
+ r = compute_reward(True, 200, 180, 2000, 5, EnvConfig())
11
+ assert r > 0
12
+
13
+
14
+ def test_wrong_expensive_negative():
15
+ r = compute_reward(False, 800, 800, 4000, 1, EnvConfig(cost_penalty_weight=0.5))
16
+ assert r < 0
17
+
18
+
19
+ def test_efficiency_bonus():
20
+ cfg = EnvConfig(efficiency_bonus_weight=0.5)
21
+ r_cheap = compute_reward(True, 400, 50, 2000, 5, cfg)
22
+ r_full = compute_reward(True, 400, 400, 2000, 5, cfg)
23
+ assert r_cheap > r_full
reasonbudget_gym/training/__init__.py ADDED
File without changes
reasonbudget_gym/training/ppo_train.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PPO training loop for the token budget allocation policy.
3
+
4
+ Usage:
5
+ python -m reasonbudget_gym.training.ppo_train \
6
+ --n_episodes 500 \
7
+ --ppo_epochs 4 \
8
+ --clip_eps 0.2 \
9
+ --value_coef 0.5 \
10
+ --entropy_coef 0.01 \
11
+ --output_dir runs/ppo_run1
12
+
13
+ Training procedure:
14
+ 1. Roll out N episodes using the current policy (stochastic)
15
+ 2. Compute returns and GAE advantages
16
+ 3. Run PPO update for K epochs over the rollout buffer
17
+ 4. Log metrics and save checkpoint every 50 episodes
18
+ """
19
+ import argparse
20
+ import json
21
+ import math
22
+ import os
23
+ import random
24
+ import sys
25
+ from pathlib import Path
26
+
27
+ try:
28
+ import torch
29
+ import torch.nn.functional as F
30
+ _TORCH_AVAILABLE = True
31
+ except ImportError:
32
+ _TORCH_AVAILABLE = False
33
+
34
+ from ..env.config import EnvConfig
35
+ from ..env.reason_budget_env import ReasonBudgetEnv
36
+
37
+
38
+ def collect_rollout(env: ReasonBudgetEnv, policy, config: EnvConfig) -> dict:
39
+ """Run one episode and collect (obs, action_frac, log_prob, reward, value, done)."""
40
+ obs = env.reset()
41
+ flat_obs = env.flat_observation(obs)
42
+
43
+ observations, action_fracs, log_probs, rewards, values, dones = [], [], [], [], [], []
44
+ done = False
45
+ total_reward = 0.0
46
+ correct_count = 0
47
+ steps = 0
48
+
49
+ while not done:
50
+ action_int, log_prob, value = policy.get_action(flat_obs)
51
+ result = env.step(action_int)
52
+
53
+ # Normalise action to [0, 1]
54
+ frac = (action_int - config.min_tokens) / max(1, config.max_tokens - config.min_tokens)
55
+ frac = max(0.0, min(1.0, frac))
56
+
57
+ observations.append(flat_obs)
58
+ action_fracs.append(frac)
59
+ log_probs.append(log_prob)
60
+ rewards.append(result.reward)
61
+ values.append(value)
62
+ dones.append(result.done)
63
+
64
+ if result.info.get("was_correct"):
65
+ correct_count += 1
66
+ total_reward += result.reward
67
+ steps += 1
68
+ done = result.done
69
+ flat_obs = env.flat_observation(result.observation)
70
+
71
+ return {
72
+ "observations": observations,
73
+ "action_fracs": action_fracs,
74
+ "log_probs": log_probs,
75
+ "rewards": rewards,
76
+ "values": values,
77
+ "dones": dones,
78
+ "total_reward": total_reward,
79
+ "accuracy": correct_count / max(1, steps),
80
+ "steps": steps,
81
+ }
82
+
83
+
84
+ def compute_returns_and_advantages(
85
+ rewards, values, dones, gamma=0.99, lam=0.95
86
+ ):
87
+ """Compute GAE advantages and discounted returns."""
88
+ n = len(rewards)
89
+ advantages = [0.0] * n
90
+ returns = [0.0] * n
91
+ gae = 0.0
92
+ next_value = 0.0
93
+
94
+ for t in reversed(range(n)):
95
+ mask = 0.0 if dones[t] else 1.0
96
+ delta = rewards[t] + gamma * next_value * mask - values[t]
97
+ gae = delta + gamma * lam * mask * gae
98
+ advantages[t] = gae
99
+ next_value = values[t]
100
+
101
+ for t in range(n):
102
+ returns[t] = advantages[t] + values[t]
103
+
104
+ return advantages, returns
105
+
106
+
107
+ def ppo_update(policy, rollouts: list, clip_eps: float, value_coef: float,
108
+ entropy_coef: float, ppo_epochs: int, batch_size: int = 64):
109
+ """Run PPO updates over the collected rollouts."""
110
+ import torch
111
+
112
+ # Flatten all rollouts
113
+ all_obs, all_fracs, all_lp, all_adv, all_ret = [], [], [], [], []
114
+ for r in rollouts:
115
+ all_obs.extend(r["observations"])
116
+ all_fracs.extend(r["action_fracs"])
117
+ all_lp.extend(r["log_probs"])
118
+ all_adv.extend(r["advantages"])
119
+ all_ret.extend(r["returns"])
120
+
121
+ obs_t = torch.tensor(all_obs, dtype=torch.float32)
122
+ fracs_t = torch.tensor(all_fracs, dtype=torch.float32).unsqueeze(1)
123
+ old_lp_t = torch.tensor(all_lp, dtype=torch.float32).unsqueeze(1)
124
+ adv_t = torch.tensor(all_adv, dtype=torch.float32).unsqueeze(1)
125
+ ret_t = torch.tensor(all_ret, dtype=torch.float32).unsqueeze(1)
126
+
127
+ # Normalise advantages
128
+ adv_t = (adv_t - adv_t.mean()) / (adv_t.std() + 1e-8)
129
+
130
+ n = obs_t.shape[0]
131
+ total_loss = 0.0
132
+ n_updates = 0
133
+
134
+ for _ in range(ppo_epochs):
135
+ idx = torch.randperm(n)
136
+ for start in range(0, n, batch_size):
137
+ b_idx = idx[start:start + batch_size]
138
+ b_obs = obs_t[b_idx]
139
+ b_fracs = fracs_t[b_idx]
140
+ b_old_lp = old_lp_t[b_idx]
141
+ b_adv = adv_t[b_idx]
142
+ b_ret = ret_t[b_idx]
143
+
144
+ new_lp, entropy, values = policy.evaluate_actions(b_obs, b_fracs)
145
+
146
+ ratio = torch.exp(new_lp - b_old_lp)
147
+ surr1 = ratio * b_adv
148
+ surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * b_adv
149
+ actor_loss = -torch.min(surr1, surr2).mean()
150
+ value_loss = F.mse_loss(values, b_ret)
151
+ entropy_loss = -entropy.mean()
152
+
153
+ loss = actor_loss + value_coef * value_loss + entropy_coef * entropy_loss
154
+
155
+ policy.optimizer.zero_grad()
156
+ loss.backward()
157
+ torch.nn.utils.clip_grad_norm_(policy.net.parameters(), 0.5)
158
+ policy.optimizer.step()
159
+
160
+ total_loss += loss.item()
161
+ n_updates += 1
162
+
163
+ return total_loss / max(1, n_updates)
164
+
165
+
166
+ def train(args):
167
+ if not _TORCH_AVAILABLE:
168
+ print("ERROR: PyTorch not installed. Run: pip install torch", file=sys.stderr)
169
+ sys.exit(1)
170
+
171
+ config = EnvConfig(seed=args.seed)
172
+ env = ReasonBudgetEnv(config)
173
+ obs_dim = env.observation_dim
174
+
175
+ from ..policy.allocation_policy import AllocationPolicy
176
+ policy = AllocationPolicy(
177
+ obs_dim=obs_dim,
178
+ max_tokens=config.max_tokens,
179
+ min_tokens=config.min_tokens,
180
+ )
181
+
182
+ output_dir = Path(args.output_dir)
183
+ output_dir.mkdir(parents=True, exist_ok=True)
184
+
185
+ history = []
186
+ print(f"Starting PPO training: {args.n_episodes} episodes, obs_dim={obs_dim}")
187
+ print(f"Output dir: {output_dir}")
188
+
189
+ for episode in range(1, args.n_episodes + 1):
190
+ # Collect rollout
191
+ rollout = collect_rollout(env, policy, config)
192
+ advantages, returns = compute_returns_and_advantages(
193
+ rollout["rewards"], rollout["values"], rollout["dones"]
194
+ )
195
+ rollout["advantages"] = advantages
196
+ rollout["returns"] = returns
197
+
198
+ # PPO update
199
+ loss = ppo_update(
200
+ policy, [rollout],
201
+ clip_eps=args.clip_eps,
202
+ value_coef=args.value_coef,
203
+ entropy_coef=args.entropy_coef,
204
+ ppo_epochs=args.ppo_epochs,
205
+ )
206
+
207
+ record = {
208
+ "episode": episode,
209
+ "reward": rollout["total_reward"],
210
+ "accuracy": rollout["accuracy"],
211
+ "loss": loss,
212
+ "steps": rollout["steps"],
213
+ }
214
+ history.append(record)
215
+
216
+ if episode % 10 == 0:
217
+ recent = history[-10:]
218
+ avg_r = sum(r["reward"] for r in recent) / len(recent)
219
+ avg_a = sum(r["accuracy"] for r in recent) / len(recent)
220
+ print(f" Ep {episode:4d} | avg_reward={avg_r:.3f} | avg_acc={avg_a:.3f} | loss={loss:.4f}")
221
+
222
+ if episode % 50 == 0:
223
+ ckpt_path = output_dir / f"policy_ep{episode}.pt"
224
+ policy.save(str(ckpt_path))
225
+ print(f" Saved checkpoint: {ckpt_path}")
226
+
227
+ # Final save
228
+ policy.save(str(output_dir / "policy_final.pt"))
229
+ with open(output_dir / "training_history.json", "w") as f:
230
+ json.dump(history, f, indent=2)
231
+
232
+ print(f"\nTraining complete. Final checkpoint: {output_dir / 'policy_final.pt'}")
233
+ final_10 = history[-10:]
234
+ print(f"Last 10 episodes — avg reward: {sum(r['reward'] for r in final_10)/10:.3f}, "
235
+ f"avg accuracy: {sum(r['accuracy'] for r in final_10)/10:.3f}")
236
+
237
+
238
+ def main():
239
+ parser = argparse.ArgumentParser(description="PPO training for ReasonBudgetEnv")
240
+ parser.add_argument("--n_episodes", type=int, default=500)
241
+ parser.add_argument("--ppo_epochs", type=int, default=4)
242
+ parser.add_argument("--clip_eps", type=float, default=0.2)
243
+ parser.add_argument("--value_coef", type=float, default=0.5)
244
+ parser.add_argument("--entropy_coef", type=float, default=0.01)
245
+ parser.add_argument("--seed", type=int, default=42)
246
+ parser.add_argument("--output_dir", type=str, default="runs/ppo_run1")
247
+ args = parser.parse_args()
248
+ train(args)
249
+
250
+
251
+ if __name__ == "__main__":
252
+ main()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi>=0.110.0
2
+ uvicorn[standard]>=0.29.0
3
+ pydantic>=2.0
4
+ numpy>=1.24
5
+ datasets>=2.18.0
6
+ sentence-transformers>=2.7.0
7
+ matplotlib>=3.8
8
+ seaborn>=0.13