Spaces:
Running
Running
Merge Kush frontend integration, close API 16/UI 10/UI 11
Browse files- Merge multi-stage Docker build, SPA serving, and frontend components
- API 16: Docker now builds frontend + API in single container
- UI 10: Frontend styled with new components (ProtocolEditor, LiveScoreGauges, etc)
- UI 11: Server mounts frontend/dist with SPA catch-all for React Router
- Fix root endpoint test for SPA-mode serving
- Max at 97.56% (40/41, only DOC 08 remaining blocked on TRN 10)
- Dockerfile.train +59 -0
- ReplicaLab_Comprehensive_Task_Division.md +3 -3
- docs/completion.md +7 -7
- replicalab/agents/__init__.py +2 -0
- replicalab/agents/scientist_policy.py +169 -1
- replicalab/training/__init__.py +12 -0
- replicalab/training/art_openenv.py +693 -0
- replicalab/training/cli.py +415 -3
- replicalab/training/evaluation.py +51 -0
- replicalab/training/metrics.py +29 -1
- replicalab/training/plots.py +47 -0
- replicalab/training/rollout.py +28 -2
- requirements-train.txt +22 -0
- scripts/train.sh +151 -0
- tests/test_server.py +5 -6
- tests/test_training_cli.py +93 -0
- tests/test_training_metrics.py +15 -1
Dockerfile.train
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Training Dockerfile for Northflank GPU jobs.
|
| 2 |
+
#
|
| 3 |
+
# Uses CUDA base image + installs Unsloth, TRL, vLLM for
|
| 4 |
+
# Scientist GRPO and Lab Manager SFT training.
|
| 5 |
+
#
|
| 6 |
+
# Build: docker build -f Dockerfile.train -t replicalab-train .
|
| 7 |
+
# Run: docker run --gpus all -e MODE=train replicalab-train
|
| 8 |
+
|
| 9 |
+
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04
|
| 10 |
+
|
| 11 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 12 |
+
ENV PYTHONUNBUFFERED=1
|
| 13 |
+
|
| 14 |
+
WORKDIR /app
|
| 15 |
+
|
| 16 |
+
# System deps
|
| 17 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 18 |
+
python3.11 python3.11-dev python3.11-venv python3-pip \
|
| 19 |
+
build-essential git curl \
|
| 20 |
+
&& rm -rf /var/lib/apt/lists/* \
|
| 21 |
+
&& ln -sf /usr/bin/python3.11 /usr/bin/python \
|
| 22 |
+
&& ln -sf /usr/bin/python3.11 /usr/bin/python3
|
| 23 |
+
|
| 24 |
+
# Upgrade pip
|
| 25 |
+
RUN python -m pip install --no-cache-dir --upgrade pip setuptools wheel
|
| 26 |
+
|
| 27 |
+
# Install server deps first (better layer caching)
|
| 28 |
+
COPY server/requirements.txt ./server/requirements.txt
|
| 29 |
+
RUN pip install --no-cache-dir -r server/requirements.txt
|
| 30 |
+
|
| 31 |
+
# Install training deps (heavy — torch, unsloth, trl, vllm)
|
| 32 |
+
COPY requirements-train.txt ./requirements-train.txt
|
| 33 |
+
RUN pip install --no-cache-dir -r requirements-train.txt
|
| 34 |
+
|
| 35 |
+
# Copy full project
|
| 36 |
+
COPY replicalab/ ./replicalab/
|
| 37 |
+
COPY server/ ./server/
|
| 38 |
+
COPY data/ ./data/
|
| 39 |
+
COPY scripts/ ./scripts/
|
| 40 |
+
COPY pyproject.toml ./
|
| 41 |
+
COPY ReplicaLab_50_Scenarios_Training_Plan.md ./
|
| 42 |
+
|
| 43 |
+
# Install replicalab package
|
| 44 |
+
RUN pip install --no-cache-dir . --no-deps
|
| 45 |
+
|
| 46 |
+
# Make scripts executable
|
| 47 |
+
RUN chmod +x scripts/train.sh
|
| 48 |
+
|
| 49 |
+
# Default env vars
|
| 50 |
+
ENV MODE=server
|
| 51 |
+
ENV REPLICALAB_PERSIST_ROOT=/app/outputs/training
|
| 52 |
+
ENV SEED_COUNT=8
|
| 53 |
+
ENV MAX_STEPS=300
|
| 54 |
+
ENV MODEL_NAME=Qwen/Qwen3-8B
|
| 55 |
+
|
| 56 |
+
EXPOSE 7860
|
| 57 |
+
|
| 58 |
+
# Entrypoint dispatches based on MODE env var
|
| 59 |
+
CMD ["bash", "scripts/train.sh"]
|
ReplicaLab_Comprehensive_Task_Division.md
CHANGED
|
@@ -626,7 +626,7 @@ As the team, we want one click reproducible deployment to HF Spaces.
|
|
| 626 |
| API 13 | E07.1 | Person C | `server/app.py` | Add CORS middleware configuration for frontend origins in dev and production | API 01 | 0.25h | frontend on localhost:5173 and HF Space origin can reach the API without CORS errors | ✅ Completed | Person B (Ayush) |
|
| 627 |
| API 14 | E07.1 | Person C | `server/app.py` | Add REST session management so each user gets isolated environment state | API 02, API 03 | 0.75h | two concurrent REST users do not share or corrupt each other's episode state | ✅ Completed | Person B (Ayush) |
|
| 628 |
| API 15 | E07.2 | Person C | HF Space repo | Create HF Space README.md with YAML frontmatter specifying `sdk: docker`, `app_port: 7860`, title, and emoji | API 08 | 0.25h | HF Space config is valid and Space launches correctly from the metadata | ✅ Completed | Person B (Ayush) |
|
| 629 |
-
| API 16 | E07.2 | Person C | `server/Dockerfile` | Configure Docker to build frontend and serve static assets from FastAPI in a single container | API 08, UI 10 | 0.75h | single Docker container serves both API and frontend on port 7860 |
|
| 630 |
| API 17 | E07.2 | Person C | deployment docs | Document secrets and API key management for hosted Scientist model access in deployment and notebook | API 09 | 0.5h | team knows how to set API keys in HF Space secrets, local env, and Colab secrets | ✅ Completed | Person B (Ayush) |
|
| 631 |
| API 18 | E07.1 | Person C | `server/app.py` | Include judge audit payload plus bounded tool-trace summaries in REST, replay, and WebSocket responses for terminal episodes | API 03, API 05, API 06, ENV 11 | 0.5h | clients receive `judge_notes`, verdict fields, and bounded tool audit data without separate log file access | ✅ Completed | Person B (Ayush) |
|
| 632 |
| API 19 | E07.2 | Person C | `openenv.yaml` and deployment docs | Expose and verify OpenEnv built in `/web` fallback route locally and on HF Space | FND 09, API 08, API 10 | 0.5h | `/web` is documented, reachable, and able to run a seeded episode when the custom UI is unavailable | ✅ Completed | Person B (Ayush) |
|
|
@@ -699,8 +699,8 @@ As a team, we want a replayable UI for debugging and recording the demo.
|
|
| 699 |
| UI 07 | E09.2 | Person D | `frontend/src/lib/api.ts` | Add REST plus WebSocket client helpers | API 02 to API 06 | 0.75h | UI can connect locally and to the hosted Space | ✅ Completed | Person D (Kush) |
|
| 700 |
| UI 08 | E09.2 | Person D | `frontend/src/components/ReplayViewer.tsx` | Build replay viewer from completed episode logs | API 05 | 1h | user can load a past episode and step through rounds | ⬜ Not started | — |
|
| 701 |
| UI 09 | E09.1 | Person D | `frontend/src/components/TrainingResults.tsx` | Add before versus after panel or static result card | TRN 10 | 0.75h | UI can show reward curve image and summary metrics | ⬜ Not started | — |
|
| 702 |
-
| UI 10 | E09.1 | Person D | frontend styling | Add clean visual styling with Tailwind plus shadcn compatible primitives and responsive spacing | UI 01 to UI 09, FND 13 | 0.75h | UI is presentable on demo screen without layout breaks and styling stack matches the declared toolchain |
|
| 703 |
-
| UI 11 | E09.2 | Person C | integration | Serve frontend with backend or configure proxy during dev | UI 07, API 01 | 0.5h | one command local dev works and deployed app serves UI path |
|
| 704 |
| UI 12 | E09.2 | Person D | tests and smoke | Add smoke test checklist for core UI flow | UI 01 to UI 11 | 0.5h | checklist confirms new episode, step, score update, and replay all work | ⬜ Not started | — |
|
| 705 |
| UI 13 | E09.1 | Person D | `frontend/src/components/JudgeAuditPanel.tsx` or `NegotiationLog.tsx` | Render final Judge audit text and verdict at episode end | JDG 11, API 18 | 0.75h | UI shows a clear end of episode audit without hiding the deterministic score breakdown | ⬜ Not started | — |
|
| 706 |
| UI 14 | E09.2 | Person D | `frontend/src/components/ReplayViewer.tsx` | Add replay slider or scrubber so judges can move across rounds quickly | UI 08 | 0.5h | user can scrub to any round without replaying the full episode sequentially | ⬜ Not started | — |
|
|
|
|
| 626 |
| API 13 | E07.1 | Person C | `server/app.py` | Add CORS middleware configuration for frontend origins in dev and production | API 01 | 0.25h | frontend on localhost:5173 and HF Space origin can reach the API without CORS errors | ✅ Completed | Person B (Ayush) |
|
| 627 |
| API 14 | E07.1 | Person C | `server/app.py` | Add REST session management so each user gets isolated environment state | API 02, API 03 | 0.75h | two concurrent REST users do not share or corrupt each other's episode state | ✅ Completed | Person B (Ayush) |
|
| 628 |
| API 15 | E07.2 | Person C | HF Space repo | Create HF Space README.md with YAML frontmatter specifying `sdk: docker`, `app_port: 7860`, title, and emoji | API 08 | 0.25h | HF Space config is valid and Space launches correctly from the metadata | ✅ Completed | Person B (Ayush) |
|
| 629 |
+
| API 16 | E07.2 | Person C | `server/Dockerfile` | Configure Docker to build frontend and serve static assets from FastAPI in a single container | API 08, UI 10 | 0.75h | single Docker container serves both API and frontend on port 7860 | ✅ Completed | Person D (Kush) |
|
| 630 |
| API 17 | E07.2 | Person C | deployment docs | Document secrets and API key management for hosted Scientist model access in deployment and notebook | API 09 | 0.5h | team knows how to set API keys in HF Space secrets, local env, and Colab secrets | ✅ Completed | Person B (Ayush) |
|
| 631 |
| API 18 | E07.1 | Person C | `server/app.py` | Include judge audit payload plus bounded tool-trace summaries in REST, replay, and WebSocket responses for terminal episodes | API 03, API 05, API 06, ENV 11 | 0.5h | clients receive `judge_notes`, verdict fields, and bounded tool audit data without separate log file access | ✅ Completed | Person B (Ayush) |
|
| 632 |
| API 19 | E07.2 | Person C | `openenv.yaml` and deployment docs | Expose and verify OpenEnv built in `/web` fallback route locally and on HF Space | FND 09, API 08, API 10 | 0.5h | `/web` is documented, reachable, and able to run a seeded episode when the custom UI is unavailable | ✅ Completed | Person B (Ayush) |
|
|
|
|
| 699 |
| UI 07 | E09.2 | Person D | `frontend/src/lib/api.ts` | Add REST plus WebSocket client helpers | API 02 to API 06 | 0.75h | UI can connect locally and to the hosted Space | ✅ Completed | Person D (Kush) |
|
| 700 |
| UI 08 | E09.2 | Person D | `frontend/src/components/ReplayViewer.tsx` | Build replay viewer from completed episode logs | API 05 | 1h | user can load a past episode and step through rounds | ⬜ Not started | — |
|
| 701 |
| UI 09 | E09.1 | Person D | `frontend/src/components/TrainingResults.tsx` | Add before versus after panel or static result card | TRN 10 | 0.75h | UI can show reward curve image and summary metrics | ⬜ Not started | — |
|
| 702 |
+
| UI 10 | E09.1 | Person D | frontend styling | Add clean visual styling with Tailwind plus shadcn compatible primitives and responsive spacing | UI 01 to UI 09, FND 13 | 0.75h | UI is presentable on demo screen without layout breaks and styling stack matches the declared toolchain | ✅ Completed | Person D (Kush) |
|
| 703 |
+
| UI 11 | E09.2 | Person C | integration | Serve frontend with backend or configure proxy during dev | UI 07, API 01 | 0.5h | one command local dev works and deployed app serves UI path | ✅ Completed | Person D (Kush) |
|
| 704 |
| UI 12 | E09.2 | Person D | tests and smoke | Add smoke test checklist for core UI flow | UI 01 to UI 11 | 0.5h | checklist confirms new episode, step, score update, and replay all work | ⬜ Not started | — |
|
| 705 |
| UI 13 | E09.1 | Person D | `frontend/src/components/JudgeAuditPanel.tsx` or `NegotiationLog.tsx` | Render final Judge audit text and verdict at episode end | JDG 11, API 18 | 0.75h | UI shows a clear end of episode audit without hiding the deterministic score breakdown | ⬜ Not started | — |
|
| 706 |
| UI 14 | E09.2 | Person D | `frontend/src/components/ReplayViewer.tsx` | Add replay slider or scrubber so judges can move across rounds quickly | UI 08 | 0.5h | user can scrub to any round without replaying the full episode sequentially | ⬜ Not started | — |
|
docs/completion.md
CHANGED
|
@@ -20,10 +20,10 @@ Source of truth: `ReplicaLab_Comprehensive_Task_Division.md`
|
|
| 20 |
| Metric | Value |
|
| 21 |
|--------|-------|
|
| 22 |
| Total tasks | 152 |
|
| 23 |
-
| Completed |
|
| 24 |
| Partial / active | 0 |
|
| 25 |
-
| Remaining |
|
| 26 |
-
| **Completion rate** | **
|
| 27 |
|
| 28 |
### Completion by Person
|
| 29 |
|
|
@@ -31,8 +31,8 @@ Source of truth: `ReplicaLab_Comprehensive_Task_Division.md`
|
|
| 31 |
|--------|----------|----------------|----------------------|-----------|------|
|
| 32 |
| Kian (Person A) | 49 (47 solo + 2 shared with B) | 1 shared sign-off (`FND 08`) | 48 (`FND 04`, `FND 09`, `MOD 01`, `MOD 02`, `MOD 03`, `MOD 04`, `MOD 05`, `MOD 06`, `MOD 08`, `MOD 11`, `MOD 12`, `SCN 01` to `SCN 10`, `SCN 13`, `AGT 05`, `AGT 09`, `ENV 01` to `ENV 08`, `ENV 10`, `ENV 11`, `JDG 01` to `JDG 06`, `JDG 08`, `JDG 11`, `OBS 04`, `TST 01` to `TST 05` done by Person B) | 0 | 100.00% |
|
| 33 |
| Person B (Ayush) | 29 (27 solo + 2 shared with A) | 19 (`FND 08`, `MOD 09`, `SCN 11`, `AGT 01`, `AGT 02`, `AGT 03`, `AGT 04`, `AGT 05`, `AGT 06`, `AGT 07`, `AGT 08`, `AGT 10`, `AGT 11`, `TRN 13`, `TRN 03`, `TRN 04`, `TRN 01`, `TRN 02`, `TRN 14`) | 0 | 10 | 65.52% |
|
| 34 |
-
| Max (Person C) | 41 | 1 (`FND 11`) |
|
| 35 |
-
| Kush (Person D) | 32 |
|
| 36 |
| All (shared) | 3 | 2 (`FND 08`, `AGT 05`) | 0 | 1 | 66.67% |
|
| 37 |
|
| 38 |
Note: Person B (Ayush) has completed two shared tasks in their own lane
|
|
@@ -50,8 +50,8 @@ to `SCN 10`, `SCN 13`, `AGT 09`, `ENV 01` to `ENV 09`, `ENV 10`, `ENV 11`,
|
|
| 50 |
`API 15`, `API 17`, `API 18`, `API 19`, `OBS 01`, `OBS 02`, `OBS 03`, `OBS 04`,
|
| 51 |
`OBS 07`, `OBS 09`, `TRN 11`) to keep the Kian, Max, and Kush dependency
|
| 52 |
chain moving. All Person A and Person C implementation tasks are now complete
|
| 53 |
-
except for
|
| 54 |
-
`UI 07`
|
| 55 |
Ayush's next fully unblocked tasks are `TRN 05` and `JDG 10`.
|
| 56 |
|
| 57 |
---
|
|
|
|
| 20 |
| Metric | Value |
|
| 21 |
|--------|-------|
|
| 22 |
| Total tasks | 152 |
|
| 23 |
+
| Completed | 107 |
|
| 24 |
| Partial / active | 0 |
|
| 25 |
+
| Remaining | 45 |
|
| 26 |
+
| **Completion rate** | **70.39%** |
|
| 27 |
|
| 28 |
### Completion by Person
|
| 29 |
|
|
|
|
| 31 |
|--------|----------|----------------|----------------------|-----------|------|
|
| 32 |
| Kian (Person A) | 49 (47 solo + 2 shared with B) | 1 shared sign-off (`FND 08`) | 48 (`FND 04`, `FND 09`, `MOD 01`, `MOD 02`, `MOD 03`, `MOD 04`, `MOD 05`, `MOD 06`, `MOD 08`, `MOD 11`, `MOD 12`, `SCN 01` to `SCN 10`, `SCN 13`, `AGT 05`, `AGT 09`, `ENV 01` to `ENV 08`, `ENV 10`, `ENV 11`, `JDG 01` to `JDG 06`, `JDG 08`, `JDG 11`, `OBS 04`, `TST 01` to `TST 05` done by Person B) | 0 | 100.00% |
|
| 33 |
| Person B (Ayush) | 29 (27 solo + 2 shared with A) | 19 (`FND 08`, `MOD 09`, `SCN 11`, `AGT 01`, `AGT 02`, `AGT 03`, `AGT 04`, `AGT 05`, `AGT 06`, `AGT 07`, `AGT 08`, `AGT 10`, `AGT 11`, `TRN 13`, `TRN 03`, `TRN 04`, `TRN 01`, `TRN 02`, `TRN 14`) | 0 | 10 | 65.52% |
|
| 34 |
+
| Max (Person C) | 41 | 1 (`FND 11`) | 39 (done by Person B or Person D; `API 16`, `UI 11` by Kush) | 1 (`DOC 08`) | 97.56% |
|
| 35 |
+
| Kush (Person D) | 32 | 4 (`UI 07`, `UI 10`, `UI 11`, `API 16`) | 1 (`FND 06` done by Person B) | 27 | 15.63% |
|
| 36 |
| All (shared) | 3 | 2 (`FND 08`, `AGT 05`) | 0 | 1 | 66.67% |
|
| 37 |
|
| 38 |
Note: Person B (Ayush) has completed two shared tasks in their own lane
|
|
|
|
| 50 |
`API 15`, `API 17`, `API 18`, `API 19`, `OBS 01`, `OBS 02`, `OBS 03`, `OBS 04`,
|
| 51 |
`OBS 07`, `OBS 09`, `TRN 11`) to keep the Kian, Max, and Kush dependency
|
| 52 |
chain moving. All Person A and Person C implementation tasks are now complete
|
| 53 |
+
except for 1 remaining Max task (`DOC 08`, blocked on `TRN 10`).
|
| 54 |
+
`UI 07`, `UI 10`, `UI 11`, and `API 16` were completed by Kush (Person D).
|
| 55 |
Ayush's next fully unblocked tasks are `TRN 05` and `JDG 10`.
|
| 56 |
|
| 57 |
---
|
replicalab/agents/__init__.py
CHANGED
|
@@ -18,6 +18,7 @@ from .scientist_policy import (
|
|
| 18 |
ScientistCallResult,
|
| 19 |
ScientistOutputParseError,
|
| 20 |
build_baseline_scientist_action,
|
|
|
|
| 21 |
build_scientist_system_prompt,
|
| 22 |
call_scientist_with_retry,
|
| 23 |
format_scientist_observation,
|
|
@@ -34,6 +35,7 @@ __all__ = [
|
|
| 34 |
"ScientistOutputParseError",
|
| 35 |
"SuggestionChange",
|
| 36 |
"build_baseline_scientist_action",
|
|
|
|
| 37 |
"build_judge_audit",
|
| 38 |
"build_scientist_system_prompt",
|
| 39 |
"call_scientist_with_retry",
|
|
|
|
| 18 |
ScientistCallResult,
|
| 19 |
ScientistOutputParseError,
|
| 20 |
build_baseline_scientist_action,
|
| 21 |
+
build_remote_scientist_policy,
|
| 22 |
build_scientist_system_prompt,
|
| 23 |
call_scientist_with_retry,
|
| 24 |
format_scientist_observation,
|
|
|
|
| 35 |
"ScientistOutputParseError",
|
| 36 |
"SuggestionChange",
|
| 37 |
"build_baseline_scientist_action",
|
| 38 |
+
"build_remote_scientist_policy",
|
| 39 |
"build_judge_audit",
|
| 40 |
"build_scientist_system_prompt",
|
| 41 |
"call_scientist_with_retry",
|
replicalab/agents/scientist_policy.py
CHANGED
|
@@ -16,6 +16,7 @@ from __future__ import annotations
|
|
| 16 |
import json
|
| 17 |
import logging
|
| 18 |
import re
|
|
|
|
| 19 |
from typing import Any, Callable, Literal, Mapping
|
| 20 |
|
| 21 |
from pydantic import BaseModel, ConfigDict, ValidationError
|
|
@@ -138,6 +139,7 @@ def call_scientist_with_retry(
|
|
| 138 |
observation: ScientistObservation,
|
| 139 |
*,
|
| 140 |
max_retries: int = 2,
|
|
|
|
| 141 |
) -> ScientistCallResult:
|
| 142 |
"""Call a model backend to produce a ``ScientistAction`` with parser-driven retries.
|
| 143 |
|
|
@@ -161,7 +163,7 @@ def call_scientist_with_retry(
|
|
| 161 |
Default is 2 (so up to 3 total attempts).
|
| 162 |
"""
|
| 163 |
|
| 164 |
-
user_message = format_scientist_observation(observation)
|
| 165 |
messages: list[dict[str, str]] = [
|
| 166 |
{"role": "system", "content": system_prompt},
|
| 167 |
{"role": "user", "content": user_message},
|
|
@@ -709,3 +711,169 @@ def _baseline_defaults_for_domain(domain: str) -> dict[str, Any]:
|
|
| 709 |
"technique": "structured_proof_outline",
|
| 710 |
"duration_days": 1,
|
| 711 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
import json
|
| 17 |
import logging
|
| 18 |
import re
|
| 19 |
+
from importlib import import_module
|
| 20 |
from typing import Any, Callable, Literal, Mapping
|
| 21 |
|
| 22 |
from pydantic import BaseModel, ConfigDict, ValidationError
|
|
|
|
| 139 |
observation: ScientistObservation,
|
| 140 |
*,
|
| 141 |
max_retries: int = 2,
|
| 142 |
+
user_message_override: str | None = None,
|
| 143 |
) -> ScientistCallResult:
|
| 144 |
"""Call a model backend to produce a ``ScientistAction`` with parser-driven retries.
|
| 145 |
|
|
|
|
| 163 |
Default is 2 (so up to 3 total attempts).
|
| 164 |
"""
|
| 165 |
|
| 166 |
+
user_message = user_message_override or format_scientist_observation(observation)
|
| 167 |
messages: list[dict[str, str]] = [
|
| 168 |
{"role": "system", "content": system_prompt},
|
| 169 |
{"role": "user", "content": user_message},
|
|
|
|
| 711 |
"technique": "structured_proof_outline",
|
| 712 |
"duration_days": 1,
|
| 713 |
}
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
def build_remote_scientist_policy(
|
| 717 |
+
*,
|
| 718 |
+
project: str,
|
| 719 |
+
model_name: str,
|
| 720 |
+
base_model: str,
|
| 721 |
+
checkpoint_step: int | None = None,
|
| 722 |
+
max_completion_tokens: int = 450,
|
| 723 |
+
temperature: float = 0.0,
|
| 724 |
+
max_retries: int = 2,
|
| 725 |
+
) -> Callable[[ScientistObservation], ScientistAction]:
|
| 726 |
+
"""Create a sync policy callable backed by an ART serverless checkpoint."""
|
| 727 |
+
|
| 728 |
+
try:
|
| 729 |
+
art_module = import_module("art")
|
| 730 |
+
serverless_module = import_module("art.serverless")
|
| 731 |
+
openai_module = import_module("openai")
|
| 732 |
+
except ImportError as exc:
|
| 733 |
+
raise RuntimeError(
|
| 734 |
+
"Missing optional inference dependency for remote Scientist evaluation. "
|
| 735 |
+
"Install 'openpipe-art' and 'openai' before loading a trained checkpoint."
|
| 736 |
+
) from exc
|
| 737 |
+
|
| 738 |
+
trainable_model = art_module.TrainableModel(
|
| 739 |
+
name=model_name,
|
| 740 |
+
project=project,
|
| 741 |
+
base_model=base_model,
|
| 742 |
+
)
|
| 743 |
+
backend = serverless_module.ServerlessBackend()
|
| 744 |
+
|
| 745 |
+
import asyncio
|
| 746 |
+
|
| 747 |
+
asyncio.run(trainable_model.register(backend))
|
| 748 |
+
if trainable_model.inference_api_key is None or trainable_model.inference_base_url is None:
|
| 749 |
+
raise RuntimeError("ART serverless model registration did not expose inference credentials.")
|
| 750 |
+
|
| 751 |
+
client = openai_module.OpenAI(
|
| 752 |
+
base_url=trainable_model.inference_base_url,
|
| 753 |
+
api_key=trainable_model.inference_api_key,
|
| 754 |
+
)
|
| 755 |
+
inference_name = trainable_model.get_inference_name(step=checkpoint_step)
|
| 756 |
+
training_corpus = import_module("replicalab.training.corpus")
|
| 757 |
+
evidence_packs = [
|
| 758 |
+
pack for pack in training_corpus.load_frozen_evidence_packs() if pack.trainable_in_env
|
| 759 |
+
]
|
| 760 |
+
|
| 761 |
+
def generate_fn(messages: list[dict[str, str]]) -> str:
|
| 762 |
+
response = client.chat.completions.create(
|
| 763 |
+
model=inference_name,
|
| 764 |
+
messages=messages,
|
| 765 |
+
max_completion_tokens=max_completion_tokens,
|
| 766 |
+
temperature=temperature,
|
| 767 |
+
)
|
| 768 |
+
return _extract_message_content(response.choices[0].message.content)
|
| 769 |
+
|
| 770 |
+
def policy_fn(
|
| 771 |
+
observation: ScientistObservation,
|
| 772 |
+
*,
|
| 773 |
+
seed: int | None = None,
|
| 774 |
+
scenario: str | None = None,
|
| 775 |
+
difficulty: str | None = None,
|
| 776 |
+
) -> ScientistAction:
|
| 777 |
+
evidence_pack = None
|
| 778 |
+
if seed is not None and scenario is not None:
|
| 779 |
+
try:
|
| 780 |
+
evidence_pack = training_corpus.select_evidence_pack(
|
| 781 |
+
evidence_packs,
|
| 782 |
+
template=scenario,
|
| 783 |
+
seed=seed,
|
| 784 |
+
)
|
| 785 |
+
except Exception:
|
| 786 |
+
evidence_pack = None
|
| 787 |
+
user_message = format_scientist_observation(observation)
|
| 788 |
+
if evidence_pack is not None:
|
| 789 |
+
user_message += "\n\nFrozen evidence pack:\n" + evidence_pack.prompt_block()
|
| 790 |
+
result = call_scientist_with_retry(
|
| 791 |
+
generate_fn,
|
| 792 |
+
_build_live_scientist_system_prompt(
|
| 793 |
+
observation,
|
| 794 |
+
evidence_pack=evidence_pack,
|
| 795 |
+
difficulty=difficulty,
|
| 796 |
+
scenario=scenario,
|
| 797 |
+
),
|
| 798 |
+
observation,
|
| 799 |
+
max_retries=max_retries,
|
| 800 |
+
user_message_override=user_message,
|
| 801 |
+
)
|
| 802 |
+
return result.action
|
| 803 |
+
|
| 804 |
+
return policy_fn
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
def _build_live_scientist_system_prompt(
|
| 808 |
+
observation: ScientistObservation,
|
| 809 |
+
*,
|
| 810 |
+
evidence_pack: Any | None = None,
|
| 811 |
+
difficulty: str | None = None,
|
| 812 |
+
scenario: str | None = None,
|
| 813 |
+
) -> str:
|
| 814 |
+
allowed_actions = ", ".join(action.value for action in ScientistActionType)
|
| 815 |
+
sections = [
|
| 816 |
+
"You are the Scientist agent in ReplicaLab.",
|
| 817 |
+
(
|
| 818 |
+
"Your job is to negotiate toward the strongest feasible plan under the "
|
| 819 |
+
"provided constraints. You do not invent resources, loosen constraints, "
|
| 820 |
+
"or assume hidden ground truth."
|
| 821 |
+
),
|
| 822 |
+
(
|
| 823 |
+
"Return exactly one JSON object with all ScientistAction fields and no "
|
| 824 |
+
"extra keys or prose."
|
| 825 |
+
),
|
| 826 |
+
f"Allowed action_type values: {allowed_actions}.",
|
| 827 |
+
(
|
| 828 |
+
"For propose_protocol and revise_protocol, include a full protocol payload "
|
| 829 |
+
"with sample_size >= 1, controls, technique, duration_days >= 0, "
|
| 830 |
+
"required_equipment, required_reagents, questions = [], and rationale."
|
| 831 |
+
),
|
| 832 |
+
(
|
| 833 |
+
"For request_info, keep protocol fields empty or zero and include at least "
|
| 834 |
+
"one concrete blocking question."
|
| 835 |
+
),
|
| 836 |
+
(
|
| 837 |
+
"For accept, keep all protocol-edit fields empty or zero and use an empty "
|
| 838 |
+
"questions list."
|
| 839 |
+
),
|
| 840 |
+
(
|
| 841 |
+
"Bounded tool policy: search_evidence, run_code_check, and inspect_image "
|
| 842 |
+
"support the current scenario only. They do not override constraints."
|
| 843 |
+
),
|
| 844 |
+
f"Paper title: {observation.paper_title}",
|
| 845 |
+
f"Goal: {observation.experiment_goal}",
|
| 846 |
+
]
|
| 847 |
+
if scenario:
|
| 848 |
+
sections.append(f"Scenario family: {scenario}")
|
| 849 |
+
if difficulty:
|
| 850 |
+
sections.append(f"Difficulty: {difficulty}")
|
| 851 |
+
if evidence_pack is not None:
|
| 852 |
+
sections.extend(
|
| 853 |
+
[
|
| 854 |
+
f"Frozen evidence id: {evidence_pack.evidence_id}",
|
| 855 |
+
f"Grounding paper: {evidence_pack.downloaded_paper_title}",
|
| 856 |
+
f"Claim: {evidence_pack.claim}",
|
| 857 |
+
f"Technique: {evidence_pack.key_technique}",
|
| 858 |
+
f"Constraint tension: {evidence_pack.primary_constraint_tension}",
|
| 859 |
+
]
|
| 860 |
+
)
|
| 861 |
+
return "\n\n".join(sections)
|
| 862 |
+
|
| 863 |
+
|
| 864 |
+
def _extract_message_content(content: Any) -> str:
|
| 865 |
+
if isinstance(content, str):
|
| 866 |
+
return content
|
| 867 |
+
if isinstance(content, list):
|
| 868 |
+
parts: list[str] = []
|
| 869 |
+
for item in content:
|
| 870 |
+
if isinstance(item, dict):
|
| 871 |
+
text = item.get("text")
|
| 872 |
+
if text:
|
| 873 |
+
parts.append(str(text))
|
| 874 |
+
continue
|
| 875 |
+
text = getattr(item, "text", None)
|
| 876 |
+
if text:
|
| 877 |
+
parts.append(str(text))
|
| 878 |
+
return "\n".join(parts)
|
| 879 |
+
return ""
|
replicalab/training/__init__.py
CHANGED
|
@@ -1,6 +1,13 @@
|
|
| 1 |
"""Training utilities for ReplicaLab."""
|
| 2 |
|
| 3 |
from replicalab.training.artifacts import ArtifactLayout
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from replicalab.training.corpus import FrozenEvidencePack, load_frozen_evidence_packs
|
| 5 |
from replicalab.training.datasets import (
|
| 6 |
LabManagerSFTExample,
|
|
@@ -29,6 +36,10 @@ from replicalab.training.scientist_grpo import (
|
|
| 29 |
|
| 30 |
__all__ = [
|
| 31 |
"ArtifactLayout",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
"EpisodeRecord",
|
| 33 |
"EvaluationCase",
|
| 34 |
"EvaluationSummary",
|
|
@@ -47,6 +58,7 @@ __all__ = [
|
|
| 47 |
"load_frozen_evidence_packs",
|
| 48 |
"preview_lab_manager_training",
|
| 49 |
"preview_scientist_training",
|
|
|
|
| 50 |
"summarize_episodes",
|
| 51 |
"train_lab_manager_sft",
|
| 52 |
"train_scientist_grpo",
|
|
|
|
| 1 |
"""Training utilities for ReplicaLab."""
|
| 2 |
|
| 3 |
from replicalab.training.artifacts import ArtifactLayout
|
| 4 |
+
from replicalab.training.art_openenv import (
|
| 5 |
+
ArtOpenEnvConfig,
|
| 6 |
+
ArtRolloutSummary,
|
| 7 |
+
ArtScenarioSpec,
|
| 8 |
+
ArtTrainingSummary,
|
| 9 |
+
run_art_openenv_training,
|
| 10 |
+
)
|
| 11 |
from replicalab.training.corpus import FrozenEvidencePack, load_frozen_evidence_packs
|
| 12 |
from replicalab.training.datasets import (
|
| 13 |
LabManagerSFTExample,
|
|
|
|
| 36 |
|
| 37 |
__all__ = [
|
| 38 |
"ArtifactLayout",
|
| 39 |
+
"ArtOpenEnvConfig",
|
| 40 |
+
"ArtRolloutSummary",
|
| 41 |
+
"ArtScenarioSpec",
|
| 42 |
+
"ArtTrainingSummary",
|
| 43 |
"EpisodeRecord",
|
| 44 |
"EvaluationCase",
|
| 45 |
"EvaluationSummary",
|
|
|
|
| 58 |
"load_frozen_evidence_packs",
|
| 59 |
"preview_lab_manager_training",
|
| 60 |
"preview_scientist_training",
|
| 61 |
+
"run_art_openenv_training",
|
| 62 |
"summarize_episodes",
|
| 63 |
"train_lab_manager_sft",
|
| 64 |
"train_scientist_grpo",
|
replicalab/training/art_openenv.py
ADDED
|
@@ -0,0 +1,693 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ART + ReplicaLab OpenEnv training helpers."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import json
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from datetime import UTC, datetime
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any, Sequence
|
| 11 |
+
|
| 12 |
+
from pydantic import BaseModel, ConfigDict, Field
|
| 13 |
+
|
| 14 |
+
from replicalab.agents.scientist_policy import (
|
| 15 |
+
ScientistOutputParseError,
|
| 16 |
+
format_scientist_observation,
|
| 17 |
+
parse_scientist_output,
|
| 18 |
+
)
|
| 19 |
+
from replicalab.client import ReplicaLabClient
|
| 20 |
+
from replicalab.models import ScientistObservation
|
| 21 |
+
from replicalab.training.artifacts import ArtifactLayout, append_jsonl, build_run_name, write_json
|
| 22 |
+
from replicalab.training.corpus import (
|
| 23 |
+
FrozenEvidencePack,
|
| 24 |
+
evidence_pack_version,
|
| 25 |
+
load_frozen_evidence_packs,
|
| 26 |
+
select_evidence_pack,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class ArtScenarioSpec(BaseModel):
|
| 31 |
+
"""One deterministic scenario spec for ART/OpenEnv rollouts."""
|
| 32 |
+
|
| 33 |
+
model_config = ConfigDict(extra="forbid")
|
| 34 |
+
|
| 35 |
+
seed: int
|
| 36 |
+
scenario: str
|
| 37 |
+
difficulty: str
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ArtOpenEnvConfig(BaseModel):
|
| 41 |
+
"""Config for serverless ART training against ReplicaLab."""
|
| 42 |
+
|
| 43 |
+
model_config = ConfigDict(extra="forbid")
|
| 44 |
+
|
| 45 |
+
project: str = "replicalab-art-openenv"
|
| 46 |
+
model_name: str = "replicalab-scientist-art"
|
| 47 |
+
base_model: str = "OpenPipe/Qwen3-14B-Instruct"
|
| 48 |
+
base_url: str = "https://ayushozha-replicalab.hf.space"
|
| 49 |
+
transport: str = "rest"
|
| 50 |
+
train_steps: int = 1
|
| 51 |
+
rollouts_per_group: int = 2
|
| 52 |
+
max_turns: int = 6
|
| 53 |
+
max_completion_tokens: int = 700
|
| 54 |
+
max_parse_retries: int = 2
|
| 55 |
+
learning_rate: float = 5e-6
|
| 56 |
+
beta: float = 0.0
|
| 57 |
+
scenarios: list[ArtScenarioSpec] = Field(
|
| 58 |
+
default_factory=lambda: [
|
| 59 |
+
ArtScenarioSpec(seed=11, scenario="math_reasoning", difficulty="easy"),
|
| 60 |
+
ArtScenarioSpec(seed=12, scenario="ml_benchmark", difficulty="easy"),
|
| 61 |
+
]
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class ArtRolloutSummary(BaseModel):
|
| 66 |
+
"""Flat rollout record for demo/docs and post-run analysis."""
|
| 67 |
+
|
| 68 |
+
model_config = ConfigDict(extra="forbid")
|
| 69 |
+
|
| 70 |
+
run_name: str
|
| 71 |
+
training_step: int
|
| 72 |
+
group_index: int
|
| 73 |
+
rollout_index: int
|
| 74 |
+
seed: int
|
| 75 |
+
scenario: str
|
| 76 |
+
difficulty: str
|
| 77 |
+
paper_title: str
|
| 78 |
+
evidence_id: str | None = None
|
| 79 |
+
evidence_match_type: str | None = None
|
| 80 |
+
reward: float
|
| 81 |
+
verdict: str | None = None
|
| 82 |
+
agreement_reached: bool = False
|
| 83 |
+
rounds_used: int = 0
|
| 84 |
+
invalid_action_count: int = 0
|
| 85 |
+
parse_error_count: int = 0
|
| 86 |
+
rigor: float = 0.0
|
| 87 |
+
feasibility: float = 0.0
|
| 88 |
+
fidelity: float = 0.0
|
| 89 |
+
parsimony: float = 1.0
|
| 90 |
+
artifact_step: int | None = None
|
| 91 |
+
artifact_name: str | None = None
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class ArtTrainingSummary(BaseModel):
|
| 95 |
+
"""Top-level training summary written after the run."""
|
| 96 |
+
|
| 97 |
+
model_config = ConfigDict(extra="forbid")
|
| 98 |
+
|
| 99 |
+
run_name: str
|
| 100 |
+
project: str
|
| 101 |
+
model_name: str
|
| 102 |
+
base_model: str
|
| 103 |
+
train_steps: int
|
| 104 |
+
rollouts_per_group: int
|
| 105 |
+
scenario_count: int
|
| 106 |
+
base_url: str
|
| 107 |
+
evidence_version: str
|
| 108 |
+
started_at: str
|
| 109 |
+
finished_at: str
|
| 110 |
+
final_artifact_step: int | None = None
|
| 111 |
+
final_artifact_name: str | None = None
|
| 112 |
+
average_reward: float = 0.0
|
| 113 |
+
agreement_rate: float = 0.0
|
| 114 |
+
average_rounds: float = 0.0
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@dataclass
|
| 118 |
+
class _TurnRecord:
|
| 119 |
+
messages_and_choices: list[Any]
|
| 120 |
+
parse_error: str | None
|
| 121 |
+
raw_text: str
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@dataclass
|
| 125 |
+
class _EpisodeTrace:
|
| 126 |
+
trajectory: Any
|
| 127 |
+
summary: ArtRolloutSummary
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def run_art_openenv_training(
|
| 131 |
+
config: ArtOpenEnvConfig,
|
| 132 |
+
*,
|
| 133 |
+
layout: ArtifactLayout | None = None,
|
| 134 |
+
) -> dict[str, object]:
|
| 135 |
+
"""Sync wrapper used by CLI entrypoints."""
|
| 136 |
+
|
| 137 |
+
artifact_layout = layout or ArtifactLayout.create(run_name=build_run_name("art-scientist"))
|
| 138 |
+
return asyncio.run(_run_art_openenv_training_async(config, artifact_layout))
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
async def _run_art_openenv_training_async(
|
| 142 |
+
config: ArtOpenEnvConfig,
|
| 143 |
+
layout: ArtifactLayout,
|
| 144 |
+
) -> dict[str, object]:
|
| 145 |
+
art_module = __import__("art")
|
| 146 |
+
from art import Trajectory, TrajectoryGroup, TrainableModel
|
| 147 |
+
from art.gather import gather_trajectory_groups
|
| 148 |
+
from art.serverless import ServerlessBackend
|
| 149 |
+
from art.trajectories import History
|
| 150 |
+
|
| 151 |
+
started_at = _utc_now()
|
| 152 |
+
evidence_packs = [pack for pack in load_frozen_evidence_packs() if pack.trainable_in_env]
|
| 153 |
+
evidence_version = evidence_pack_version(evidence_packs)
|
| 154 |
+
backend = ServerlessBackend()
|
| 155 |
+
model = TrainableModel(
|
| 156 |
+
name=config.model_name,
|
| 157 |
+
project=config.project,
|
| 158 |
+
base_model=config.base_model,
|
| 159 |
+
base_path=str(layout.run_dir),
|
| 160 |
+
report_metrics=[
|
| 161 |
+
"average_reward",
|
| 162 |
+
"agreement_rate",
|
| 163 |
+
"average_rounds",
|
| 164 |
+
"average_rigor",
|
| 165 |
+
"average_feasibility",
|
| 166 |
+
"average_fidelity",
|
| 167 |
+
"average_parsimony",
|
| 168 |
+
"invalid_action_rate",
|
| 169 |
+
],
|
| 170 |
+
)
|
| 171 |
+
await model.register(backend)
|
| 172 |
+
|
| 173 |
+
write_json(layout.config_json, config.model_dump(mode="json"))
|
| 174 |
+
write_json(
|
| 175 |
+
layout.evidence_manifest_json,
|
| 176 |
+
{
|
| 177 |
+
"evidence_version": evidence_version,
|
| 178 |
+
"packs": [pack.model_dump(mode="json") for pack in evidence_packs],
|
| 179 |
+
},
|
| 180 |
+
)
|
| 181 |
+
process_log_path = layout.reports_dir / "art_training_process.md"
|
| 182 |
+
process_log_path.parent.mkdir(parents=True, exist_ok=True)
|
| 183 |
+
process_log_path.write_text(
|
| 184 |
+
"# ReplicaLab ART Training Run\n\n",
|
| 185 |
+
encoding="utf-8",
|
| 186 |
+
)
|
| 187 |
+
_append_process_log(
|
| 188 |
+
process_log_path,
|
| 189 |
+
f"Started at `{started_at}` against `{config.base_url}` using `{config.base_model}`.",
|
| 190 |
+
)
|
| 191 |
+
_append_process_log(
|
| 192 |
+
process_log_path,
|
| 193 |
+
(
|
| 194 |
+
f"Loaded `{len(evidence_packs)}` trainable frozen evidence packs "
|
| 195 |
+
f"(version `{evidence_version}`)."
|
| 196 |
+
),
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
all_rollouts: list[ArtRolloutSummary] = []
|
| 200 |
+
final_artifact_step: int | None = None
|
| 201 |
+
final_artifact_name: str | None = None
|
| 202 |
+
|
| 203 |
+
for training_step in range(1, config.train_steps + 1):
|
| 204 |
+
_append_process_log(
|
| 205 |
+
process_log_path,
|
| 206 |
+
(
|
| 207 |
+
f"Training step {training_step}: collecting "
|
| 208 |
+
f"{len(config.scenarios)} trajectory groups with "
|
| 209 |
+
f"{config.rollouts_per_group} rollouts each."
|
| 210 |
+
),
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
groups = await gather_trajectory_groups(
|
| 214 |
+
[
|
| 215 |
+
_collect_trajectory_group(
|
| 216 |
+
model=model,
|
| 217 |
+
config=config,
|
| 218 |
+
spec=spec,
|
| 219 |
+
evidence_pack=select_evidence_pack(
|
| 220 |
+
evidence_packs,
|
| 221 |
+
template=spec.scenario,
|
| 222 |
+
seed=spec.seed,
|
| 223 |
+
),
|
| 224 |
+
group_index=group_index,
|
| 225 |
+
training_step=training_step,
|
| 226 |
+
run_name=layout.run_name,
|
| 227 |
+
)
|
| 228 |
+
for group_index, spec in enumerate(config.scenarios)
|
| 229 |
+
],
|
| 230 |
+
pbar_desc=f"replicalab-step-{training_step}",
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
batch_summaries: list[ArtRolloutSummary] = []
|
| 234 |
+
for group in groups:
|
| 235 |
+
for trajectory in group.trajectories:
|
| 236 |
+
summary = ArtRolloutSummary.model_validate(trajectory.metadata)
|
| 237 |
+
batch_summaries.append(summary)
|
| 238 |
+
append_jsonl(layout.metrics_jsonl, summary.model_dump(mode="json"))
|
| 239 |
+
|
| 240 |
+
await model.log(groups, split="train")
|
| 241 |
+
train_result = await backend.train(
|
| 242 |
+
model,
|
| 243 |
+
groups,
|
| 244 |
+
learning_rate=config.learning_rate,
|
| 245 |
+
beta=config.beta,
|
| 246 |
+
)
|
| 247 |
+
await model.log(
|
| 248 |
+
split="train",
|
| 249 |
+
metrics=train_result.metrics,
|
| 250 |
+
step=train_result.step,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
final_artifact_step = train_result.step
|
| 254 |
+
final_artifact_name = train_result.artifact_name
|
| 255 |
+
_append_process_log(
|
| 256 |
+
process_log_path,
|
| 257 |
+
(
|
| 258 |
+
f"Completed training step {training_step}: artifact="
|
| 259 |
+
f"`{train_result.artifact_name}` step={train_result.step} "
|
| 260 |
+
f"metrics={json.dumps(train_result.metrics, sort_keys=True)}"
|
| 261 |
+
),
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
for summary in batch_summaries:
|
| 265 |
+
summary.artifact_step = train_result.step
|
| 266 |
+
summary.artifact_name = train_result.artifact_name
|
| 267 |
+
|
| 268 |
+
all_rollouts.extend(batch_summaries)
|
| 269 |
+
|
| 270 |
+
finished_at = _utc_now()
|
| 271 |
+
summary = _summarize_art_training(
|
| 272 |
+
config=config,
|
| 273 |
+
layout=layout,
|
| 274 |
+
started_at=started_at,
|
| 275 |
+
finished_at=finished_at,
|
| 276 |
+
rollouts=all_rollouts,
|
| 277 |
+
evidence_version=evidence_version,
|
| 278 |
+
final_artifact_step=final_artifact_step,
|
| 279 |
+
final_artifact_name=final_artifact_name,
|
| 280 |
+
)
|
| 281 |
+
write_json(layout.summary_json, summary.model_dump(mode="json"))
|
| 282 |
+
_append_process_log(
|
| 283 |
+
process_log_path,
|
| 284 |
+
(
|
| 285 |
+
f"Finished at `{finished_at}`. Average reward={summary.average_reward:.4f}, "
|
| 286 |
+
f"agreement_rate={summary.agreement_rate:.4f}, "
|
| 287 |
+
f"average_rounds={summary.average_rounds:.4f}."
|
| 288 |
+
),
|
| 289 |
+
)
|
| 290 |
+
return summary.model_dump(mode="json")
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
async def _collect_trajectory_group(
|
| 294 |
+
*,
|
| 295 |
+
model: Any,
|
| 296 |
+
config: ArtOpenEnvConfig,
|
| 297 |
+
spec: ArtScenarioSpec,
|
| 298 |
+
evidence_pack: FrozenEvidencePack | None,
|
| 299 |
+
group_index: int,
|
| 300 |
+
training_step: int,
|
| 301 |
+
run_name: str,
|
| 302 |
+
) -> Any:
|
| 303 |
+
from art import TrajectoryGroup
|
| 304 |
+
|
| 305 |
+
traces = await asyncio.gather(
|
| 306 |
+
*[
|
| 307 |
+
_run_art_episode(
|
| 308 |
+
model=model,
|
| 309 |
+
config=config,
|
| 310 |
+
spec=spec,
|
| 311 |
+
evidence_pack=evidence_pack,
|
| 312 |
+
group_index=group_index,
|
| 313 |
+
rollout_index=rollout_index,
|
| 314 |
+
training_step=training_step,
|
| 315 |
+
run_name=run_name,
|
| 316 |
+
)
|
| 317 |
+
for rollout_index in range(config.rollouts_per_group)
|
| 318 |
+
]
|
| 319 |
+
)
|
| 320 |
+
return TrajectoryGroup(
|
| 321 |
+
trajectories=[trace.trajectory for trace in traces],
|
| 322 |
+
metadata={
|
| 323 |
+
"scenario": spec.scenario,
|
| 324 |
+
"difficulty": spec.difficulty,
|
| 325 |
+
"seed": spec.seed,
|
| 326 |
+
"training_step": training_step,
|
| 327 |
+
},
|
| 328 |
+
metrics={
|
| 329 |
+
"average_reward": _mean(summary.reward for summary in [trace.summary for trace in traces]),
|
| 330 |
+
"agreement_rate": _mean(
|
| 331 |
+
1.0 if trace.summary.agreement_reached else 0.0 for trace in traces
|
| 332 |
+
),
|
| 333 |
+
},
|
| 334 |
+
logs=[
|
| 335 |
+
(
|
| 336 |
+
f"group={group_index} seed={spec.seed} scenario={spec.scenario} "
|
| 337 |
+
f"difficulty={spec.difficulty}"
|
| 338 |
+
)
|
| 339 |
+
],
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
async def _run_art_episode(
|
| 344 |
+
*,
|
| 345 |
+
model: Any,
|
| 346 |
+
config: ArtOpenEnvConfig,
|
| 347 |
+
spec: ArtScenarioSpec,
|
| 348 |
+
evidence_pack: FrozenEvidencePack | None,
|
| 349 |
+
group_index: int,
|
| 350 |
+
rollout_index: int,
|
| 351 |
+
training_step: int,
|
| 352 |
+
run_name: str,
|
| 353 |
+
) -> _EpisodeTrace:
|
| 354 |
+
from art import Trajectory
|
| 355 |
+
from art.trajectories import History
|
| 356 |
+
|
| 357 |
+
client = ReplicaLabClient(config.base_url, transport=config.transport)
|
| 358 |
+
await asyncio.to_thread(client.connect)
|
| 359 |
+
invalid_action_count = 0
|
| 360 |
+
parse_error_count = 0
|
| 361 |
+
turns: list[_TurnRecord] = []
|
| 362 |
+
|
| 363 |
+
try:
|
| 364 |
+
observation = await asyncio.to_thread(
|
| 365 |
+
client.reset,
|
| 366 |
+
spec.seed,
|
| 367 |
+
spec.scenario,
|
| 368 |
+
spec.difficulty,
|
| 369 |
+
)
|
| 370 |
+
scientist_obs = observation.scientist
|
| 371 |
+
if scientist_obs is None:
|
| 372 |
+
raise RuntimeError("Reset returned no scientist observation.")
|
| 373 |
+
|
| 374 |
+
terminal_reward = -1.0
|
| 375 |
+
terminal_info = None
|
| 376 |
+
|
| 377 |
+
for _ in range(config.max_turns):
|
| 378 |
+
system_prompt = _build_art_scientist_system_prompt(
|
| 379 |
+
spec=spec,
|
| 380 |
+
observation=scientist_obs,
|
| 381 |
+
evidence_pack=evidence_pack,
|
| 382 |
+
)
|
| 383 |
+
user_prompt = format_scientist_observation(scientist_obs)
|
| 384 |
+
if evidence_pack is not None:
|
| 385 |
+
user_prompt += "\n\nFrozen evidence pack:\n" + evidence_pack.prompt_block()
|
| 386 |
+
turn = await _generate_turn(
|
| 387 |
+
model=model,
|
| 388 |
+
system_prompt=system_prompt,
|
| 389 |
+
user_prompt=user_prompt,
|
| 390 |
+
max_completion_tokens=config.max_completion_tokens,
|
| 391 |
+
max_parse_retries=config.max_parse_retries,
|
| 392 |
+
)
|
| 393 |
+
turns.append(turn)
|
| 394 |
+
|
| 395 |
+
if turn.parse_error is not None:
|
| 396 |
+
parse_error_count += 1
|
| 397 |
+
terminal_reward = -1.0
|
| 398 |
+
break
|
| 399 |
+
|
| 400 |
+
action = parse_scientist_output(turn.raw_text)
|
| 401 |
+
result = await asyncio.to_thread(client.step, action)
|
| 402 |
+
terminal_reward = result.reward
|
| 403 |
+
terminal_info = result.info
|
| 404 |
+
if result.info.error:
|
| 405 |
+
invalid_action_count += 1
|
| 406 |
+
|
| 407 |
+
if result.done:
|
| 408 |
+
break
|
| 409 |
+
|
| 410 |
+
if result.observation is None or result.observation.scientist is None:
|
| 411 |
+
raise RuntimeError("Non-terminal step returned no scientist observation.")
|
| 412 |
+
scientist_obs = result.observation.scientist
|
| 413 |
+
|
| 414 |
+
histories = [
|
| 415 |
+
History(messages_and_choices=turn.messages_and_choices)
|
| 416 |
+
for turn in turns[1:]
|
| 417 |
+
]
|
| 418 |
+
trajectory = Trajectory(
|
| 419 |
+
messages_and_choices=(turns[0].messages_and_choices if turns else []),
|
| 420 |
+
additional_histories=histories,
|
| 421 |
+
reward=terminal_reward,
|
| 422 |
+
metrics=_extract_terminal_metrics(
|
| 423 |
+
terminal_info=terminal_info,
|
| 424 |
+
invalid_action_count=invalid_action_count,
|
| 425 |
+
parse_error_count=parse_error_count,
|
| 426 |
+
rounds_used=len(turns),
|
| 427 |
+
),
|
| 428 |
+
metadata={},
|
| 429 |
+
logs=[
|
| 430 |
+
(
|
| 431 |
+
f"training_step={training_step} group={group_index} rollout={rollout_index} "
|
| 432 |
+
f"seed={spec.seed} scenario={spec.scenario} difficulty={spec.difficulty}"
|
| 433 |
+
)
|
| 434 |
+
],
|
| 435 |
+
)
|
| 436 |
+
summary = ArtRolloutSummary(
|
| 437 |
+
run_name=run_name,
|
| 438 |
+
training_step=training_step,
|
| 439 |
+
group_index=group_index,
|
| 440 |
+
rollout_index=rollout_index,
|
| 441 |
+
seed=spec.seed,
|
| 442 |
+
scenario=spec.scenario,
|
| 443 |
+
difficulty=spec.difficulty,
|
| 444 |
+
paper_title=scientist_obs.paper_title,
|
| 445 |
+
evidence_id=(evidence_pack.evidence_id if evidence_pack is not None else None),
|
| 446 |
+
evidence_match_type=(
|
| 447 |
+
evidence_pack.match_type if evidence_pack is not None else None
|
| 448 |
+
),
|
| 449 |
+
reward=terminal_reward,
|
| 450 |
+
verdict=(terminal_info.verdict if terminal_info is not None else None),
|
| 451 |
+
agreement_reached=(
|
| 452 |
+
terminal_info.agreement_reached if terminal_info is not None else False
|
| 453 |
+
),
|
| 454 |
+
rounds_used=len(turns),
|
| 455 |
+
invalid_action_count=invalid_action_count,
|
| 456 |
+
parse_error_count=parse_error_count,
|
| 457 |
+
rigor=(
|
| 458 |
+
terminal_info.reward_breakdown.rigor
|
| 459 |
+
if terminal_info and terminal_info.reward_breakdown
|
| 460 |
+
else 0.0
|
| 461 |
+
),
|
| 462 |
+
feasibility=(
|
| 463 |
+
terminal_info.reward_breakdown.feasibility
|
| 464 |
+
if terminal_info and terminal_info.reward_breakdown
|
| 465 |
+
else 0.0
|
| 466 |
+
),
|
| 467 |
+
fidelity=(
|
| 468 |
+
terminal_info.reward_breakdown.fidelity
|
| 469 |
+
if terminal_info and terminal_info.reward_breakdown
|
| 470 |
+
else 0.0
|
| 471 |
+
),
|
| 472 |
+
parsimony=(
|
| 473 |
+
terminal_info.reward_breakdown.parsimony
|
| 474 |
+
if terminal_info and terminal_info.reward_breakdown
|
| 475 |
+
else 1.0
|
| 476 |
+
),
|
| 477 |
+
)
|
| 478 |
+
trajectory.metadata.update(summary.model_dump(mode="json"))
|
| 479 |
+
return _EpisodeTrace(trajectory=trajectory, summary=summary)
|
| 480 |
+
finally:
|
| 481 |
+
await asyncio.to_thread(client.close)
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
async def _generate_turn(
|
| 485 |
+
*,
|
| 486 |
+
model: Any,
|
| 487 |
+
system_prompt: str,
|
| 488 |
+
user_prompt: str,
|
| 489 |
+
max_completion_tokens: int,
|
| 490 |
+
max_parse_retries: int,
|
| 491 |
+
) -> _TurnRecord:
|
| 492 |
+
client = model.openai_client()
|
| 493 |
+
messages = [
|
| 494 |
+
{"role": "system", "content": system_prompt},
|
| 495 |
+
{"role": "user", "content": user_prompt},
|
| 496 |
+
]
|
| 497 |
+
for attempt in range(max_parse_retries + 1):
|
| 498 |
+
completion = await client.chat.completions.create(
|
| 499 |
+
model=model.get_inference_name(),
|
| 500 |
+
messages=messages,
|
| 501 |
+
max_completion_tokens=max_completion_tokens,
|
| 502 |
+
temperature=0.0,
|
| 503 |
+
)
|
| 504 |
+
choice = completion.choices[0]
|
| 505 |
+
raw_text = _extract_choice_text(choice)
|
| 506 |
+
try:
|
| 507 |
+
parse_scientist_output(raw_text)
|
| 508 |
+
return _TurnRecord(
|
| 509 |
+
messages_and_choices=[
|
| 510 |
+
*messages,
|
| 511 |
+
{"role": "assistant", "content": raw_text},
|
| 512 |
+
],
|
| 513 |
+
parse_error=None,
|
| 514 |
+
raw_text=raw_text,
|
| 515 |
+
)
|
| 516 |
+
except ScientistOutputParseError as exc:
|
| 517 |
+
if attempt >= max_parse_retries:
|
| 518 |
+
return _TurnRecord(
|
| 519 |
+
messages_and_choices=[
|
| 520 |
+
*messages,
|
| 521 |
+
{"role": "assistant", "content": raw_text},
|
| 522 |
+
],
|
| 523 |
+
parse_error=exc.message,
|
| 524 |
+
raw_text=raw_text,
|
| 525 |
+
)
|
| 526 |
+
messages.extend(
|
| 527 |
+
[
|
| 528 |
+
{"role": "assistant", "content": raw_text},
|
| 529 |
+
{"role": "user", "content": _build_art_correction_prompt(exc)},
|
| 530 |
+
]
|
| 531 |
+
)
|
| 532 |
+
raise RuntimeError("unreachable")
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
def _build_art_scientist_system_prompt(
|
| 536 |
+
*,
|
| 537 |
+
spec: ArtScenarioSpec,
|
| 538 |
+
observation: ScientistObservation,
|
| 539 |
+
evidence_pack: FrozenEvidencePack | None,
|
| 540 |
+
) -> str:
|
| 541 |
+
sections = [
|
| 542 |
+
"You are the Scientist agent in ReplicaLab.",
|
| 543 |
+
"Negotiate toward the strongest feasible technical plan under hard real-world constraints.",
|
| 544 |
+
"Return exactly one valid ScientistAction JSON object with no markdown and no extra prose.",
|
| 545 |
+
"Use request_info only when a concrete blocking question remains.",
|
| 546 |
+
"Use accept only when the current protocol is genuinely ready.",
|
| 547 |
+
"Bounded tool policy: search_evidence, run_code_check, and inspect_image are support tools only; they never override constraints or reveal hidden ground truth.",
|
| 548 |
+
f"Scenario family: {spec.scenario}",
|
| 549 |
+
f"Difficulty: {spec.difficulty}",
|
| 550 |
+
f"Paper title: {observation.paper_title}",
|
| 551 |
+
f"Goal: {observation.experiment_goal}",
|
| 552 |
+
(
|
| 553 |
+
"The user observation already contains the full conversation "
|
| 554 |
+
"history and current protocol. Use that as your source of truth "
|
| 555 |
+
"for each turn."
|
| 556 |
+
),
|
| 557 |
+
]
|
| 558 |
+
if evidence_pack is not None:
|
| 559 |
+
sections.extend(
|
| 560 |
+
[
|
| 561 |
+
f"Frozen evidence id: {evidence_pack.evidence_id}",
|
| 562 |
+
f"Grounding paper: {evidence_pack.downloaded_paper_title}",
|
| 563 |
+
f"Claim: {evidence_pack.claim}",
|
| 564 |
+
f"Technique: {evidence_pack.key_technique}",
|
| 565 |
+
f"Constraint tension: {evidence_pack.primary_constraint_tension}",
|
| 566 |
+
]
|
| 567 |
+
)
|
| 568 |
+
sections.extend(
|
| 569 |
+
[
|
| 570 |
+
"Always emit all ScientistAction fields, even for request_info or accept.",
|
| 571 |
+
(
|
| 572 |
+
"Shape example: "
|
| 573 |
+
'{"action_type":"propose_protocol","sample_size":8,"controls":["baseline"],'
|
| 574 |
+
'"technique":"LoRA fine-tuning on the public subset","duration_days":2,'
|
| 575 |
+
'"required_equipment":["gpu_h100"],"required_reagents":[],'
|
| 576 |
+
'"questions":[],"rationale":"Uses the available hardware and stays within the reduced dataset budget."}'
|
| 577 |
+
),
|
| 578 |
+
]
|
| 579 |
+
)
|
| 580 |
+
return "\n".join(sections)
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
def _extract_choice_text(choice: Any) -> str:
|
| 584 |
+
message = getattr(choice, "message", None)
|
| 585 |
+
content = getattr(message, "content", None)
|
| 586 |
+
if isinstance(content, str):
|
| 587 |
+
return content
|
| 588 |
+
if isinstance(content, list):
|
| 589 |
+
parts: list[str] = []
|
| 590 |
+
for item in content:
|
| 591 |
+
text = getattr(item, "text", None)
|
| 592 |
+
if text:
|
| 593 |
+
parts.append(str(text))
|
| 594 |
+
elif isinstance(item, dict) and "text" in item:
|
| 595 |
+
parts.append(str(item["text"]))
|
| 596 |
+
return "\n".join(parts)
|
| 597 |
+
return ""
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
def _build_art_correction_prompt(error: ScientistOutputParseError) -> str:
|
| 601 |
+
suffix = (
|
| 602 |
+
"Return exactly one JSON object with all ScientistAction fields. "
|
| 603 |
+
"No markdown fences, no prose, no commentary."
|
| 604 |
+
)
|
| 605 |
+
if error.code == "no_json":
|
| 606 |
+
return "Your previous response did not contain a JSON object. " + suffix
|
| 607 |
+
if error.code == "invalid_json":
|
| 608 |
+
return (
|
| 609 |
+
f"Your previous response contained malformed JSON: {error.message}. " + suffix
|
| 610 |
+
)
|
| 611 |
+
return (
|
| 612 |
+
"Your previous response contained valid JSON but failed ScientistAction "
|
| 613 |
+
f"validation: {error.message}. Fix the validation error and return a corrected "
|
| 614 |
+
"ScientistAction JSON object. " + suffix
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
def _extract_terminal_metrics(
|
| 619 |
+
*,
|
| 620 |
+
terminal_info: Any,
|
| 621 |
+
invalid_action_count: int,
|
| 622 |
+
parse_error_count: int,
|
| 623 |
+
rounds_used: int,
|
| 624 |
+
) -> dict[str, float | int | bool]:
|
| 625 |
+
breakdown = terminal_info.reward_breakdown if terminal_info is not None else None
|
| 626 |
+
return {
|
| 627 |
+
"agreement_reached": terminal_info.agreement_reached if terminal_info else False,
|
| 628 |
+
"invalid_action_count": invalid_action_count,
|
| 629 |
+
"invalid_action_rate": (invalid_action_count / max(1, rounds_used)),
|
| 630 |
+
"parse_error_count": parse_error_count,
|
| 631 |
+
"parse_error_rate": (parse_error_count / max(1, rounds_used)),
|
| 632 |
+
"rounds_used": rounds_used,
|
| 633 |
+
"rigor": (breakdown.rigor if breakdown is not None else 0.0),
|
| 634 |
+
"feasibility": (breakdown.feasibility if breakdown is not None else 0.0),
|
| 635 |
+
"fidelity": (breakdown.fidelity if breakdown is not None else 0.0),
|
| 636 |
+
"parsimony": (breakdown.parsimony if breakdown is not None else 1.0),
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
def _summarize_art_training(
|
| 641 |
+
*,
|
| 642 |
+
config: ArtOpenEnvConfig,
|
| 643 |
+
layout: ArtifactLayout,
|
| 644 |
+
started_at: str,
|
| 645 |
+
finished_at: str,
|
| 646 |
+
rollouts: Sequence[ArtRolloutSummary],
|
| 647 |
+
evidence_version: str,
|
| 648 |
+
final_artifact_step: int | None,
|
| 649 |
+
final_artifact_name: str | None,
|
| 650 |
+
) -> ArtTrainingSummary:
|
| 651 |
+
return ArtTrainingSummary(
|
| 652 |
+
run_name=layout.run_name,
|
| 653 |
+
project=config.project,
|
| 654 |
+
model_name=config.model_name,
|
| 655 |
+
base_model=config.base_model,
|
| 656 |
+
train_steps=config.train_steps,
|
| 657 |
+
rollouts_per_group=config.rollouts_per_group,
|
| 658 |
+
scenario_count=len(config.scenarios),
|
| 659 |
+
base_url=config.base_url,
|
| 660 |
+
evidence_version=evidence_version,
|
| 661 |
+
started_at=started_at,
|
| 662 |
+
finished_at=finished_at,
|
| 663 |
+
final_artifact_step=final_artifact_step,
|
| 664 |
+
final_artifact_name=final_artifact_name,
|
| 665 |
+
average_reward=_mean(item.reward for item in rollouts),
|
| 666 |
+
agreement_rate=_mean(1.0 if item.agreement_reached else 0.0 for item in rollouts),
|
| 667 |
+
average_rounds=_mean(item.rounds_used for item in rollouts),
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
def _append_process_log(path: Path, line: str) -> None:
|
| 672 |
+
with path.open("a", encoding="utf-8") as handle:
|
| 673 |
+
handle.write(f"- {line}\n")
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
def _utc_now() -> str:
|
| 677 |
+
return datetime.now(UTC).isoformat()
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
def _mean(values: Any) -> float:
|
| 681 |
+
values = list(values)
|
| 682 |
+
if not values:
|
| 683 |
+
return 0.0
|
| 684 |
+
return round(sum(float(value) for value in values) / len(values), 6)
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
__all__ = [
|
| 688 |
+
"ArtOpenEnvConfig",
|
| 689 |
+
"ArtRolloutSummary",
|
| 690 |
+
"ArtScenarioSpec",
|
| 691 |
+
"ArtTrainingSummary",
|
| 692 |
+
"run_art_openenv_training",
|
| 693 |
+
]
|
replicalab/training/cli.py
CHANGED
|
@@ -8,21 +8,34 @@ import sys
|
|
| 8 |
from pathlib import Path
|
| 9 |
from typing import Sequence
|
| 10 |
|
| 11 |
-
from replicalab.agents import build_baseline_scientist_action
|
| 12 |
from replicalab.training.artifacts import (
|
| 13 |
ArtifactLayout,
|
| 14 |
append_jsonl,
|
| 15 |
build_run_name,
|
| 16 |
write_json,
|
| 17 |
)
|
| 18 |
-
from replicalab.training.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
from replicalab.training.lab_manager_sft import (
|
| 20 |
LabManagerSFTConfig,
|
| 21 |
preview_lab_manager_training,
|
| 22 |
train_lab_manager_sft,
|
| 23 |
)
|
| 24 |
from replicalab.training.metrics import episode_to_metrics
|
| 25 |
-
from replicalab.training.plots import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
from replicalab.training.scientist_grpo import (
|
| 27 |
ScientistGRPOConfig,
|
| 28 |
preview_scientist_training,
|
|
@@ -46,6 +59,10 @@ def main(argv: Sequence[str] | None = None) -> int:
|
|
| 46 |
return _run_lab_manager_train(args)
|
| 47 |
if args.command == "baseline-eval":
|
| 48 |
return _run_baseline_eval(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
parser.error(f"Unsupported command: {args.command}")
|
| 51 |
return 2
|
|
@@ -169,6 +186,159 @@ def _build_parser() -> argparse.ArgumentParser:
|
|
| 169 |
help="Difficulty levels to evaluate.",
|
| 170 |
)
|
| 171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
return parser
|
| 173 |
|
| 174 |
|
|
@@ -279,6 +449,22 @@ def _run_scientist_train(args: argparse.Namespace) -> int:
|
|
| 279 |
max_steps=args.max_steps,
|
| 280 |
)
|
| 281 |
result = train_scientist_grpo(config, layout=layout, dry_run=args.dry_run)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
_maybe_plot_training_history(
|
| 283 |
layout=layout,
|
| 284 |
state_name="scientist_trainer_state.json",
|
|
@@ -327,6 +513,21 @@ def _run_lab_manager_train(args: argparse.Namespace) -> int:
|
|
| 327 |
load_in_4bit=args.load_in_4bit,
|
| 328 |
)
|
| 329 |
result = train_lab_manager_sft(config, layout=layout, dry_run=args.dry_run)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
_maybe_plot_training_history(
|
| 331 |
layout=layout,
|
| 332 |
state_name="lab_manager_trainer_state.json",
|
|
@@ -364,6 +565,22 @@ def _run_baseline_eval(args: argparse.Namespace) -> int:
|
|
| 364 |
"cases": [case.__dict__ for case in cases],
|
| 365 |
},
|
| 366 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
for record in records:
|
| 368 |
append_jsonl(
|
| 369 |
layout.metrics_jsonl,
|
|
@@ -376,6 +593,127 @@ def _run_baseline_eval(args: argparse.Namespace) -> int:
|
|
| 376 |
return 0
|
| 377 |
|
| 378 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
def _maybe_plot_training_history(
|
| 380 |
*,
|
| 381 |
layout: ArtifactLayout,
|
|
@@ -415,6 +753,80 @@ def _plot_eval_summary(
|
|
| 415 |
metric_key="agreement_rate",
|
| 416 |
title="Baseline agreement rate",
|
| 417 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
|
| 419 |
|
| 420 |
if __name__ == "__main__":
|
|
|
|
| 8 |
from pathlib import Path
|
| 9 |
from typing import Sequence
|
| 10 |
|
| 11 |
+
from replicalab.agents import build_baseline_scientist_action, build_remote_scientist_policy
|
| 12 |
from replicalab.training.artifacts import (
|
| 13 |
ArtifactLayout,
|
| 14 |
append_jsonl,
|
| 15 |
build_run_name,
|
| 16 |
write_json,
|
| 17 |
)
|
| 18 |
+
from replicalab.training.art_openenv import (
|
| 19 |
+
ArtOpenEnvConfig,
|
| 20 |
+
ArtScenarioSpec,
|
| 21 |
+
run_art_openenv_training,
|
| 22 |
+
)
|
| 23 |
+
from replicalab.training.evaluation import (
|
| 24 |
+
build_default_evaluation_cases,
|
| 25 |
+
compare_policies,
|
| 26 |
+
evaluate_policy,
|
| 27 |
+
)
|
| 28 |
from replicalab.training.lab_manager_sft import (
|
| 29 |
LabManagerSFTConfig,
|
| 30 |
preview_lab_manager_training,
|
| 31 |
train_lab_manager_sft,
|
| 32 |
)
|
| 33 |
from replicalab.training.metrics import episode_to_metrics
|
| 34 |
+
from replicalab.training.plots import (
|
| 35 |
+
plot_evaluation_bars,
|
| 36 |
+
plot_metrics_by_step,
|
| 37 |
+
plot_training_history,
|
| 38 |
+
)
|
| 39 |
from replicalab.training.scientist_grpo import (
|
| 40 |
ScientistGRPOConfig,
|
| 41 |
preview_scientist_training,
|
|
|
|
| 59 |
return _run_lab_manager_train(args)
|
| 60 |
if args.command == "baseline-eval":
|
| 61 |
return _run_baseline_eval(args)
|
| 62 |
+
if args.command == "scientist-compare-eval":
|
| 63 |
+
return _run_scientist_compare_eval(args)
|
| 64 |
+
if args.command == "art-scientist-train":
|
| 65 |
+
return _run_art_scientist_train(args)
|
| 66 |
|
| 67 |
parser.error(f"Unsupported command: {args.command}")
|
| 68 |
return 2
|
|
|
|
| 186 |
help="Difficulty levels to evaluate.",
|
| 187 |
)
|
| 188 |
|
| 189 |
+
compare_eval = subparsers.add_parser(
|
| 190 |
+
"scientist-compare-eval",
|
| 191 |
+
help="Compare baseline Scientist versus a trained ART Scientist checkpoint.",
|
| 192 |
+
)
|
| 193 |
+
_add_common_artifact_args(compare_eval, prefix="eval-compare")
|
| 194 |
+
compare_eval.add_argument(
|
| 195 |
+
"--base-url",
|
| 196 |
+
default="https://ayushozha-replicalab.hf.space",
|
| 197 |
+
help="ReplicaLab environment base URL.",
|
| 198 |
+
)
|
| 199 |
+
compare_eval.add_argument(
|
| 200 |
+
"--transport",
|
| 201 |
+
default="rest",
|
| 202 |
+
choices=("rest", "ws"),
|
| 203 |
+
help="Transport used by ReplicaLabClient.",
|
| 204 |
+
)
|
| 205 |
+
compare_eval.add_argument(
|
| 206 |
+
"--eval-seeds",
|
| 207 |
+
nargs="+",
|
| 208 |
+
type=int,
|
| 209 |
+
default=[101, 102],
|
| 210 |
+
help="Evaluation seeds.",
|
| 211 |
+
)
|
| 212 |
+
compare_eval.add_argument(
|
| 213 |
+
"--scenarios",
|
| 214 |
+
nargs="+",
|
| 215 |
+
default=list(scientist_defaults.templates),
|
| 216 |
+
help="Scenario families to evaluate.",
|
| 217 |
+
)
|
| 218 |
+
compare_eval.add_argument(
|
| 219 |
+
"--difficulties",
|
| 220 |
+
nargs="+",
|
| 221 |
+
default=list(scientist_defaults.difficulties),
|
| 222 |
+
help="Difficulty levels to evaluate.",
|
| 223 |
+
)
|
| 224 |
+
compare_eval.add_argument(
|
| 225 |
+
"--project",
|
| 226 |
+
default="replicalab-ai",
|
| 227 |
+
help="ART project name for the trained Scientist checkpoint.",
|
| 228 |
+
)
|
| 229 |
+
compare_eval.add_argument(
|
| 230 |
+
"--model-name",
|
| 231 |
+
default="replicalab-scientist-art-live",
|
| 232 |
+
help="ART trainable model name for the trained Scientist checkpoint.",
|
| 233 |
+
)
|
| 234 |
+
compare_eval.add_argument(
|
| 235 |
+
"--base-model",
|
| 236 |
+
default="OpenPipe/Qwen3-14B-Instruct",
|
| 237 |
+
help="Base model used for the ART trained Scientist.",
|
| 238 |
+
)
|
| 239 |
+
compare_eval.add_argument(
|
| 240 |
+
"--checkpoint-step",
|
| 241 |
+
type=int,
|
| 242 |
+
default=None,
|
| 243 |
+
help="Optional explicit ART checkpoint step to evaluate.",
|
| 244 |
+
)
|
| 245 |
+
compare_eval.add_argument(
|
| 246 |
+
"--max-completion-tokens",
|
| 247 |
+
type=int,
|
| 248 |
+
default=450,
|
| 249 |
+
help="Max completion tokens for the trained remote Scientist.",
|
| 250 |
+
)
|
| 251 |
+
compare_eval.add_argument(
|
| 252 |
+
"--temperature",
|
| 253 |
+
type=float,
|
| 254 |
+
default=0.0,
|
| 255 |
+
help="Sampling temperature for the trained remote Scientist.",
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
art_train = subparsers.add_parser(
|
| 259 |
+
"art-scientist-train",
|
| 260 |
+
help="Run ART serverless RL training against the ReplicaLab OpenEnv deployment.",
|
| 261 |
+
)
|
| 262 |
+
_add_common_artifact_args(art_train, prefix="art-scientist")
|
| 263 |
+
art_train.add_argument(
|
| 264 |
+
"--project",
|
| 265 |
+
default="replicalab-art-openenv",
|
| 266 |
+
help="Weights & Biases / ART project name.",
|
| 267 |
+
)
|
| 268 |
+
art_train.add_argument(
|
| 269 |
+
"--model-name",
|
| 270 |
+
default="replicalab-scientist-art",
|
| 271 |
+
help="ART trainable model name.",
|
| 272 |
+
)
|
| 273 |
+
art_train.add_argument(
|
| 274 |
+
"--base-model",
|
| 275 |
+
default="OpenPipe/Qwen3-14B-Instruct",
|
| 276 |
+
help="ART serverless base model.",
|
| 277 |
+
)
|
| 278 |
+
art_train.add_argument(
|
| 279 |
+
"--base-url",
|
| 280 |
+
default="https://ayushozha-replicalab.hf.space",
|
| 281 |
+
help="ReplicaLab environment base URL.",
|
| 282 |
+
)
|
| 283 |
+
art_train.add_argument(
|
| 284 |
+
"--transport",
|
| 285 |
+
default="rest",
|
| 286 |
+
choices=("rest",),
|
| 287 |
+
help="Transport used for live environment interaction.",
|
| 288 |
+
)
|
| 289 |
+
art_train.add_argument(
|
| 290 |
+
"--train-steps",
|
| 291 |
+
type=int,
|
| 292 |
+
default=1,
|
| 293 |
+
help="Number of ART training updates to run.",
|
| 294 |
+
)
|
| 295 |
+
art_train.add_argument(
|
| 296 |
+
"--rollouts-per-group",
|
| 297 |
+
type=int,
|
| 298 |
+
default=2,
|
| 299 |
+
help="Number of sampled rollouts for each scenario group.",
|
| 300 |
+
)
|
| 301 |
+
art_train.add_argument(
|
| 302 |
+
"--max-turns",
|
| 303 |
+
type=int,
|
| 304 |
+
default=6,
|
| 305 |
+
help="Max environment turns per rollout.",
|
| 306 |
+
)
|
| 307 |
+
art_train.add_argument(
|
| 308 |
+
"--max-completion-tokens",
|
| 309 |
+
type=int,
|
| 310 |
+
default=700,
|
| 311 |
+
help="Assistant max completion tokens per turn.",
|
| 312 |
+
)
|
| 313 |
+
art_train.add_argument(
|
| 314 |
+
"--max-parse-retries",
|
| 315 |
+
type=int,
|
| 316 |
+
default=2,
|
| 317 |
+
help="Number of parser-driven correction retries per turn.",
|
| 318 |
+
)
|
| 319 |
+
art_train.add_argument(
|
| 320 |
+
"--learning-rate",
|
| 321 |
+
type=float,
|
| 322 |
+
default=5e-6,
|
| 323 |
+
help="ART learning rate.",
|
| 324 |
+
)
|
| 325 |
+
art_train.add_argument(
|
| 326 |
+
"--beta",
|
| 327 |
+
type=float,
|
| 328 |
+
default=0.0,
|
| 329 |
+
help="ART KL penalty coefficient.",
|
| 330 |
+
)
|
| 331 |
+
art_train.add_argument(
|
| 332 |
+
"--scenario-spec",
|
| 333 |
+
nargs="+",
|
| 334 |
+
default=[
|
| 335 |
+
"0:ml_benchmark:easy",
|
| 336 |
+
"1:ml_benchmark:medium",
|
| 337 |
+
"0:finance_trading:easy",
|
| 338 |
+
],
|
| 339 |
+
help="Scenario specs in the form seed:scenario:difficulty.",
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
return parser
|
| 343 |
|
| 344 |
|
|
|
|
| 449 |
max_steps=args.max_steps,
|
| 450 |
)
|
| 451 |
result = train_scientist_grpo(config, layout=layout, dry_run=args.dry_run)
|
| 452 |
+
_write_run_metadata(
|
| 453 |
+
layout,
|
| 454 |
+
{
|
| 455 |
+
"kind": "scientist_train",
|
| 456 |
+
"model_name": args.model_name,
|
| 457 |
+
"templates": args.templates,
|
| 458 |
+
"difficulties": args.difficulties,
|
| 459 |
+
"seed_count": args.seed_count,
|
| 460 |
+
"max_steps": args.max_steps,
|
| 461 |
+
"bounded_tool_policy": [
|
| 462 |
+
"search_evidence",
|
| 463 |
+
"run_code_check",
|
| 464 |
+
"inspect_image",
|
| 465 |
+
],
|
| 466 |
+
},
|
| 467 |
+
)
|
| 468 |
_maybe_plot_training_history(
|
| 469 |
layout=layout,
|
| 470 |
state_name="scientist_trainer_state.json",
|
|
|
|
| 513 |
load_in_4bit=args.load_in_4bit,
|
| 514 |
)
|
| 515 |
result = train_lab_manager_sft(config, layout=layout, dry_run=args.dry_run)
|
| 516 |
+
_write_run_metadata(
|
| 517 |
+
layout,
|
| 518 |
+
{
|
| 519 |
+
"kind": "lab_manager_train",
|
| 520 |
+
"model_name": args.model_name,
|
| 521 |
+
"templates": args.templates,
|
| 522 |
+
"difficulties": args.difficulties,
|
| 523 |
+
"seed_count": args.seed_count,
|
| 524 |
+
"bounded_tool_policy": [
|
| 525 |
+
"search_evidence",
|
| 526 |
+
"run_code_check",
|
| 527 |
+
"inspect_image",
|
| 528 |
+
],
|
| 529 |
+
},
|
| 530 |
+
)
|
| 531 |
_maybe_plot_training_history(
|
| 532 |
layout=layout,
|
| 533 |
state_name="lab_manager_trainer_state.json",
|
|
|
|
| 565 |
"cases": [case.__dict__ for case in cases],
|
| 566 |
},
|
| 567 |
)
|
| 568 |
+
_write_run_metadata(
|
| 569 |
+
layout,
|
| 570 |
+
{
|
| 571 |
+
"kind": "baseline_eval",
|
| 572 |
+
"base_url": args.base_url,
|
| 573 |
+
"transport": args.transport,
|
| 574 |
+
"eval_seeds": args.eval_seeds,
|
| 575 |
+
"scenarios": args.scenarios,
|
| 576 |
+
"difficulties": args.difficulties,
|
| 577 |
+
"bounded_tool_policy": [
|
| 578 |
+
"search_evidence",
|
| 579 |
+
"run_code_check",
|
| 580 |
+
"inspect_image",
|
| 581 |
+
],
|
| 582 |
+
},
|
| 583 |
+
)
|
| 584 |
for record in records:
|
| 585 |
append_jsonl(
|
| 586 |
layout.metrics_jsonl,
|
|
|
|
| 593 |
return 0
|
| 594 |
|
| 595 |
|
| 596 |
+
def _run_scientist_compare_eval(args: argparse.Namespace) -> int:
|
| 597 |
+
layout = _build_layout(
|
| 598 |
+
prefix="eval-compare",
|
| 599 |
+
persist_root=args.persist_root,
|
| 600 |
+
run_name=args.run_name,
|
| 601 |
+
)
|
| 602 |
+
cases = build_default_evaluation_cases(
|
| 603 |
+
seeds=args.eval_seeds,
|
| 604 |
+
scenarios=args.scenarios,
|
| 605 |
+
difficulties=args.difficulties,
|
| 606 |
+
)
|
| 607 |
+
trained_policy = build_remote_scientist_policy(
|
| 608 |
+
project=args.project,
|
| 609 |
+
model_name=args.model_name,
|
| 610 |
+
base_model=args.base_model,
|
| 611 |
+
checkpoint_step=args.checkpoint_step,
|
| 612 |
+
max_completion_tokens=args.max_completion_tokens,
|
| 613 |
+
temperature=args.temperature,
|
| 614 |
+
)
|
| 615 |
+
records_by_label, rows = compare_policies(
|
| 616 |
+
base_url=args.base_url,
|
| 617 |
+
policies=[
|
| 618 |
+
("baseline", build_baseline_scientist_action),
|
| 619 |
+
("trained", trained_policy),
|
| 620 |
+
],
|
| 621 |
+
cases=cases,
|
| 622 |
+
transport=args.transport,
|
| 623 |
+
)
|
| 624 |
+
write_json(
|
| 625 |
+
layout.config_json,
|
| 626 |
+
{
|
| 627 |
+
"kind": "scientist_compare_eval",
|
| 628 |
+
"base_url": args.base_url,
|
| 629 |
+
"transport": args.transport,
|
| 630 |
+
"cases": [case.__dict__ for case in cases],
|
| 631 |
+
"project": args.project,
|
| 632 |
+
"model_name": args.model_name,
|
| 633 |
+
"base_model": args.base_model,
|
| 634 |
+
"checkpoint_step": args.checkpoint_step,
|
| 635 |
+
},
|
| 636 |
+
)
|
| 637 |
+
_write_run_metadata(
|
| 638 |
+
layout,
|
| 639 |
+
{
|
| 640 |
+
"kind": "scientist_compare_eval",
|
| 641 |
+
"base_url": args.base_url,
|
| 642 |
+
"transport": args.transport,
|
| 643 |
+
"eval_seeds": args.eval_seeds,
|
| 644 |
+
"scenarios": args.scenarios,
|
| 645 |
+
"difficulties": args.difficulties,
|
| 646 |
+
"project": args.project,
|
| 647 |
+
"model_name": args.model_name,
|
| 648 |
+
"base_model": args.base_model,
|
| 649 |
+
"checkpoint_step": args.checkpoint_step,
|
| 650 |
+
"bounded_tool_policy": [
|
| 651 |
+
"search_evidence",
|
| 652 |
+
"run_code_check",
|
| 653 |
+
"inspect_image",
|
| 654 |
+
],
|
| 655 |
+
},
|
| 656 |
+
)
|
| 657 |
+
for label, records in records_by_label.items():
|
| 658 |
+
for record in records:
|
| 659 |
+
append_jsonl(
|
| 660 |
+
layout.metrics_jsonl,
|
| 661 |
+
{"label": label, **episode_to_metrics(record).model_dump(mode="json")},
|
| 662 |
+
)
|
| 663 |
+
rows_payload = [row.model_dump(mode="json") for row in rows]
|
| 664 |
+
write_json(layout.summary_json, {"rows": rows_payload})
|
| 665 |
+
_plot_comparison_summary(rows_payload, layout=layout)
|
| 666 |
+
print(json.dumps({"rows": rows_payload}, indent=2, sort_keys=True))
|
| 667 |
+
return 0
|
| 668 |
+
|
| 669 |
+
|
| 670 |
+
def _run_art_scientist_train(args: argparse.Namespace) -> int:
|
| 671 |
+
layout = _build_layout(
|
| 672 |
+
prefix="art-scientist",
|
| 673 |
+
persist_root=args.persist_root,
|
| 674 |
+
run_name=args.run_name,
|
| 675 |
+
)
|
| 676 |
+
config = ArtOpenEnvConfig(
|
| 677 |
+
project=args.project,
|
| 678 |
+
model_name=args.model_name,
|
| 679 |
+
base_model=args.base_model,
|
| 680 |
+
base_url=args.base_url,
|
| 681 |
+
transport=args.transport,
|
| 682 |
+
train_steps=args.train_steps,
|
| 683 |
+
rollouts_per_group=args.rollouts_per_group,
|
| 684 |
+
max_turns=args.max_turns,
|
| 685 |
+
max_completion_tokens=args.max_completion_tokens,
|
| 686 |
+
max_parse_retries=args.max_parse_retries,
|
| 687 |
+
learning_rate=args.learning_rate,
|
| 688 |
+
beta=args.beta,
|
| 689 |
+
scenarios=[_parse_art_scenario_spec(item) for item in args.scenario_spec],
|
| 690 |
+
)
|
| 691 |
+
result = run_art_openenv_training(config, layout=layout)
|
| 692 |
+
_write_run_metadata(
|
| 693 |
+
layout,
|
| 694 |
+
{
|
| 695 |
+
"kind": "art_scientist_train",
|
| 696 |
+
"project": args.project,
|
| 697 |
+
"model_name": args.model_name,
|
| 698 |
+
"base_model": args.base_model,
|
| 699 |
+
"base_url": args.base_url,
|
| 700 |
+
"train_steps": args.train_steps,
|
| 701 |
+
"rollouts_per_group": args.rollouts_per_group,
|
| 702 |
+
"max_turns": args.max_turns,
|
| 703 |
+
"max_parse_retries": args.max_parse_retries,
|
| 704 |
+
"scenario_spec": args.scenario_spec,
|
| 705 |
+
"bounded_tool_policy": [
|
| 706 |
+
"search_evidence",
|
| 707 |
+
"run_code_check",
|
| 708 |
+
"inspect_image",
|
| 709 |
+
],
|
| 710 |
+
},
|
| 711 |
+
)
|
| 712 |
+
_plot_art_metrics(layout)
|
| 713 |
+
print(json.dumps(result, indent=2, sort_keys=True))
|
| 714 |
+
return 0
|
| 715 |
+
|
| 716 |
+
|
| 717 |
def _maybe_plot_training_history(
|
| 718 |
*,
|
| 719 |
layout: ArtifactLayout,
|
|
|
|
| 753 |
metric_key="agreement_rate",
|
| 754 |
title="Baseline agreement rate",
|
| 755 |
)
|
| 756 |
+
if "average_invalid_bounded_tool_rate" in summary:
|
| 757 |
+
plot_evaluation_bars(
|
| 758 |
+
rows,
|
| 759 |
+
output_path=layout.plots_dir / "baseline_invalid_bounded_tool_rate.png",
|
| 760 |
+
metric_key="average_invalid_bounded_tool_rate",
|
| 761 |
+
title="Baseline invalid bounded-tool rate",
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
def _plot_comparison_summary(
|
| 766 |
+
rows: list[dict[str, float | str]],
|
| 767 |
+
*,
|
| 768 |
+
layout: ArtifactLayout,
|
| 769 |
+
) -> None:
|
| 770 |
+
for metric_key, title, output_name in (
|
| 771 |
+
("average_reward", "Before vs after average reward", "compare_average_reward.png"),
|
| 772 |
+
("agreement_rate", "Before vs after agreement rate", "compare_agreement_rate.png"),
|
| 773 |
+
("invalid_action_rate", "Before vs after invalid action rate", "compare_invalid_action_rate.png"),
|
| 774 |
+
(
|
| 775 |
+
"average_invalid_bounded_tool_rate",
|
| 776 |
+
"Before vs after invalid bounded-tool rate",
|
| 777 |
+
"compare_invalid_bounded_tool_rate.png",
|
| 778 |
+
),
|
| 779 |
+
):
|
| 780 |
+
plot_evaluation_bars(
|
| 781 |
+
rows,
|
| 782 |
+
output_path=layout.plots_dir / output_name,
|
| 783 |
+
metric_key=metric_key,
|
| 784 |
+
title=title,
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
|
| 788 |
+
def _plot_art_metrics(layout: ArtifactLayout) -> None:
|
| 789 |
+
if not layout.metrics_jsonl.exists():
|
| 790 |
+
return
|
| 791 |
+
rows = [
|
| 792 |
+
json.loads(line)
|
| 793 |
+
for line in layout.metrics_jsonl.read_text(encoding="utf-8").splitlines()
|
| 794 |
+
if line.strip()
|
| 795 |
+
]
|
| 796 |
+
if not rows:
|
| 797 |
+
return
|
| 798 |
+
plot_metrics_by_step(
|
| 799 |
+
rows,
|
| 800 |
+
output_path=layout.plots_dir / "art_reward_components.png",
|
| 801 |
+
title="ART Scientist reward components by training step",
|
| 802 |
+
metric_keys=[
|
| 803 |
+
"reward",
|
| 804 |
+
"rigor",
|
| 805 |
+
"feasibility",
|
| 806 |
+
"fidelity",
|
| 807 |
+
"agreement_reached",
|
| 808 |
+
"invalid_action_count",
|
| 809 |
+
"parse_error_count",
|
| 810 |
+
],
|
| 811 |
+
)
|
| 812 |
+
|
| 813 |
+
|
| 814 |
+
def _write_run_metadata(layout: ArtifactLayout, payload: dict[str, object]) -> None:
|
| 815 |
+
write_json(layout.reports_dir / "run_metadata.json", payload)
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
def _parse_art_scenario_spec(value: str) -> ArtScenarioSpec:
|
| 819 |
+
parts = value.split(":")
|
| 820 |
+
if len(parts) != 3:
|
| 821 |
+
raise ValueError(
|
| 822 |
+
f"Invalid scenario spec {value!r}. Expected seed:scenario:difficulty."
|
| 823 |
+
)
|
| 824 |
+
seed_text, scenario, difficulty = parts
|
| 825 |
+
return ArtScenarioSpec(
|
| 826 |
+
seed=int(seed_text),
|
| 827 |
+
scenario=scenario,
|
| 828 |
+
difficulty=difficulty,
|
| 829 |
+
)
|
| 830 |
|
| 831 |
|
| 832 |
if __name__ == "__main__":
|
replicalab/training/evaluation.py
CHANGED
|
@@ -5,6 +5,8 @@ from __future__ import annotations
|
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from typing import Callable, Iterable, Sequence
|
| 7 |
|
|
|
|
|
|
|
| 8 |
from replicalab.client import ReplicaLabClient
|
| 9 |
from replicalab.models import ScientistAction, ScientistObservation
|
| 10 |
from replicalab.training.metrics import EvaluationSummary, summarize_episodes
|
|
@@ -21,6 +23,25 @@ class EvaluationCase:
|
|
| 21 |
difficulty: str
|
| 22 |
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
def build_default_evaluation_cases(
|
| 25 |
*,
|
| 26 |
seeds: Iterable[int],
|
|
@@ -62,8 +83,38 @@ def evaluate_policy(
|
|
| 62 |
return records, summarize_episodes(records)
|
| 63 |
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
__all__ = [
|
| 66 |
"EvaluationCase",
|
|
|
|
| 67 |
"build_default_evaluation_cases",
|
|
|
|
| 68 |
"evaluate_policy",
|
| 69 |
]
|
|
|
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from typing import Callable, Iterable, Sequence
|
| 7 |
|
| 8 |
+
from pydantic import BaseModel, ConfigDict
|
| 9 |
+
|
| 10 |
from replicalab.client import ReplicaLabClient
|
| 11 |
from replicalab.models import ScientistAction, ScientistObservation
|
| 12 |
from replicalab.training.metrics import EvaluationSummary, summarize_episodes
|
|
|
|
| 23 |
difficulty: str
|
| 24 |
|
| 25 |
|
| 26 |
+
class PolicyComparisonRow(BaseModel):
|
| 27 |
+
"""One flattened before/after comparison row."""
|
| 28 |
+
|
| 29 |
+
model_config = ConfigDict(extra="forbid")
|
| 30 |
+
|
| 31 |
+
label: str
|
| 32 |
+
episode_count: int
|
| 33 |
+
average_reward: float
|
| 34 |
+
average_rounds: float
|
| 35 |
+
agreement_rate: float
|
| 36 |
+
invalid_action_rate: float
|
| 37 |
+
average_invalid_bounded_tool_rate: float
|
| 38 |
+
average_rigor: float
|
| 39 |
+
average_feasibility: float
|
| 40 |
+
average_fidelity: float
|
| 41 |
+
average_parsimony: float
|
| 42 |
+
average_tool_trace_count: float
|
| 43 |
+
|
| 44 |
+
|
| 45 |
def build_default_evaluation_cases(
|
| 46 |
*,
|
| 47 |
seeds: Iterable[int],
|
|
|
|
| 83 |
return records, summarize_episodes(records)
|
| 84 |
|
| 85 |
|
| 86 |
+
def compare_policies(
|
| 87 |
+
*,
|
| 88 |
+
base_url: str,
|
| 89 |
+
policies: Sequence[tuple[str, PolicyFn]],
|
| 90 |
+
cases: Sequence[EvaluationCase],
|
| 91 |
+
transport: str = "rest",
|
| 92 |
+
) -> tuple[dict[str, list[EpisodeRecord]], list[PolicyComparisonRow]]:
|
| 93 |
+
"""Evaluate multiple policies on the exact same case set."""
|
| 94 |
+
|
| 95 |
+
records_by_label: dict[str, list[EpisodeRecord]] = {}
|
| 96 |
+
rows: list[PolicyComparisonRow] = []
|
| 97 |
+
for label, policy_fn in policies:
|
| 98 |
+
records, summary = evaluate_policy(
|
| 99 |
+
base_url=base_url,
|
| 100 |
+
policy_fn=policy_fn,
|
| 101 |
+
cases=cases,
|
| 102 |
+
transport=transport,
|
| 103 |
+
)
|
| 104 |
+
records_by_label[label] = records
|
| 105 |
+
rows.append(
|
| 106 |
+
PolicyComparisonRow(
|
| 107 |
+
label=label,
|
| 108 |
+
**summary.model_dump(mode="json"),
|
| 109 |
+
)
|
| 110 |
+
)
|
| 111 |
+
return records_by_label, rows
|
| 112 |
+
|
| 113 |
+
|
| 114 |
__all__ = [
|
| 115 |
"EvaluationCase",
|
| 116 |
+
"PolicyComparisonRow",
|
| 117 |
"build_default_evaluation_cases",
|
| 118 |
+
"compare_policies",
|
| 119 |
"evaluate_policy",
|
| 120 |
]
|
replicalab/training/metrics.py
CHANGED
|
@@ -24,6 +24,8 @@ class EpisodeMetrics(BaseModel):
|
|
| 24 |
invalid_action_count: int = 0
|
| 25 |
invalid_action_rate: float = 0.0
|
| 26 |
tool_trace_count: int = 0
|
|
|
|
|
|
|
| 27 |
rigor: float = 0.0
|
| 28 |
feasibility: float = 0.0
|
| 29 |
fidelity: float = 0.0
|
|
@@ -40,6 +42,7 @@ class EvaluationSummary(BaseModel):
|
|
| 40 |
average_rounds: float
|
| 41 |
agreement_rate: float
|
| 42 |
invalid_action_rate: float
|
|
|
|
| 43 |
average_rigor: float
|
| 44 |
average_feasibility: float
|
| 45 |
average_fidelity: float
|
|
@@ -52,6 +55,8 @@ def episode_to_metrics(record: EpisodeRecord) -> EpisodeMetrics:
|
|
| 52 |
|
| 53 |
invalid_actions = sum(1 for step in record.steps if step.error)
|
| 54 |
rounds_used = max(1, record.rounds_used)
|
|
|
|
|
|
|
| 55 |
breakdown = record.reward_breakdown
|
| 56 |
|
| 57 |
return EpisodeMetrics(
|
|
@@ -64,7 +69,9 @@ def episode_to_metrics(record: EpisodeRecord) -> EpisodeMetrics:
|
|
| 64 |
verdict=record.verdict,
|
| 65 |
invalid_action_count=invalid_actions,
|
| 66 |
invalid_action_rate=invalid_actions / rounds_used,
|
| 67 |
-
tool_trace_count=
|
|
|
|
|
|
|
| 68 |
rigor=(breakdown.rigor if breakdown is not None else 0.0),
|
| 69 |
feasibility=(breakdown.feasibility if breakdown is not None else 0.0),
|
| 70 |
fidelity=(breakdown.fidelity if breakdown is not None else 0.0),
|
|
@@ -83,6 +90,7 @@ def summarize_episodes(records: list[EpisodeRecord]) -> EvaluationSummary:
|
|
| 83 |
average_rounds=0.0,
|
| 84 |
agreement_rate=0.0,
|
| 85 |
invalid_action_rate=0.0,
|
|
|
|
| 86 |
average_rigor=0.0,
|
| 87 |
average_feasibility=0.0,
|
| 88 |
average_fidelity=0.0,
|
|
@@ -96,6 +104,9 @@ def summarize_episodes(records: list[EpisodeRecord]) -> EvaluationSummary:
|
|
| 96 |
average_rounds=mean(item.rounds_used for item in metrics),
|
| 97 |
agreement_rate=mean(1.0 if item.agreement_reached else 0.0 for item in metrics),
|
| 98 |
invalid_action_rate=mean(item.invalid_action_rate for item in metrics),
|
|
|
|
|
|
|
|
|
|
| 99 |
average_rigor=mean(item.rigor for item in metrics),
|
| 100 |
average_feasibility=mean(item.feasibility for item in metrics),
|
| 101 |
average_fidelity=mean(item.fidelity for item in metrics),
|
|
@@ -104,6 +115,23 @@ def summarize_episodes(records: list[EpisodeRecord]) -> EvaluationSummary:
|
|
| 104 |
)
|
| 105 |
|
| 106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
__all__ = [
|
| 108 |
"EpisodeMetrics",
|
| 109 |
"EvaluationSummary",
|
|
|
|
| 24 |
invalid_action_count: int = 0
|
| 25 |
invalid_action_rate: float = 0.0
|
| 26 |
tool_trace_count: int = 0
|
| 27 |
+
invalid_bounded_tool_count: int = 0
|
| 28 |
+
invalid_bounded_tool_rate: float = 0.0
|
| 29 |
rigor: float = 0.0
|
| 30 |
feasibility: float = 0.0
|
| 31 |
fidelity: float = 0.0
|
|
|
|
| 42 |
average_rounds: float
|
| 43 |
agreement_rate: float
|
| 44 |
invalid_action_rate: float
|
| 45 |
+
average_invalid_bounded_tool_rate: float
|
| 46 |
average_rigor: float
|
| 47 |
average_feasibility: float
|
| 48 |
average_fidelity: float
|
|
|
|
| 55 |
|
| 56 |
invalid_actions = sum(1 for step in record.steps if step.error)
|
| 57 |
rounds_used = max(1, record.rounds_used)
|
| 58 |
+
invalid_bounded_tools = _count_invalid_bounded_tools(record.tool_traces)
|
| 59 |
+
tool_trace_count = record.tool_trace_count
|
| 60 |
breakdown = record.reward_breakdown
|
| 61 |
|
| 62 |
return EpisodeMetrics(
|
|
|
|
| 69 |
verdict=record.verdict,
|
| 70 |
invalid_action_count=invalid_actions,
|
| 71 |
invalid_action_rate=invalid_actions / rounds_used,
|
| 72 |
+
tool_trace_count=tool_trace_count,
|
| 73 |
+
invalid_bounded_tool_count=invalid_bounded_tools,
|
| 74 |
+
invalid_bounded_tool_rate=invalid_bounded_tools / max(1, tool_trace_count),
|
| 75 |
rigor=(breakdown.rigor if breakdown is not None else 0.0),
|
| 76 |
feasibility=(breakdown.feasibility if breakdown is not None else 0.0),
|
| 77 |
fidelity=(breakdown.fidelity if breakdown is not None else 0.0),
|
|
|
|
| 90 |
average_rounds=0.0,
|
| 91 |
agreement_rate=0.0,
|
| 92 |
invalid_action_rate=0.0,
|
| 93 |
+
average_invalid_bounded_tool_rate=0.0,
|
| 94 |
average_rigor=0.0,
|
| 95 |
average_feasibility=0.0,
|
| 96 |
average_fidelity=0.0,
|
|
|
|
| 104 |
average_rounds=mean(item.rounds_used for item in metrics),
|
| 105 |
agreement_rate=mean(1.0 if item.agreement_reached else 0.0 for item in metrics),
|
| 106 |
invalid_action_rate=mean(item.invalid_action_rate for item in metrics),
|
| 107 |
+
average_invalid_bounded_tool_rate=mean(
|
| 108 |
+
item.invalid_bounded_tool_rate for item in metrics
|
| 109 |
+
),
|
| 110 |
average_rigor=mean(item.rigor for item in metrics),
|
| 111 |
average_feasibility=mean(item.feasibility for item in metrics),
|
| 112 |
average_fidelity=mean(item.fidelity for item in metrics),
|
|
|
|
| 115 |
)
|
| 116 |
|
| 117 |
|
| 118 |
+
def _count_invalid_bounded_tools(traces: list[dict[str, object]]) -> int:
|
| 119 |
+
invalid_count = 0
|
| 120 |
+
for trace in traces:
|
| 121 |
+
status = str(trace.get("status", "") or "").strip().lower()
|
| 122 |
+
error = trace.get("error")
|
| 123 |
+
valid = trace.get("valid")
|
| 124 |
+
if error:
|
| 125 |
+
invalid_count += 1
|
| 126 |
+
continue
|
| 127 |
+
if valid is False:
|
| 128 |
+
invalid_count += 1
|
| 129 |
+
continue
|
| 130 |
+
if status and status not in {"ok", "success", "succeeded", "completed"}:
|
| 131 |
+
invalid_count += 1
|
| 132 |
+
return invalid_count
|
| 133 |
+
|
| 134 |
+
|
| 135 |
__all__ = [
|
| 136 |
"EpisodeMetrics",
|
| 137 |
"EvaluationSummary",
|
replicalab/training/plots.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
from pathlib import Path
|
|
|
|
| 6 |
from typing import Iterable
|
| 7 |
|
| 8 |
|
|
@@ -77,7 +78,53 @@ def plot_evaluation_bars(
|
|
| 77 |
plt.close(fig)
|
| 78 |
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
__all__ = [
|
| 81 |
"plot_evaluation_bars",
|
|
|
|
| 82 |
"plot_training_history",
|
| 83 |
]
|
|
|
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
from pathlib import Path
|
| 6 |
+
from statistics import mean
|
| 7 |
from typing import Iterable
|
| 8 |
|
| 9 |
|
|
|
|
| 78 |
plt.close(fig)
|
| 79 |
|
| 80 |
|
| 81 |
+
def plot_metrics_by_step(
|
| 82 |
+
rows: Iterable[dict[str, object]],
|
| 83 |
+
*,
|
| 84 |
+
output_path: Path,
|
| 85 |
+
title: str,
|
| 86 |
+
metric_keys: list[str],
|
| 87 |
+
x_key: str = "training_step",
|
| 88 |
+
) -> None:
|
| 89 |
+
"""Plot averaged metric curves grouped by training step."""
|
| 90 |
+
|
| 91 |
+
matplotlib = __import__("matplotlib.pyplot", fromlist=["pyplot"])
|
| 92 |
+
plt = matplotlib
|
| 93 |
+
|
| 94 |
+
grouped: dict[int, dict[str, list[float]]] = {}
|
| 95 |
+
for row in rows:
|
| 96 |
+
raw_step = row.get(x_key)
|
| 97 |
+
if not isinstance(raw_step, int):
|
| 98 |
+
continue
|
| 99 |
+
bucket = grouped.setdefault(raw_step, {})
|
| 100 |
+
for metric_key in metric_keys:
|
| 101 |
+
raw_value = row.get(metric_key)
|
| 102 |
+
if isinstance(raw_value, (int, float)):
|
| 103 |
+
bucket.setdefault(metric_key, []).append(float(raw_value))
|
| 104 |
+
|
| 105 |
+
if not grouped:
|
| 106 |
+
raise ValueError(f"No '{x_key}' values found for metric plotting.")
|
| 107 |
+
|
| 108 |
+
steps = sorted(grouped)
|
| 109 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 110 |
+
fig, ax = plt.subplots(figsize=(10, 5))
|
| 111 |
+
for metric_key in metric_keys:
|
| 112 |
+
values = [
|
| 113 |
+
mean(grouped[step].get(metric_key, [0.0]))
|
| 114 |
+
for step in steps
|
| 115 |
+
]
|
| 116 |
+
ax.plot(steps, values, marker="o", label=metric_key.replace("_", " "))
|
| 117 |
+
ax.set_title(title)
|
| 118 |
+
ax.set_xlabel(x_key.replace("_", " "))
|
| 119 |
+
ax.grid(True, alpha=0.3)
|
| 120 |
+
ax.legend()
|
| 121 |
+
fig.tight_layout()
|
| 122 |
+
fig.savefig(output_path, dpi=160)
|
| 123 |
+
plt.close(fig)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
__all__ = [
|
| 127 |
"plot_evaluation_bars",
|
| 128 |
+
"plot_metrics_by_step",
|
| 129 |
"plot_training_history",
|
| 130 |
]
|
replicalab/training/rollout.py
CHANGED
|
@@ -29,6 +29,7 @@ returns a ``ScientistAction``. The baseline from
|
|
| 29 |
from __future__ import annotations
|
| 30 |
|
| 31 |
from dataclasses import dataclass, field
|
|
|
|
| 32 |
from typing import Any, Callable, Iterable, Optional
|
| 33 |
|
| 34 |
from replicalab.client import ReplicaLabClient
|
|
@@ -88,7 +89,7 @@ class EpisodeRecord:
|
|
| 88 |
|
| 89 |
|
| 90 |
# Type alias for the policy callable
|
| 91 |
-
PolicyFn = Callable[
|
| 92 |
|
| 93 |
|
| 94 |
class RolloutWorker:
|
|
@@ -147,7 +148,13 @@ class RolloutWorker:
|
|
| 147 |
raise RuntimeError("Reset returned no scientist observation")
|
| 148 |
|
| 149 |
for step_idx in range(self._max_steps):
|
| 150 |
-
action =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
result: StepResult = self._client.step(action)
|
| 153 |
tool_traces = _extract_tool_traces(result.info)
|
|
@@ -221,3 +228,22 @@ def _extract_tool_traces(info: StepInfo) -> list[dict[str, Any]]:
|
|
| 221 |
if isinstance(item, dict):
|
| 222 |
traces.append(dict(item))
|
| 223 |
return traces
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
from __future__ import annotations
|
| 30 |
|
| 31 |
from dataclasses import dataclass, field
|
| 32 |
+
from inspect import signature
|
| 33 |
from typing import Any, Callable, Iterable, Optional
|
| 34 |
|
| 35 |
from replicalab.client import ReplicaLabClient
|
|
|
|
| 89 |
|
| 90 |
|
| 91 |
# Type alias for the policy callable
|
| 92 |
+
PolicyFn = Callable[..., ScientistAction]
|
| 93 |
|
| 94 |
|
| 95 |
class RolloutWorker:
|
|
|
|
| 148 |
raise RuntimeError("Reset returned no scientist observation")
|
| 149 |
|
| 150 |
for step_idx in range(self._max_steps):
|
| 151 |
+
action = _invoke_policy(
|
| 152 |
+
policy_fn,
|
| 153 |
+
scientist_obs,
|
| 154 |
+
seed=seed,
|
| 155 |
+
scenario=scenario,
|
| 156 |
+
difficulty=difficulty,
|
| 157 |
+
)
|
| 158 |
|
| 159 |
result: StepResult = self._client.step(action)
|
| 160 |
tool_traces = _extract_tool_traces(result.info)
|
|
|
|
| 228 |
if isinstance(item, dict):
|
| 229 |
traces.append(dict(item))
|
| 230 |
return traces
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def _invoke_policy(
|
| 234 |
+
policy_fn: PolicyFn,
|
| 235 |
+
observation: ScientistObservation,
|
| 236 |
+
*,
|
| 237 |
+
seed: int,
|
| 238 |
+
scenario: str,
|
| 239 |
+
difficulty: str,
|
| 240 |
+
) -> ScientistAction:
|
| 241 |
+
parameters = signature(policy_fn).parameters
|
| 242 |
+
if len(parameters) <= 1:
|
| 243 |
+
return policy_fn(observation)
|
| 244 |
+
return policy_fn(
|
| 245 |
+
observation,
|
| 246 |
+
seed=seed,
|
| 247 |
+
scenario=scenario,
|
| 248 |
+
difficulty=difficulty,
|
| 249 |
+
)
|
requirements-train.txt
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Training dependencies (GPU required)
|
| 2 |
+
# Installed on top of server/requirements.txt
|
| 3 |
+
|
| 4 |
+
# RL training framework
|
| 5 |
+
trl>=0.15,<1.0
|
| 6 |
+
|
| 7 |
+
# Dataset handling
|
| 8 |
+
datasets>=3.0,<4.0
|
| 9 |
+
|
| 10 |
+
# Unsloth for fast LoRA fine-tuning
|
| 11 |
+
unsloth>=2025.3
|
| 12 |
+
|
| 13 |
+
# vLLM for fast inference during GRPO rollouts
|
| 14 |
+
vllm>=0.7
|
| 15 |
+
|
| 16 |
+
# Plotting
|
| 17 |
+
matplotlib>=3.9,<4.0
|
| 18 |
+
|
| 19 |
+
# Already in server/requirements.txt but listed for completeness
|
| 20 |
+
pydantic>=2.7,<3.0
|
| 21 |
+
httpx>=0.27,<1.0
|
| 22 |
+
websocket-client>=1.7,<2.0
|
scripts/train.sh
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# ReplicaLab training entrypoint for Northflank GPU jobs.
|
| 3 |
+
#
|
| 4 |
+
# Usage:
|
| 5 |
+
# MODE=train ./scripts/train.sh # full training (scientist + lab manager)
|
| 6 |
+
# MODE=scientist ./scripts/train.sh # scientist GRPO only
|
| 7 |
+
# MODE=lab-manager ./scripts/train.sh # lab manager SFT only
|
| 8 |
+
# MODE=eval ./scripts/train.sh # baseline evaluation only
|
| 9 |
+
# MODE=server ./scripts/train.sh # just run server (default)
|
| 10 |
+
#
|
| 11 |
+
# The script starts the ReplicaLab server in the background (needed for
|
| 12 |
+
# rollout evaluation), then runs the requested training flow.
|
| 13 |
+
|
| 14 |
+
set -euo pipefail
|
| 15 |
+
|
| 16 |
+
MODE="${MODE:-server}"
|
| 17 |
+
SEED_COUNT="${SEED_COUNT:-8}"
|
| 18 |
+
MAX_STEPS="${MAX_STEPS:-300}"
|
| 19 |
+
MODEL_NAME="${MODEL_NAME:-Qwen/Qwen3-8B}"
|
| 20 |
+
PERSIST_ROOT="${REPLICALAB_PERSIST_ROOT:-/app/outputs/training}"
|
| 21 |
+
BASE_URL="http://localhost:7860"
|
| 22 |
+
|
| 23 |
+
echo "=========================================="
|
| 24 |
+
echo " ReplicaLab Training Pipeline"
|
| 25 |
+
echo "=========================================="
|
| 26 |
+
echo " Mode: $MODE"
|
| 27 |
+
echo " Model: $MODEL_NAME"
|
| 28 |
+
echo " Seeds: $SEED_COUNT"
|
| 29 |
+
echo " Max steps: $MAX_STEPS"
|
| 30 |
+
echo " Persist: $PERSIST_ROOT"
|
| 31 |
+
echo " Server URL: $BASE_URL"
|
| 32 |
+
echo "=========================================="
|
| 33 |
+
|
| 34 |
+
# ── Start server in background (needed for eval rollouts) ──────────────
|
| 35 |
+
start_server() {
|
| 36 |
+
echo "[train.sh] Starting ReplicaLab server on port 7860..."
|
| 37 |
+
uvicorn server.app:app --host 0.0.0.0 --port 7860 &
|
| 38 |
+
SERVER_PID=$!
|
| 39 |
+
echo "[train.sh] Server PID: $SERVER_PID"
|
| 40 |
+
|
| 41 |
+
# Wait for server to be ready
|
| 42 |
+
for i in $(seq 1 30); do
|
| 43 |
+
if curl -sf http://localhost:7860/health > /dev/null 2>&1; then
|
| 44 |
+
echo "[train.sh] Server is ready."
|
| 45 |
+
return 0
|
| 46 |
+
fi
|
| 47 |
+
sleep 1
|
| 48 |
+
done
|
| 49 |
+
echo "[train.sh] WARNING: Server did not become ready in 30s, continuing anyway."
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
# ── Scientist GRPO training ───────────────────────────────────────────
|
| 53 |
+
run_scientist_train() {
|
| 54 |
+
echo ""
|
| 55 |
+
echo "=== Phase 1: Scientist GRPO Training ==="
|
| 56 |
+
echo ""
|
| 57 |
+
|
| 58 |
+
# Preview first (no GPU needed)
|
| 59 |
+
python -m replicalab.training.cli scientist-preview \
|
| 60 |
+
--persist-root "$PERSIST_ROOT" \
|
| 61 |
+
--model-name "$MODEL_NAME" \
|
| 62 |
+
--seed-count "$SEED_COUNT"
|
| 63 |
+
|
| 64 |
+
# Full training
|
| 65 |
+
python -m replicalab.training.cli scientist-train \
|
| 66 |
+
--persist-root "$PERSIST_ROOT" \
|
| 67 |
+
--model-name "$MODEL_NAME" \
|
| 68 |
+
--seed-count "$SEED_COUNT" \
|
| 69 |
+
--max-steps "$MAX_STEPS"
|
| 70 |
+
|
| 71 |
+
echo "[train.sh] Scientist GRPO training complete."
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
# ── Lab Manager SFT training ─────────────────────────────────────────
|
| 75 |
+
run_lab_manager_train() {
|
| 76 |
+
echo ""
|
| 77 |
+
echo "=== Phase 2: Lab Manager SFT Training ==="
|
| 78 |
+
echo ""
|
| 79 |
+
|
| 80 |
+
# Preview first
|
| 81 |
+
python -m replicalab.training.cli lab-manager-preview \
|
| 82 |
+
--persist-root "$PERSIST_ROOT" \
|
| 83 |
+
--model-name "$MODEL_NAME" \
|
| 84 |
+
--seed-count "$SEED_COUNT"
|
| 85 |
+
|
| 86 |
+
# Full training
|
| 87 |
+
python -m replicalab.training.cli lab-manager-train \
|
| 88 |
+
--persist-root "$PERSIST_ROOT" \
|
| 89 |
+
--model-name "$MODEL_NAME" \
|
| 90 |
+
--seed-count "$SEED_COUNT"
|
| 91 |
+
|
| 92 |
+
echo "[train.sh] Lab Manager SFT training complete."
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
# ── Baseline evaluation ──────────────────────────────────────────────
|
| 96 |
+
run_eval() {
|
| 97 |
+
echo ""
|
| 98 |
+
echo "=== Baseline Evaluation ==="
|
| 99 |
+
echo ""
|
| 100 |
+
|
| 101 |
+
python -m replicalab.training.cli baseline-eval \
|
| 102 |
+
--persist-root "$PERSIST_ROOT" \
|
| 103 |
+
--base-url "$BASE_URL" \
|
| 104 |
+
--seed-count "$SEED_COUNT"
|
| 105 |
+
|
| 106 |
+
echo "[train.sh] Evaluation complete."
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
# ── Mode dispatch ────────────────────────────────────────────────────
|
| 110 |
+
|
| 111 |
+
case "$MODE" in
|
| 112 |
+
server)
|
| 113 |
+
echo "[train.sh] Server-only mode."
|
| 114 |
+
exec uvicorn server.app:app --host 0.0.0.0 --port 7860
|
| 115 |
+
;;
|
| 116 |
+
|
| 117 |
+
train)
|
| 118 |
+
start_server
|
| 119 |
+
run_scientist_train
|
| 120 |
+
run_lab_manager_train
|
| 121 |
+
run_eval
|
| 122 |
+
echo ""
|
| 123 |
+
echo "=========================================="
|
| 124 |
+
echo " All training complete!"
|
| 125 |
+
echo " Artifacts saved to: $PERSIST_ROOT"
|
| 126 |
+
echo "=========================================="
|
| 127 |
+
# Keep container alive so artifacts can be retrieved
|
| 128 |
+
echo "[train.sh] Training done. Keeping container alive..."
|
| 129 |
+
wait $SERVER_PID
|
| 130 |
+
;;
|
| 131 |
+
|
| 132 |
+
scientist)
|
| 133 |
+
run_scientist_train
|
| 134 |
+
;;
|
| 135 |
+
|
| 136 |
+
lab-manager)
|
| 137 |
+
run_lab_manager_train
|
| 138 |
+
;;
|
| 139 |
+
|
| 140 |
+
eval)
|
| 141 |
+
start_server
|
| 142 |
+
run_eval
|
| 143 |
+
wait $SERVER_PID
|
| 144 |
+
;;
|
| 145 |
+
|
| 146 |
+
*)
|
| 147 |
+
echo "Unknown MODE: $MODE"
|
| 148 |
+
echo "Valid modes: server, train, scientist, lab-manager, eval"
|
| 149 |
+
exit 1
|
| 150 |
+
;;
|
| 151 |
+
esac
|
tests/test_server.py
CHANGED
|
@@ -96,12 +96,11 @@ class TestRootEndpoint:
|
|
| 96 |
|
| 97 |
def test_root_mentions_core_api_endpoints(self, client: TestClient) -> None:
|
| 98 |
body = client.get("/").text
|
| 99 |
-
|
| 100 |
-
assert "
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
assert "WS /ws" in body
|
| 105 |
|
| 106 |
|
| 107 |
class TestWebFallback:
|
|
|
|
| 96 |
|
| 97 |
def test_root_mentions_core_api_endpoints(self, client: TestClient) -> None:
|
| 98 |
body = client.get("/").text
|
| 99 |
+
# When frontend/dist exists, root serves the SPA; otherwise the API landing
|
| 100 |
+
assert "ReplicaLab" in body
|
| 101 |
+
if "ReplicaLab API" in body:
|
| 102 |
+
assert "GET /health" in body
|
| 103 |
+
assert "POST /reset" in body
|
|
|
|
| 104 |
|
| 105 |
|
| 106 |
class TestWebFallback:
|
tests/test_training_cli.py
CHANGED
|
@@ -4,6 +4,7 @@ import json
|
|
| 4 |
|
| 5 |
from replicalab.models import RewardBreakdown
|
| 6 |
from replicalab.training.cli import main
|
|
|
|
| 7 |
from replicalab.training.metrics import EvaluationSummary
|
| 8 |
from replicalab.training.rollout import EpisodeRecord
|
| 9 |
|
|
@@ -57,6 +58,7 @@ def test_baseline_eval_cli_writes_summary_and_metrics(tmp_path, monkeypatch) ->
|
|
| 57 |
average_rounds=1.0,
|
| 58 |
agreement_rate=1.0,
|
| 59 |
invalid_action_rate=0.0,
|
|
|
|
| 60 |
average_rigor=0.6,
|
| 61 |
average_feasibility=0.8,
|
| 62 |
average_fidelity=0.7,
|
|
@@ -95,3 +97,94 @@ def test_baseline_eval_cli_writes_summary_and_metrics(tmp_path, monkeypatch) ->
|
|
| 95 |
metric = json.loads(metrics_lines[0])
|
| 96 |
assert metric["scenario"] == "ml_benchmark"
|
| 97 |
assert metric["agreement_reached"] is True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
from replicalab.models import RewardBreakdown
|
| 6 |
from replicalab.training.cli import main
|
| 7 |
+
from replicalab.training.evaluation import PolicyComparisonRow
|
| 8 |
from replicalab.training.metrics import EvaluationSummary
|
| 9 |
from replicalab.training.rollout import EpisodeRecord
|
| 10 |
|
|
|
|
| 58 |
average_rounds=1.0,
|
| 59 |
agreement_rate=1.0,
|
| 60 |
invalid_action_rate=0.0,
|
| 61 |
+
average_invalid_bounded_tool_rate=0.0,
|
| 62 |
average_rigor=0.6,
|
| 63 |
average_feasibility=0.8,
|
| 64 |
average_fidelity=0.7,
|
|
|
|
| 97 |
metric = json.loads(metrics_lines[0])
|
| 98 |
assert metric["scenario"] == "ml_benchmark"
|
| 99 |
assert metric["agreement_reached"] is True
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def test_scientist_compare_eval_cli_writes_rows(tmp_path, monkeypatch) -> None:
|
| 103 |
+
baseline_record = EpisodeRecord(
|
| 104 |
+
seed=101,
|
| 105 |
+
scenario="ml_benchmark",
|
| 106 |
+
difficulty="easy",
|
| 107 |
+
episode_id="baseline-1",
|
| 108 |
+
total_reward=1.0,
|
| 109 |
+
reward_breakdown=RewardBreakdown(rigor=0.4, feasibility=0.5, fidelity=0.6),
|
| 110 |
+
verdict="timeout",
|
| 111 |
+
agreement_reached=False,
|
| 112 |
+
)
|
| 113 |
+
trained_record = EpisodeRecord(
|
| 114 |
+
seed=101,
|
| 115 |
+
scenario="ml_benchmark",
|
| 116 |
+
difficulty="easy",
|
| 117 |
+
episode_id="trained-1",
|
| 118 |
+
total_reward=3.5,
|
| 119 |
+
reward_breakdown=RewardBreakdown(rigor=0.8, feasibility=0.9, fidelity=0.85),
|
| 120 |
+
verdict="accept",
|
| 121 |
+
agreement_reached=True,
|
| 122 |
+
)
|
| 123 |
+
rows = [
|
| 124 |
+
PolicyComparisonRow(
|
| 125 |
+
label="baseline",
|
| 126 |
+
episode_count=1,
|
| 127 |
+
average_reward=1.0,
|
| 128 |
+
average_rounds=2.0,
|
| 129 |
+
agreement_rate=0.0,
|
| 130 |
+
invalid_action_rate=0.5,
|
| 131 |
+
average_invalid_bounded_tool_rate=0.0,
|
| 132 |
+
average_rigor=0.4,
|
| 133 |
+
average_feasibility=0.5,
|
| 134 |
+
average_fidelity=0.6,
|
| 135 |
+
average_parsimony=1.0,
|
| 136 |
+
average_tool_trace_count=0.0,
|
| 137 |
+
),
|
| 138 |
+
PolicyComparisonRow(
|
| 139 |
+
label="trained",
|
| 140 |
+
episode_count=1,
|
| 141 |
+
average_reward=3.5,
|
| 142 |
+
average_rounds=1.0,
|
| 143 |
+
agreement_rate=1.0,
|
| 144 |
+
invalid_action_rate=0.0,
|
| 145 |
+
average_invalid_bounded_tool_rate=0.0,
|
| 146 |
+
average_rigor=0.8,
|
| 147 |
+
average_feasibility=0.9,
|
| 148 |
+
average_fidelity=0.85,
|
| 149 |
+
average_parsimony=1.0,
|
| 150 |
+
average_tool_trace_count=0.0,
|
| 151 |
+
),
|
| 152 |
+
]
|
| 153 |
+
|
| 154 |
+
monkeypatch.setattr(
|
| 155 |
+
"replicalab.training.cli.build_remote_scientist_policy",
|
| 156 |
+
lambda **_: (lambda _obs: None),
|
| 157 |
+
)
|
| 158 |
+
monkeypatch.setattr(
|
| 159 |
+
"replicalab.training.cli.compare_policies",
|
| 160 |
+
lambda **_: (
|
| 161 |
+
{"baseline": [baseline_record], "trained": [trained_record]},
|
| 162 |
+
rows,
|
| 163 |
+
),
|
| 164 |
+
)
|
| 165 |
+
monkeypatch.setattr(
|
| 166 |
+
"replicalab.training.cli.plot_evaluation_bars",
|
| 167 |
+
lambda *args, **kwargs: None,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
exit_code = main(
|
| 171 |
+
[
|
| 172 |
+
"scientist-compare-eval",
|
| 173 |
+
"--persist-root",
|
| 174 |
+
str(tmp_path),
|
| 175 |
+
"--run-name",
|
| 176 |
+
"compare-eval-test",
|
| 177 |
+
"--eval-seeds",
|
| 178 |
+
"101",
|
| 179 |
+
"--scenarios",
|
| 180 |
+
"ml_benchmark",
|
| 181 |
+
"--difficulties",
|
| 182 |
+
"easy",
|
| 183 |
+
]
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
assert exit_code == 0
|
| 187 |
+
summary_path = tmp_path / "compare-eval-test" / "reports" / "summary.json"
|
| 188 |
+
payload = json.loads(summary_path.read_text(encoding="utf-8"))
|
| 189 |
+
assert [row["label"] for row in payload["rows"]] == ["baseline", "trained"]
|
| 190 |
+
assert payload["rows"][1]["average_reward"] == 3.5
|
tests/test_training_metrics.py
CHANGED
|
@@ -7,7 +7,11 @@ from replicalab.training.metrics import episode_to_metrics, summarize_episodes
|
|
| 7 |
from replicalab.training.rollout import EpisodeRecord, StepRecord
|
| 8 |
|
| 9 |
|
| 10 |
-
def _build_step_record(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
return StepRecord(
|
| 12 |
round_number=0,
|
| 13 |
observation=ScientistObservation(
|
|
@@ -36,6 +40,7 @@ def _build_step_record(error: str | None = None) -> StepRecord:
|
|
| 36 |
done=False,
|
| 37 |
error=error,
|
| 38 |
info=StepInfo(error=error),
|
|
|
|
| 39 |
)
|
| 40 |
|
| 41 |
|
|
@@ -55,12 +60,18 @@ def test_episode_to_metrics_counts_invalid_actions() -> None:
|
|
| 55 |
),
|
| 56 |
verdict="accept",
|
| 57 |
agreement_reached=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
)
|
| 59 |
|
| 60 |
metrics = episode_to_metrics(record)
|
| 61 |
|
| 62 |
assert metrics.invalid_action_count == 1
|
| 63 |
assert metrics.invalid_action_rate == 0.5
|
|
|
|
|
|
|
| 64 |
assert metrics.agreement_reached is True
|
| 65 |
|
| 66 |
|
|
@@ -75,6 +86,7 @@ def test_summarize_episodes_aggregates_rewards() -> None:
|
|
| 75 |
reward_breakdown=RewardBreakdown(rigor=0.6, feasibility=0.7, fidelity=0.8),
|
| 76 |
verdict="accept",
|
| 77 |
agreement_reached=True,
|
|
|
|
| 78 |
)
|
| 79 |
second = EpisodeRecord(
|
| 80 |
seed=2,
|
|
@@ -86,6 +98,7 @@ def test_summarize_episodes_aggregates_rewards() -> None:
|
|
| 86 |
reward_breakdown=RewardBreakdown(rigor=0.2, feasibility=0.4, fidelity=0.5),
|
| 87 |
verdict="timeout",
|
| 88 |
agreement_reached=False,
|
|
|
|
| 89 |
)
|
| 90 |
|
| 91 |
summary = summarize_episodes([first, second])
|
|
@@ -93,3 +106,4 @@ def test_summarize_episodes_aggregates_rewards() -> None:
|
|
| 93 |
assert summary.episode_count == 2
|
| 94 |
assert summary.average_reward == 1.25
|
| 95 |
assert 0.0 < summary.invalid_action_rate < 1.0
|
|
|
|
|
|
| 7 |
from replicalab.training.rollout import EpisodeRecord, StepRecord
|
| 8 |
|
| 9 |
|
| 10 |
+
def _build_step_record(
|
| 11 |
+
error: str | None = None,
|
| 12 |
+
*,
|
| 13 |
+
tool_traces: list[dict[str, object]] | None = None,
|
| 14 |
+
) -> StepRecord:
|
| 15 |
return StepRecord(
|
| 16 |
round_number=0,
|
| 17 |
observation=ScientistObservation(
|
|
|
|
| 40 |
done=False,
|
| 41 |
error=error,
|
| 42 |
info=StepInfo(error=error),
|
| 43 |
+
tool_traces=tool_traces or [],
|
| 44 |
)
|
| 45 |
|
| 46 |
|
|
|
|
| 60 |
),
|
| 61 |
verdict="accept",
|
| 62 |
agreement_reached=True,
|
| 63 |
+
tool_traces=[
|
| 64 |
+
{"tool": "search_evidence", "status": "ok"},
|
| 65 |
+
{"tool": "run_code_check", "status": "error", "error": "timeout"},
|
| 66 |
+
],
|
| 67 |
)
|
| 68 |
|
| 69 |
metrics = episode_to_metrics(record)
|
| 70 |
|
| 71 |
assert metrics.invalid_action_count == 1
|
| 72 |
assert metrics.invalid_action_rate == 0.5
|
| 73 |
+
assert metrics.invalid_bounded_tool_count == 1
|
| 74 |
+
assert metrics.invalid_bounded_tool_rate == 0.5
|
| 75 |
assert metrics.agreement_reached is True
|
| 76 |
|
| 77 |
|
|
|
|
| 86 |
reward_breakdown=RewardBreakdown(rigor=0.6, feasibility=0.7, fidelity=0.8),
|
| 87 |
verdict="accept",
|
| 88 |
agreement_reached=True,
|
| 89 |
+
tool_traces=[{"tool": "search_evidence", "status": "ok"}],
|
| 90 |
)
|
| 91 |
second = EpisodeRecord(
|
| 92 |
seed=2,
|
|
|
|
| 98 |
reward_breakdown=RewardBreakdown(rigor=0.2, feasibility=0.4, fidelity=0.5),
|
| 99 |
verdict="timeout",
|
| 100 |
agreement_reached=False,
|
| 101 |
+
tool_traces=[{"tool": "run_code_check", "status": "error"}],
|
| 102 |
)
|
| 103 |
|
| 104 |
summary = summarize_episodes([first, second])
|
|
|
|
| 106 |
assert summary.episode_count == 2
|
| 107 |
assert summary.average_reward == 1.25
|
| 108 |
assert 0.0 < summary.invalid_action_rate < 1.0
|
| 109 |
+
assert summary.average_invalid_bounded_tool_rate == 0.5
|