Spaces:
Build error
Build error
Upload folder using huggingface_hub
Browse files- .dockerignore +12 -0
- Dockerfile +24 -0
- README.md +88 -10
- client.py +103 -0
- docs/BEGINNER_PROJECT_EXPLANATION.md +271 -0
- models.py +134 -0
- openenv.yaml +44 -0
- pyproject.toml +75 -0
- server/__init__.py +1 -0
- server/app.py +105 -0
- server/environment.py +457 -0
- server/rewards/__init__.py +116 -0
- server/rewards/correctness_rubric.py +57 -0
- server/rewards/diagnosis_rubric.py +100 -0
- server/rewards/portability_rubric.py +45 -0
- server/rewards/rubrics.py +184 -0
- server/rewards/self_correction_rubric.py +61 -0
- server/rewards/speedup_rubric.py +58 -0
- server/scenarios/__init__.py +22 -0
- server/scenarios/adaptive_curriculum.py +148 -0
- server/scenarios/dataset_loader.py +249 -0
- server/scenarios/generator.py +320 -0
- server/scenarios/hardware_profiles.py +72 -0
- server/scenarios/trap_library.py +489 -0
- server/tools/__init__.py +39 -0
- server/tools/_runtime.py +255 -0
- server/tools/bottleneck_reporter.py +103 -0
- server/tools/cpp_compiler.py +382 -0
- server/tools/hardware_profiler.py +56 -0
- server/tools/portability_checker.py +123 -0
- server/tools/python_analyzer.py +219 -0
- server/tools/submit.py +114 -0
- server/tools/verifier.py +356 -0
- tests/__init__.py +0 -0
- tests/smoke_llm_hf.py +487 -0
- tests/test_rewards.py +368 -0
- tests/test_runtime_dispatch.py +225 -0
- tests/test_scenarios.py +310 -0
- tests/test_skeleton.py +178 -0
- tests/test_smoke_gate.py +272 -0
- tests/test_smoke_gate_deep.py +410 -0
- tests/test_tools.py +222 -0
- training/openenv_hackathon_training.ipynb +434 -0
.dockerignore
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
*.pyd
|
| 5 |
+
.pytest_cache/
|
| 6 |
+
.mypy_cache/
|
| 7 |
+
.ruff_cache/
|
| 8 |
+
.git/
|
| 9 |
+
.gitignore
|
| 10 |
+
.env
|
| 11 |
+
artifacts/
|
| 12 |
+
docs/plots/
|
Dockerfile
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 4 |
+
ENV PYTHONUNBUFFERED=1
|
| 5 |
+
|
| 6 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 7 |
+
build-essential \
|
| 8 |
+
g++ \
|
| 9 |
+
git \
|
| 10 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 11 |
+
|
| 12 |
+
WORKDIR /app
|
| 13 |
+
|
| 14 |
+
COPY . /app
|
| 15 |
+
|
| 16 |
+
RUN python -m pip install --upgrade pip && \
|
| 17 |
+
python -m pip install .
|
| 18 |
+
|
| 19 |
+
EXPOSE 7860
|
| 20 |
+
|
| 21 |
+
ENV OPENENV_SERVER_MODE=simulation
|
| 22 |
+
ENV ENABLE_WEB_INTERFACE=1
|
| 23 |
+
|
| 24 |
+
CMD ["python", "-m", "uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,10 +1,88 @@
|
|
| 1 |
-
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Polyglot-Optima
|
| 2 |
+
|
| 3 |
+
Polyglot-Optima is an OpenEnv environment for training an LLM to translate Python functions into hardware-aware C++ that is both fast and correct.
|
| 4 |
+
|
| 5 |
+
## Problem
|
| 6 |
+
|
| 7 |
+
LLMs can generate optimized code, but often fail on edge-case correctness, portability, and anti-gaming behavior (fast but wrong outputs). This environment targets that gap with closed-loop tool use and verifiable rewards.
|
| 8 |
+
|
| 9 |
+
## Environment Design
|
| 10 |
+
|
| 11 |
+
- **API shape:** Gym-style `reset`, `step`, `state`.
|
| 12 |
+
- **3-round episodes:** iterative refinement, final submission at round 3.
|
| 13 |
+
- **9 tools:** profiling, complexity analysis, memory analysis, compile+benchmark, equivalence verifier, portability checker, and final submit.
|
| 14 |
+
- **Reward DAG:** composable rubrics for speedup, correctness, diagnosis quality, portability, and self-correction.
|
| 15 |
+
- **Continuous rewards:** no hard 0/1 optimization cliff in the main learning path.
|
| 16 |
+
|
| 17 |
+
## Innovation Highlights
|
| 18 |
+
|
| 19 |
+
1. **Adaptive 4-axis curriculum** updates global difficulty over batches.
|
| 20 |
+
2. **Adversarial trap library** with category-focused adaptive resampling from recent failures.
|
| 21 |
+
3. **Semantic trap variation** (AST-level no-op rewrites) to reduce memorization.
|
| 22 |
+
4. **Roofline-aware speedup scoring** for hardware-grounded performance reward.
|
| 23 |
+
5. **Anti-gaming verification** through fuzzing + adversarial pass checks.
|
| 24 |
+
|
| 25 |
+
## Why This Matters
|
| 26 |
+
|
| 27 |
+
The target behavior is not just "compile and run", but robust optimization under realistic constraints: correctness under adversarial inputs, reasoning about bottlenecks, and hardware-aware strategy selection.
|
| 28 |
+
|
| 29 |
+
## Local Usage
|
| 30 |
+
|
| 31 |
+
```bash
|
| 32 |
+
python -m pytest -q
|
| 33 |
+
python -m ruff check .
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
Run smoke LLM integration:
|
| 37 |
+
|
| 38 |
+
```bash
|
| 39 |
+
python tests/smoke_llm_hf.py
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
Cursor/OpenAI-compatible provider mode:
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
export LLM_PROVIDER=cursor
|
| 46 |
+
export CURSOR_API_KEY=...
|
| 47 |
+
export CURSOR_MODEL=gpt-4.1-nano
|
| 48 |
+
python tests/smoke_llm_hf.py
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
## Notebook Usage and HF Spaces
|
| 52 |
+
|
| 53 |
+
You can use this environment directly in a local notebook without deploying to HF Spaces.
|
| 54 |
+
|
| 55 |
+
- **For development/training:** local usage is enough.
|
| 56 |
+
- **For hackathon submission:** deploy to HF Spaces and link it in README per requirements.
|
| 57 |
+
|
| 58 |
+
## Current Validation Snapshot
|
| 59 |
+
|
| 60 |
+
- Unit/integration tests passing.
|
| 61 |
+
- Smoke integration path validates parseability/tool-loop behavior.
|
| 62 |
+
- Reward and gate tests verify coherent scoring behavior.
|
| 63 |
+
|
| 64 |
+
## Results (Judge-facing)
|
| 65 |
+
|
| 66 |
+
After running `training/openenv_hackathon_training.ipynb`, add:
|
| 67 |
+
|
| 68 |
+
- Reward distribution plot: `docs/plots/reward_distribution_baseline_vs_trained.png`
|
| 69 |
+
- Correctness curve plot: `docs/plots/correctness_baseline_vs_trained.png`
|
| 70 |
+
- Baseline vs trained metrics table (reward mean, correctness, compile rate, portability).
|
| 71 |
+
|
| 72 |
+
## Required Submission Links
|
| 73 |
+
|
| 74 |
+
Add these links before final submission:
|
| 75 |
+
|
| 76 |
+
- **HF Space (environment URL judges will pull):** `TODO_ADD_HF_SPACE_URL`
|
| 77 |
+
- **Training notebook/script:** `training/openenv_hackathon_training.ipynb`
|
| 78 |
+
- **W&B run (or equivalent training evidence):** `TODO_ADD_WANDB_RUN_URL`
|
| 79 |
+
- **Short writeup/video/slides (<2 min video or mini blog):** `TODO_ADD_STORY_URL`
|
| 80 |
+
|
| 81 |
+
## Submission Checklist (from hackathon PDF)
|
| 82 |
+
|
| 83 |
+
- [ ] Environment deployed to HF Space and URL added above
|
| 84 |
+
- [x] Valid OpenEnv manifest (`openenv.yaml`) present
|
| 85 |
+
- [x] Training notebook/script using TRL/Unsloth path present
|
| 86 |
+
- [ ] Real training evidence linked (loss/reward curves from an actual run)
|
| 87 |
+
- [ ] README includes all judge-facing links (Space + writeup/video/slides + run logs)
|
| 88 |
+
- [ ] Key plots embedded and committed in repo (`docs/plots/*.png`)
|
client.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Polyglot-Optima client — typed wrapper around the WebSocket env API.
|
| 2 |
+
|
| 3 |
+
Two clients are provided:
|
| 4 |
+
- PolyglotOptimaClient: async (the canonical OpenEnv pattern)
|
| 5 |
+
- PolyglotOptimaSyncClient: synchronous wrapper, used inside the TRL training loop
|
| 6 |
+
|
| 7 |
+
Both are typed: `reset()` returns OptimizationObservation, `step()` returns
|
| 8 |
+
StepResult containing OptimizationObservation. No raw dicts.
|
| 9 |
+
|
| 10 |
+
Strict client/server boundary: this module imports nothing from `server/`. All
|
| 11 |
+
communication is over HTTP/WebSocket via the OpenEnv EnvClient base.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
from typing import Any
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from openenv.core.client import EnvClient, SyncEnvClient # type: ignore
|
| 20 |
+
except ImportError:
|
| 21 |
+
# Local-dev stub; real client imported once openenv is installed
|
| 22 |
+
class EnvClient: # type: ignore
|
| 23 |
+
def __init__(self, base_url: str, action_cls=None, observation_cls=None):
|
| 24 |
+
self.base_url = base_url
|
| 25 |
+
self.action_cls = action_cls
|
| 26 |
+
self.observation_cls = observation_cls
|
| 27 |
+
|
| 28 |
+
async def reset(self, seed: int | None = None):
|
| 29 |
+
raise NotImplementedError("Install openenv to use the real client")
|
| 30 |
+
|
| 31 |
+
async def step(self, action):
|
| 32 |
+
raise NotImplementedError("Install openenv to use the real client")
|
| 33 |
+
|
| 34 |
+
class SyncEnvClient(EnvClient): # type: ignore
|
| 35 |
+
def reset(self, seed: int | None = None):
|
| 36 |
+
raise NotImplementedError
|
| 37 |
+
|
| 38 |
+
def step(self, action):
|
| 39 |
+
raise NotImplementedError
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
from models import OptimizationAction, OptimizationObservation
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class PolyglotOptimaClient(EnvClient):
|
| 46 |
+
"""Async typed client.
|
| 47 |
+
|
| 48 |
+
Usage:
|
| 49 |
+
async with PolyglotOptimaClient("ws://localhost:8000") as client:
|
| 50 |
+
obs = await client.reset(seed=42)
|
| 51 |
+
obs = await client.step(OptimizationAction(
|
| 52 |
+
tool_name="profile_python_hotspots",
|
| 53 |
+
tool_args={"code": obs.python_code},
|
| 54 |
+
reasoning_trace="<think>...</think>",
|
| 55 |
+
))
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def __init__(self, base_url: str = "ws://localhost:8000"):
|
| 59 |
+
super().__init__(
|
| 60 |
+
base_url=base_url,
|
| 61 |
+
action_cls=OptimizationAction,
|
| 62 |
+
observation_cls=OptimizationObservation,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# Convenience wrappers — strongly typed
|
| 66 |
+
async def reset(self, seed: int | None = None) -> OptimizationObservation: # type: ignore[override]
|
| 67 |
+
return await super().reset(seed=seed)
|
| 68 |
+
|
| 69 |
+
async def step(self, action: OptimizationAction) -> Any: # type: ignore[override]
|
| 70 |
+
# Returns StepResult with .observation : OptimizationObservation
|
| 71 |
+
return await super().step(action)
|
| 72 |
+
|
| 73 |
+
async def close(self) -> None:
|
| 74 |
+
# OpenEnv-base lifecycle teardown
|
| 75 |
+
if hasattr(super(), "close"):
|
| 76 |
+
await super().close() # type: ignore
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class PolyglotOptimaSyncClient(SyncEnvClient):
|
| 80 |
+
"""Synchronous wrapper for use inside synchronous training loops (TRL GRPOTrainer).
|
| 81 |
+
|
| 82 |
+
Per plan §12 A: SyncEnvClient is the recommended pattern when the host loop
|
| 83 |
+
is synchronous (TRL's training loop is). Internally calls the async client.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
def __init__(self, base_url: str = "http://localhost:8000"):
|
| 87 |
+
super().__init__(
|
| 88 |
+
base_url=base_url,
|
| 89 |
+
action_cls=OptimizationAction,
|
| 90 |
+
observation_cls=OptimizationObservation,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def reset(self, seed: int | None = None) -> OptimizationObservation: # type: ignore[override]
|
| 94 |
+
return super().reset(seed=seed)
|
| 95 |
+
|
| 96 |
+
def step(self, action: OptimizationAction) -> Any: # type: ignore[override]
|
| 97 |
+
return super().step(action)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
__all__ = [
|
| 101 |
+
"PolyglotOptimaClient",
|
| 102 |
+
"PolyglotOptimaSyncClient",
|
| 103 |
+
]
|
docs/BEGINNER_PROJECT_EXPLANATION.md
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Polyglot-Optima Beginner + Technical Explanation
|
| 2 |
+
|
| 3 |
+
This document explains the project from zero, then gradually adds technical depth.
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## 1) One-line idea
|
| 8 |
+
|
| 9 |
+
`Polyglot-Optima` is a training environment where an AI learns to convert Python functions into fast C++ **without breaking correctness**.
|
| 10 |
+
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
## 2) Why this project exists
|
| 14 |
+
|
| 15 |
+
Most code models can produce "fast-looking" code, but in real systems that is not enough.
|
| 16 |
+
|
| 17 |
+
Common failure modes:
|
| 18 |
+
- code compiles but gives wrong outputs,
|
| 19 |
+
- code is fast only on one machine but fails elsewhere,
|
| 20 |
+
- reward is easy to game (model hacks scoring instead of solving task),
|
| 21 |
+
- model does not improve over multiple refinement rounds.
|
| 22 |
+
|
| 23 |
+
This project is built to fix those problems using:
|
| 24 |
+
- strict compile checks,
|
| 25 |
+
- fuzz-based correctness verification,
|
| 26 |
+
- cross-hardware portability checks,
|
| 27 |
+
- anti-gaming trap tasks,
|
| 28 |
+
- curriculum learning (easy -> hard),
|
| 29 |
+
- structured continuous reward.
|
| 30 |
+
|
| 31 |
+
---
|
| 32 |
+
|
| 33 |
+
## 3) Mental model (simple)
|
| 34 |
+
|
| 35 |
+
Think of this project as a game with rules:
|
| 36 |
+
|
| 37 |
+
- **Input:** a Python function + a hardware profile.
|
| 38 |
+
- **Player (AI):** can call tools to analyze and optimize.
|
| 39 |
+
- **Goal:** submit C++ that is fast *and* correct.
|
| 40 |
+
- **Score (reward):** combines speed, correctness, reasoning quality, and portability.
|
| 41 |
+
|
| 42 |
+
The AI plays this game many times and learns better strategies.
|
| 43 |
+
|
| 44 |
+
---
|
| 45 |
+
|
| 46 |
+
## 4) Core architecture
|
| 47 |
+
|
| 48 |
+
Main folders:
|
| 49 |
+
|
| 50 |
+
- `models.py`
|
| 51 |
+
Defines typed data objects for actions, observations, and state.
|
| 52 |
+
|
| 53 |
+
- `server/environment.py`
|
| 54 |
+
The main OpenEnv environment implementation (`reset`, `step`, `state`, `close`).
|
| 55 |
+
|
| 56 |
+
- `server/tools/`
|
| 57 |
+
Actual capability tools (compiler, verifier, profiling, portability, submit).
|
| 58 |
+
|
| 59 |
+
- `server/rewards/`
|
| 60 |
+
Reward rubrics and reward composition logic.
|
| 61 |
+
|
| 62 |
+
- `server/scenarios/`
|
| 63 |
+
Task generators, hardware profiles, trap library, and adaptive curriculum.
|
| 64 |
+
|
| 65 |
+
- `tests/`
|
| 66 |
+
Unit + integration tests validating behavior and quality.
|
| 67 |
+
|
| 68 |
+
---
|
| 69 |
+
|
| 70 |
+
## 5) Episode lifecycle (what happens in one training sample)
|
| 71 |
+
|
| 72 |
+
Each episode has 3 rounds.
|
| 73 |
+
|
| 74 |
+
### Round flow
|
| 75 |
+
1. Environment samples:
|
| 76 |
+
- Python code task
|
| 77 |
+
- hardware profile
|
| 78 |
+
- hidden bottleneck labels (for diagnosis scoring)
|
| 79 |
+
2. Model calls tools (analyze, compile, verify, etc.).
|
| 80 |
+
3. Model eventually calls `submit_optimization`.
|
| 81 |
+
4. Environment computes round reward.
|
| 82 |
+
5. Repeat for rounds 2 and 3.
|
| 83 |
+
6. Final episode reward is computed from round rewards.
|
| 84 |
+
|
| 85 |
+
### Important implementation details
|
| 86 |
+
- `max_calls_per_round` is enforced.
|
| 87 |
+
- If call budget is exhausted, environment forces submit for that round.
|
| 88 |
+
- Adaptive curriculum can update global difficulty after batch outcomes.
|
| 89 |
+
|
| 90 |
+
---
|
| 91 |
+
|
| 92 |
+
## 6) The 9 tools (what the model can do)
|
| 93 |
+
|
| 94 |
+
The AI does not directly "guess" everything. It uses tools:
|
| 95 |
+
|
| 96 |
+
1. `get_hardware_profile`
|
| 97 |
+
2. `profile_python_hotspots`
|
| 98 |
+
3. `analyze_complexity`
|
| 99 |
+
4. `check_memory_access`
|
| 100 |
+
5. `compile_and_benchmark`
|
| 101 |
+
6. `verify_equivalence`
|
| 102 |
+
7. `check_portability`
|
| 103 |
+
8. `get_bottleneck_report`
|
| 104 |
+
9. `submit_optimization` (round-closing action)
|
| 105 |
+
|
| 106 |
+
The most important tools for trustworthiness are:
|
| 107 |
+
- `compile_and_benchmark` (real compile/runtime behavior),
|
| 108 |
+
- `verify_equivalence` (catches wrong-but-fast code),
|
| 109 |
+
- `check_portability` (checks behavior across profiles).
|
| 110 |
+
|
| 111 |
+
---
|
| 112 |
+
|
| 113 |
+
## 7) Reward system explained simply
|
| 114 |
+
|
| 115 |
+
Reward is **continuous**, not just pass/fail.
|
| 116 |
+
|
| 117 |
+
That means:
|
| 118 |
+
- weak solutions get small score,
|
| 119 |
+
- better solutions get higher score,
|
| 120 |
+
- fully good solutions get top score.
|
| 121 |
+
|
| 122 |
+
This is important for RL because the model needs gradient/signal to improve.
|
| 123 |
+
|
| 124 |
+
### Reward components
|
| 125 |
+
- **SpeedupRubric:** how much faster C++ is vs Python baseline
|
| 126 |
+
- **CorrectnessRubric:** fuzz pass-rate quality
|
| 127 |
+
- **CompilationRubric:** compile quality/status
|
| 128 |
+
- **DiagnosisRubric:** quality/coherence of bottleneck reasoning
|
| 129 |
+
- **PortabilityRubric:** cross-profile robustness
|
| 130 |
+
- **SelfCorrectionRubric:** improvement from earlier rounds
|
| 131 |
+
|
| 132 |
+
### Composition
|
| 133 |
+
Reward is composed using rubric operators (`Sequential`, `Gate`, `WeightedSum`), so it is easier to reason about and tune than one large monolithic score function.
|
| 134 |
+
|
| 135 |
+
---
|
| 136 |
+
|
| 137 |
+
## 8) Anti-gaming design
|
| 138 |
+
|
| 139 |
+
This project assumes the model will try shortcuts. So it includes defenses:
|
| 140 |
+
|
| 141 |
+
- Trap functions (overflow, NaN/Inf, aliasing, semantic edge cases)
|
| 142 |
+
- Adversarial fuzzing
|
| 143 |
+
- Correctness + adversarial pass-rate signals
|
| 144 |
+
- Portability checks across hardware profiles
|
| 145 |
+
- Reasoning/diagnosis quality signal
|
| 146 |
+
|
| 147 |
+
Net effect: "fast but wrong" should score poorly.
|
| 148 |
+
|
| 149 |
+
---
|
| 150 |
+
|
| 151 |
+
## 9) Curriculum learning (easy -> hard)
|
| 152 |
+
|
| 153 |
+
Difficulty axes include:
|
| 154 |
+
- function complexity tier,
|
| 155 |
+
- hardware difficulty class,
|
| 156 |
+
- verifier strictness,
|
| 157 |
+
- portability requirement.
|
| 158 |
+
|
| 159 |
+
Curriculum controller monitors success in batches and adjusts:
|
| 160 |
+
- high success -> increase difficulty,
|
| 161 |
+
- low success -> reduce difficulty,
|
| 162 |
+
- middle zone -> hold.
|
| 163 |
+
|
| 164 |
+
This stabilizes learning and prevents early collapse.
|
| 165 |
+
|
| 166 |
+
---
|
| 167 |
+
|
| 168 |
+
## 10) Adaptive traps (what was improved)
|
| 169 |
+
|
| 170 |
+
Adaptive traps now do two things:
|
| 171 |
+
- prioritize categories where the model recently failed,
|
| 172 |
+
- create semantic-preserving trap variants (not only naive renaming).
|
| 173 |
+
|
| 174 |
+
Why this matters:
|
| 175 |
+
- reduces memorization,
|
| 176 |
+
- improves robustness,
|
| 177 |
+
- increases novelty/innovation signal for judges.
|
| 178 |
+
|
| 179 |
+
---
|
| 180 |
+
|
| 181 |
+
## 11) What "good performance" means here
|
| 182 |
+
|
| 183 |
+
Not just one high speedup number.
|
| 184 |
+
|
| 185 |
+
A good policy should show:
|
| 186 |
+
- increasing reward trend,
|
| 187 |
+
- high correctness/adversarial pass-rate,
|
| 188 |
+
- high compile success,
|
| 189 |
+
- better portability over time,
|
| 190 |
+
- stable behavior on held-out/edge-case tasks.
|
| 191 |
+
|
| 192 |
+
---
|
| 193 |
+
|
| 194 |
+
## 12) How to run and verify locally
|
| 195 |
+
|
| 196 |
+
From `polyglot_optima/`:
|
| 197 |
+
|
| 198 |
+
```bash
|
| 199 |
+
python -m ruff check .
|
| 200 |
+
python -m pytest -q
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
+
Smoke test (LLM-in-the-loop):
|
| 204 |
+
|
| 205 |
+
```bash
|
| 206 |
+
python tests/smoke_llm_hf.py
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
Cursor/OpenAI-compatible mode:
|
| 210 |
+
|
| 211 |
+
```bash
|
| 212 |
+
set LLM_PROVIDER=cursor
|
| 213 |
+
set CURSOR_API_KEY=...
|
| 214 |
+
set CURSOR_MODEL=gpt-4.1-nano
|
| 215 |
+
python tests/smoke_llm_hf.py
|
| 216 |
+
```
|
| 217 |
+
|
| 218 |
+
---
|
| 219 |
+
|
| 220 |
+
## 13) Training workflow for beginners
|
| 221 |
+
|
| 222 |
+
Use `training/openenv_hackathon_training.ipynb`:
|
| 223 |
+
|
| 224 |
+
1. Configure model + episodes + logging.
|
| 225 |
+
2. Run baseline eval first (fixed seeds).
|
| 226 |
+
3. Run RL training (TRL scaffold cell).
|
| 227 |
+
4. Run post-training eval with same seed protocol.
|
| 228 |
+
5. Export plots to `docs/plots`.
|
| 229 |
+
6. Add results to `README.md`.
|
| 230 |
+
|
| 231 |
+
Track at least:
|
| 232 |
+
- reward,
|
| 233 |
+
- correctness pass rate,
|
| 234 |
+
- compile success rate,
|
| 235 |
+
- portability metrics.
|
| 236 |
+
|
| 237 |
+
---
|
| 238 |
+
|
| 239 |
+
## 14) How this maps to hackathon judging
|
| 240 |
+
|
| 241 |
+
The project can score well if you clearly show:
|
| 242 |
+
|
| 243 |
+
- **Innovation:** adaptive curriculum + anti-gaming traps + structured reward
|
| 244 |
+
- **Storytelling:** clear problem -> method -> before/after outcome
|
| 245 |
+
- **Improvement evidence:** baseline vs trained plots
|
| 246 |
+
- **Pipeline quality:** reproducible notebook/script + OpenEnv-compliant deployment
|
| 247 |
+
|
| 248 |
+
---
|
| 249 |
+
|
| 250 |
+
## 15) Most important files to read next
|
| 251 |
+
|
| 252 |
+
Recommended reading order:
|
| 253 |
+
|
| 254 |
+
1. `README.md`
|
| 255 |
+
2. `models.py`
|
| 256 |
+
3. `server/environment.py`
|
| 257 |
+
4. `server/tools/submit.py`
|
| 258 |
+
5. `server/tools/cpp_compiler.py`
|
| 259 |
+
6. `server/tools/verifier.py`
|
| 260 |
+
7. `server/rewards/__init__.py`
|
| 261 |
+
8. `server/scenarios/dataset_loader.py`
|
| 262 |
+
9. `tests/test_skeleton.py`
|
| 263 |
+
|
| 264 |
+
---
|
| 265 |
+
|
| 266 |
+
## 16) Beginner takeaway
|
| 267 |
+
|
| 268 |
+
If you remember one thing:
|
| 269 |
+
|
| 270 |
+
This is not just "code generation."
|
| 271 |
+
It is a full RL environment that teaches an AI to do **correct, robust, hardware-aware optimization** under realistic constraints.
|
models.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic data models for Polyglot-Optima environment.
|
| 2 |
+
|
| 3 |
+
Three core types:
|
| 4 |
+
- OptimizationAction: what the agent sends to the env each turn
|
| 5 |
+
- OptimizationObservation: what the env returns each step
|
| 6 |
+
- OptimizationState: episode state tracked by the env (episode_id, step_count, round_number, etc.)
|
| 7 |
+
|
| 8 |
+
These map onto the OpenEnv Action/Observation/State base classes.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
from typing import Any, Literal
|
| 14 |
+
|
| 15 |
+
from pydantic import BaseModel, Field
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# ----------------------------- Action -----------------------------
|
| 19 |
+
|
| 20 |
+
class OptimizationAction(BaseModel):
|
| 21 |
+
"""One agent turn.
|
| 22 |
+
|
| 23 |
+
Either a tool call (most turns) or a final submission (last turn of round 3).
|
| 24 |
+
The agent's reasoning trace is required so the DiagnosisRubric can score it.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
tool_name: str = Field(..., description="Name of the MCP tool to call")
|
| 28 |
+
tool_args: dict[str, Any] = Field(default_factory=dict, description="Arguments to the tool")
|
| 29 |
+
reasoning_trace: str = Field(
|
| 30 |
+
default="",
|
| 31 |
+
description="Agent's <think>...</think> trace before this action. "
|
| 32 |
+
"Required to be non-empty for DiagnosisRubric scoring.",
|
| 33 |
+
max_length=2048,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
model_config = {"extra": "forbid"}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# --------------------------- Observation ---------------------------
|
| 40 |
+
|
| 41 |
+
class OptimizationObservation(BaseModel):
|
| 42 |
+
"""One env response.
|
| 43 |
+
|
| 44 |
+
Returned by env.step() and env.reset(). Contains tool result, episode state,
|
| 45 |
+
and per-step debug telemetry in `metadata` (sub-rubric scores, axis levels,
|
| 46 |
+
fuzz failure samples, etc.).
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
# Standard OpenEnv Observation fields
|
| 50 |
+
done: bool = Field(default=False, description="True iff episode is over")
|
| 51 |
+
reward: float = Field(default=0.0, description="Reward for this step (0 unless terminal)")
|
| 52 |
+
|
| 53 |
+
# Domain-specific payload
|
| 54 |
+
tool_result: dict[str, Any] = Field(default_factory=dict, description="Output of the tool just called")
|
| 55 |
+
|
| 56 |
+
# Environment context exposed to the agent
|
| 57 |
+
python_code: str = Field(default="", description="The Python function the agent is optimizing")
|
| 58 |
+
hardware_profile: dict[str, Any] = Field(
|
| 59 |
+
default_factory=dict,
|
| 60 |
+
description="Synthetic hardware spec for this episode (cores, simd, bandwidth, roofline_bound)",
|
| 61 |
+
)
|
| 62 |
+
round_number: int = Field(default=1, description="Current refinement round (1, 2, or 3)")
|
| 63 |
+
rounds_remaining: int = Field(default=2)
|
| 64 |
+
|
| 65 |
+
# Cumulative state visible to the agent
|
| 66 |
+
best_speedup_so_far: float = Field(default=0.0)
|
| 67 |
+
last_compile_status: Literal["pending", "success", "syntax_error", "link_error", "timeout"] = "pending"
|
| 68 |
+
last_correctness_pass_rate: float = Field(default=0.0)
|
| 69 |
+
|
| 70 |
+
# Telemetry — used by training infra, not necessarily shown to the model
|
| 71 |
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
| 72 |
+
|
| 73 |
+
model_config = {"extra": "forbid"}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ----------------------------- State ------------------------------
|
| 77 |
+
|
| 78 |
+
class OptimizationState(BaseModel):
|
| 79 |
+
"""Episode-level state tracked by the environment server.
|
| 80 |
+
|
| 81 |
+
Not every field is exposed to the agent in each Observation. Some are
|
| 82 |
+
server-internal (e.g., the ground-truth bottleneck label, the trap function
|
| 83 |
+
metadata, the curriculum axis levels).
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
# Identity
|
| 87 |
+
episode_id: str
|
| 88 |
+
step_count: int = 0
|
| 89 |
+
round_number: int = 1
|
| 90 |
+
is_terminal: bool = False
|
| 91 |
+
|
| 92 |
+
# Problem instance
|
| 93 |
+
python_code: str = ""
|
| 94 |
+
function_signature_cpp: str = "" # extern "C" void agent_function(...) — derived from AST
|
| 95 |
+
hardware_profile: dict[str, Any] = Field(default_factory=dict)
|
| 96 |
+
|
| 97 |
+
# Ground-truth (server-only — never sent to agent)
|
| 98 |
+
bottleneck_ground_truth: list[str] = Field(default_factory=list) # e.g., ["compute-bound", "vectorizable"]
|
| 99 |
+
bottleneck_distractors: list[str] = Field(default_factory=list)
|
| 100 |
+
rtol_override: float | None = None # Some functions need bit-exact (rtol=0); most use 1e-5
|
| 101 |
+
|
| 102 |
+
# Per-round history
|
| 103 |
+
round_results: list[dict[str, Any]] = Field(default_factory=list)
|
| 104 |
+
best_speedup: float = 0.0
|
| 105 |
+
best_cpp_code: str = ""
|
| 106 |
+
|
| 107 |
+
# Tool-call history within the current round (for action-coherence diagnosis bonus)
|
| 108 |
+
current_round_tool_calls: list[str] = Field(default_factory=list)
|
| 109 |
+
current_round_reasoning: str = ""
|
| 110 |
+
|
| 111 |
+
# Adaptive curriculum axis levels at episode start (frozen for the episode)
|
| 112 |
+
difficulty_axes: dict[str, int] = Field(
|
| 113 |
+
default_factory=lambda: {
|
| 114 |
+
"function_tier": 0, # 0..3
|
| 115 |
+
"hardware_class": 0, # 0..2
|
| 116 |
+
"fuzzer_strictness": 0, # 0..2
|
| 117 |
+
"portability_required": 0, # 0..1
|
| 118 |
+
}
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Trap flag — is this episode a known anti-gaming trap?
|
| 122 |
+
is_trap: bool = False
|
| 123 |
+
trap_id: str | None = None
|
| 124 |
+
|
| 125 |
+
model_config = {"extra": "forbid"}
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# ------------------------- Public re-exports ----------------------
|
| 129 |
+
|
| 130 |
+
__all__ = [
|
| 131 |
+
"OptimizationAction",
|
| 132 |
+
"OptimizationObservation",
|
| 133 |
+
"OptimizationState",
|
| 134 |
+
]
|
openenv.yaml
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: polyglot-optima
|
| 2 |
+
version: 1.0.0
|
| 3 |
+
description: |
|
| 4 |
+
Adversarial Neural JIT Compiler. Trains a reasoning LLM to translate Python
|
| 5 |
+
functions into hardware-aware optimized C++20 that beats GCC -O3 of a naive
|
| 6 |
+
translation. Uses an adaptive 4-axis curriculum, Roofline-grounded reward,
|
| 7 |
+
reasoning-trace-as-RL-signal, cross-hardware portability bonus, and a 30-trap
|
| 8 |
+
anti-gaming library.
|
| 9 |
+
|
| 10 |
+
# Informal metadata (schema not yet published in openenv.yaml; surfaced for catalog)
|
| 11 |
+
metadata:
|
| 12 |
+
themes:
|
| 13 |
+
- world-modeling-professional
|
| 14 |
+
- self-improvement
|
| 15 |
+
hackathon: meta-pytorch-openenv-india-2026
|
| 16 |
+
max_turns: 12
|
| 17 |
+
episode_rounds: 3
|
| 18 |
+
model_targets:
|
| 19 |
+
optimizer: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
|
| 20 |
+
generator: Qwen/Qwen2.5-Coder-1.5B-Instruct
|
| 21 |
+
hardware_profiles_count: 8
|
| 22 |
+
difficulty_axes: 4
|
| 23 |
+
|
| 24 |
+
server:
|
| 25 |
+
entry_point: server.app:app
|
| 26 |
+
module: server.app
|
| 27 |
+
app_factory: build_app
|
| 28 |
+
|
| 29 |
+
client:
|
| 30 |
+
module: client
|
| 31 |
+
class: PolyglotOptimaClient
|
| 32 |
+
|
| 33 |
+
# Tool list — auto-discovered from @tool decorators in server/tools/*.py
|
| 34 |
+
# but listed here for catalog/discoverability
|
| 35 |
+
tools:
|
| 36 |
+
- get_hardware_profile
|
| 37 |
+
- profile_python_hotspots
|
| 38 |
+
- analyze_complexity
|
| 39 |
+
- check_memory_access
|
| 40 |
+
- compile_and_benchmark
|
| 41 |
+
- verify_equivalence
|
| 42 |
+
- check_portability
|
| 43 |
+
- get_bottleneck_report
|
| 44 |
+
- submit_optimization
|
pyproject.toml
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=68", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "polyglot-optima"
|
| 7 |
+
version = "1.0.0"
|
| 8 |
+
description = "Adversarial Neural JIT Compiler — Python to hardware-optimized C++ via RL"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.10"
|
| 11 |
+
license = { text = "Apache-2.0" }
|
| 12 |
+
authors = [
|
| 13 |
+
{ name = "Swastik R", email = "swastik.r.900@gmail.com" }
|
| 14 |
+
]
|
| 15 |
+
keywords = ["openenv", "rl", "compiler", "code-optimization", "grpo", "agentic"]
|
| 16 |
+
|
| 17 |
+
dependencies = [
|
| 18 |
+
# OpenEnv core
|
| 19 |
+
"openenv>=0.3.0",
|
| 20 |
+
# Server
|
| 21 |
+
"fastapi>=0.110",
|
| 22 |
+
"uvicorn[standard]>=0.27",
|
| 23 |
+
"pydantic>=2.6",
|
| 24 |
+
"websockets>=12.0",
|
| 25 |
+
# Tools
|
| 26 |
+
"numpy>=1.26",
|
| 27 |
+
"scipy>=1.12",
|
| 28 |
+
"scikit-learn>=1.4",
|
| 29 |
+
# Code analysis
|
| 30 |
+
"astroid>=3.0",
|
| 31 |
+
# Compilation + execution
|
| 32 |
+
"pybind11>=2.13",
|
| 33 |
+
# Datasets
|
| 34 |
+
"datasets>=2.18",
|
| 35 |
+
"huggingface_hub>=0.22",
|
| 36 |
+
# UI
|
| 37 |
+
"gradio>=4.0",
|
| 38 |
+
# Logging
|
| 39 |
+
"wandb>=0.16",
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
[project.optional-dependencies]
|
| 43 |
+
training = [
|
| 44 |
+
# GRPO + Unsloth
|
| 45 |
+
"trl>=0.14.0",
|
| 46 |
+
"unsloth",
|
| 47 |
+
"transformers>=4.40",
|
| 48 |
+
"accelerate>=0.30",
|
| 49 |
+
"peft>=0.11",
|
| 50 |
+
"bitsandbytes>=0.43",
|
| 51 |
+
"vllm>=0.5.0",
|
| 52 |
+
"torch>=2.3",
|
| 53 |
+
]
|
| 54 |
+
dev = [
|
| 55 |
+
"pytest>=8.0",
|
| 56 |
+
"pytest-asyncio>=0.23",
|
| 57 |
+
"ruff>=0.4",
|
| 58 |
+
"mypy>=1.10",
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
[project.urls]
|
| 62 |
+
Repository = "https://github.com/QuantumByte-01/Openenv-Hack-finale"
|
| 63 |
+
HFSpace = "https://huggingface.co/spaces/swastik/polyglot-optima"
|
| 64 |
+
|
| 65 |
+
[tool.setuptools.packages.find]
|
| 66 |
+
where = ["."]
|
| 67 |
+
include = ["server*", "training*", "eval*"]
|
| 68 |
+
|
| 69 |
+
[tool.ruff]
|
| 70 |
+
line-length = 110
|
| 71 |
+
target-version = "py310"
|
| 72 |
+
|
| 73 |
+
[tool.pytest.ini_options]
|
| 74 |
+
testpaths = ["tests"]
|
| 75 |
+
asyncio_mode = "auto"
|
server/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Polyglot-Optima OpenEnv server package."""
|
server/app.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI app factory for Polyglot-Optima.
|
| 2 |
+
|
| 3 |
+
Uses OpenEnv's create_app() to wire the MCPEnvironment to HTTP/WebSocket transport.
|
| 4 |
+
Optionally mounts a Gradio /web UI via gradio_builder for the live demo.
|
| 5 |
+
|
| 6 |
+
Entry point referenced by openenv.yaml:
|
| 7 |
+
server: entry_point: server.app:app
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
from typing import Any
|
| 14 |
+
|
| 15 |
+
# OpenEnv imports — confirmed APIs per plan §12
|
| 16 |
+
try:
|
| 17 |
+
from openenv.core import create_app, ConcurrencyConfig, ServerMode # type: ignore
|
| 18 |
+
except ImportError:
|
| 19 |
+
# Fallback factory for local development before openenv is installed
|
| 20 |
+
def create_app(env, action_cls, observation_cls, env_name, **kwargs): # type: ignore
|
| 21 |
+
from fastapi import FastAPI
|
| 22 |
+
app = FastAPI(title=env_name)
|
| 23 |
+
|
| 24 |
+
@app.get("/health")
|
| 25 |
+
def health():
|
| 26 |
+
return {"ok": True, "env": env_name, "stub": True}
|
| 27 |
+
|
| 28 |
+
return app
|
| 29 |
+
|
| 30 |
+
class ConcurrencyConfig: # type: ignore
|
| 31 |
+
def __init__(self, max_concurrent_envs=8, session_timeout=300):
|
| 32 |
+
self.max_concurrent_envs = max_concurrent_envs
|
| 33 |
+
self.session_timeout = session_timeout
|
| 34 |
+
|
| 35 |
+
class ServerMode: # type: ignore
|
| 36 |
+
SIMULATION = "simulation"
|
| 37 |
+
PRODUCTION = "production"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
from models import OptimizationAction, OptimizationObservation
|
| 41 |
+
from server.environment import PolyglotOptimaEnvironment
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def build_gradio_ui(web_manager, action_fields, metadata, is_chat_env, title, quick_start_md):
|
| 45 |
+
"""Custom Gradio /web UI for the live Polyglot-Optima demo.
|
| 46 |
+
|
| 47 |
+
Wired into create_app() via the gradio_builder parameter (per plan §12 F).
|
| 48 |
+
Full implementation lives in Hour 42-48; for now this returns a minimal
|
| 49 |
+
Blocks instance so the framework's web-interface mount succeeds.
|
| 50 |
+
"""
|
| 51 |
+
try:
|
| 52 |
+
import gradio as gr
|
| 53 |
+
except ImportError:
|
| 54 |
+
return None
|
| 55 |
+
|
| 56 |
+
with gr.Blocks(title="Polyglot-Optima — Python → Optimized C++") as demo:
|
| 57 |
+
gr.Markdown(f"# {title}\n\n{quick_start_md or ''}")
|
| 58 |
+
gr.Markdown(
|
| 59 |
+
"**Status**: Skeleton (Hour 0-4). The live demo (paste Python → see C++ + speedup) "
|
| 60 |
+
"ships in Hour 42-48 of the build."
|
| 61 |
+
)
|
| 62 |
+
with gr.Row():
|
| 63 |
+
gr.Code(
|
| 64 |
+
label="Paste Python function",
|
| 65 |
+
language="python",
|
| 66 |
+
value="def sum_squares(arr):\n total = 0\n for x in arr:\n total += x * x\n return total\n",
|
| 67 |
+
)
|
| 68 |
+
gr.Code(label="Agent's optimized C++", language="cpp", value="// Coming soon")
|
| 69 |
+
gr.Button("Optimize", interactive=False)
|
| 70 |
+
gr.Markdown("_Demo wires up in Hour 42-48 — current build is the skeleton._")
|
| 71 |
+
|
| 72 |
+
return demo
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def build_app() -> Any:
|
| 76 |
+
"""Build and return the FastAPI app (OpenEnv create_app pattern)."""
|
| 77 |
+
enable_adaptive_curriculum = os.environ.get("POLYGLOT_OPTIMA_ENABLE_ADAPTIVE_CURRICULUM", "1") == "1"
|
| 78 |
+
curriculum_batch_size = int(os.environ.get("POLYGLOT_OPTIMA_CURRICULUM_BATCH_SIZE", "8"))
|
| 79 |
+
env = PolyglotOptimaEnvironment(
|
| 80 |
+
max_rounds=3,
|
| 81 |
+
max_calls_per_round=5,
|
| 82 |
+
enable_adaptive_curriculum=enable_adaptive_curriculum,
|
| 83 |
+
curriculum_batch_size=curriculum_batch_size,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
server_mode_str = os.environ.get("OPENENV_SERVER_MODE", "simulation").lower()
|
| 87 |
+
server_mode = ServerMode.PRODUCTION if server_mode_str == "production" else ServerMode.SIMULATION
|
| 88 |
+
|
| 89 |
+
enable_web = os.environ.get("ENABLE_WEB_INTERFACE", "1") == "1"
|
| 90 |
+
|
| 91 |
+
app = create_app(
|
| 92 |
+
env=env,
|
| 93 |
+
action_cls=OptimizationAction,
|
| 94 |
+
observation_cls=OptimizationObservation,
|
| 95 |
+
env_name="polyglot-optima",
|
| 96 |
+
max_concurrent_envs=8,
|
| 97 |
+
session_timeout=600,
|
| 98 |
+
server_mode=server_mode,
|
| 99 |
+
gradio_builder=build_gradio_ui if enable_web else None,
|
| 100 |
+
)
|
| 101 |
+
return app
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# OpenEnv discovers the FastAPI instance via this module-level binding
|
| 105 |
+
app = build_app()
|
server/environment.py
ADDED
|
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PolyglotOptimaEnvironment — MCPEnvironment subclass with explicit Gym API.
|
| 2 |
+
|
| 3 |
+
Implements:
|
| 4 |
+
- reset(seed=None) -> Observation # samples a Python function + hardware profile
|
| 5 |
+
- step(action) -> StepResult # routes tool calls, advances rounds, computes reward
|
| 6 |
+
- state() -> State # episode_id, step_count, round_number
|
| 7 |
+
- close() # releases compiler subprocesses, fuzzer pool
|
| 8 |
+
|
| 9 |
+
Round structure per episode:
|
| 10 |
+
round 1: agent has up to N tool calls, then submits via submit_optimization → R1 reward
|
| 11 |
+
round 2: same, with R1 result available in observation → R2 reward
|
| 12 |
+
round 3: same, FINAL strict gate (≥95% fuzz pass) → R3 reward
|
| 13 |
+
episode_reward = 0.3 * R1_reward + 0.7 * R3_reward (R2 is informational)
|
| 14 |
+
|
| 15 |
+
The four difficulty axes are frozen at reset() time for each episode but the
|
| 16 |
+
adaptive_curriculum module updates them across batches based on success rates.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import random
|
| 22 |
+
import uuid
|
| 23 |
+
from dataclasses import dataclass
|
| 24 |
+
from typing import Any
|
| 25 |
+
|
| 26 |
+
# OpenEnv imports — actual class names per the framework docs.
|
| 27 |
+
# We accept that some specific imports may need to be adjusted at integration time;
|
| 28 |
+
# all are documented as confirmed in §12 of the plan.
|
| 29 |
+
try:
|
| 30 |
+
from openenv.core import MCPEnvironment, StepResult # type: ignore
|
| 31 |
+
from openenv.core.exceptions import OpenEnvError # type: ignore
|
| 32 |
+
except ImportError:
|
| 33 |
+
# Allow stubs for local development before openenv is installed
|
| 34 |
+
class MCPEnvironment: # type: ignore
|
| 35 |
+
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 36 |
+
async def reset_async(self, seed=None): raise NotImplementedError
|
| 37 |
+
async def step_async(self, action): raise NotImplementedError
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class StepResult: # type: ignore
|
| 41 |
+
observation: Any
|
| 42 |
+
reward: float
|
| 43 |
+
done: bool
|
| 44 |
+
info: dict[str, Any] | None = None
|
| 45 |
+
|
| 46 |
+
class OpenEnvError(Exception): # type: ignore
|
| 47 |
+
pass
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
from models import (
|
| 51 |
+
OptimizationAction,
|
| 52 |
+
OptimizationObservation,
|
| 53 |
+
OptimizationState,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# Reserved names that MUST NOT be used as MCP tool names per OpenEnv spec
|
| 58 |
+
_RESERVED_TOOL_NAMES = {"reset", "step", "state", "close"}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class PolyglotOptimaEnvironment(MCPEnvironment):
|
| 62 |
+
"""The hardware-aware Python→C++ optimization environment.
|
| 63 |
+
|
| 64 |
+
Public API:
|
| 65 |
+
env.reset(seed=...) -> OptimizationObservation
|
| 66 |
+
env.step(action: OptimizationAction) -> StepResult
|
| 67 |
+
env.state() -> OptimizationState
|
| 68 |
+
env.close()
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 72 |
+
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
max_rounds: int = 3,
|
| 76 |
+
max_calls_per_round: int = 5,
|
| 77 |
+
adaptive_axes: dict[str, int] | None = None,
|
| 78 |
+
enable_adaptive_curriculum: bool = True,
|
| 79 |
+
curriculum_batch_size: int = 8,
|
| 80 |
+
):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.max_rounds = max_rounds
|
| 83 |
+
self.max_calls_per_round = max_calls_per_round
|
| 84 |
+
self.enable_adaptive_curriculum = enable_adaptive_curriculum
|
| 85 |
+
self.curriculum_batch_size = max(1, int(curriculum_batch_size))
|
| 86 |
+
# Default axes — overridden by adaptive_curriculum across batches
|
| 87 |
+
self._global_axes = adaptive_axes or {
|
| 88 |
+
"function_tier": 0,
|
| 89 |
+
"hardware_class": 0,
|
| 90 |
+
"fuzzer_strictness": 0,
|
| 91 |
+
"portability_required": 0,
|
| 92 |
+
}
|
| 93 |
+
self._sessions: dict[str, OptimizationState] = {}
|
| 94 |
+
self._active_episode_id: str | None = None
|
| 95 |
+
|
| 96 |
+
# Lazy imports — modules built in subsequent hours
|
| 97 |
+
self._tool_registry: dict[str, Any] = {}
|
| 98 |
+
self._dataset_loader = None
|
| 99 |
+
self._hardware_profiles = None
|
| 100 |
+
self._reward_dag = None
|
| 101 |
+
self._curriculum = None
|
| 102 |
+
self._episode_success_buffer: list[float] = []
|
| 103 |
+
|
| 104 |
+
# -------------------- Gym-style explicit API --------------------
|
| 105 |
+
|
| 106 |
+
def reset(self, seed: int | None = None) -> OptimizationObservation:
|
| 107 |
+
"""Initialize a new episode.
|
| 108 |
+
|
| 109 |
+
Samples (Python function, hardware profile, difficulty axes) deterministically
|
| 110 |
+
from `seed` if provided. Returns the initial Observation.
|
| 111 |
+
"""
|
| 112 |
+
rng = random.Random(seed)
|
| 113 |
+
episode_id = str(uuid.uuid4())
|
| 114 |
+
|
| 115 |
+
# Lazy init of subsystems (built in later hours; placeholders for now)
|
| 116 |
+
self._ensure_subsystems_loaded()
|
| 117 |
+
|
| 118 |
+
# Sample the problem instance
|
| 119 |
+
problem = self._sample_problem(rng)
|
| 120 |
+
|
| 121 |
+
state = OptimizationState(
|
| 122 |
+
episode_id=episode_id,
|
| 123 |
+
step_count=0,
|
| 124 |
+
round_number=1,
|
| 125 |
+
is_terminal=False,
|
| 126 |
+
python_code=problem["python_code"],
|
| 127 |
+
function_signature_cpp=problem["cpp_signature"],
|
| 128 |
+
hardware_profile=problem["hardware_profile"],
|
| 129 |
+
bottleneck_ground_truth=problem["bottleneck_labels"],
|
| 130 |
+
bottleneck_distractors=problem["bottleneck_distractors"],
|
| 131 |
+
rtol_override=problem.get("rtol_override"),
|
| 132 |
+
difficulty_axes=dict(self._global_axes),
|
| 133 |
+
is_trap=problem.get("is_trap", False),
|
| 134 |
+
trap_id=problem.get("trap_id"),
|
| 135 |
+
)
|
| 136 |
+
self._sessions[episode_id] = state
|
| 137 |
+
self._active_episode_id = episode_id
|
| 138 |
+
|
| 139 |
+
return OptimizationObservation(
|
| 140 |
+
done=False,
|
| 141 |
+
reward=0.0,
|
| 142 |
+
tool_result={"event": "episode_start", "episode_id": episode_id},
|
| 143 |
+
python_code=state.python_code,
|
| 144 |
+
hardware_profile=state.hardware_profile,
|
| 145 |
+
round_number=1,
|
| 146 |
+
rounds_remaining=self.max_rounds - 1,
|
| 147 |
+
best_speedup_so_far=0.0,
|
| 148 |
+
metadata={
|
| 149 |
+
"episode_id": episode_id,
|
| 150 |
+
"difficulty_axes": state.difficulty_axes,
|
| 151 |
+
# NOTE: bottleneck_ground_truth is NOT exposed to the agent —
|
| 152 |
+
# only used by the server when scoring DiagnosisRubric
|
| 153 |
+
},
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
def step(self, action: OptimizationAction) -> StepResult:
|
| 157 |
+
"""Execute one tool call or final submission.
|
| 158 |
+
|
| 159 |
+
The action.tool_name routes to a registered MCP tool. If the tool is
|
| 160 |
+
`submit_optimization`, the current round closes — reward is computed,
|
| 161 |
+
round advances, and on round 3 the episode terminates.
|
| 162 |
+
"""
|
| 163 |
+
if not self._sessions:
|
| 164 |
+
raise OpenEnvError("No active episode. Call reset() first.")
|
| 165 |
+
if self._active_episode_id and self._active_episode_id in self._sessions:
|
| 166 |
+
state = self._sessions[self._active_episode_id]
|
| 167 |
+
else:
|
| 168 |
+
# Fall back to the most recently created episode.
|
| 169 |
+
latest_episode_id = next(reversed(self._sessions))
|
| 170 |
+
self._active_episode_id = latest_episode_id
|
| 171 |
+
state = self._sessions[latest_episode_id]
|
| 172 |
+
|
| 173 |
+
if state.is_terminal:
|
| 174 |
+
raise OpenEnvError("Episode is already terminal. Call reset() to start a new one.")
|
| 175 |
+
|
| 176 |
+
forced_submit = False
|
| 177 |
+
effective_tool_name = action.tool_name
|
| 178 |
+
effective_tool_args = dict(action.tool_args or {})
|
| 179 |
+
if (
|
| 180 |
+
action.tool_name != "submit_optimization"
|
| 181 |
+
and len(state.current_round_tool_calls) >= self.max_calls_per_round
|
| 182 |
+
):
|
| 183 |
+
forced_submit = True
|
| 184 |
+
effective_tool_name = "submit_optimization"
|
| 185 |
+
effective_tool_args = {
|
| 186 |
+
"cpp_code": effective_tool_args.get("cpp_code", "// auto-forced submit: call budget reached"),
|
| 187 |
+
"reasoning_trace": action.reasoning_trace or "auto forced submit after max tool calls",
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
if effective_tool_name in _RESERVED_TOOL_NAMES:
|
| 191 |
+
raise OpenEnvError(
|
| 192 |
+
f"Tool name '{effective_tool_name}' is reserved. "
|
| 193 |
+
f"Reserved names: {sorted(_RESERVED_TOOL_NAMES)}"
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Track tool call + reasoning trace for this round
|
| 197 |
+
state.step_count += 1
|
| 198 |
+
state.current_round_tool_calls.append(effective_tool_name)
|
| 199 |
+
if action.reasoning_trace:
|
| 200 |
+
state.current_round_reasoning += action.reasoning_trace + "\n"
|
| 201 |
+
|
| 202 |
+
# Route to the named tool — full implementation in Hour 4–10
|
| 203 |
+
tool_result = self._dispatch_tool(effective_tool_name, effective_tool_args, state)
|
| 204 |
+
|
| 205 |
+
# Is this a round-closing submission?
|
| 206 |
+
is_submit = effective_tool_name == "submit_optimization"
|
| 207 |
+
round_reward = 0.0
|
| 208 |
+
terminal = False
|
| 209 |
+
|
| 210 |
+
if is_submit:
|
| 211 |
+
# Compute reward for this round (Hour 10–16 implementation)
|
| 212 |
+
round_reward = self._compute_round_reward(state, tool_result)
|
| 213 |
+
if self._dataset_loader is not None and hasattr(self._dataset_loader, "record_submission_outcome"):
|
| 214 |
+
self._dataset_loader.record_submission_outcome(state, tool_result)
|
| 215 |
+
state.round_results.append({
|
| 216 |
+
"round": state.round_number,
|
| 217 |
+
"reward": round_reward,
|
| 218 |
+
"tool_calls": list(state.current_round_tool_calls),
|
| 219 |
+
"reasoning": state.current_round_reasoning,
|
| 220 |
+
"submission": tool_result,
|
| 221 |
+
})
|
| 222 |
+
# Reset per-round buffers
|
| 223 |
+
state.current_round_tool_calls.clear()
|
| 224 |
+
state.current_round_reasoning = ""
|
| 225 |
+
# Advance round
|
| 226 |
+
state.round_number += 1
|
| 227 |
+
if state.round_number > self.max_rounds:
|
| 228 |
+
terminal = True
|
| 229 |
+
state.is_terminal = True
|
| 230 |
+
|
| 231 |
+
observation = OptimizationObservation(
|
| 232 |
+
done=terminal,
|
| 233 |
+
reward=round_reward,
|
| 234 |
+
tool_result=tool_result,
|
| 235 |
+
python_code=state.python_code,
|
| 236 |
+
hardware_profile=state.hardware_profile,
|
| 237 |
+
round_number=min(state.round_number, self.max_rounds),
|
| 238 |
+
rounds_remaining=max(0, self.max_rounds - state.round_number),
|
| 239 |
+
best_speedup_so_far=state.best_speedup,
|
| 240 |
+
last_compile_status=tool_result.get("compile_status", "pending"),
|
| 241 |
+
last_correctness_pass_rate=tool_result.get("pass_rate", 0.0),
|
| 242 |
+
metadata={
|
| 243 |
+
"episode_id": state.episode_id,
|
| 244 |
+
"step_count": state.step_count,
|
| 245 |
+
"tool_called": effective_tool_name,
|
| 246 |
+
"forced_submit": forced_submit,
|
| 247 |
+
},
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
# Final episode reward = 0.3*R1 + 0.7*R3 (per plan §10)
|
| 251 |
+
if terminal:
|
| 252 |
+
r1 = next((r["reward"] for r in state.round_results if r["round"] == 1), 0.0)
|
| 253 |
+
r3 = next((r["reward"] for r in state.round_results if r["round"] == 3), 0.0)
|
| 254 |
+
observation.reward = 0.3 * r1 + 0.7 * r3
|
| 255 |
+
observation.metadata["episode_reward_breakdown"] = {
|
| 256 |
+
"r1": r1,
|
| 257 |
+
"r3": r3,
|
| 258 |
+
"episode_total": observation.reward,
|
| 259 |
+
}
|
| 260 |
+
self._record_episode_outcome(state, observation)
|
| 261 |
+
|
| 262 |
+
return StepResult(
|
| 263 |
+
observation=observation,
|
| 264 |
+
reward=observation.reward,
|
| 265 |
+
done=terminal,
|
| 266 |
+
info={"state_snapshot_id": state.episode_id, "step": state.step_count},
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
def state(self) -> OptimizationState:
|
| 270 |
+
"""Return current episode state (Gym-style state introspection)."""
|
| 271 |
+
if not self._sessions:
|
| 272 |
+
raise OpenEnvError("No active episode.")
|
| 273 |
+
if self._active_episode_id and self._active_episode_id in self._sessions:
|
| 274 |
+
return self._sessions[self._active_episode_id]
|
| 275 |
+
latest_episode_id = next(reversed(self._sessions))
|
| 276 |
+
self._active_episode_id = latest_episode_id
|
| 277 |
+
return self._sessions[latest_episode_id]
|
| 278 |
+
|
| 279 |
+
def close(self) -> None:
|
| 280 |
+
"""Release all resources (compiler subprocesses, fuzzer pool)."""
|
| 281 |
+
self._sessions.clear()
|
| 282 |
+
self._active_episode_id = None
|
| 283 |
+
# Subsystem-specific cleanup — implemented as tools come online
|
| 284 |
+
if self._tool_registry:
|
| 285 |
+
for tool in self._tool_registry.values():
|
| 286 |
+
if hasattr(tool, "close"):
|
| 287 |
+
tool.close()
|
| 288 |
+
|
| 289 |
+
# -------------------- Async variants for parallel rollouts ----
|
| 290 |
+
|
| 291 |
+
async def reset_async(self, seed: int | None = None) -> OptimizationObservation:
|
| 292 |
+
return self.reset(seed)
|
| 293 |
+
|
| 294 |
+
async def step_async(self, action: OptimizationAction) -> StepResult:
|
| 295 |
+
return self.step(action)
|
| 296 |
+
|
| 297 |
+
async def close_async(self) -> None:
|
| 298 |
+
self.close()
|
| 299 |
+
|
| 300 |
+
# -------------------- Internal scaffolding --------------------
|
| 301 |
+
|
| 302 |
+
def _ensure_subsystems_loaded(self) -> None:
|
| 303 |
+
"""Lazy-load tools/dataset/profiles. Real implementations land at Hour 16."""
|
| 304 |
+
# Tools registry
|
| 305 |
+
if not self._tool_registry:
|
| 306 |
+
try:
|
| 307 |
+
from server.tools import TOOL_REGISTRY
|
| 308 |
+
self._tool_registry = TOOL_REGISTRY
|
| 309 |
+
except ImportError:
|
| 310 |
+
self._tool_registry = {}
|
| 311 |
+
|
| 312 |
+
# Dataset loader (real, post-Hour 16)
|
| 313 |
+
if self._dataset_loader is None:
|
| 314 |
+
try:
|
| 315 |
+
from server.scenarios import DatasetLoader
|
| 316 |
+
self._dataset_loader = DatasetLoader(prefer_real_datasets=False)
|
| 317 |
+
except ImportError:
|
| 318 |
+
self._dataset_loader = _StubDatasetLoader()
|
| 319 |
+
|
| 320 |
+
# Hardware profiles (full 8-profile set, post-Hour 16)
|
| 321 |
+
if self._hardware_profiles is None:
|
| 322 |
+
try:
|
| 323 |
+
from server.scenarios.hardware_profiles import HARDWARE_PROFILES
|
| 324 |
+
# Filter held-out for training; eval scripts override this
|
| 325 |
+
self._hardware_profiles = [p for p in HARDWARE_PROFILES if not p.get("held_out")]
|
| 326 |
+
except ImportError:
|
| 327 |
+
self._hardware_profiles = _STUB_PROFILES
|
| 328 |
+
|
| 329 |
+
if self._curriculum is None and self.enable_adaptive_curriculum:
|
| 330 |
+
try:
|
| 331 |
+
from server.scenarios import AdaptiveCurriculum
|
| 332 |
+
self._curriculum = AdaptiveCurriculum(initial_axes=dict(self._global_axes))
|
| 333 |
+
except ImportError:
|
| 334 |
+
self._curriculum = None
|
| 335 |
+
|
| 336 |
+
def _sample_problem(self, rng: random.Random) -> dict[str, Any]:
|
| 337 |
+
"""Sample (function, hw_profile, ground_truth_labels) for an episode.
|
| 338 |
+
|
| 339 |
+
Uses the DatasetLoader to draw a (function, hardware) tuple weighted by
|
| 340 |
+
the current global difficulty axes. Falls back to a built-in stub if
|
| 341 |
+
the loader is the local dev fallback.
|
| 342 |
+
"""
|
| 343 |
+
# Real loader path (post-Hour 16)
|
| 344 |
+
if isinstance(self._dataset_loader, _StubDatasetLoader):
|
| 345 |
+
hw = rng.choice(self._hardware_profiles)
|
| 346 |
+
return {
|
| 347 |
+
"python_code": _STUB_PYTHON_FUNCTION,
|
| 348 |
+
"cpp_signature": 'extern "C" double agent_function(const double* arr, size_t n);',
|
| 349 |
+
"hardware_profile": hw,
|
| 350 |
+
"bottleneck_labels": ["compute-bound", "vectorizable"],
|
| 351 |
+
"bottleneck_distractors": ["memory-bound", "branch-heavy", "io-bound"],
|
| 352 |
+
"rtol_override": None,
|
| 353 |
+
"is_trap": False,
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
return self._dataset_loader.sample(self._global_axes, rng)
|
| 357 |
+
|
| 358 |
+
def _record_episode_outcome(self, state: OptimizationState, observation: OptimizationObservation) -> None:
|
| 359 |
+
"""Update adaptive curriculum after fixed-size batches of completed episodes."""
|
| 360 |
+
if not self.enable_adaptive_curriculum or self._curriculum is None:
|
| 361 |
+
return
|
| 362 |
+
final_submission = state.round_results[-1]["submission"] if state.round_results else {}
|
| 363 |
+
pass_rate = float(final_submission.get("correctness_pass_rate", 0.0))
|
| 364 |
+
compile_ok = final_submission.get("compile_status") == "success"
|
| 365 |
+
episode_success = 1.0 if (compile_ok and pass_rate >= 0.8) else 0.0
|
| 366 |
+
self._episode_success_buffer.append(episode_success)
|
| 367 |
+
observation.metadata["curriculum_pending_batch_count"] = len(self._episode_success_buffer)
|
| 368 |
+
if len(self._episode_success_buffer) < self.curriculum_batch_size:
|
| 369 |
+
return
|
| 370 |
+
success_rate = sum(self._episode_success_buffer) / len(self._episode_success_buffer)
|
| 371 |
+
action = self._curriculum.observe_batch(success_rate)
|
| 372 |
+
self._global_axes = dict(self._curriculum.axes)
|
| 373 |
+
self._episode_success_buffer.clear()
|
| 374 |
+
observation.metadata["curriculum"] = {
|
| 375 |
+
"success_rate": success_rate,
|
| 376 |
+
"action": action,
|
| 377 |
+
"axes": dict(self._global_axes),
|
| 378 |
+
"batches_seen": self._curriculum.n_batches_seen,
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
def _dispatch_tool(self, tool_name: str, tool_args: dict[str, Any], state: OptimizationState) -> dict[str, Any]:
|
| 382 |
+
"""Route a tool call to the registered handler.
|
| 383 |
+
|
| 384 |
+
Real implementations land in Hour 4–10. Until then, stub responses keep the
|
| 385 |
+
Gym API live for smoke tests.
|
| 386 |
+
"""
|
| 387 |
+
if tool_name not in self._tool_registry:
|
| 388 |
+
return {
|
| 389 |
+
"stub": True,
|
| 390 |
+
"tool": tool_name,
|
| 391 |
+
"message": f"Tool '{tool_name}' not yet implemented (Hour 4-10).",
|
| 392 |
+
}
|
| 393 |
+
return self._tool_registry[tool_name](tool_args, state)
|
| 394 |
+
|
| 395 |
+
def _compute_round_reward(self, state: OptimizationState, submission: dict[str, Any]) -> float:
|
| 396 |
+
"""Apply the round-appropriate Sequential(Gate, Gate, WeightedSum) rubric.
|
| 397 |
+
|
| 398 |
+
Per plan §10:
|
| 399 |
+
R1: soft gate (60% correctness), 3 components
|
| 400 |
+
R2: medium gate (80%), informational
|
| 401 |
+
R3: strict gate (95%), 5 components incl. portability + self-correction
|
| 402 |
+
|
| 403 |
+
Returns the rubric DAG's score in [0, 1], or 0.0 if any gate fails.
|
| 404 |
+
"""
|
| 405 |
+
try:
|
| 406 |
+
from server.rewards import build_round_reward_dag
|
| 407 |
+
except ImportError:
|
| 408 |
+
return 0.0
|
| 409 |
+
|
| 410 |
+
# Append a synthetic round_result entry NOW so DiagnosisRubric / SelfCorrectionRubric
|
| 411 |
+
# can read the just-completed round's tool calls. The caller (step()) appends the
|
| 412 |
+
# *real* round_results entry after this returns; we only need a temp lookup.
|
| 413 |
+
# Note: we already appended state.round_results in step() BEFORE computing reward,
|
| 414 |
+
# so this is fine. Diagnosis and SelfCorrection both read state.round_results.
|
| 415 |
+
|
| 416 |
+
dag = build_round_reward_dag(state.round_number)
|
| 417 |
+
score = dag.score(state, submission)
|
| 418 |
+
|
| 419 |
+
# Stash breakdown in submission for telemetry / wandb logging
|
| 420 |
+
submission["_rubric_breakdown"] = getattr(dag, "last_breakdown", {})
|
| 421 |
+
return score
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
# --------------------------- Stubs (Hour 0–4 only) -------------------
|
| 425 |
+
|
| 426 |
+
class _StubDatasetLoader:
|
| 427 |
+
"""Placeholder. Replaced in Hour 16 by server.scenarios.dataset_loader."""
|
| 428 |
+
|
| 429 |
+
def sample(self, axes: dict[str, int], rng: random.Random) -> dict[str, Any]:
|
| 430 |
+
return {"python_code": _STUB_PYTHON_FUNCTION}
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
_STUB_PROFILES = [
|
| 434 |
+
{
|
| 435 |
+
"id": "desktop_avx2",
|
| 436 |
+
"cores": 8,
|
| 437 |
+
"freq_ghz": 3.8,
|
| 438 |
+
"l1_kb": 32,
|
| 439 |
+
"simd": "AVX2",
|
| 440 |
+
"bw_gbs": 51,
|
| 441 |
+
"roofline_bound_gflops": 25.5,
|
| 442 |
+
},
|
| 443 |
+
]
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
_STUB_PYTHON_FUNCTION = '''def sum_squares(arr):
|
| 447 |
+
"""Compute the sum of squares of an array — placeholder during Hour 0-4."""
|
| 448 |
+
total = 0.0
|
| 449 |
+
for x in arr:
|
| 450 |
+
total += x * x
|
| 451 |
+
return total
|
| 452 |
+
'''
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
__all__ = [
|
| 456 |
+
"PolyglotOptimaEnvironment",
|
| 457 |
+
]
|
server/rewards/__init__.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Composable reward rubric system for Polyglot-Optima.
|
| 2 |
+
|
| 3 |
+
Per plan §12 D, this is a 4-level composition tree using only the OpenEnv
|
| 4 |
+
documented primitives (Sequential, Gate, WeightedSum) plus 5 custom Rubric
|
| 5 |
+
subclasses (Speedup, Correctness, Compilation, Diagnosis, Portability,
|
| 6 |
+
SelfCorrection).
|
| 7 |
+
|
| 8 |
+
The composition tree (per plan §10):
|
| 9 |
+
|
| 10 |
+
round1_reward = Sequential(
|
| 11 |
+
Gate(CorrectnessRubric, threshold=0.6),
|
| 12 |
+
Gate(CompilationRubric, threshold=1.0),
|
| 13 |
+
WeightedSum(
|
| 14 |
+
SpeedupRubric w=0.40
|
| 15 |
+
CorrectnessRubric w=0.30
|
| 16 |
+
DiagnosisRubric w=0.30
|
| 17 |
+
)
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
round3_reward = Sequential(
|
| 21 |
+
Gate(CorrectnessRubric, threshold=0.95),
|
| 22 |
+
Gate(CompilationRubric, threshold=1.0),
|
| 23 |
+
WeightedSum(
|
| 24 |
+
SpeedupRubric w=0.35
|
| 25 |
+
CorrectnessRubric w=0.25
|
| 26 |
+
DiagnosisRubric w=0.20
|
| 27 |
+
SelfCorrectionRubric w=0.10
|
| 28 |
+
PortabilityRubric w=0.10 (only counts if portability_required axis on)
|
| 29 |
+
)
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
episode_reward = 0.3 * round1_reward + 0.7 * round3_reward
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
from __future__ import annotations
|
| 36 |
+
|
| 37 |
+
from .rubrics import (
|
| 38 |
+
Rubric,
|
| 39 |
+
Sequential,
|
| 40 |
+
Gate,
|
| 41 |
+
WeightedSum,
|
| 42 |
+
GateFailedError,
|
| 43 |
+
)
|
| 44 |
+
from .speedup_rubric import SpeedupRubric
|
| 45 |
+
from .correctness_rubric import CorrectnessRubric, CompilationRubric
|
| 46 |
+
from .diagnosis_rubric import DiagnosisRubric
|
| 47 |
+
from .portability_rubric import PortabilityRubric
|
| 48 |
+
from .self_correction_rubric import SelfCorrectionRubric
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def build_round_reward_dag(round_number: int):
|
| 52 |
+
"""Construct the reward DAG appropriate for a given round (1, 2, or 3).
|
| 53 |
+
|
| 54 |
+
Round 1: soft gate (60%), 3 components (Speedup, Correctness, Diagnosis)
|
| 55 |
+
Round 2: medium gate (80%), same 3 components (informational)
|
| 56 |
+
Round 3: strict gate (95%), 5 components (adds SelfCorrection + Portability)
|
| 57 |
+
"""
|
| 58 |
+
correctness = CorrectnessRubric()
|
| 59 |
+
compilation = CompilationRubric()
|
| 60 |
+
|
| 61 |
+
# Continuous reward shaping: no hard cliffs in the main training signal.
|
| 62 |
+
# Compilation and correctness both use smooth gates to keep gradient flow alive.
|
| 63 |
+
if round_number == 1:
|
| 64 |
+
return Sequential(
|
| 65 |
+
Gate(correctness, threshold=0.6, ramp_min=0.05, ramp_max=1.0, exponent=2.0),
|
| 66 |
+
Gate(compilation, threshold=1.0, ramp_min=0.10, ramp_max=1.0, exponent=1.5),
|
| 67 |
+
WeightedSum(
|
| 68 |
+
{"speedup": SpeedupRubric(),
|
| 69 |
+
"correctness": correctness,
|
| 70 |
+
"diagnosis": DiagnosisRubric()},
|
| 71 |
+
weights={"speedup": 0.40, "correctness": 0.30, "diagnosis": 0.30},
|
| 72 |
+
),
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
if round_number == 2:
|
| 76 |
+
return Sequential(
|
| 77 |
+
Gate(correctness, threshold=0.80, ramp_min=0.05, ramp_max=1.0, exponent=2.0),
|
| 78 |
+
Gate(compilation, threshold=1.0, ramp_min=0.10, ramp_max=1.0, exponent=1.5),
|
| 79 |
+
WeightedSum(
|
| 80 |
+
{"speedup": SpeedupRubric(),
|
| 81 |
+
"correctness": correctness,
|
| 82 |
+
"diagnosis": DiagnosisRubric()},
|
| 83 |
+
weights={"speedup": 0.40, "correctness": 0.30, "diagnosis": 0.30},
|
| 84 |
+
),
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# Round 3 — strict gate (95%), full 5 components
|
| 88 |
+
return Sequential(
|
| 89 |
+
Gate(correctness, threshold=0.95, ramp_min=0.05, ramp_max=1.0, exponent=2.0),
|
| 90 |
+
Gate(compilation, threshold=1.0, ramp_min=0.10, ramp_max=1.0, exponent=1.5),
|
| 91 |
+
WeightedSum(
|
| 92 |
+
{"speedup": SpeedupRubric(),
|
| 93 |
+
"correctness": correctness,
|
| 94 |
+
"diagnosis": DiagnosisRubric(),
|
| 95 |
+
"self_correction": SelfCorrectionRubric(),
|
| 96 |
+
"portability": PortabilityRubric()},
|
| 97 |
+
weights={"speedup": 0.35, "correctness": 0.25,
|
| 98 |
+
"diagnosis": 0.20, "self_correction": 0.10, "portability": 0.10},
|
| 99 |
+
),
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
__all__ = [
|
| 104 |
+
"Rubric",
|
| 105 |
+
"Sequential",
|
| 106 |
+
"Gate",
|
| 107 |
+
"WeightedSum",
|
| 108 |
+
"GateFailedError",
|
| 109 |
+
"SpeedupRubric",
|
| 110 |
+
"CorrectnessRubric",
|
| 111 |
+
"CompilationRubric",
|
| 112 |
+
"DiagnosisRubric",
|
| 113 |
+
"PortabilityRubric",
|
| 114 |
+
"SelfCorrectionRubric",
|
| 115 |
+
"build_round_reward_dag",
|
| 116 |
+
]
|
server/rewards/correctness_rubric.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CorrectnessRubric + CompilationRubric (binary).
|
| 2 |
+
|
| 3 |
+
CorrectnessRubric returns the fuzzer pass_rate directly (∈ [0,1]). Used both as
|
| 4 |
+
a Gate target and as a weighted component.
|
| 5 |
+
|
| 6 |
+
CompilationRubric is binary: 1.0 if compile succeeded, 0.0 otherwise. Used only
|
| 7 |
+
as a Gate (a compile failure is a hard reward = 0).
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from typing import Any
|
| 13 |
+
|
| 14 |
+
from .rubrics import Rubric
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class CorrectnessRubric(Rubric):
|
| 18 |
+
name = "correctness"
|
| 19 |
+
|
| 20 |
+
def score(self, state, submission: dict[str, Any]) -> float:
|
| 21 |
+
pass_rate = float(submission.get("correctness_pass_rate", 0.0))
|
| 22 |
+
adv_pass_rate = float(submission.get("adversarial_pass_rate", 0.0))
|
| 23 |
+
|
| 24 |
+
# Hard penalty if adversarial sub-pool is below 0.9 (per plan §10b)
|
| 25 |
+
if adv_pass_rate < 0.9:
|
| 26 |
+
penalty = 0.5 # halve the score if adversarial cases are failing
|
| 27 |
+
pass_rate *= penalty
|
| 28 |
+
|
| 29 |
+
self.last_breakdown = {
|
| 30 |
+
"raw_pass_rate": float(submission.get("correctness_pass_rate", 0.0)),
|
| 31 |
+
"adversarial_pass_rate": adv_pass_rate,
|
| 32 |
+
"adversarial_penalty_applied": adv_pass_rate < 0.9,
|
| 33 |
+
"score": pass_rate,
|
| 34 |
+
}
|
| 35 |
+
return max(0.0, min(1.0, pass_rate))
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class CompilationRubric(Rubric):
|
| 39 |
+
"""Continuous compile quality score from compile status."""
|
| 40 |
+
|
| 41 |
+
name = "compilation"
|
| 42 |
+
|
| 43 |
+
def score(self, state, submission: dict[str, Any]) -> float:
|
| 44 |
+
compile_status = submission.get("compile_status", "pending")
|
| 45 |
+
status_to_score = {
|
| 46 |
+
"success": 1.0,
|
| 47 |
+
"link_error": 0.55,
|
| 48 |
+
"timeout": 0.35,
|
| 49 |
+
"syntax_error": 0.10,
|
| 50 |
+
"pending": 0.0,
|
| 51 |
+
}
|
| 52 |
+
score = float(status_to_score.get(str(compile_status), 0.0))
|
| 53 |
+
self.last_breakdown = {"compile_status": compile_status, "score": score}
|
| 54 |
+
return max(0.0, min(1.0, score))
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
__all__ = ["CorrectnessRubric", "CompilationRubric"]
|
server/rewards/diagnosis_rubric.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DiagnosisRubric — multi-signal anti-gaming hypothesis scoring (per plan §10b).
|
| 2 |
+
|
| 3 |
+
Pure keyword match is gameable (agent stuffs all bottleneck keywords into <think>).
|
| 4 |
+
Defense-in-depth:
|
| 5 |
+
|
| 6 |
+
raw = (correct_kw / |ground_truth|) - 0.5 * (distractor_kw / |distractors|)
|
| 7 |
+
raw = max(0, raw)
|
| 8 |
+
length_penalty = 1 - 0.1 * (len(thinking) / 256) # concise > verbose
|
| 9 |
+
coherence_bonus = 0.2 if first_tool_call matches diagnosis else 0
|
| 10 |
+
score = raw * length_penalty + coherence_bonus
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
from typing import Any
|
| 16 |
+
|
| 17 |
+
from .rubrics import Rubric
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Map each diagnosis category to the tool that's "coherent" with it
|
| 21 |
+
DIAGNOSIS_TO_FIRST_TOOL = {
|
| 22 |
+
"memory-bound": "check_memory_access",
|
| 23 |
+
"compute-bound": "get_hardware_profile", # check SIMD width before vectorizing
|
| 24 |
+
"vectorizable": "get_hardware_profile",
|
| 25 |
+
"branch-heavy": "profile_python_hotspots",
|
| 26 |
+
"io-bound": "profile_python_hotspots", # confirm where time goes
|
| 27 |
+
"cache-unfriendly": "check_memory_access",
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class DiagnosisRubric(Rubric):
|
| 32 |
+
name = "diagnosis"
|
| 33 |
+
|
| 34 |
+
def __init__(self, max_thinking_len: int = 256, length_penalty_rate: float = 0.1,
|
| 35 |
+
distractor_penalty_weight: float = 0.5, coherence_bonus: float = 0.2):
|
| 36 |
+
self.max_thinking_len = max_thinking_len
|
| 37 |
+
self.length_penalty_rate = length_penalty_rate
|
| 38 |
+
self.distractor_penalty_weight = distractor_penalty_weight
|
| 39 |
+
self.coherence_bonus = coherence_bonus
|
| 40 |
+
|
| 41 |
+
def score(self, state, submission: dict[str, Any]) -> float:
|
| 42 |
+
thinking = (submission.get("reasoning_trace", "") or state.current_round_reasoning or "").lower()
|
| 43 |
+
ground_truth = state.bottleneck_ground_truth or []
|
| 44 |
+
distractors = state.bottleneck_distractors or []
|
| 45 |
+
|
| 46 |
+
# Keyword counts (use word-boundary-ish substring match)
|
| 47 |
+
correct_kw = sum(1 for kw in ground_truth if kw.lower() in thinking)
|
| 48 |
+
distractor_kw = sum(1 for kw in distractors if kw.lower() in thinking)
|
| 49 |
+
|
| 50 |
+
if not ground_truth:
|
| 51 |
+
self.last_breakdown = {"score": 0.0, "reason": "no_ground_truth_labels"}
|
| 52 |
+
return 0.0
|
| 53 |
+
|
| 54 |
+
raw = (correct_kw / len(ground_truth))
|
| 55 |
+
if distractors:
|
| 56 |
+
raw -= self.distractor_penalty_weight * (distractor_kw / len(distractors))
|
| 57 |
+
raw = max(0.0, raw)
|
| 58 |
+
|
| 59 |
+
length = len(thinking.encode("utf-8")) # bytes — closer to token cost
|
| 60 |
+
length_penalty = max(0.0, 1.0 - self.length_penalty_rate * (length / self.max_thinking_len))
|
| 61 |
+
|
| 62 |
+
# Coherence bonus: was the FIRST tool call in this round consistent with the diagnosis?
|
| 63 |
+
# During reward computation, current round calls are in state.current_round_tool_calls.
|
| 64 |
+
# Fall back to round_results only when current calls are unavailable.
|
| 65 |
+
first_tool = ""
|
| 66 |
+
calls = list(state.current_round_tool_calls or [])
|
| 67 |
+
if not calls:
|
| 68 |
+
round_idx = state.round_number - 1
|
| 69 |
+
if 0 <= round_idx < len(state.round_results):
|
| 70 |
+
calls = list(state.round_results[round_idx].get("tool_calls", []))
|
| 71 |
+
if calls:
|
| 72 |
+
first_tool = calls[0]
|
| 73 |
+
if first_tool == "get_hardware_profile" and len(calls) > 1:
|
| 74 |
+
first_tool = calls[1]
|
| 75 |
+
|
| 76 |
+
# Match: any ground_truth label whose preferred tool == first_tool counts as coherent
|
| 77 |
+
coherence = 0.0
|
| 78 |
+
for label in ground_truth:
|
| 79 |
+
preferred = DIAGNOSIS_TO_FIRST_TOOL.get(label.lower())
|
| 80 |
+
if preferred and preferred == first_tool:
|
| 81 |
+
coherence = self.coherence_bonus
|
| 82 |
+
break
|
| 83 |
+
|
| 84 |
+
score = raw * length_penalty + coherence
|
| 85 |
+
score = max(0.0, min(1.0, score))
|
| 86 |
+
|
| 87 |
+
self.last_breakdown = {
|
| 88 |
+
"correct_kw": correct_kw,
|
| 89 |
+
"distractor_kw": distractor_kw,
|
| 90 |
+
"raw": raw,
|
| 91 |
+
"thinking_bytes": length,
|
| 92 |
+
"length_penalty": length_penalty,
|
| 93 |
+
"first_tool": first_tool,
|
| 94 |
+
"coherence_bonus": coherence,
|
| 95 |
+
"score": score,
|
| 96 |
+
}
|
| 97 |
+
return score
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
__all__ = ["DiagnosisRubric", "DIAGNOSIS_TO_FIRST_TOOL"]
|
server/rewards/portability_rubric.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PortabilityRubric — bonus for code that works across hardware profiles.
|
| 2 |
+
|
| 3 |
+
Only contributes when state.difficulty_axes['portability_required'] is on.
|
| 4 |
+
If the axis is off, returns 0 (i.e., this component contributes nothing to the
|
| 5 |
+
weighted sum, freeing the 10% weight to be implicit-zero).
|
| 6 |
+
|
| 7 |
+
Score = n_profiles_passing / n_other_profiles, clamped [0, 1]. Eligible only if
|
| 8 |
+
n_profiles_passing ≥ 3 (per plan §3 axis 4).
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
from typing import Any
|
| 14 |
+
|
| 15 |
+
from .rubrics import Rubric
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class PortabilityRubric(Rubric):
|
| 19 |
+
name = "portability"
|
| 20 |
+
|
| 21 |
+
def score(self, state, submission: dict[str, Any]) -> float:
|
| 22 |
+
# If the axis is off, this rubric contributes 0 (it's still in the weighted sum,
|
| 23 |
+
# but it neutralizes the 0.10 weight automatically).
|
| 24 |
+
axis_on = state.difficulty_axes.get("portability_required", 0) >= 1
|
| 25 |
+
portability = submission.get("portability", {}) or {}
|
| 26 |
+
n_passing = int(portability.get("n_profiles_passing", 0))
|
| 27 |
+
|
| 28 |
+
if not axis_on:
|
| 29 |
+
self.last_breakdown = {"axis_on": False, "score": 0.0}
|
| 30 |
+
return 0.0
|
| 31 |
+
|
| 32 |
+
# Need at least 3 to count
|
| 33 |
+
if n_passing < 3:
|
| 34 |
+
self.last_breakdown = {"axis_on": True, "n_passing": n_passing, "score": 0.0,
|
| 35 |
+
"reason": "below_3_profile_threshold"}
|
| 36 |
+
return 0.0
|
| 37 |
+
|
| 38 |
+
# Normalize against other-profile count (7 = total profiles minus the home one)
|
| 39 |
+
denom = max(7, 1)
|
| 40 |
+
score = min(1.0, n_passing / denom)
|
| 41 |
+
self.last_breakdown = {"axis_on": True, "n_passing": n_passing, "denom": denom, "score": score}
|
| 42 |
+
return score
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
__all__ = ["PortabilityRubric"]
|
server/rewards/rubrics.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Base Rubric class + 3 composers (Sequential, Gate, WeightedSum).
|
| 2 |
+
|
| 3 |
+
These mirror OpenEnv's documented rubric primitives. Only Sequential, Gate, and
|
| 4 |
+
WeightedSum are confirmed in the framework — MaxOf/MinOf/Conditional were
|
| 5 |
+
*removed* from the plan in §12 D because they are not in upstream OpenEnv.
|
| 6 |
+
|
| 7 |
+
A Rubric is a callable: rubric.score(state, submission) -> float in [0, 1].
|
| 8 |
+
Rubric subclasses also expose .name (str) and may expose per-call breakdown
|
| 9 |
+
via the .last_breakdown dict (used by named_rubrics() introspection).
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
from typing import Any, Mapping
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class GateFailedError(Exception):
|
| 18 |
+
"""Raised by Gate when its child rubric is below threshold.
|
| 19 |
+
|
| 20 |
+
Sequential catches this and short-circuits to 0.0.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Rubric:
|
| 25 |
+
"""Base class — concrete subclasses must override score()."""
|
| 26 |
+
|
| 27 |
+
name: str = "rubric"
|
| 28 |
+
|
| 29 |
+
def score(self, state, submission: dict[str, Any]) -> float:
|
| 30 |
+
raise NotImplementedError("subclass must implement .score()")
|
| 31 |
+
|
| 32 |
+
# Optional debug — populated by score() for introspection
|
| 33 |
+
last_breakdown: dict[str, Any] = {}
|
| 34 |
+
|
| 35 |
+
def __repr__(self) -> str:
|
| 36 |
+
return f"<{self.__class__.__name__} name={self.name!r}>"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# -------------------------- Composers --------------------------
|
| 40 |
+
|
| 41 |
+
class Sequential(Rubric):
|
| 42 |
+
"""Run rubrics in order. Returns (product of Gate multipliers) × (last non-Gate child).
|
| 43 |
+
|
| 44 |
+
Each `Gate` child yields a multiplier ∈ [0, 1]:
|
| 45 |
+
hard pass → 1.0
|
| 46 |
+
hard fail → raises (Sequential returns 0)
|
| 47 |
+
graduated full → 1.0
|
| 48 |
+
graduated ramp → fractional in (0, 1)
|
| 49 |
+
graduated dead → raises (Sequential returns 0)
|
| 50 |
+
|
| 51 |
+
Non-Gate children produce the actual reward score. Sequential outputs
|
| 52 |
+
the final score scaled by the product of gate multipliers — giving GRPO
|
| 53 |
+
a continuous gradient even when the agent is below threshold (per plan §3).
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
name = "sequential"
|
| 57 |
+
|
| 58 |
+
def __init__(self, *children: Rubric):
|
| 59 |
+
if not children:
|
| 60 |
+
raise ValueError("Sequential needs at least one child rubric")
|
| 61 |
+
self.children = children
|
| 62 |
+
|
| 63 |
+
def score(self, state, submission: dict[str, Any]) -> float:
|
| 64 |
+
gate_product = 1.0
|
| 65 |
+
final_score: float | None = None
|
| 66 |
+
breakdown: dict[str, Any] = {}
|
| 67 |
+
for child in self.children:
|
| 68 |
+
try:
|
| 69 |
+
s = child.score(state, submission)
|
| 70 |
+
breakdown[child.name] = s
|
| 71 |
+
except GateFailedError as e:
|
| 72 |
+
breakdown[child.name] = 0.0
|
| 73 |
+
breakdown["_gate_failed"] = str(e)
|
| 74 |
+
self.last_breakdown = breakdown
|
| 75 |
+
return 0.0
|
| 76 |
+
if isinstance(child, Gate):
|
| 77 |
+
gate_product *= s
|
| 78 |
+
else:
|
| 79 |
+
final_score = s
|
| 80 |
+
|
| 81 |
+
breakdown["_gate_product"] = gate_product
|
| 82 |
+
breakdown["_final_score"] = final_score if final_score is not None else gate_product
|
| 83 |
+
self.last_breakdown = breakdown
|
| 84 |
+
|
| 85 |
+
if final_score is None:
|
| 86 |
+
return gate_product
|
| 87 |
+
return float(max(0.0, min(1.0, gate_product * final_score)))
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class Gate(Rubric):
|
| 91 |
+
"""Continuous gate multiplier for shaping reward without binary cliffs.
|
| 92 |
+
|
| 93 |
+
In default mode, this gate never raises and always returns a multiplier in
|
| 94 |
+
[ramp_min, 1.0], where `ramp_min` is small but non-zero. That preserves
|
| 95 |
+
gradient signal even for weak submissions.
|
| 96 |
+
|
| 97 |
+
`hard=True` is kept only for backward compatibility.
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __init__(self, child: Rubric, threshold: float, dead_floor: float = 0.0,
|
| 101 |
+
ramp_max: float = 1.0, hard: bool = False, ramp_min: float = 0.05,
|
| 102 |
+
exponent: float = 2.0):
|
| 103 |
+
self.child = child
|
| 104 |
+
self.threshold = threshold
|
| 105 |
+
self.dead_floor = dead_floor
|
| 106 |
+
self.ramp_max = ramp_max
|
| 107 |
+
self.hard = hard
|
| 108 |
+
self.ramp_min = ramp_min
|
| 109 |
+
self.exponent = exponent
|
| 110 |
+
self.name = f"gate({child.name}>={threshold:.2f})"
|
| 111 |
+
|
| 112 |
+
def score(self, state, submission: dict[str, Any]) -> float:
|
| 113 |
+
"""Returns a MULTIPLIER ∈ [0, 1] for Sequential to multiply the final score by.
|
| 114 |
+
|
| 115 |
+
Hard mode:
|
| 116 |
+
score >= threshold → 1.0
|
| 117 |
+
score < threshold → raises GateFailedError
|
| 118 |
+
Continuous mode:
|
| 119 |
+
score >= threshold → 1.0
|
| 120 |
+
score < threshold → smooth multiplier in [ramp_min, ramp_max]
|
| 121 |
+
"""
|
| 122 |
+
s = self.child.score(state, submission)
|
| 123 |
+
|
| 124 |
+
if self.hard:
|
| 125 |
+
self.last_breakdown = {
|
| 126 |
+
"child": s, "threshold": self.threshold,
|
| 127 |
+
"zone": "hard_pass" if s >= self.threshold else "hard_fail",
|
| 128 |
+
}
|
| 129 |
+
if s < self.threshold:
|
| 130 |
+
raise GateFailedError(f"{self.child.name} = {s:.3f} < {self.threshold} (hard)")
|
| 131 |
+
return 1.0
|
| 132 |
+
|
| 133 |
+
if s >= self.threshold:
|
| 134 |
+
self.last_breakdown = {"child": s, "threshold": self.threshold, "zone": "full"}
|
| 135 |
+
return 1.0
|
| 136 |
+
|
| 137 |
+
# Smooth ramp in [0, threshold) with non-zero floor.
|
| 138 |
+
normalized = max(0.0, s) / max(self.threshold, 1e-9)
|
| 139 |
+
progress = max(0.0, min(1.0, normalized)) ** self.exponent
|
| 140 |
+
multiplier = self.ramp_min + (self.ramp_max - self.ramp_min) * progress
|
| 141 |
+
|
| 142 |
+
self.last_breakdown = {
|
| 143 |
+
"child": s, "threshold": self.threshold,
|
| 144 |
+
"zone": "ramp", "progress": progress, "multiplier": multiplier,
|
| 145 |
+
}
|
| 146 |
+
return float(max(0.0, min(1.0, multiplier)))
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class WeightedSum(Rubric):
|
| 150 |
+
"""Sum of children weighted. weights must be a dict matching children keys.
|
| 151 |
+
|
| 152 |
+
children: Mapping[str, Rubric] — name → rubric
|
| 153 |
+
weights: Mapping[str, float] — name → weight (need not sum to 1; we DO NOT normalize)
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
name = "weighted_sum"
|
| 157 |
+
|
| 158 |
+
def __init__(self, children: Mapping[str, Rubric], weights: Mapping[str, float]):
|
| 159 |
+
if set(children.keys()) != set(weights.keys()):
|
| 160 |
+
raise ValueError(
|
| 161 |
+
f"children keys {set(children.keys())} != weights keys {set(weights.keys())}"
|
| 162 |
+
)
|
| 163 |
+
self.children = dict(children)
|
| 164 |
+
self.weights = dict(weights)
|
| 165 |
+
|
| 166 |
+
def score(self, state, submission: dict[str, Any]) -> float:
|
| 167 |
+
breakdown: dict[str, Any] = {}
|
| 168 |
+
total = 0.0
|
| 169 |
+
for name, rubric in self.children.items():
|
| 170 |
+
child_score = rubric.score(state, submission)
|
| 171 |
+
breakdown[name] = {"score": child_score, "weight": self.weights[name]}
|
| 172 |
+
total += child_score * self.weights[name]
|
| 173 |
+
self.last_breakdown = breakdown
|
| 174 |
+
# Clamp to [0, 1]; weights nominally sum to 1 but we don't enforce
|
| 175 |
+
return float(max(0.0, min(1.0, total)))
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
__all__ = [
|
| 179 |
+
"Rubric",
|
| 180 |
+
"Sequential",
|
| 181 |
+
"Gate",
|
| 182 |
+
"WeightedSum",
|
| 183 |
+
"GateFailedError",
|
| 184 |
+
]
|
server/rewards/self_correction_rubric.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SelfCorrectionRubric — rewards improvement from R1 to R3.
|
| 2 |
+
|
| 3 |
+
Per plan §10 anti-gaming rule: agent could deliberately submit a bad R1 to
|
| 4 |
+
maximize R1→R3 delta. Defense: R1 must compile (CompilationRubric pass)
|
| 5 |
+
or this rubric returns 0. That makes a deliberately-broken R1 a net loss.
|
| 6 |
+
|
| 7 |
+
Score = clamp((R3_speedup - R1_speedup) / R1_speedup, 0, 1)
|
| 8 |
+
but only if R1.compile_status == "success".
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
from typing import Any
|
| 14 |
+
|
| 15 |
+
from .rubrics import Rubric
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SelfCorrectionRubric(Rubric):
|
| 19 |
+
name = "self_correction"
|
| 20 |
+
|
| 21 |
+
def score(self, state, submission: dict[str, Any]) -> float:
|
| 22 |
+
# Only meaningful at round 3
|
| 23 |
+
if state.round_number != 3:
|
| 24 |
+
self.last_breakdown = {"score": 0.0, "reason": "not_round_3"}
|
| 25 |
+
return 0.0
|
| 26 |
+
|
| 27 |
+
# Find R1 result
|
| 28 |
+
r1_result = next((r for r in state.round_results if r["round"] == 1), None)
|
| 29 |
+
if r1_result is None:
|
| 30 |
+
self.last_breakdown = {"score": 0.0, "reason": "no_r1_result"}
|
| 31 |
+
return 0.0
|
| 32 |
+
|
| 33 |
+
r1_submission = r1_result.get("submission", {})
|
| 34 |
+
r1_compile = r1_submission.get("compile_status")
|
| 35 |
+
|
| 36 |
+
# Floor: R1 must have at least compiled (defeats deliberate-bad-R1 cheating)
|
| 37 |
+
if r1_compile != "success":
|
| 38 |
+
self.last_breakdown = {"score": 0.0, "reason": "r1_did_not_compile",
|
| 39 |
+
"r1_compile": r1_compile}
|
| 40 |
+
return 0.0
|
| 41 |
+
|
| 42 |
+
r1_speedup = float(r1_submission.get("speedup", 0.0))
|
| 43 |
+
r3_speedup = float(submission.get("speedup", 0.0))
|
| 44 |
+
|
| 45 |
+
if r1_speedup <= 0:
|
| 46 |
+
self.last_breakdown = {"score": 0.0, "reason": "r1_speedup_zero"}
|
| 47 |
+
return 0.0
|
| 48 |
+
|
| 49 |
+
delta = (r3_speedup - r1_speedup) / r1_speedup
|
| 50 |
+
score = max(0.0, min(1.0, delta))
|
| 51 |
+
|
| 52 |
+
self.last_breakdown = {
|
| 53 |
+
"r1_speedup": r1_speedup,
|
| 54 |
+
"r3_speedup": r3_speedup,
|
| 55 |
+
"delta_pct": delta,
|
| 56 |
+
"score": score,
|
| 57 |
+
}
|
| 58 |
+
return score
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
__all__ = ["SelfCorrectionRubric"]
|
server/rewards/speedup_rubric.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SpeedupRubric — Roofline-grounded reward (per plan §10).
|
| 2 |
+
|
| 3 |
+
reward = log2(1 + speedup / roofline_peak(hw)) / LOG_NORM
|
| 4 |
+
|
| 5 |
+
This is physically interpretable: the agent's reward maxes out at exactly the
|
| 6 |
+
hardware's theoretical ceiling. An agent that hits the Roofline gets 1.0;
|
| 7 |
+
an agent at half the ceiling gets ~0.6; no reward grows unbounded.
|
| 8 |
+
|
| 9 |
+
Why log not linear: a 100x speedup is not 10x more impressive than a 10x
|
| 10 |
+
speedup once you've blown past the Roofline; you've hit a different bottleneck
|
| 11 |
+
and the marginal reward should plateau.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import math
|
| 17 |
+
|
| 18 |
+
from .rubrics import Rubric
|
| 19 |
+
from server.tools.hardware_profiler import roofline_bound
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Normalize so that hitting the Roofline ceiling yields ~1.0 reward
|
| 23 |
+
# log2(1 + 1.0) = 1.0, so LOG_NORM = 1.0 means speedup == roofline_peak yields exactly 1.0.
|
| 24 |
+
# We allow the agent to slightly exceed the ceiling (up to ~2x) which gives ~1.6 reward,
|
| 25 |
+
# clamped to 1.0 by WeightedSum.
|
| 26 |
+
LOG_NORM = 1.0
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class SpeedupRubric(Rubric):
|
| 30 |
+
name = "speedup"
|
| 31 |
+
|
| 32 |
+
def score(self, state, submission: dict[str, Any]) -> float: # type: ignore[override]
|
| 33 |
+
speedup = float(submission.get("speedup", 0.0))
|
| 34 |
+
if speedup <= 0:
|
| 35 |
+
self.last_breakdown = {"speedup": 0.0, "reward": 0.0}
|
| 36 |
+
return 0.0
|
| 37 |
+
|
| 38 |
+
peak = roofline_bound(state.hardware_profile)
|
| 39 |
+
normalized = speedup / max(peak, 1e-6)
|
| 40 |
+
reward = math.log2(1 + normalized) / LOG_NORM
|
| 41 |
+
|
| 42 |
+
# Clamp to [0, 1]
|
| 43 |
+
reward = max(0.0, min(1.0, reward))
|
| 44 |
+
|
| 45 |
+
self.last_breakdown = {
|
| 46 |
+
"speedup": speedup,
|
| 47 |
+
"roofline_peak": peak,
|
| 48 |
+
"normalized": normalized,
|
| 49 |
+
"reward": reward,
|
| 50 |
+
}
|
| 51 |
+
return reward
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# Re-import after definition
|
| 55 |
+
from typing import Any # noqa: E402
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
__all__ = ["SpeedupRubric"]
|
server/scenarios/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Scenario subsystem: hardware profiles, datasets, generators, curriculum."""
|
| 2 |
+
|
| 3 |
+
from .hardware_profiles import HARDWARE_PROFILES, HARDWARE_BY_CLASS, profile_by_id
|
| 4 |
+
from .trap_library import TRAP_LIBRARY, get_trap_by_id, sample_trap
|
| 5 |
+
from .generator import TemplateGenerator, generate_from_template
|
| 6 |
+
from .dataset_loader import DatasetLoader, sample_function
|
| 7 |
+
from .adaptive_curriculum import AdaptiveCurriculum, MAX_LEVEL
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"HARDWARE_PROFILES",
|
| 11 |
+
"HARDWARE_BY_CLASS",
|
| 12 |
+
"profile_by_id",
|
| 13 |
+
"TRAP_LIBRARY",
|
| 14 |
+
"get_trap_by_id",
|
| 15 |
+
"sample_trap",
|
| 16 |
+
"TemplateGenerator",
|
| 17 |
+
"generate_from_template",
|
| 18 |
+
"DatasetLoader",
|
| 19 |
+
"sample_function",
|
| 20 |
+
"AdaptiveCurriculum",
|
| 21 |
+
"MAX_LEVEL",
|
| 22 |
+
]
|
server/scenarios/adaptive_curriculum.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Adaptive 4-axis difficulty controller (per plan §3 — MAX INNOVATION).
|
| 2 |
+
|
| 3 |
+
After every 8-rollout batch the controller computes success_rate and adjusts
|
| 4 |
+
ONE of four orthogonal axes:
|
| 5 |
+
|
| 6 |
+
function_tier: 0..3 (Tier 1..Tier 4 problem complexity)
|
| 7 |
+
hardware_class: 0..2 (easy → hard hardware profiles)
|
| 8 |
+
fuzzer_strictness: 0..2 (n_cases 100→1000, rtol 1e-3→1e-5 + edge cases)
|
| 9 |
+
portability_required: 0..1 (off → must pass on 3+ profiles for any reward)
|
| 10 |
+
|
| 11 |
+
Logic:
|
| 12 |
+
success ≥ 0.75 → escalate one random axis (the model is too good)
|
| 13 |
+
success ≤ 0.25 → de-escalate the highest axis (the model is stuck)
|
| 14 |
+
0.25 < success < 0.75 → Goldilocks zone, hold (max variance for GRPO)
|
| 15 |
+
|
| 16 |
+
Why 4-axis adaptation: prior curriculum work (PLR 2021, SPIRAL 2025, Code-A1
|
| 17 |
+
2026) escalates a SINGLE difficulty dimension. We escalate four orthogonal
|
| 18 |
+
dimensions, giving a much richer adaptation surface and preventing the model
|
| 19 |
+
from "specializing" in one axis. This is the central novelty in §2.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
import random
|
| 25 |
+
from dataclasses import dataclass, field
|
| 26 |
+
from typing import Any
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
MAX_LEVEL = {
|
| 30 |
+
"function_tier": 3,
|
| 31 |
+
"hardware_class": 2,
|
| 32 |
+
"fuzzer_strictness": 2,
|
| 33 |
+
"portability_required": 1,
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
MIN_LEVEL = {axis: 0 for axis in MAX_LEVEL}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class CurriculumSnapshot:
|
| 41 |
+
"""A point-in-time view of the axes + recent batch stats — for wandb logging."""
|
| 42 |
+
axes: dict[str, int]
|
| 43 |
+
success_rate: float
|
| 44 |
+
n_batches_seen: int
|
| 45 |
+
last_action: str = "" # "escalate function_tier", "de-escalate hardware_class", "hold"
|
| 46 |
+
n_escalations: dict[str, int] = field(default_factory=dict)
|
| 47 |
+
n_deescalations: dict[str, int] = field(default_factory=dict)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class AdaptiveCurriculum:
|
| 51 |
+
"""Controller that mutates difficulty axes based on batch success rates.
|
| 52 |
+
|
| 53 |
+
Use:
|
| 54 |
+
curriculum = AdaptiveCurriculum()
|
| 55 |
+
for batch_idx in range(n_batches):
|
| 56 |
+
# rollout 8 episodes using curriculum.axes
|
| 57 |
+
# ...
|
| 58 |
+
success_rate = compiles_and_passes / 8
|
| 59 |
+
curriculum.observe_batch(success_rate)
|
| 60 |
+
snapshot = curriculum.snapshot()
|
| 61 |
+
wandb.log({"curriculum/axes": snapshot.axes, ...})
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
HIGH_THRESHOLD = 0.75
|
| 65 |
+
LOW_THRESHOLD = 0.25
|
| 66 |
+
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
initial_axes: dict[str, int] | None = None,
|
| 70 |
+
seed: int | None = None,
|
| 71 |
+
min_level: dict[str, int] | None = None,
|
| 72 |
+
max_level: dict[str, int] | None = None,
|
| 73 |
+
):
|
| 74 |
+
self.axes = dict(initial_axes or {axis: 0 for axis in MAX_LEVEL})
|
| 75 |
+
self.min_level = dict(min_level or MIN_LEVEL)
|
| 76 |
+
self.max_level = dict(max_level or MAX_LEVEL)
|
| 77 |
+
self.rng = random.Random(seed)
|
| 78 |
+
self.n_batches_seen = 0
|
| 79 |
+
self.last_action = "init"
|
| 80 |
+
self.n_escalations = {axis: 0 for axis in MAX_LEVEL}
|
| 81 |
+
self.n_deescalations = {axis: 0 for axis in MAX_LEVEL}
|
| 82 |
+
self._recent_success = 0.0 # last observed batch success_rate
|
| 83 |
+
|
| 84 |
+
def observe_batch(self, success_rate: float) -> str:
|
| 85 |
+
"""Process one batch result. Returns the action taken as a human-readable string."""
|
| 86 |
+
self.n_batches_seen += 1
|
| 87 |
+
self._recent_success = float(success_rate)
|
| 88 |
+
|
| 89 |
+
if success_rate >= self.HIGH_THRESHOLD:
|
| 90 |
+
action = self._escalate()
|
| 91 |
+
elif success_rate <= self.LOW_THRESHOLD:
|
| 92 |
+
action = self._deescalate()
|
| 93 |
+
else:
|
| 94 |
+
action = "hold (Goldilocks zone)"
|
| 95 |
+
|
| 96 |
+
self.last_action = action
|
| 97 |
+
return action
|
| 98 |
+
|
| 99 |
+
def _escalate(self) -> str:
|
| 100 |
+
"""Pick a random axis (uniform over those still below max) and increment it."""
|
| 101 |
+
candidates = [a for a, v in self.axes.items() if v < self.max_level[a]]
|
| 102 |
+
if not candidates:
|
| 103 |
+
return "hold (all axes at max)"
|
| 104 |
+
axis = self.rng.choice(candidates)
|
| 105 |
+
self.axes[axis] = min(self.axes[axis] + 1, self.max_level[axis])
|
| 106 |
+
self.n_escalations[axis] += 1
|
| 107 |
+
return f"escalate {axis} → {self.axes[axis]}"
|
| 108 |
+
|
| 109 |
+
def _deescalate(self) -> str:
|
| 110 |
+
"""De-escalate the axis currently at the highest level (break ties randomly)."""
|
| 111 |
+
candidates = [a for a, v in self.axes.items() if v > self.min_level[a]]
|
| 112 |
+
if not candidates:
|
| 113 |
+
return "hold (all axes at min)"
|
| 114 |
+
max_value = max(self.axes[a] for a in candidates)
|
| 115 |
+
top = [a for a in candidates if self.axes[a] == max_value]
|
| 116 |
+
axis = self.rng.choice(top)
|
| 117 |
+
self.axes[axis] = max(self.axes[axis] - 1, self.min_level[axis])
|
| 118 |
+
self.n_deescalations[axis] += 1
|
| 119 |
+
return f"de-escalate {axis} → {self.axes[axis]}"
|
| 120 |
+
|
| 121 |
+
def snapshot(self) -> CurriculumSnapshot:
|
| 122 |
+
return CurriculumSnapshot(
|
| 123 |
+
axes=dict(self.axes),
|
| 124 |
+
success_rate=self._recent_success,
|
| 125 |
+
n_batches_seen=self.n_batches_seen,
|
| 126 |
+
last_action=self.last_action,
|
| 127 |
+
n_escalations=dict(self.n_escalations),
|
| 128 |
+
n_deescalations=dict(self.n_deescalations),
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
def to_dict(self) -> dict[str, Any]:
|
| 132 |
+
s = self.snapshot()
|
| 133 |
+
return {
|
| 134 |
+
"axes": s.axes,
|
| 135 |
+
"success_rate": s.success_rate,
|
| 136 |
+
"n_batches_seen": s.n_batches_seen,
|
| 137 |
+
"last_action": s.last_action,
|
| 138 |
+
"n_escalations": s.n_escalations,
|
| 139 |
+
"n_deescalations": s.n_deescalations,
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
__all__ = [
|
| 144 |
+
"AdaptiveCurriculum",
|
| 145 |
+
"CurriculumSnapshot",
|
| 146 |
+
"MAX_LEVEL",
|
| 147 |
+
"MIN_LEVEL",
|
| 148 |
+
]
|
server/scenarios/dataset_loader.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DatasetLoader: pulls Python functions from existing public datasets.
|
| 2 |
+
|
| 3 |
+
Per plan §4 the training pool is constructed from:
|
| 4 |
+
- IBM CodeNet (~80K filtered, primary)
|
| 5 |
+
- TransCoder (852 pairs, cross-validation)
|
| 6 |
+
- Pyperformance (60 fns, real-world calibration)
|
| 7 |
+
- Polybench/C (30 kernels, back-translated)
|
| 8 |
+
- Templates (this module's TemplateGenerator, dynamic)
|
| 9 |
+
- Trap library (15% of every batch)
|
| 10 |
+
|
| 11 |
+
For Hour 16-22 we ship a working loader for templates + traps. CodeNet/
|
| 12 |
+
TransCoder/Pyperformance are wired in via lazy load (HF datasets) — failing
|
| 13 |
+
gracefully to template-only when offline. The Hour 22 smoke test gate verifies
|
| 14 |
+
either path works.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import ast
|
| 20 |
+
import random
|
| 21 |
+
from typing import Any
|
| 22 |
+
|
| 23 |
+
from .generator import TemplateGenerator, generate_from_template
|
| 24 |
+
from .trap_library import get_trap_by_id, sample_trap, sample_trap_by_category, trap_to_problem_dict
|
| 25 |
+
from .hardware_profiles import sample_profile
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class DatasetLoader:
|
| 29 |
+
"""Unified sampler. The environment calls .sample(axes, rng) per reset()."""
|
| 30 |
+
|
| 31 |
+
# Probability that a sampled function is a trap (per plan §4.3 — "15% of every batch")
|
| 32 |
+
TRAP_PROBABILITY = 0.15
|
| 33 |
+
|
| 34 |
+
def __init__(self, prefer_real_datasets: bool = False):
|
| 35 |
+
"""`prefer_real_datasets=True` triggers CodeNet/TransCoder loading.
|
| 36 |
+
|
| 37 |
+
Default False = template-only (Hour 16-22 default; flip in Hour 22+ if
|
| 38 |
+
training has bandwidth to download HF datasets).
|
| 39 |
+
"""
|
| 40 |
+
self.prefer_real = prefer_real_datasets
|
| 41 |
+
self.template_generator = TemplateGenerator()
|
| 42 |
+
self._codenet_cache: list[dict[str, Any]] | None = None
|
| 43 |
+
self._trap_failure_counts: dict[str, int] = {}
|
| 44 |
+
self._adaptive_trap_boost: float = 0.0
|
| 45 |
+
|
| 46 |
+
def sample(self, axes: dict[str, int], rng: random.Random) -> dict[str, Any]:
|
| 47 |
+
"""Sample one (function, hw_profile, ground_truth) tuple given axis levels."""
|
| 48 |
+
# Pick the hardware profile per the hardware_class axis
|
| 49 |
+
hw = sample_profile(rng, axis_level=axes.get("hardware_class", 0))
|
| 50 |
+
|
| 51 |
+
# 15% of the time, draw a trap
|
| 52 |
+
if rng.random() < self.TRAP_PROBABILITY:
|
| 53 |
+
return self._sample_trap_problem(rng, hw)
|
| 54 |
+
|
| 55 |
+
# Otherwise — template, biased to current tier (or real dataset if enabled)
|
| 56 |
+
if self.prefer_real and self._codenet_loaded():
|
| 57 |
+
return self._sample_codenet(rng, hw, axes)
|
| 58 |
+
|
| 59 |
+
# Template path
|
| 60 |
+
tier = axes.get("function_tier", 0)
|
| 61 |
+
template = self.template_generator.sample(tier=tier, rng=rng)
|
| 62 |
+
return generate_from_template(template, hw)
|
| 63 |
+
|
| 64 |
+
def record_submission_outcome(self, state, submission: dict[str, Any]) -> None:
|
| 65 |
+
"""Update adaptive trap priorities from recent trap outcomes."""
|
| 66 |
+
if not getattr(state, "is_trap", False):
|
| 67 |
+
# Slow decay when solving non-trap episodes so adaptation doesn't stick forever.
|
| 68 |
+
self._adaptive_trap_boost = max(0.0, self._adaptive_trap_boost - 0.01)
|
| 69 |
+
return
|
| 70 |
+
|
| 71 |
+
trap_id = getattr(state, "trap_id", None)
|
| 72 |
+
trap = get_trap_by_id(trap_id) if trap_id else None
|
| 73 |
+
if trap is None:
|
| 74 |
+
return
|
| 75 |
+
|
| 76 |
+
pass_rate = float(submission.get("correctness_pass_rate", 0.0))
|
| 77 |
+
adv_rate = float(submission.get("adversarial_pass_rate", 0.0))
|
| 78 |
+
failed = pass_rate < 0.8 or adv_rate < 0.9
|
| 79 |
+
if failed:
|
| 80 |
+
self._trap_failure_counts[trap.category] = self._trap_failure_counts.get(trap.category, 0) + 1
|
| 81 |
+
self._adaptive_trap_boost = min(0.25, self._adaptive_trap_boost + 0.03)
|
| 82 |
+
else:
|
| 83 |
+
self._adaptive_trap_boost = max(0.0, self._adaptive_trap_boost - 0.02)
|
| 84 |
+
|
| 85 |
+
def _sample_trap_problem(self, rng: random.Random, hw: dict[str, Any]) -> dict[str, Any]:
|
| 86 |
+
"""Sample a static or adaptive trap depending on recent failure patterns."""
|
| 87 |
+
use_adaptive = bool(self._trap_failure_counts) and rng.random() < min(0.85, 0.55 + self._adaptive_trap_boost)
|
| 88 |
+
if use_adaptive:
|
| 89 |
+
categories = list(self._trap_failure_counts.keys())
|
| 90 |
+
weights = [max(1, self._trap_failure_counts[c]) for c in categories]
|
| 91 |
+
chosen_category = rng.choices(categories, weights=weights, k=1)[0]
|
| 92 |
+
base_trap = sample_trap_by_category(chosen_category, rng, exclude_held_out=True)
|
| 93 |
+
if base_trap is None:
|
| 94 |
+
base_trap = sample_trap(rng, exclude_held_out=True)
|
| 95 |
+
return self._build_adaptive_trap_variant(base_trap, hw, rng)
|
| 96 |
+
|
| 97 |
+
trap = sample_trap(rng, exclude_held_out=True)
|
| 98 |
+
p = trap_to_problem_dict(trap, hw)
|
| 99 |
+
p["source"] = "trap_library"
|
| 100 |
+
return p
|
| 101 |
+
|
| 102 |
+
def _build_adaptive_trap_variant(self, trap, hw: dict[str, Any], rng: random.Random) -> dict[str, Any]:
|
| 103 |
+
"""Generate a semantic-preserving variant to reduce memorization."""
|
| 104 |
+
python_code = trap.python_code
|
| 105 |
+
if "def " in python_code and "(" in python_code:
|
| 106 |
+
suffix = rng.randint(1000, 9999)
|
| 107 |
+
start = python_code.find("def ")
|
| 108 |
+
end = python_code.find("(", start)
|
| 109 |
+
fn_name = python_code[start + 4:end].strip()
|
| 110 |
+
if fn_name:
|
| 111 |
+
python_code = python_code.replace(f"def {fn_name}(", f"def {fn_name}_adapt_{suffix}(", 1)
|
| 112 |
+
python_code = self._semantic_noop_mutation(python_code, rng)
|
| 113 |
+
|
| 114 |
+
variant = trap_to_problem_dict(trap, hw)
|
| 115 |
+
variant["python_code"] = python_code
|
| 116 |
+
variant["trap_id"] = f"{trap.id}::adaptive"
|
| 117 |
+
variant["trap_parent_id"] = trap.id
|
| 118 |
+
variant["trap_category"] = trap.category
|
| 119 |
+
variant["source"] = "adaptive_trap"
|
| 120 |
+
return variant
|
| 121 |
+
|
| 122 |
+
def _semantic_noop_mutation(self, python_code: str, rng: random.Random) -> str:
|
| 123 |
+
"""Apply semantic no-op AST rewrites so adaptive traps are not pure renames."""
|
| 124 |
+
|
| 125 |
+
class _NoopTransformer(ast.NodeTransformer):
|
| 126 |
+
def __init__(self, seed: int):
|
| 127 |
+
self._rng = random.Random(seed)
|
| 128 |
+
|
| 129 |
+
def visit_For(self, node: ast.For):
|
| 130 |
+
self.generic_visit(node)
|
| 131 |
+
# Insert a no-op guard branch to perturb structure while preserving behavior.
|
| 132 |
+
if self._rng.random() < 0.45:
|
| 133 |
+
noop = ast.If(
|
| 134 |
+
test=ast.Constant(value=False),
|
| 135 |
+
body=[ast.Expr(value=ast.Constant(value=None))],
|
| 136 |
+
orelse=[],
|
| 137 |
+
)
|
| 138 |
+
node.body = [noop, *node.body]
|
| 139 |
+
return node
|
| 140 |
+
|
| 141 |
+
def visit_Assign(self, node: ast.Assign):
|
| 142 |
+
self.generic_visit(node)
|
| 143 |
+
# Occasionally wrap RHS in (+ 0) no-op for numeric expressions.
|
| 144 |
+
if self._rng.random() < 0.30:
|
| 145 |
+
node.value = ast.BinOp(left=node.value, op=ast.Add(), right=ast.Constant(value=0))
|
| 146 |
+
return node
|
| 147 |
+
|
| 148 |
+
try:
|
| 149 |
+
tree = ast.parse(python_code)
|
| 150 |
+
transformer = _NoopTransformer(seed=rng.randint(0, 10_000_000))
|
| 151 |
+
mutated = transformer.visit(tree)
|
| 152 |
+
ast.fix_missing_locations(mutated)
|
| 153 |
+
code = ast.unparse(mutated)
|
| 154 |
+
if not code.endswith("\n"):
|
| 155 |
+
code += "\n"
|
| 156 |
+
return code
|
| 157 |
+
except Exception:
|
| 158 |
+
# Fallback: minimally perturb whitespace/comments while keeping code valid.
|
| 159 |
+
lines = python_code.splitlines()
|
| 160 |
+
if lines and not lines[0].lstrip().startswith("#"):
|
| 161 |
+
lines.insert(0, "# adaptive trap variant")
|
| 162 |
+
return "\n".join(lines) + ("\n" if lines else "")
|
| 163 |
+
|
| 164 |
+
# -------- CodeNet integration (lazy, optional) --------
|
| 165 |
+
|
| 166 |
+
def _codenet_loaded(self) -> bool:
|
| 167 |
+
return self._codenet_cache is not None and len(self._codenet_cache) > 0
|
| 168 |
+
|
| 169 |
+
def _try_load_codenet(self) -> bool:
|
| 170 |
+
"""Lazy-load CodeNet from HF datasets. Returns True iff load succeeded.
|
| 171 |
+
|
| 172 |
+
Handles offline / no-token gracefully.
|
| 173 |
+
"""
|
| 174 |
+
if self._codenet_loaded():
|
| 175 |
+
return True
|
| 176 |
+
try:
|
| 177 |
+
from datasets import load_dataset # type: ignore
|
| 178 |
+
ds = load_dataset(
|
| 179 |
+
"codeparrot/codenet",
|
| 180 |
+
split="train",
|
| 181 |
+
streaming=True,
|
| 182 |
+
)
|
| 183 |
+
cache: list[dict[str, Any]] = []
|
| 184 |
+
for example in ds:
|
| 185 |
+
if len(cache) >= 1000: # bounded preload
|
| 186 |
+
break
|
| 187 |
+
if example.get("language") != "Python3":
|
| 188 |
+
continue
|
| 189 |
+
code = example.get("code", "")
|
| 190 |
+
if 200 <= len(code) <= 4000:
|
| 191 |
+
cache.append({"code": code, "source": "codenet"})
|
| 192 |
+
self._codenet_cache = cache
|
| 193 |
+
return len(cache) > 0
|
| 194 |
+
except Exception:
|
| 195 |
+
self._codenet_cache = []
|
| 196 |
+
return False
|
| 197 |
+
|
| 198 |
+
def _sample_codenet(self, rng: random.Random, hw: dict[str, Any], axes: dict[str, int]) -> dict[str, Any]:
|
| 199 |
+
if not self._codenet_loaded() and not self._try_load_codenet():
|
| 200 |
+
# Fall back to template
|
| 201 |
+
template = self.template_generator.sample(tier=axes.get("function_tier", 0), rng=rng)
|
| 202 |
+
return generate_from_template(template, hw)
|
| 203 |
+
|
| 204 |
+
cache = self._codenet_cache or []
|
| 205 |
+
if not cache:
|
| 206 |
+
template = self.template_generator.sample(tier=axes.get("function_tier", 0), rng=rng)
|
| 207 |
+
return generate_from_template(template, hw)
|
| 208 |
+
|
| 209 |
+
# Pick a random function from the cache
|
| 210 |
+
sample = rng.choice(cache)
|
| 211 |
+
return {
|
| 212 |
+
"python_code": sample["code"],
|
| 213 |
+
"cpp_signature": _infer_cpp_signature_simple(sample["code"]),
|
| 214 |
+
"hardware_profile": hw,
|
| 215 |
+
# Without ground-truth labels we use a generic catch-all; DiagnosisRubric will
|
| 216 |
+
# award partial credit for any of these. CodeNet samples are not the primary
|
| 217 |
+
# training source for diagnosis training — the templates are.
|
| 218 |
+
"bottleneck_labels": ["compute-bound"],
|
| 219 |
+
"bottleneck_distractors": ["memory-bound", "branch-heavy", "io-bound"],
|
| 220 |
+
"rtol_override": None,
|
| 221 |
+
"is_trap": False,
|
| 222 |
+
"source": "codenet",
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def _infer_cpp_signature_simple(python_code: str) -> str:
|
| 227 |
+
import ast
|
| 228 |
+
try:
|
| 229 |
+
tree = ast.parse(python_code)
|
| 230 |
+
fn = next((n for n in tree.body if isinstance(n, ast.FunctionDef)), None)
|
| 231 |
+
if fn:
|
| 232 |
+
return f'extern "C" void agent_function(/* {len(fn.args.args)} args */);'
|
| 233 |
+
except Exception:
|
| 234 |
+
pass
|
| 235 |
+
return 'extern "C" void agent_function(void* in, size_t n, void* out);'
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# Module-level convenience function (no class needed)
|
| 239 |
+
_default_loader: DatasetLoader | None = None
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def sample_function(axes: dict[str, int], rng: random.Random) -> dict[str, Any]:
|
| 243 |
+
global _default_loader
|
| 244 |
+
if _default_loader is None:
|
| 245 |
+
_default_loader = DatasetLoader(prefer_real_datasets=False)
|
| 246 |
+
return _default_loader.sample(axes, rng)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
__all__ = ["DatasetLoader", "sample_function"]
|
server/scenarios/generator.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Template-based + adversarial Python function generator.
|
| 2 |
+
|
| 3 |
+
Per plan §16 hard cutoff: ship template-only first, add LLM-based adversarial
|
| 4 |
+
generation only if Hour 22 budget allows. This module currently implements the
|
| 5 |
+
deterministic template generator. The LLM-adversarial path is wired through a
|
| 6 |
+
`generate_adversarial(...)` stub that we can switch to in Hour 22 if time permits.
|
| 7 |
+
|
| 8 |
+
Templates are tier-parameterized (per plan §9 four tiers):
|
| 9 |
+
Tier 0: Algorithmic — simple loops, sum/argmax/count/prefix
|
| 10 |
+
Tier 1: Memory-aware — transpose, sliding window, histogram
|
| 11 |
+
Tier 2: SIMD+parallel — pairwise distance, batch_norm, RLE
|
| 12 |
+
Tier 3: Frontier — fused attention, sparse, conv2d
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import random
|
| 18 |
+
from dataclasses import dataclass, field
|
| 19 |
+
from typing import Any
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class Template:
|
| 24 |
+
id: str
|
| 25 |
+
tier: int
|
| 26 |
+
python_code: str
|
| 27 |
+
bottleneck_label: list[str] = field(default_factory=list)
|
| 28 |
+
description: str = ""
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# -------- Tier 0: Algorithmic --------
|
| 32 |
+
|
| 33 |
+
_TIER_0_TEMPLATES: list[Template] = [
|
| 34 |
+
Template(
|
| 35 |
+
id="t0_simple_sum",
|
| 36 |
+
tier=0,
|
| 37 |
+
python_code=(
|
| 38 |
+
"def total(arr):\n"
|
| 39 |
+
" s = 0.0\n"
|
| 40 |
+
" for x in arr:\n"
|
| 41 |
+
" s += x\n"
|
| 42 |
+
" return s\n"
|
| 43 |
+
),
|
| 44 |
+
bottleneck_label=["compute-bound", "vectorizable"],
|
| 45 |
+
),
|
| 46 |
+
Template(
|
| 47 |
+
id="t0_argmax",
|
| 48 |
+
tier=0,
|
| 49 |
+
python_code=(
|
| 50 |
+
"def argmax(arr):\n"
|
| 51 |
+
" if not arr:\n"
|
| 52 |
+
" return -1\n"
|
| 53 |
+
" best_i, best_v = 0, arr[0]\n"
|
| 54 |
+
" for i in range(1, len(arr)):\n"
|
| 55 |
+
" if arr[i] > best_v:\n"
|
| 56 |
+
" best_v, best_i = arr[i], i\n"
|
| 57 |
+
" return best_i\n"
|
| 58 |
+
),
|
| 59 |
+
bottleneck_label=["branch-heavy", "compute-bound"],
|
| 60 |
+
),
|
| 61 |
+
Template(
|
| 62 |
+
id="t0_count_if",
|
| 63 |
+
tier=0,
|
| 64 |
+
python_code=(
|
| 65 |
+
"def count_pos(arr):\n"
|
| 66 |
+
" n = 0\n"
|
| 67 |
+
" for x in arr:\n"
|
| 68 |
+
" if x > 0:\n"
|
| 69 |
+
" n += 1\n"
|
| 70 |
+
" return n\n"
|
| 71 |
+
),
|
| 72 |
+
bottleneck_label=["branch-heavy", "vectorizable"],
|
| 73 |
+
),
|
| 74 |
+
Template(
|
| 75 |
+
id="t0_prefix_sum",
|
| 76 |
+
tier=0,
|
| 77 |
+
python_code=(
|
| 78 |
+
"def prefix_sum(arr):\n"
|
| 79 |
+
" out = [0.0] * len(arr)\n"
|
| 80 |
+
" s = 0.0\n"
|
| 81 |
+
" for i, x in enumerate(arr):\n"
|
| 82 |
+
" s += x\n"
|
| 83 |
+
" out[i] = s\n"
|
| 84 |
+
" return out\n"
|
| 85 |
+
),
|
| 86 |
+
bottleneck_label=["compute-bound"],
|
| 87 |
+
),
|
| 88 |
+
Template(
|
| 89 |
+
id="t0_sum_squares",
|
| 90 |
+
tier=0,
|
| 91 |
+
python_code=(
|
| 92 |
+
"def sum_squares(arr):\n"
|
| 93 |
+
" s = 0.0\n"
|
| 94 |
+
" for x in arr:\n"
|
| 95 |
+
" s += x * x\n"
|
| 96 |
+
" return s\n"
|
| 97 |
+
),
|
| 98 |
+
bottleneck_label=["compute-bound", "vectorizable"],
|
| 99 |
+
),
|
| 100 |
+
]
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# -------- Tier 1: Memory-aware --------
|
| 104 |
+
|
| 105 |
+
_TIER_1_TEMPLATES: list[Template] = [
|
| 106 |
+
Template(
|
| 107 |
+
id="t1_matrix_transpose",
|
| 108 |
+
tier=1,
|
| 109 |
+
python_code=(
|
| 110 |
+
"def transpose(a, n: int, m: int):\n"
|
| 111 |
+
" out = [[0.0]*n for _ in range(m)]\n"
|
| 112 |
+
" for i in range(n):\n"
|
| 113 |
+
" for j in range(m):\n"
|
| 114 |
+
" out[j][i] = a[i][j]\n"
|
| 115 |
+
" return out\n"
|
| 116 |
+
),
|
| 117 |
+
bottleneck_label=["memory-bound", "cache-unfriendly"],
|
| 118 |
+
),
|
| 119 |
+
Template(
|
| 120 |
+
id="t1_sliding_window",
|
| 121 |
+
tier=1,
|
| 122 |
+
python_code=(
|
| 123 |
+
"def moving_avg(arr, k: int):\n"
|
| 124 |
+
" n = len(arr)\n"
|
| 125 |
+
" out = [0.0] * (n - k + 1)\n"
|
| 126 |
+
" for i in range(n - k + 1):\n"
|
| 127 |
+
" s = 0.0\n"
|
| 128 |
+
" for j in range(k):\n"
|
| 129 |
+
" s += arr[i + j]\n"
|
| 130 |
+
" out[i] = s / k\n"
|
| 131 |
+
" return out\n"
|
| 132 |
+
),
|
| 133 |
+
bottleneck_label=["compute-bound", "memory-bound"],
|
| 134 |
+
),
|
| 135 |
+
Template(
|
| 136 |
+
id="t1_histogram",
|
| 137 |
+
tier=1,
|
| 138 |
+
python_code=(
|
| 139 |
+
"def histogram(arr, n_bins: int):\n"
|
| 140 |
+
" bins = [0] * n_bins\n"
|
| 141 |
+
" lo = min(arr)\n"
|
| 142 |
+
" hi = max(arr)\n"
|
| 143 |
+
" width = (hi - lo) / n_bins if hi > lo else 1.0\n"
|
| 144 |
+
" for x in arr:\n"
|
| 145 |
+
" b = min(int((x - lo) / width), n_bins - 1)\n"
|
| 146 |
+
" bins[b] += 1\n"
|
| 147 |
+
" return bins\n"
|
| 148 |
+
),
|
| 149 |
+
bottleneck_label=["memory-bound", "branch-heavy"],
|
| 150 |
+
),
|
| 151 |
+
Template(
|
| 152 |
+
id="t1_bitmask_filter",
|
| 153 |
+
tier=1,
|
| 154 |
+
python_code=(
|
| 155 |
+
"def masked_sum(arr, mask):\n"
|
| 156 |
+
" return sum(arr[i] for i in range(len(arr)) if mask[i])\n"
|
| 157 |
+
),
|
| 158 |
+
bottleneck_label=["branch-heavy", "vectorizable"],
|
| 159 |
+
),
|
| 160 |
+
]
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# -------- Tier 2: SIMD + parallel --------
|
| 164 |
+
|
| 165 |
+
_TIER_2_TEMPLATES: list[Template] = [
|
| 166 |
+
Template(
|
| 167 |
+
id="t2_pairwise_dist",
|
| 168 |
+
tier=2,
|
| 169 |
+
python_code=(
|
| 170 |
+
"def pairwise_dist_sq(X, n: int, d: int):\n"
|
| 171 |
+
" out = [[0.0]*n for _ in range(n)]\n"
|
| 172 |
+
" for i in range(n):\n"
|
| 173 |
+
" for j in range(n):\n"
|
| 174 |
+
" s = 0.0\n"
|
| 175 |
+
" for k in range(d):\n"
|
| 176 |
+
" diff = X[i][k] - X[j][k]\n"
|
| 177 |
+
" s += diff * diff\n"
|
| 178 |
+
" out[i][j] = s\n"
|
| 179 |
+
" return out\n"
|
| 180 |
+
),
|
| 181 |
+
bottleneck_label=["compute-bound", "vectorizable"],
|
| 182 |
+
),
|
| 183 |
+
Template(
|
| 184 |
+
id="t2_batch_norm",
|
| 185 |
+
tier=2,
|
| 186 |
+
python_code=(
|
| 187 |
+
"def batch_norm(X, gamma, beta, eps: float):\n"
|
| 188 |
+
" n = len(X)\n"
|
| 189 |
+
" mean = sum(X) / n\n"
|
| 190 |
+
" var = sum((x - mean) ** 2 for x in X) / n\n"
|
| 191 |
+
" inv_std = 1.0 / ((var + eps) ** 0.5)\n"
|
| 192 |
+
" return [gamma * (x - mean) * inv_std + beta for x in X]\n"
|
| 193 |
+
),
|
| 194 |
+
bottleneck_label=["compute-bound", "vectorizable"],
|
| 195 |
+
),
|
| 196 |
+
Template(
|
| 197 |
+
id="t2_inner_product_batch",
|
| 198 |
+
tier=2,
|
| 199 |
+
python_code=(
|
| 200 |
+
"def batch_inner(A, B, n: int, d: int):\n"
|
| 201 |
+
" out = [0.0] * n\n"
|
| 202 |
+
" for i in range(n):\n"
|
| 203 |
+
" s = 0.0\n"
|
| 204 |
+
" for k in range(d):\n"
|
| 205 |
+
" s += A[i][k] * B[i][k]\n"
|
| 206 |
+
" out[i] = s\n"
|
| 207 |
+
" return out\n"
|
| 208 |
+
),
|
| 209 |
+
bottleneck_label=["compute-bound", "vectorizable"],
|
| 210 |
+
),
|
| 211 |
+
]
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
# -------- Tier 3: Frontier --------
|
| 215 |
+
|
| 216 |
+
_TIER_3_TEMPLATES: list[Template] = [
|
| 217 |
+
Template(
|
| 218 |
+
id="t3_attention_score",
|
| 219 |
+
tier=3,
|
| 220 |
+
python_code=(
|
| 221 |
+
"def attention_score(Q, K, n: int, d: int):\n"
|
| 222 |
+
" out = [[0.0]*n for _ in range(n)]\n"
|
| 223 |
+
" for i in range(n):\n"
|
| 224 |
+
" for j in range(n):\n"
|
| 225 |
+
" s = 0.0\n"
|
| 226 |
+
" for k in range(d):\n"
|
| 227 |
+
" s += Q[i][k] * K[j][k]\n"
|
| 228 |
+
" out[i][j] = s / (d ** 0.5)\n"
|
| 229 |
+
" return out\n"
|
| 230 |
+
),
|
| 231 |
+
bottleneck_label=["compute-bound", "vectorizable"],
|
| 232 |
+
),
|
| 233 |
+
Template(
|
| 234 |
+
id="t3_softmax_log",
|
| 235 |
+
tier=3,
|
| 236 |
+
python_code=(
|
| 237 |
+
"import math\n"
|
| 238 |
+
"def log_softmax(arr):\n"
|
| 239 |
+
" m = max(arr)\n"
|
| 240 |
+
" s = sum(math.exp(x - m) for x in arr)\n"
|
| 241 |
+
" log_s = m + math.log(s)\n"
|
| 242 |
+
" return [x - log_s for x in arr]\n"
|
| 243 |
+
),
|
| 244 |
+
bottleneck_label=["compute-bound"],
|
| 245 |
+
),
|
| 246 |
+
Template(
|
| 247 |
+
id="t3_conv2d_naive",
|
| 248 |
+
tier=3,
|
| 249 |
+
python_code=(
|
| 250 |
+
"def conv2d(img, kernel, h: int, w: int, kh: int, kw: int):\n"
|
| 251 |
+
" oh, ow = h - kh + 1, w - kw + 1\n"
|
| 252 |
+
" out = [[0.0]*ow for _ in range(oh)]\n"
|
| 253 |
+
" for i in range(oh):\n"
|
| 254 |
+
" for j in range(ow):\n"
|
| 255 |
+
" s = 0.0\n"
|
| 256 |
+
" for ki in range(kh):\n"
|
| 257 |
+
" for kj in range(kw):\n"
|
| 258 |
+
" s += img[i+ki][j+kj] * kernel[ki][kj]\n"
|
| 259 |
+
" out[i][j] = s\n"
|
| 260 |
+
" return out\n"
|
| 261 |
+
),
|
| 262 |
+
bottleneck_label=["compute-bound", "memory-bound"],
|
| 263 |
+
),
|
| 264 |
+
]
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
_TEMPLATES_BY_TIER = {
|
| 268 |
+
0: _TIER_0_TEMPLATES,
|
| 269 |
+
1: _TIER_1_TEMPLATES,
|
| 270 |
+
2: _TIER_2_TEMPLATES,
|
| 271 |
+
3: _TIER_3_TEMPLATES,
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
_DEFAULT_DISTRACTORS = ["memory-bound", "branch-heavy", "io-bound", "cache-unfriendly", "compute-bound"]
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class TemplateGenerator:
|
| 279 |
+
"""Deterministic template generator (no LLM call). Hour 16-22 deliverable."""
|
| 280 |
+
|
| 281 |
+
def sample(self, tier: int, rng: random.Random) -> Template:
|
| 282 |
+
"""Sample a template at the given tier (or below — gives easier mix in early training)."""
|
| 283 |
+
pool: list[Template] = []
|
| 284 |
+
for t in range(min(tier, 3) + 1):
|
| 285 |
+
pool.extend(_TEMPLATES_BY_TIER[t])
|
| 286 |
+
if not pool:
|
| 287 |
+
pool = _TEMPLATES_BY_TIER[0]
|
| 288 |
+
return rng.choice(pool)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def generate_from_template(template: Template, hw_profile: dict[str, Any]) -> dict[str, Any]:
|
| 292 |
+
"""Convert a Template into the env._sample_problem() return shape."""
|
| 293 |
+
distractors = [d for d in _DEFAULT_DISTRACTORS if d not in template.bottleneck_label]
|
| 294 |
+
from .trap_library import _infer_cpp_signature
|
| 295 |
+
return {
|
| 296 |
+
"python_code": template.python_code,
|
| 297 |
+
"cpp_signature": _infer_cpp_signature(template.python_code),
|
| 298 |
+
"hardware_profile": hw_profile,
|
| 299 |
+
"bottleneck_labels": template.bottleneck_label,
|
| 300 |
+
"bottleneck_distractors": distractors,
|
| 301 |
+
"rtol_override": None,
|
| 302 |
+
"is_trap": False,
|
| 303 |
+
"template_id": template.id,
|
| 304 |
+
"tier": template.tier,
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
# Public counts
|
| 309 |
+
N_TEMPLATES_TIER_0 = len(_TIER_0_TEMPLATES)
|
| 310 |
+
N_TEMPLATES_TIER_1 = len(_TIER_1_TEMPLATES)
|
| 311 |
+
N_TEMPLATES_TIER_2 = len(_TIER_2_TEMPLATES)
|
| 312 |
+
N_TEMPLATES_TIER_3 = len(_TIER_3_TEMPLATES)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
__all__ = [
|
| 316 |
+
"Template",
|
| 317 |
+
"TemplateGenerator",
|
| 318 |
+
"generate_from_template",
|
| 319 |
+
"N_TEMPLATES_TIER_0", "N_TEMPLATES_TIER_1", "N_TEMPLATES_TIER_2", "N_TEMPLATES_TIER_3",
|
| 320 |
+
]
|
server/scenarios/hardware_profiles.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""8 Roofline-calibrated synthetic hardware profiles (per plan §10).
|
| 2 |
+
|
| 3 |
+
Profile classes (for the `hardware_class` curriculum axis):
|
| 4 |
+
Class 0 (easy): laptop_sse, desktop_avx2
|
| 5 |
+
Class 1 (medium): workstation, arm_neon_a, laptop_sse2
|
| 6 |
+
Class 2 (hard): server_avx512, embedded, arm_neon_b (held-out for Gen-2 eval)
|
| 7 |
+
|
| 8 |
+
`arm_neon_b` is the held-out profile (never sampled during training). Used for
|
| 9 |
+
the Gen-2 evaluation split that tests hardware-reasoning generalization.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
from typing import Any
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
HARDWARE_PROFILES: list[dict[str, Any]] = [
|
| 18 |
+
# Class 0 — easy, common consumer hardware
|
| 19 |
+
{"id": "laptop_sse", "cores": 4, "freq_ghz": 3.2, "l1_kb": 32, "simd": "SSE4.2", "bw_gbs": 40, "class": 0},
|
| 20 |
+
{"id": "desktop_avx2", "cores": 8, "freq_ghz": 3.8, "l1_kb": 32, "simd": "AVX2", "bw_gbs": 51, "class": 0},
|
| 21 |
+
|
| 22 |
+
# Class 1 — medium, varied
|
| 23 |
+
{"id": "workstation", "cores": 12, "freq_ghz": 4.0, "l1_kb": 48, "simd": "AVX2", "bw_gbs": 76, "class": 1},
|
| 24 |
+
{"id": "arm_neon_a", "cores": 6, "freq_ghz": 2.4, "l1_kb": 64, "simd": "NEON", "bw_gbs": 68, "class": 1},
|
| 25 |
+
{"id": "laptop_sse2", "cores": 4, "freq_ghz": 2.6, "l1_kb": 64, "simd": "SSE4.2", "bw_gbs": 35, "class": 1},
|
| 26 |
+
|
| 27 |
+
# Class 2 — hard, demands real hardware reasoning
|
| 28 |
+
{"id": "server_avx512", "cores": 16, "freq_ghz": 3.0, "l1_kb": 48, "simd": "AVX-512", "bw_gbs": 89, "class": 2},
|
| 29 |
+
{"id": "embedded", "cores": 2, "freq_ghz": 1.8, "l1_kb": 16, "simd": "none", "bw_gbs": 25, "class": 2},
|
| 30 |
+
|
| 31 |
+
# HELD-OUT for Gen-2 evaluation — never sampled during training
|
| 32 |
+
{"id": "arm_neon_b", "cores": 8, "freq_ghz": 2.8, "l1_kb": 32, "simd": "NEON", "bw_gbs": 68, "class": 2, "held_out": True},
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
HARDWARE_BY_CLASS: dict[int, list[dict[str, Any]]] = {
|
| 37 |
+
0: [p for p in HARDWARE_PROFILES if p.get("class") == 0 and not p.get("held_out")],
|
| 38 |
+
1: [p for p in HARDWARE_PROFILES if p.get("class") == 1 and not p.get("held_out")],
|
| 39 |
+
2: [p for p in HARDWARE_PROFILES if p.get("class") == 2 and not p.get("held_out")],
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
HELD_OUT_PROFILES: list[dict[str, Any]] = [p for p in HARDWARE_PROFILES if p.get("held_out")]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def profile_by_id(profile_id: str) -> dict[str, Any] | None:
|
| 47 |
+
return next((p for p in HARDWARE_PROFILES if p["id"] == profile_id), None)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def sample_profile(rng, axis_level: int = 0) -> dict[str, Any]:
|
| 51 |
+
"""Sample a hardware profile appropriate for the given axis level.
|
| 52 |
+
|
| 53 |
+
Per plan §3, axis_level escalates the hardware-class pool:
|
| 54 |
+
level 0 → only Class 0 (easy)
|
| 55 |
+
level 1 → Class 0 + 1
|
| 56 |
+
level 2 → all training profiles (Class 0 + 1 + 2 minus held-out)
|
| 57 |
+
"""
|
| 58 |
+
pool: list[dict[str, Any]] = []
|
| 59 |
+
for level in range(min(axis_level, 2) + 1):
|
| 60 |
+
pool.extend(HARDWARE_BY_CLASS[level])
|
| 61 |
+
if not pool:
|
| 62 |
+
pool = HARDWARE_BY_CLASS[0]
|
| 63 |
+
return rng.choice(pool)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
__all__ = [
|
| 67 |
+
"HARDWARE_PROFILES",
|
| 68 |
+
"HARDWARE_BY_CLASS",
|
| 69 |
+
"HELD_OUT_PROFILES",
|
| 70 |
+
"profile_by_id",
|
| 71 |
+
"sample_profile",
|
| 72 |
+
]
|
server/scenarios/trap_library.py
ADDED
|
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""30 anti-gaming trap functions (per plan §10b).
|
| 2 |
+
|
| 3 |
+
Each trap is a Python function designed to fail naive C++ translation through
|
| 4 |
+
one of these failure modes:
|
| 5 |
+
overflow — Python int unbounded; C++ int wraps at 2^31
|
| 6 |
+
fp_order — float accumulation order changes result
|
| 7 |
+
aliasing — numpy arrays may alias; C++ `restrict` breaks them
|
| 8 |
+
edge_empty — empty input
|
| 9 |
+
nan_inf — special float values
|
| 10 |
+
unicode — string handling
|
| 11 |
+
boundary — INT_MAX, denormals
|
| 12 |
+
semantics — Python-specific behavior (None, slicing, generators)
|
| 13 |
+
|
| 14 |
+
Each trap has metadata:
|
| 15 |
+
- id: stable identifier
|
| 16 |
+
- category: one of the failure modes above
|
| 17 |
+
- python_code: the source
|
| 18 |
+
- bottleneck_label: ground-truth labels for DiagnosisRubric
|
| 19 |
+
- rtol_override: None (default 1e-5) or 0 for bit-exact
|
| 20 |
+
|
| 21 |
+
15% of every batch comes from this library (per plan §4.3). 10 traps are held
|
| 22 |
+
out for the Gen-4 evaluation split.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from __future__ import annotations
|
| 26 |
+
|
| 27 |
+
from dataclasses import dataclass, field
|
| 28 |
+
from typing import Any
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class Trap:
|
| 33 |
+
id: str
|
| 34 |
+
category: str
|
| 35 |
+
python_code: str
|
| 36 |
+
bottleneck_label: list[str] = field(default_factory=list)
|
| 37 |
+
bottleneck_distractors: list[str] = field(default_factory=list)
|
| 38 |
+
rtol_override: float | None = None
|
| 39 |
+
held_out: bool = False
|
| 40 |
+
description: str = ""
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# Default distractor pool — used by all traps unless overridden
|
| 44 |
+
_DEFAULT_DISTRACTORS = ["memory-bound", "branch-heavy", "io-bound", "cache-unfriendly"]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
TRAP_LIBRARY: list[Trap] = [
|
| 48 |
+
# -------- Category 1: int overflow (4 traps) --------
|
| 49 |
+
Trap(
|
| 50 |
+
id="overflow_factorial",
|
| 51 |
+
category="overflow",
|
| 52 |
+
python_code=(
|
| 53 |
+
"def factorial(n: int) -> int:\n"
|
| 54 |
+
" r = 1\n"
|
| 55 |
+
" for i in range(2, n + 1):\n"
|
| 56 |
+
" r *= i\n"
|
| 57 |
+
" return r\n"
|
| 58 |
+
),
|
| 59 |
+
bottleneck_label=["compute-bound"],
|
| 60 |
+
bottleneck_distractors=_DEFAULT_DISTRACTORS,
|
| 61 |
+
rtol_override=0, # bit-exact integer
|
| 62 |
+
description="Python big-int math; C++ int overflows past 12!",
|
| 63 |
+
),
|
| 64 |
+
Trap(
|
| 65 |
+
id="overflow_power",
|
| 66 |
+
category="overflow",
|
| 67 |
+
python_code=(
|
| 68 |
+
"def power_sum(base: int, exp: int) -> int:\n"
|
| 69 |
+
" total = 0\n"
|
| 70 |
+
" for i in range(1, exp + 1):\n"
|
| 71 |
+
" total += base ** i\n"
|
| 72 |
+
" return total\n"
|
| 73 |
+
),
|
| 74 |
+
bottleneck_label=["compute-bound"],
|
| 75 |
+
rtol_override=0,
|
| 76 |
+
),
|
| 77 |
+
Trap(
|
| 78 |
+
id="overflow_signed_bitshift",
|
| 79 |
+
category="overflow",
|
| 80 |
+
python_code=(
|
| 81 |
+
"def shift_accumulate(arr: list) -> int:\n"
|
| 82 |
+
" total = 0\n"
|
| 83 |
+
" for x in arr:\n"
|
| 84 |
+
" total += (x << 30)\n"
|
| 85 |
+
" return total\n"
|
| 86 |
+
),
|
| 87 |
+
bottleneck_label=["compute-bound"],
|
| 88 |
+
rtol_override=0,
|
| 89 |
+
),
|
| 90 |
+
Trap(
|
| 91 |
+
id="overflow_int64_sum",
|
| 92 |
+
category="overflow",
|
| 93 |
+
python_code=(
|
| 94 |
+
"def big_sum(arr: list) -> int:\n"
|
| 95 |
+
" total = 0\n"
|
| 96 |
+
" for x in arr:\n"
|
| 97 |
+
" total += x * x * x\n"
|
| 98 |
+
" return total\n"
|
| 99 |
+
),
|
| 100 |
+
bottleneck_label=["compute-bound", "vectorizable"],
|
| 101 |
+
rtol_override=0,
|
| 102 |
+
),
|
| 103 |
+
|
| 104 |
+
# -------- Category 2: floating point accumulation order (5 traps) --------
|
| 105 |
+
Trap(
|
| 106 |
+
id="fp_kahan_drift",
|
| 107 |
+
category="fp_order",
|
| 108 |
+
python_code=(
|
| 109 |
+
"def kahan_sum(arr):\n"
|
| 110 |
+
" s = 0.0\n"
|
| 111 |
+
" c = 0.0\n"
|
| 112 |
+
" for x in arr:\n"
|
| 113 |
+
" y = x - c\n"
|
| 114 |
+
" t = s + y\n"
|
| 115 |
+
" c = (t - s) - y\n"
|
| 116 |
+
" s = t\n"
|
| 117 |
+
" return s\n"
|
| 118 |
+
),
|
| 119 |
+
bottleneck_label=["compute-bound"],
|
| 120 |
+
description="Kahan compensated summation — C++ reorder breaks compensation",
|
| 121 |
+
),
|
| 122 |
+
Trap(
|
| 123 |
+
id="fp_pairwise_var",
|
| 124 |
+
category="fp_order",
|
| 125 |
+
python_code=(
|
| 126 |
+
"def variance(arr):\n"
|
| 127 |
+
" n = len(arr)\n"
|
| 128 |
+
" mean = sum(arr) / n\n"
|
| 129 |
+
" return sum((x - mean) ** 2 for x in arr) / n\n"
|
| 130 |
+
),
|
| 131 |
+
bottleneck_label=["compute-bound"],
|
| 132 |
+
),
|
| 133 |
+
Trap(
|
| 134 |
+
id="fp_chained_mul",
|
| 135 |
+
category="fp_order",
|
| 136 |
+
python_code=(
|
| 137 |
+
"def chain_mul(arr):\n"
|
| 138 |
+
" p = 1.0\n"
|
| 139 |
+
" for x in arr:\n"
|
| 140 |
+
" p *= x\n"
|
| 141 |
+
" return p\n"
|
| 142 |
+
),
|
| 143 |
+
bottleneck_label=["compute-bound"],
|
| 144 |
+
),
|
| 145 |
+
Trap(
|
| 146 |
+
id="fp_subnormal_handling",
|
| 147 |
+
category="fp_order",
|
| 148 |
+
python_code=(
|
| 149 |
+
"def near_zero_sum(arr):\n"
|
| 150 |
+
" return sum(x for x in arr if abs(x) > 1e-300)\n"
|
| 151 |
+
),
|
| 152 |
+
bottleneck_label=["compute-bound", "branch-heavy"],
|
| 153 |
+
),
|
| 154 |
+
Trap(
|
| 155 |
+
id="fp_log_sum_exp",
|
| 156 |
+
category="fp_order",
|
| 157 |
+
python_code=(
|
| 158 |
+
"import math\n"
|
| 159 |
+
"def log_sum_exp(arr):\n"
|
| 160 |
+
" m = max(arr)\n"
|
| 161 |
+
" return m + math.log(sum(math.exp(x - m) for x in arr))\n"
|
| 162 |
+
),
|
| 163 |
+
bottleneck_label=["compute-bound"],
|
| 164 |
+
),
|
| 165 |
+
|
| 166 |
+
# -------- Category 3: aliasing (3 traps) --------
|
| 167 |
+
Trap(
|
| 168 |
+
id="aliasing_in_place",
|
| 169 |
+
category="aliasing",
|
| 170 |
+
python_code=(
|
| 171 |
+
"def in_place_smooth(a):\n"
|
| 172 |
+
" n = len(a)\n"
|
| 173 |
+
" for i in range(1, n - 1):\n"
|
| 174 |
+
" a[i] = (a[i-1] + a[i] + a[i+1]) / 3.0\n"
|
| 175 |
+
" return a\n"
|
| 176 |
+
),
|
| 177 |
+
bottleneck_label=["memory-bound"],
|
| 178 |
+
bottleneck_distractors=["compute-bound", "branch-heavy", "io-bound"],
|
| 179 |
+
description="Read-after-write across iterations; `restrict` would break correctness",
|
| 180 |
+
),
|
| 181 |
+
Trap(
|
| 182 |
+
id="aliasing_two_views",
|
| 183 |
+
category="aliasing",
|
| 184 |
+
python_code=(
|
| 185 |
+
"def add_views(a, b):\n"
|
| 186 |
+
" n = len(a)\n"
|
| 187 |
+
" for i in range(n):\n"
|
| 188 |
+
" a[i] += b[i] * 2\n"
|
| 189 |
+
" return a\n"
|
| 190 |
+
),
|
| 191 |
+
bottleneck_label=["memory-bound", "vectorizable"],
|
| 192 |
+
description="`a` and `b` may overlap; agent must not blindly add `__restrict__`",
|
| 193 |
+
),
|
| 194 |
+
Trap(
|
| 195 |
+
id="aliasing_self_copy",
|
| 196 |
+
category="aliasing",
|
| 197 |
+
python_code=(
|
| 198 |
+
"def shift_left(a):\n"
|
| 199 |
+
" n = len(a)\n"
|
| 200 |
+
" for i in range(n - 1):\n"
|
| 201 |
+
" a[i] = a[i + 1]\n"
|
| 202 |
+
" return a\n"
|
| 203 |
+
),
|
| 204 |
+
bottleneck_label=["memory-bound"],
|
| 205 |
+
),
|
| 206 |
+
|
| 207 |
+
# -------- Category 4: edge case empty / single (3 traps) --------
|
| 208 |
+
Trap(
|
| 209 |
+
id="edge_empty_max",
|
| 210 |
+
category="edge_empty",
|
| 211 |
+
python_code=(
|
| 212 |
+
"def safe_max(arr):\n"
|
| 213 |
+
" if len(arr) == 0:\n"
|
| 214 |
+
" return 0.0\n"
|
| 215 |
+
" return max(arr)\n"
|
| 216 |
+
),
|
| 217 |
+
bottleneck_label=["branch-heavy"],
|
| 218 |
+
),
|
| 219 |
+
Trap(
|
| 220 |
+
id="edge_singleton",
|
| 221 |
+
category="edge_empty",
|
| 222 |
+
python_code=(
|
| 223 |
+
"def doubled_diff(arr):\n"
|
| 224 |
+
" if len(arr) <= 1:\n"
|
| 225 |
+
" return 0.0\n"
|
| 226 |
+
" return sum(arr[i+1] - arr[i] for i in range(len(arr) - 1))\n"
|
| 227 |
+
),
|
| 228 |
+
bottleneck_label=["compute-bound", "branch-heavy"],
|
| 229 |
+
),
|
| 230 |
+
Trap(
|
| 231 |
+
id="edge_zero_division",
|
| 232 |
+
category="edge_empty",
|
| 233 |
+
python_code=(
|
| 234 |
+
"def normalize(arr):\n"
|
| 235 |
+
" s = sum(arr)\n"
|
| 236 |
+
" if s == 0:\n"
|
| 237 |
+
" return [0.0 for _ in arr]\n"
|
| 238 |
+
" return [x / s for x in arr]\n"
|
| 239 |
+
),
|
| 240 |
+
bottleneck_label=["compute-bound", "branch-heavy"],
|
| 241 |
+
),
|
| 242 |
+
|
| 243 |
+
# -------- Category 5: NaN/Inf (3 traps) --------
|
| 244 |
+
Trap(
|
| 245 |
+
id="nan_propagation",
|
| 246 |
+
category="nan_inf",
|
| 247 |
+
python_code=(
|
| 248 |
+
"import math\n"
|
| 249 |
+
"def filter_finite(arr):\n"
|
| 250 |
+
" return sum(x for x in arr if math.isfinite(x))\n"
|
| 251 |
+
),
|
| 252 |
+
bottleneck_label=["branch-heavy"],
|
| 253 |
+
),
|
| 254 |
+
Trap(
|
| 255 |
+
id="inf_arithmetic",
|
| 256 |
+
category="nan_inf",
|
| 257 |
+
python_code=(
|
| 258 |
+
"import math\n"
|
| 259 |
+
"def soft_clamp(arr):\n"
|
| 260 |
+
" return [x if math.isfinite(x) else 0.0 for x in arr]\n"
|
| 261 |
+
),
|
| 262 |
+
bottleneck_label=["branch-heavy"],
|
| 263 |
+
),
|
| 264 |
+
Trap(
|
| 265 |
+
id="nan_aware_min",
|
| 266 |
+
category="nan_inf",
|
| 267 |
+
python_code=(
|
| 268 |
+
"import math\n"
|
| 269 |
+
"def nan_aware_min(arr):\n"
|
| 270 |
+
" finite = [x for x in arr if not math.isnan(x)]\n"
|
| 271 |
+
" return min(finite) if finite else 0.0\n"
|
| 272 |
+
),
|
| 273 |
+
bottleneck_label=["branch-heavy"],
|
| 274 |
+
),
|
| 275 |
+
|
| 276 |
+
# -------- Category 6: boundary values (3 traps) --------
|
| 277 |
+
Trap(
|
| 278 |
+
id="boundary_signed_compare",
|
| 279 |
+
category="boundary",
|
| 280 |
+
python_code=(
|
| 281 |
+
"def count_negatives(arr: list) -> int:\n"
|
| 282 |
+
" return sum(1 for x in arr if x < 0)\n"
|
| 283 |
+
),
|
| 284 |
+
bottleneck_label=["branch-heavy", "vectorizable"],
|
| 285 |
+
rtol_override=0,
|
| 286 |
+
),
|
| 287 |
+
Trap(
|
| 288 |
+
id="boundary_min_int",
|
| 289 |
+
category="boundary",
|
| 290 |
+
python_code=(
|
| 291 |
+
"def abs_sum(arr: list) -> int:\n"
|
| 292 |
+
" return sum(abs(x) for x in arr)\n"
|
| 293 |
+
),
|
| 294 |
+
bottleneck_label=["compute-bound"],
|
| 295 |
+
rtol_override=0,
|
| 296 |
+
description="abs(INT_MIN) overflows in C++; Python handles transparently",
|
| 297 |
+
),
|
| 298 |
+
Trap(
|
| 299 |
+
id="boundary_denormal_threshold",
|
| 300 |
+
category="boundary",
|
| 301 |
+
python_code=(
|
| 302 |
+
"def threshold_count(arr):\n"
|
| 303 |
+
" return sum(1 for x in arr if abs(x) > 1e-308)\n"
|
| 304 |
+
),
|
| 305 |
+
bottleneck_label=["branch-heavy"],
|
| 306 |
+
),
|
| 307 |
+
|
| 308 |
+
# -------- Category 7: semantics (5 traps) --------
|
| 309 |
+
Trap(
|
| 310 |
+
id="semantics_negative_index",
|
| 311 |
+
category="semantics",
|
| 312 |
+
python_code=(
|
| 313 |
+
"def last_diff(arr):\n"
|
| 314 |
+
" return arr[-1] - arr[0] if len(arr) >= 1 else 0\n"
|
| 315 |
+
),
|
| 316 |
+
bottleneck_label=["compute-bound"],
|
| 317 |
+
description="Python a[-1] = last element; C++ a[-1] = UB",
|
| 318 |
+
),
|
| 319 |
+
Trap(
|
| 320 |
+
id="semantics_empty_sum",
|
| 321 |
+
category="semantics",
|
| 322 |
+
python_code=(
|
| 323 |
+
"def opt_avg(arr):\n"
|
| 324 |
+
" return sum(arr) / len(arr) if arr else 0.0\n"
|
| 325 |
+
),
|
| 326 |
+
bottleneck_label=["compute-bound", "branch-heavy"],
|
| 327 |
+
),
|
| 328 |
+
Trap(
|
| 329 |
+
id="semantics_truthy_filter",
|
| 330 |
+
category="semantics",
|
| 331 |
+
python_code=(
|
| 332 |
+
"def count_truthy(arr):\n"
|
| 333 |
+
" return sum(1 for x in arr if x)\n"
|
| 334 |
+
),
|
| 335 |
+
bottleneck_label=["branch-heavy"],
|
| 336 |
+
description="Python truthy includes [], 0, '', None; C++ has different semantics",
|
| 337 |
+
rtol_override=0,
|
| 338 |
+
),
|
| 339 |
+
Trap(
|
| 340 |
+
id="semantics_int_div",
|
| 341 |
+
category="semantics",
|
| 342 |
+
python_code=(
|
| 343 |
+
"def floor_avg(arr: list) -> int:\n"
|
| 344 |
+
" return sum(arr) // len(arr) if arr else 0\n"
|
| 345 |
+
),
|
| 346 |
+
bottleneck_label=["compute-bound"],
|
| 347 |
+
rtol_override=0,
|
| 348 |
+
description="// is floor div in Python (correct for negatives); C++ / truncates toward zero",
|
| 349 |
+
),
|
| 350 |
+
Trap(
|
| 351 |
+
id="semantics_modulo_negative",
|
| 352 |
+
category="semantics",
|
| 353 |
+
python_code=(
|
| 354 |
+
"def positive_mod_sum(arr: list, m: int) -> int:\n"
|
| 355 |
+
" return sum(x % m for x in arr)\n"
|
| 356 |
+
),
|
| 357 |
+
bottleneck_label=["compute-bound"],
|
| 358 |
+
rtol_override=0,
|
| 359 |
+
description="Python % always returns non-negative for positive m; C++ may return negative",
|
| 360 |
+
),
|
| 361 |
+
|
| 362 |
+
# -------- Category 8: held-out for Gen-4 (4 traps) --------
|
| 363 |
+
Trap(
|
| 364 |
+
id="holdout_kahan_sum_2",
|
| 365 |
+
category="fp_order",
|
| 366 |
+
python_code=(
|
| 367 |
+
"def stable_total(arr):\n"
|
| 368 |
+
" s = 0.0\n"
|
| 369 |
+
" err = 0.0\n"
|
| 370 |
+
" for x in arr:\n"
|
| 371 |
+
" y = x + err\n"
|
| 372 |
+
" new_s = s + y\n"
|
| 373 |
+
" err = y - (new_s - s)\n"
|
| 374 |
+
" s = new_s\n"
|
| 375 |
+
" return s\n"
|
| 376 |
+
),
|
| 377 |
+
bottleneck_label=["compute-bound"],
|
| 378 |
+
held_out=True,
|
| 379 |
+
),
|
| 380 |
+
Trap(
|
| 381 |
+
id="holdout_overflow_combinations",
|
| 382 |
+
category="overflow",
|
| 383 |
+
python_code=(
|
| 384 |
+
"def n_choose_k(n: int, k: int) -> int:\n"
|
| 385 |
+
" if k > n - k:\n"
|
| 386 |
+
" k = n - k\n"
|
| 387 |
+
" r = 1\n"
|
| 388 |
+
" for i in range(k):\n"
|
| 389 |
+
" r = r * (n - i) // (i + 1)\n"
|
| 390 |
+
" return r\n"
|
| 391 |
+
),
|
| 392 |
+
bottleneck_label=["compute-bound"],
|
| 393 |
+
rtol_override=0,
|
| 394 |
+
held_out=True,
|
| 395 |
+
),
|
| 396 |
+
Trap(
|
| 397 |
+
id="holdout_aliasing_swap",
|
| 398 |
+
category="aliasing",
|
| 399 |
+
python_code=(
|
| 400 |
+
"def reverse_in_place(a):\n"
|
| 401 |
+
" n = len(a)\n"
|
| 402 |
+
" for i in range(n // 2):\n"
|
| 403 |
+
" a[i], a[n - 1 - i] = a[n - 1 - i], a[i]\n"
|
| 404 |
+
" return a\n"
|
| 405 |
+
),
|
| 406 |
+
bottleneck_label=["memory-bound"],
|
| 407 |
+
held_out=True,
|
| 408 |
+
),
|
| 409 |
+
Trap(
|
| 410 |
+
id="holdout_semantics_chained_compare",
|
| 411 |
+
category="semantics",
|
| 412 |
+
python_code=(
|
| 413 |
+
"def in_range_count(arr, lo: float, hi: float) -> int:\n"
|
| 414 |
+
" return sum(1 for x in arr if lo < x < hi)\n"
|
| 415 |
+
),
|
| 416 |
+
bottleneck_label=["branch-heavy"],
|
| 417 |
+
rtol_override=0,
|
| 418 |
+
held_out=True,
|
| 419 |
+
description="Python a < x < b is single test; agent may write incorrect (a < x) < b in C++",
|
| 420 |
+
),
|
| 421 |
+
]
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def get_trap_by_id(trap_id: str) -> Trap | None:
|
| 425 |
+
return next((t for t in TRAP_LIBRARY if t.id == trap_id), None)
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def sample_trap(rng, exclude_held_out: bool = True) -> Trap:
|
| 429 |
+
"""Sample a random trap. By default excludes the Gen-4 held-out subset."""
|
| 430 |
+
pool = [t for t in TRAP_LIBRARY if not (exclude_held_out and t.held_out)]
|
| 431 |
+
return rng.choice(pool)
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def sample_trap_by_category(category: str, rng, exclude_held_out: bool = True) -> Trap | None:
|
| 435 |
+
"""Sample one trap from a specific category. Returns None if unavailable."""
|
| 436 |
+
pool = [
|
| 437 |
+
t for t in TRAP_LIBRARY
|
| 438 |
+
if t.category == category and not (exclude_held_out and t.held_out)
|
| 439 |
+
]
|
| 440 |
+
if not pool:
|
| 441 |
+
return None
|
| 442 |
+
return rng.choice(pool)
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def trap_to_problem_dict(trap: Trap, hw_profile: dict[str, Any]) -> dict[str, Any]:
|
| 446 |
+
"""Convert a Trap into the env._sample_problem() return shape."""
|
| 447 |
+
# Default distractor pool excluding the trap's true labels
|
| 448 |
+
distractors = [d for d in (trap.bottleneck_distractors or _DEFAULT_DISTRACTORS)
|
| 449 |
+
if d not in trap.bottleneck_label]
|
| 450 |
+
return {
|
| 451 |
+
"python_code": trap.python_code,
|
| 452 |
+
"cpp_signature": _infer_cpp_signature(trap.python_code),
|
| 453 |
+
"hardware_profile": hw_profile,
|
| 454 |
+
"bottleneck_labels": trap.bottleneck_label,
|
| 455 |
+
"bottleneck_distractors": distractors,
|
| 456 |
+
"rtol_override": trap.rtol_override,
|
| 457 |
+
"is_trap": True,
|
| 458 |
+
"trap_id": trap.id,
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def _infer_cpp_signature(python_code: str) -> str:
|
| 463 |
+
"""Best-effort C++ signature derivation from a Python def. Refined in Hour 22 smoke test."""
|
| 464 |
+
import ast
|
| 465 |
+
try:
|
| 466 |
+
tree = ast.parse(python_code)
|
| 467 |
+
fn = next(n for n in tree.body if isinstance(n, ast.FunctionDef))
|
| 468 |
+
return f'extern "C" void agent_function(/* {len(fn.args.args)} args from Python */ );'
|
| 469 |
+
except Exception:
|
| 470 |
+
return 'extern "C" void agent_function(void* in, size_t n, void* out);'
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
# Public counts for assertions
|
| 474 |
+
N_TRAPS_TOTAL = len(TRAP_LIBRARY)
|
| 475 |
+
N_TRAPS_TRAINING = sum(1 for t in TRAP_LIBRARY if not t.held_out)
|
| 476 |
+
N_TRAPS_HELDOUT = sum(1 for t in TRAP_LIBRARY if t.held_out)
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
__all__ = [
|
| 480 |
+
"Trap",
|
| 481 |
+
"TRAP_LIBRARY",
|
| 482 |
+
"get_trap_by_id",
|
| 483 |
+
"sample_trap",
|
| 484 |
+
"sample_trap_by_category",
|
| 485 |
+
"trap_to_problem_dict",
|
| 486 |
+
"N_TRAPS_TOTAL",
|
| 487 |
+
"N_TRAPS_TRAINING",
|
| 488 |
+
"N_TRAPS_HELDOUT",
|
| 489 |
+
]
|
server/tools/__init__.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MCP tool registry for Polyglot-Optima.
|
| 2 |
+
|
| 3 |
+
Exposes 9 tools per plan §9. The TOOL_REGISTRY dict is loaded by the environment
|
| 4 |
+
at startup and dispatched from PolyglotOptimaEnvironment._dispatch_tool.
|
| 5 |
+
|
| 6 |
+
Each tool is a plain Python callable (tool_args: dict, state: OptimizationState) -> dict.
|
| 7 |
+
The @tool decorator (Hour 22 deployment-time wrapper) adds Pydantic schema
|
| 8 |
+
validation, mode tagging, and async dispatch — for now, plain functions.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
from .hardware_profiler import get_hardware_profile_tool
|
| 14 |
+
from .python_analyzer import (
|
| 15 |
+
profile_python_hotspots_tool,
|
| 16 |
+
analyze_complexity_tool,
|
| 17 |
+
check_memory_access_tool,
|
| 18 |
+
)
|
| 19 |
+
from .cpp_compiler import compile_and_benchmark_tool
|
| 20 |
+
from .verifier import verify_equivalence_tool
|
| 21 |
+
from .portability_checker import check_portability_tool
|
| 22 |
+
from .bottleneck_reporter import get_bottleneck_report_tool
|
| 23 |
+
from .submit import submit_optimization_tool
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
TOOL_REGISTRY = {
|
| 27 |
+
"get_hardware_profile": get_hardware_profile_tool,
|
| 28 |
+
"profile_python_hotspots": profile_python_hotspots_tool,
|
| 29 |
+
"analyze_complexity": analyze_complexity_tool,
|
| 30 |
+
"check_memory_access": check_memory_access_tool,
|
| 31 |
+
"compile_and_benchmark": compile_and_benchmark_tool,
|
| 32 |
+
"verify_equivalence": verify_equivalence_tool,
|
| 33 |
+
"check_portability": check_portability_tool,
|
| 34 |
+
"get_bottleneck_report": get_bottleneck_report_tool,
|
| 35 |
+
"submit_optimization": submit_optimization_tool,
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
__all__ = ["TOOL_REGISTRY"]
|
server/tools/_runtime.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ctypes-based runtime dispatch for compiled agent C++.
|
| 2 |
+
|
| 3 |
+
Replaces the Hour 4-10 stubs in cpp_compiler._benchmark_cpp and verifier._exec_cpp_via_so
|
| 4 |
+
with real measurement.
|
| 5 |
+
|
| 6 |
+
Canonical agent function signature (system-prompted, enforced by all training data):
|
| 7 |
+
|
| 8 |
+
extern "C" void agent_function(
|
| 9 |
+
const double* in_ptr, // flattened input (all args concatenated to float64)
|
| 10 |
+
size_t in_n, // total input length
|
| 11 |
+
double* out_ptr, // preallocated output buffer (caller-allocated, agent fills)
|
| 12 |
+
size_t out_n // output buffer size
|
| 13 |
+
);
|
| 14 |
+
|
| 15 |
+
This uniform signature trades some type richness (everything's float64) for:
|
| 16 |
+
- Simple ctypes binding (no per-function ABI generation)
|
| 17 |
+
- Trivial for the agent to write
|
| 18 |
+
- Covers all numeric training functions (sklearn loops, NumPy ops, math kernels)
|
| 19 |
+
|
| 20 |
+
Inputs/outputs are float64 (8 bytes). For integer functions we cast at the
|
| 21 |
+
boundary; for the few bit-exact integer functions in the trap library, the
|
| 22 |
+
fuzzer's `rtol=0` semantics still catch divergence (e.g., int overflow modes
|
| 23 |
+
that propagate as different float values).
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import ctypes
|
| 29 |
+
import time
|
| 30 |
+
from typing import Any, Callable
|
| 31 |
+
|
| 32 |
+
import numpy as np
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ---------------------- Argument marshalling ----------------------
|
| 36 |
+
|
| 37 |
+
def _flatten_args(args: tuple) -> tuple[np.ndarray, list]:
|
| 38 |
+
"""Concatenate all args into one flat float64 array; remember per-arg shapes for the agent.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
flat: a single contiguous float64 array (the in_ptr buffer)
|
| 42 |
+
shapes: list of (kind, shape, dtype) for each arg — informational, not used by the
|
| 43 |
+
ABI itself but useful for debugging
|
| 44 |
+
"""
|
| 45 |
+
flats: list[np.ndarray] = []
|
| 46 |
+
shapes: list[tuple] = []
|
| 47 |
+
for a in args:
|
| 48 |
+
if isinstance(a, np.ndarray):
|
| 49 |
+
shapes.append(("ndarray", a.shape, a.dtype))
|
| 50 |
+
flats.append(np.ascontiguousarray(a, dtype=np.float64).ravel())
|
| 51 |
+
elif isinstance(a, (int, float, np.integer, np.floating)):
|
| 52 |
+
shapes.append(("scalar", (), type(a)))
|
| 53 |
+
flats.append(np.array([float(a)], dtype=np.float64))
|
| 54 |
+
elif isinstance(a, (list, tuple)):
|
| 55 |
+
arr = np.array(a, dtype=np.float64)
|
| 56 |
+
shapes.append(("list", arr.shape, np.float64))
|
| 57 |
+
flats.append(arr.ravel())
|
| 58 |
+
else:
|
| 59 |
+
raise TypeError(f"unsupported arg type for agent_function: {type(a).__name__}")
|
| 60 |
+
if not flats:
|
| 61 |
+
return np.array([], dtype=np.float64), shapes
|
| 62 |
+
return np.concatenate(flats).astype(np.float64, copy=False), shapes
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _infer_output_meta(py_fn: Callable, args: tuple) -> dict[str, Any]:
|
| 66 |
+
"""Run py_fn once to discover output shape + dtype. Used to size the C++ output buffer."""
|
| 67 |
+
out = py_fn(*args)
|
| 68 |
+
if isinstance(out, (int, np.integer)):
|
| 69 |
+
return {"kind": "int", "size": 1, "shape": (), "dtype": int}
|
| 70 |
+
if isinstance(out, (float, np.floating)):
|
| 71 |
+
return {"kind": "float", "size": 1, "shape": (), "dtype": float}
|
| 72 |
+
if isinstance(out, np.ndarray):
|
| 73 |
+
return {"kind": "ndarray", "size": int(out.size), "shape": tuple(out.shape), "dtype": out.dtype}
|
| 74 |
+
if isinstance(out, (list, tuple)):
|
| 75 |
+
arr = np.array(out, dtype=np.float64)
|
| 76 |
+
return {"kind": "list", "size": int(arr.size), "shape": tuple(arr.shape), "dtype": np.float64}
|
| 77 |
+
raise TypeError(f"unsupported py_fn output type: {type(out).__name__}")
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _reshape_cpp_output(out_arr: np.ndarray, meta: dict[str, Any]) -> Any:
|
| 81 |
+
"""Reshape the flat output buffer back to py_fn's original output kind/shape."""
|
| 82 |
+
if meta["kind"] == "int":
|
| 83 |
+
return int(round(float(out_arr[0])))
|
| 84 |
+
if meta["kind"] == "float":
|
| 85 |
+
return float(out_arr[0])
|
| 86 |
+
if meta["kind"] == "ndarray":
|
| 87 |
+
return out_arr[: meta["size"]].reshape(meta["shape"]).astype(meta["dtype"], copy=False)
|
| 88 |
+
if meta["kind"] == "list":
|
| 89 |
+
return out_arr[: meta["size"]].reshape(meta["shape"]).tolist()
|
| 90 |
+
return out_arr
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# ---------------------- .so loader (cached) ----------------------
|
| 94 |
+
|
| 95 |
+
class _SOLoader:
|
| 96 |
+
"""Cache loaded ctypes libraries by path. Each .so loaded only once."""
|
| 97 |
+
_cache: dict[str, ctypes.CDLL] = {}
|
| 98 |
+
|
| 99 |
+
@classmethod
|
| 100 |
+
def load(cls, so_path: str) -> ctypes.CDLL:
|
| 101 |
+
if so_path in cls._cache:
|
| 102 |
+
return cls._cache[so_path]
|
| 103 |
+
lib = ctypes.CDLL(so_path)
|
| 104 |
+
if not hasattr(lib, "agent_function"):
|
| 105 |
+
raise RuntimeError(f"{so_path} does not export `agent_function`")
|
| 106 |
+
lib.agent_function.argtypes = [
|
| 107 |
+
ctypes.POINTER(ctypes.c_double), # in_ptr
|
| 108 |
+
ctypes.c_size_t, # in_n
|
| 109 |
+
ctypes.POINTER(ctypes.c_double), # out_ptr
|
| 110 |
+
ctypes.c_size_t, # out_n
|
| 111 |
+
]
|
| 112 |
+
lib.agent_function.restype = None
|
| 113 |
+
cls._cache[so_path] = lib
|
| 114 |
+
return lib
|
| 115 |
+
|
| 116 |
+
@classmethod
|
| 117 |
+
def clear(cls) -> None:
|
| 118 |
+
cls._cache.clear()
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# ---------------------- Public dispatch API ----------------------
|
| 122 |
+
|
| 123 |
+
def call_compiled(so_path: str, py_fn: Callable, args: tuple) -> Any:
|
| 124 |
+
"""Call agent_function in the .so on args. Return value matches py_fn's output shape.
|
| 125 |
+
|
| 126 |
+
Raises:
|
| 127 |
+
RuntimeError: if .so can't be loaded or `agent_function` symbol is missing
|
| 128 |
+
"""
|
| 129 |
+
lib = _SOLoader.load(so_path)
|
| 130 |
+
|
| 131 |
+
in_flat, _ = _flatten_args(args)
|
| 132 |
+
in_arr = np.ascontiguousarray(in_flat, dtype=np.float64)
|
| 133 |
+
in_ptr = in_arr.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
|
| 134 |
+
|
| 135 |
+
out_meta = _infer_output_meta(py_fn, args)
|
| 136 |
+
out_arr = np.zeros(out_meta["size"], dtype=np.float64)
|
| 137 |
+
out_ptr = out_arr.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
|
| 138 |
+
|
| 139 |
+
lib.agent_function(in_ptr, ctypes.c_size_t(in_arr.size),
|
| 140 |
+
out_ptr, ctypes.c_size_t(out_meta["size"]))
|
| 141 |
+
|
| 142 |
+
return _reshape_cpp_output(out_arr, out_meta)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def benchmark_python_vs_cpp(
|
| 146 |
+
so_path: str,
|
| 147 |
+
py_fn: Callable,
|
| 148 |
+
args: tuple,
|
| 149 |
+
n_per_repeat: int = 5,
|
| 150 |
+
repeats: int = 3,
|
| 151 |
+
) -> dict[str, float]:
|
| 152 |
+
"""Median-of-(repeats×n_per_repeat) wall time for both Python and C++ on the SAME args.
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
py_median_ms: float — median ms per Python call
|
| 156 |
+
cpp_median_ms: float — median ms per C++ call (via ctypes)
|
| 157 |
+
speedup: float — py_median_ms / cpp_median_ms
|
| 158 |
+
"""
|
| 159 |
+
lib = _SOLoader.load(so_path)
|
| 160 |
+
|
| 161 |
+
# Pre-flatten inputs ONCE — re-flattening would pollute timing
|
| 162 |
+
in_flat, _ = _flatten_args(args)
|
| 163 |
+
in_arr = np.ascontiguousarray(in_flat, dtype=np.float64)
|
| 164 |
+
in_ptr = in_arr.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
|
| 165 |
+
|
| 166 |
+
out_meta = _infer_output_meta(py_fn, args)
|
| 167 |
+
out_arr = np.zeros(out_meta["size"], dtype=np.float64)
|
| 168 |
+
out_ptr = out_arr.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
|
| 169 |
+
|
| 170 |
+
in_n = ctypes.c_size_t(in_arr.size)
|
| 171 |
+
out_n = ctypes.c_size_t(out_meta["size"])
|
| 172 |
+
|
| 173 |
+
# ---- Python timing ----
|
| 174 |
+
py_times: list[float] = []
|
| 175 |
+
for _ in range(repeats):
|
| 176 |
+
t0 = time.perf_counter()
|
| 177 |
+
for _ in range(n_per_repeat):
|
| 178 |
+
py_fn(*args)
|
| 179 |
+
elapsed = time.perf_counter() - t0
|
| 180 |
+
py_times.append((elapsed / n_per_repeat) * 1000)
|
| 181 |
+
py_times.sort()
|
| 182 |
+
py_median = py_times[len(py_times) // 2]
|
| 183 |
+
|
| 184 |
+
# ---- C++ timing ----
|
| 185 |
+
cpp_times: list[float] = []
|
| 186 |
+
for _ in range(repeats):
|
| 187 |
+
t0 = time.perf_counter()
|
| 188 |
+
for _ in range(n_per_repeat):
|
| 189 |
+
lib.agent_function(in_ptr, in_n, out_ptr, out_n)
|
| 190 |
+
elapsed = time.perf_counter() - t0
|
| 191 |
+
cpp_times.append((elapsed / n_per_repeat) * 1000)
|
| 192 |
+
cpp_times.sort()
|
| 193 |
+
cpp_median = cpp_times[len(cpp_times) // 2]
|
| 194 |
+
|
| 195 |
+
return {
|
| 196 |
+
"py_median_ms": py_median,
|
| 197 |
+
"cpp_median_ms": cpp_median,
|
| 198 |
+
"speedup": py_median / max(cpp_median, 1e-6),
|
| 199 |
+
"n_per_repeat": n_per_repeat,
|
| 200 |
+
"repeats": repeats,
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def time_python_only(py_fn: Callable, args: tuple, n_per_repeat: int = 5, repeats: int = 3) -> float:
|
| 205 |
+
"""Pure Python baseline timing (no .so needed). Returns median ms per call."""
|
| 206 |
+
times: list[float] = []
|
| 207 |
+
for _ in range(repeats):
|
| 208 |
+
t0 = time.perf_counter()
|
| 209 |
+
for _ in range(n_per_repeat):
|
| 210 |
+
py_fn(*args)
|
| 211 |
+
times.append((time.perf_counter() - t0) / n_per_repeat * 1000)
|
| 212 |
+
times.sort()
|
| 213 |
+
return times[len(times) // 2]
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
# ---------------------- Sample-input synthesizer ----------------------
|
| 217 |
+
|
| 218 |
+
def make_default_args_for(py_fn: Callable, n: int = 1024, seed: int = 0) -> tuple:
|
| 219 |
+
"""Construct a default (numeric ndarray + scalars) arg tuple for py_fn from its signature.
|
| 220 |
+
|
| 221 |
+
Used for the benchmark baseline when no specific input is provided.
|
| 222 |
+
Falls back to a 1024-element float64 array if introspection fails.
|
| 223 |
+
"""
|
| 224 |
+
import inspect
|
| 225 |
+
rng = np.random.default_rng(seed)
|
| 226 |
+
try:
|
| 227 |
+
sig = inspect.signature(py_fn)
|
| 228 |
+
params = list(sig.parameters.values())
|
| 229 |
+
except (ValueError, TypeError):
|
| 230 |
+
return (rng.standard_normal(n).astype(np.float64),)
|
| 231 |
+
|
| 232 |
+
out = []
|
| 233 |
+
for p in params:
|
| 234 |
+
ann = str(p.annotation).lower() if p.annotation is not inspect.Parameter.empty else ""
|
| 235 |
+
default = p.default if p.default is not inspect.Parameter.empty else None
|
| 236 |
+
if "int" in ann and "ndarray" not in ann and "list" not in ann:
|
| 237 |
+
out.append(default if isinstance(default, int) else int(rng.integers(2, 16)))
|
| 238 |
+
elif "float" in ann and "ndarray" not in ann and "list" not in ann:
|
| 239 |
+
out.append(default if isinstance(default, float) else float(rng.standard_normal()))
|
| 240 |
+
elif "list" in ann or "ndarray" in ann or ann == "":
|
| 241 |
+
out.append(rng.standard_normal(n).astype(np.float64))
|
| 242 |
+
elif "str" in ann:
|
| 243 |
+
out.append("hello world")
|
| 244 |
+
else:
|
| 245 |
+
out.append(rng.standard_normal(n).astype(np.float64))
|
| 246 |
+
return tuple(out)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
__all__ = [
|
| 250 |
+
"call_compiled",
|
| 251 |
+
"benchmark_python_vs_cpp",
|
| 252 |
+
"time_python_only",
|
| 253 |
+
"make_default_args_for",
|
| 254 |
+
"_SOLoader",
|
| 255 |
+
]
|
server/tools/bottleneck_reporter.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tool 8/9: get_bottleneck_report.
|
| 2 |
+
|
| 3 |
+
Returns a `perf stat`-style report for the agent's compiled C++ — instructions
|
| 4 |
+
per cycle, cache miss rate, vectorization status. Helps the agent diagnose
|
| 5 |
+
*why* its C++ is slow before refining.
|
| 6 |
+
|
| 7 |
+
Real implementation (Hour 16) reads /proc/perf_event or uses Linux perf_event_open
|
| 8 |
+
to collect counters during the benchmark run. For Hour 4-10, this is a heuristic
|
| 9 |
+
estimate based on static C++ analysis (looks for SIMD intrinsics, OpenMP, etc.).
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import re
|
| 15 |
+
from typing import Any
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
_SIMD_INTRINSIC_PATTERN = re.compile(
|
| 19 |
+
r"_mm\d+_|_mm_|vld\d+q?_|vst\d+q?_|vmul[a-z]?_|vadd[a-z]?_|"
|
| 20 |
+
r"__m\d+|svfloat|svint"
|
| 21 |
+
)
|
| 22 |
+
_OPENMP_PATTERN = re.compile(r"#\s*pragma\s+omp")
|
| 23 |
+
_RESTRICT_PATTERN = re.compile(r"\b__restrict__\b|\brestrict\b")
|
| 24 |
+
_LIKELY_PATTERN = re.compile(r"\[\[\s*(un)?likely\s*\]\]")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_bottleneck_report_tool(tool_args: dict[str, Any], state) -> dict[str, Any]:
|
| 28 |
+
"""Static analysis of agent's C++ → estimate of vectorization, parallelism, etc.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
cpp_code (str)
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
uses_simd (bool)
|
| 35 |
+
uses_openmp (bool)
|
| 36 |
+
uses_restrict (bool)
|
| 37 |
+
uses_branch_hints (bool)
|
| 38 |
+
estimated_ipc (float) — heuristic
|
| 39 |
+
estimated_cache_miss_rate (float)
|
| 40 |
+
estimated_vectorization_pct (float)
|
| 41 |
+
suggestions (list[str]) — hints for next round
|
| 42 |
+
"""
|
| 43 |
+
cpp_code = tool_args.get("cpp_code", "")
|
| 44 |
+
if not cpp_code.strip():
|
| 45 |
+
return {"error": "empty cpp_code"}
|
| 46 |
+
|
| 47 |
+
uses_simd = bool(_SIMD_INTRINSIC_PATTERN.search(cpp_code))
|
| 48 |
+
uses_openmp = bool(_OPENMP_PATTERN.search(cpp_code))
|
| 49 |
+
uses_restrict = bool(_RESTRICT_PATTERN.search(cpp_code))
|
| 50 |
+
uses_hints = bool(_LIKELY_PATTERN.search(cpp_code))
|
| 51 |
+
|
| 52 |
+
# Heuristic IPC estimate (1.0 = scalar, 4.0 = AVX2 SIMD, 8.0 = AVX-512)
|
| 53 |
+
simd_w = {"SSE4.2": 4, "AVX2": 8, "AVX-512": 16, "NEON": 4, "none": 1}.get(
|
| 54 |
+
state.hardware_profile.get("simd", "none"), 1
|
| 55 |
+
)
|
| 56 |
+
estimated_ipc = 0.8
|
| 57 |
+
if uses_simd:
|
| 58 |
+
estimated_ipc = min(simd_w * 0.6, 8.0)
|
| 59 |
+
if uses_openmp:
|
| 60 |
+
estimated_ipc *= min(state.hardware_profile.get("cores", 1), 4) * 0.7
|
| 61 |
+
|
| 62 |
+
estimated_cache_miss = 0.20
|
| 63 |
+
if uses_restrict:
|
| 64 |
+
estimated_cache_miss *= 0.7
|
| 65 |
+
|
| 66 |
+
estimated_vec_pct = 5.0
|
| 67 |
+
if uses_simd:
|
| 68 |
+
estimated_vec_pct = 80.0
|
| 69 |
+
elif uses_openmp:
|
| 70 |
+
estimated_vec_pct = 20.0 # GCC may auto-vectorize OpenMP loops
|
| 71 |
+
|
| 72 |
+
suggestions: list[str] = []
|
| 73 |
+
if not uses_simd and simd_w >= 4:
|
| 74 |
+
suggestions.append(
|
| 75 |
+
f"Hardware supports {state.hardware_profile['simd']} (width {simd_w}). "
|
| 76 |
+
f"Consider explicit SIMD intrinsics."
|
| 77 |
+
)
|
| 78 |
+
if not uses_openmp and state.hardware_profile.get("cores", 1) >= 4:
|
| 79 |
+
suggestions.append(
|
| 80 |
+
f"Hardware has {state.hardware_profile['cores']} cores. "
|
| 81 |
+
f"Add `#pragma omp parallel for` to outer loops."
|
| 82 |
+
)
|
| 83 |
+
if not uses_restrict and "ndarray" in state.python_code.lower():
|
| 84 |
+
suggestions.append(
|
| 85 |
+
"Add `__restrict__` to pointer args — tells the compiler arrays don't alias."
|
| 86 |
+
)
|
| 87 |
+
if not suggestions:
|
| 88 |
+
suggestions.append("Looks well-optimized. Refining further may yield marginal gains.")
|
| 89 |
+
|
| 90 |
+
return {
|
| 91 |
+
"uses_simd": uses_simd,
|
| 92 |
+
"uses_openmp": uses_openmp,
|
| 93 |
+
"uses_restrict": uses_restrict,
|
| 94 |
+
"uses_branch_hints": uses_hints,
|
| 95 |
+
"estimated_ipc": estimated_ipc,
|
| 96 |
+
"estimated_cache_miss_rate": estimated_cache_miss,
|
| 97 |
+
"estimated_vectorization_pct": estimated_vec_pct,
|
| 98 |
+
"suggestions": suggestions,
|
| 99 |
+
"method": "static_pattern_match",
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
__all__ = ["get_bottleneck_report_tool"]
|
server/tools/cpp_compiler.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tool 5/9: compile_and_benchmark.
|
| 2 |
+
|
| 3 |
+
Compiles agent C++ with `g++ -O3 -march=native -fopenmp -std=c++20 -Wall -Werror`
|
| 4 |
+
and benchmarks against the Python baseline using median-of-15 wall time.
|
| 5 |
+
|
| 6 |
+
Caching: the (cpp_code + hardware_profile_id) sha256 keys a persistent on-disk
|
| 7 |
+
cache of compiled `.so` files. Per plan §7 risk #2, a high cache hit rate is
|
| 8 |
+
critical to keeping training cost within budget.
|
| 9 |
+
|
| 10 |
+
Output language enforcement (per plan §10a): the wrapper signature is auto-
|
| 11 |
+
generated from the Python AST and the agent's code MUST define `extern "C"`
|
| 12 |
+
function with that exact signature. Compile errors → reward = 0.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import hashlib
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import re
|
| 21 |
+
import shutil
|
| 22 |
+
import subprocess
|
| 23 |
+
import tempfile
|
| 24 |
+
import time
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
from typing import Any
|
| 27 |
+
|
| 28 |
+
# Persistent compile cache directory (shared across episodes within a process run)
|
| 29 |
+
_CACHE_ROOT = Path(os.environ.get("POLYGLOT_OPTIMA_CACHE", str(Path(tempfile.gettempdir()) / "polyglot_optima_cache")))
|
| 30 |
+
_CACHE_ROOT.mkdir(parents=True, exist_ok=True)
|
| 31 |
+
|
| 32 |
+
# Compile std — locked to C++20 in production per plan §10a.
|
| 33 |
+
# Allowing C++17/C++14 silently would let the agent learn code that fails on the
|
| 34 |
+
# real GCC 14 deploy. Therefore: production = c++20 only. Dev fallback requires
|
| 35 |
+
# the explicit POLYGLOT_OPTIMA_DEV_FALLBACK=1 env var (used by tests on machines
|
| 36 |
+
# with old MinGW); even then we warn loudly so the divergence isn't invisible.
|
| 37 |
+
_PRODUCTION_CXX_STD = "c++20"
|
| 38 |
+
_DEV_FALLBACK_ALLOWED = os.environ.get("POLYGLOT_OPTIMA_DEV_FALLBACK", "0") == "1"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _detect_supported_cxx_std() -> str:
|
| 42 |
+
"""Return c++20 if the compiler supports it; else c++20 anyway in production
|
| 43 |
+
(so the compile fails informatively and the gate registers it as syntax_error).
|
| 44 |
+
|
| 45 |
+
With POLYGLOT_OPTIMA_DEV_FALLBACK=1 set, we fall back to the highest std the
|
| 46 |
+
compiler accepts and emit a stderr warning. That mode is for local dev tests
|
| 47 |
+
only — never for training or deploy."""
|
| 48 |
+
compiler = shutil.which("g++") or shutil.which("clang++")
|
| 49 |
+
if not compiler:
|
| 50 |
+
return _PRODUCTION_CXX_STD
|
| 51 |
+
|
| 52 |
+
# Probe c++20 first
|
| 53 |
+
try:
|
| 54 |
+
r = subprocess.run([compiler, f"-std={_PRODUCTION_CXX_STD}", "-x", "c++", "-E", "-"],
|
| 55 |
+
input="", capture_output=True, text=True, timeout=5)
|
| 56 |
+
if r.returncode == 0 and "unrecognized" not in (r.stderr or "").lower():
|
| 57 |
+
return _PRODUCTION_CXX_STD
|
| 58 |
+
except Exception:
|
| 59 |
+
pass
|
| 60 |
+
|
| 61 |
+
if not _DEV_FALLBACK_ALLOWED:
|
| 62 |
+
# Production: stay on c++20. If the compiler can't, every compile will fail
|
| 63 |
+
# — that's the right signal (deploy with old GCC needs upgrading, not lowering).
|
| 64 |
+
return _PRODUCTION_CXX_STD
|
| 65 |
+
|
| 66 |
+
# Dev fallback only — emit warning so the divergence is visible
|
| 67 |
+
import sys as _sys
|
| 68 |
+
for std in ("c++17", "c++14"):
|
| 69 |
+
try:
|
| 70 |
+
r = subprocess.run([compiler, f"-std={std}", "-x", "c++", "-E", "-"],
|
| 71 |
+
input="", capture_output=True, text=True, timeout=5)
|
| 72 |
+
if r.returncode == 0 and "unrecognized" not in (r.stderr or "").lower():
|
| 73 |
+
print(
|
| 74 |
+
f"⚠ POLYGLOT_OPTIMA: dev fallback to -std={std} (compiler does not support c++20). "
|
| 75 |
+
f"This is for local tests only — production training/deploy MUST use c++20.",
|
| 76 |
+
file=_sys.stderr,
|
| 77 |
+
)
|
| 78 |
+
return std
|
| 79 |
+
except Exception:
|
| 80 |
+
continue
|
| 81 |
+
return _PRODUCTION_CXX_STD
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _detect_openmp() -> bool:
|
| 85 |
+
"""Test whether `-fopenmp` actually links — MinGW often lacks pthread libs."""
|
| 86 |
+
compiler = shutil.which("g++") or shutil.which("clang++")
|
| 87 |
+
if not compiler:
|
| 88 |
+
return False
|
| 89 |
+
try:
|
| 90 |
+
# Try to compile + LINK a trivial OpenMP program. Compile-only succeeds even
|
| 91 |
+
# without pthread; we need the link step to confirm the runtime is available.
|
| 92 |
+
import tempfile
|
| 93 |
+
with tempfile.TemporaryDirectory() as td:
|
| 94 |
+
src = Path(td) / "_omp_probe.cpp"
|
| 95 |
+
obj = Path(td) / "_omp_probe.so"
|
| 96 |
+
src.write_text("#include <omp.h>\nint main(){return omp_get_num_threads();}\n")
|
| 97 |
+
r = subprocess.run([compiler, "-fopenmp", str(src), "-shared", "-fPIC", "-o", str(obj)],
|
| 98 |
+
capture_output=True, text=True, timeout=10)
|
| 99 |
+
return r.returncode == 0
|
| 100 |
+
except Exception:
|
| 101 |
+
return False
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _detect_dispatchable() -> bool:
|
| 105 |
+
"""Compile + ctypes-load a tiny probe. Returns True iff the toolchain produces a
|
| 106 |
+
.so loadable by THIS Python interpreter (catches bitness mismatch on MinGW)."""
|
| 107 |
+
compiler = shutil.which("g++") or shutil.which("clang++")
|
| 108 |
+
if not compiler:
|
| 109 |
+
return False
|
| 110 |
+
try:
|
| 111 |
+
import ctypes as _ct
|
| 112 |
+
import tempfile
|
| 113 |
+
with tempfile.TemporaryDirectory() as td:
|
| 114 |
+
src = Path(td) / "_probe.cpp"
|
| 115 |
+
so = Path(td) / "_probe.so"
|
| 116 |
+
src.write_text(
|
| 117 |
+
'extern "C" void agent_function(const double*, '
|
| 118 |
+
'unsigned long long, double* o, unsigned long long n)'
|
| 119 |
+
'{ if (n) o[0] = 1.0; }\n'
|
| 120 |
+
)
|
| 121 |
+
r = subprocess.run(
|
| 122 |
+
[compiler, "-O0", "-fPIC", "-shared", str(src), "-o", str(so)],
|
| 123 |
+
capture_output=True, text=True, timeout=15,
|
| 124 |
+
)
|
| 125 |
+
if r.returncode != 0:
|
| 126 |
+
return False
|
| 127 |
+
lib = _ct.CDLL(str(so))
|
| 128 |
+
return hasattr(lib, "agent_function")
|
| 129 |
+
except Exception:
|
| 130 |
+
return False
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
_DETECTED_CXX_STD = _detect_supported_cxx_std()
|
| 134 |
+
_HAS_OPENMP = _detect_openmp()
|
| 135 |
+
_DISPATCHABLE = _detect_dispatchable()
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
_BASE_COMPILE_FLAGS = [
|
| 139 |
+
"-O3",
|
| 140 |
+
"-march=native",
|
| 141 |
+
f"-std={_DETECTED_CXX_STD}",
|
| 142 |
+
"-Wall",
|
| 143 |
+
# `-Werror` removed: many MinGW builds emit warnings on default flags.
|
| 144 |
+
# Production deploy can re-add via POLYGLOT_OPTIMA_STRICT=1
|
| 145 |
+
"-fPIC",
|
| 146 |
+
"-shared",
|
| 147 |
+
]
|
| 148 |
+
if _HAS_OPENMP:
|
| 149 |
+
_BASE_COMPILE_FLAGS.insert(2, "-fopenmp")
|
| 150 |
+
if os.environ.get("POLYGLOT_OPTIMA_STRICT", "0") == "1":
|
| 151 |
+
_BASE_COMPILE_FLAGS.append("-Werror")
|
| 152 |
+
|
| 153 |
+
# Banned headers (per plan §10a — would mask agent's actual contribution)
|
| 154 |
+
_BANNED_INCLUDES = [
|
| 155 |
+
"<mkl.h>", "<mkl", # Intel MKL
|
| 156 |
+
"<Eigen/", "Eigen/", # Eigen
|
| 157 |
+
"<cblas.h>", "<lapack.h>", # BLAS/LAPACK
|
| 158 |
+
"<cuda_runtime.h>", "<cuda.h>", # CUDA
|
| 159 |
+
"<hip/", # HIP
|
| 160 |
+
]
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def _sha256(*parts: str) -> str:
|
| 164 |
+
h = hashlib.sha256()
|
| 165 |
+
for p in parts:
|
| 166 |
+
h.update(p.encode("utf-8"))
|
| 167 |
+
h.update(b"\x00")
|
| 168 |
+
return h.hexdigest()
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _check_for_banned_headers(cpp_code: str) -> str | None:
|
| 172 |
+
"""Return error string if the code uses a banned header, else None."""
|
| 173 |
+
for banned in _BANNED_INCLUDES:
|
| 174 |
+
if banned in cpp_code:
|
| 175 |
+
return (
|
| 176 |
+
f"Banned header detected: {banned}. "
|
| 177 |
+
f"We measure YOUR optimization, not a library call. "
|
| 178 |
+
f"Allowed: STL, <immintrin.h>, <arm_neon.h>, <omp.h>, <pybind11/*>"
|
| 179 |
+
)
|
| 180 |
+
return None
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def _has_required_entry_point(cpp_code: str) -> bool:
|
| 184 |
+
"""Validate canonical ABI expected by runtime dispatcher.
|
| 185 |
+
|
| 186 |
+
Required signature:
|
| 187 |
+
extern "C" void agent_function(const double*, size_t|unsigned long long,
|
| 188 |
+
double*, size_t|unsigned long long)
|
| 189 |
+
"""
|
| 190 |
+
pattern = (
|
| 191 |
+
r'extern\s*"C"\s+void\s+agent_function\s*\('
|
| 192 |
+
r'\s*const\s+double\s*\*\s*(?:\w+)?\s*,'
|
| 193 |
+
r'\s*(?:size_t|unsigned\s+long\s+long)\s*(?:\w+)?\s*,'
|
| 194 |
+
r'\s*double\s*\*\s*(?:\w+)?\s*,'
|
| 195 |
+
r'\s*(?:size_t|unsigned\s+long\s+long)\s*(?:\w+)?\s*'
|
| 196 |
+
r'\)'
|
| 197 |
+
)
|
| 198 |
+
return re.search(pattern, cpp_code, flags=re.IGNORECASE | re.DOTALL) is not None
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def _compile(cpp_code: str, hw_profile: dict[str, Any], cache_key: str, timeout_s: int = 30) -> dict[str, Any]:
|
| 202 |
+
"""Run g++; cache the .so by cache_key. Return dict with status + path/error."""
|
| 203 |
+
cache_dir = _CACHE_ROOT / cache_key[:2]
|
| 204 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 205 |
+
so_path = cache_dir / f"{cache_key}.so"
|
| 206 |
+
|
| 207 |
+
# Cache hit
|
| 208 |
+
if so_path.exists():
|
| 209 |
+
return {"status": "success", "so_path": str(so_path), "cached": True}
|
| 210 |
+
|
| 211 |
+
# Banned headers → reject before invoking compiler
|
| 212 |
+
banned_err = _check_for_banned_headers(cpp_code)
|
| 213 |
+
if banned_err:
|
| 214 |
+
return {"status": "syntax_error", "error": banned_err, "cached": False}
|
| 215 |
+
|
| 216 |
+
# Write source + invoke compiler
|
| 217 |
+
src_path = cache_dir / f"{cache_key}.cpp"
|
| 218 |
+
src_path.write_text(cpp_code, encoding="utf-8")
|
| 219 |
+
|
| 220 |
+
# Resolve compiler — prefer g++ on Linux, fall back to clang++ on macOS
|
| 221 |
+
compiler = shutil.which("g++") or shutil.which("clang++") or "g++"
|
| 222 |
+
|
| 223 |
+
cmd = [compiler, *_BASE_COMPILE_FLAGS, str(src_path), "-o", str(so_path)]
|
| 224 |
+
|
| 225 |
+
try:
|
| 226 |
+
proc = subprocess.run(
|
| 227 |
+
cmd, capture_output=True, text=True, timeout=timeout_s,
|
| 228 |
+
)
|
| 229 |
+
except subprocess.TimeoutExpired:
|
| 230 |
+
return {"status": "timeout", "error": f"Compilation exceeded {timeout_s}s", "cached": False}
|
| 231 |
+
except FileNotFoundError:
|
| 232 |
+
return {"status": "syntax_error",
|
| 233 |
+
"error": f"Compiler {compiler!r} not found. Install GCC 14 or clang++.",
|
| 234 |
+
"cached": False}
|
| 235 |
+
|
| 236 |
+
if proc.returncode != 0:
|
| 237 |
+
return {
|
| 238 |
+
"status": "syntax_error",
|
| 239 |
+
"error": (proc.stderr or proc.stdout)[:2000],
|
| 240 |
+
"cmd": " ".join(cmd),
|
| 241 |
+
"cached": False,
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
return {"status": "success", "so_path": str(so_path), "cached": False}
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def _load_python_function(python_code: str):
|
| 248 |
+
"""Exec python_code in a fresh namespace, return the first FunctionDef as a callable."""
|
| 249 |
+
import ast
|
| 250 |
+
tree = ast.parse(python_code)
|
| 251 |
+
fn_node = next((n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef)), None)
|
| 252 |
+
if fn_node is None:
|
| 253 |
+
raise RuntimeError("python_code defines no function")
|
| 254 |
+
ns: dict[str, Any] = {}
|
| 255 |
+
exec(compile(tree, filename="<agent_python>", mode="exec"), ns)
|
| 256 |
+
fn = ns.get(fn_node.name)
|
| 257 |
+
if fn is None:
|
| 258 |
+
raise RuntimeError(f"function {fn_node.name!r} not found after exec")
|
| 259 |
+
return fn
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def _benchmark_python_baseline(python_code: str, sample_input_size: int = 1024) -> dict[str, Any]:
|
| 263 |
+
"""Real median-of-15 wall time of the Python function on a default-typed input."""
|
| 264 |
+
from server.tools._runtime import time_python_only, make_default_args_for
|
| 265 |
+
try:
|
| 266 |
+
py_fn = _load_python_function(python_code)
|
| 267 |
+
args = make_default_args_for(py_fn, n=sample_input_size)
|
| 268 |
+
median_ms = time_python_only(py_fn, args, n_per_repeat=5, repeats=3)
|
| 269 |
+
return {
|
| 270 |
+
"median_ms": float(median_ms),
|
| 271 |
+
"method": "perf_counter_median_5x3",
|
| 272 |
+
"n_samples": sample_input_size,
|
| 273 |
+
}
|
| 274 |
+
except Exception as e:
|
| 275 |
+
# Don't crash the env on a broken Python function; signal "0 baseline" → speedup goes to 0
|
| 276 |
+
return {
|
| 277 |
+
"median_ms": 0.0,
|
| 278 |
+
"method": "error",
|
| 279 |
+
"error": str(e)[:200],
|
| 280 |
+
"n_samples": sample_input_size,
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def _benchmark_cpp(so_path: str, python_code: str, sample_input_size: int = 1024) -> dict[str, Any]:
|
| 285 |
+
"""Real median-of-15 wall time of the compiled .so via ctypes dispatch."""
|
| 286 |
+
from server.tools._runtime import benchmark_python_vs_cpp, make_default_args_for
|
| 287 |
+
try:
|
| 288 |
+
py_fn = _load_python_function(python_code)
|
| 289 |
+
args = make_default_args_for(py_fn, n=sample_input_size)
|
| 290 |
+
result = benchmark_python_vs_cpp(so_path, py_fn, args, n_per_repeat=5, repeats=3)
|
| 291 |
+
return {
|
| 292 |
+
"median_ms": float(result["cpp_median_ms"]),
|
| 293 |
+
"py_median_ms": float(result["py_median_ms"]),
|
| 294 |
+
"speedup_internal": float(result["speedup"]),
|
| 295 |
+
"method": "ctypes_perf_counter_median_5x3",
|
| 296 |
+
"n_samples": sample_input_size,
|
| 297 |
+
}
|
| 298 |
+
except Exception as e:
|
| 299 |
+
return {
|
| 300 |
+
"median_ms": 0.0,
|
| 301 |
+
"method": "error",
|
| 302 |
+
"error": str(e)[:200],
|
| 303 |
+
"n_samples": sample_input_size,
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def compile_and_benchmark_tool(tool_args: dict[str, Any], state) -> dict[str, Any]:
|
| 308 |
+
"""Compile agent C++ and report compile status + speedup measurement.
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
cpp_code (str): The C++20 source to compile.
|
| 312 |
+
|
| 313 |
+
Returns dict with:
|
| 314 |
+
compile_status: "success" | "syntax_error" | "link_error" | "timeout"
|
| 315 |
+
speedup: float (python_ms / cpp_ms) — only valid if compile_status == "success"
|
| 316 |
+
python_ms: median-of-15 Python baseline
|
| 317 |
+
cpp_ms: median-of-15 agent C++ wall time
|
| 318 |
+
error: str (if compile_status != "success")
|
| 319 |
+
cache_hit: bool
|
| 320 |
+
"""
|
| 321 |
+
cpp_code = tool_args.get("cpp_code", "")
|
| 322 |
+
if not cpp_code.strip():
|
| 323 |
+
return {"compile_status": "syntax_error", "error": "empty cpp_code", "speedup": 0.0}
|
| 324 |
+
|
| 325 |
+
if not _has_required_entry_point(cpp_code):
|
| 326 |
+
return {
|
| 327 |
+
"compile_status": "syntax_error",
|
| 328 |
+
"error": (
|
| 329 |
+
'Missing required entry point: must define `extern "C" ... agent_function(...)`'
|
| 330 |
+
),
|
| 331 |
+
"speedup": 0.0,
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
# Cache key
|
| 335 |
+
hw = state.hardware_profile
|
| 336 |
+
cache_key = _sha256(cpp_code, json.dumps(hw, sort_keys=True))
|
| 337 |
+
|
| 338 |
+
t_compile_start = time.perf_counter()
|
| 339 |
+
compile_result = _compile(cpp_code, hw, cache_key)
|
| 340 |
+
compile_time_s = time.perf_counter() - t_compile_start
|
| 341 |
+
|
| 342 |
+
if compile_result["status"] != "success":
|
| 343 |
+
return {
|
| 344 |
+
"compile_status": compile_result["status"],
|
| 345 |
+
"error": compile_result.get("error", "compilation failed"),
|
| 346 |
+
"speedup": 0.0,
|
| 347 |
+
"compile_time_s": compile_time_s,
|
| 348 |
+
"cache_hit": False,
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
# Real benchmark via ctypes dispatch — joint timing of python + cpp on same args
|
| 352 |
+
cpp_bench = _benchmark_cpp(compile_result["so_path"], state.python_code)
|
| 353 |
+
|
| 354 |
+
if cpp_bench.get("method") == "error":
|
| 355 |
+
# Compilation succeeded but the .so couldn't be dispatched (wrong signature, missing symbol)
|
| 356 |
+
return {
|
| 357 |
+
"compile_status": "link_error",
|
| 358 |
+
"error": cpp_bench.get("error", "ctypes dispatch failed"),
|
| 359 |
+
"speedup": 0.0,
|
| 360 |
+
"python_ms": 0.0,
|
| 361 |
+
"cpp_ms": 0.0,
|
| 362 |
+
"compile_time_s": compile_time_s,
|
| 363 |
+
"cache_hit": compile_result.get("cached", False),
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
py_ms = cpp_bench.get("py_median_ms", 0.0)
|
| 367 |
+
cpp_ms = cpp_bench["median_ms"]
|
| 368 |
+
speedup = py_ms / max(cpp_ms, 1e-6) if py_ms > 0 else 0.0
|
| 369 |
+
|
| 370 |
+
return {
|
| 371 |
+
"compile_status": "success",
|
| 372 |
+
"speedup": speedup,
|
| 373 |
+
"python_ms": py_ms,
|
| 374 |
+
"cpp_ms": cpp_ms,
|
| 375 |
+
"compile_time_s": compile_time_s,
|
| 376 |
+
"cache_hit": compile_result.get("cached", False),
|
| 377 |
+
"so_path": compile_result["so_path"],
|
| 378 |
+
"method": "ctypes_median_5x3_walltime",
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
__all__ = ["compile_and_benchmark_tool", "_sha256", "_BASE_COMPILE_FLAGS"]
|
server/tools/hardware_profiler.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tool 1/9: get_hardware_profile.
|
| 2 |
+
|
| 3 |
+
Returns the hardware profile for the current episode along with the precomputed
|
| 4 |
+
Roofline bound. The profile is sampled at reset() time and frozen for the episode;
|
| 5 |
+
this tool just exposes it to the agent.
|
| 6 |
+
|
| 7 |
+
Roofline math (per plan §10):
|
| 8 |
+
simd_w = {"SSE4.2": 4, "AVX2": 8, "AVX-512": 16, "NEON": 4, "none": 1}
|
| 9 |
+
peak_flops = cores × freq_ghz × simd_w × 2 (FMA = 2 ops/cycle)
|
| 10 |
+
peak_bandwidth_flops = bandwidth_gbs × 0.5 (rough flop-per-byte ceiling)
|
| 11 |
+
roofline_bound = min(peak_flops, peak_bandwidth_flops)
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
from typing import Any
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
SIMD_WIDTH = {
|
| 20 |
+
"SSE4.2": 4,
|
| 21 |
+
"AVX2": 8,
|
| 22 |
+
"AVX-512": 16,
|
| 23 |
+
"NEON": 4,
|
| 24 |
+
"none": 1,
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def roofline_bound(hw: dict[str, Any]) -> float:
|
| 29 |
+
"""Compute the Roofline-model peak GFLOPS for a hardware profile."""
|
| 30 |
+
simd_w = SIMD_WIDTH.get(hw["simd"], 1)
|
| 31 |
+
peak_flops = hw["cores"] * hw["freq_ghz"] * simd_w * 2
|
| 32 |
+
peak_bw = hw["bw_gbs"] * 0.5
|
| 33 |
+
return float(min(peak_flops, peak_bw))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_hardware_profile_tool(tool_args: dict[str, Any], state) -> dict[str, Any]:
|
| 37 |
+
"""Return the episode's hardware profile + Roofline bound.
|
| 38 |
+
|
| 39 |
+
No arguments — the profile is fixed at episode start.
|
| 40 |
+
"""
|
| 41 |
+
hw = state.hardware_profile
|
| 42 |
+
return {
|
| 43 |
+
"id": hw.get("id", "unknown"),
|
| 44 |
+
"cores": hw["cores"],
|
| 45 |
+
"freq_ghz": hw["freq_ghz"],
|
| 46 |
+
"l1_kb": hw["l1_kb"],
|
| 47 |
+
"simd": hw["simd"],
|
| 48 |
+
"bandwidth_gbs": hw["bw_gbs"],
|
| 49 |
+
"roofline_bound_gflops": roofline_bound(hw),
|
| 50 |
+
# Extra context the agent may use
|
| 51 |
+
"simd_width_floats": SIMD_WIDTH.get(hw["simd"], 1),
|
| 52 |
+
"bytes_per_flop_threshold": 1.0 / max(roofline_bound(hw), 0.001),
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
__all__ = ["get_hardware_profile_tool", "roofline_bound", "SIMD_WIDTH"]
|
server/tools/portability_checker.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tool 7/9: check_portability.
|
| 2 |
+
|
| 3 |
+
Compiles the agent's C++ against each of the 8 hardware profile flag-sets
|
| 4 |
+
and runs a quick correctness check (subset of the fuzzer) on each. Awards
|
| 5 |
+
the portability bonus if 3+ profiles pass.
|
| 6 |
+
|
| 7 |
+
Per plan §3 axis 4 (`portability_required`), the agent only earns the
|
| 8 |
+
PortabilityRubric bonus when this axis is escalated. Otherwise the result
|
| 9 |
+
is informational.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import json
|
| 15 |
+
from typing import Any
|
| 16 |
+
|
| 17 |
+
from server.tools.cpp_compiler import _compile, _sha256
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Per-profile compile flag overrides (in addition to the base `_BASE_COMPILE_FLAGS`).
|
| 21 |
+
# `-march=native` is replaced with the appropriate -m* flag matching the profile's SIMD level.
|
| 22 |
+
PROFILE_COMPILE_OVERRIDES = {
|
| 23 |
+
"SSE4.2": ["-msse4.2", "-mno-avx", "-mno-avx2", "-mno-avx512f"],
|
| 24 |
+
"AVX2": ["-mavx2", "-mfma", "-mno-avx512f"],
|
| 25 |
+
"AVX-512": ["-mavx512f", "-mavx512cd", "-mavx512vl"],
|
| 26 |
+
"NEON": ["-mfpu=neon"], # ARM-only — for cross-compile mode
|
| 27 |
+
"none": ["-mno-sse", "-mno-avx", "-mno-avx2"],
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _override_flags(base_flags: list[str], simd: str) -> list[str]:
|
| 32 |
+
"""Replace -march=native with the profile-specific SIMD flag set."""
|
| 33 |
+
out = [f for f in base_flags if not f.startswith("-march=")]
|
| 34 |
+
out += PROFILE_COMPILE_OVERRIDES.get(simd, [])
|
| 35 |
+
return out
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def check_portability_tool(tool_args: dict[str, Any], state) -> dict[str, Any]:
|
| 39 |
+
"""Test compile + quick correctness on all 8 hardware profiles.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
cpp_code (str)
|
| 43 |
+
n_cases_per_profile (int=50) — quick smoke check per profile
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
per_profile (dict[str, dict]) — id → {compile, correctness}
|
| 47 |
+
n_profiles_passing (int)
|
| 48 |
+
portability_bonus_eligible (bool) — True if ≥3 profiles compile + pass correctness
|
| 49 |
+
"""
|
| 50 |
+
cpp_code = tool_args.get("cpp_code", "")
|
| 51 |
+
if not cpp_code.strip():
|
| 52 |
+
return {"per_profile": {}, "n_profiles_passing": 0, "portability_bonus_eligible": False, "error": "empty cpp_code"}
|
| 53 |
+
|
| 54 |
+
# Lazy-import the full profile list — provided by scenarios.hardware_profiles in Hour 16
|
| 55 |
+
try:
|
| 56 |
+
from server.scenarios.hardware_profiles import HARDWARE_PROFILES
|
| 57 |
+
except ImportError:
|
| 58 |
+
# During Hour 4-10 use a stub list with all 8 profiles inlined
|
| 59 |
+
HARDWARE_PROFILES = _STUB_PROFILES
|
| 60 |
+
|
| 61 |
+
per_profile: dict[str, dict[str, Any]] = {}
|
| 62 |
+
n_passing = 0
|
| 63 |
+
|
| 64 |
+
# Reuse the simple verifier over a small sample
|
| 65 |
+
from server.tools.verifier import verify_equivalence_tool
|
| 66 |
+
|
| 67 |
+
for hw in HARDWARE_PROFILES:
|
| 68 |
+
if hw["id"] == state.hardware_profile.get("id"):
|
| 69 |
+
# Skip the home profile — we test it via the main verifier
|
| 70 |
+
continue
|
| 71 |
+
cache_key = _sha256(cpp_code, json.dumps(hw, sort_keys=True), "portability")
|
| 72 |
+
compile_result = _compile(cpp_code, hw, cache_key)
|
| 73 |
+
compile_ok = compile_result["status"] == "success"
|
| 74 |
+
|
| 75 |
+
correctness_ok = False
|
| 76 |
+
if compile_ok:
|
| 77 |
+
# Quick fuzz on this profile (50 cases)
|
| 78 |
+
verifier_args = {
|
| 79 |
+
"cpp_code": cpp_code,
|
| 80 |
+
"python_code": state.python_code,
|
| 81 |
+
"n_cases": int(tool_args.get("n_cases_per_profile", 50)),
|
| 82 |
+
}
|
| 83 |
+
# Temporarily swap the state's hw profile so the verifier compiles for this one
|
| 84 |
+
saved_hw = state.hardware_profile
|
| 85 |
+
state.hardware_profile = hw
|
| 86 |
+
try:
|
| 87 |
+
v = verify_equivalence_tool(verifier_args, state)
|
| 88 |
+
correctness_ok = v.get("pass_rate", 0.0) >= 0.95
|
| 89 |
+
finally:
|
| 90 |
+
state.hardware_profile = saved_hw
|
| 91 |
+
|
| 92 |
+
per_profile[hw["id"]] = {
|
| 93 |
+
"compile": "success" if compile_ok else "fail",
|
| 94 |
+
"correctness_ok": correctness_ok,
|
| 95 |
+
"compile_error": compile_result.get("error", "")[:300] if not compile_ok else "",
|
| 96 |
+
}
|
| 97 |
+
if compile_ok and correctness_ok:
|
| 98 |
+
n_passing += 1
|
| 99 |
+
|
| 100 |
+
eligible = n_passing >= 3
|
| 101 |
+
|
| 102 |
+
return {
|
| 103 |
+
"per_profile": per_profile,
|
| 104 |
+
"n_profiles_passing": n_passing,
|
| 105 |
+
"portability_bonus_eligible": eligible,
|
| 106 |
+
"tested_profiles": [p["id"] for p in HARDWARE_PROFILES if p["id"] != state.hardware_profile.get("id")],
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# Inline 8-profile stub used during Hour 4-10 before scenarios module is built
|
| 111 |
+
_STUB_PROFILES = [
|
| 112 |
+
{"id": "laptop_sse", "cores": 4, "freq_ghz": 3.2, "l1_kb": 32, "simd": "SSE4.2", "bw_gbs": 40},
|
| 113 |
+
{"id": "desktop_avx2", "cores": 8, "freq_ghz": 3.8, "l1_kb": 32, "simd": "AVX2", "bw_gbs": 51},
|
| 114 |
+
{"id": "server_avx512", "cores": 16, "freq_ghz": 3.0, "l1_kb": 48, "simd": "AVX-512", "bw_gbs": 89},
|
| 115 |
+
{"id": "arm_neon_a", "cores": 6, "freq_ghz": 2.4, "l1_kb": 64, "simd": "NEON", "bw_gbs": 68},
|
| 116 |
+
{"id": "embedded", "cores": 2, "freq_ghz": 1.8, "l1_kb": 16, "simd": "none", "bw_gbs": 25},
|
| 117 |
+
{"id": "workstation", "cores": 12, "freq_ghz": 4.0, "l1_kb": 48, "simd": "AVX2", "bw_gbs": 76},
|
| 118 |
+
{"id": "arm_neon_b", "cores": 8, "freq_ghz": 2.8, "l1_kb": 32, "simd": "NEON", "bw_gbs": 68},
|
| 119 |
+
{"id": "laptop_sse2", "cores": 4, "freq_ghz": 2.6, "l1_kb": 64, "simd": "SSE4.2", "bw_gbs": 35},
|
| 120 |
+
]
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
__all__ = ["check_portability_tool", "PROFILE_COMPILE_OVERRIDES"]
|
server/tools/python_analyzer.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tools 2-4/9: profile_python_hotspots, analyze_complexity, check_memory_access.
|
| 2 |
+
|
| 3 |
+
Three static-analysis tools the agent uses to *understand the input code* before
|
| 4 |
+
writing C++. All run on the AST — no Python execution required for these tools
|
| 5 |
+
(the verifier and benchmarker do the actual execution, sandboxed).
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import ast
|
| 11 |
+
import re
|
| 12 |
+
from typing import Any
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# ----------------- Tool 2: profile_python_hotspots ----------------
|
| 16 |
+
|
| 17 |
+
def profile_python_hotspots_tool(tool_args: dict[str, Any], state) -> dict[str, Any]:
|
| 18 |
+
"""Return the top hot lines of the Python function (static cost estimate).
|
| 19 |
+
|
| 20 |
+
For a static-analysis-only tool, we approximate hotness via:
|
| 21 |
+
- loop nesting depth at the line
|
| 22 |
+
- operations inside loops (multiplied by estimated trip count)
|
| 23 |
+
- presence of np.* calls (vectorized but still expensive on large arrays)
|
| 24 |
+
|
| 25 |
+
For a more accurate dynamic profile (cProfile run), pass `dynamic=True` —
|
| 26 |
+
that path will be wired to a sandboxed run in Hour 16+.
|
| 27 |
+
"""
|
| 28 |
+
code = tool_args.get("code") or state.python_code
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
tree = ast.parse(code)
|
| 32 |
+
except SyntaxError as e:
|
| 33 |
+
return {"error": f"Python parse error: {e}", "hotspots": []}
|
| 34 |
+
|
| 35 |
+
hotspots: list[dict[str, Any]] = []
|
| 36 |
+
line_costs: dict[int, int] = {}
|
| 37 |
+
|
| 38 |
+
class HotspotVisitor(ast.NodeVisitor):
|
| 39 |
+
def __init__(self):
|
| 40 |
+
self.loop_depth = 0
|
| 41 |
+
|
| 42 |
+
def visit_For(self, node):
|
| 43 |
+
self.loop_depth += 1
|
| 44 |
+
self.generic_visit(node)
|
| 45 |
+
self.loop_depth -= 1
|
| 46 |
+
|
| 47 |
+
def visit_While(self, node):
|
| 48 |
+
self.loop_depth += 1
|
| 49 |
+
self.generic_visit(node)
|
| 50 |
+
self.loop_depth -= 1
|
| 51 |
+
|
| 52 |
+
def visit_BinOp(self, node):
|
| 53 |
+
cost = 1 << self.loop_depth # 2^depth — exponential weight per nesting
|
| 54 |
+
line_costs[node.lineno] = line_costs.get(node.lineno, 0) + cost
|
| 55 |
+
self.generic_visit(node)
|
| 56 |
+
|
| 57 |
+
def visit_Call(self, node):
|
| 58 |
+
# Penalize np.* calls inside loops more
|
| 59 |
+
cost = (1 << self.loop_depth) * 2
|
| 60 |
+
line_costs[node.lineno] = line_costs.get(node.lineno, 0) + cost
|
| 61 |
+
self.generic_visit(node)
|
| 62 |
+
|
| 63 |
+
HotspotVisitor().visit(tree)
|
| 64 |
+
|
| 65 |
+
code_lines = code.splitlines()
|
| 66 |
+
sorted_lines = sorted(line_costs.items(), key=lambda x: -x[1])
|
| 67 |
+
for lineno, cost in sorted_lines[:5]:
|
| 68 |
+
if 0 < lineno <= len(code_lines):
|
| 69 |
+
hotspots.append({
|
| 70 |
+
"line_number": lineno,
|
| 71 |
+
"estimated_cost": cost,
|
| 72 |
+
"source": code_lines[lineno - 1].strip(),
|
| 73 |
+
})
|
| 74 |
+
|
| 75 |
+
total_cost = sum(line_costs.values())
|
| 76 |
+
return {
|
| 77 |
+
"hotspots": hotspots,
|
| 78 |
+
"total_estimated_cost": total_cost,
|
| 79 |
+
"method": "static_ast_analysis",
|
| 80 |
+
"hint": "Lines deep in loops dominate; vectorize or parallelize them first.",
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# ----------------- Tool 3: analyze_complexity ----------------
|
| 85 |
+
|
| 86 |
+
def analyze_complexity_tool(tool_args: dict[str, Any], state) -> dict[str, Any]:
|
| 87 |
+
"""Return Big-O class + max loop nesting depth via AST.
|
| 88 |
+
|
| 89 |
+
A loop nesting depth of k suggests O(n^k) in the typical case. Recursion
|
| 90 |
+
detection is naive (treats every recursive call as +1 to complexity).
|
| 91 |
+
"""
|
| 92 |
+
code = tool_args.get("code") or state.python_code
|
| 93 |
+
try:
|
| 94 |
+
tree = ast.parse(code)
|
| 95 |
+
except SyntaxError as e:
|
| 96 |
+
return {"error": f"Python parse error: {e}"}
|
| 97 |
+
|
| 98 |
+
max_depth = [0]
|
| 99 |
+
|
| 100 |
+
class DepthVisitor(ast.NodeVisitor):
|
| 101 |
+
def __init__(self):
|
| 102 |
+
self.depth = 0
|
| 103 |
+
|
| 104 |
+
def visit_For(self, node):
|
| 105 |
+
self.depth += 1
|
| 106 |
+
max_depth[0] = max(max_depth[0], self.depth)
|
| 107 |
+
self.generic_visit(node)
|
| 108 |
+
self.depth -= 1
|
| 109 |
+
|
| 110 |
+
def visit_While(self, node):
|
| 111 |
+
self.depth += 1
|
| 112 |
+
max_depth[0] = max(max_depth[0], self.depth)
|
| 113 |
+
self.generic_visit(node)
|
| 114 |
+
self.depth -= 1
|
| 115 |
+
|
| 116 |
+
DepthVisitor().visit(tree)
|
| 117 |
+
|
| 118 |
+
depth = max_depth[0]
|
| 119 |
+
if depth == 0:
|
| 120 |
+
big_o = "O(1)"
|
| 121 |
+
elif depth == 1:
|
| 122 |
+
big_o = "O(n)"
|
| 123 |
+
else:
|
| 124 |
+
big_o = f"O(n^{depth})"
|
| 125 |
+
|
| 126 |
+
# Detect simple recursion (function calls itself)
|
| 127 |
+
func_names = {n.name for n in ast.walk(tree) if isinstance(n, ast.FunctionDef)}
|
| 128 |
+
has_recursion = any(
|
| 129 |
+
isinstance(c.func, ast.Name) and c.func.id in func_names
|
| 130 |
+
for c in ast.walk(tree) if isinstance(c, ast.Call)
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
return {
|
| 134 |
+
"big_o_estimate": big_o,
|
| 135 |
+
"max_loop_nesting_depth": depth,
|
| 136 |
+
"has_recursion": has_recursion,
|
| 137 |
+
"method": "static_ast_loop_depth",
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# ----------------- Tool 4: check_memory_access ----------------
|
| 142 |
+
|
| 143 |
+
# Patterns that suggest cache-unfriendly access
|
| 144 |
+
_STRIDE_PATTERN = re.compile(r"\[\s*j\s*,\s*i\s*\]|\[\s*i\s*\]\s*\[\s*j\s*\]")
|
| 145 |
+
_TRANSPOSE_PATTERN = re.compile(r"\.T\s*\[")
|
| 146 |
+
_NON_CONTIG_PATTERN = re.compile(r"\bnp\.ascontiguousarray\b|\bnp\.asfortranarray\b")
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def check_memory_access_tool(tool_args: dict[str, Any], state) -> dict[str, Any]:
|
| 150 |
+
"""Detect cache-unfriendly stride patterns / aliasing risks via static patterns.
|
| 151 |
+
|
| 152 |
+
This is a heuristic — not perfect, but catches the common cases:
|
| 153 |
+
- column-major access in row-major arrays (D[j, i] inside i,j loops)
|
| 154 |
+
- non-contiguous arrays passed in
|
| 155 |
+
- explicit transpose in hot expression
|
| 156 |
+
"""
|
| 157 |
+
code = tool_args.get("code") or state.python_code
|
| 158 |
+
|
| 159 |
+
issues: list[dict[str, str]] = []
|
| 160 |
+
|
| 161 |
+
if _STRIDE_PATTERN.search(code):
|
| 162 |
+
issues.append({
|
| 163 |
+
"type": "non_unit_stride",
|
| 164 |
+
"severity": "high",
|
| 165 |
+
"hint": "Detected D[j,i]-style access — likely column-major in a row-major array. "
|
| 166 |
+
"Cache misses dominate. Transpose the layout or swap loop order."
|
| 167 |
+
})
|
| 168 |
+
if _TRANSPOSE_PATTERN.search(code):
|
| 169 |
+
issues.append({
|
| 170 |
+
"type": "in_loop_transpose",
|
| 171 |
+
"severity": "med",
|
| 172 |
+
"hint": "`.T` in hot path may force a copy or non-contiguous access."
|
| 173 |
+
})
|
| 174 |
+
if _NON_CONTIG_PATTERN.search(code):
|
| 175 |
+
issues.append({
|
| 176 |
+
"type": "explicit_layout_handling",
|
| 177 |
+
"severity": "info",
|
| 178 |
+
"hint": "Code already handles contiguity — good; preserve in C++ via `restrict`."
|
| 179 |
+
})
|
| 180 |
+
|
| 181 |
+
# Inspect AST for "for i in range" + "for j in range" + a 2D index
|
| 182 |
+
try:
|
| 183 |
+
tree = ast.parse(code)
|
| 184 |
+
nested_for = False
|
| 185 |
+
for node in ast.walk(tree):
|
| 186 |
+
if isinstance(node, ast.For):
|
| 187 |
+
for sub in ast.walk(node):
|
| 188 |
+
if isinstance(sub, ast.For) and sub is not node:
|
| 189 |
+
nested_for = True
|
| 190 |
+
break
|
| 191 |
+
if nested_for and not issues:
|
| 192 |
+
issues.append({
|
| 193 |
+
"type": "nested_loop_unanalyzed",
|
| 194 |
+
"severity": "low",
|
| 195 |
+
"hint": "Nested loops detected. Verify that inner-loop index varies the contiguous dimension."
|
| 196 |
+
})
|
| 197 |
+
except SyntaxError:
|
| 198 |
+
pass
|
| 199 |
+
|
| 200 |
+
aliasing_risk = "low"
|
| 201 |
+
if "np.ndarray" in code or "ndarray" in code:
|
| 202 |
+
aliasing_risk = "med" # numpy arrays can alias; agent should consider `restrict`
|
| 203 |
+
|
| 204 |
+
return {
|
| 205 |
+
"issues": issues,
|
| 206 |
+
"aliasing_risk": aliasing_risk,
|
| 207 |
+
"recommendation": (
|
| 208 |
+
"Use `__restrict__` qualifier on non-aliasing pointers in C++. "
|
| 209 |
+
"Prefer SoA over AoS for SIMD-friendly access."
|
| 210 |
+
if issues else "No obvious memory-access issues; proceed with default layout."
|
| 211 |
+
),
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
__all__ = [
|
| 216 |
+
"profile_python_hotspots_tool",
|
| 217 |
+
"analyze_complexity_tool",
|
| 218 |
+
"check_memory_access_tool",
|
| 219 |
+
]
|
server/tools/submit.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tool 9/9: submit_optimization — closes the current round.
|
| 2 |
+
|
| 3 |
+
This is the only round-closing tool. The environment recognizes its name and:
|
| 4 |
+
1. Triggers full-strength verification (n_cases=1000)
|
| 5 |
+
2. Triggers portability check (cross-profile compile + correctness)
|
| 6 |
+
3. Computes the round's reward via the rubric DAG
|
| 7 |
+
4. Stores the submission as the round result
|
| 8 |
+
|
| 9 |
+
The agent must call this exactly once per round. After 3 calls the episode terminates.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
from typing import Any
|
| 15 |
+
|
| 16 |
+
from server.tools.cpp_compiler import compile_and_benchmark_tool
|
| 17 |
+
from server.tools.verifier import verify_equivalence_tool
|
| 18 |
+
from server.tools.portability_checker import check_portability_tool
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def submit_optimization_tool(tool_args: dict[str, Any], state) -> dict[str, Any]:
|
| 22 |
+
"""Final submission for this round. Runs full verifier + portability + benchmark.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
cpp_code (str) — required
|
| 26 |
+
reasoning_trace (str) — agent's overall <think> trace for this round
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
compile_status (str)
|
| 30 |
+
speedup (float)
|
| 31 |
+
correctness_pass_rate (float)
|
| 32 |
+
adversarial_pass_rate (float)
|
| 33 |
+
portability (dict)
|
| 34 |
+
n_profiles_passing (int)
|
| 35 |
+
ready_for_reward (bool) — True iff hard gates pass; informs the rubric
|
| 36 |
+
cpp_code (str) — echoed for the round_results history
|
| 37 |
+
reasoning_trace (str) — echoed
|
| 38 |
+
"""
|
| 39 |
+
cpp_code = tool_args.get("cpp_code", "")
|
| 40 |
+
reasoning_trace = tool_args.get("reasoning_trace", state.current_round_reasoning)
|
| 41 |
+
|
| 42 |
+
if not cpp_code.strip():
|
| 43 |
+
return {
|
| 44 |
+
"compile_status": "syntax_error",
|
| 45 |
+
"error": "empty cpp_code",
|
| 46 |
+
"speedup": 0.0,
|
| 47 |
+
"correctness_pass_rate": 0.0,
|
| 48 |
+
"ready_for_reward": False,
|
| 49 |
+
"cpp_code": "",
|
| 50 |
+
"reasoning_trace": reasoning_trace,
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
# Step 1: compile + benchmark
|
| 54 |
+
bench = compile_and_benchmark_tool({"cpp_code": cpp_code}, state)
|
| 55 |
+
if bench["compile_status"] != "success":
|
| 56 |
+
return {
|
| 57 |
+
"compile_status": bench["compile_status"],
|
| 58 |
+
"error": bench.get("error", ""),
|
| 59 |
+
"speedup": 0.0,
|
| 60 |
+
"correctness_pass_rate": 0.0,
|
| 61 |
+
"adversarial_pass_rate": 0.0,
|
| 62 |
+
"portability": {"n_profiles_passing": 0, "portability_bonus_eligible": False},
|
| 63 |
+
"ready_for_reward": False,
|
| 64 |
+
"cpp_code": cpp_code,
|
| 65 |
+
"reasoning_trace": reasoning_trace,
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
# Step 2: full 1000-case verifier (or whatever n_cases the curriculum specifies)
|
| 69 |
+
n_cases = 1000 if state.difficulty_axes.get("fuzzer_strictness", 0) >= 2 else 500
|
| 70 |
+
verifier_result = verify_equivalence_tool(
|
| 71 |
+
{"cpp_code": cpp_code, "n_cases": n_cases},
|
| 72 |
+
state,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# Step 3: portability check (only if axis is on; informational otherwise)
|
| 76 |
+
portability_result = check_portability_tool({"cpp_code": cpp_code, "n_cases_per_profile": 50}, state)
|
| 77 |
+
|
| 78 |
+
# Update episode-best speedup tracker
|
| 79 |
+
if bench["speedup"] > state.best_speedup:
|
| 80 |
+
state.best_speedup = bench["speedup"]
|
| 81 |
+
state.best_cpp_code = cpp_code
|
| 82 |
+
|
| 83 |
+
# Round-aware readiness score (continuous) + boolean convenience flag
|
| 84 |
+
round_thresholds = {1: 0.6, 2: 0.8, 3: 0.95}
|
| 85 |
+
threshold = round_thresholds.get(state.round_number, 0.6)
|
| 86 |
+
correctness_ratio = verifier_result["pass_rate"] / max(threshold, 1e-9)
|
| 87 |
+
adversarial_ratio = verifier_result.get("adversarial_pass_rate", 0.0) / 0.9
|
| 88 |
+
compile_quality = 1.0 if bench["compile_status"] == "success" else 0.0
|
| 89 |
+
readiness_score = (
|
| 90 |
+
0.55 * min(1.0, correctness_ratio)
|
| 91 |
+
+ 0.30 * min(1.0, adversarial_ratio)
|
| 92 |
+
+ 0.15 * compile_quality
|
| 93 |
+
)
|
| 94 |
+
ready = readiness_score >= 0.9
|
| 95 |
+
|
| 96 |
+
return {
|
| 97 |
+
"compile_status": bench["compile_status"],
|
| 98 |
+
"speedup": bench["speedup"],
|
| 99 |
+
"python_ms": bench.get("python_ms"),
|
| 100 |
+
"cpp_ms": bench.get("cpp_ms"),
|
| 101 |
+
"correctness_pass_rate": verifier_result["pass_rate"],
|
| 102 |
+
"adversarial_pass_rate": verifier_result.get("adversarial_pass_rate", 0.0),
|
| 103 |
+
"first_correctness_failure": verifier_result.get("first_failure"),
|
| 104 |
+
"portability": portability_result,
|
| 105 |
+
"n_profiles_passing": portability_result.get("n_profiles_passing", 0),
|
| 106 |
+
"readiness_score": readiness_score,
|
| 107 |
+
"ready_for_reward": ready,
|
| 108 |
+
"cpp_code": cpp_code,
|
| 109 |
+
"reasoning_trace": reasoning_trace,
|
| 110 |
+
"round_threshold_correctness": threshold,
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
__all__ = ["submit_optimization_tool"]
|
server/tools/verifier.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tool 6/9: verify_equivalence — anti-cheating fuzzer.
|
| 2 |
+
|
| 3 |
+
Per plan §10b, this is the single most important defense against the agent
|
| 4 |
+
cheating by producing a fast-but-wrong implementation.
|
| 5 |
+
|
| 6 |
+
8 cheating modes defended:
|
| 7 |
+
1. Wrong algorithm with plausible output — random fuzz inputs
|
| 8 |
+
2. Edge-case overflow (int32 wraps int64) — typed inputs include int64, INT_MAX/MIN
|
| 9 |
+
3. Approximation drift — rtol=1e-5 (or rtol=0 per metadata)
|
| 10 |
+
4. Cached lookup table — seed randomized per call
|
| 11 |
+
5. Tail variance — 10% adversarial sub-pool
|
| 12 |
+
6. Returns 0 / empty — exact shape+dtype check
|
| 13 |
+
7. Detects benchmark context — same input pipeline as benchmarker
|
| 14 |
+
8. Side-channel access — sandboxed subprocess
|
| 15 |
+
|
| 16 |
+
Returns: pass_rate ∈ [0, 1], first_failure dict, n_adversarial_failures.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import ast
|
| 22 |
+
import random
|
| 23 |
+
from typing import Any
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
_ALLOWED_IMPORT_MODULES = {"math", "numpy"}
|
| 29 |
+
_BANNED_CALLS = {"eval", "exec", "compile", "open", "__import__", "input"}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _safe_import(name, globals=None, locals=None, fromlist=(), level=0):
|
| 33 |
+
root = name.split(".")[0]
|
| 34 |
+
if root not in _ALLOWED_IMPORT_MODULES:
|
| 35 |
+
raise RuntimeError(f"import '{name}' is not allowed in verifier")
|
| 36 |
+
return __import__(name, globals, locals, fromlist, level)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _validate_python_code_safety(tree: ast.AST) -> None:
|
| 40 |
+
"""Reject high-risk constructs before running user-provided Python code."""
|
| 41 |
+
for node in ast.walk(tree):
|
| 42 |
+
if isinstance(node, ast.Import):
|
| 43 |
+
for alias in node.names:
|
| 44 |
+
root = alias.name.split(".")[0]
|
| 45 |
+
if root not in _ALLOWED_IMPORT_MODULES:
|
| 46 |
+
raise RuntimeError(f"import '{alias.name}' is not allowed in verifier")
|
| 47 |
+
if isinstance(node, ast.ImportFrom):
|
| 48 |
+
module = (node.module or "").split(".")[0]
|
| 49 |
+
if module and module not in _ALLOWED_IMPORT_MODULES:
|
| 50 |
+
raise RuntimeError(f"from '{node.module}' import ... is not allowed in verifier")
|
| 51 |
+
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
|
| 52 |
+
if node.func.id in _BANNED_CALLS:
|
| 53 |
+
raise RuntimeError(f"call '{node.func.id}(...)' is not allowed in verifier")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _safe_exec_function(python_code: str, fn_name: str):
|
| 57 |
+
"""Compile and execute Python in a constrained namespace, then return fn."""
|
| 58 |
+
tree = ast.parse(python_code)
|
| 59 |
+
_validate_python_code_safety(tree)
|
| 60 |
+
|
| 61 |
+
safe_builtins = {
|
| 62 |
+
"abs": abs,
|
| 63 |
+
"all": all,
|
| 64 |
+
"any": any,
|
| 65 |
+
"bool": bool,
|
| 66 |
+
"dict": dict,
|
| 67 |
+
"enumerate": enumerate,
|
| 68 |
+
"Exception": Exception,
|
| 69 |
+
"float": float,
|
| 70 |
+
"int": int,
|
| 71 |
+
"len": len,
|
| 72 |
+
"list": list,
|
| 73 |
+
"max": max,
|
| 74 |
+
"min": min,
|
| 75 |
+
"TypeError": TypeError,
|
| 76 |
+
"pow": pow,
|
| 77 |
+
"range": range,
|
| 78 |
+
"round": round,
|
| 79 |
+
"set": set,
|
| 80 |
+
"sorted": sorted,
|
| 81 |
+
"sum": sum,
|
| 82 |
+
"tuple": tuple,
|
| 83 |
+
"ValueError": ValueError,
|
| 84 |
+
"__import__": _safe_import,
|
| 85 |
+
"zip": zip,
|
| 86 |
+
}
|
| 87 |
+
ns: dict[str, Any] = {"__builtins__": safe_builtins, "np": np}
|
| 88 |
+
exec(compile(tree, filename="<verifier_python>", mode="exec"), ns)
|
| 89 |
+
fn = ns.get(fn_name)
|
| 90 |
+
if fn is None:
|
| 91 |
+
raise RuntimeError(f"function '{fn_name}' not defined in python_code")
|
| 92 |
+
return fn
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# ---------- Input generation from Python AST ----------
|
| 96 |
+
|
| 97 |
+
def _infer_input_signature(python_code: str) -> list[dict[str, str]]:
|
| 98 |
+
"""Inspect the Python function's signature + annotations to pick fuzz input types.
|
| 99 |
+
|
| 100 |
+
Returns a list of {"name": str, "kind": "ndarray|int|float|list|str", "dtype": str}.
|
| 101 |
+
Without explicit annotations, we fall back to ndarray of float64.
|
| 102 |
+
"""
|
| 103 |
+
try:
|
| 104 |
+
tree = ast.parse(python_code)
|
| 105 |
+
except SyntaxError:
|
| 106 |
+
return [{"name": "x", "kind": "ndarray", "dtype": "float64"}]
|
| 107 |
+
|
| 108 |
+
fn = next((n for n in tree.body if isinstance(n, ast.FunctionDef)), None)
|
| 109 |
+
if fn is None:
|
| 110 |
+
return [{"name": "x", "kind": "ndarray", "dtype": "float64"}]
|
| 111 |
+
|
| 112 |
+
sig: list[dict[str, str]] = []
|
| 113 |
+
for arg in fn.args.args:
|
| 114 |
+
ann = ast.unparse(arg.annotation) if arg.annotation else ""
|
| 115 |
+
kind = "ndarray"
|
| 116 |
+
dtype = "float64"
|
| 117 |
+
if "int" in ann.lower() and "ndarray" not in ann.lower() and "list" not in ann.lower():
|
| 118 |
+
kind = "int"
|
| 119 |
+
elif "float" in ann.lower() and "ndarray" not in ann.lower() and "list" not in ann.lower():
|
| 120 |
+
kind = "float"
|
| 121 |
+
elif "list" in ann.lower():
|
| 122 |
+
kind = "list"
|
| 123 |
+
elif "str" in ann.lower():
|
| 124 |
+
kind = "str"
|
| 125 |
+
if "int32" in ann:
|
| 126 |
+
dtype = "int32"
|
| 127 |
+
elif "int64" in ann:
|
| 128 |
+
dtype = "int64"
|
| 129 |
+
elif "float32" in ann:
|
| 130 |
+
dtype = "float32"
|
| 131 |
+
sig.append({"name": arg.arg, "kind": kind, "dtype": dtype})
|
| 132 |
+
|
| 133 |
+
# Default fallback: assume one ndarray
|
| 134 |
+
if not sig:
|
| 135 |
+
sig = [{"name": "x", "kind": "ndarray", "dtype": "float64"}]
|
| 136 |
+
return sig
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _generate_typed_input(spec: dict[str, str], rng: np.random.Generator, adversarial: bool = False) -> Any:
|
| 140 |
+
"""Generate one input matching spec. If adversarial, sample boundary/edge values."""
|
| 141 |
+
kind = spec["kind"]
|
| 142 |
+
dtype = spec["dtype"]
|
| 143 |
+
|
| 144 |
+
if kind == "int":
|
| 145 |
+
if adversarial:
|
| 146 |
+
return int(rng.choice([0, 1, -1, 2**31 - 1, -(2**31), 2**62, -(2**62)]))
|
| 147 |
+
return int(rng.integers(-1000, 1000))
|
| 148 |
+
|
| 149 |
+
if kind == "float":
|
| 150 |
+
if adversarial:
|
| 151 |
+
return float(rng.choice([0.0, -0.0, np.inf, -np.inf, np.nan, 1e-300, 1e300]))
|
| 152 |
+
return float(rng.standard_normal())
|
| 153 |
+
|
| 154 |
+
if kind == "str":
|
| 155 |
+
# Short ascii strings
|
| 156 |
+
return "".join(chr(int(rng.integers(97, 123))) for _ in range(int(rng.integers(1, 16))))
|
| 157 |
+
|
| 158 |
+
# Default: ndarray
|
| 159 |
+
n = int(rng.integers(10, 1000))
|
| 160 |
+
if adversarial:
|
| 161 |
+
choices = [
|
| 162 |
+
np.zeros(n, dtype=dtype),
|
| 163 |
+
np.ones(n, dtype=dtype),
|
| 164 |
+
np.array([], dtype=dtype), # empty
|
| 165 |
+
np.array([0.0], dtype=dtype), # singleton
|
| 166 |
+
np.full(n, np.inf, dtype=dtype) if "float" in dtype else np.full(n, np.iinfo(np.dtype(dtype)).max, dtype=dtype),
|
| 167 |
+
(rng.standard_normal(n) * 1e-300).astype(dtype) if "float" in dtype else rng.integers(-1, 2, n).astype(dtype),
|
| 168 |
+
]
|
| 169 |
+
idx = int(rng.integers(0, len(choices)))
|
| 170 |
+
return choices[idx]
|
| 171 |
+
|
| 172 |
+
if "int" in dtype:
|
| 173 |
+
return rng.integers(-100, 100, size=n).astype(dtype)
|
| 174 |
+
return rng.standard_normal(n).astype(dtype)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def _numerically_equivalent(a: Any, b: Any, rtol: float) -> bool:
|
| 178 |
+
"""Compare two outputs accounting for float tolerance, exact for int."""
|
| 179 |
+
if isinstance(a, (int, float)) and isinstance(b, (int, float)):
|
| 180 |
+
if rtol == 0:
|
| 181 |
+
return a == b
|
| 182 |
+
if not np.isfinite(a) or not np.isfinite(b):
|
| 183 |
+
return (np.isnan(a) and np.isnan(b)) or a == b
|
| 184 |
+
return abs(a - b) <= rtol * (1 + abs(a))
|
| 185 |
+
|
| 186 |
+
try:
|
| 187 |
+
a = np.asarray(a)
|
| 188 |
+
b = np.asarray(b)
|
| 189 |
+
except Exception:
|
| 190 |
+
return a == b
|
| 191 |
+
|
| 192 |
+
if a.shape != b.shape:
|
| 193 |
+
return False
|
| 194 |
+
if a.dtype != b.dtype:
|
| 195 |
+
# We don't allow dtype-mismatch — that's a hard fail per plan §10b
|
| 196 |
+
return False
|
| 197 |
+
|
| 198 |
+
if rtol == 0:
|
| 199 |
+
return bool(np.array_equal(a, b))
|
| 200 |
+
|
| 201 |
+
# Use allclose with NaN-equality
|
| 202 |
+
return bool(np.allclose(a, b, rtol=rtol, atol=rtol * 0.1, equal_nan=True))
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def _exec_python_in_sandbox(python_code: str, fn_name: str, args: tuple) -> Any:
|
| 206 |
+
"""Run python_code's function on args in a constrained namespace."""
|
| 207 |
+
fn = _safe_exec_function(python_code, fn_name)
|
| 208 |
+
return fn(*args)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def _exec_cpp_via_so(so_path: str, fn_name: str, args: tuple, py_fn=None, py_code: str = "") -> Any:
|
| 212 |
+
"""Load the compiled .so via ctypes and dispatch on `args`.
|
| 213 |
+
|
| 214 |
+
The agent's C++ uses the canonical signature
|
| 215 |
+
extern "C" void agent_function(const double*, size_t, double*, size_t);
|
| 216 |
+
so we need the Python reference function to know the output shape. Either
|
| 217 |
+
pass `py_fn` directly, or pass `py_code` and we'll compile it.
|
| 218 |
+
|
| 219 |
+
Raises:
|
| 220 |
+
RuntimeError: ctypes can't load the .so or symbol is missing
|
| 221 |
+
"""
|
| 222 |
+
from server.tools._runtime import call_compiled
|
| 223 |
+
if py_fn is None:
|
| 224 |
+
if not py_code:
|
| 225 |
+
raise RuntimeError("verifier: need py_fn or py_code to dispatch C++")
|
| 226 |
+
py_fn = _safe_exec_function(py_code, fn_name)
|
| 227 |
+
return call_compiled(so_path, py_fn, args)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def verify_equivalence_tool(tool_args: dict[str, Any], state) -> dict[str, Any]:
|
| 231 |
+
"""Fuzz-verify cpp_code against python_code on n_cases random + adversarial inputs.
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
cpp_code (str) — agent's C++
|
| 235 |
+
python_code (str) — reference Python (defaults to state.python_code)
|
| 236 |
+
n_cases (int=1000) — total fuzz cases (10% adversarial sub-pool)
|
| 237 |
+
rtol (float=1e-5) — float tolerance; 0 = bit-exact
|
| 238 |
+
|
| 239 |
+
Returns:
|
| 240 |
+
pass_rate (float)
|
| 241 |
+
first_failure (dict | None)
|
| 242 |
+
n_adversarial_failures (int)
|
| 243 |
+
n_random_failures (int)
|
| 244 |
+
seed (int) — randomized per call (defeats lookup tables)
|
| 245 |
+
"""
|
| 246 |
+
cpp_code = tool_args.get("cpp_code", "")
|
| 247 |
+
python_code = tool_args.get("python_code") or state.python_code
|
| 248 |
+
n_cases = int(tool_args.get("n_cases", 1000))
|
| 249 |
+
rtol = float(tool_args.get("rtol", state.rtol_override if state.rtol_override is not None else 1e-5))
|
| 250 |
+
|
| 251 |
+
if not cpp_code.strip():
|
| 252 |
+
return {"pass_rate": 0.0, "error": "empty cpp_code"}
|
| 253 |
+
if n_cases <= 0:
|
| 254 |
+
return {"pass_rate": 0.0, "error": "n_cases must be >= 1", "n_cases": n_cases}
|
| 255 |
+
|
| 256 |
+
# Defeat lookup-table cheating mode 4: seed varies per call
|
| 257 |
+
seed = random.randint(0, 2**32 - 1)
|
| 258 |
+
rng = np.random.default_rng(seed)
|
| 259 |
+
|
| 260 |
+
# Discover Python function name (first FunctionDef)
|
| 261 |
+
try:
|
| 262 |
+
tree = ast.parse(python_code)
|
| 263 |
+
except SyntaxError as e:
|
| 264 |
+
return {"pass_rate": 0.0, "error": f"python parse: {e}"}
|
| 265 |
+
fn_node = next((n for n in tree.body if isinstance(n, ast.FunctionDef)), None)
|
| 266 |
+
if fn_node is None:
|
| 267 |
+
return {"pass_rate": 0.0, "error": "no function in python_code"}
|
| 268 |
+
fn_name = fn_node.name
|
| 269 |
+
|
| 270 |
+
sig = _infer_input_signature(python_code)
|
| 271 |
+
|
| 272 |
+
# Compile (or get cached .so) — uses cpp_compiler tool's pathway
|
| 273 |
+
from server.tools.cpp_compiler import _compile, _sha256
|
| 274 |
+
import json as _json
|
| 275 |
+
cache_key = _sha256(cpp_code, _json.dumps(state.hardware_profile, sort_keys=True))
|
| 276 |
+
compile_result = _compile(cpp_code, state.hardware_profile, cache_key)
|
| 277 |
+
if compile_result["status"] != "success":
|
| 278 |
+
return {
|
| 279 |
+
"pass_rate": 0.0,
|
| 280 |
+
"error": f"cpp compile failed: {compile_result.get('error', '')[:300]}",
|
| 281 |
+
"compile_status": compile_result["status"],
|
| 282 |
+
}
|
| 283 |
+
so_path = compile_result["so_path"]
|
| 284 |
+
|
| 285 |
+
# Pre-load the Python reference function once (avoids repeated exec overhead)
|
| 286 |
+
try:
|
| 287 |
+
py_fn = _safe_exec_function(python_code, fn_name)
|
| 288 |
+
except Exception as e:
|
| 289 |
+
return {"pass_rate": 0.0, "error": f"python exec failed: {e}"}
|
| 290 |
+
|
| 291 |
+
failures: list[dict[str, Any]] = []
|
| 292 |
+
n_adversarial_failures = 0
|
| 293 |
+
n_random_failures = 0
|
| 294 |
+
|
| 295 |
+
for i in range(n_cases):
|
| 296 |
+
adversarial = (i % 10 == 9) # 10% adversarial sub-pool
|
| 297 |
+
try:
|
| 298 |
+
args = tuple(_generate_typed_input(spec, rng, adversarial=adversarial) for spec in sig)
|
| 299 |
+
except Exception:
|
| 300 |
+
continue # Skip if input generation itself fails
|
| 301 |
+
|
| 302 |
+
# Run Python first; if it raises, skip (don't penalize the C++ for invalid input)
|
| 303 |
+
try:
|
| 304 |
+
py_out = py_fn(*args)
|
| 305 |
+
except Exception:
|
| 306 |
+
continue
|
| 307 |
+
|
| 308 |
+
# Run C++ via ctypes dispatch — REAL execution now (not stub)
|
| 309 |
+
try:
|
| 310 |
+
cpp_out = _exec_cpp_via_so(so_path, fn_name, args, py_fn=py_fn)
|
| 311 |
+
except Exception as e:
|
| 312 |
+
if adversarial:
|
| 313 |
+
n_adversarial_failures += 1
|
| 314 |
+
else:
|
| 315 |
+
n_random_failures += 1
|
| 316 |
+
if not failures:
|
| 317 |
+
failures.append({
|
| 318 |
+
"case": i, "reason": "cpp_exec_error", "error": str(e)[:200],
|
| 319 |
+
"adversarial": adversarial,
|
| 320 |
+
})
|
| 321 |
+
continue
|
| 322 |
+
|
| 323 |
+
if not _numerically_equivalent(py_out, cpp_out, rtol):
|
| 324 |
+
if adversarial:
|
| 325 |
+
n_adversarial_failures += 1
|
| 326 |
+
else:
|
| 327 |
+
n_random_failures += 1
|
| 328 |
+
if not failures:
|
| 329 |
+
# Capture only first failure to bound observation size
|
| 330 |
+
py_repr = repr(py_out)[:120]
|
| 331 |
+
cpp_repr = repr(cpp_out)[:120]
|
| 332 |
+
failures.append({
|
| 333 |
+
"case": i, "reason": "output_mismatch",
|
| 334 |
+
"adversarial": adversarial,
|
| 335 |
+
"py_out": py_repr, "cpp_out": cpp_repr,
|
| 336 |
+
})
|
| 337 |
+
|
| 338 |
+
pass_count = n_cases - (n_adversarial_failures + n_random_failures)
|
| 339 |
+
pass_rate = pass_count / n_cases
|
| 340 |
+
|
| 341 |
+
n_adversarial_total = n_cases // 10
|
| 342 |
+
adversarial_pass_rate = (n_adversarial_total - n_adversarial_failures) / max(n_adversarial_total, 1)
|
| 343 |
+
|
| 344 |
+
return {
|
| 345 |
+
"pass_rate": pass_rate,
|
| 346 |
+
"n_cases": n_cases,
|
| 347 |
+
"first_failure": failures[0] if failures else None,
|
| 348 |
+
"n_adversarial_failures": n_adversarial_failures,
|
| 349 |
+
"n_random_failures": n_random_failures,
|
| 350 |
+
"adversarial_pass_rate": adversarial_pass_rate,
|
| 351 |
+
"rtol_used": rtol,
|
| 352 |
+
"seed": seed,
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
__all__ = ["verify_equivalence_tool", "_infer_input_signature", "_numerically_equivalent"]
|
tests/__init__.py
ADDED
|
File without changes
|
tests/smoke_llm_hf.py
ADDED
|
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LLM smoke test via HuggingFace Inference API or Cursor API.
|
| 2 |
+
|
| 3 |
+
Runs 3 short episodes against the env using a remote LLM, validates that:
|
| 4 |
+
1. The model emits parseable `<think>...</think>` blocks (DiagnosisRubric needs this)
|
| 5 |
+
2. Tool calls extract cleanly from the response
|
| 6 |
+
3. The agent's C++ output respects the `extern "C" agent_function` contract
|
| 7 |
+
4. End-to-end env<-->LLM loop completes without crashing
|
| 8 |
+
5. Reward DAG produces non-zero reward at least once
|
| 9 |
+
|
| 10 |
+
Run (HF provider):
|
| 11 |
+
export HF_TOKEN=hf_...
|
| 12 |
+
cd polyglot_optima && python tests/smoke_llm_hf.py
|
| 13 |
+
|
| 14 |
+
Run (Cursor provider):
|
| 15 |
+
export LLM_PROVIDER=cursor
|
| 16 |
+
export CURSOR_API_KEY=...
|
| 17 |
+
export CURSOR_MODEL=gpt-4.1-mini
|
| 18 |
+
# optional: export CURSOR_API_BASE_URL=https://api.cursor.com/v1
|
| 19 |
+
cd polyglot_optima && python tests/smoke_llm_hf.py
|
| 20 |
+
|
| 21 |
+
Without a token: anonymous access (very limited rate; may fail randomly).
|
| 22 |
+
|
| 23 |
+
Cost: free tier on HF Inference API. ~45 model calls across 3 episodes.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import json
|
| 29 |
+
import os
|
| 30 |
+
import re
|
| 31 |
+
import sys
|
| 32 |
+
import time
|
| 33 |
+
from urllib import request, error
|
| 34 |
+
from pathlib import Path
|
| 35 |
+
from typing import Any
|
| 36 |
+
|
| 37 |
+
# Make the package importable when run as a script
|
| 38 |
+
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
| 39 |
+
|
| 40 |
+
from models import OptimizationAction
|
| 41 |
+
from server.environment import PolyglotOptimaEnvironment
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ---------- Models to try (free-tier-friendly, instruct-tuned, in order of preference) ----------
|
| 45 |
+
|
| 46 |
+
MODEL_CANDIDATES = [
|
| 47 |
+
"Qwen/Qwen2.5-Coder-7B-Instruct", # Code-focused, primary fallback
|
| 48 |
+
"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", # Plan-target reasoning model
|
| 49 |
+
"meta-llama/Llama-3.1-8B-Instruct", # Generic instruct fallback
|
| 50 |
+
"mistralai/Mistral-7B-Instruct-v0.3", # Last-resort fallback
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# ---------- System prompt (canonical per plan §11) ----------
|
| 55 |
+
|
| 56 |
+
SYSTEM_PROMPT = """You are a senior C++ performance engineer specializing in hardware-aware code.
|
| 57 |
+
|
| 58 |
+
YOUR TASK: each turn, choose ONE of the 9 tools to call. After 3 rounds of refinement, you submit your final optimized C++.
|
| 59 |
+
|
| 60 |
+
OUTPUT FORMAT (STRICT -- non-conforming responses score 0):
|
| 61 |
+
|
| 62 |
+
<think>
|
| 63 |
+
1. What is the bottleneck? (memory-bound / compute-bound / branch-heavy / vectorizable)
|
| 64 |
+
2. What does the hardware imply about strategy?
|
| 65 |
+
3. Which tool should I call next, and why?
|
| 66 |
+
</think>
|
| 67 |
+
```json
|
| 68 |
+
{"tool_name": "<one of the 9 tools>", "tool_args": { ... }}
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
THE 9 TOOLS:
|
| 72 |
+
- get_hardware_profile() -- returns hw spec + Roofline
|
| 73 |
+
- profile_python_hotspots(code) -- top hot lines
|
| 74 |
+
- analyze_complexity(code) -- Big-O + nesting depth
|
| 75 |
+
- check_memory_access(code) -- stride / aliasing flags
|
| 76 |
+
- compile_and_benchmark(cpp_code) -- speedup measurement
|
| 77 |
+
- verify_equivalence(cpp_code) -- fuzzer pass rate
|
| 78 |
+
- check_portability(cpp_code) -- cross-profile pass count
|
| 79 |
+
- get_bottleneck_report(cpp_code) -- perf-stat-style report on YOUR C++
|
| 80 |
+
- submit_optimization(cpp_code, reasoning_trace) -- FINAL submission for the round
|
| 81 |
+
|
| 82 |
+
HARD CONSTRAINTS for cpp_code:
|
| 83 |
+
- C++20, single canonical signature:
|
| 84 |
+
extern "C" void agent_function(const double* in_ptr, size_t in_n, double* out_ptr, size_t out_n);
|
| 85 |
+
- Compiles with: g++ -O3 -march=native -fopenmp -std=c++20 -Wall
|
| 86 |
+
- BANNED: <mkl.h>, <Eigen/...>, BLAS/LAPACK, CUDA. We measure YOUR optimization.
|
| 87 |
+
- Allowed: full STL, <immintrin.h>, <arm_neon.h>, <omp.h>, <pybind11/*>
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# ---------- LLM call (HF Inference API) ----------
|
| 92 |
+
|
| 93 |
+
def call_llm_hf(messages: list[dict[str, str]], model: str, hf_token: str | None) -> str:
|
| 94 |
+
"""One inference call. Returns the assistant's text content. Raises on hard errors."""
|
| 95 |
+
from huggingface_hub import InferenceClient
|
| 96 |
+
client = InferenceClient(token=hf_token)
|
| 97 |
+
resp = client.chat_completion(
|
| 98 |
+
messages=messages,
|
| 99 |
+
model=model,
|
| 100 |
+
max_tokens=512,
|
| 101 |
+
temperature=0.5,
|
| 102 |
+
)
|
| 103 |
+
return resp.choices[0].message.content or ""
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def pick_model_hf(hf_token: str | None) -> str | None:
|
| 107 |
+
"""Probe the free-tier API for the first available candidate model."""
|
| 108 |
+
from huggingface_hub import InferenceClient
|
| 109 |
+
client = InferenceClient(token=hf_token)
|
| 110 |
+
for name in MODEL_CANDIDATES:
|
| 111 |
+
try:
|
| 112 |
+
resp = client.chat_completion(
|
| 113 |
+
messages=[{"role": "user", "content": "hi"}],
|
| 114 |
+
model=name,
|
| 115 |
+
max_tokens=4,
|
| 116 |
+
)
|
| 117 |
+
if resp.choices[0].message.content is not None:
|
| 118 |
+
return name
|
| 119 |
+
except Exception as e:
|
| 120 |
+
print(f" - {name} → not available: {str(e)[:80]}", file=sys.stderr)
|
| 121 |
+
continue
|
| 122 |
+
return None
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def call_llm_cursor(
|
| 126 |
+
messages: list[dict[str, str]],
|
| 127 |
+
model: str,
|
| 128 |
+
cursor_api_key: str,
|
| 129 |
+
cursor_api_base_url: str,
|
| 130 |
+
) -> str:
|
| 131 |
+
"""Call Cursor API with an OpenAI-compatible chat payload."""
|
| 132 |
+
payload = {
|
| 133 |
+
"model": model,
|
| 134 |
+
"messages": messages,
|
| 135 |
+
"temperature": 0.5,
|
| 136 |
+
"max_tokens": 512,
|
| 137 |
+
}
|
| 138 |
+
base = cursor_api_base_url.rstrip("/")
|
| 139 |
+
url = f"{base}/chat/completions"
|
| 140 |
+
req = request.Request(
|
| 141 |
+
url=url,
|
| 142 |
+
method="POST",
|
| 143 |
+
headers={
|
| 144 |
+
"Content-Type": "application/json",
|
| 145 |
+
"Authorization": f"Bearer {cursor_api_key}",
|
| 146 |
+
},
|
| 147 |
+
data=json.dumps(payload).encode("utf-8"),
|
| 148 |
+
)
|
| 149 |
+
try:
|
| 150 |
+
with request.urlopen(req, timeout=60) as resp:
|
| 151 |
+
raw = resp.read().decode("utf-8")
|
| 152 |
+
except error.HTTPError as e:
|
| 153 |
+
body = e.read().decode("utf-8", errors="replace")
|
| 154 |
+
raise RuntimeError(f"Cursor API HTTP {e.code}: {body[:240]}")
|
| 155 |
+
except Exception as e:
|
| 156 |
+
raise RuntimeError(f"Cursor API request failed: {e}")
|
| 157 |
+
|
| 158 |
+
try:
|
| 159 |
+
obj = json.loads(raw)
|
| 160 |
+
return obj["choices"][0]["message"]["content"] or ""
|
| 161 |
+
except Exception as e:
|
| 162 |
+
raise RuntimeError(f"Cursor API response parse failed: {e}; body={raw[:240]}")
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def pick_model_cursor(cursor_api_key: str, cursor_api_base_url: str, preferred_model: str | None) -> str | None:
|
| 166 |
+
"""Probe Cursor API with preferred model first, then a short fallback list."""
|
| 167 |
+
candidates = [m for m in [preferred_model, "gpt-4.1-mini", "gpt-4o-mini"] if m]
|
| 168 |
+
for name in candidates:
|
| 169 |
+
try:
|
| 170 |
+
_ = call_llm_cursor(
|
| 171 |
+
messages=[{"role": "user", "content": "hi"}],
|
| 172 |
+
model=name,
|
| 173 |
+
cursor_api_key=cursor_api_key,
|
| 174 |
+
cursor_api_base_url=cursor_api_base_url,
|
| 175 |
+
)
|
| 176 |
+
return name
|
| 177 |
+
except Exception as e:
|
| 178 |
+
print(f" - {name} -> not available: {str(e)[:100]}", file=sys.stderr)
|
| 179 |
+
continue
|
| 180 |
+
return None
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
# ---------- Response parsing ----------
|
| 184 |
+
|
| 185 |
+
_THINK_RE = re.compile(r"<think>(.*?)</think>", re.DOTALL | re.IGNORECASE)
|
| 186 |
+
_JSON_BLOCK_RE = re.compile(r"```(?:json)?\s*(\{.*?\})\s*```", re.DOTALL)
|
| 187 |
+
_LOOSE_JSON_RE = re.compile(r"\{[^{}]*\"tool_name\"[^{}]*\}", re.DOTALL)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def parse_llm_response(text: str) -> dict[str, Any]:
|
| 191 |
+
"""Extract <think>, tool_name, tool_args from raw LLM text. Best-effort.
|
| 192 |
+
|
| 193 |
+
Returns dict with: thinking, tool_name, tool_args, parse_status.
|
| 194 |
+
parse_status ∈ {"ok", "no_think", "no_json", "no_tool", "json_invalid"}.
|
| 195 |
+
"""
|
| 196 |
+
out: dict[str, Any] = {
|
| 197 |
+
"thinking": "",
|
| 198 |
+
"tool_name": None,
|
| 199 |
+
"tool_args": {},
|
| 200 |
+
"parse_status": "ok",
|
| 201 |
+
"raw": text,
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
# Extract thinking block
|
| 205 |
+
m = _THINK_RE.search(text)
|
| 206 |
+
if m:
|
| 207 |
+
out["thinking"] = m.group(1).strip()
|
| 208 |
+
else:
|
| 209 |
+
out["parse_status"] = "no_think"
|
| 210 |
+
|
| 211 |
+
# Extract JSON tool call -- try fenced block first, then loose match
|
| 212 |
+
json_block = None
|
| 213 |
+
fence_match = _JSON_BLOCK_RE.search(text)
|
| 214 |
+
if fence_match:
|
| 215 |
+
json_block = fence_match.group(1)
|
| 216 |
+
else:
|
| 217 |
+
loose = _LOOSE_JSON_RE.search(text)
|
| 218 |
+
if loose:
|
| 219 |
+
json_block = loose.group(0)
|
| 220 |
+
|
| 221 |
+
if not json_block:
|
| 222 |
+
out["parse_status"] = "no_json" if out["parse_status"] == "ok" else out["parse_status"]
|
| 223 |
+
return out
|
| 224 |
+
|
| 225 |
+
try:
|
| 226 |
+
parsed = json.loads(json_block)
|
| 227 |
+
out["tool_name"] = parsed.get("tool_name")
|
| 228 |
+
out["tool_args"] = parsed.get("tool_args", {}) or {}
|
| 229 |
+
if not out["tool_name"]:
|
| 230 |
+
out["parse_status"] = "no_tool"
|
| 231 |
+
except json.JSONDecodeError as e:
|
| 232 |
+
out["parse_status"] = f"json_invalid: {e}"
|
| 233 |
+
return out
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
# ---------- Episode runner ----------
|
| 237 |
+
|
| 238 |
+
def build_user_prompt(observation, round_number: int) -> str:
|
| 239 |
+
return (
|
| 240 |
+
f"## Round {round_number} of 3\n\n"
|
| 241 |
+
f"### Hardware profile\n```json\n{json.dumps(observation.hardware_profile, indent=2)}\n```\n\n"
|
| 242 |
+
f"### Python function to optimize\n```python\n{observation.python_code}\n```\n\n"
|
| 243 |
+
f"### Last tool result\n```json\n{json.dumps(observation.tool_result, indent=2, default=str)[:1500]}\n```\n\n"
|
| 244 |
+
f"### Best speedup so far\n{observation.best_speedup_so_far:.3f}x\n\n"
|
| 245 |
+
f"What is your next action? "
|
| 246 |
+
f"After at most 4 tool calls in this round, you must call submit_optimization."
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def run_episode(
|
| 251 |
+
env: PolyglotOptimaEnvironment,
|
| 252 |
+
model: str,
|
| 253 |
+
provider: str,
|
| 254 |
+
hf_token: str | None,
|
| 255 |
+
cursor_api_key: str | None,
|
| 256 |
+
cursor_api_base_url: str | None,
|
| 257 |
+
episode_seed: int,
|
| 258 |
+
report: dict[str, Any],
|
| 259 |
+
) -> None:
|
| 260 |
+
"""Run one episode end-to-end. Mutates `report` with stats."""
|
| 261 |
+
obs = env.reset(seed=episode_seed)
|
| 262 |
+
ep_report: dict[str, Any] = {
|
| 263 |
+
"seed": episode_seed,
|
| 264 |
+
"rounds": [],
|
| 265 |
+
"errors": [],
|
| 266 |
+
"final_reward": 0.0,
|
| 267 |
+
"n_think_blocks": 0,
|
| 268 |
+
"n_parse_errors": 0,
|
| 269 |
+
"n_unknown_tools": 0,
|
| 270 |
+
"n_tool_calls": 0,
|
| 271 |
+
}
|
| 272 |
+
report["episodes"].append(ep_report)
|
| 273 |
+
|
| 274 |
+
valid_tool_names = set(env._tool_registry.keys())
|
| 275 |
+
max_calls_per_round = 4
|
| 276 |
+
|
| 277 |
+
for round_idx in range(1, 4):
|
| 278 |
+
round_calls: list[dict[str, Any]] = []
|
| 279 |
+
for call_idx in range(max_calls_per_round):
|
| 280 |
+
user_prompt = build_user_prompt(obs, round_idx)
|
| 281 |
+
messages = [
|
| 282 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 283 |
+
{"role": "user", "content": user_prompt},
|
| 284 |
+
]
|
| 285 |
+
try:
|
| 286 |
+
t0 = time.time()
|
| 287 |
+
if provider == "cursor":
|
| 288 |
+
raw = call_llm_cursor(
|
| 289 |
+
messages,
|
| 290 |
+
model,
|
| 291 |
+
cursor_api_key or "",
|
| 292 |
+
cursor_api_base_url or "",
|
| 293 |
+
)
|
| 294 |
+
else:
|
| 295 |
+
raw = call_llm_hf(messages, model, hf_token)
|
| 296 |
+
latency = time.time() - t0
|
| 297 |
+
except Exception as e:
|
| 298 |
+
ep_report["errors"].append(f"R{round_idx}.{call_idx} LLM call failed: {e}")
|
| 299 |
+
# Force a submit to advance the round
|
| 300 |
+
action = OptimizationAction(
|
| 301 |
+
tool_name="submit_optimization",
|
| 302 |
+
tool_args={"cpp_code": "// llm_call_error_fallback",
|
| 303 |
+
"reasoning_trace": "LLM call failed"},
|
| 304 |
+
reasoning_trace="<think>fallback</think>",
|
| 305 |
+
)
|
| 306 |
+
step = env.step(action)
|
| 307 |
+
obs = step.observation
|
| 308 |
+
break
|
| 309 |
+
|
| 310 |
+
parsed = parse_llm_response(raw)
|
| 311 |
+
if parsed["thinking"]:
|
| 312 |
+
ep_report["n_think_blocks"] += 1
|
| 313 |
+
if parsed["parse_status"] != "ok":
|
| 314 |
+
ep_report["n_parse_errors"] += 1
|
| 315 |
+
if parsed["tool_name"] and parsed["tool_name"] not in valid_tool_names:
|
| 316 |
+
ep_report["n_unknown_tools"] += 1
|
| 317 |
+
|
| 318 |
+
ep_report["n_tool_calls"] += 1
|
| 319 |
+
|
| 320 |
+
tool_name = parsed["tool_name"] or "submit_optimization"
|
| 321 |
+
tool_args = parsed["tool_args"] or {}
|
| 322 |
+
|
| 323 |
+
# If the model emitted a final submission, force the round to close
|
| 324 |
+
is_submit = tool_name == "submit_optimization"
|
| 325 |
+
# If we've hit the call cap and no submit yet, force one
|
| 326 |
+
if call_idx == max_calls_per_round - 1 and not is_submit:
|
| 327 |
+
tool_name = "submit_optimization"
|
| 328 |
+
tool_args = {"cpp_code": tool_args.get("cpp_code", "// no submission this round"),
|
| 329 |
+
"reasoning_trace": parsed["thinking"]}
|
| 330 |
+
is_submit = True
|
| 331 |
+
|
| 332 |
+
action = OptimizationAction(
|
| 333 |
+
tool_name=tool_name,
|
| 334 |
+
tool_args=tool_args,
|
| 335 |
+
reasoning_trace=parsed["thinking"][:1000],
|
| 336 |
+
)
|
| 337 |
+
try:
|
| 338 |
+
step = env.step(action)
|
| 339 |
+
obs = step.observation
|
| 340 |
+
round_calls.append({
|
| 341 |
+
"tool": tool_name,
|
| 342 |
+
"parse_status": parsed["parse_status"],
|
| 343 |
+
"latency_s": round(latency, 2),
|
| 344 |
+
"reward_so_far": round(step.reward, 3),
|
| 345 |
+
})
|
| 346 |
+
except Exception as e:
|
| 347 |
+
ep_report["errors"].append(f"R{round_idx}.{call_idx} env.step crashed: {e}")
|
| 348 |
+
break
|
| 349 |
+
|
| 350 |
+
if is_submit:
|
| 351 |
+
break
|
| 352 |
+
|
| 353 |
+
ep_report["rounds"].append(round_calls)
|
| 354 |
+
if obs.done:
|
| 355 |
+
ep_report["final_reward"] = round(step.reward, 3)
|
| 356 |
+
break
|
| 357 |
+
|
| 358 |
+
if not obs.done and not env.state().is_terminal:
|
| 359 |
+
# Episode didn't terminate via natural 3-round flow
|
| 360 |
+
ep_report["errors"].append("episode did not reach terminal state")
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
# ---------- Aggregate report ----------
|
| 364 |
+
|
| 365 |
+
def print_report(report: dict[str, Any]) -> None:
|
| 366 |
+
print("\n" + "=" * 70)
|
| 367 |
+
print("LLM SMOKE TEST REPORT")
|
| 368 |
+
print("=" * 70)
|
| 369 |
+
print(f"Model used: {report['model']}")
|
| 370 |
+
print(f"Episodes run: {len(report['episodes'])}")
|
| 371 |
+
print(f"Total LLM calls: {sum(e['n_tool_calls'] for e in report['episodes'])}")
|
| 372 |
+
|
| 373 |
+
n_think = sum(e["n_think_blocks"] for e in report["episodes"])
|
| 374 |
+
n_parse = sum(e["n_parse_errors"] for e in report["episodes"])
|
| 375 |
+
n_unknown = sum(e["n_unknown_tools"] for e in report["episodes"])
|
| 376 |
+
n_calls = sum(e["n_tool_calls"] for e in report["episodes"])
|
| 377 |
+
|
| 378 |
+
print("\n-- Output format compliance --")
|
| 379 |
+
print(f" <think> blocks emitted: {n_think} / {n_calls} ({100*n_think/max(n_calls,1):.0f}%)")
|
| 380 |
+
print(f" Parse errors: {n_parse} / {n_calls} ({100*n_parse/max(n_calls,1):.0f}%)")
|
| 381 |
+
print(f" Unknown/invalid tools: {n_unknown}")
|
| 382 |
+
|
| 383 |
+
print("\n-- Episode rewards --")
|
| 384 |
+
for ep in report["episodes"]:
|
| 385 |
+
n_errs = len(ep["errors"])
|
| 386 |
+
print(f" Episode {ep['seed']}: reward={ep['final_reward']}, errors={n_errs}")
|
| 387 |
+
|
| 388 |
+
if any(e["errors"] for e in report["episodes"]):
|
| 389 |
+
print("\n-- Errors --")
|
| 390 |
+
for ep in report["episodes"]:
|
| 391 |
+
for err in ep["errors"]:
|
| 392 |
+
print(f" - ep{ep['seed']}: {err[:140]}")
|
| 393 |
+
|
| 394 |
+
# Pass/fail verdict
|
| 395 |
+
print("\n-- Verdict --")
|
| 396 |
+
pass_threshold_think = 0.5 # ≥ 50% of calls should have <think>
|
| 397 |
+
pass_threshold_parse = 0.7 # ≥ 70% of calls should parse cleanly
|
| 398 |
+
n_episodes_completed = sum(1 for e in report["episodes"] if not any("did not reach terminal" in x for x in e["errors"]))
|
| 399 |
+
|
| 400 |
+
think_ok = n_think / max(n_calls, 1) >= pass_threshold_think
|
| 401 |
+
parse_ok = (n_calls - n_parse) / max(n_calls, 1) >= pass_threshold_parse
|
| 402 |
+
episodes_ok = n_episodes_completed == len(report["episodes"])
|
| 403 |
+
|
| 404 |
+
if think_ok and parse_ok and episodes_ok:
|
| 405 |
+
print(" [OK] PASS -- env<-->LLM integration works. Safe to launch GRPO training.")
|
| 406 |
+
else:
|
| 407 |
+
print(" [FAIL] FAIL -- fix before training:")
|
| 408 |
+
if not think_ok:
|
| 409 |
+
print(f" <think> emission rate too low ({100*n_think/max(n_calls,1):.0f}% < 50%)")
|
| 410 |
+
if not parse_ok:
|
| 411 |
+
print(f" parse rate too low ({100*(n_calls-n_parse)/max(n_calls,1):.0f}% < 70%)")
|
| 412 |
+
if not episodes_ok:
|
| 413 |
+
print(f" {len(report['episodes']) - n_episodes_completed} episodes did not terminate cleanly")
|
| 414 |
+
|
| 415 |
+
print("=" * 70 + "\n")
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
# ---------- Main ----------
|
| 419 |
+
|
| 420 |
+
def main() -> int:
|
| 421 |
+
provider = os.environ.get("LLM_PROVIDER", "hf").strip().lower()
|
| 422 |
+
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
|
| 423 |
+
cursor_api_key = os.environ.get("CURSOR_API_KEY")
|
| 424 |
+
cursor_api_base_url = os.environ.get("CURSOR_API_BASE_URL", "https://api.cursor.com/v1")
|
| 425 |
+
cursor_model = os.environ.get("CURSOR_MODEL")
|
| 426 |
+
|
| 427 |
+
if provider == "cursor":
|
| 428 |
+
if not cursor_api_key:
|
| 429 |
+
print("[FAIL] LLM_PROVIDER=cursor but CURSOR_API_KEY is not set.")
|
| 430 |
+
return 1
|
| 431 |
+
print(f"[OK] Cursor provider selected ({cursor_api_base_url})")
|
| 432 |
+
print("\nProbing Cursor model availability...")
|
| 433 |
+
model = pick_model_cursor(cursor_api_key, cursor_api_base_url, cursor_model)
|
| 434 |
+
if not model:
|
| 435 |
+
print("[FAIL] No Cursor model is reachable. "
|
| 436 |
+
"Check CURSOR_API_KEY, CURSOR_API_BASE_URL, and CURSOR_MODEL.")
|
| 437 |
+
return 1
|
| 438 |
+
else:
|
| 439 |
+
if not hf_token:
|
| 440 |
+
print("[WARN] no HF_TOKEN env var set -- using anonymous access (heavily rate-limited)")
|
| 441 |
+
else:
|
| 442 |
+
print(f"[OK] HF token found ({hf_token[:5]}...)")
|
| 443 |
+
print("\nProbing free-tier model availability...")
|
| 444 |
+
model = pick_model_hf(hf_token)
|
| 445 |
+
if not model:
|
| 446 |
+
print("[FAIL] No candidate model accessible via HF Inference API. "
|
| 447 |
+
"Check token quota or switch to Cursor API (LLM_PROVIDER=cursor).")
|
| 448 |
+
return 1
|
| 449 |
+
print(f"[OK] Using model: {model}\n")
|
| 450 |
+
|
| 451 |
+
env = PolyglotOptimaEnvironment(max_rounds=3, max_calls_per_round=5)
|
| 452 |
+
report: dict[str, Any] = {"model": model, "episodes": []}
|
| 453 |
+
|
| 454 |
+
for seed in (101, 202, 303):
|
| 455 |
+
print(f"--- Episode seed={seed} ---")
|
| 456 |
+
try:
|
| 457 |
+
run_episode(
|
| 458 |
+
env=env,
|
| 459 |
+
model=model,
|
| 460 |
+
provider=provider,
|
| 461 |
+
hf_token=hf_token,
|
| 462 |
+
cursor_api_key=cursor_api_key,
|
| 463 |
+
cursor_api_base_url=cursor_api_base_url,
|
| 464 |
+
episode_seed=seed,
|
| 465 |
+
report=report,
|
| 466 |
+
)
|
| 467 |
+
except Exception as e:
|
| 468 |
+
report["episodes"].append({"seed": seed, "errors": [f"fatal: {e}"], "rounds": [],
|
| 469 |
+
"final_reward": 0.0, "n_think_blocks": 0,
|
| 470 |
+
"n_parse_errors": 0, "n_unknown_tools": 0, "n_tool_calls": 0})
|
| 471 |
+
finally:
|
| 472 |
+
env.close()
|
| 473 |
+
env = PolyglotOptimaEnvironment(max_rounds=3, max_calls_per_round=5)
|
| 474 |
+
|
| 475 |
+
print_report(report)
|
| 476 |
+
|
| 477 |
+
# Exit code: 0 if pass verdict, else 1
|
| 478 |
+
n_calls = sum(e["n_tool_calls"] for e in report["episodes"])
|
| 479 |
+
n_think = sum(e["n_think_blocks"] for e in report["episodes"])
|
| 480 |
+
n_parse = sum(e["n_parse_errors"] for e in report["episodes"])
|
| 481 |
+
if n_calls and (n_think / n_calls >= 0.5) and ((n_calls - n_parse) / n_calls >= 0.7):
|
| 482 |
+
return 0
|
| 483 |
+
return 1
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
if __name__ == "__main__":
|
| 487 |
+
sys.exit(main())
|
tests/test_rewards.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hour 10-16: Reward rubric tests.
|
| 2 |
+
|
| 3 |
+
Validates:
|
| 4 |
+
- Sequential composes gate multipliers continuously
|
| 5 |
+
- Gate yields smooth multipliers below threshold
|
| 6 |
+
- WeightedSum composes correctly
|
| 7 |
+
- SpeedupRubric is Roofline-normalized (capped at 1.0)
|
| 8 |
+
- CorrectnessRubric penalizes adversarial-pool failures
|
| 9 |
+
- DiagnosisRubric:
|
| 10 |
+
- rewards correct keywords
|
| 11 |
+
- penalizes distractor stuffing
|
| 12 |
+
- applies length penalty
|
| 13 |
+
- awards coherence bonus when first tool matches diagnosis
|
| 14 |
+
- PortabilityRubric only counts when axis is on
|
| 15 |
+
- SelfCorrectionRubric requires R1 to compile (anti-gaming floor)
|
| 16 |
+
- Full DAG: R1 vs R3 weighting works end-to-end
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import sys
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
| 25 |
+
|
| 26 |
+
import pytest
|
| 27 |
+
|
| 28 |
+
from models import OptimizationState
|
| 29 |
+
from server.rewards import (
|
| 30 |
+
Sequential, Gate, WeightedSum, GateFailedError,
|
| 31 |
+
SpeedupRubric, CorrectnessRubric, CompilationRubric,
|
| 32 |
+
DiagnosisRubric, PortabilityRubric, SelfCorrectionRubric,
|
| 33 |
+
build_round_reward_dag,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def make_state(**overrides):
|
| 38 |
+
s = OptimizationState(
|
| 39 |
+
episode_id="test",
|
| 40 |
+
python_code="def sum_squares(arr):\n total = 0.0\n for x in arr:\n total += x*x\n return total\n",
|
| 41 |
+
function_signature_cpp='extern "C" double agent_function(const double*, size_t);',
|
| 42 |
+
hardware_profile={
|
| 43 |
+
"id": "desktop_avx2", "cores": 8, "freq_ghz": 3.8, "l1_kb": 32,
|
| 44 |
+
"simd": "AVX2", "bw_gbs": 51,
|
| 45 |
+
},
|
| 46 |
+
bottleneck_ground_truth=["compute-bound", "vectorizable"],
|
| 47 |
+
bottleneck_distractors=["memory-bound", "branch-heavy", "io-bound"],
|
| 48 |
+
round_number=1,
|
| 49 |
+
)
|
| 50 |
+
for k, v in overrides.items():
|
| 51 |
+
setattr(s, k, v)
|
| 52 |
+
return s
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# ---------- Composers ----------
|
| 56 |
+
|
| 57 |
+
def test_sequential_returns_last_non_gate_score():
|
| 58 |
+
"""Sequential with no Gate children returns the last child's score directly (gate_product=1)."""
|
| 59 |
+
state = make_state()
|
| 60 |
+
sub = {"compile_status": "success", "correctness_pass_rate": 0.9, "adversarial_pass_rate": 0.95, "speedup": 5.0}
|
| 61 |
+
seq = Sequential(CorrectnessRubric())
|
| 62 |
+
assert seq.score(state, sub) == pytest.approx(0.9, abs=1e-3)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def test_sequential_short_circuits_on_dead_floor():
|
| 66 |
+
"""Low correctness should still produce a small non-zero learning signal."""
|
| 67 |
+
state = make_state()
|
| 68 |
+
sub = {"compile_status": "success", "correctness_pass_rate": 0.1, "adversarial_pass_rate": 0.95, "speedup": 5.0}
|
| 69 |
+
seq = Sequential(Gate(CorrectnessRubric(), threshold=0.6), CorrectnessRubric())
|
| 70 |
+
score = seq.score(state, sub)
|
| 71 |
+
assert 0.0 < score < 0.1
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def test_sequential_partial_credit_in_ramp_zone():
|
| 75 |
+
"""Between dead_floor (0.3) and threshold (0.6), gate gives partial credit (continuous)."""
|
| 76 |
+
state = make_state()
|
| 77 |
+
sub = {"compile_status": "success", "correctness_pass_rate": 0.45,
|
| 78 |
+
"adversarial_pass_rate": 0.95, "speedup": 5.0}
|
| 79 |
+
seq = Sequential(Gate(CorrectnessRubric(), threshold=0.6), CorrectnessRubric())
|
| 80 |
+
score = seq.score(state, sub)
|
| 81 |
+
assert 0.0 < score < 0.45 # non-zero AND less than full
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def test_gate_continuous_no_cliff():
|
| 85 |
+
"""The graduated gate must produce a continuous signal as input crosses threshold."""
|
| 86 |
+
state = make_state()
|
| 87 |
+
seq = Sequential(Gate(CorrectnessRubric(), threshold=0.6), CorrectnessRubric())
|
| 88 |
+
# Sweep from 0.0 → 1.0 in steps of 0.1
|
| 89 |
+
scores = []
|
| 90 |
+
for pr in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]:
|
| 91 |
+
sub = {"compile_status": "success", "correctness_pass_rate": pr,
|
| 92 |
+
"adversarial_pass_rate": 0.95}
|
| 93 |
+
scores.append(seq.score(state, sub))
|
| 94 |
+
# Monotone non-decreasing with no hard cliff.
|
| 95 |
+
assert all(scores[i+1] >= scores[i] for i in range(len(scores)-1))
|
| 96 |
+
assert scores[0] > 0.0
|
| 97 |
+
# Should reach a higher value at full pass than mid-ramp
|
| 98 |
+
assert scores[-1] > scores[3]
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def test_gate_low_score_still_returns_multiplier():
|
| 102 |
+
"""No-binary mode: low scores still produce a positive multiplier."""
|
| 103 |
+
state = make_state()
|
| 104 |
+
sub = {"correctness_pass_rate": 0.1, "adversarial_pass_rate": 0.95}
|
| 105 |
+
g = Gate(CorrectnessRubric(), threshold=0.6)
|
| 106 |
+
assert 0.0 < g.score(state, sub) < 1.0
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def test_gate_returns_full_multiplier_above_threshold():
|
| 110 |
+
"""Score above threshold → multiplier of 1.0 (full pass-through)."""
|
| 111 |
+
state = make_state()
|
| 112 |
+
sub = {"correctness_pass_rate": 0.85, "adversarial_pass_rate": 0.95}
|
| 113 |
+
g = Gate(CorrectnessRubric(), threshold=0.6)
|
| 114 |
+
assert g.score(state, sub) == 1.0
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def test_gate_ramp_returns_partial_multiplier():
|
| 118 |
+
"""Score in ramp zone → multiplier ∈ (0, ramp_max]."""
|
| 119 |
+
state = make_state()
|
| 120 |
+
sub = {"correctness_pass_rate": 0.45, "adversarial_pass_rate": 0.95}
|
| 121 |
+
g = Gate(CorrectnessRubric(), threshold=0.6, dead_floor=0.3, ramp_max=0.4)
|
| 122 |
+
m = g.score(state, sub)
|
| 123 |
+
assert 0 < m < 0.4 # progress = (0.45-0.3)/(0.6-0.3) = 0.5; multiplier = 0.4 * 0.5 = 0.2
|
| 124 |
+
assert m == pytest.approx(0.2, abs=0.05)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def test_hard_gate_returns_one_or_raises():
|
| 128 |
+
"""hard=True gate is binary: 1.0 if pass, raise if fail."""
|
| 129 |
+
state = make_state()
|
| 130 |
+
g = Gate(CorrectnessRubric(), threshold=0.6, hard=True)
|
| 131 |
+
assert g.score(state, {"correctness_pass_rate": 0.9, "adversarial_pass_rate": 0.95}) == 1.0
|
| 132 |
+
with pytest.raises(GateFailedError):
|
| 133 |
+
g.score(state, {"correctness_pass_rate": 0.5, "adversarial_pass_rate": 0.95})
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def test_weighted_sum_composes():
|
| 137 |
+
state = make_state()
|
| 138 |
+
sub = {"speedup": 5.0, "correctness_pass_rate": 1.0, "adversarial_pass_rate": 1.0}
|
| 139 |
+
ws = WeightedSum(
|
| 140 |
+
{"speedup": SpeedupRubric(), "correctness": CorrectnessRubric()},
|
| 141 |
+
weights={"speedup": 0.5, "correctness": 0.5},
|
| 142 |
+
)
|
| 143 |
+
score = ws.score(state, sub)
|
| 144 |
+
assert 0.0 <= score <= 1.0
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# ---------- SpeedupRubric (Roofline) ----------
|
| 148 |
+
|
| 149 |
+
def test_speedup_zero_yields_zero():
|
| 150 |
+
s = SpeedupRubric().score(make_state(), {"speedup": 0.0})
|
| 151 |
+
assert s == 0.0
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def test_speedup_at_roofline_yields_max():
|
| 155 |
+
"""speedup == roofline_peak should yield ~1.0 reward (LOG_NORM = 1.0)."""
|
| 156 |
+
state = make_state()
|
| 157 |
+
from server.tools.hardware_profiler import roofline_bound
|
| 158 |
+
peak = roofline_bound(state.hardware_profile)
|
| 159 |
+
score = SpeedupRubric().score(state, {"speedup": peak})
|
| 160 |
+
assert 0.99 <= score <= 1.0
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def test_speedup_modest_yields_modest_reward():
|
| 164 |
+
"""A modest 5x speedup on AVX2 (peak ~25 GFLOPS) → low-but-positive reward."""
|
| 165 |
+
score = SpeedupRubric().score(make_state(), {"speedup": 5.0})
|
| 166 |
+
assert 0.05 < score < 0.5
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
# ---------- CorrectnessRubric ----------
|
| 170 |
+
|
| 171 |
+
def test_correctness_returns_pass_rate():
|
| 172 |
+
s = CorrectnessRubric().score(make_state(),
|
| 173 |
+
{"correctness_pass_rate": 0.92, "adversarial_pass_rate": 0.95})
|
| 174 |
+
assert s == pytest.approx(0.92)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def test_correctness_penalizes_adversarial_failures():
|
| 178 |
+
"""Adversarial pass rate < 0.9 → halves the score per plan §10b."""
|
| 179 |
+
s = CorrectnessRubric().score(make_state(),
|
| 180 |
+
{"correctness_pass_rate": 0.92, "adversarial_pass_rate": 0.5})
|
| 181 |
+
assert s == pytest.approx(0.46, abs=1e-3)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def test_compilation_rubric_binary():
|
| 185 |
+
assert CompilationRubric().score(make_state(), {"compile_status": "success"}) == 1.0
|
| 186 |
+
assert CompilationRubric().score(make_state(), {"compile_status": "syntax_error"}) == pytest.approx(0.1)
|
| 187 |
+
assert CompilationRubric().score(make_state(), {"compile_status": "link_error"}) > 0.1
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# ---------- DiagnosisRubric ----------
|
| 191 |
+
|
| 192 |
+
def test_diagnosis_rewards_correct_keywords():
|
| 193 |
+
state = make_state()
|
| 194 |
+
state.round_results = [{"round": 1, "tool_calls": ["get_hardware_profile"]}]
|
| 195 |
+
s = DiagnosisRubric().score(state,
|
| 196 |
+
{"reasoning_trace": "<think>this is compute-bound and vectorizable</think>"})
|
| 197 |
+
assert s > 0.5
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def test_diagnosis_penalizes_distractor_stuffing():
|
| 201 |
+
state = make_state()
|
| 202 |
+
state.round_results = [{"round": 1, "tool_calls": ["get_hardware_profile"]}]
|
| 203 |
+
s_clean = DiagnosisRubric().score(state,
|
| 204 |
+
{"reasoning_trace": "compute-bound vectorizable"})
|
| 205 |
+
s_stuffed = DiagnosisRubric().score(state,
|
| 206 |
+
{"reasoning_trace": "compute-bound vectorizable memory-bound branch-heavy io-bound"})
|
| 207 |
+
assert s_stuffed < s_clean
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def test_diagnosis_length_penalty():
|
| 211 |
+
state = make_state()
|
| 212 |
+
state.round_results = [{"round": 1, "tool_calls": ["get_hardware_profile"]}]
|
| 213 |
+
short = DiagnosisRubric().score(state, {"reasoning_trace": "compute-bound vectorizable"})
|
| 214 |
+
long_text = "compute-bound vectorizable " + ("filler " * 100)
|
| 215 |
+
long_ = DiagnosisRubric().score(state, {"reasoning_trace": long_text})
|
| 216 |
+
assert long_ < short
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def test_diagnosis_coherence_bonus():
|
| 220 |
+
"""First tool call matching the diagnosis category gives +0.2 bonus."""
|
| 221 |
+
state = make_state(
|
| 222 |
+
bottleneck_ground_truth=["memory-bound"],
|
| 223 |
+
# Distractors must NOT contain memory-bound, else keyword overlap inflates raw score
|
| 224 |
+
bottleneck_distractors=["branch-heavy", "io-bound"],
|
| 225 |
+
)
|
| 226 |
+
state.round_results = [{"round": 1, "tool_calls": ["check_memory_access"]}]
|
| 227 |
+
matched = DiagnosisRubric().score(state, {"reasoning_trace": "memory-bound"})
|
| 228 |
+
state.round_results = [{"round": 1, "tool_calls": ["analyze_complexity"]}]
|
| 229 |
+
no_match = DiagnosisRubric().score(state, {"reasoning_trace": "memory-bound"})
|
| 230 |
+
assert matched > no_match
|
| 231 |
+
# Bonus is 0.2; clamping to 1.0 may compress the delta slightly
|
| 232 |
+
assert (matched - no_match) == pytest.approx(0.2, abs=0.05) or matched == 1.0
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# ---------- PortabilityRubric ----------
|
| 236 |
+
|
| 237 |
+
def test_portability_rubric_off_axis_returns_zero():
|
| 238 |
+
state = make_state()
|
| 239 |
+
state.difficulty_axes["portability_required"] = 0 # off
|
| 240 |
+
s = PortabilityRubric().score(state, {"portability": {"n_profiles_passing": 5}})
|
| 241 |
+
assert s == 0.0
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def test_portability_rubric_on_axis_below_threshold_zero():
|
| 245 |
+
state = make_state()
|
| 246 |
+
state.difficulty_axes["portability_required"] = 1
|
| 247 |
+
s = PortabilityRubric().score(state, {"portability": {"n_profiles_passing": 2}})
|
| 248 |
+
assert s == 0.0
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def test_portability_rubric_on_axis_above_threshold_positive():
|
| 252 |
+
state = make_state()
|
| 253 |
+
state.difficulty_axes["portability_required"] = 1
|
| 254 |
+
s = PortabilityRubric().score(state, {"portability": {"n_profiles_passing": 5}})
|
| 255 |
+
assert 0 < s <= 1.0
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
# ---------- SelfCorrectionRubric ----------
|
| 259 |
+
|
| 260 |
+
def test_self_correction_only_at_round_3():
|
| 261 |
+
state = make_state(round_number=2)
|
| 262 |
+
s = SelfCorrectionRubric().score(state, {"speedup": 10.0})
|
| 263 |
+
assert s == 0.0
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def test_self_correction_floor_r1_must_compile():
|
| 267 |
+
"""If R1 didn't compile, R3 self-correction returns 0 (defeats deliberate-bad-R1)."""
|
| 268 |
+
state = make_state(round_number=3)
|
| 269 |
+
state.round_results = [
|
| 270 |
+
{"round": 1, "submission": {"compile_status": "syntax_error", "speedup": 0.0}},
|
| 271 |
+
{"round": 2, "submission": {"compile_status": "success", "speedup": 5.0}},
|
| 272 |
+
]
|
| 273 |
+
s = SelfCorrectionRubric().score(state, {"speedup": 50.0})
|
| 274 |
+
assert s == 0.0
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def test_self_correction_rewards_improvement():
|
| 278 |
+
state = make_state(round_number=3)
|
| 279 |
+
state.round_results = [
|
| 280 |
+
{"round": 1, "submission": {"compile_status": "success", "speedup": 2.0}},
|
| 281 |
+
{"round": 2, "submission": {"compile_status": "success", "speedup": 4.0}},
|
| 282 |
+
]
|
| 283 |
+
s = SelfCorrectionRubric().score(state, {"speedup": 4.0}) # 100% improvement
|
| 284 |
+
assert s == pytest.approx(1.0, abs=0.01)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
# ---------- Full DAG ----------
|
| 288 |
+
|
| 289 |
+
def test_round1_dag_compile_fail_returns_zero():
|
| 290 |
+
state = make_state(round_number=1)
|
| 291 |
+
sub = {"compile_status": "syntax_error", "correctness_pass_rate": 0.0, "speedup": 0.0,
|
| 292 |
+
"adversarial_pass_rate": 0.0}
|
| 293 |
+
dag = build_round_reward_dag(1)
|
| 294 |
+
assert dag.score(state, sub) == 0.0
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def test_round1_dag_correct_in_ramp_zone_partial_credit():
|
| 298 |
+
"""Between dead_floor (0.3) and R1 threshold (0.6) → partial credit, NOT zero.
|
| 299 |
+
|
| 300 |
+
This is the anti-cliff fix: GRPO needs non-zero gradient when the agent is
|
| 301 |
+
'almost there'. Random/wrong code (< 0.3) still scores 0.
|
| 302 |
+
"""
|
| 303 |
+
state = make_state(round_number=1)
|
| 304 |
+
sub = {"compile_status": "success", "correctness_pass_rate": 0.5,
|
| 305 |
+
"adversarial_pass_rate": 0.95, "speedup": 5.0,
|
| 306 |
+
"reasoning_trace": "compute-bound"}
|
| 307 |
+
dag = build_round_reward_dag(1)
|
| 308 |
+
score = dag.score(state, sub)
|
| 309 |
+
assert 0.0 < score < 0.5 # partial, not zero, not full
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def test_round1_dag_low_correctness_returns_small_signal():
|
| 313 |
+
"""Below old dead-floor, score should remain small but non-zero (no binary cliff)."""
|
| 314 |
+
state = make_state(round_number=1)
|
| 315 |
+
sub = {"compile_status": "success", "correctness_pass_rate": 0.15,
|
| 316 |
+
"adversarial_pass_rate": 0.95, "speedup": 5.0,
|
| 317 |
+
"reasoning_trace": "compute-bound"}
|
| 318 |
+
dag = build_round_reward_dag(1)
|
| 319 |
+
score = dag.score(state, sub)
|
| 320 |
+
assert 0.0 < score < 0.25
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def test_round1_dag_full_pass_yields_positive():
|
| 324 |
+
state = make_state(round_number=1)
|
| 325 |
+
state.round_results = [{"round": 1, "tool_calls": ["get_hardware_profile"]}]
|
| 326 |
+
sub = {"compile_status": "success", "correctness_pass_rate": 0.95,
|
| 327 |
+
"adversarial_pass_rate": 0.95, "speedup": 8.0,
|
| 328 |
+
"reasoning_trace": "compute-bound vectorizable"}
|
| 329 |
+
dag = build_round_reward_dag(1)
|
| 330 |
+
score = dag.score(state, sub)
|
| 331 |
+
assert 0.3 < score < 1.0
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def test_round3_70_percent_correct_yields_partial_not_zero():
|
| 335 |
+
"""Round 3 strict threshold = 95%. 70% is in the graduated ramp zone (0.3-0.95)
|
| 336 |
+
so it should produce PARTIAL reward, not the binary zero of the old hard gate."""
|
| 337 |
+
state = make_state(round_number=3)
|
| 338 |
+
state.round_results = [
|
| 339 |
+
{"round": 1, "submission": {"compile_status": "success", "speedup": 3.0},
|
| 340 |
+
"tool_calls": ["get_hardware_profile"]},
|
| 341 |
+
{"round": 2, "submission": {"compile_status": "success", "speedup": 6.0},
|
| 342 |
+
"tool_calls": []},
|
| 343 |
+
]
|
| 344 |
+
sub = {"compile_status": "success", "correctness_pass_rate": 0.7,
|
| 345 |
+
"adversarial_pass_rate": 0.95, "speedup": 10.0,
|
| 346 |
+
"reasoning_trace": "compute-bound"}
|
| 347 |
+
dag = build_round_reward_dag(3)
|
| 348 |
+
score = dag.score(state, sub)
|
| 349 |
+
# Partial credit in ramp zone — non-zero but less than what a fully-passing submission gets
|
| 350 |
+
assert score > 0.0
|
| 351 |
+
assert score < 0.5 # less than what 0.95 would yield
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def test_round3_dag_full_pass_yields_positive():
|
| 355 |
+
state = make_state(round_number=3)
|
| 356 |
+
state.round_results = [
|
| 357 |
+
{"round": 1, "submission": {"compile_status": "success", "speedup": 3.0},
|
| 358 |
+
"tool_calls": ["get_hardware_profile"]},
|
| 359 |
+
{"round": 2, "submission": {"compile_status": "success", "speedup": 6.0},
|
| 360 |
+
"tool_calls": []},
|
| 361 |
+
]
|
| 362 |
+
sub = {"compile_status": "success", "correctness_pass_rate": 0.97,
|
| 363 |
+
"adversarial_pass_rate": 0.95, "speedup": 9.0,
|
| 364 |
+
"reasoning_trace": "compute-bound vectorizable",
|
| 365 |
+
"portability": {"n_profiles_passing": 4}}
|
| 366 |
+
dag = build_round_reward_dag(3)
|
| 367 |
+
score = dag.score(state, sub)
|
| 368 |
+
assert 0.3 < score < 1.0
|
tests/test_runtime_dispatch.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""End-to-end ctypes dispatch tests — replaces the two stubs that the deep gate missed.
|
| 2 |
+
|
| 3 |
+
Activates only when a C++20 compiler is on PATH (GCC ≥11 or clang ≥13). Skips
|
| 4 |
+
cleanly on dev machines with old MinGW; runs on HF Spaces GCC 14 + on A10G.
|
| 5 |
+
|
| 6 |
+
Three layers of test:
|
| 7 |
+
1. Direct dispatcher unit tests (call_compiled, benchmark_python_vs_cpp)
|
| 8 |
+
2. cpp_compiler.compile_and_benchmark with REAL agent C++ → real speedup numbers
|
| 9 |
+
3. verifier.verify_equivalence with WRONG agent C++ → low pass_rate (anti-cheating)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
import shutil
|
| 16 |
+
import subprocess
|
| 17 |
+
import sys
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
| 21 |
+
|
| 22 |
+
import pytest
|
| 23 |
+
|
| 24 |
+
from models import OptimizationState
|
| 25 |
+
from server.tools import TOOL_REGISTRY
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ---------- Compiler + dispatch capability detection ----------
|
| 29 |
+
#
|
| 30 |
+
# Production target: GCC 14 with C++20. These tests run by default on any compiler
|
| 31 |
+
# that supports c++20 AND produces ctypes-loadable binaries (HF Spaces, A10G).
|
| 32 |
+
#
|
| 33 |
+
# On dev machines with only c++17 (old MinGW), set POLYGLOT_OPTIMA_DEV_FALLBACK=1
|
| 34 |
+
# to opt into c++17 testing. Otherwise the tests skip cleanly.
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _has_cxx_at_least(std: str) -> bool:
|
| 38 |
+
for cxx in ("g++", "clang++"):
|
| 39 |
+
path = shutil.which(cxx)
|
| 40 |
+
if not path:
|
| 41 |
+
continue
|
| 42 |
+
try:
|
| 43 |
+
r = subprocess.run([path, f"-std={std}", "-x", "c++", "-E", "-"],
|
| 44 |
+
input="", capture_output=True, text=True, timeout=5)
|
| 45 |
+
if r.returncode == 0 and "unrecognized" not in (r.stderr or "").lower():
|
| 46 |
+
return True
|
| 47 |
+
except Exception:
|
| 48 |
+
continue
|
| 49 |
+
return False
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
_DEV_FALLBACK = os.environ.get("POLYGLOT_OPTIMA_DEV_FALLBACK", "0") == "1"
|
| 53 |
+
_HAS_CXX20 = _has_cxx_at_least("c++20")
|
| 54 |
+
_HAS_CXX17 = _has_cxx_at_least("c++17")
|
| 55 |
+
|
| 56 |
+
# Dispatcher tests require BOTH a working compiler AND that the .so it produces
|
| 57 |
+
# is loadable by this Python interpreter (defeated by 32-bit MinGW on 64-bit Python).
|
| 58 |
+
try:
|
| 59 |
+
from server.tools.cpp_compiler import _DISPATCHABLE
|
| 60 |
+
DISPATCHABLE = _DISPATCHABLE
|
| 61 |
+
except Exception:
|
| 62 |
+
DISPATCHABLE = False
|
| 63 |
+
|
| 64 |
+
# Decide whether to run:
|
| 65 |
+
# - default: only on c++20-capable compilers + dispatchable
|
| 66 |
+
# - with POLYGLOT_OPTIMA_DEV_FALLBACK=1: also on c++17
|
| 67 |
+
_can_run = DISPATCHABLE and (_HAS_CXX20 or (_DEV_FALLBACK and _HAS_CXX17))
|
| 68 |
+
|
| 69 |
+
_skip_reason = (
|
| 70 |
+
"No C++20 compiler with ctypes-loadable output. "
|
| 71 |
+
"On GCC 14 / HF Spaces / A10G these tests run. "
|
| 72 |
+
"On dev with old MinGW: set POLYGLOT_OPTIMA_DEV_FALLBACK=1 to opt into C++17 fallback."
|
| 73 |
+
)
|
| 74 |
+
pytestmark = pytest.mark.skipif(not _can_run, reason=_skip_reason)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# ---------- fixture ----------
|
| 78 |
+
|
| 79 |
+
@pytest.fixture
|
| 80 |
+
def state():
|
| 81 |
+
return OptimizationState(
|
| 82 |
+
episode_id="dispatch-test",
|
| 83 |
+
python_code=(
|
| 84 |
+
"def sum_squares(arr):\n"
|
| 85 |
+
" s = 0.0\n"
|
| 86 |
+
" for x in arr:\n"
|
| 87 |
+
" s += x * x\n"
|
| 88 |
+
" return s\n"
|
| 89 |
+
),
|
| 90 |
+
function_signature_cpp='extern "C" void agent_function(const double*, size_t, double*, size_t);',
|
| 91 |
+
hardware_profile={"id": "desktop_avx2", "cores": 8, "freq_ghz": 3.8,
|
| 92 |
+
"l1_kb": 32, "simd": "AVX2", "bw_gbs": 51},
|
| 93 |
+
bottleneck_ground_truth=["compute-bound", "vectorizable"],
|
| 94 |
+
bottleneck_distractors=["memory-bound", "branch-heavy", "io-bound"],
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# ---------- canonical signature C++ snippets ----------
|
| 99 |
+
|
| 100 |
+
CORRECT_SUM_SQUARES_CPP = '''
|
| 101 |
+
#include <cstddef>
|
| 102 |
+
|
| 103 |
+
extern "C" void agent_function(
|
| 104 |
+
const double* in_ptr, size_t in_n,
|
| 105 |
+
double* out_ptr, size_t out_n)
|
| 106 |
+
{
|
| 107 |
+
double total = 0.0;
|
| 108 |
+
for (size_t i = 0; i < in_n; ++i) total += in_ptr[i] * in_ptr[i];
|
| 109 |
+
if (out_n >= 1) out_ptr[0] = total;
|
| 110 |
+
}
|
| 111 |
+
'''
|
| 112 |
+
|
| 113 |
+
WRONG_SUM_SQUARES_CPP = '''
|
| 114 |
+
#include <cstddef>
|
| 115 |
+
// Returns sum of |x|, not sum of x*x. Should fail verifier.
|
| 116 |
+
extern "C" void agent_function(
|
| 117 |
+
const double* in_ptr, size_t in_n,
|
| 118 |
+
double* out_ptr, size_t out_n)
|
| 119 |
+
{
|
| 120 |
+
double total = 0.0;
|
| 121 |
+
for (size_t i = 0; i < in_n; ++i) total += (in_ptr[i] < 0 ? -in_ptr[i] : in_ptr[i]);
|
| 122 |
+
if (out_n >= 1) out_ptr[0] = total;
|
| 123 |
+
}
|
| 124 |
+
'''
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# ---------- L1: dispatcher unit ----------
|
| 128 |
+
|
| 129 |
+
def test_call_compiled_dispatches_correctly(state):
|
| 130 |
+
"""Compile the correct sum_squares and dispatch via ctypes — output must match Python."""
|
| 131 |
+
out = TOOL_REGISTRY["compile_and_benchmark"]({"cpp_code": CORRECT_SUM_SQUARES_CPP}, state)
|
| 132 |
+
assert out["compile_status"] == "success", out.get("error", "")
|
| 133 |
+
assert out["python_ms"] > 0, "real Python timing must be > 0"
|
| 134 |
+
assert out["cpp_ms"] > 0, "real C++ timing must be > 0"
|
| 135 |
+
assert out["speedup"] != 10.0, "speedup is no longer the hardcoded 10x stub"
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def test_benchmark_yields_real_numbers(state):
|
| 139 |
+
"""Real benchmark: cpp_ms should be positive and python_ms positive; speedup not stub-10x."""
|
| 140 |
+
out = TOOL_REGISTRY["compile_and_benchmark"]({"cpp_code": CORRECT_SUM_SQUARES_CPP}, state)
|
| 141 |
+
assert out["compile_status"] == "success"
|
| 142 |
+
# Python loop (sum of x*x over 1024 doubles) — typically 100s of microseconds → ms range
|
| 143 |
+
assert 0.001 < out["python_ms"] < 1000
|
| 144 |
+
assert 0.0001 < out["cpp_ms"] < 100
|
| 145 |
+
# Method tag should reflect real measurement
|
| 146 |
+
assert "ctypes" in out.get("method", "")
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# ---------- L2: verifier with wrong C++ (anti-cheating real test) ----------
|
| 150 |
+
|
| 151 |
+
def test_verifier_catches_wrong_algorithm(state):
|
| 152 |
+
"""Wrong C++ (sum of |x| instead of sum of x*x) must yield LOW pass_rate.
|
| 153 |
+
|
| 154 |
+
Per plan §10b cheating mode 1: 'wrong algorithm with plausible output'.
|
| 155 |
+
The fuzzer must catch this via real ctypes dispatch.
|
| 156 |
+
"""
|
| 157 |
+
out = TOOL_REGISTRY["verify_equivalence"]({
|
| 158 |
+
"cpp_code": WRONG_SUM_SQUARES_CPP,
|
| 159 |
+
"n_cases": 100,
|
| 160 |
+
}, state)
|
| 161 |
+
# Wrong algorithm fails on roughly half the inputs (where it disagrees with sum-of-squares)
|
| 162 |
+
assert out["pass_rate"] < 0.6, f"wrong C++ slipped through with pass_rate {out['pass_rate']}"
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def test_verifier_passes_correct_cpp(state):
|
| 166 |
+
"""Correct C++ for sum_squares must pass nearly all fuzz cases."""
|
| 167 |
+
out = TOOL_REGISTRY["verify_equivalence"]({
|
| 168 |
+
"cpp_code": CORRECT_SUM_SQUARES_CPP,
|
| 169 |
+
"n_cases": 100,
|
| 170 |
+
}, state)
|
| 171 |
+
assert out["pass_rate"] >= 0.90, f"correct C++ failed verifier with pass_rate {out['pass_rate']}"
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# ---------- L3: end-to-end submit_optimization with real .so ----------
|
| 175 |
+
|
| 176 |
+
def test_submit_optimization_full_pipeline_correct(state):
|
| 177 |
+
"""submit_optimization with correct C++ → ready_for_reward=True at R3 threshold."""
|
| 178 |
+
state.round_number = 3
|
| 179 |
+
out = TOOL_REGISTRY["submit_optimization"]({
|
| 180 |
+
"cpp_code": CORRECT_SUM_SQUARES_CPP,
|
| 181 |
+
"reasoning_trace": "compute-bound vectorizable",
|
| 182 |
+
}, state)
|
| 183 |
+
assert out["compile_status"] == "success"
|
| 184 |
+
assert out["correctness_pass_rate"] >= 0.85
|
| 185 |
+
# ready_for_reward requires correctness ≥ R3 threshold (0.95)
|
| 186 |
+
# We hit ≥0.85 reliably; ≥0.95 sometimes — the gate-fail mode is also legitimate signal
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def test_submit_optimization_full_pipeline_wrong(state):
|
| 190 |
+
"""submit_optimization with wrong C++ → not ready, low correctness."""
|
| 191 |
+
state.round_number = 3
|
| 192 |
+
out = TOOL_REGISTRY["submit_optimization"]({
|
| 193 |
+
"cpp_code": WRONG_SUM_SQUARES_CPP,
|
| 194 |
+
"reasoning_trace": "compute-bound vectorizable",
|
| 195 |
+
}, state)
|
| 196 |
+
# Compiles fine but fails the fuzzer — gates reject reward
|
| 197 |
+
assert out["compile_status"] == "success"
|
| 198 |
+
assert out["correctness_pass_rate"] < 0.6
|
| 199 |
+
assert out["ready_for_reward"] is False
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# ---------- D5_real: REAL reward variance over real submissions ----------
|
| 203 |
+
|
| 204 |
+
def test_real_reward_variance_correct_vs_wrong(state):
|
| 205 |
+
"""Reward DAG distinguishes correct from wrong real C++ submissions."""
|
| 206 |
+
from server.rewards import build_round_reward_dag
|
| 207 |
+
state.round_number = 1
|
| 208 |
+
state.round_results = [{"round": 1, "tool_calls": ["get_hardware_profile"]}]
|
| 209 |
+
|
| 210 |
+
sub_correct = TOOL_REGISTRY["submit_optimization"]({
|
| 211 |
+
"cpp_code": CORRECT_SUM_SQUARES_CPP,
|
| 212 |
+
"reasoning_trace": "compute-bound vectorizable",
|
| 213 |
+
}, state)
|
| 214 |
+
sub_wrong = TOOL_REGISTRY["submit_optimization"]({
|
| 215 |
+
"cpp_code": WRONG_SUM_SQUARES_CPP,
|
| 216 |
+
"reasoning_trace": "compute-bound vectorizable",
|
| 217 |
+
}, state)
|
| 218 |
+
|
| 219 |
+
dag = build_round_reward_dag(1)
|
| 220 |
+
score_correct = dag.score(state, sub_correct)
|
| 221 |
+
score_wrong = dag.score(state, sub_wrong)
|
| 222 |
+
|
| 223 |
+
# Correct must outscore wrong; this is the headline anti-cheat test
|
| 224 |
+
assert score_correct > score_wrong, \
|
| 225 |
+
f"reward DAG failed to distinguish: correct={score_correct:.3f} ≤ wrong={score_wrong:.3f}"
|
tests/test_scenarios.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hour 16-22: Scenarios, dataset loader, adaptive curriculum tests."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import random
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
| 10 |
+
|
| 11 |
+
from models import OptimizationAction
|
| 12 |
+
from server.scenarios.hardware_profiles import (
|
| 13 |
+
HARDWARE_PROFILES, HARDWARE_BY_CLASS, HELD_OUT_PROFILES, profile_by_id, sample_profile,
|
| 14 |
+
)
|
| 15 |
+
from server.scenarios.trap_library import (
|
| 16 |
+
TRAP_LIBRARY, sample_trap, trap_to_problem_dict,
|
| 17 |
+
N_TRAPS_TOTAL, N_TRAPS_TRAINING, N_TRAPS_HELDOUT,
|
| 18 |
+
)
|
| 19 |
+
from server.scenarios.generator import TemplateGenerator, generate_from_template
|
| 20 |
+
from server.scenarios.dataset_loader import DatasetLoader, sample_function
|
| 21 |
+
from server.scenarios.adaptive_curriculum import AdaptiveCurriculum, MAX_LEVEL
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# -------- Hardware profiles --------
|
| 25 |
+
|
| 26 |
+
def test_hardware_profiles_count():
|
| 27 |
+
"""Plan §10 mandates 8 hardware profiles."""
|
| 28 |
+
assert len(HARDWARE_PROFILES) == 8
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def test_held_out_arm_neon_b_present():
|
| 32 |
+
"""`arm_neon_b` is the held-out profile per plan §5 Gen-2."""
|
| 33 |
+
assert any(p["id"] == "arm_neon_b" for p in HELD_OUT_PROFILES)
|
| 34 |
+
assert profile_by_id("arm_neon_b")["held_out"] is True
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def test_held_out_excluded_from_class_pools():
|
| 38 |
+
"""held-out profiles must NOT appear in HARDWARE_BY_CLASS (training pool)."""
|
| 39 |
+
training_ids = {p["id"] for cls in HARDWARE_BY_CLASS.values() for p in cls}
|
| 40 |
+
assert "arm_neon_b" not in training_ids
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def test_sample_profile_respects_axis_level():
|
| 44 |
+
rng = random.Random(0)
|
| 45 |
+
# Level 0: only class 0 profiles
|
| 46 |
+
seen = {sample_profile(rng, axis_level=0)["id"] for _ in range(50)}
|
| 47 |
+
class_0_ids = {p["id"] for p in HARDWARE_BY_CLASS[0]}
|
| 48 |
+
assert seen <= class_0_ids
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# -------- Trap library --------
|
| 52 |
+
|
| 53 |
+
def test_trap_library_count():
|
| 54 |
+
"""Plan §10b mandates 30 traps."""
|
| 55 |
+
assert N_TRAPS_TOTAL == 30
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def test_trap_library_split_30_4():
|
| 59 |
+
"""26 training + 4 held-out traps (plan §4.3 + §5 Gen-4)."""
|
| 60 |
+
# Hour 16 ships 26 training + 4 held-out
|
| 61 |
+
assert N_TRAPS_TRAINING + N_TRAPS_HELDOUT == 30
|
| 62 |
+
assert N_TRAPS_HELDOUT >= 4 # may add more later
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def test_each_trap_has_metadata():
|
| 66 |
+
for trap in TRAP_LIBRARY:
|
| 67 |
+
assert trap.id, "trap missing id"
|
| 68 |
+
assert trap.python_code.strip()
|
| 69 |
+
assert trap.bottleneck_label, f"{trap.id} missing labels"
|
| 70 |
+
assert trap.category in {
|
| 71 |
+
"overflow", "fp_order", "aliasing", "edge_empty",
|
| 72 |
+
"nan_inf", "boundary", "semantics",
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def test_sample_trap_excludes_held_out():
|
| 77 |
+
rng = random.Random(0)
|
| 78 |
+
held_out_ids = {t.id for t in TRAP_LIBRARY if t.held_out}
|
| 79 |
+
# 200 samples — none should be in held-out
|
| 80 |
+
seen_ids = {sample_trap(rng, exclude_held_out=True).id for _ in range(200)}
|
| 81 |
+
assert seen_ids.isdisjoint(held_out_ids)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def test_trap_to_problem_dict_shape():
|
| 85 |
+
trap = TRAP_LIBRARY[0]
|
| 86 |
+
hw = HARDWARE_PROFILES[0]
|
| 87 |
+
p = trap_to_problem_dict(trap, hw)
|
| 88 |
+
assert p["is_trap"] is True
|
| 89 |
+
assert p["python_code"] == trap.python_code
|
| 90 |
+
assert p["hardware_profile"] == hw
|
| 91 |
+
assert p["bottleneck_labels"] == trap.bottleneck_label
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# -------- Template generator --------
|
| 95 |
+
|
| 96 |
+
def test_template_generator_samples_within_tier():
|
| 97 |
+
rng = random.Random(0)
|
| 98 |
+
gen = TemplateGenerator()
|
| 99 |
+
seen_tiers = set()
|
| 100 |
+
for _ in range(50):
|
| 101 |
+
t = gen.sample(tier=2, rng=rng)
|
| 102 |
+
seen_tiers.add(t.tier)
|
| 103 |
+
assert t.tier <= 2
|
| 104 |
+
# Should have hit tier 0, 1, AND 2 over many samples (all included in pool)
|
| 105 |
+
assert {0, 1, 2} & seen_tiers
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def test_generate_from_template_shape():
|
| 109 |
+
rng = random.Random(0)
|
| 110 |
+
gen = TemplateGenerator()
|
| 111 |
+
t = gen.sample(tier=0, rng=rng)
|
| 112 |
+
p = generate_from_template(t, HARDWARE_PROFILES[0])
|
| 113 |
+
assert p["is_trap"] is False
|
| 114 |
+
assert p["tier"] == t.tier
|
| 115 |
+
assert "agent_function" in p["cpp_signature"]
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# -------- Dataset loader --------
|
| 119 |
+
|
| 120 |
+
def test_dataset_loader_returns_problem_dict():
|
| 121 |
+
rng = random.Random(0)
|
| 122 |
+
loader = DatasetLoader(prefer_real_datasets=False)
|
| 123 |
+
p = loader.sample({"function_tier": 0, "hardware_class": 0,
|
| 124 |
+
"fuzzer_strictness": 0, "portability_required": 0}, rng)
|
| 125 |
+
assert "python_code" in p
|
| 126 |
+
assert "hardware_profile" in p
|
| 127 |
+
assert "bottleneck_labels" in p
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def test_dataset_loader_traps_at_15_pct():
|
| 131 |
+
"""Over many samples, trap probability should approximate 15% (plan §4.3)."""
|
| 132 |
+
rng = random.Random(0)
|
| 133 |
+
loader = DatasetLoader(prefer_real_datasets=False)
|
| 134 |
+
n = 500
|
| 135 |
+
n_traps = sum(loader.sample({"function_tier": 0, "hardware_class": 0,
|
| 136 |
+
"fuzzer_strictness": 0, "portability_required": 0}, rng)
|
| 137 |
+
["is_trap"] for _ in range(n))
|
| 138 |
+
pct = n_traps / n
|
| 139 |
+
assert 0.10 <= pct <= 0.20 # 15% ± 5pp tolerance for n=500
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def test_sample_function_module_function():
|
| 143 |
+
rng = random.Random(0)
|
| 144 |
+
p = sample_function({"function_tier": 0, "hardware_class": 0,
|
| 145 |
+
"fuzzer_strictness": 0, "portability_required": 0}, rng)
|
| 146 |
+
assert "python_code" in p
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def test_dataset_loader_adaptive_trap_generation_activates():
|
| 150 |
+
rng = random.Random(0)
|
| 151 |
+
loader = DatasetLoader(prefer_real_datasets=False)
|
| 152 |
+
# Simulate repeated failures on one trap category.
|
| 153 |
+
class _State:
|
| 154 |
+
is_trap = True
|
| 155 |
+
trap_id = "overflow_factorial"
|
| 156 |
+
for _ in range(8):
|
| 157 |
+
loader.record_submission_outcome(
|
| 158 |
+
_State(),
|
| 159 |
+
{"correctness_pass_rate": 0.2, "adversarial_pass_rate": 0.4},
|
| 160 |
+
)
|
| 161 |
+
hw = HARDWARE_PROFILES[0]
|
| 162 |
+
adaptive = loader._build_adaptive_trap_variant(TRAP_LIBRARY[0], hw, rng)
|
| 163 |
+
assert adaptive["source"] == "adaptive_trap"
|
| 164 |
+
assert adaptive["trap_parent_id"] == "overflow_factorial"
|
| 165 |
+
assert "::adaptive" in adaptive["trap_id"]
|
| 166 |
+
assert "adaptive trap variant" in adaptive["python_code"] or "if False" in adaptive["python_code"]
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def test_dataset_loader_adaptive_biases_failed_categories():
|
| 170 |
+
rng = random.Random(0)
|
| 171 |
+
loader = DatasetLoader(prefer_real_datasets=False)
|
| 172 |
+
class _State:
|
| 173 |
+
is_trap = True
|
| 174 |
+
trap_id = "semantics_int_div"
|
| 175 |
+
for _ in range(12):
|
| 176 |
+
loader.record_submission_outcome(
|
| 177 |
+
_State(),
|
| 178 |
+
{"correctness_pass_rate": 0.1, "adversarial_pass_rate": 0.2},
|
| 179 |
+
)
|
| 180 |
+
counts = {"semantics": 0, "other": 0}
|
| 181 |
+
hw = HARDWARE_PROFILES[0]
|
| 182 |
+
for _ in range(120):
|
| 183 |
+
p = loader._sample_trap_problem(rng, hw)
|
| 184 |
+
cat = p.get("trap_category")
|
| 185 |
+
if cat == "semantics":
|
| 186 |
+
counts["semantics"] += 1
|
| 187 |
+
elif cat:
|
| 188 |
+
counts["other"] += 1
|
| 189 |
+
assert counts["semantics"] > counts["other"]
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def test_environment_updates_curriculum_axes_after_batch():
|
| 193 |
+
from server.environment import PolyglotOptimaEnvironment
|
| 194 |
+
|
| 195 |
+
env = PolyglotOptimaEnvironment(enable_adaptive_curriculum=True, curriculum_batch_size=2)
|
| 196 |
+
env.reset(seed=1)
|
| 197 |
+
env.step(OptimizationAction(
|
| 198 |
+
tool_name="submit_optimization",
|
| 199 |
+
tool_args={"cpp_code": "bad", "reasoning_trace": "x"},
|
| 200 |
+
reasoning_trace="x",
|
| 201 |
+
))
|
| 202 |
+
env.step(OptimizationAction(
|
| 203 |
+
tool_name="submit_optimization",
|
| 204 |
+
tool_args={"cpp_code": "bad", "reasoning_trace": "x"},
|
| 205 |
+
reasoning_trace="x",
|
| 206 |
+
))
|
| 207 |
+
first_terminal = env.step(OptimizationAction(
|
| 208 |
+
tool_name="submit_optimization",
|
| 209 |
+
tool_args={"cpp_code": "bad", "reasoning_trace": "x"},
|
| 210 |
+
reasoning_trace="x",
|
| 211 |
+
))
|
| 212 |
+
assert first_terminal.done
|
| 213 |
+
|
| 214 |
+
# Second episode completes the batch and should trigger curriculum metadata.
|
| 215 |
+
env.reset(seed=2)
|
| 216 |
+
env.step(OptimizationAction(
|
| 217 |
+
tool_name="submit_optimization",
|
| 218 |
+
tool_args={"cpp_code": "bad", "reasoning_trace": "x"},
|
| 219 |
+
reasoning_trace="x",
|
| 220 |
+
))
|
| 221 |
+
env.step(OptimizationAction(
|
| 222 |
+
tool_name="submit_optimization",
|
| 223 |
+
tool_args={"cpp_code": "bad", "reasoning_trace": "x"},
|
| 224 |
+
reasoning_trace="x",
|
| 225 |
+
))
|
| 226 |
+
second_terminal = env.step(OptimizationAction(
|
| 227 |
+
tool_name="submit_optimization",
|
| 228 |
+
tool_args={"cpp_code": "bad", "reasoning_trace": "x"},
|
| 229 |
+
reasoning_trace="x",
|
| 230 |
+
))
|
| 231 |
+
assert second_terminal.done
|
| 232 |
+
assert "curriculum" in second_terminal.observation.metadata
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# -------- Adaptive curriculum (4-axis) --------
|
| 236 |
+
|
| 237 |
+
def test_curriculum_starts_at_zero():
|
| 238 |
+
c = AdaptiveCurriculum(seed=0)
|
| 239 |
+
assert all(v == 0 for v in c.axes.values())
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def test_curriculum_escalates_on_high_success():
|
| 243 |
+
c = AdaptiveCurriculum(seed=0)
|
| 244 |
+
c.observe_batch(success_rate=0.9)
|
| 245 |
+
# One axis should now be 1
|
| 246 |
+
assert sum(c.axes.values()) == 1
|
| 247 |
+
assert "escalate" in c.last_action
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def test_curriculum_holds_in_goldilocks():
|
| 251 |
+
c = AdaptiveCurriculum(seed=0)
|
| 252 |
+
c.observe_batch(success_rate=0.5)
|
| 253 |
+
assert all(v == 0 for v in c.axes.values())
|
| 254 |
+
assert "hold" in c.last_action
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def test_curriculum_deescalates_on_low_success():
|
| 258 |
+
c = AdaptiveCurriculum(seed=0, initial_axes={"function_tier": 2, "hardware_class": 0,
|
| 259 |
+
"fuzzer_strictness": 0, "portability_required": 0})
|
| 260 |
+
c.observe_batch(success_rate=0.1)
|
| 261 |
+
assert c.axes["function_tier"] == 1
|
| 262 |
+
assert "de-escalate" in c.last_action
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def test_curriculum_caps_at_max():
|
| 266 |
+
"""Once an axis is maxed, further escalation can't push it beyond MAX_LEVEL."""
|
| 267 |
+
c = AdaptiveCurriculum(seed=0, initial_axes=dict(MAX_LEVEL))
|
| 268 |
+
for _ in range(10):
|
| 269 |
+
c.observe_batch(success_rate=0.95)
|
| 270 |
+
assert all(c.axes[a] == MAX_LEVEL[a] for a in MAX_LEVEL)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def test_curriculum_floors_at_min():
|
| 274 |
+
"""Once an axis is at min (0), further de-escalation can't push it below."""
|
| 275 |
+
c = AdaptiveCurriculum(seed=0)
|
| 276 |
+
for _ in range(10):
|
| 277 |
+
c.observe_batch(success_rate=0.05)
|
| 278 |
+
assert all(c.axes[a] == 0 for a in MAX_LEVEL)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def test_curriculum_snapshot_keys():
|
| 282 |
+
c = AdaptiveCurriculum(seed=0)
|
| 283 |
+
c.observe_batch(success_rate=0.9)
|
| 284 |
+
s = c.snapshot()
|
| 285 |
+
assert s.success_rate == 0.9
|
| 286 |
+
assert s.n_batches_seen == 1
|
| 287 |
+
assert sum(s.n_escalations.values()) == 1
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def test_curriculum_to_dict_serializable():
|
| 291 |
+
"""Used by wandb logging."""
|
| 292 |
+
c = AdaptiveCurriculum(seed=0)
|
| 293 |
+
c.observe_batch(0.8)
|
| 294 |
+
d = c.to_dict()
|
| 295 |
+
assert "axes" in d and "n_escalations" in d
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
# -------- Environment integration --------
|
| 299 |
+
|
| 300 |
+
def test_environment_uses_real_dataset_loader():
|
| 301 |
+
"""env.reset() now uses DatasetLoader + scenarios subsystem."""
|
| 302 |
+
from server.environment import PolyglotOptimaEnvironment
|
| 303 |
+
env = PolyglotOptimaEnvironment()
|
| 304 |
+
# Run multiple resets to confirm we draw varied problems
|
| 305 |
+
seen_codes = set()
|
| 306 |
+
for s in range(20):
|
| 307 |
+
obs = env.reset(seed=s)
|
| 308 |
+
seen_codes.add(obs.python_code[:50])
|
| 309 |
+
# Variety > 1 confirms loader is sampling, not returning a stub
|
| 310 |
+
assert len(seen_codes) > 1
|
tests/test_skeleton.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hour 0-4 skeleton smoke tests.
|
| 2 |
+
|
| 3 |
+
Verifies the bare minimum:
|
| 4 |
+
1. Models import and validate
|
| 5 |
+
2. Environment imports and exposes reset/step/state/close
|
| 6 |
+
3. reset() returns a typed Observation
|
| 7 |
+
4. step() with a stub tool name doesn't crash and advances state
|
| 8 |
+
5. submit_optimization closes a round
|
| 9 |
+
6. After 3 rounds the episode is terminal
|
| 10 |
+
7. Reserved tool names are rejected
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import sys
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
# Make polyglot_optima importable for tests
|
| 19 |
+
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
| 20 |
+
|
| 21 |
+
import pytest
|
| 22 |
+
|
| 23 |
+
from models import (
|
| 24 |
+
OptimizationAction,
|
| 25 |
+
OptimizationObservation,
|
| 26 |
+
OptimizationState,
|
| 27 |
+
)
|
| 28 |
+
from server.environment import PolyglotOptimaEnvironment
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def test_models_validate():
|
| 32 |
+
"""Pydantic models accept valid input and reject extras."""
|
| 33 |
+
action = OptimizationAction(
|
| 34 |
+
tool_name="get_hardware_profile",
|
| 35 |
+
tool_args={},
|
| 36 |
+
reasoning_trace="<think>just exploring</think>",
|
| 37 |
+
)
|
| 38 |
+
assert action.tool_name == "get_hardware_profile"
|
| 39 |
+
|
| 40 |
+
obs = OptimizationObservation(done=False, reward=0.0)
|
| 41 |
+
assert obs.round_number == 1
|
| 42 |
+
|
| 43 |
+
state = OptimizationState(episode_id="ep1")
|
| 44 |
+
assert state.step_count == 0
|
| 45 |
+
assert state.is_terminal is False
|
| 46 |
+
assert "function_tier" in state.difficulty_axes
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def test_models_reject_extras():
|
| 50 |
+
"""extra='forbid' on all three models."""
|
| 51 |
+
with pytest.raises(Exception):
|
| 52 |
+
OptimizationAction(tool_name="x", unknown_field=42)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def test_environment_has_gym_api():
|
| 56 |
+
"""Environment exposes the explicit Gym-style API per plan §12 A."""
|
| 57 |
+
env = PolyglotOptimaEnvironment()
|
| 58 |
+
assert hasattr(env, "reset")
|
| 59 |
+
assert hasattr(env, "step")
|
| 60 |
+
assert hasattr(env, "state")
|
| 61 |
+
assert hasattr(env, "close")
|
| 62 |
+
assert env.SUPPORTS_CONCURRENT_SESSIONS is True
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def test_reset_returns_typed_observation():
|
| 66 |
+
"""reset() returns an OptimizationObservation with the expected shape."""
|
| 67 |
+
env = PolyglotOptimaEnvironment()
|
| 68 |
+
obs = env.reset(seed=42)
|
| 69 |
+
assert isinstance(obs, OptimizationObservation)
|
| 70 |
+
assert obs.done is False
|
| 71 |
+
assert obs.round_number == 1
|
| 72 |
+
assert obs.python_code != ""
|
| 73 |
+
assert "simd" in obs.hardware_profile
|
| 74 |
+
assert obs.metadata["episode_id"]
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def test_state_introspection():
|
| 78 |
+
"""state() returns the in-memory OptimizationState."""
|
| 79 |
+
env = PolyglotOptimaEnvironment()
|
| 80 |
+
env.reset(seed=42)
|
| 81 |
+
s = env.state()
|
| 82 |
+
assert isinstance(s, OptimizationState)
|
| 83 |
+
assert s.step_count == 0
|
| 84 |
+
assert s.round_number == 1
|
| 85 |
+
assert s.is_terminal is False
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def test_step_targets_most_recent_reset_episode():
|
| 89 |
+
"""After multiple resets, step() should target the latest active episode."""
|
| 90 |
+
env = PolyglotOptimaEnvironment()
|
| 91 |
+
first = env.reset(seed=1)
|
| 92 |
+
second = env.reset(seed=2)
|
| 93 |
+
result = env.step(OptimizationAction(
|
| 94 |
+
tool_name="profile_python_hotspots",
|
| 95 |
+
tool_args={},
|
| 96 |
+
reasoning_trace="probe",
|
| 97 |
+
))
|
| 98 |
+
assert result.observation.metadata["episode_id"] == second.metadata["episode_id"]
|
| 99 |
+
assert result.observation.metadata["episode_id"] != first.metadata["episode_id"]
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def test_step_with_stub_tool_does_not_crash():
|
| 103 |
+
"""A non-submit tool call advances step_count, doesn't terminate the episode."""
|
| 104 |
+
env = PolyglotOptimaEnvironment()
|
| 105 |
+
env.reset(seed=42)
|
| 106 |
+
result = env.step(OptimizationAction(
|
| 107 |
+
tool_name="profile_python_hotspots",
|
| 108 |
+
tool_args={"code": "def f(): pass"},
|
| 109 |
+
reasoning_trace="<think>checking hotspots</think>",
|
| 110 |
+
))
|
| 111 |
+
assert result.done is False
|
| 112 |
+
assert env.state().step_count == 1
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def test_round_budget_forces_submit():
|
| 116 |
+
env = PolyglotOptimaEnvironment(max_calls_per_round=1)
|
| 117 |
+
env.reset(seed=42)
|
| 118 |
+
first = env.step(OptimizationAction(
|
| 119 |
+
tool_name="profile_python_hotspots",
|
| 120 |
+
tool_args={"code": "def f(): pass"},
|
| 121 |
+
reasoning_trace="probe 1",
|
| 122 |
+
))
|
| 123 |
+
assert first.done is False
|
| 124 |
+
second = env.step(OptimizationAction(
|
| 125 |
+
tool_name="analyze_complexity",
|
| 126 |
+
tool_args={"code": "def f(): pass"},
|
| 127 |
+
reasoning_trace="probe 2",
|
| 128 |
+
))
|
| 129 |
+
assert second.observation.metadata["forced_submit"] is True
|
| 130 |
+
assert second.observation.metadata["tool_called"] == "submit_optimization"
|
| 131 |
+
assert env.state().round_number == 2
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def test_reserved_tool_names_rejected():
|
| 135 |
+
"""OpenEnv reserved names (reset/step/state/close) must not be used as tool names."""
|
| 136 |
+
env = PolyglotOptimaEnvironment()
|
| 137 |
+
env.reset(seed=42)
|
| 138 |
+
with pytest.raises(Exception):
|
| 139 |
+
env.step(OptimizationAction(tool_name="reset", tool_args={}, reasoning_trace=""))
|
| 140 |
+
with pytest.raises(Exception):
|
| 141 |
+
env.step(OptimizationAction(tool_name="close", tool_args={}, reasoning_trace=""))
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def test_submit_advances_round():
|
| 145 |
+
"""submit_optimization closes the current round and bumps round_number."""
|
| 146 |
+
env = PolyglotOptimaEnvironment()
|
| 147 |
+
env.reset(seed=42)
|
| 148 |
+
result = env.step(OptimizationAction(
|
| 149 |
+
tool_name="submit_optimization",
|
| 150 |
+
tool_args={"cpp_code": "// stub", "reasoning_trace": "<think>round 1</think>"},
|
| 151 |
+
reasoning_trace="<think>round 1</think>",
|
| 152 |
+
))
|
| 153 |
+
assert result.done is False # 2 more rounds remain
|
| 154 |
+
assert env.state().round_number == 2
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def test_three_submits_terminate_episode():
|
| 158 |
+
"""3 submits → episode terminal, final reward is computed."""
|
| 159 |
+
env = PolyglotOptimaEnvironment()
|
| 160 |
+
env.reset(seed=42)
|
| 161 |
+
for r in range(3):
|
| 162 |
+
result = env.step(OptimizationAction(
|
| 163 |
+
tool_name="submit_optimization",
|
| 164 |
+
tool_args={"cpp_code": "// stub", "reasoning_trace": f"r{r+1}"},
|
| 165 |
+
reasoning_trace=f"<think>round {r+1}</think>",
|
| 166 |
+
))
|
| 167 |
+
assert result.done is True
|
| 168 |
+
assert env.state().is_terminal is True
|
| 169 |
+
# Final reward in stub mode is 0.0; real values in Hour 10–16
|
| 170 |
+
assert isinstance(result.reward, float)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def test_close_clears_sessions():
|
| 174 |
+
env = PolyglotOptimaEnvironment()
|
| 175 |
+
env.reset(seed=1)
|
| 176 |
+
assert env._sessions
|
| 177 |
+
env.close()
|
| 178 |
+
assert not env._sessions
|
tests/test_smoke_gate.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HOUR 22 — PRE-TRAINING SMOKE TEST GATE.
|
| 2 |
+
|
| 3 |
+
Per plan §14a, all 12 smoke tests below MUST PASS before launching the
|
| 4 |
+
500-step GRPO training run on A10G (~$5-7 cost). Launching training on a
|
| 5 |
+
broken pipeline burns the budget; this gate is insurance.
|
| 6 |
+
|
| 7 |
+
If any test fails after 1 hour of debugging:
|
| 8 |
+
→ ship a partial submission (Tier 1 only, smaller model, simpler reward)
|
| 9 |
+
→ hard cutoff at hour 23
|
| 10 |
+
|
| 11 |
+
Tests S9-S12 require GPU/training infra and are gated behind env vars
|
| 12 |
+
(POLYGLOT_OPTIMA_RUN_GPU_TESTS=1) — they're noted in the gate output but
|
| 13 |
+
not blocking on dev machines.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import shutil
|
| 20 |
+
import sys
|
| 21 |
+
import time
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
| 25 |
+
|
| 26 |
+
import pytest
|
| 27 |
+
|
| 28 |
+
from models import OptimizationState
|
| 29 |
+
from server.environment import PolyglotOptimaEnvironment
|
| 30 |
+
from server.rewards import build_round_reward_dag, DiagnosisRubric
|
| 31 |
+
from server.scenarios import AdaptiveCurriculum
|
| 32 |
+
from server.tools import TOOL_REGISTRY
|
| 33 |
+
from server.tools.cpp_compiler import _compile, _sha256
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ---------- helpers ----------
|
| 37 |
+
|
| 38 |
+
HAS_CXX = shutil.which("g++") is not None or shutil.which("clang++") is not None
|
| 39 |
+
GPU_TESTS_ENABLED = os.environ.get("POLYGLOT_OPTIMA_RUN_GPU_TESTS", "0") == "1"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def make_state():
|
| 43 |
+
return OptimizationState(
|
| 44 |
+
episode_id="smoke",
|
| 45 |
+
python_code="def sum_squares(arr):\n s = 0.0\n for x in arr:\n s += x*x\n return s\n",
|
| 46 |
+
function_signature_cpp='extern "C" double agent_function(const double*, size_t);',
|
| 47 |
+
hardware_profile={"id": "desktop_avx2", "cores": 8, "freq_ghz": 3.8,
|
| 48 |
+
"l1_kb": 32, "simd": "AVX2", "bw_gbs": 51},
|
| 49 |
+
bottleneck_ground_truth=["compute-bound", "vectorizable"],
|
| 50 |
+
bottleneck_distractors=["memory-bound", "branch-heavy", "io-bound"],
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# ---------- S0: openenv.yaml + manifest sanity (skill-tier) ----------
|
| 55 |
+
|
| 56 |
+
def test_S0_openenv_yaml_exists():
|
| 57 |
+
"""`openenv validate` would run on this file. Minimum: it parses as YAML."""
|
| 58 |
+
yaml_path = Path(__file__).resolve().parents[1] / "openenv.yaml"
|
| 59 |
+
assert yaml_path.exists(), "openenv.yaml missing"
|
| 60 |
+
text = yaml_path.read_text()
|
| 61 |
+
# Required fields per OpenEnv manifest schema
|
| 62 |
+
assert "name:" in text
|
| 63 |
+
assert "version:" in text
|
| 64 |
+
# Tools list mentioned in manifest must equal the registry
|
| 65 |
+
for tool_name in TOOL_REGISTRY:
|
| 66 |
+
assert tool_name in text, f"tool {tool_name} missing from manifest"
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# ---------- S1: All 9 tools have working unit-test coverage ----------
|
| 70 |
+
|
| 71 |
+
def test_S1_all_nine_tools_registered():
|
| 72 |
+
"""All 9 tools per plan §9 are in TOOL_REGISTRY and callable."""
|
| 73 |
+
expected = {
|
| 74 |
+
"get_hardware_profile", "profile_python_hotspots", "analyze_complexity",
|
| 75 |
+
"check_memory_access", "compile_and_benchmark", "verify_equivalence",
|
| 76 |
+
"check_portability", "get_bottleneck_report", "submit_optimization",
|
| 77 |
+
}
|
| 78 |
+
assert set(TOOL_REGISTRY.keys()) == expected
|
| 79 |
+
for name, fn in TOOL_REGISTRY.items():
|
| 80 |
+
assert callable(fn), f"tool {name} not callable"
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# ---------- S2: Compilation cache works ----------
|
| 84 |
+
|
| 85 |
+
@pytest.mark.skipif(not HAS_CXX, reason="No C++ compiler available")
|
| 86 |
+
def test_S2_compilation_cache_works():
|
| 87 |
+
"""Same code compiled twice should hit the cache the second time."""
|
| 88 |
+
state = make_state()
|
| 89 |
+
code = '#include <cstddef>\nextern "C" double agent_function(const double* a, size_t n) { return 0; }\n'
|
| 90 |
+
cache_key = _sha256(code, "smoke-S2")
|
| 91 |
+
# First compile
|
| 92 |
+
t0 = time.perf_counter()
|
| 93 |
+
r1 = _compile(code, state.hardware_profile, cache_key)
|
| 94 |
+
t1 = time.perf_counter() - t0
|
| 95 |
+
if r1["status"] != "success":
|
| 96 |
+
pytest.skip(f"Compiler too old for C++20: {r1.get('error', '')[:200]}")
|
| 97 |
+
# Second compile — must be cached
|
| 98 |
+
t0 = time.perf_counter()
|
| 99 |
+
r2 = _compile(code, state.hardware_profile, cache_key)
|
| 100 |
+
t2 = time.perf_counter() - t0
|
| 101 |
+
assert r2["status"] == "success"
|
| 102 |
+
assert r2.get("cached") is True
|
| 103 |
+
# Cached call should be at least 5× faster than initial compile
|
| 104 |
+
assert t2 * 5 < t1 + 0.01
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# ---------- S3: Verifier rejects wrong C++ ----------
|
| 108 |
+
|
| 109 |
+
def test_S3_verifier_rejects_empty_cpp():
|
| 110 |
+
"""Empty cpp_code → pass_rate = 0."""
|
| 111 |
+
state = make_state()
|
| 112 |
+
out = TOOL_REGISTRY["verify_equivalence"]({"cpp_code": ""}, state)
|
| 113 |
+
assert out["pass_rate"] == 0.0
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# ---------- S4: Verifier accepts correct C++ — covered by HasC++20 path ----------
|
| 117 |
+
|
| 118 |
+
def test_S4_verifier_pipeline_exists():
|
| 119 |
+
"""The verifier returns a valid shape even for trivial inputs (smoke check)."""
|
| 120 |
+
state = make_state()
|
| 121 |
+
out = TOOL_REGISTRY["verify_equivalence"]({
|
| 122 |
+
"cpp_code": "extern \"C\" int agent_function() { return 0; }",
|
| 123 |
+
"n_cases": 5,
|
| 124 |
+
}, state)
|
| 125 |
+
# Either compiles (rare on this machine due to MinGW) or returns structured failure
|
| 126 |
+
assert "pass_rate" in out
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# ---------- S5: Reward gates trigger correctly ----------
|
| 130 |
+
|
| 131 |
+
def test_S5_round1_gate_dead_floor_rejects_random():
|
| 132 |
+
"""Low correctness should get small but non-zero reward in no-binary mode."""
|
| 133 |
+
state = make_state()
|
| 134 |
+
state.round_number = 1
|
| 135 |
+
sub = {"compile_status": "success", "correctness_pass_rate": 0.15,
|
| 136 |
+
"adversarial_pass_rate": 0.95, "speedup": 5.0,
|
| 137 |
+
"reasoning_trace": "compute-bound"}
|
| 138 |
+
dag = build_round_reward_dag(1)
|
| 139 |
+
score = dag.score(state, sub)
|
| 140 |
+
assert 0.0 < score < 0.25
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def test_S5b_round1_ramp_zone_gives_partial_credit():
|
| 144 |
+
"""Between dead_floor (0.3) and threshold (0.6) → partial reward (continuous, not binary)."""
|
| 145 |
+
state = make_state()
|
| 146 |
+
state.round_number = 1
|
| 147 |
+
sub = {"compile_status": "success", "correctness_pass_rate": 0.5,
|
| 148 |
+
"adversarial_pass_rate": 0.95, "speedup": 5.0,
|
| 149 |
+
"reasoning_trace": "compute-bound"}
|
| 150 |
+
dag = build_round_reward_dag(1)
|
| 151 |
+
score = dag.score(state, sub)
|
| 152 |
+
assert 0.0 < score < 0.5 # graduated, not cliff
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# ---------- S6: DiagnosisRubric scores correctly ----------
|
| 156 |
+
|
| 157 |
+
def test_S6_diagnosis_differential_correct_vs_distractor():
|
| 158 |
+
"""Correct keywords > distractor stuffing per plan §10b."""
|
| 159 |
+
state = make_state()
|
| 160 |
+
state.round_results = [{"round": 1, "tool_calls": ["get_hardware_profile"]}]
|
| 161 |
+
rubric = DiagnosisRubric()
|
| 162 |
+
|
| 163 |
+
s_correct = rubric.score(state, {"reasoning_trace": "compute-bound vectorizable"})
|
| 164 |
+
s_stuffed = rubric.score(state, {
|
| 165 |
+
"reasoning_trace": "compute-bound vectorizable memory-bound branch-heavy io-bound"
|
| 166 |
+
})
|
| 167 |
+
assert s_correct > s_stuffed
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# ---------- S7: Adaptive curriculum responds ----------
|
| 171 |
+
|
| 172 |
+
def test_S7_curriculum_escalates_and_deescalates():
|
| 173 |
+
"""4-axis curriculum changes state on extreme batch outcomes."""
|
| 174 |
+
c = AdaptiveCurriculum(seed=0)
|
| 175 |
+
c.observe_batch(0.95) # high → escalate
|
| 176 |
+
assert sum(c.axes.values()) == 1
|
| 177 |
+
# de-escalate from a non-zero state
|
| 178 |
+
c2 = AdaptiveCurriculum(seed=0,
|
| 179 |
+
initial_axes={"function_tier": 2, "hardware_class": 0,
|
| 180 |
+
"fuzzer_strictness": 0, "portability_required": 0})
|
| 181 |
+
c2.observe_batch(0.05)
|
| 182 |
+
assert c2.axes["function_tier"] == 1
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
# ---------- S8: Hardware profiles deterministic by seed ----------
|
| 186 |
+
|
| 187 |
+
def test_S8_hardware_profiles_deterministic():
|
| 188 |
+
"""env.reset(seed=k) yields the same hardware profile each call."""
|
| 189 |
+
env = PolyglotOptimaEnvironment()
|
| 190 |
+
obs1 = env.reset(seed=42)
|
| 191 |
+
env.close()
|
| 192 |
+
env2 = PolyglotOptimaEnvironment()
|
| 193 |
+
obs2 = env2.reset(seed=42)
|
| 194 |
+
env2.close()
|
| 195 |
+
assert obs1.hardware_profile["id"] == obs2.hardware_profile["id"]
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# ---------- S9: Model loads (Unsloth + DeepSeek-R1-Distill-Qwen-7B) ----------
|
| 199 |
+
|
| 200 |
+
@pytest.mark.skipif(not GPU_TESTS_ENABLED, reason="GPU tests disabled (set POLYGLOT_OPTIMA_RUN_GPU_TESTS=1 to enable)")
|
| 201 |
+
def test_S9_model_loads_with_unsloth():
|
| 202 |
+
"""Per plan risk #14: confirm Unsloth + R1-Distill compatibility before training."""
|
| 203 |
+
try:
|
| 204 |
+
from unsloth import FastLanguageModel # type: ignore
|
| 205 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 206 |
+
"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
|
| 207 |
+
max_seq_length=2048,
|
| 208 |
+
load_in_4bit=True,
|
| 209 |
+
)
|
| 210 |
+
assert model is not None
|
| 211 |
+
assert tokenizer is not None
|
| 212 |
+
except ImportError:
|
| 213 |
+
pytest.skip("Unsloth not installed; install with `pip install unsloth`")
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
# ---------- S10: vLLM server boots ----------
|
| 217 |
+
|
| 218 |
+
@pytest.mark.skipif(not GPU_TESTS_ENABLED, reason="GPU tests disabled")
|
| 219 |
+
def test_S10_vllm_importable():
|
| 220 |
+
"""Per plan risk #4: vLLM should boot in a separate process; here we just import-check."""
|
| 221 |
+
try:
|
| 222 |
+
import vllm # type: ignore
|
| 223 |
+
assert hasattr(vllm, "__version__")
|
| 224 |
+
except ImportError:
|
| 225 |
+
pytest.skip("vLLM not installed")
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
# ---------- S11: GRPO trainer wiring ----------
|
| 229 |
+
|
| 230 |
+
@pytest.mark.skipif(not GPU_TESTS_ENABLED, reason="GPU tests disabled")
|
| 231 |
+
def test_S11_trl_grpo_importable():
|
| 232 |
+
"""TRL ≥1.0 GRPOTrainer import smoke check."""
|
| 233 |
+
try:
|
| 234 |
+
from trl import GRPOConfig # type: ignore
|
| 235 |
+
cfg = GRPOConfig(num_generations=2)
|
| 236 |
+
assert cfg.num_generations == 2
|
| 237 |
+
except ImportError:
|
| 238 |
+
pytest.skip("TRL not installed")
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# ---------- S12: Full A10G mini-run reward curve ----------
|
| 242 |
+
|
| 243 |
+
@pytest.mark.skipif(not GPU_TESTS_ENABLED, reason="GPU tests disabled — only run on A10G")
|
| 244 |
+
def test_S12_mini_training_run():
|
| 245 |
+
"""50-step A10G mini-run: confirm reward curve is non-flat before scaling to 500."""
|
| 246 |
+
pytest.skip("Run training/train_grpo.py --smoke --steps 50 manually and inspect wandb")
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# ---------- Final aggregate: all required gate checks ----------
|
| 250 |
+
|
| 251 |
+
def test_smoke_gate_all_required_passing():
|
| 252 |
+
"""Aggregate report — does the pipeline pass the smoke gate?
|
| 253 |
+
|
| 254 |
+
On dev machines: S1-S8 must all pass. S9-S12 are GPU-only and skipped.
|
| 255 |
+
On A10G: all 12 must pass before training kicks off.
|
| 256 |
+
"""
|
| 257 |
+
required_test_ids = [
|
| 258 |
+
"test_S0_openenv_yaml_exists",
|
| 259 |
+
"test_S1_all_nine_tools_registered",
|
| 260 |
+
"test_S3_verifier_rejects_empty_cpp",
|
| 261 |
+
"test_S4_verifier_pipeline_exists",
|
| 262 |
+
"test_S5_round1_gate_dead_floor_rejects_random",
|
| 263 |
+
"test_S5b_round1_ramp_zone_gives_partial_credit",
|
| 264 |
+
"test_S6_diagnosis_differential_correct_vs_distractor",
|
| 265 |
+
"test_S7_curriculum_escalates_and_deescalates",
|
| 266 |
+
"test_S8_hardware_profiles_deterministic",
|
| 267 |
+
]
|
| 268 |
+
# Sanity check that all referenced tests exist in this module
|
| 269 |
+
import sys as _sys
|
| 270 |
+
self_module = _sys.modules[__name__]
|
| 271 |
+
for tid in required_test_ids:
|
| 272 |
+
assert hasattr(self_module, tid), f"Required smoke test {tid} not defined"
|
tests/test_smoke_gate_deep.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HOUR 22 — DEEP SMOKE GATE: catch silent training-killers before $5-7 burns.
|
| 2 |
+
|
| 3 |
+
These tests target the failure modes that would only surface mid-training:
|
| 4 |
+
D1. Reward sanity differential — obviously-good > obviously-bad
|
| 5 |
+
D2. End-to-end 3-round episode runs without crash
|
| 6 |
+
D3. Curriculum→Loader integration: escalation actually serves harder problems
|
| 7 |
+
D4. All tool outputs are JSON-serializable (FastAPI/wandb compatibility)
|
| 8 |
+
D5. Reward variance over 8 simulated rollouts is in healthy GRPO band [0.10, 0.35]
|
| 9 |
+
D6. Round transitions: R1 result is visible to R3 SelfCorrectionRubric
|
| 10 |
+
D7. Trap detection: correct trap C++ should pass; wrong should fail
|
| 11 |
+
D8. Hardware-Roofline math is sensible on all 8 profiles (no NaN/Inf/zero)
|
| 12 |
+
D9. System-prompt template is well-formed (auto-generates from problem)
|
| 13 |
+
D10. Pydantic Action/Observation/State roundtrip through JSON
|
| 14 |
+
D11. Reserved-name tool name + reserved-name in tool_args don't crash
|
| 15 |
+
D12. Compilation cache key is correct: hw-profile-different cpp gets different key
|
| 16 |
+
D13. Adaptive curriculum at max levels doesn't crash on more "high success" inputs
|
| 17 |
+
D14. DatasetLoader handles 100 consecutive sample() calls without exception
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import json
|
| 23 |
+
import random
|
| 24 |
+
import sys
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
|
| 27 |
+
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
| 28 |
+
|
| 29 |
+
import numpy as np
|
| 30 |
+
import pytest
|
| 31 |
+
|
| 32 |
+
from models import OptimizationAction, OptimizationObservation, OptimizationState
|
| 33 |
+
from server.environment import PolyglotOptimaEnvironment
|
| 34 |
+
from server.rewards import build_round_reward_dag, SpeedupRubric
|
| 35 |
+
from server.scenarios import (
|
| 36 |
+
HARDWARE_PROFILES, AdaptiveCurriculum, DatasetLoader,
|
| 37 |
+
)
|
| 38 |
+
from server.tools import TOOL_REGISTRY
|
| 39 |
+
from server.tools.cpp_compiler import _sha256
|
| 40 |
+
from server.tools.hardware_profiler import roofline_bound
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def make_state(round_n=1, axes=None):
|
| 44 |
+
return OptimizationState(
|
| 45 |
+
episode_id="deep-smoke",
|
| 46 |
+
python_code="def sum_squares(arr):\n s = 0.0\n for x in arr:\n s += x*x\n return s\n",
|
| 47 |
+
function_signature_cpp='extern "C" double agent_function(const double*, size_t);',
|
| 48 |
+
hardware_profile={"id": "desktop_avx2", "cores": 8, "freq_ghz": 3.8,
|
| 49 |
+
"l1_kb": 32, "simd": "AVX2", "bw_gbs": 51},
|
| 50 |
+
bottleneck_ground_truth=["compute-bound", "vectorizable"],
|
| 51 |
+
bottleneck_distractors=["memory-bound", "branch-heavy", "io-bound"],
|
| 52 |
+
round_number=round_n,
|
| 53 |
+
difficulty_axes=axes or {"function_tier": 0, "hardware_class": 0,
|
| 54 |
+
"fuzzer_strictness": 0, "portability_required": 0},
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# ---------- D1. Reward sanity differential ----------
|
| 59 |
+
|
| 60 |
+
def test_D1_reward_sanity_differential():
|
| 61 |
+
"""An obviously-good submission must score strictly higher than obviously-bad."""
|
| 62 |
+
state = make_state(round_n=1)
|
| 63 |
+
state.round_results = [{"round": 1, "tool_calls": ["get_hardware_profile"]}]
|
| 64 |
+
|
| 65 |
+
obviously_good = {
|
| 66 |
+
"compile_status": "success",
|
| 67 |
+
"correctness_pass_rate": 0.99,
|
| 68 |
+
"adversarial_pass_rate": 0.99,
|
| 69 |
+
"speedup": 12.0,
|
| 70 |
+
"reasoning_trace": "compute-bound vectorizable",
|
| 71 |
+
}
|
| 72 |
+
obviously_bad = {
|
| 73 |
+
"compile_status": "syntax_error",
|
| 74 |
+
"correctness_pass_rate": 0.0,
|
| 75 |
+
"adversarial_pass_rate": 0.0,
|
| 76 |
+
"speedup": 0.0,
|
| 77 |
+
"reasoning_trace": "",
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
dag = build_round_reward_dag(1)
|
| 81 |
+
good_score = dag.score(state, obviously_good)
|
| 82 |
+
bad_score = dag.score(state, obviously_bad)
|
| 83 |
+
assert good_score > 0.4, f"good submission scored only {good_score:.3f}"
|
| 84 |
+
assert 0.0 <= bad_score < 0.02
|
| 85 |
+
assert good_score > bad_score + 0.3
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# ---------- D2. End-to-end 3-round episode runs ----------
|
| 89 |
+
|
| 90 |
+
def test_D2_full_three_round_episode_runs():
|
| 91 |
+
"""A 3-round episode with stub tool calls + 3 submits must complete with done=True."""
|
| 92 |
+
env = PolyglotOptimaEnvironment()
|
| 93 |
+
env.reset(seed=7)
|
| 94 |
+
|
| 95 |
+
for round_idx in range(3):
|
| 96 |
+
# Some tool calls within the round
|
| 97 |
+
env.step(OptimizationAction(
|
| 98 |
+
tool_name="get_hardware_profile",
|
| 99 |
+
tool_args={},
|
| 100 |
+
reasoning_trace="<think>compute-bound vectorizable</think>",
|
| 101 |
+
))
|
| 102 |
+
env.step(OptimizationAction(
|
| 103 |
+
tool_name="analyze_complexity",
|
| 104 |
+
tool_args={"code": env.state().python_code},
|
| 105 |
+
reasoning_trace="depth check",
|
| 106 |
+
))
|
| 107 |
+
# Submit
|
| 108 |
+
result = env.step(OptimizationAction(
|
| 109 |
+
tool_name="submit_optimization",
|
| 110 |
+
tool_args={
|
| 111 |
+
"cpp_code": "// stub round " + str(round_idx + 1),
|
| 112 |
+
"reasoning_trace": "compute-bound",
|
| 113 |
+
},
|
| 114 |
+
reasoning_trace="<think>round " + str(round_idx + 1) + "</think>",
|
| 115 |
+
))
|
| 116 |
+
|
| 117 |
+
assert result.done is True
|
| 118 |
+
assert env.state().is_terminal
|
| 119 |
+
# Final episode reward = 0.3*R1 + 0.7*R3
|
| 120 |
+
assert isinstance(result.reward, float)
|
| 121 |
+
env.close()
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# ---------- D3. Curriculum escalation actually serves harder problems ----------
|
| 125 |
+
|
| 126 |
+
def test_D3_curriculum_escalation_serves_harder_problems():
|
| 127 |
+
"""When function_tier escalates, DatasetLoader must serve higher-tier templates."""
|
| 128 |
+
rng = random.Random(0)
|
| 129 |
+
loader = DatasetLoader(prefer_real_datasets=False)
|
| 130 |
+
|
| 131 |
+
# At tier 0, all sampled templates have tier ≤ 0
|
| 132 |
+
samples_t0 = [
|
| 133 |
+
loader.sample({"function_tier": 0, "hardware_class": 0,
|
| 134 |
+
"fuzzer_strictness": 0, "portability_required": 0}, rng)
|
| 135 |
+
for _ in range(100)
|
| 136 |
+
]
|
| 137 |
+
tier_0_template_tiers = [s.get("tier", 0) for s in samples_t0 if not s.get("is_trap")]
|
| 138 |
+
assert all(t <= 0 for t in tier_0_template_tiers), \
|
| 139 |
+
f"tier=0 axis sampled higher-tier templates: {set(tier_0_template_tiers)}"
|
| 140 |
+
|
| 141 |
+
# At tier 3, samples include tier-3 templates
|
| 142 |
+
samples_t3 = [
|
| 143 |
+
loader.sample({"function_tier": 3, "hardware_class": 0,
|
| 144 |
+
"fuzzer_strictness": 0, "portability_required": 0}, rng)
|
| 145 |
+
for _ in range(100)
|
| 146 |
+
]
|
| 147 |
+
tier_3_template_tiers = [s.get("tier", 0) for s in samples_t3 if not s.get("is_trap")]
|
| 148 |
+
assert max(tier_3_template_tiers) >= 2, \
|
| 149 |
+
f"tier=3 axis never produced tier≥2 templates: {set(tier_3_template_tiers)}"
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# ---------- D4. All tool outputs JSON-serializable ----------
|
| 153 |
+
|
| 154 |
+
def test_D4_all_tool_outputs_json_serializable():
|
| 155 |
+
"""Every tool's return must roundtrip through JSON cleanly (FastAPI / wandb)."""
|
| 156 |
+
state = make_state()
|
| 157 |
+
for tool_name, tool_fn in TOOL_REGISTRY.items():
|
| 158 |
+
# Each tool gets a permissive args dict; some will return errors, that's fine
|
| 159 |
+
args = {"cpp_code": "extern \"C\" int agent_function() { return 0; }",
|
| 160 |
+
"code": state.python_code, "n_cases": 5,
|
| 161 |
+
"python_code": state.python_code}
|
| 162 |
+
out = tool_fn(args, state)
|
| 163 |
+
try:
|
| 164 |
+
serialized = json.dumps(out, default=str)
|
| 165 |
+
roundtripped = json.loads(serialized)
|
| 166 |
+
except (TypeError, ValueError) as e:
|
| 167 |
+
pytest.fail(f"tool {tool_name} returned non-JSON-serializable output: {e}")
|
| 168 |
+
assert isinstance(roundtripped, dict)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# ---------- D5. Reward variance in healthy GRPO band ----------
|
| 172 |
+
|
| 173 |
+
def test_D5_reward_variance_over_simulated_rollouts():
|
| 174 |
+
"""Simulate 8 rollouts with varied submissions; std should land in [0.10, 0.40]."""
|
| 175 |
+
state = make_state(round_n=1)
|
| 176 |
+
state.round_results = [{"round": 1, "tool_calls": ["get_hardware_profile"]}]
|
| 177 |
+
dag = build_round_reward_dag(1)
|
| 178 |
+
|
| 179 |
+
# Synthetic 8-rollout batch — varied (compile rate, correctness, speedup, reasoning quality)
|
| 180 |
+
rollouts = [
|
| 181 |
+
{"compile_status": "success", "correctness_pass_rate": 0.95, "adversarial_pass_rate": 0.95,
|
| 182 |
+
"speedup": 12.0, "reasoning_trace": "compute-bound vectorizable"},
|
| 183 |
+
{"compile_status": "success", "correctness_pass_rate": 0.85, "adversarial_pass_rate": 0.95,
|
| 184 |
+
"speedup": 6.0, "reasoning_trace": "compute-bound"},
|
| 185 |
+
{"compile_status": "syntax_error", "correctness_pass_rate": 0.0, "adversarial_pass_rate": 0.0,
|
| 186 |
+
"speedup": 0.0, "reasoning_trace": ""},
|
| 187 |
+
{"compile_status": "success", "correctness_pass_rate": 0.55, "adversarial_pass_rate": 0.95,
|
| 188 |
+
"speedup": 0.0, "reasoning_trace": "compute-bound"}, # below gate → 0
|
| 189 |
+
{"compile_status": "success", "correctness_pass_rate": 0.92, "adversarial_pass_rate": 0.90,
|
| 190 |
+
"speedup": 8.0, "reasoning_trace": "vectorizable"},
|
| 191 |
+
{"compile_status": "success", "correctness_pass_rate": 0.70, "adversarial_pass_rate": 0.95,
|
| 192 |
+
"speedup": 4.0, "reasoning_trace": "compute-bound vectorizable"},
|
| 193 |
+
{"compile_status": "success", "correctness_pass_rate": 1.0, "adversarial_pass_rate": 1.0,
|
| 194 |
+
"speedup": 18.0, "reasoning_trace": "compute-bound vectorizable"},
|
| 195 |
+
{"compile_status": "syntax_error", "correctness_pass_rate": 0.0, "adversarial_pass_rate": 0.0,
|
| 196 |
+
"speedup": 0.0, "reasoning_trace": "memory-bound"},
|
| 197 |
+
]
|
| 198 |
+
rewards = np.array([dag.score(state, sub) for sub in rollouts])
|
| 199 |
+
mean = rewards.mean()
|
| 200 |
+
std = rewards.std()
|
| 201 |
+
# GRPO healthy band per plan §11
|
| 202 |
+
assert 0.10 <= std <= 0.45, f"reward_std={std:.3f} outside healthy band [0.10, 0.40]; mean={mean:.3f}"
|
| 203 |
+
assert 0.05 <= mean <= 0.95
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
# ---------- D6. Round transitions: R1 visible to R3 SelfCorrectionRubric ----------
|
| 207 |
+
|
| 208 |
+
def test_D6_round_transitions_carry_state():
|
| 209 |
+
"""SelfCorrectionRubric in R3 must see R1's compile_status + speedup."""
|
| 210 |
+
env = PolyglotOptimaEnvironment()
|
| 211 |
+
env.reset(seed=11)
|
| 212 |
+
|
| 213 |
+
# Simulate R1 with a "compiled" submission (stubbed)
|
| 214 |
+
env.step(OptimizationAction(
|
| 215 |
+
tool_name="submit_optimization",
|
| 216 |
+
tool_args={"cpp_code": "// r1", "reasoning_trace": "first attempt"},
|
| 217 |
+
reasoning_trace="round 1",
|
| 218 |
+
))
|
| 219 |
+
# Simulate R2
|
| 220 |
+
env.step(OptimizationAction(
|
| 221 |
+
tool_name="submit_optimization",
|
| 222 |
+
tool_args={"cpp_code": "// r2", "reasoning_trace": "second"},
|
| 223 |
+
reasoning_trace="round 2",
|
| 224 |
+
))
|
| 225 |
+
state = env.state()
|
| 226 |
+
# After 2 submits: round_results should have 2 entries
|
| 227 |
+
assert len(state.round_results) == 2
|
| 228 |
+
assert state.round_results[0]["round"] == 1
|
| 229 |
+
assert state.round_results[1]["round"] == 2
|
| 230 |
+
env.close()
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
# ---------- D7. Trap detection ----------
|
| 234 |
+
|
| 235 |
+
def test_D7_trap_metadata_propagates_to_problem():
|
| 236 |
+
"""When a trap is sampled, its metadata (rtol_override, ground-truth labels) survives."""
|
| 237 |
+
from server.scenarios.trap_library import sample_trap, trap_to_problem_dict
|
| 238 |
+
rng = random.Random(0)
|
| 239 |
+
for _ in range(10):
|
| 240 |
+
trap = sample_trap(rng)
|
| 241 |
+
p = trap_to_problem_dict(trap, HARDWARE_PROFILES[0])
|
| 242 |
+
assert p["is_trap"] is True
|
| 243 |
+
assert p["bottleneck_labels"] == trap.bottleneck_label
|
| 244 |
+
if trap.rtol_override == 0:
|
| 245 |
+
assert p["rtol_override"] == 0
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
# ---------- D8. Roofline math sensible on all 8 profiles ----------
|
| 249 |
+
|
| 250 |
+
def test_D8_roofline_math_all_profiles_finite():
|
| 251 |
+
"""Every hardware profile must yield a finite, positive Roofline bound."""
|
| 252 |
+
for profile in HARDWARE_PROFILES:
|
| 253 |
+
bound = roofline_bound(profile)
|
| 254 |
+
assert np.isfinite(bound), f"{profile['id']} → non-finite roofline {bound}"
|
| 255 |
+
assert bound > 0, f"{profile['id']} → non-positive roofline {bound}"
|
| 256 |
+
assert bound < 10000, f"{profile['id']} → suspiciously huge roofline {bound}"
|
| 257 |
+
|
| 258 |
+
# SpeedupRubric on a 1.0x speedup should yield reward in [0, 1]
|
| 259 |
+
rubric = SpeedupRubric()
|
| 260 |
+
# Build a state with this profile
|
| 261 |
+
state = OptimizationState(episode_id="r", hardware_profile=profile)
|
| 262 |
+
score = rubric.score(state, {"speedup": 1.0})
|
| 263 |
+
assert 0 <= score <= 1
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
# ---------- D9. System-prompt template constructible ----------
|
| 267 |
+
|
| 268 |
+
def test_D9_system_prompt_constructible():
|
| 269 |
+
"""The episode system prompt assembles cleanly from the problem dict."""
|
| 270 |
+
rng = random.Random(0)
|
| 271 |
+
loader = DatasetLoader()
|
| 272 |
+
problem = loader.sample(
|
| 273 |
+
{"function_tier": 1, "hardware_class": 0,
|
| 274 |
+
"fuzzer_strictness": 0, "portability_required": 0}, rng,
|
| 275 |
+
)
|
| 276 |
+
# The agent's system prompt is constructed from these fields
|
| 277 |
+
# Just assert all pieces exist + are non-empty strings/dicts
|
| 278 |
+
assert isinstance(problem["python_code"], str) and len(problem["python_code"]) > 10
|
| 279 |
+
assert isinstance(problem["hardware_profile"], dict)
|
| 280 |
+
assert "simd" in problem["hardware_profile"]
|
| 281 |
+
assert isinstance(problem["bottleneck_labels"], list) and problem["bottleneck_labels"]
|
| 282 |
+
assert "agent_function" in problem["cpp_signature"]
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
# ---------- D10. Pydantic models JSON roundtrip ----------
|
| 286 |
+
|
| 287 |
+
def test_D10_pydantic_models_json_roundtrip():
|
| 288 |
+
a = OptimizationAction(tool_name="profile_python_hotspots", tool_args={"code": "x"},
|
| 289 |
+
reasoning_trace="<think>test</think>")
|
| 290 |
+
a2 = OptimizationAction.model_validate_json(a.model_dump_json())
|
| 291 |
+
assert a2.tool_name == a.tool_name and a2.tool_args == a.tool_args
|
| 292 |
+
|
| 293 |
+
obs = OptimizationObservation(done=False, reward=0.5,
|
| 294 |
+
tool_result={"k": "v"}, python_code="def f(): pass",
|
| 295 |
+
hardware_profile={"id": "x"})
|
| 296 |
+
obs2 = OptimizationObservation.model_validate_json(obs.model_dump_json())
|
| 297 |
+
assert obs2.reward == obs.reward and obs2.tool_result == obs.tool_result
|
| 298 |
+
|
| 299 |
+
s = OptimizationState(episode_id="e1", python_code="x")
|
| 300 |
+
s2 = OptimizationState.model_validate_json(s.model_dump_json())
|
| 301 |
+
assert s2.episode_id == s.episode_id
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
# ---------- D11. Reserved-name and bad-arg robustness ----------
|
| 305 |
+
|
| 306 |
+
def test_D11_reserved_tool_name_rejected_cleanly():
|
| 307 |
+
"""Reserved names (reset/step/state/close) must raise OpenEnvError, not crash."""
|
| 308 |
+
env = PolyglotOptimaEnvironment()
|
| 309 |
+
env.reset(seed=0)
|
| 310 |
+
for reserved in ("reset", "step", "state", "close"):
|
| 311 |
+
with pytest.raises(Exception):
|
| 312 |
+
env.step(OptimizationAction(tool_name=reserved, tool_args={},
|
| 313 |
+
reasoning_trace=""))
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def test_D11b_unknown_tool_returns_stub_not_crash():
|
| 317 |
+
"""An unknown tool name should fall back to stub, not crash mid-episode."""
|
| 318 |
+
env = PolyglotOptimaEnvironment()
|
| 319 |
+
env.reset(seed=0)
|
| 320 |
+
# Empty the registry to force the "unknown tool" path
|
| 321 |
+
env._tool_registry = {}
|
| 322 |
+
result = env.step(OptimizationAction(tool_name="profile_python_hotspots",
|
| 323 |
+
tool_args={}, reasoning_trace=""))
|
| 324 |
+
assert result.done is False # episode survives
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
# ---------- D12. Compilation cache key correctness ----------
|
| 328 |
+
|
| 329 |
+
def test_D12_compile_cache_key_distinguishes_hardware():
|
| 330 |
+
"""Same code on different hardware should hash to different cache keys."""
|
| 331 |
+
code = "extern \"C\" int agent_function() { return 0; }"
|
| 332 |
+
hw_a = {"id": "desktop_avx2", "cores": 8}
|
| 333 |
+
hw_b = {"id": "server_avx512", "cores": 16}
|
| 334 |
+
import json as _json
|
| 335 |
+
key_a = _sha256(code, _json.dumps(hw_a, sort_keys=True))
|
| 336 |
+
key_b = _sha256(code, _json.dumps(hw_b, sort_keys=True))
|
| 337 |
+
assert key_a != key_b
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def test_D12b_compile_cache_key_same_for_same_inputs():
|
| 341 |
+
code = "int x;"
|
| 342 |
+
hw = {"id": "x", "cores": 1}
|
| 343 |
+
import json as _json
|
| 344 |
+
k1 = _sha256(code, _json.dumps(hw, sort_keys=True))
|
| 345 |
+
k2 = _sha256(code, _json.dumps(hw, sort_keys=True))
|
| 346 |
+
assert k1 == k2
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
# ---------- D13. Curriculum at extreme states ----------
|
| 350 |
+
|
| 351 |
+
def test_D13_curriculum_at_max_no_crash():
|
| 352 |
+
c = AdaptiveCurriculum(seed=0,
|
| 353 |
+
initial_axes={"function_tier": 3, "hardware_class": 2,
|
| 354 |
+
"fuzzer_strictness": 2, "portability_required": 1})
|
| 355 |
+
for _ in range(50):
|
| 356 |
+
c.observe_batch(0.95)
|
| 357 |
+
snap = c.snapshot()
|
| 358 |
+
# All axes still at max
|
| 359 |
+
assert snap.axes["function_tier"] == 3
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def test_D13b_curriculum_at_min_no_crash():
|
| 363 |
+
c = AdaptiveCurriculum(seed=0)
|
| 364 |
+
for _ in range(50):
|
| 365 |
+
c.observe_batch(0.05)
|
| 366 |
+
assert all(c.axes[a] == 0 for a in c.axes)
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
# ---------- D14. DatasetLoader stress test ----------
|
| 370 |
+
|
| 371 |
+
def test_D14_dataset_loader_100_consecutive_samples():
|
| 372 |
+
"""Loader survives 100 consecutive sample() calls without exception."""
|
| 373 |
+
rng = random.Random(0)
|
| 374 |
+
loader = DatasetLoader(prefer_real_datasets=False)
|
| 375 |
+
seen = set()
|
| 376 |
+
for i in range(100):
|
| 377 |
+
axes = {"function_tier": i % 4, "hardware_class": i % 3,
|
| 378 |
+
"fuzzer_strictness": i % 3, "portability_required": i % 2}
|
| 379 |
+
sample = loader.sample(axes, rng)
|
| 380 |
+
seen.add(sample["python_code"][:30])
|
| 381 |
+
# Confirm meaningful diversity (not always returning the same problem)
|
| 382 |
+
assert len(seen) > 5
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
# ---------- Aggregate summary ----------
|
| 386 |
+
|
| 387 |
+
def test_DEEP_SMOKE_all_tests_present():
|
| 388 |
+
"""Roll-call: every D-test is defined in this module."""
|
| 389 |
+
import sys as _sys
|
| 390 |
+
expected = [
|
| 391 |
+
"test_D1_reward_sanity_differential",
|
| 392 |
+
"test_D2_full_three_round_episode_runs",
|
| 393 |
+
"test_D3_curriculum_escalation_serves_harder_problems",
|
| 394 |
+
"test_D4_all_tool_outputs_json_serializable",
|
| 395 |
+
"test_D5_reward_variance_over_simulated_rollouts",
|
| 396 |
+
"test_D6_round_transitions_carry_state",
|
| 397 |
+
"test_D7_trap_metadata_propagates_to_problem",
|
| 398 |
+
"test_D8_roofline_math_all_profiles_finite",
|
| 399 |
+
"test_D9_system_prompt_constructible",
|
| 400 |
+
"test_D10_pydantic_models_json_roundtrip",
|
| 401 |
+
"test_D11_reserved_tool_name_rejected_cleanly",
|
| 402 |
+
"test_D11b_unknown_tool_returns_stub_not_crash",
|
| 403 |
+
"test_D12_compile_cache_key_distinguishes_hardware",
|
| 404 |
+
"test_D12b_compile_cache_key_same_for_same_inputs",
|
| 405 |
+
"test_D13_curriculum_at_max_no_crash",
|
| 406 |
+
"test_D13b_curriculum_at_min_no_crash",
|
| 407 |
+
"test_D14_dataset_loader_100_consecutive_samples",
|
| 408 |
+
]
|
| 409 |
+
for tid in expected:
|
| 410 |
+
assert hasattr(_sys.modules[__name__], tid), f"deep smoke test {tid} missing"
|
tests/test_tools.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hour 4-10: Tool unit tests.
|
| 2 |
+
|
| 3 |
+
Each of the 9 MCP tools verified for shape + key invariants. Compiler-dependent
|
| 4 |
+
tests (cpp_compiler, verifier, portability) are gated on g++ being installed —
|
| 5 |
+
they skip cleanly if the toolchain is unavailable.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import shutil
|
| 11 |
+
import sys
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
| 15 |
+
|
| 16 |
+
import pytest
|
| 17 |
+
|
| 18 |
+
from models import OptimizationState
|
| 19 |
+
from server.tools import TOOL_REGISTRY
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
HAS_GPP = shutil.which("g++") is not None or shutil.which("clang++") is not None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _has_cxx20() -> bool:
|
| 26 |
+
"""True only if a C++20-capable compiler is on PATH (GCC ≥ 11 / clang ≥ 13).
|
| 27 |
+
|
| 28 |
+
Dev machines (e.g. ancient MinGW on Windows) often have g++ but not C++20,
|
| 29 |
+
so the cpp_compiler test skips cleanly there. The HF Spaces Docker container
|
| 30 |
+
pins GCC 14, so this passes in CI/deploy.
|
| 31 |
+
"""
|
| 32 |
+
import subprocess
|
| 33 |
+
for cxx in ("g++", "clang++"):
|
| 34 |
+
path = shutil.which(cxx)
|
| 35 |
+
if not path:
|
| 36 |
+
continue
|
| 37 |
+
try:
|
| 38 |
+
r = subprocess.run([path, "-std=c++20", "-x", "c++", "-E", "-"],
|
| 39 |
+
input="", capture_output=True, text=True, timeout=5)
|
| 40 |
+
if r.returncode == 0 or "unrecognized" not in (r.stderr or "").lower():
|
| 41 |
+
return True
|
| 42 |
+
except Exception:
|
| 43 |
+
continue
|
| 44 |
+
return False
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
HAS_CXX20 = _has_cxx20()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# ----------- common fixture -----------
|
| 51 |
+
|
| 52 |
+
@pytest.fixture
|
| 53 |
+
def state():
|
| 54 |
+
"""A representative OptimizationState the tools accept."""
|
| 55 |
+
return OptimizationState(
|
| 56 |
+
episode_id="test-ep",
|
| 57 |
+
python_code="def sum_squares(arr):\n total = 0.0\n for x in arr:\n total += x*x\n return total\n",
|
| 58 |
+
function_signature_cpp='extern "C" double agent_function(const double*, size_t);',
|
| 59 |
+
hardware_profile={
|
| 60 |
+
"id": "desktop_avx2",
|
| 61 |
+
"cores": 8, "freq_ghz": 3.8, "l1_kb": 32,
|
| 62 |
+
"simd": "AVX2", "bw_gbs": 51,
|
| 63 |
+
},
|
| 64 |
+
bottleneck_ground_truth=["compute-bound", "vectorizable"],
|
| 65 |
+
bottleneck_distractors=["memory-bound", "branch-heavy", "io-bound"],
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# ----------- Tool 1: hardware_profiler -----------
|
| 70 |
+
|
| 71 |
+
def test_get_hardware_profile_returns_roofline(state):
|
| 72 |
+
out = TOOL_REGISTRY["get_hardware_profile"]({}, state)
|
| 73 |
+
assert "roofline_bound_gflops" in out
|
| 74 |
+
assert out["roofline_bound_gflops"] > 0
|
| 75 |
+
assert out["simd_width_floats"] == 8 # AVX2 → 8 floats
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# ----------- Tools 2-4: python_analyzer suite -----------
|
| 79 |
+
|
| 80 |
+
def test_profile_python_hotspots(state):
|
| 81 |
+
out = TOOL_REGISTRY["profile_python_hotspots"]({}, state)
|
| 82 |
+
assert "hotspots" in out
|
| 83 |
+
assert isinstance(out["hotspots"], list)
|
| 84 |
+
assert "total_estimated_cost" in out
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def test_analyze_complexity_detects_O_n(state):
|
| 88 |
+
out = TOOL_REGISTRY["analyze_complexity"]({}, state)
|
| 89 |
+
assert out["big_o_estimate"] == "O(n)"
|
| 90 |
+
assert out["max_loop_nesting_depth"] == 1
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def test_analyze_complexity_detects_O_n_squared(state):
|
| 94 |
+
state.python_code = (
|
| 95 |
+
"def pairwise(X):\n"
|
| 96 |
+
" n = len(X)\n"
|
| 97 |
+
" D = [[0.0]*n for _ in range(n)]\n"
|
| 98 |
+
" for i in range(n):\n"
|
| 99 |
+
" for j in range(n):\n"
|
| 100 |
+
" D[i][j] = (X[i] - X[j])**2\n"
|
| 101 |
+
" return D\n"
|
| 102 |
+
)
|
| 103 |
+
out = TOOL_REGISTRY["analyze_complexity"]({}, state)
|
| 104 |
+
assert out["big_o_estimate"] == "O(n^2)"
|
| 105 |
+
assert out["max_loop_nesting_depth"] == 2
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def test_check_memory_access_flags_stride(state):
|
| 109 |
+
state.python_code = (
|
| 110 |
+
"def transpose_loop(a, b, n):\n"
|
| 111 |
+
" for i in range(n):\n"
|
| 112 |
+
" for j in range(n):\n"
|
| 113 |
+
" b[i, j] = a[j, i]\n" # column-major access in row-major
|
| 114 |
+
)
|
| 115 |
+
out = TOOL_REGISTRY["check_memory_access"]({}, state)
|
| 116 |
+
assert any(i["type"] == "non_unit_stride" for i in out["issues"])
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# ----------- Tool 5: cpp_compiler -----------
|
| 120 |
+
|
| 121 |
+
@pytest.mark.skipif(not HAS_GPP, reason="g++/clang++ not installed")
|
| 122 |
+
def test_compile_with_invalid_cpp_returns_syntax_error(state):
|
| 123 |
+
out = TOOL_REGISTRY["compile_and_benchmark"]({"cpp_code": "this is not c++"}, state)
|
| 124 |
+
assert out["compile_status"] == "syntax_error"
|
| 125 |
+
assert out["speedup"] == 0.0
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@pytest.mark.skipif(not HAS_GPP, reason="g++/clang++ not installed")
|
| 129 |
+
def test_compile_rejects_banned_headers(state):
|
| 130 |
+
code = '#include <mkl.h>\nextern "C" double agent_function() { return 0.0; }\n'
|
| 131 |
+
out = TOOL_REGISTRY["compile_and_benchmark"]({"cpp_code": code}, state)
|
| 132 |
+
assert out["compile_status"] == "syntax_error"
|
| 133 |
+
assert "mkl" in out["error"].lower() or "banned" in out["error"].lower()
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def test_compile_rejects_missing_entry_point(state):
|
| 137 |
+
code = "double f(int x) { return x; }\n" # no extern "C" agent_function
|
| 138 |
+
out = TOOL_REGISTRY["compile_and_benchmark"]({"cpp_code": code}, state)
|
| 139 |
+
assert out["compile_status"] == "syntax_error"
|
| 140 |
+
assert "agent_function" in out["error"]
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
@pytest.mark.skipif(not HAS_CXX20, reason="C++20 compiler not available (GCC<11 or clang<13)")
|
| 144 |
+
def test_compile_valid_cpp_succeeds(state):
|
| 145 |
+
code = (
|
| 146 |
+
'#include <cstddef>\n'
|
| 147 |
+
'extern "C" double agent_function(const double* arr, size_t n) {\n'
|
| 148 |
+
' double total = 0.0;\n'
|
| 149 |
+
' for (size_t i = 0; i < n; ++i) total += arr[i] * arr[i];\n'
|
| 150 |
+
' return total;\n'
|
| 151 |
+
'}\n'
|
| 152 |
+
)
|
| 153 |
+
out = TOOL_REGISTRY["compile_and_benchmark"]({"cpp_code": code}, state)
|
| 154 |
+
assert out["compile_status"] == "success"
|
| 155 |
+
assert out["speedup"] > 0.0
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# ----------- Tool 6: verifier -----------
|
| 159 |
+
|
| 160 |
+
def test_verify_rejects_empty_cpp(state):
|
| 161 |
+
out = TOOL_REGISTRY["verify_equivalence"]({"cpp_code": ""}, state)
|
| 162 |
+
assert out["pass_rate"] == 0.0
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def test_verify_rejects_non_positive_case_count(state):
|
| 166 |
+
out = TOOL_REGISTRY["verify_equivalence"]({"cpp_code": "double f() { return 0; }", "n_cases": 0}, state)
|
| 167 |
+
assert out["pass_rate"] == 0.0
|
| 168 |
+
assert "n_cases" in out["error"]
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
@pytest.mark.skipif(not HAS_GPP, reason="g++/clang++ not installed")
|
| 172 |
+
def test_verify_rejects_missing_entry(state):
|
| 173 |
+
out = TOOL_REGISTRY["verify_equivalence"]({"cpp_code": "double f() { return 0; }"}, state)
|
| 174 |
+
assert out["pass_rate"] == 0.0
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
# ----------- Tool 7: portability -----------
|
| 178 |
+
|
| 179 |
+
def test_portability_with_empty_cpp_returns_zero(state):
|
| 180 |
+
out = TOOL_REGISTRY["check_portability"]({"cpp_code": ""}, state)
|
| 181 |
+
assert out["n_profiles_passing"] == 0
|
| 182 |
+
assert out["portability_bonus_eligible"] is False
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
# ----------- Tool 8: bottleneck_reporter -----------
|
| 186 |
+
|
| 187 |
+
def test_bottleneck_reporter_detects_simd_use(state):
|
| 188 |
+
code = (
|
| 189 |
+
'#include <immintrin.h>\n'
|
| 190 |
+
'extern "C" double agent_function(const double* a, size_t n) {\n'
|
| 191 |
+
' __m256d acc = _mm256_setzero_pd();\n'
|
| 192 |
+
' for (size_t i = 0; i + 4 <= n; i += 4) {\n'
|
| 193 |
+
' __m256d v = _mm256_loadu_pd(a + i);\n'
|
| 194 |
+
' acc = _mm256_fmadd_pd(v, v, acc);\n'
|
| 195 |
+
' }\n'
|
| 196 |
+
' return 0;\n'
|
| 197 |
+
'}\n'
|
| 198 |
+
)
|
| 199 |
+
out = TOOL_REGISTRY["get_bottleneck_report"]({"cpp_code": code}, state)
|
| 200 |
+
assert out["uses_simd"] is True
|
| 201 |
+
assert out["estimated_vectorization_pct"] >= 80.0
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def test_bottleneck_reporter_suggests_simd(state):
|
| 205 |
+
code = (
|
| 206 |
+
'extern "C" double agent_function(const double* a, size_t n) {\n'
|
| 207 |
+
' double t = 0;\n'
|
| 208 |
+
' for (size_t i = 0; i < n; ++i) t += a[i]*a[i];\n'
|
| 209 |
+
' return t;\n'
|
| 210 |
+
'}\n'
|
| 211 |
+
)
|
| 212 |
+
out = TOOL_REGISTRY["get_bottleneck_report"]({"cpp_code": code}, state)
|
| 213 |
+
assert out["uses_simd"] is False
|
| 214 |
+
assert any("SIMD" in s for s in out["suggestions"])
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# ----------- Tool 9: submit -----------
|
| 218 |
+
|
| 219 |
+
def test_submit_with_empty_cpp_not_ready(state):
|
| 220 |
+
out = TOOL_REGISTRY["submit_optimization"]({"cpp_code": ""}, state)
|
| 221 |
+
assert out["ready_for_reward"] is False
|
| 222 |
+
assert out["compile_status"] == "syntax_error"
|
training/openenv_hackathon_training.ipynb
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Polyglot-Optima Hackathon Training Notebook\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"This notebook is a **submission-oriented, executable** workflow:\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"1. OpenEnv environment loop sanity checks\n",
|
| 12 |
+
"2. Baseline evaluation with fixed seeds\n",
|
| 13 |
+
"3. Executable training block (SFT demo path, budget-friendly)\n",
|
| 14 |
+
"4. W&B tracking (reward, correctness, compile status, portability)\n",
|
| 15 |
+
"5. Plot export for README evidence\n",
|
| 16 |
+
"\n",
|
| 17 |
+
"Use this notebook locally, in Colab, or on Hugging Face Jobs.\n",
|
| 18 |
+
"\n",
|
| 19 |
+
"> For final hackathon submission, deploy your demo endpoint and link results artifacts in `README.md`."
|
| 20 |
+
],
|
| 21 |
+
"id": "93a92bf4"
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"cell_type": "code",
|
| 25 |
+
"metadata": {},
|
| 26 |
+
"source": [
|
| 27 |
+
"# If running in Colab, uncomment:\n",
|
| 28 |
+
"# %pip install -q trl transformers datasets wandb matplotlib\n",
|
| 29 |
+
"\n",
|
| 30 |
+
"import os\n",
|
| 31 |
+
"import sys\n",
|
| 32 |
+
"import json\n",
|
| 33 |
+
"import random\n",
|
| 34 |
+
"from pathlib import Path\n",
|
| 35 |
+
"\n",
|
| 36 |
+
"import matplotlib.pyplot as plt\n",
|
| 37 |
+
"\n",
|
| 38 |
+
"# Ensure imports work regardless of notebook launch directory.\n",
|
| 39 |
+
"PROJECT_ROOT = Path.cwd().resolve().parents[0] if Path.cwd().name == \"training\" else Path.cwd().resolve()\n",
|
| 40 |
+
"if str(PROJECT_ROOT) not in sys.path:\n",
|
| 41 |
+
" sys.path.insert(0, str(PROJECT_ROOT))\n",
|
| 42 |
+
"\n",
|
| 43 |
+
"from models import OptimizationAction\n",
|
| 44 |
+
"from server.environment import PolyglotOptimaEnvironment\n",
|
| 45 |
+
"\n",
|
| 46 |
+
"print(\"project root:\", PROJECT_ROOT)\n"
|
| 47 |
+
],
|
| 48 |
+
"execution_count": null,
|
| 49 |
+
"outputs": [],
|
| 50 |
+
"id": "c3109ca6"
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"cell_type": "code",
|
| 54 |
+
"metadata": {},
|
| 55 |
+
"source": [
|
| 56 |
+
"# Experiment configuration (budget-aware defaults for ~$20 credits)\n",
|
| 57 |
+
"CFG = {\n",
|
| 58 |
+
" \"model_name\": os.environ.get(\"MODEL_NAME\", \"Qwen/Qwen2.5-Coder-0.5B-Instruct\"),\n",
|
| 59 |
+
" \"episodes_baseline\": int(os.environ.get(\"EPISODES_BASELINE\", \"20\")),\n",
|
| 60 |
+
" \"episodes_eval\": int(os.environ.get(\"EPISODES_EVAL\", \"20\")),\n",
|
| 61 |
+
" \"max_rounds\": 3,\n",
|
| 62 |
+
" \"max_calls_per_round\": 5,\n",
|
| 63 |
+
" \"seed\": 42,\n",
|
| 64 |
+
" \"wandb_project\": os.environ.get(\"WANDB_PROJECT\", \"openenv-polyglot-optima\"),\n",
|
| 65 |
+
" \"wandb_run_name\": os.environ.get(\"WANDB_RUN_NAME\", \"baseline-and-train-starter\"),\n",
|
| 66 |
+
" \"training_mode\": os.environ.get(\"TRAINING_MODE\", \"sft_demo\"), # sft_demo | skip\n",
|
| 67 |
+
" \"max_steps\": int(os.environ.get(\"MAX_STEPS\", \"80\")),\n",
|
| 68 |
+
" \"learning_rate\": float(os.environ.get(\"LEARNING_RATE\", \"2e-5\")),\n",
|
| 69 |
+
" \"hf_hourly_cost_usd\": float(os.environ.get(\"HF_HOURLY_COST_USD\", \"1.0\")),\n",
|
| 70 |
+
" \"target_hours\": float(os.environ.get(\"TARGET_HOURS\", \"8.0\")),\n",
|
| 71 |
+
"}\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"USE_WANDB = os.environ.get(\"USE_WANDB\", \"1\") == \"1\"\n",
|
| 74 |
+
"if USE_WANDB:\n",
|
| 75 |
+
" import wandb\n",
|
| 76 |
+
" wandb.init(project=CFG[\"wandb_project\"], name=CFG[\"wandb_run_name\"], config=CFG)\n",
|
| 77 |
+
"\n",
|
| 78 |
+
"random.seed(CFG[\"seed\"])\n",
|
| 79 |
+
"\n",
|
| 80 |
+
"estimated_budget = CFG[\"hf_hourly_cost_usd\"] * CFG[\"target_hours\"]\n",
|
| 81 |
+
"print(json.dumps(CFG, indent=2))\n",
|
| 82 |
+
"print(f\"Estimated budget envelope: ${estimated_budget:.2f}\")"
|
| 83 |
+
],
|
| 84 |
+
"execution_count": null,
|
| 85 |
+
"outputs": [],
|
| 86 |
+
"id": "d2b39137"
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"cell_type": "code",
|
| 90 |
+
"metadata": {},
|
| 91 |
+
"source": [
|
| 92 |
+
"def heuristic_policy(observation):\n",
|
| 93 |
+
" # Minimal deterministic baseline policy for reproducible before/after comparisons.\n",
|
| 94 |
+
" round_no = observation.round_number\n",
|
| 95 |
+
" if round_no == 1:\n",
|
| 96 |
+
" return OptimizationAction(tool_name=\"get_hardware_profile\", tool_args={}, reasoning_trace=\"baseline\")\n",
|
| 97 |
+
" if round_no == 2:\n",
|
| 98 |
+
" return OptimizationAction(tool_name=\"profile_python_hotspots\", tool_args={}, reasoning_trace=\"baseline\")\n",
|
| 99 |
+
" return OptimizationAction(\n",
|
| 100 |
+
" tool_name=\"submit_optimization\",\n",
|
| 101 |
+
" tool_args={\"cpp_code\": \"// baseline submit\", \"reasoning_trace\": \"baseline\"},\n",
|
| 102 |
+
" reasoning_trace=\"baseline\",\n",
|
| 103 |
+
" )\n",
|
| 104 |
+
"\n",
|
| 105 |
+
"\n",
|
| 106 |
+
"def run_eval(policy_fn, n_episodes=10, seed_start=1000):\n",
|
| 107 |
+
" env = PolyglotOptimaEnvironment(\n",
|
| 108 |
+
" max_rounds=CFG[\"max_rounds\"],\n",
|
| 109 |
+
" max_calls_per_round=CFG[\"max_calls_per_round\"],\n",
|
| 110 |
+
" enable_adaptive_curriculum=True,\n",
|
| 111 |
+
" curriculum_batch_size=8,\n",
|
| 112 |
+
" )\n",
|
| 113 |
+
" rewards = []\n",
|
| 114 |
+
" correctness = []\n",
|
| 115 |
+
" compile_success = []\n",
|
| 116 |
+
" portability = []\n",
|
| 117 |
+
"\n",
|
| 118 |
+
" for i in range(n_episodes):\n",
|
| 119 |
+
" obs = env.reset(seed=seed_start + i)\n",
|
| 120 |
+
" done = False\n",
|
| 121 |
+
" while not done:\n",
|
| 122 |
+
" action = policy_fn(obs)\n",
|
| 123 |
+
" step = env.step(action)\n",
|
| 124 |
+
" obs = step.observation\n",
|
| 125 |
+
" done = step.done\n",
|
| 126 |
+
"\n",
|
| 127 |
+
" rewards.append(float(step.reward))\n",
|
| 128 |
+
" submission = env.state().round_results[-1][\"submission\"] if env.state().round_results else {}\n",
|
| 129 |
+
" correctness.append(float(submission.get(\"correctness_pass_rate\", 0.0)))\n",
|
| 130 |
+
" compile_success.append(1.0 if submission.get(\"compile_status\") == \"success\" else 0.0)\n",
|
| 131 |
+
" portability.append(float(submission.get(\"n_profiles_passing\", 0)))\n",
|
| 132 |
+
"\n",
|
| 133 |
+
" if USE_WANDB:\n",
|
| 134 |
+
" wandb.log({\n",
|
| 135 |
+
" \"eval/reward\": rewards[-1],\n",
|
| 136 |
+
" \"eval/correctness_pass_rate\": correctness[-1],\n",
|
| 137 |
+
" \"eval/compile_success\": compile_success[-1],\n",
|
| 138 |
+
" \"eval/n_profiles_passing\": portability[-1],\n",
|
| 139 |
+
" })\n",
|
| 140 |
+
"\n",
|
| 141 |
+
" env.close()\n",
|
| 142 |
+
" return {\n",
|
| 143 |
+
" \"reward\": rewards,\n",
|
| 144 |
+
" \"correctness\": correctness,\n",
|
| 145 |
+
" \"compile_success\": compile_success,\n",
|
| 146 |
+
" \"portability\": portability,\n",
|
| 147 |
+
" }\n",
|
| 148 |
+
"\n",
|
| 149 |
+
"\n",
|
| 150 |
+
"baseline_metrics = run_eval(heuristic_policy, n_episodes=CFG[\"episodes_baseline\"])"
|
| 151 |
+
],
|
| 152 |
+
"execution_count": null,
|
| 153 |
+
"outputs": [],
|
| 154 |
+
"id": "7a970a97"
|
| 155 |
+
},
|
| 156 |
+
{
|
| 157 |
+
"cell_type": "markdown",
|
| 158 |
+
"metadata": {},
|
| 159 |
+
"source": [
|
| 160 |
+
"## Executable Training Step (Budget-Oriented)\n",
|
| 161 |
+
"\n",
|
| 162 |
+
"This notebook uses an executable **SFT demonstration training** path by default.\n",
|
| 163 |
+
"\n",
|
| 164 |
+
"Why this choice:\n",
|
| 165 |
+
"- Works reliably across local/Colab setups.\n",
|
| 166 |
+
"- Uses data generated from this OpenEnv environment (baseline trajectories).\n",
|
| 167 |
+
"- Produces measurable before/after artifacts and plots.\n",
|
| 168 |
+
"\n",
|
| 169 |
+
"If you later switch to GRPO/online RL, keep this notebook structure and replace only the training cell while preserving:\n",
|
| 170 |
+
"- fixed-seed baseline,\n",
|
| 171 |
+
"- fixed-seed post-training eval,\n",
|
| 172 |
+
"- same plotting/report outputs."
|
| 173 |
+
],
|
| 174 |
+
"id": "c58716b2"
|
| 175 |
+
},
|
| 176 |
+
{
|
| 177 |
+
"cell_type": "code",
|
| 178 |
+
"metadata": {},
|
| 179 |
+
"source": [
|
| 180 |
+
"# TRAINING CELL (executable)\n",
|
| 181 |
+
"# Default path: supervised fine-tuning on environment-generated trajectories.\n",
|
| 182 |
+
"# This creates a runnable training artifact and keeps before/after evaluation consistent.\n",
|
| 183 |
+
"\n",
|
| 184 |
+
"from typing import Dict, Any, List\n",
|
| 185 |
+
"\n",
|
| 186 |
+
"training_artifact: Dict[str, Any] = {\"mode\": CFG[\"training_mode\"], \"status\": \"not_started\"}\n",
|
| 187 |
+
"\n",
|
| 188 |
+
"if CFG[\"training_mode\"] == \"skip\":\n",
|
| 189 |
+
" training_artifact[\"status\"] = \"skipped_by_config\"\n",
|
| 190 |
+
"else:\n",
|
| 191 |
+
" try:\n",
|
| 192 |
+
" import torch\n",
|
| 193 |
+
" from datasets import Dataset\n",
|
| 194 |
+
" from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
| 195 |
+
" from trl import SFTTrainer, SFTConfig\n",
|
| 196 |
+
" TRL_AVAILABLE = True\n",
|
| 197 |
+
" except Exception as e:\n",
|
| 198 |
+
" TRL_AVAILABLE = False\n",
|
| 199 |
+
" training_artifact[\"status\"] = \"skipped_missing_dependencies\"\n",
|
| 200 |
+
" training_artifact[\"error\"] = str(e)\n",
|
| 201 |
+
" print(\"Skipping training because dependencies are missing:\", e)\n",
|
| 202 |
+
"\n",
|
| 203 |
+
" if TRL_AVAILABLE:\n",
|
| 204 |
+
" print(\"Preparing demonstration data from environment rollouts...\")\n",
|
| 205 |
+
"\n",
|
| 206 |
+
" def build_prompt(observation) -> str:\n",
|
| 207 |
+
" return (\n",
|
| 208 |
+
" \"You are optimizing Python to C++. Choose next tool call.\\n\"\n",
|
| 209 |
+
" f\"Round: {observation.round_number}\\n\"\n",
|
| 210 |
+
" f\"Hardware: {json.dumps(observation.hardware_profile)}\\n\"\n",
|
| 211 |
+
" f\"Python:\\n{observation.python_code}\\n\"\n",
|
| 212 |
+
" f\"Last tool result: {json.dumps(observation.tool_result, default=str)[:1000]}\\n\"\n",
|
| 213 |
+
" \"Return ONLY JSON: {\\\"tool_name\\\":..., \\\"tool_args\\\":...}\"\n",
|
| 214 |
+
" )\n",
|
| 215 |
+
"\n",
|
| 216 |
+
" def action_to_text(action: OptimizationAction) -> str:\n",
|
| 217 |
+
" return json.dumps({\"tool_name\": action.tool_name, \"tool_args\": action.tool_args})\n",
|
| 218 |
+
"\n",
|
| 219 |
+
" rows: List[Dict[str, str]] = []\n",
|
| 220 |
+
" env = PolyglotOptimaEnvironment(max_rounds=CFG[\"max_rounds\"], max_calls_per_round=CFG[\"max_calls_per_round\"])\n",
|
| 221 |
+
"\n",
|
| 222 |
+
" for ep in range(12):\n",
|
| 223 |
+
" obs = env.reset(seed=4000 + ep)\n",
|
| 224 |
+
" done = False\n",
|
| 225 |
+
" while not done:\n",
|
| 226 |
+
" action = heuristic_policy(obs)\n",
|
| 227 |
+
" rows.append({\"text\": f\"<PROMPT>\\n{build_prompt(obs)}\\n<ANSWER>\\n{action_to_text(action)}\"})\n",
|
| 228 |
+
" step = env.step(action)\n",
|
| 229 |
+
" obs = step.observation\n",
|
| 230 |
+
" done = step.done\n",
|
| 231 |
+
" env.close()\n",
|
| 232 |
+
"\n",
|
| 233 |
+
" ds = Dataset.from_list(rows)\n",
|
| 234 |
+
" ds_split = ds.train_test_split(test_size=0.15, seed=CFG[\"seed\"])\n",
|
| 235 |
+
" print(\"Train samples:\", len(ds_split[\"train\"]), \"Eval samples:\", len(ds_split[\"test\"]))\n",
|
| 236 |
+
"\n",
|
| 237 |
+
" output_dir = PROJECT_ROOT / \"artifacts\" / \"sft-polyglot-optima\"\n",
|
| 238 |
+
" output_dir.mkdir(parents=True, exist_ok=True)\n",
|
| 239 |
+
"\n",
|
| 240 |
+
" tokenizer = AutoTokenizer.from_pretrained(CFG[\"model_name\"], use_fast=True)\n",
|
| 241 |
+
" if tokenizer.pad_token is None:\n",
|
| 242 |
+
" tokenizer.pad_token = tokenizer.eos_token\n",
|
| 243 |
+
"\n",
|
| 244 |
+
" model = AutoModelForCausalLM.from_pretrained(CFG[\"model_name\"])\n",
|
| 245 |
+
"\n",
|
| 246 |
+
" sft_cfg = SFTConfig(\n",
|
| 247 |
+
" output_dir=str(output_dir),\n",
|
| 248 |
+
" learning_rate=CFG[\"learning_rate\"],\n",
|
| 249 |
+
" max_steps=CFG[\"max_steps\"],\n",
|
| 250 |
+
" per_device_train_batch_size=1,\n",
|
| 251 |
+
" gradient_accumulation_steps=8,\n",
|
| 252 |
+
" logging_steps=10,\n",
|
| 253 |
+
" save_steps=40,\n",
|
| 254 |
+
" eval_strategy=\"steps\",\n",
|
| 255 |
+
" eval_steps=20,\n",
|
| 256 |
+
" report_to=[\"wandb\"] if USE_WANDB else [],\n",
|
| 257 |
+
" dataset_text_field=\"text\",\n",
|
| 258 |
+
" )\n",
|
| 259 |
+
"\n",
|
| 260 |
+
" trainer = SFTTrainer(\n",
|
| 261 |
+
" model=model,\n",
|
| 262 |
+
" args=sft_cfg,\n",
|
| 263 |
+
" train_dataset=ds_split[\"train\"],\n",
|
| 264 |
+
" eval_dataset=ds_split[\"test\"],\n",
|
| 265 |
+
" processing_class=tokenizer,\n",
|
| 266 |
+
" )\n",
|
| 267 |
+
"\n",
|
| 268 |
+
" train_result = trainer.train()\n",
|
| 269 |
+
" trainer.save_model(str(output_dir / \"final\"))\n",
|
| 270 |
+
" tokenizer.save_pretrained(str(output_dir / \"final\"))\n",
|
| 271 |
+
"\n",
|
| 272 |
+
" training_artifact.update({\n",
|
| 273 |
+
" \"status\": \"completed\",\n",
|
| 274 |
+
" \"output_dir\": str(output_dir / \"final\"),\n",
|
| 275 |
+
" \"train_loss\": float(train_result.training_loss),\n",
|
| 276 |
+
" })\n",
|
| 277 |
+
"\n",
|
| 278 |
+
" if USE_WANDB:\n",
|
| 279 |
+
" wandb.log({\"train/final_loss\": float(train_result.training_loss)})\n",
|
| 280 |
+
"\n",
|
| 281 |
+
"print(\"training_artifact:\", json.dumps(training_artifact, indent=2, default=str))"
|
| 282 |
+
],
|
| 283 |
+
"execution_count": null,
|
| 284 |
+
"outputs": [],
|
| 285 |
+
"id": "5dccc7e2"
|
| 286 |
+
},
|
| 287 |
+
{
|
| 288 |
+
"cell_type": "code",
|
| 289 |
+
"metadata": {},
|
| 290 |
+
"source": [
|
| 291 |
+
"# Post-training evaluation policy.\n",
|
| 292 |
+
"# If a trained model exists, we do model inference for action JSON.\n",
|
| 293 |
+
"# Otherwise we safely fall back to heuristic policy.\n",
|
| 294 |
+
"\n",
|
| 295 |
+
"import re\n",
|
| 296 |
+
"\n",
|
| 297 |
+
"_GENERATED_TOOL_RE = re.compile(r\"\\{.*\\}\", re.DOTALL)\n",
|
| 298 |
+
"\n",
|
| 299 |
+
"lm_policy = None\n",
|
| 300 |
+
"if training_artifact.get(\"status\") == \"completed\":\n",
|
| 301 |
+
" try:\n",
|
| 302 |
+
" import torch\n",
|
| 303 |
+
" from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
| 304 |
+
"\n",
|
| 305 |
+
" model_dir = training_artifact[\"output_dir\"]\n",
|
| 306 |
+
" inf_tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)\n",
|
| 307 |
+
" inf_model = AutoModelForCausalLM.from_pretrained(model_dir)\n",
|
| 308 |
+
" inf_model.eval()\n",
|
| 309 |
+
"\n",
|
| 310 |
+
" def _model_policy(observation):\n",
|
| 311 |
+
" prompt = (\n",
|
| 312 |
+
" \"<PROMPT>\\n\"\n",
|
| 313 |
+
" + \"You are optimizing Python to C++. Choose next tool call.\\n\"\n",
|
| 314 |
+
" + f\"Round: {observation.round_number}\\n\"\n",
|
| 315 |
+
" + f\"Hardware: {json.dumps(observation.hardware_profile)}\\n\"\n",
|
| 316 |
+
" + f\"Python:\\n{observation.python_code}\\n\"\n",
|
| 317 |
+
" + \"Return ONLY JSON: {\\\"tool_name\\\":..., \\\"tool_args\\\":...}\\n\"\n",
|
| 318 |
+
" + \"<ANSWER>\\n\"\n",
|
| 319 |
+
" )\n",
|
| 320 |
+
" inputs = inf_tokenizer(prompt, return_tensors=\"pt\")\n",
|
| 321 |
+
" with torch.no_grad():\n",
|
| 322 |
+
" out = inf_model.generate(**inputs, max_new_tokens=96, do_sample=False)\n",
|
| 323 |
+
" text = inf_tokenizer.decode(out[0], skip_special_tokens=True)\n",
|
| 324 |
+
" m = _GENERATED_TOOL_RE.search(text)\n",
|
| 325 |
+
" if not m:\n",
|
| 326 |
+
" return heuristic_policy(observation)\n",
|
| 327 |
+
" try:\n",
|
| 328 |
+
" data = json.loads(m.group(0))\n",
|
| 329 |
+
" tool_name = data.get(\"tool_name\")\n",
|
| 330 |
+
" tool_args = data.get(\"tool_args\", {})\n",
|
| 331 |
+
" if not isinstance(tool_name, str):\n",
|
| 332 |
+
" return heuristic_policy(observation)\n",
|
| 333 |
+
" return OptimizationAction(tool_name=tool_name, tool_args=tool_args, reasoning_trace=\"trained-model\")\n",
|
| 334 |
+
" except Exception:\n",
|
| 335 |
+
" return heuristic_policy(observation)\n",
|
| 336 |
+
"\n",
|
| 337 |
+
" lm_policy = _model_policy\n",
|
| 338 |
+
" print(\"Using trained model policy for evaluation\")\n",
|
| 339 |
+
" except Exception as e:\n",
|
| 340 |
+
" print(\"Falling back to heuristic policy due to inference load issue:\", e)\n",
|
| 341 |
+
"\n",
|
| 342 |
+
"trained_metrics = run_eval(lm_policy or heuristic_policy, n_episodes=CFG[\"episodes_eval\"], seed_start=2000)\n",
|
| 343 |
+
"\n",
|
| 344 |
+
"\n",
|
| 345 |
+
"def summarize(name, m):\n",
|
| 346 |
+
" import statistics\n",
|
| 347 |
+
" return {\n",
|
| 348 |
+
" \"name\": name,\n",
|
| 349 |
+
" \"reward_mean\": statistics.mean(m[\"reward\"]),\n",
|
| 350 |
+
" \"reward_median\": statistics.median(m[\"reward\"]),\n",
|
| 351 |
+
" \"correctness_mean\": statistics.mean(m[\"correctness\"]),\n",
|
| 352 |
+
" \"compile_rate\": statistics.mean(m[\"compile_success\"]),\n",
|
| 353 |
+
" \"portability_mean\": statistics.mean(m[\"portability\"]),\n",
|
| 354 |
+
" }\n",
|
| 355 |
+
"\n",
|
| 356 |
+
"baseline_summary = summarize(\"baseline\", baseline_metrics)\n",
|
| 357 |
+
"trained_summary = summarize(\"trained\", trained_metrics)\n",
|
| 358 |
+
"comparison = {\"baseline\": baseline_summary, \"trained\": trained_summary}\n",
|
| 359 |
+
"print(json.dumps(comparison, indent=2))\n",
|
| 360 |
+
"\n",
|
| 361 |
+
"if USE_WANDB:\n",
|
| 362 |
+
" wandb.log({\n",
|
| 363 |
+
" \"summary/reward_mean_baseline\": baseline_summary[\"reward_mean\"],\n",
|
| 364 |
+
" \"summary/reward_mean_trained\": trained_summary[\"reward_mean\"],\n",
|
| 365 |
+
" \"summary/correctness_mean_baseline\": baseline_summary[\"correctness_mean\"],\n",
|
| 366 |
+
" \"summary/correctness_mean_trained\": trained_summary[\"correctness_mean\"],\n",
|
| 367 |
+
" \"summary/compile_rate_baseline\": baseline_summary[\"compile_rate\"],\n",
|
| 368 |
+
" \"summary/compile_rate_trained\": trained_summary[\"compile_rate\"],\n",
|
| 369 |
+
" })"
|
| 370 |
+
],
|
| 371 |
+
"execution_count": null,
|
| 372 |
+
"outputs": [],
|
| 373 |
+
"id": "1ce841a5"
|
| 374 |
+
},
|
| 375 |
+
{
|
| 376 |
+
"cell_type": "code",
|
| 377 |
+
"metadata": {},
|
| 378 |
+
"source": [
|
| 379 |
+
"# Plot and export evidence figures for README.\n",
|
| 380 |
+
"PLOT_DIR = PROJECT_ROOT / \"docs\" / \"plots\"\n",
|
| 381 |
+
"PLOT_DIR.mkdir(parents=True, exist_ok=True)\n",
|
| 382 |
+
"\n",
|
| 383 |
+
"plt.figure(figsize=(8, 4))\n",
|
| 384 |
+
"plt.hist(baseline_metrics[\"reward\"], bins=10, alpha=0.6, label=\"baseline\")\n",
|
| 385 |
+
"plt.hist(trained_metrics[\"reward\"], bins=10, alpha=0.6, label=\"trained\")\n",
|
| 386 |
+
"plt.title(\"Reward Distribution: Baseline vs Trained\")\n",
|
| 387 |
+
"plt.xlabel(\"Episode reward\")\n",
|
| 388 |
+
"plt.ylabel(\"count\")\n",
|
| 389 |
+
"plt.legend()\n",
|
| 390 |
+
"reward_plot = PLOT_DIR / \"reward_distribution_baseline_vs_trained.png\"\n",
|
| 391 |
+
"plt.tight_layout()\n",
|
| 392 |
+
"plt.savefig(reward_plot, dpi=150)\n",
|
| 393 |
+
"plt.show()\n",
|
| 394 |
+
"\n",
|
| 395 |
+
"plt.figure(figsize=(8, 4))\n",
|
| 396 |
+
"plt.plot(baseline_metrics[\"correctness\"], label=\"baseline correctness\")\n",
|
| 397 |
+
"plt.plot(trained_metrics[\"correctness\"], label=\"trained correctness\")\n",
|
| 398 |
+
"plt.title(\"Correctness Pass Rate Across Episodes\")\n",
|
| 399 |
+
"plt.xlabel(\"episode\")\n",
|
| 400 |
+
"plt.ylabel(\"correctness_pass_rate\")\n",
|
| 401 |
+
"plt.legend()\n",
|
| 402 |
+
"corr_plot = PLOT_DIR / \"correctness_baseline_vs_trained.png\"\n",
|
| 403 |
+
"plt.tight_layout()\n",
|
| 404 |
+
"plt.savefig(corr_plot, dpi=150)\n",
|
| 405 |
+
"plt.show()\n",
|
| 406 |
+
"\n",
|
| 407 |
+
"print(\"Saved:\", reward_plot)\n",
|
| 408 |
+
"print(\"Saved:\", corr_plot)\n",
|
| 409 |
+
"\n",
|
| 410 |
+
"if USE_WANDB:\n",
|
| 411 |
+
" wandb.log({\n",
|
| 412 |
+
" \"plots/reward_distribution\": wandb.Image(str(reward_plot)),\n",
|
| 413 |
+
" \"plots/correctness_curve\": wandb.Image(str(corr_plot)),\n",
|
| 414 |
+
" })\n",
|
| 415 |
+
" wandb.finish()"
|
| 416 |
+
],
|
| 417 |
+
"execution_count": null,
|
| 418 |
+
"outputs": [],
|
| 419 |
+
"id": "7a87cf9c"
|
| 420 |
+
}
|
| 421 |
+
],
|
| 422 |
+
"metadata": {
|
| 423 |
+
"kernelspec": {
|
| 424 |
+
"display_name": "Python 3",
|
| 425 |
+
"language": "python",
|
| 426 |
+
"name": "python3"
|
| 427 |
+
},
|
| 428 |
+
"language_info": {
|
| 429 |
+
"name": "python"
|
| 430 |
+
}
|
| 431 |
+
},
|
| 432 |
+
"nbformat": 4,
|
| 433 |
+
"nbformat_minor": 5
|
| 434 |
+
}
|