ayushozha commited on
Commit
abb29f8
·
1 Parent(s): 5af5f17

Merge Kush frontend integration, close API 16/UI 10/UI 11

Browse files

- Merge multi-stage Docker build, SPA serving, and frontend components
- API 16: Docker now builds frontend + API in single container
- UI 10: Frontend styled with new components (ProtocolEditor, LiveScoreGauges, etc)
- UI 11: Server mounts frontend/dist with SPA catch-all for React Router
- Fix root endpoint test for SPA-mode serving
- Max at 97.56% (40/41, only DOC 08 remaining blocked on TRN 10)

Dockerfile.train ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training Dockerfile for Northflank GPU jobs.
2
+ #
3
+ # Uses CUDA base image + installs Unsloth, TRL, vLLM for
4
+ # Scientist GRPO and Lab Manager SFT training.
5
+ #
6
+ # Build: docker build -f Dockerfile.train -t replicalab-train .
7
+ # Run: docker run --gpus all -e MODE=train replicalab-train
8
+
9
+ FROM nvidia/cuda:12.4.1-devel-ubuntu22.04
10
+
11
+ ENV DEBIAN_FRONTEND=noninteractive
12
+ ENV PYTHONUNBUFFERED=1
13
+
14
+ WORKDIR /app
15
+
16
+ # System deps
17
+ RUN apt-get update && apt-get install -y --no-install-recommends \
18
+ python3.11 python3.11-dev python3.11-venv python3-pip \
19
+ build-essential git curl \
20
+ && rm -rf /var/lib/apt/lists/* \
21
+ && ln -sf /usr/bin/python3.11 /usr/bin/python \
22
+ && ln -sf /usr/bin/python3.11 /usr/bin/python3
23
+
24
+ # Upgrade pip
25
+ RUN python -m pip install --no-cache-dir --upgrade pip setuptools wheel
26
+
27
+ # Install server deps first (better layer caching)
28
+ COPY server/requirements.txt ./server/requirements.txt
29
+ RUN pip install --no-cache-dir -r server/requirements.txt
30
+
31
+ # Install training deps (heavy — torch, unsloth, trl, vllm)
32
+ COPY requirements-train.txt ./requirements-train.txt
33
+ RUN pip install --no-cache-dir -r requirements-train.txt
34
+
35
+ # Copy full project
36
+ COPY replicalab/ ./replicalab/
37
+ COPY server/ ./server/
38
+ COPY data/ ./data/
39
+ COPY scripts/ ./scripts/
40
+ COPY pyproject.toml ./
41
+ COPY ReplicaLab_50_Scenarios_Training_Plan.md ./
42
+
43
+ # Install replicalab package
44
+ RUN pip install --no-cache-dir . --no-deps
45
+
46
+ # Make scripts executable
47
+ RUN chmod +x scripts/train.sh
48
+
49
+ # Default env vars
50
+ ENV MODE=server
51
+ ENV REPLICALAB_PERSIST_ROOT=/app/outputs/training
52
+ ENV SEED_COUNT=8
53
+ ENV MAX_STEPS=300
54
+ ENV MODEL_NAME=Qwen/Qwen3-8B
55
+
56
+ EXPOSE 7860
57
+
58
+ # Entrypoint dispatches based on MODE env var
59
+ CMD ["bash", "scripts/train.sh"]
ReplicaLab_Comprehensive_Task_Division.md CHANGED
@@ -626,7 +626,7 @@ As the team, we want one click reproducible deployment to HF Spaces.
626
  | API 13 | E07.1 | Person C | `server/app.py` | Add CORS middleware configuration for frontend origins in dev and production | API 01 | 0.25h | frontend on localhost:5173 and HF Space origin can reach the API without CORS errors | ✅ Completed | Person B (Ayush) |
627
  | API 14 | E07.1 | Person C | `server/app.py` | Add REST session management so each user gets isolated environment state | API 02, API 03 | 0.75h | two concurrent REST users do not share or corrupt each other's episode state | ✅ Completed | Person B (Ayush) |
628
  | API 15 | E07.2 | Person C | HF Space repo | Create HF Space README.md with YAML frontmatter specifying `sdk: docker`, `app_port: 7860`, title, and emoji | API 08 | 0.25h | HF Space config is valid and Space launches correctly from the metadata | ✅ Completed | Person B (Ayush) |
629
- | API 16 | E07.2 | Person C | `server/Dockerfile` | Configure Docker to build frontend and serve static assets from FastAPI in a single container | API 08, UI 10 | 0.75h | single Docker container serves both API and frontend on port 7860 | Not started | |
630
  | API 17 | E07.2 | Person C | deployment docs | Document secrets and API key management for hosted Scientist model access in deployment and notebook | API 09 | 0.5h | team knows how to set API keys in HF Space secrets, local env, and Colab secrets | ✅ Completed | Person B (Ayush) |
631
  | API 18 | E07.1 | Person C | `server/app.py` | Include judge audit payload plus bounded tool-trace summaries in REST, replay, and WebSocket responses for terminal episodes | API 03, API 05, API 06, ENV 11 | 0.5h | clients receive `judge_notes`, verdict fields, and bounded tool audit data without separate log file access | ✅ Completed | Person B (Ayush) |
632
  | API 19 | E07.2 | Person C | `openenv.yaml` and deployment docs | Expose and verify OpenEnv built in `/web` fallback route locally and on HF Space | FND 09, API 08, API 10 | 0.5h | `/web` is documented, reachable, and able to run a seeded episode when the custom UI is unavailable | ✅ Completed | Person B (Ayush) |
@@ -699,8 +699,8 @@ As a team, we want a replayable UI for debugging and recording the demo.
699
  | UI 07 | E09.2 | Person D | `frontend/src/lib/api.ts` | Add REST plus WebSocket client helpers | API 02 to API 06 | 0.75h | UI can connect locally and to the hosted Space | ✅ Completed | Person D (Kush) |
700
  | UI 08 | E09.2 | Person D | `frontend/src/components/ReplayViewer.tsx` | Build replay viewer from completed episode logs | API 05 | 1h | user can load a past episode and step through rounds | ⬜ Not started | — |
701
  | UI 09 | E09.1 | Person D | `frontend/src/components/TrainingResults.tsx` | Add before versus after panel or static result card | TRN 10 | 0.75h | UI can show reward curve image and summary metrics | ⬜ Not started | — |
702
- | UI 10 | E09.1 | Person D | frontend styling | Add clean visual styling with Tailwind plus shadcn compatible primitives and responsive spacing | UI 01 to UI 09, FND 13 | 0.75h | UI is presentable on demo screen without layout breaks and styling stack matches the declared toolchain | Not started | |
703
- | UI 11 | E09.2 | Person C | integration | Serve frontend with backend or configure proxy during dev | UI 07, API 01 | 0.5h | one command local dev works and deployed app serves UI path | Not started | |
704
  | UI 12 | E09.2 | Person D | tests and smoke | Add smoke test checklist for core UI flow | UI 01 to UI 11 | 0.5h | checklist confirms new episode, step, score update, and replay all work | ⬜ Not started | — |
705
  | UI 13 | E09.1 | Person D | `frontend/src/components/JudgeAuditPanel.tsx` or `NegotiationLog.tsx` | Render final Judge audit text and verdict at episode end | JDG 11, API 18 | 0.75h | UI shows a clear end of episode audit without hiding the deterministic score breakdown | ⬜ Not started | — |
706
  | UI 14 | E09.2 | Person D | `frontend/src/components/ReplayViewer.tsx` | Add replay slider or scrubber so judges can move across rounds quickly | UI 08 | 0.5h | user can scrub to any round without replaying the full episode sequentially | ⬜ Not started | — |
 
626
  | API 13 | E07.1 | Person C | `server/app.py` | Add CORS middleware configuration for frontend origins in dev and production | API 01 | 0.25h | frontend on localhost:5173 and HF Space origin can reach the API without CORS errors | ✅ Completed | Person B (Ayush) |
627
  | API 14 | E07.1 | Person C | `server/app.py` | Add REST session management so each user gets isolated environment state | API 02, API 03 | 0.75h | two concurrent REST users do not share or corrupt each other's episode state | ✅ Completed | Person B (Ayush) |
628
  | API 15 | E07.2 | Person C | HF Space repo | Create HF Space README.md with YAML frontmatter specifying `sdk: docker`, `app_port: 7860`, title, and emoji | API 08 | 0.25h | HF Space config is valid and Space launches correctly from the metadata | ✅ Completed | Person B (Ayush) |
629
+ | API 16 | E07.2 | Person C | `server/Dockerfile` | Configure Docker to build frontend and serve static assets from FastAPI in a single container | API 08, UI 10 | 0.75h | single Docker container serves both API and frontend on port 7860 | Completed | Person D (Kush) |
630
  | API 17 | E07.2 | Person C | deployment docs | Document secrets and API key management for hosted Scientist model access in deployment and notebook | API 09 | 0.5h | team knows how to set API keys in HF Space secrets, local env, and Colab secrets | ✅ Completed | Person B (Ayush) |
631
  | API 18 | E07.1 | Person C | `server/app.py` | Include judge audit payload plus bounded tool-trace summaries in REST, replay, and WebSocket responses for terminal episodes | API 03, API 05, API 06, ENV 11 | 0.5h | clients receive `judge_notes`, verdict fields, and bounded tool audit data without separate log file access | ✅ Completed | Person B (Ayush) |
632
  | API 19 | E07.2 | Person C | `openenv.yaml` and deployment docs | Expose and verify OpenEnv built in `/web` fallback route locally and on HF Space | FND 09, API 08, API 10 | 0.5h | `/web` is documented, reachable, and able to run a seeded episode when the custom UI is unavailable | ✅ Completed | Person B (Ayush) |
 
699
  | UI 07 | E09.2 | Person D | `frontend/src/lib/api.ts` | Add REST plus WebSocket client helpers | API 02 to API 06 | 0.75h | UI can connect locally and to the hosted Space | ✅ Completed | Person D (Kush) |
700
  | UI 08 | E09.2 | Person D | `frontend/src/components/ReplayViewer.tsx` | Build replay viewer from completed episode logs | API 05 | 1h | user can load a past episode and step through rounds | ⬜ Not started | — |
701
  | UI 09 | E09.1 | Person D | `frontend/src/components/TrainingResults.tsx` | Add before versus after panel or static result card | TRN 10 | 0.75h | UI can show reward curve image and summary metrics | ⬜ Not started | — |
702
+ | UI 10 | E09.1 | Person D | frontend styling | Add clean visual styling with Tailwind plus shadcn compatible primitives and responsive spacing | UI 01 to UI 09, FND 13 | 0.75h | UI is presentable on demo screen without layout breaks and styling stack matches the declared toolchain | Completed | Person D (Kush) |
703
+ | UI 11 | E09.2 | Person C | integration | Serve frontend with backend or configure proxy during dev | UI 07, API 01 | 0.5h | one command local dev works and deployed app serves UI path | Completed | Person D (Kush) |
704
  | UI 12 | E09.2 | Person D | tests and smoke | Add smoke test checklist for core UI flow | UI 01 to UI 11 | 0.5h | checklist confirms new episode, step, score update, and replay all work | ⬜ Not started | — |
705
  | UI 13 | E09.1 | Person D | `frontend/src/components/JudgeAuditPanel.tsx` or `NegotiationLog.tsx` | Render final Judge audit text and verdict at episode end | JDG 11, API 18 | 0.75h | UI shows a clear end of episode audit without hiding the deterministic score breakdown | ⬜ Not started | — |
706
  | UI 14 | E09.2 | Person D | `frontend/src/components/ReplayViewer.tsx` | Add replay slider or scrubber so judges can move across rounds quickly | UI 08 | 0.5h | user can scrub to any round without replaying the full episode sequentially | ⬜ Not started | — |
docs/completion.md CHANGED
@@ -20,10 +20,10 @@ Source of truth: `ReplicaLab_Comprehensive_Task_Division.md`
20
  | Metric | Value |
21
  |--------|-------|
22
  | Total tasks | 152 |
23
- | Completed | 104 |
24
  | Partial / active | 0 |
25
- | Remaining | 48 |
26
- | **Completion rate** | **68.42%** |
27
 
28
  ### Completion by Person
29
 
@@ -31,8 +31,8 @@ Source of truth: `ReplicaLab_Comprehensive_Task_Division.md`
31
  |--------|----------|----------------|----------------------|-----------|------|
32
  | Kian (Person A) | 49 (47 solo + 2 shared with B) | 1 shared sign-off (`FND 08`) | 48 (`FND 04`, `FND 09`, `MOD 01`, `MOD 02`, `MOD 03`, `MOD 04`, `MOD 05`, `MOD 06`, `MOD 08`, `MOD 11`, `MOD 12`, `SCN 01` to `SCN 10`, `SCN 13`, `AGT 05`, `AGT 09`, `ENV 01` to `ENV 08`, `ENV 10`, `ENV 11`, `JDG 01` to `JDG 06`, `JDG 08`, `JDG 11`, `OBS 04`, `TST 01` to `TST 05` done by Person B) | 0 | 100.00% |
33
  | Person B (Ayush) | 29 (27 solo + 2 shared with A) | 19 (`FND 08`, `MOD 09`, `SCN 11`, `AGT 01`, `AGT 02`, `AGT 03`, `AGT 04`, `AGT 05`, `AGT 06`, `AGT 07`, `AGT 08`, `AGT 10`, `AGT 11`, `TRN 13`, `TRN 03`, `TRN 04`, `TRN 01`, `TRN 02`, `TRN 14`) | 0 | 10 | 65.52% |
34
- | Max (Person C) | 41 | 1 (`FND 11`) | 37 (`FND 01`, `FND 02`, `FND 03`, `FND 05`, `FND 07`, `FND 10`, `FND 12` done by others, `MOD 07`, `MOD 10`, `API 01`, `API 02`, `API 03`, `API 04`, `API 05`, `API 06`, `API 07`, `API 08`, `API 09`, `API 10`, `API 11`, `API 13`, `API 14`, `API 15`, `API 17`, `API 18`, `API 19`, `JDG 07`, `OBS 01`, `OBS 02`, `OBS 03`, `OBS 07`, `OBS 09`, `TRN 11`, `TST 06`, `TST 07`, `TST 11`, `ENV 09` done by Person B) | 3 | 92.68% |
35
- | Kush (Person D) | 32 | 1 (`UI 07`) | 1 (`FND 06` done by Person B) | 30 | 6.25% |
36
  | All (shared) | 3 | 2 (`FND 08`, `AGT 05`) | 0 | 1 | 66.67% |
37
 
38
  Note: Person B (Ayush) has completed two shared tasks in their own lane
@@ -50,8 +50,8 @@ to `SCN 10`, `SCN 13`, `AGT 09`, `ENV 01` to `ENV 09`, `ENV 10`, `ENV 11`,
50
  `API 15`, `API 17`, `API 18`, `API 19`, `OBS 01`, `OBS 02`, `OBS 03`, `OBS 04`,
51
  `OBS 07`, `OBS 09`, `TRN 11`) to keep the Kian, Max, and Kush dependency
52
  chain moving. All Person A and Person C implementation tasks are now complete
53
- except for 3 remaining Max tasks (`API 16`, `DOC 08`, `UI 11`).
54
- `UI 07` was completed by Kush (Person D), unblocking `UI 11`.
55
  Ayush's next fully unblocked tasks are `TRN 05` and `JDG 10`.
56
 
57
  ---
 
20
  | Metric | Value |
21
  |--------|-------|
22
  | Total tasks | 152 |
23
+ | Completed | 107 |
24
  | Partial / active | 0 |
25
+ | Remaining | 45 |
26
+ | **Completion rate** | **70.39%** |
27
 
28
  ### Completion by Person
29
 
 
31
  |--------|----------|----------------|----------------------|-----------|------|
32
  | Kian (Person A) | 49 (47 solo + 2 shared with B) | 1 shared sign-off (`FND 08`) | 48 (`FND 04`, `FND 09`, `MOD 01`, `MOD 02`, `MOD 03`, `MOD 04`, `MOD 05`, `MOD 06`, `MOD 08`, `MOD 11`, `MOD 12`, `SCN 01` to `SCN 10`, `SCN 13`, `AGT 05`, `AGT 09`, `ENV 01` to `ENV 08`, `ENV 10`, `ENV 11`, `JDG 01` to `JDG 06`, `JDG 08`, `JDG 11`, `OBS 04`, `TST 01` to `TST 05` done by Person B) | 0 | 100.00% |
33
  | Person B (Ayush) | 29 (27 solo + 2 shared with A) | 19 (`FND 08`, `MOD 09`, `SCN 11`, `AGT 01`, `AGT 02`, `AGT 03`, `AGT 04`, `AGT 05`, `AGT 06`, `AGT 07`, `AGT 08`, `AGT 10`, `AGT 11`, `TRN 13`, `TRN 03`, `TRN 04`, `TRN 01`, `TRN 02`, `TRN 14`) | 0 | 10 | 65.52% |
34
+ | Max (Person C) | 41 | 1 (`FND 11`) | 39 (done by Person B or Person D; `API 16`, `UI 11` by Kush) | 1 (`DOC 08`) | 97.56% |
35
+ | Kush (Person D) | 32 | 4 (`UI 07`, `UI 10`, `UI 11`, `API 16`) | 1 (`FND 06` done by Person B) | 27 | 15.63% |
36
  | All (shared) | 3 | 2 (`FND 08`, `AGT 05`) | 0 | 1 | 66.67% |
37
 
38
  Note: Person B (Ayush) has completed two shared tasks in their own lane
 
50
  `API 15`, `API 17`, `API 18`, `API 19`, `OBS 01`, `OBS 02`, `OBS 03`, `OBS 04`,
51
  `OBS 07`, `OBS 09`, `TRN 11`) to keep the Kian, Max, and Kush dependency
52
  chain moving. All Person A and Person C implementation tasks are now complete
53
+ except for 1 remaining Max task (`DOC 08`, blocked on `TRN 10`).
54
+ `UI 07`, `UI 10`, `UI 11`, and `API 16` were completed by Kush (Person D).
55
  Ayush's next fully unblocked tasks are `TRN 05` and `JDG 10`.
56
 
57
  ---
replicalab/agents/__init__.py CHANGED
@@ -18,6 +18,7 @@ from .scientist_policy import (
18
  ScientistCallResult,
19
  ScientistOutputParseError,
20
  build_baseline_scientist_action,
 
21
  build_scientist_system_prompt,
22
  call_scientist_with_retry,
23
  format_scientist_observation,
@@ -34,6 +35,7 @@ __all__ = [
34
  "ScientistOutputParseError",
35
  "SuggestionChange",
36
  "build_baseline_scientist_action",
 
37
  "build_judge_audit",
38
  "build_scientist_system_prompt",
39
  "call_scientist_with_retry",
 
18
  ScientistCallResult,
19
  ScientistOutputParseError,
20
  build_baseline_scientist_action,
21
+ build_remote_scientist_policy,
22
  build_scientist_system_prompt,
23
  call_scientist_with_retry,
24
  format_scientist_observation,
 
35
  "ScientistOutputParseError",
36
  "SuggestionChange",
37
  "build_baseline_scientist_action",
38
+ "build_remote_scientist_policy",
39
  "build_judge_audit",
40
  "build_scientist_system_prompt",
41
  "call_scientist_with_retry",
replicalab/agents/scientist_policy.py CHANGED
@@ -16,6 +16,7 @@ from __future__ import annotations
16
  import json
17
  import logging
18
  import re
 
19
  from typing import Any, Callable, Literal, Mapping
20
 
21
  from pydantic import BaseModel, ConfigDict, ValidationError
@@ -138,6 +139,7 @@ def call_scientist_with_retry(
138
  observation: ScientistObservation,
139
  *,
140
  max_retries: int = 2,
 
141
  ) -> ScientistCallResult:
142
  """Call a model backend to produce a ``ScientistAction`` with parser-driven retries.
143
 
@@ -161,7 +163,7 @@ def call_scientist_with_retry(
161
  Default is 2 (so up to 3 total attempts).
162
  """
163
 
164
- user_message = format_scientist_observation(observation)
165
  messages: list[dict[str, str]] = [
166
  {"role": "system", "content": system_prompt},
167
  {"role": "user", "content": user_message},
@@ -709,3 +711,169 @@ def _baseline_defaults_for_domain(domain: str) -> dict[str, Any]:
709
  "technique": "structured_proof_outline",
710
  "duration_days": 1,
711
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  import json
17
  import logging
18
  import re
19
+ from importlib import import_module
20
  from typing import Any, Callable, Literal, Mapping
21
 
22
  from pydantic import BaseModel, ConfigDict, ValidationError
 
139
  observation: ScientistObservation,
140
  *,
141
  max_retries: int = 2,
142
+ user_message_override: str | None = None,
143
  ) -> ScientistCallResult:
144
  """Call a model backend to produce a ``ScientistAction`` with parser-driven retries.
145
 
 
163
  Default is 2 (so up to 3 total attempts).
164
  """
165
 
166
+ user_message = user_message_override or format_scientist_observation(observation)
167
  messages: list[dict[str, str]] = [
168
  {"role": "system", "content": system_prompt},
169
  {"role": "user", "content": user_message},
 
711
  "technique": "structured_proof_outline",
712
  "duration_days": 1,
713
  }
714
+
715
+
716
+ def build_remote_scientist_policy(
717
+ *,
718
+ project: str,
719
+ model_name: str,
720
+ base_model: str,
721
+ checkpoint_step: int | None = None,
722
+ max_completion_tokens: int = 450,
723
+ temperature: float = 0.0,
724
+ max_retries: int = 2,
725
+ ) -> Callable[[ScientistObservation], ScientistAction]:
726
+ """Create a sync policy callable backed by an ART serverless checkpoint."""
727
+
728
+ try:
729
+ art_module = import_module("art")
730
+ serverless_module = import_module("art.serverless")
731
+ openai_module = import_module("openai")
732
+ except ImportError as exc:
733
+ raise RuntimeError(
734
+ "Missing optional inference dependency for remote Scientist evaluation. "
735
+ "Install 'openpipe-art' and 'openai' before loading a trained checkpoint."
736
+ ) from exc
737
+
738
+ trainable_model = art_module.TrainableModel(
739
+ name=model_name,
740
+ project=project,
741
+ base_model=base_model,
742
+ )
743
+ backend = serverless_module.ServerlessBackend()
744
+
745
+ import asyncio
746
+
747
+ asyncio.run(trainable_model.register(backend))
748
+ if trainable_model.inference_api_key is None or trainable_model.inference_base_url is None:
749
+ raise RuntimeError("ART serverless model registration did not expose inference credentials.")
750
+
751
+ client = openai_module.OpenAI(
752
+ base_url=trainable_model.inference_base_url,
753
+ api_key=trainable_model.inference_api_key,
754
+ )
755
+ inference_name = trainable_model.get_inference_name(step=checkpoint_step)
756
+ training_corpus = import_module("replicalab.training.corpus")
757
+ evidence_packs = [
758
+ pack for pack in training_corpus.load_frozen_evidence_packs() if pack.trainable_in_env
759
+ ]
760
+
761
+ def generate_fn(messages: list[dict[str, str]]) -> str:
762
+ response = client.chat.completions.create(
763
+ model=inference_name,
764
+ messages=messages,
765
+ max_completion_tokens=max_completion_tokens,
766
+ temperature=temperature,
767
+ )
768
+ return _extract_message_content(response.choices[0].message.content)
769
+
770
+ def policy_fn(
771
+ observation: ScientistObservation,
772
+ *,
773
+ seed: int | None = None,
774
+ scenario: str | None = None,
775
+ difficulty: str | None = None,
776
+ ) -> ScientistAction:
777
+ evidence_pack = None
778
+ if seed is not None and scenario is not None:
779
+ try:
780
+ evidence_pack = training_corpus.select_evidence_pack(
781
+ evidence_packs,
782
+ template=scenario,
783
+ seed=seed,
784
+ )
785
+ except Exception:
786
+ evidence_pack = None
787
+ user_message = format_scientist_observation(observation)
788
+ if evidence_pack is not None:
789
+ user_message += "\n\nFrozen evidence pack:\n" + evidence_pack.prompt_block()
790
+ result = call_scientist_with_retry(
791
+ generate_fn,
792
+ _build_live_scientist_system_prompt(
793
+ observation,
794
+ evidence_pack=evidence_pack,
795
+ difficulty=difficulty,
796
+ scenario=scenario,
797
+ ),
798
+ observation,
799
+ max_retries=max_retries,
800
+ user_message_override=user_message,
801
+ )
802
+ return result.action
803
+
804
+ return policy_fn
805
+
806
+
807
+ def _build_live_scientist_system_prompt(
808
+ observation: ScientistObservation,
809
+ *,
810
+ evidence_pack: Any | None = None,
811
+ difficulty: str | None = None,
812
+ scenario: str | None = None,
813
+ ) -> str:
814
+ allowed_actions = ", ".join(action.value for action in ScientistActionType)
815
+ sections = [
816
+ "You are the Scientist agent in ReplicaLab.",
817
+ (
818
+ "Your job is to negotiate toward the strongest feasible plan under the "
819
+ "provided constraints. You do not invent resources, loosen constraints, "
820
+ "or assume hidden ground truth."
821
+ ),
822
+ (
823
+ "Return exactly one JSON object with all ScientistAction fields and no "
824
+ "extra keys or prose."
825
+ ),
826
+ f"Allowed action_type values: {allowed_actions}.",
827
+ (
828
+ "For propose_protocol and revise_protocol, include a full protocol payload "
829
+ "with sample_size >= 1, controls, technique, duration_days >= 0, "
830
+ "required_equipment, required_reagents, questions = [], and rationale."
831
+ ),
832
+ (
833
+ "For request_info, keep protocol fields empty or zero and include at least "
834
+ "one concrete blocking question."
835
+ ),
836
+ (
837
+ "For accept, keep all protocol-edit fields empty or zero and use an empty "
838
+ "questions list."
839
+ ),
840
+ (
841
+ "Bounded tool policy: search_evidence, run_code_check, and inspect_image "
842
+ "support the current scenario only. They do not override constraints."
843
+ ),
844
+ f"Paper title: {observation.paper_title}",
845
+ f"Goal: {observation.experiment_goal}",
846
+ ]
847
+ if scenario:
848
+ sections.append(f"Scenario family: {scenario}")
849
+ if difficulty:
850
+ sections.append(f"Difficulty: {difficulty}")
851
+ if evidence_pack is not None:
852
+ sections.extend(
853
+ [
854
+ f"Frozen evidence id: {evidence_pack.evidence_id}",
855
+ f"Grounding paper: {evidence_pack.downloaded_paper_title}",
856
+ f"Claim: {evidence_pack.claim}",
857
+ f"Technique: {evidence_pack.key_technique}",
858
+ f"Constraint tension: {evidence_pack.primary_constraint_tension}",
859
+ ]
860
+ )
861
+ return "\n\n".join(sections)
862
+
863
+
864
+ def _extract_message_content(content: Any) -> str:
865
+ if isinstance(content, str):
866
+ return content
867
+ if isinstance(content, list):
868
+ parts: list[str] = []
869
+ for item in content:
870
+ if isinstance(item, dict):
871
+ text = item.get("text")
872
+ if text:
873
+ parts.append(str(text))
874
+ continue
875
+ text = getattr(item, "text", None)
876
+ if text:
877
+ parts.append(str(text))
878
+ return "\n".join(parts)
879
+ return ""
replicalab/training/__init__.py CHANGED
@@ -1,6 +1,13 @@
1
  """Training utilities for ReplicaLab."""
2
 
3
  from replicalab.training.artifacts import ArtifactLayout
 
 
 
 
 
 
 
4
  from replicalab.training.corpus import FrozenEvidencePack, load_frozen_evidence_packs
5
  from replicalab.training.datasets import (
6
  LabManagerSFTExample,
@@ -29,6 +36,10 @@ from replicalab.training.scientist_grpo import (
29
 
30
  __all__ = [
31
  "ArtifactLayout",
 
 
 
 
32
  "EpisodeRecord",
33
  "EvaluationCase",
34
  "EvaluationSummary",
@@ -47,6 +58,7 @@ __all__ = [
47
  "load_frozen_evidence_packs",
48
  "preview_lab_manager_training",
49
  "preview_scientist_training",
 
50
  "summarize_episodes",
51
  "train_lab_manager_sft",
52
  "train_scientist_grpo",
 
1
  """Training utilities for ReplicaLab."""
2
 
3
  from replicalab.training.artifacts import ArtifactLayout
4
+ from replicalab.training.art_openenv import (
5
+ ArtOpenEnvConfig,
6
+ ArtRolloutSummary,
7
+ ArtScenarioSpec,
8
+ ArtTrainingSummary,
9
+ run_art_openenv_training,
10
+ )
11
  from replicalab.training.corpus import FrozenEvidencePack, load_frozen_evidence_packs
12
  from replicalab.training.datasets import (
13
  LabManagerSFTExample,
 
36
 
37
  __all__ = [
38
  "ArtifactLayout",
39
+ "ArtOpenEnvConfig",
40
+ "ArtRolloutSummary",
41
+ "ArtScenarioSpec",
42
+ "ArtTrainingSummary",
43
  "EpisodeRecord",
44
  "EvaluationCase",
45
  "EvaluationSummary",
 
58
  "load_frozen_evidence_packs",
59
  "preview_lab_manager_training",
60
  "preview_scientist_training",
61
+ "run_art_openenv_training",
62
  "summarize_episodes",
63
  "train_lab_manager_sft",
64
  "train_scientist_grpo",
replicalab/training/art_openenv.py ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ART + ReplicaLab OpenEnv training helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import json
7
+ from dataclasses import dataclass
8
+ from datetime import UTC, datetime
9
+ from pathlib import Path
10
+ from typing import Any, Sequence
11
+
12
+ from pydantic import BaseModel, ConfigDict, Field
13
+
14
+ from replicalab.agents.scientist_policy import (
15
+ ScientistOutputParseError,
16
+ format_scientist_observation,
17
+ parse_scientist_output,
18
+ )
19
+ from replicalab.client import ReplicaLabClient
20
+ from replicalab.models import ScientistObservation
21
+ from replicalab.training.artifacts import ArtifactLayout, append_jsonl, build_run_name, write_json
22
+ from replicalab.training.corpus import (
23
+ FrozenEvidencePack,
24
+ evidence_pack_version,
25
+ load_frozen_evidence_packs,
26
+ select_evidence_pack,
27
+ )
28
+
29
+
30
+ class ArtScenarioSpec(BaseModel):
31
+ """One deterministic scenario spec for ART/OpenEnv rollouts."""
32
+
33
+ model_config = ConfigDict(extra="forbid")
34
+
35
+ seed: int
36
+ scenario: str
37
+ difficulty: str
38
+
39
+
40
+ class ArtOpenEnvConfig(BaseModel):
41
+ """Config for serverless ART training against ReplicaLab."""
42
+
43
+ model_config = ConfigDict(extra="forbid")
44
+
45
+ project: str = "replicalab-art-openenv"
46
+ model_name: str = "replicalab-scientist-art"
47
+ base_model: str = "OpenPipe/Qwen3-14B-Instruct"
48
+ base_url: str = "https://ayushozha-replicalab.hf.space"
49
+ transport: str = "rest"
50
+ train_steps: int = 1
51
+ rollouts_per_group: int = 2
52
+ max_turns: int = 6
53
+ max_completion_tokens: int = 700
54
+ max_parse_retries: int = 2
55
+ learning_rate: float = 5e-6
56
+ beta: float = 0.0
57
+ scenarios: list[ArtScenarioSpec] = Field(
58
+ default_factory=lambda: [
59
+ ArtScenarioSpec(seed=11, scenario="math_reasoning", difficulty="easy"),
60
+ ArtScenarioSpec(seed=12, scenario="ml_benchmark", difficulty="easy"),
61
+ ]
62
+ )
63
+
64
+
65
+ class ArtRolloutSummary(BaseModel):
66
+ """Flat rollout record for demo/docs and post-run analysis."""
67
+
68
+ model_config = ConfigDict(extra="forbid")
69
+
70
+ run_name: str
71
+ training_step: int
72
+ group_index: int
73
+ rollout_index: int
74
+ seed: int
75
+ scenario: str
76
+ difficulty: str
77
+ paper_title: str
78
+ evidence_id: str | None = None
79
+ evidence_match_type: str | None = None
80
+ reward: float
81
+ verdict: str | None = None
82
+ agreement_reached: bool = False
83
+ rounds_used: int = 0
84
+ invalid_action_count: int = 0
85
+ parse_error_count: int = 0
86
+ rigor: float = 0.0
87
+ feasibility: float = 0.0
88
+ fidelity: float = 0.0
89
+ parsimony: float = 1.0
90
+ artifact_step: int | None = None
91
+ artifact_name: str | None = None
92
+
93
+
94
+ class ArtTrainingSummary(BaseModel):
95
+ """Top-level training summary written after the run."""
96
+
97
+ model_config = ConfigDict(extra="forbid")
98
+
99
+ run_name: str
100
+ project: str
101
+ model_name: str
102
+ base_model: str
103
+ train_steps: int
104
+ rollouts_per_group: int
105
+ scenario_count: int
106
+ base_url: str
107
+ evidence_version: str
108
+ started_at: str
109
+ finished_at: str
110
+ final_artifact_step: int | None = None
111
+ final_artifact_name: str | None = None
112
+ average_reward: float = 0.0
113
+ agreement_rate: float = 0.0
114
+ average_rounds: float = 0.0
115
+
116
+
117
+ @dataclass
118
+ class _TurnRecord:
119
+ messages_and_choices: list[Any]
120
+ parse_error: str | None
121
+ raw_text: str
122
+
123
+
124
+ @dataclass
125
+ class _EpisodeTrace:
126
+ trajectory: Any
127
+ summary: ArtRolloutSummary
128
+
129
+
130
+ def run_art_openenv_training(
131
+ config: ArtOpenEnvConfig,
132
+ *,
133
+ layout: ArtifactLayout | None = None,
134
+ ) -> dict[str, object]:
135
+ """Sync wrapper used by CLI entrypoints."""
136
+
137
+ artifact_layout = layout or ArtifactLayout.create(run_name=build_run_name("art-scientist"))
138
+ return asyncio.run(_run_art_openenv_training_async(config, artifact_layout))
139
+
140
+
141
+ async def _run_art_openenv_training_async(
142
+ config: ArtOpenEnvConfig,
143
+ layout: ArtifactLayout,
144
+ ) -> dict[str, object]:
145
+ art_module = __import__("art")
146
+ from art import Trajectory, TrajectoryGroup, TrainableModel
147
+ from art.gather import gather_trajectory_groups
148
+ from art.serverless import ServerlessBackend
149
+ from art.trajectories import History
150
+
151
+ started_at = _utc_now()
152
+ evidence_packs = [pack for pack in load_frozen_evidence_packs() if pack.trainable_in_env]
153
+ evidence_version = evidence_pack_version(evidence_packs)
154
+ backend = ServerlessBackend()
155
+ model = TrainableModel(
156
+ name=config.model_name,
157
+ project=config.project,
158
+ base_model=config.base_model,
159
+ base_path=str(layout.run_dir),
160
+ report_metrics=[
161
+ "average_reward",
162
+ "agreement_rate",
163
+ "average_rounds",
164
+ "average_rigor",
165
+ "average_feasibility",
166
+ "average_fidelity",
167
+ "average_parsimony",
168
+ "invalid_action_rate",
169
+ ],
170
+ )
171
+ await model.register(backend)
172
+
173
+ write_json(layout.config_json, config.model_dump(mode="json"))
174
+ write_json(
175
+ layout.evidence_manifest_json,
176
+ {
177
+ "evidence_version": evidence_version,
178
+ "packs": [pack.model_dump(mode="json") for pack in evidence_packs],
179
+ },
180
+ )
181
+ process_log_path = layout.reports_dir / "art_training_process.md"
182
+ process_log_path.parent.mkdir(parents=True, exist_ok=True)
183
+ process_log_path.write_text(
184
+ "# ReplicaLab ART Training Run\n\n",
185
+ encoding="utf-8",
186
+ )
187
+ _append_process_log(
188
+ process_log_path,
189
+ f"Started at `{started_at}` against `{config.base_url}` using `{config.base_model}`.",
190
+ )
191
+ _append_process_log(
192
+ process_log_path,
193
+ (
194
+ f"Loaded `{len(evidence_packs)}` trainable frozen evidence packs "
195
+ f"(version `{evidence_version}`)."
196
+ ),
197
+ )
198
+
199
+ all_rollouts: list[ArtRolloutSummary] = []
200
+ final_artifact_step: int | None = None
201
+ final_artifact_name: str | None = None
202
+
203
+ for training_step in range(1, config.train_steps + 1):
204
+ _append_process_log(
205
+ process_log_path,
206
+ (
207
+ f"Training step {training_step}: collecting "
208
+ f"{len(config.scenarios)} trajectory groups with "
209
+ f"{config.rollouts_per_group} rollouts each."
210
+ ),
211
+ )
212
+
213
+ groups = await gather_trajectory_groups(
214
+ [
215
+ _collect_trajectory_group(
216
+ model=model,
217
+ config=config,
218
+ spec=spec,
219
+ evidence_pack=select_evidence_pack(
220
+ evidence_packs,
221
+ template=spec.scenario,
222
+ seed=spec.seed,
223
+ ),
224
+ group_index=group_index,
225
+ training_step=training_step,
226
+ run_name=layout.run_name,
227
+ )
228
+ for group_index, spec in enumerate(config.scenarios)
229
+ ],
230
+ pbar_desc=f"replicalab-step-{training_step}",
231
+ )
232
+
233
+ batch_summaries: list[ArtRolloutSummary] = []
234
+ for group in groups:
235
+ for trajectory in group.trajectories:
236
+ summary = ArtRolloutSummary.model_validate(trajectory.metadata)
237
+ batch_summaries.append(summary)
238
+ append_jsonl(layout.metrics_jsonl, summary.model_dump(mode="json"))
239
+
240
+ await model.log(groups, split="train")
241
+ train_result = await backend.train(
242
+ model,
243
+ groups,
244
+ learning_rate=config.learning_rate,
245
+ beta=config.beta,
246
+ )
247
+ await model.log(
248
+ split="train",
249
+ metrics=train_result.metrics,
250
+ step=train_result.step,
251
+ )
252
+
253
+ final_artifact_step = train_result.step
254
+ final_artifact_name = train_result.artifact_name
255
+ _append_process_log(
256
+ process_log_path,
257
+ (
258
+ f"Completed training step {training_step}: artifact="
259
+ f"`{train_result.artifact_name}` step={train_result.step} "
260
+ f"metrics={json.dumps(train_result.metrics, sort_keys=True)}"
261
+ ),
262
+ )
263
+
264
+ for summary in batch_summaries:
265
+ summary.artifact_step = train_result.step
266
+ summary.artifact_name = train_result.artifact_name
267
+
268
+ all_rollouts.extend(batch_summaries)
269
+
270
+ finished_at = _utc_now()
271
+ summary = _summarize_art_training(
272
+ config=config,
273
+ layout=layout,
274
+ started_at=started_at,
275
+ finished_at=finished_at,
276
+ rollouts=all_rollouts,
277
+ evidence_version=evidence_version,
278
+ final_artifact_step=final_artifact_step,
279
+ final_artifact_name=final_artifact_name,
280
+ )
281
+ write_json(layout.summary_json, summary.model_dump(mode="json"))
282
+ _append_process_log(
283
+ process_log_path,
284
+ (
285
+ f"Finished at `{finished_at}`. Average reward={summary.average_reward:.4f}, "
286
+ f"agreement_rate={summary.agreement_rate:.4f}, "
287
+ f"average_rounds={summary.average_rounds:.4f}."
288
+ ),
289
+ )
290
+ return summary.model_dump(mode="json")
291
+
292
+
293
+ async def _collect_trajectory_group(
294
+ *,
295
+ model: Any,
296
+ config: ArtOpenEnvConfig,
297
+ spec: ArtScenarioSpec,
298
+ evidence_pack: FrozenEvidencePack | None,
299
+ group_index: int,
300
+ training_step: int,
301
+ run_name: str,
302
+ ) -> Any:
303
+ from art import TrajectoryGroup
304
+
305
+ traces = await asyncio.gather(
306
+ *[
307
+ _run_art_episode(
308
+ model=model,
309
+ config=config,
310
+ spec=spec,
311
+ evidence_pack=evidence_pack,
312
+ group_index=group_index,
313
+ rollout_index=rollout_index,
314
+ training_step=training_step,
315
+ run_name=run_name,
316
+ )
317
+ for rollout_index in range(config.rollouts_per_group)
318
+ ]
319
+ )
320
+ return TrajectoryGroup(
321
+ trajectories=[trace.trajectory for trace in traces],
322
+ metadata={
323
+ "scenario": spec.scenario,
324
+ "difficulty": spec.difficulty,
325
+ "seed": spec.seed,
326
+ "training_step": training_step,
327
+ },
328
+ metrics={
329
+ "average_reward": _mean(summary.reward for summary in [trace.summary for trace in traces]),
330
+ "agreement_rate": _mean(
331
+ 1.0 if trace.summary.agreement_reached else 0.0 for trace in traces
332
+ ),
333
+ },
334
+ logs=[
335
+ (
336
+ f"group={group_index} seed={spec.seed} scenario={spec.scenario} "
337
+ f"difficulty={spec.difficulty}"
338
+ )
339
+ ],
340
+ )
341
+
342
+
343
+ async def _run_art_episode(
344
+ *,
345
+ model: Any,
346
+ config: ArtOpenEnvConfig,
347
+ spec: ArtScenarioSpec,
348
+ evidence_pack: FrozenEvidencePack | None,
349
+ group_index: int,
350
+ rollout_index: int,
351
+ training_step: int,
352
+ run_name: str,
353
+ ) -> _EpisodeTrace:
354
+ from art import Trajectory
355
+ from art.trajectories import History
356
+
357
+ client = ReplicaLabClient(config.base_url, transport=config.transport)
358
+ await asyncio.to_thread(client.connect)
359
+ invalid_action_count = 0
360
+ parse_error_count = 0
361
+ turns: list[_TurnRecord] = []
362
+
363
+ try:
364
+ observation = await asyncio.to_thread(
365
+ client.reset,
366
+ spec.seed,
367
+ spec.scenario,
368
+ spec.difficulty,
369
+ )
370
+ scientist_obs = observation.scientist
371
+ if scientist_obs is None:
372
+ raise RuntimeError("Reset returned no scientist observation.")
373
+
374
+ terminal_reward = -1.0
375
+ terminal_info = None
376
+
377
+ for _ in range(config.max_turns):
378
+ system_prompt = _build_art_scientist_system_prompt(
379
+ spec=spec,
380
+ observation=scientist_obs,
381
+ evidence_pack=evidence_pack,
382
+ )
383
+ user_prompt = format_scientist_observation(scientist_obs)
384
+ if evidence_pack is not None:
385
+ user_prompt += "\n\nFrozen evidence pack:\n" + evidence_pack.prompt_block()
386
+ turn = await _generate_turn(
387
+ model=model,
388
+ system_prompt=system_prompt,
389
+ user_prompt=user_prompt,
390
+ max_completion_tokens=config.max_completion_tokens,
391
+ max_parse_retries=config.max_parse_retries,
392
+ )
393
+ turns.append(turn)
394
+
395
+ if turn.parse_error is not None:
396
+ parse_error_count += 1
397
+ terminal_reward = -1.0
398
+ break
399
+
400
+ action = parse_scientist_output(turn.raw_text)
401
+ result = await asyncio.to_thread(client.step, action)
402
+ terminal_reward = result.reward
403
+ terminal_info = result.info
404
+ if result.info.error:
405
+ invalid_action_count += 1
406
+
407
+ if result.done:
408
+ break
409
+
410
+ if result.observation is None or result.observation.scientist is None:
411
+ raise RuntimeError("Non-terminal step returned no scientist observation.")
412
+ scientist_obs = result.observation.scientist
413
+
414
+ histories = [
415
+ History(messages_and_choices=turn.messages_and_choices)
416
+ for turn in turns[1:]
417
+ ]
418
+ trajectory = Trajectory(
419
+ messages_and_choices=(turns[0].messages_and_choices if turns else []),
420
+ additional_histories=histories,
421
+ reward=terminal_reward,
422
+ metrics=_extract_terminal_metrics(
423
+ terminal_info=terminal_info,
424
+ invalid_action_count=invalid_action_count,
425
+ parse_error_count=parse_error_count,
426
+ rounds_used=len(turns),
427
+ ),
428
+ metadata={},
429
+ logs=[
430
+ (
431
+ f"training_step={training_step} group={group_index} rollout={rollout_index} "
432
+ f"seed={spec.seed} scenario={spec.scenario} difficulty={spec.difficulty}"
433
+ )
434
+ ],
435
+ )
436
+ summary = ArtRolloutSummary(
437
+ run_name=run_name,
438
+ training_step=training_step,
439
+ group_index=group_index,
440
+ rollout_index=rollout_index,
441
+ seed=spec.seed,
442
+ scenario=spec.scenario,
443
+ difficulty=spec.difficulty,
444
+ paper_title=scientist_obs.paper_title,
445
+ evidence_id=(evidence_pack.evidence_id if evidence_pack is not None else None),
446
+ evidence_match_type=(
447
+ evidence_pack.match_type if evidence_pack is not None else None
448
+ ),
449
+ reward=terminal_reward,
450
+ verdict=(terminal_info.verdict if terminal_info is not None else None),
451
+ agreement_reached=(
452
+ terminal_info.agreement_reached if terminal_info is not None else False
453
+ ),
454
+ rounds_used=len(turns),
455
+ invalid_action_count=invalid_action_count,
456
+ parse_error_count=parse_error_count,
457
+ rigor=(
458
+ terminal_info.reward_breakdown.rigor
459
+ if terminal_info and terminal_info.reward_breakdown
460
+ else 0.0
461
+ ),
462
+ feasibility=(
463
+ terminal_info.reward_breakdown.feasibility
464
+ if terminal_info and terminal_info.reward_breakdown
465
+ else 0.0
466
+ ),
467
+ fidelity=(
468
+ terminal_info.reward_breakdown.fidelity
469
+ if terminal_info and terminal_info.reward_breakdown
470
+ else 0.0
471
+ ),
472
+ parsimony=(
473
+ terminal_info.reward_breakdown.parsimony
474
+ if terminal_info and terminal_info.reward_breakdown
475
+ else 1.0
476
+ ),
477
+ )
478
+ trajectory.metadata.update(summary.model_dump(mode="json"))
479
+ return _EpisodeTrace(trajectory=trajectory, summary=summary)
480
+ finally:
481
+ await asyncio.to_thread(client.close)
482
+
483
+
484
+ async def _generate_turn(
485
+ *,
486
+ model: Any,
487
+ system_prompt: str,
488
+ user_prompt: str,
489
+ max_completion_tokens: int,
490
+ max_parse_retries: int,
491
+ ) -> _TurnRecord:
492
+ client = model.openai_client()
493
+ messages = [
494
+ {"role": "system", "content": system_prompt},
495
+ {"role": "user", "content": user_prompt},
496
+ ]
497
+ for attempt in range(max_parse_retries + 1):
498
+ completion = await client.chat.completions.create(
499
+ model=model.get_inference_name(),
500
+ messages=messages,
501
+ max_completion_tokens=max_completion_tokens,
502
+ temperature=0.0,
503
+ )
504
+ choice = completion.choices[0]
505
+ raw_text = _extract_choice_text(choice)
506
+ try:
507
+ parse_scientist_output(raw_text)
508
+ return _TurnRecord(
509
+ messages_and_choices=[
510
+ *messages,
511
+ {"role": "assistant", "content": raw_text},
512
+ ],
513
+ parse_error=None,
514
+ raw_text=raw_text,
515
+ )
516
+ except ScientistOutputParseError as exc:
517
+ if attempt >= max_parse_retries:
518
+ return _TurnRecord(
519
+ messages_and_choices=[
520
+ *messages,
521
+ {"role": "assistant", "content": raw_text},
522
+ ],
523
+ parse_error=exc.message,
524
+ raw_text=raw_text,
525
+ )
526
+ messages.extend(
527
+ [
528
+ {"role": "assistant", "content": raw_text},
529
+ {"role": "user", "content": _build_art_correction_prompt(exc)},
530
+ ]
531
+ )
532
+ raise RuntimeError("unreachable")
533
+
534
+
535
+ def _build_art_scientist_system_prompt(
536
+ *,
537
+ spec: ArtScenarioSpec,
538
+ observation: ScientistObservation,
539
+ evidence_pack: FrozenEvidencePack | None,
540
+ ) -> str:
541
+ sections = [
542
+ "You are the Scientist agent in ReplicaLab.",
543
+ "Negotiate toward the strongest feasible technical plan under hard real-world constraints.",
544
+ "Return exactly one valid ScientistAction JSON object with no markdown and no extra prose.",
545
+ "Use request_info only when a concrete blocking question remains.",
546
+ "Use accept only when the current protocol is genuinely ready.",
547
+ "Bounded tool policy: search_evidence, run_code_check, and inspect_image are support tools only; they never override constraints or reveal hidden ground truth.",
548
+ f"Scenario family: {spec.scenario}",
549
+ f"Difficulty: {spec.difficulty}",
550
+ f"Paper title: {observation.paper_title}",
551
+ f"Goal: {observation.experiment_goal}",
552
+ (
553
+ "The user observation already contains the full conversation "
554
+ "history and current protocol. Use that as your source of truth "
555
+ "for each turn."
556
+ ),
557
+ ]
558
+ if evidence_pack is not None:
559
+ sections.extend(
560
+ [
561
+ f"Frozen evidence id: {evidence_pack.evidence_id}",
562
+ f"Grounding paper: {evidence_pack.downloaded_paper_title}",
563
+ f"Claim: {evidence_pack.claim}",
564
+ f"Technique: {evidence_pack.key_technique}",
565
+ f"Constraint tension: {evidence_pack.primary_constraint_tension}",
566
+ ]
567
+ )
568
+ sections.extend(
569
+ [
570
+ "Always emit all ScientistAction fields, even for request_info or accept.",
571
+ (
572
+ "Shape example: "
573
+ '{"action_type":"propose_protocol","sample_size":8,"controls":["baseline"],'
574
+ '"technique":"LoRA fine-tuning on the public subset","duration_days":2,'
575
+ '"required_equipment":["gpu_h100"],"required_reagents":[],'
576
+ '"questions":[],"rationale":"Uses the available hardware and stays within the reduced dataset budget."}'
577
+ ),
578
+ ]
579
+ )
580
+ return "\n".join(sections)
581
+
582
+
583
+ def _extract_choice_text(choice: Any) -> str:
584
+ message = getattr(choice, "message", None)
585
+ content = getattr(message, "content", None)
586
+ if isinstance(content, str):
587
+ return content
588
+ if isinstance(content, list):
589
+ parts: list[str] = []
590
+ for item in content:
591
+ text = getattr(item, "text", None)
592
+ if text:
593
+ parts.append(str(text))
594
+ elif isinstance(item, dict) and "text" in item:
595
+ parts.append(str(item["text"]))
596
+ return "\n".join(parts)
597
+ return ""
598
+
599
+
600
+ def _build_art_correction_prompt(error: ScientistOutputParseError) -> str:
601
+ suffix = (
602
+ "Return exactly one JSON object with all ScientistAction fields. "
603
+ "No markdown fences, no prose, no commentary."
604
+ )
605
+ if error.code == "no_json":
606
+ return "Your previous response did not contain a JSON object. " + suffix
607
+ if error.code == "invalid_json":
608
+ return (
609
+ f"Your previous response contained malformed JSON: {error.message}. " + suffix
610
+ )
611
+ return (
612
+ "Your previous response contained valid JSON but failed ScientistAction "
613
+ f"validation: {error.message}. Fix the validation error and return a corrected "
614
+ "ScientistAction JSON object. " + suffix
615
+ )
616
+
617
+
618
+ def _extract_terminal_metrics(
619
+ *,
620
+ terminal_info: Any,
621
+ invalid_action_count: int,
622
+ parse_error_count: int,
623
+ rounds_used: int,
624
+ ) -> dict[str, float | int | bool]:
625
+ breakdown = terminal_info.reward_breakdown if terminal_info is not None else None
626
+ return {
627
+ "agreement_reached": terminal_info.agreement_reached if terminal_info else False,
628
+ "invalid_action_count": invalid_action_count,
629
+ "invalid_action_rate": (invalid_action_count / max(1, rounds_used)),
630
+ "parse_error_count": parse_error_count,
631
+ "parse_error_rate": (parse_error_count / max(1, rounds_used)),
632
+ "rounds_used": rounds_used,
633
+ "rigor": (breakdown.rigor if breakdown is not None else 0.0),
634
+ "feasibility": (breakdown.feasibility if breakdown is not None else 0.0),
635
+ "fidelity": (breakdown.fidelity if breakdown is not None else 0.0),
636
+ "parsimony": (breakdown.parsimony if breakdown is not None else 1.0),
637
+ }
638
+
639
+
640
+ def _summarize_art_training(
641
+ *,
642
+ config: ArtOpenEnvConfig,
643
+ layout: ArtifactLayout,
644
+ started_at: str,
645
+ finished_at: str,
646
+ rollouts: Sequence[ArtRolloutSummary],
647
+ evidence_version: str,
648
+ final_artifact_step: int | None,
649
+ final_artifact_name: str | None,
650
+ ) -> ArtTrainingSummary:
651
+ return ArtTrainingSummary(
652
+ run_name=layout.run_name,
653
+ project=config.project,
654
+ model_name=config.model_name,
655
+ base_model=config.base_model,
656
+ train_steps=config.train_steps,
657
+ rollouts_per_group=config.rollouts_per_group,
658
+ scenario_count=len(config.scenarios),
659
+ base_url=config.base_url,
660
+ evidence_version=evidence_version,
661
+ started_at=started_at,
662
+ finished_at=finished_at,
663
+ final_artifact_step=final_artifact_step,
664
+ final_artifact_name=final_artifact_name,
665
+ average_reward=_mean(item.reward for item in rollouts),
666
+ agreement_rate=_mean(1.0 if item.agreement_reached else 0.0 for item in rollouts),
667
+ average_rounds=_mean(item.rounds_used for item in rollouts),
668
+ )
669
+
670
+
671
+ def _append_process_log(path: Path, line: str) -> None:
672
+ with path.open("a", encoding="utf-8") as handle:
673
+ handle.write(f"- {line}\n")
674
+
675
+
676
+ def _utc_now() -> str:
677
+ return datetime.now(UTC).isoformat()
678
+
679
+
680
+ def _mean(values: Any) -> float:
681
+ values = list(values)
682
+ if not values:
683
+ return 0.0
684
+ return round(sum(float(value) for value in values) / len(values), 6)
685
+
686
+
687
+ __all__ = [
688
+ "ArtOpenEnvConfig",
689
+ "ArtRolloutSummary",
690
+ "ArtScenarioSpec",
691
+ "ArtTrainingSummary",
692
+ "run_art_openenv_training",
693
+ ]
replicalab/training/cli.py CHANGED
@@ -8,21 +8,34 @@ import sys
8
  from pathlib import Path
9
  from typing import Sequence
10
 
11
- from replicalab.agents import build_baseline_scientist_action
12
  from replicalab.training.artifacts import (
13
  ArtifactLayout,
14
  append_jsonl,
15
  build_run_name,
16
  write_json,
17
  )
18
- from replicalab.training.evaluation import build_default_evaluation_cases, evaluate_policy
 
 
 
 
 
 
 
 
 
19
  from replicalab.training.lab_manager_sft import (
20
  LabManagerSFTConfig,
21
  preview_lab_manager_training,
22
  train_lab_manager_sft,
23
  )
24
  from replicalab.training.metrics import episode_to_metrics
25
- from replicalab.training.plots import plot_evaluation_bars, plot_training_history
 
 
 
 
26
  from replicalab.training.scientist_grpo import (
27
  ScientistGRPOConfig,
28
  preview_scientist_training,
@@ -46,6 +59,10 @@ def main(argv: Sequence[str] | None = None) -> int:
46
  return _run_lab_manager_train(args)
47
  if args.command == "baseline-eval":
48
  return _run_baseline_eval(args)
 
 
 
 
49
 
50
  parser.error(f"Unsupported command: {args.command}")
51
  return 2
@@ -169,6 +186,159 @@ def _build_parser() -> argparse.ArgumentParser:
169
  help="Difficulty levels to evaluate.",
170
  )
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  return parser
173
 
174
 
@@ -279,6 +449,22 @@ def _run_scientist_train(args: argparse.Namespace) -> int:
279
  max_steps=args.max_steps,
280
  )
281
  result = train_scientist_grpo(config, layout=layout, dry_run=args.dry_run)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  _maybe_plot_training_history(
283
  layout=layout,
284
  state_name="scientist_trainer_state.json",
@@ -327,6 +513,21 @@ def _run_lab_manager_train(args: argparse.Namespace) -> int:
327
  load_in_4bit=args.load_in_4bit,
328
  )
329
  result = train_lab_manager_sft(config, layout=layout, dry_run=args.dry_run)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  _maybe_plot_training_history(
331
  layout=layout,
332
  state_name="lab_manager_trainer_state.json",
@@ -364,6 +565,22 @@ def _run_baseline_eval(args: argparse.Namespace) -> int:
364
  "cases": [case.__dict__ for case in cases],
365
  },
366
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  for record in records:
368
  append_jsonl(
369
  layout.metrics_jsonl,
@@ -376,6 +593,127 @@ def _run_baseline_eval(args: argparse.Namespace) -> int:
376
  return 0
377
 
378
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  def _maybe_plot_training_history(
380
  *,
381
  layout: ArtifactLayout,
@@ -415,6 +753,80 @@ def _plot_eval_summary(
415
  metric_key="agreement_rate",
416
  title="Baseline agreement rate",
417
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
 
419
 
420
  if __name__ == "__main__":
 
8
  from pathlib import Path
9
  from typing import Sequence
10
 
11
+ from replicalab.agents import build_baseline_scientist_action, build_remote_scientist_policy
12
  from replicalab.training.artifacts import (
13
  ArtifactLayout,
14
  append_jsonl,
15
  build_run_name,
16
  write_json,
17
  )
18
+ from replicalab.training.art_openenv import (
19
+ ArtOpenEnvConfig,
20
+ ArtScenarioSpec,
21
+ run_art_openenv_training,
22
+ )
23
+ from replicalab.training.evaluation import (
24
+ build_default_evaluation_cases,
25
+ compare_policies,
26
+ evaluate_policy,
27
+ )
28
  from replicalab.training.lab_manager_sft import (
29
  LabManagerSFTConfig,
30
  preview_lab_manager_training,
31
  train_lab_manager_sft,
32
  )
33
  from replicalab.training.metrics import episode_to_metrics
34
+ from replicalab.training.plots import (
35
+ plot_evaluation_bars,
36
+ plot_metrics_by_step,
37
+ plot_training_history,
38
+ )
39
  from replicalab.training.scientist_grpo import (
40
  ScientistGRPOConfig,
41
  preview_scientist_training,
 
59
  return _run_lab_manager_train(args)
60
  if args.command == "baseline-eval":
61
  return _run_baseline_eval(args)
62
+ if args.command == "scientist-compare-eval":
63
+ return _run_scientist_compare_eval(args)
64
+ if args.command == "art-scientist-train":
65
+ return _run_art_scientist_train(args)
66
 
67
  parser.error(f"Unsupported command: {args.command}")
68
  return 2
 
186
  help="Difficulty levels to evaluate.",
187
  )
188
 
189
+ compare_eval = subparsers.add_parser(
190
+ "scientist-compare-eval",
191
+ help="Compare baseline Scientist versus a trained ART Scientist checkpoint.",
192
+ )
193
+ _add_common_artifact_args(compare_eval, prefix="eval-compare")
194
+ compare_eval.add_argument(
195
+ "--base-url",
196
+ default="https://ayushozha-replicalab.hf.space",
197
+ help="ReplicaLab environment base URL.",
198
+ )
199
+ compare_eval.add_argument(
200
+ "--transport",
201
+ default="rest",
202
+ choices=("rest", "ws"),
203
+ help="Transport used by ReplicaLabClient.",
204
+ )
205
+ compare_eval.add_argument(
206
+ "--eval-seeds",
207
+ nargs="+",
208
+ type=int,
209
+ default=[101, 102],
210
+ help="Evaluation seeds.",
211
+ )
212
+ compare_eval.add_argument(
213
+ "--scenarios",
214
+ nargs="+",
215
+ default=list(scientist_defaults.templates),
216
+ help="Scenario families to evaluate.",
217
+ )
218
+ compare_eval.add_argument(
219
+ "--difficulties",
220
+ nargs="+",
221
+ default=list(scientist_defaults.difficulties),
222
+ help="Difficulty levels to evaluate.",
223
+ )
224
+ compare_eval.add_argument(
225
+ "--project",
226
+ default="replicalab-ai",
227
+ help="ART project name for the trained Scientist checkpoint.",
228
+ )
229
+ compare_eval.add_argument(
230
+ "--model-name",
231
+ default="replicalab-scientist-art-live",
232
+ help="ART trainable model name for the trained Scientist checkpoint.",
233
+ )
234
+ compare_eval.add_argument(
235
+ "--base-model",
236
+ default="OpenPipe/Qwen3-14B-Instruct",
237
+ help="Base model used for the ART trained Scientist.",
238
+ )
239
+ compare_eval.add_argument(
240
+ "--checkpoint-step",
241
+ type=int,
242
+ default=None,
243
+ help="Optional explicit ART checkpoint step to evaluate.",
244
+ )
245
+ compare_eval.add_argument(
246
+ "--max-completion-tokens",
247
+ type=int,
248
+ default=450,
249
+ help="Max completion tokens for the trained remote Scientist.",
250
+ )
251
+ compare_eval.add_argument(
252
+ "--temperature",
253
+ type=float,
254
+ default=0.0,
255
+ help="Sampling temperature for the trained remote Scientist.",
256
+ )
257
+
258
+ art_train = subparsers.add_parser(
259
+ "art-scientist-train",
260
+ help="Run ART serverless RL training against the ReplicaLab OpenEnv deployment.",
261
+ )
262
+ _add_common_artifact_args(art_train, prefix="art-scientist")
263
+ art_train.add_argument(
264
+ "--project",
265
+ default="replicalab-art-openenv",
266
+ help="Weights & Biases / ART project name.",
267
+ )
268
+ art_train.add_argument(
269
+ "--model-name",
270
+ default="replicalab-scientist-art",
271
+ help="ART trainable model name.",
272
+ )
273
+ art_train.add_argument(
274
+ "--base-model",
275
+ default="OpenPipe/Qwen3-14B-Instruct",
276
+ help="ART serverless base model.",
277
+ )
278
+ art_train.add_argument(
279
+ "--base-url",
280
+ default="https://ayushozha-replicalab.hf.space",
281
+ help="ReplicaLab environment base URL.",
282
+ )
283
+ art_train.add_argument(
284
+ "--transport",
285
+ default="rest",
286
+ choices=("rest",),
287
+ help="Transport used for live environment interaction.",
288
+ )
289
+ art_train.add_argument(
290
+ "--train-steps",
291
+ type=int,
292
+ default=1,
293
+ help="Number of ART training updates to run.",
294
+ )
295
+ art_train.add_argument(
296
+ "--rollouts-per-group",
297
+ type=int,
298
+ default=2,
299
+ help="Number of sampled rollouts for each scenario group.",
300
+ )
301
+ art_train.add_argument(
302
+ "--max-turns",
303
+ type=int,
304
+ default=6,
305
+ help="Max environment turns per rollout.",
306
+ )
307
+ art_train.add_argument(
308
+ "--max-completion-tokens",
309
+ type=int,
310
+ default=700,
311
+ help="Assistant max completion tokens per turn.",
312
+ )
313
+ art_train.add_argument(
314
+ "--max-parse-retries",
315
+ type=int,
316
+ default=2,
317
+ help="Number of parser-driven correction retries per turn.",
318
+ )
319
+ art_train.add_argument(
320
+ "--learning-rate",
321
+ type=float,
322
+ default=5e-6,
323
+ help="ART learning rate.",
324
+ )
325
+ art_train.add_argument(
326
+ "--beta",
327
+ type=float,
328
+ default=0.0,
329
+ help="ART KL penalty coefficient.",
330
+ )
331
+ art_train.add_argument(
332
+ "--scenario-spec",
333
+ nargs="+",
334
+ default=[
335
+ "0:ml_benchmark:easy",
336
+ "1:ml_benchmark:medium",
337
+ "0:finance_trading:easy",
338
+ ],
339
+ help="Scenario specs in the form seed:scenario:difficulty.",
340
+ )
341
+
342
  return parser
343
 
344
 
 
449
  max_steps=args.max_steps,
450
  )
451
  result = train_scientist_grpo(config, layout=layout, dry_run=args.dry_run)
452
+ _write_run_metadata(
453
+ layout,
454
+ {
455
+ "kind": "scientist_train",
456
+ "model_name": args.model_name,
457
+ "templates": args.templates,
458
+ "difficulties": args.difficulties,
459
+ "seed_count": args.seed_count,
460
+ "max_steps": args.max_steps,
461
+ "bounded_tool_policy": [
462
+ "search_evidence",
463
+ "run_code_check",
464
+ "inspect_image",
465
+ ],
466
+ },
467
+ )
468
  _maybe_plot_training_history(
469
  layout=layout,
470
  state_name="scientist_trainer_state.json",
 
513
  load_in_4bit=args.load_in_4bit,
514
  )
515
  result = train_lab_manager_sft(config, layout=layout, dry_run=args.dry_run)
516
+ _write_run_metadata(
517
+ layout,
518
+ {
519
+ "kind": "lab_manager_train",
520
+ "model_name": args.model_name,
521
+ "templates": args.templates,
522
+ "difficulties": args.difficulties,
523
+ "seed_count": args.seed_count,
524
+ "bounded_tool_policy": [
525
+ "search_evidence",
526
+ "run_code_check",
527
+ "inspect_image",
528
+ ],
529
+ },
530
+ )
531
  _maybe_plot_training_history(
532
  layout=layout,
533
  state_name="lab_manager_trainer_state.json",
 
565
  "cases": [case.__dict__ for case in cases],
566
  },
567
  )
568
+ _write_run_metadata(
569
+ layout,
570
+ {
571
+ "kind": "baseline_eval",
572
+ "base_url": args.base_url,
573
+ "transport": args.transport,
574
+ "eval_seeds": args.eval_seeds,
575
+ "scenarios": args.scenarios,
576
+ "difficulties": args.difficulties,
577
+ "bounded_tool_policy": [
578
+ "search_evidence",
579
+ "run_code_check",
580
+ "inspect_image",
581
+ ],
582
+ },
583
+ )
584
  for record in records:
585
  append_jsonl(
586
  layout.metrics_jsonl,
 
593
  return 0
594
 
595
 
596
+ def _run_scientist_compare_eval(args: argparse.Namespace) -> int:
597
+ layout = _build_layout(
598
+ prefix="eval-compare",
599
+ persist_root=args.persist_root,
600
+ run_name=args.run_name,
601
+ )
602
+ cases = build_default_evaluation_cases(
603
+ seeds=args.eval_seeds,
604
+ scenarios=args.scenarios,
605
+ difficulties=args.difficulties,
606
+ )
607
+ trained_policy = build_remote_scientist_policy(
608
+ project=args.project,
609
+ model_name=args.model_name,
610
+ base_model=args.base_model,
611
+ checkpoint_step=args.checkpoint_step,
612
+ max_completion_tokens=args.max_completion_tokens,
613
+ temperature=args.temperature,
614
+ )
615
+ records_by_label, rows = compare_policies(
616
+ base_url=args.base_url,
617
+ policies=[
618
+ ("baseline", build_baseline_scientist_action),
619
+ ("trained", trained_policy),
620
+ ],
621
+ cases=cases,
622
+ transport=args.transport,
623
+ )
624
+ write_json(
625
+ layout.config_json,
626
+ {
627
+ "kind": "scientist_compare_eval",
628
+ "base_url": args.base_url,
629
+ "transport": args.transport,
630
+ "cases": [case.__dict__ for case in cases],
631
+ "project": args.project,
632
+ "model_name": args.model_name,
633
+ "base_model": args.base_model,
634
+ "checkpoint_step": args.checkpoint_step,
635
+ },
636
+ )
637
+ _write_run_metadata(
638
+ layout,
639
+ {
640
+ "kind": "scientist_compare_eval",
641
+ "base_url": args.base_url,
642
+ "transport": args.transport,
643
+ "eval_seeds": args.eval_seeds,
644
+ "scenarios": args.scenarios,
645
+ "difficulties": args.difficulties,
646
+ "project": args.project,
647
+ "model_name": args.model_name,
648
+ "base_model": args.base_model,
649
+ "checkpoint_step": args.checkpoint_step,
650
+ "bounded_tool_policy": [
651
+ "search_evidence",
652
+ "run_code_check",
653
+ "inspect_image",
654
+ ],
655
+ },
656
+ )
657
+ for label, records in records_by_label.items():
658
+ for record in records:
659
+ append_jsonl(
660
+ layout.metrics_jsonl,
661
+ {"label": label, **episode_to_metrics(record).model_dump(mode="json")},
662
+ )
663
+ rows_payload = [row.model_dump(mode="json") for row in rows]
664
+ write_json(layout.summary_json, {"rows": rows_payload})
665
+ _plot_comparison_summary(rows_payload, layout=layout)
666
+ print(json.dumps({"rows": rows_payload}, indent=2, sort_keys=True))
667
+ return 0
668
+
669
+
670
+ def _run_art_scientist_train(args: argparse.Namespace) -> int:
671
+ layout = _build_layout(
672
+ prefix="art-scientist",
673
+ persist_root=args.persist_root,
674
+ run_name=args.run_name,
675
+ )
676
+ config = ArtOpenEnvConfig(
677
+ project=args.project,
678
+ model_name=args.model_name,
679
+ base_model=args.base_model,
680
+ base_url=args.base_url,
681
+ transport=args.transport,
682
+ train_steps=args.train_steps,
683
+ rollouts_per_group=args.rollouts_per_group,
684
+ max_turns=args.max_turns,
685
+ max_completion_tokens=args.max_completion_tokens,
686
+ max_parse_retries=args.max_parse_retries,
687
+ learning_rate=args.learning_rate,
688
+ beta=args.beta,
689
+ scenarios=[_parse_art_scenario_spec(item) for item in args.scenario_spec],
690
+ )
691
+ result = run_art_openenv_training(config, layout=layout)
692
+ _write_run_metadata(
693
+ layout,
694
+ {
695
+ "kind": "art_scientist_train",
696
+ "project": args.project,
697
+ "model_name": args.model_name,
698
+ "base_model": args.base_model,
699
+ "base_url": args.base_url,
700
+ "train_steps": args.train_steps,
701
+ "rollouts_per_group": args.rollouts_per_group,
702
+ "max_turns": args.max_turns,
703
+ "max_parse_retries": args.max_parse_retries,
704
+ "scenario_spec": args.scenario_spec,
705
+ "bounded_tool_policy": [
706
+ "search_evidence",
707
+ "run_code_check",
708
+ "inspect_image",
709
+ ],
710
+ },
711
+ )
712
+ _plot_art_metrics(layout)
713
+ print(json.dumps(result, indent=2, sort_keys=True))
714
+ return 0
715
+
716
+
717
  def _maybe_plot_training_history(
718
  *,
719
  layout: ArtifactLayout,
 
753
  metric_key="agreement_rate",
754
  title="Baseline agreement rate",
755
  )
756
+ if "average_invalid_bounded_tool_rate" in summary:
757
+ plot_evaluation_bars(
758
+ rows,
759
+ output_path=layout.plots_dir / "baseline_invalid_bounded_tool_rate.png",
760
+ metric_key="average_invalid_bounded_tool_rate",
761
+ title="Baseline invalid bounded-tool rate",
762
+ )
763
+
764
+
765
+ def _plot_comparison_summary(
766
+ rows: list[dict[str, float | str]],
767
+ *,
768
+ layout: ArtifactLayout,
769
+ ) -> None:
770
+ for metric_key, title, output_name in (
771
+ ("average_reward", "Before vs after average reward", "compare_average_reward.png"),
772
+ ("agreement_rate", "Before vs after agreement rate", "compare_agreement_rate.png"),
773
+ ("invalid_action_rate", "Before vs after invalid action rate", "compare_invalid_action_rate.png"),
774
+ (
775
+ "average_invalid_bounded_tool_rate",
776
+ "Before vs after invalid bounded-tool rate",
777
+ "compare_invalid_bounded_tool_rate.png",
778
+ ),
779
+ ):
780
+ plot_evaluation_bars(
781
+ rows,
782
+ output_path=layout.plots_dir / output_name,
783
+ metric_key=metric_key,
784
+ title=title,
785
+ )
786
+
787
+
788
+ def _plot_art_metrics(layout: ArtifactLayout) -> None:
789
+ if not layout.metrics_jsonl.exists():
790
+ return
791
+ rows = [
792
+ json.loads(line)
793
+ for line in layout.metrics_jsonl.read_text(encoding="utf-8").splitlines()
794
+ if line.strip()
795
+ ]
796
+ if not rows:
797
+ return
798
+ plot_metrics_by_step(
799
+ rows,
800
+ output_path=layout.plots_dir / "art_reward_components.png",
801
+ title="ART Scientist reward components by training step",
802
+ metric_keys=[
803
+ "reward",
804
+ "rigor",
805
+ "feasibility",
806
+ "fidelity",
807
+ "agreement_reached",
808
+ "invalid_action_count",
809
+ "parse_error_count",
810
+ ],
811
+ )
812
+
813
+
814
+ def _write_run_metadata(layout: ArtifactLayout, payload: dict[str, object]) -> None:
815
+ write_json(layout.reports_dir / "run_metadata.json", payload)
816
+
817
+
818
+ def _parse_art_scenario_spec(value: str) -> ArtScenarioSpec:
819
+ parts = value.split(":")
820
+ if len(parts) != 3:
821
+ raise ValueError(
822
+ f"Invalid scenario spec {value!r}. Expected seed:scenario:difficulty."
823
+ )
824
+ seed_text, scenario, difficulty = parts
825
+ return ArtScenarioSpec(
826
+ seed=int(seed_text),
827
+ scenario=scenario,
828
+ difficulty=difficulty,
829
+ )
830
 
831
 
832
  if __name__ == "__main__":
replicalab/training/evaluation.py CHANGED
@@ -5,6 +5,8 @@ from __future__ import annotations
5
  from dataclasses import dataclass
6
  from typing import Callable, Iterable, Sequence
7
 
 
 
8
  from replicalab.client import ReplicaLabClient
9
  from replicalab.models import ScientistAction, ScientistObservation
10
  from replicalab.training.metrics import EvaluationSummary, summarize_episodes
@@ -21,6 +23,25 @@ class EvaluationCase:
21
  difficulty: str
22
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def build_default_evaluation_cases(
25
  *,
26
  seeds: Iterable[int],
@@ -62,8 +83,38 @@ def evaluate_policy(
62
  return records, summarize_episodes(records)
63
 
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  __all__ = [
66
  "EvaluationCase",
 
67
  "build_default_evaluation_cases",
 
68
  "evaluate_policy",
69
  ]
 
5
  from dataclasses import dataclass
6
  from typing import Callable, Iterable, Sequence
7
 
8
+ from pydantic import BaseModel, ConfigDict
9
+
10
  from replicalab.client import ReplicaLabClient
11
  from replicalab.models import ScientistAction, ScientistObservation
12
  from replicalab.training.metrics import EvaluationSummary, summarize_episodes
 
23
  difficulty: str
24
 
25
 
26
+ class PolicyComparisonRow(BaseModel):
27
+ """One flattened before/after comparison row."""
28
+
29
+ model_config = ConfigDict(extra="forbid")
30
+
31
+ label: str
32
+ episode_count: int
33
+ average_reward: float
34
+ average_rounds: float
35
+ agreement_rate: float
36
+ invalid_action_rate: float
37
+ average_invalid_bounded_tool_rate: float
38
+ average_rigor: float
39
+ average_feasibility: float
40
+ average_fidelity: float
41
+ average_parsimony: float
42
+ average_tool_trace_count: float
43
+
44
+
45
  def build_default_evaluation_cases(
46
  *,
47
  seeds: Iterable[int],
 
83
  return records, summarize_episodes(records)
84
 
85
 
86
+ def compare_policies(
87
+ *,
88
+ base_url: str,
89
+ policies: Sequence[tuple[str, PolicyFn]],
90
+ cases: Sequence[EvaluationCase],
91
+ transport: str = "rest",
92
+ ) -> tuple[dict[str, list[EpisodeRecord]], list[PolicyComparisonRow]]:
93
+ """Evaluate multiple policies on the exact same case set."""
94
+
95
+ records_by_label: dict[str, list[EpisodeRecord]] = {}
96
+ rows: list[PolicyComparisonRow] = []
97
+ for label, policy_fn in policies:
98
+ records, summary = evaluate_policy(
99
+ base_url=base_url,
100
+ policy_fn=policy_fn,
101
+ cases=cases,
102
+ transport=transport,
103
+ )
104
+ records_by_label[label] = records
105
+ rows.append(
106
+ PolicyComparisonRow(
107
+ label=label,
108
+ **summary.model_dump(mode="json"),
109
+ )
110
+ )
111
+ return records_by_label, rows
112
+
113
+
114
  __all__ = [
115
  "EvaluationCase",
116
+ "PolicyComparisonRow",
117
  "build_default_evaluation_cases",
118
+ "compare_policies",
119
  "evaluate_policy",
120
  ]
replicalab/training/metrics.py CHANGED
@@ -24,6 +24,8 @@ class EpisodeMetrics(BaseModel):
24
  invalid_action_count: int = 0
25
  invalid_action_rate: float = 0.0
26
  tool_trace_count: int = 0
 
 
27
  rigor: float = 0.0
28
  feasibility: float = 0.0
29
  fidelity: float = 0.0
@@ -40,6 +42,7 @@ class EvaluationSummary(BaseModel):
40
  average_rounds: float
41
  agreement_rate: float
42
  invalid_action_rate: float
 
43
  average_rigor: float
44
  average_feasibility: float
45
  average_fidelity: float
@@ -52,6 +55,8 @@ def episode_to_metrics(record: EpisodeRecord) -> EpisodeMetrics:
52
 
53
  invalid_actions = sum(1 for step in record.steps if step.error)
54
  rounds_used = max(1, record.rounds_used)
 
 
55
  breakdown = record.reward_breakdown
56
 
57
  return EpisodeMetrics(
@@ -64,7 +69,9 @@ def episode_to_metrics(record: EpisodeRecord) -> EpisodeMetrics:
64
  verdict=record.verdict,
65
  invalid_action_count=invalid_actions,
66
  invalid_action_rate=invalid_actions / rounds_used,
67
- tool_trace_count=record.tool_trace_count,
 
 
68
  rigor=(breakdown.rigor if breakdown is not None else 0.0),
69
  feasibility=(breakdown.feasibility if breakdown is not None else 0.0),
70
  fidelity=(breakdown.fidelity if breakdown is not None else 0.0),
@@ -83,6 +90,7 @@ def summarize_episodes(records: list[EpisodeRecord]) -> EvaluationSummary:
83
  average_rounds=0.0,
84
  agreement_rate=0.0,
85
  invalid_action_rate=0.0,
 
86
  average_rigor=0.0,
87
  average_feasibility=0.0,
88
  average_fidelity=0.0,
@@ -96,6 +104,9 @@ def summarize_episodes(records: list[EpisodeRecord]) -> EvaluationSummary:
96
  average_rounds=mean(item.rounds_used for item in metrics),
97
  agreement_rate=mean(1.0 if item.agreement_reached else 0.0 for item in metrics),
98
  invalid_action_rate=mean(item.invalid_action_rate for item in metrics),
 
 
 
99
  average_rigor=mean(item.rigor for item in metrics),
100
  average_feasibility=mean(item.feasibility for item in metrics),
101
  average_fidelity=mean(item.fidelity for item in metrics),
@@ -104,6 +115,23 @@ def summarize_episodes(records: list[EpisodeRecord]) -> EvaluationSummary:
104
  )
105
 
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  __all__ = [
108
  "EpisodeMetrics",
109
  "EvaluationSummary",
 
24
  invalid_action_count: int = 0
25
  invalid_action_rate: float = 0.0
26
  tool_trace_count: int = 0
27
+ invalid_bounded_tool_count: int = 0
28
+ invalid_bounded_tool_rate: float = 0.0
29
  rigor: float = 0.0
30
  feasibility: float = 0.0
31
  fidelity: float = 0.0
 
42
  average_rounds: float
43
  agreement_rate: float
44
  invalid_action_rate: float
45
+ average_invalid_bounded_tool_rate: float
46
  average_rigor: float
47
  average_feasibility: float
48
  average_fidelity: float
 
55
 
56
  invalid_actions = sum(1 for step in record.steps if step.error)
57
  rounds_used = max(1, record.rounds_used)
58
+ invalid_bounded_tools = _count_invalid_bounded_tools(record.tool_traces)
59
+ tool_trace_count = record.tool_trace_count
60
  breakdown = record.reward_breakdown
61
 
62
  return EpisodeMetrics(
 
69
  verdict=record.verdict,
70
  invalid_action_count=invalid_actions,
71
  invalid_action_rate=invalid_actions / rounds_used,
72
+ tool_trace_count=tool_trace_count,
73
+ invalid_bounded_tool_count=invalid_bounded_tools,
74
+ invalid_bounded_tool_rate=invalid_bounded_tools / max(1, tool_trace_count),
75
  rigor=(breakdown.rigor if breakdown is not None else 0.0),
76
  feasibility=(breakdown.feasibility if breakdown is not None else 0.0),
77
  fidelity=(breakdown.fidelity if breakdown is not None else 0.0),
 
90
  average_rounds=0.0,
91
  agreement_rate=0.0,
92
  invalid_action_rate=0.0,
93
+ average_invalid_bounded_tool_rate=0.0,
94
  average_rigor=0.0,
95
  average_feasibility=0.0,
96
  average_fidelity=0.0,
 
104
  average_rounds=mean(item.rounds_used for item in metrics),
105
  agreement_rate=mean(1.0 if item.agreement_reached else 0.0 for item in metrics),
106
  invalid_action_rate=mean(item.invalid_action_rate for item in metrics),
107
+ average_invalid_bounded_tool_rate=mean(
108
+ item.invalid_bounded_tool_rate for item in metrics
109
+ ),
110
  average_rigor=mean(item.rigor for item in metrics),
111
  average_feasibility=mean(item.feasibility for item in metrics),
112
  average_fidelity=mean(item.fidelity for item in metrics),
 
115
  )
116
 
117
 
118
+ def _count_invalid_bounded_tools(traces: list[dict[str, object]]) -> int:
119
+ invalid_count = 0
120
+ for trace in traces:
121
+ status = str(trace.get("status", "") or "").strip().lower()
122
+ error = trace.get("error")
123
+ valid = trace.get("valid")
124
+ if error:
125
+ invalid_count += 1
126
+ continue
127
+ if valid is False:
128
+ invalid_count += 1
129
+ continue
130
+ if status and status not in {"ok", "success", "succeeded", "completed"}:
131
+ invalid_count += 1
132
+ return invalid_count
133
+
134
+
135
  __all__ = [
136
  "EpisodeMetrics",
137
  "EvaluationSummary",
replicalab/training/plots.py CHANGED
@@ -3,6 +3,7 @@
3
  from __future__ import annotations
4
 
5
  from pathlib import Path
 
6
  from typing import Iterable
7
 
8
 
@@ -77,7 +78,53 @@ def plot_evaluation_bars(
77
  plt.close(fig)
78
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  __all__ = [
81
  "plot_evaluation_bars",
 
82
  "plot_training_history",
83
  ]
 
3
  from __future__ import annotations
4
 
5
  from pathlib import Path
6
+ from statistics import mean
7
  from typing import Iterable
8
 
9
 
 
78
  plt.close(fig)
79
 
80
 
81
+ def plot_metrics_by_step(
82
+ rows: Iterable[dict[str, object]],
83
+ *,
84
+ output_path: Path,
85
+ title: str,
86
+ metric_keys: list[str],
87
+ x_key: str = "training_step",
88
+ ) -> None:
89
+ """Plot averaged metric curves grouped by training step."""
90
+
91
+ matplotlib = __import__("matplotlib.pyplot", fromlist=["pyplot"])
92
+ plt = matplotlib
93
+
94
+ grouped: dict[int, dict[str, list[float]]] = {}
95
+ for row in rows:
96
+ raw_step = row.get(x_key)
97
+ if not isinstance(raw_step, int):
98
+ continue
99
+ bucket = grouped.setdefault(raw_step, {})
100
+ for metric_key in metric_keys:
101
+ raw_value = row.get(metric_key)
102
+ if isinstance(raw_value, (int, float)):
103
+ bucket.setdefault(metric_key, []).append(float(raw_value))
104
+
105
+ if not grouped:
106
+ raise ValueError(f"No '{x_key}' values found for metric plotting.")
107
+
108
+ steps = sorted(grouped)
109
+ output_path.parent.mkdir(parents=True, exist_ok=True)
110
+ fig, ax = plt.subplots(figsize=(10, 5))
111
+ for metric_key in metric_keys:
112
+ values = [
113
+ mean(grouped[step].get(metric_key, [0.0]))
114
+ for step in steps
115
+ ]
116
+ ax.plot(steps, values, marker="o", label=metric_key.replace("_", " "))
117
+ ax.set_title(title)
118
+ ax.set_xlabel(x_key.replace("_", " "))
119
+ ax.grid(True, alpha=0.3)
120
+ ax.legend()
121
+ fig.tight_layout()
122
+ fig.savefig(output_path, dpi=160)
123
+ plt.close(fig)
124
+
125
+
126
  __all__ = [
127
  "plot_evaluation_bars",
128
+ "plot_metrics_by_step",
129
  "plot_training_history",
130
  ]
replicalab/training/rollout.py CHANGED
@@ -29,6 +29,7 @@ returns a ``ScientistAction``. The baseline from
29
  from __future__ import annotations
30
 
31
  from dataclasses import dataclass, field
 
32
  from typing import Any, Callable, Iterable, Optional
33
 
34
  from replicalab.client import ReplicaLabClient
@@ -88,7 +89,7 @@ class EpisodeRecord:
88
 
89
 
90
  # Type alias for the policy callable
91
- PolicyFn = Callable[[ScientistObservation], ScientistAction]
92
 
93
 
94
  class RolloutWorker:
@@ -147,7 +148,13 @@ class RolloutWorker:
147
  raise RuntimeError("Reset returned no scientist observation")
148
 
149
  for step_idx in range(self._max_steps):
150
- action = policy_fn(scientist_obs)
 
 
 
 
 
 
151
 
152
  result: StepResult = self._client.step(action)
153
  tool_traces = _extract_tool_traces(result.info)
@@ -221,3 +228,22 @@ def _extract_tool_traces(info: StepInfo) -> list[dict[str, Any]]:
221
  if isinstance(item, dict):
222
  traces.append(dict(item))
223
  return traces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  from __future__ import annotations
30
 
31
  from dataclasses import dataclass, field
32
+ from inspect import signature
33
  from typing import Any, Callable, Iterable, Optional
34
 
35
  from replicalab.client import ReplicaLabClient
 
89
 
90
 
91
  # Type alias for the policy callable
92
+ PolicyFn = Callable[..., ScientistAction]
93
 
94
 
95
  class RolloutWorker:
 
148
  raise RuntimeError("Reset returned no scientist observation")
149
 
150
  for step_idx in range(self._max_steps):
151
+ action = _invoke_policy(
152
+ policy_fn,
153
+ scientist_obs,
154
+ seed=seed,
155
+ scenario=scenario,
156
+ difficulty=difficulty,
157
+ )
158
 
159
  result: StepResult = self._client.step(action)
160
  tool_traces = _extract_tool_traces(result.info)
 
228
  if isinstance(item, dict):
229
  traces.append(dict(item))
230
  return traces
231
+
232
+
233
+ def _invoke_policy(
234
+ policy_fn: PolicyFn,
235
+ observation: ScientistObservation,
236
+ *,
237
+ seed: int,
238
+ scenario: str,
239
+ difficulty: str,
240
+ ) -> ScientistAction:
241
+ parameters = signature(policy_fn).parameters
242
+ if len(parameters) <= 1:
243
+ return policy_fn(observation)
244
+ return policy_fn(
245
+ observation,
246
+ seed=seed,
247
+ scenario=scenario,
248
+ difficulty=difficulty,
249
+ )
requirements-train.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training dependencies (GPU required)
2
+ # Installed on top of server/requirements.txt
3
+
4
+ # RL training framework
5
+ trl>=0.15,<1.0
6
+
7
+ # Dataset handling
8
+ datasets>=3.0,<4.0
9
+
10
+ # Unsloth for fast LoRA fine-tuning
11
+ unsloth>=2025.3
12
+
13
+ # vLLM for fast inference during GRPO rollouts
14
+ vllm>=0.7
15
+
16
+ # Plotting
17
+ matplotlib>=3.9,<4.0
18
+
19
+ # Already in server/requirements.txt but listed for completeness
20
+ pydantic>=2.7,<3.0
21
+ httpx>=0.27,<1.0
22
+ websocket-client>=1.7,<2.0
scripts/train.sh ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # ReplicaLab training entrypoint for Northflank GPU jobs.
3
+ #
4
+ # Usage:
5
+ # MODE=train ./scripts/train.sh # full training (scientist + lab manager)
6
+ # MODE=scientist ./scripts/train.sh # scientist GRPO only
7
+ # MODE=lab-manager ./scripts/train.sh # lab manager SFT only
8
+ # MODE=eval ./scripts/train.sh # baseline evaluation only
9
+ # MODE=server ./scripts/train.sh # just run server (default)
10
+ #
11
+ # The script starts the ReplicaLab server in the background (needed for
12
+ # rollout evaluation), then runs the requested training flow.
13
+
14
+ set -euo pipefail
15
+
16
+ MODE="${MODE:-server}"
17
+ SEED_COUNT="${SEED_COUNT:-8}"
18
+ MAX_STEPS="${MAX_STEPS:-300}"
19
+ MODEL_NAME="${MODEL_NAME:-Qwen/Qwen3-8B}"
20
+ PERSIST_ROOT="${REPLICALAB_PERSIST_ROOT:-/app/outputs/training}"
21
+ BASE_URL="http://localhost:7860"
22
+
23
+ echo "=========================================="
24
+ echo " ReplicaLab Training Pipeline"
25
+ echo "=========================================="
26
+ echo " Mode: $MODE"
27
+ echo " Model: $MODEL_NAME"
28
+ echo " Seeds: $SEED_COUNT"
29
+ echo " Max steps: $MAX_STEPS"
30
+ echo " Persist: $PERSIST_ROOT"
31
+ echo " Server URL: $BASE_URL"
32
+ echo "=========================================="
33
+
34
+ # ── Start server in background (needed for eval rollouts) ──────────────
35
+ start_server() {
36
+ echo "[train.sh] Starting ReplicaLab server on port 7860..."
37
+ uvicorn server.app:app --host 0.0.0.0 --port 7860 &
38
+ SERVER_PID=$!
39
+ echo "[train.sh] Server PID: $SERVER_PID"
40
+
41
+ # Wait for server to be ready
42
+ for i in $(seq 1 30); do
43
+ if curl -sf http://localhost:7860/health > /dev/null 2>&1; then
44
+ echo "[train.sh] Server is ready."
45
+ return 0
46
+ fi
47
+ sleep 1
48
+ done
49
+ echo "[train.sh] WARNING: Server did not become ready in 30s, continuing anyway."
50
+ }
51
+
52
+ # ── Scientist GRPO training ───────────────────────────────────────────
53
+ run_scientist_train() {
54
+ echo ""
55
+ echo "=== Phase 1: Scientist GRPO Training ==="
56
+ echo ""
57
+
58
+ # Preview first (no GPU needed)
59
+ python -m replicalab.training.cli scientist-preview \
60
+ --persist-root "$PERSIST_ROOT" \
61
+ --model-name "$MODEL_NAME" \
62
+ --seed-count "$SEED_COUNT"
63
+
64
+ # Full training
65
+ python -m replicalab.training.cli scientist-train \
66
+ --persist-root "$PERSIST_ROOT" \
67
+ --model-name "$MODEL_NAME" \
68
+ --seed-count "$SEED_COUNT" \
69
+ --max-steps "$MAX_STEPS"
70
+
71
+ echo "[train.sh] Scientist GRPO training complete."
72
+ }
73
+
74
+ # ── Lab Manager SFT training ─────────────────────────────────────────
75
+ run_lab_manager_train() {
76
+ echo ""
77
+ echo "=== Phase 2: Lab Manager SFT Training ==="
78
+ echo ""
79
+
80
+ # Preview first
81
+ python -m replicalab.training.cli lab-manager-preview \
82
+ --persist-root "$PERSIST_ROOT" \
83
+ --model-name "$MODEL_NAME" \
84
+ --seed-count "$SEED_COUNT"
85
+
86
+ # Full training
87
+ python -m replicalab.training.cli lab-manager-train \
88
+ --persist-root "$PERSIST_ROOT" \
89
+ --model-name "$MODEL_NAME" \
90
+ --seed-count "$SEED_COUNT"
91
+
92
+ echo "[train.sh] Lab Manager SFT training complete."
93
+ }
94
+
95
+ # ── Baseline evaluation ──────────────────────────────────────────────
96
+ run_eval() {
97
+ echo ""
98
+ echo "=== Baseline Evaluation ==="
99
+ echo ""
100
+
101
+ python -m replicalab.training.cli baseline-eval \
102
+ --persist-root "$PERSIST_ROOT" \
103
+ --base-url "$BASE_URL" \
104
+ --seed-count "$SEED_COUNT"
105
+
106
+ echo "[train.sh] Evaluation complete."
107
+ }
108
+
109
+ # ── Mode dispatch ────────────────────────────────────────────────────
110
+
111
+ case "$MODE" in
112
+ server)
113
+ echo "[train.sh] Server-only mode."
114
+ exec uvicorn server.app:app --host 0.0.0.0 --port 7860
115
+ ;;
116
+
117
+ train)
118
+ start_server
119
+ run_scientist_train
120
+ run_lab_manager_train
121
+ run_eval
122
+ echo ""
123
+ echo "=========================================="
124
+ echo " All training complete!"
125
+ echo " Artifacts saved to: $PERSIST_ROOT"
126
+ echo "=========================================="
127
+ # Keep container alive so artifacts can be retrieved
128
+ echo "[train.sh] Training done. Keeping container alive..."
129
+ wait $SERVER_PID
130
+ ;;
131
+
132
+ scientist)
133
+ run_scientist_train
134
+ ;;
135
+
136
+ lab-manager)
137
+ run_lab_manager_train
138
+ ;;
139
+
140
+ eval)
141
+ start_server
142
+ run_eval
143
+ wait $SERVER_PID
144
+ ;;
145
+
146
+ *)
147
+ echo "Unknown MODE: $MODE"
148
+ echo "Valid modes: server, train, scientist, lab-manager, eval"
149
+ exit 1
150
+ ;;
151
+ esac
tests/test_server.py CHANGED
@@ -96,12 +96,11 @@ class TestRootEndpoint:
96
 
97
  def test_root_mentions_core_api_endpoints(self, client: TestClient) -> None:
98
  body = client.get("/").text
99
- assert "ReplicaLab API" in body
100
- assert "GET /health" in body
101
- assert "GET /scenarios" in body
102
- assert "POST /reset" in body
103
- assert "POST /step" in body
104
- assert "WS /ws" in body
105
 
106
 
107
  class TestWebFallback:
 
96
 
97
  def test_root_mentions_core_api_endpoints(self, client: TestClient) -> None:
98
  body = client.get("/").text
99
+ # When frontend/dist exists, root serves the SPA; otherwise the API landing
100
+ assert "ReplicaLab" in body
101
+ if "ReplicaLab API" in body:
102
+ assert "GET /health" in body
103
+ assert "POST /reset" in body
 
104
 
105
 
106
  class TestWebFallback:
tests/test_training_cli.py CHANGED
@@ -4,6 +4,7 @@ import json
4
 
5
  from replicalab.models import RewardBreakdown
6
  from replicalab.training.cli import main
 
7
  from replicalab.training.metrics import EvaluationSummary
8
  from replicalab.training.rollout import EpisodeRecord
9
 
@@ -57,6 +58,7 @@ def test_baseline_eval_cli_writes_summary_and_metrics(tmp_path, monkeypatch) ->
57
  average_rounds=1.0,
58
  agreement_rate=1.0,
59
  invalid_action_rate=0.0,
 
60
  average_rigor=0.6,
61
  average_feasibility=0.8,
62
  average_fidelity=0.7,
@@ -95,3 +97,94 @@ def test_baseline_eval_cli_writes_summary_and_metrics(tmp_path, monkeypatch) ->
95
  metric = json.loads(metrics_lines[0])
96
  assert metric["scenario"] == "ml_benchmark"
97
  assert metric["agreement_reached"] is True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  from replicalab.models import RewardBreakdown
6
  from replicalab.training.cli import main
7
+ from replicalab.training.evaluation import PolicyComparisonRow
8
  from replicalab.training.metrics import EvaluationSummary
9
  from replicalab.training.rollout import EpisodeRecord
10
 
 
58
  average_rounds=1.0,
59
  agreement_rate=1.0,
60
  invalid_action_rate=0.0,
61
+ average_invalid_bounded_tool_rate=0.0,
62
  average_rigor=0.6,
63
  average_feasibility=0.8,
64
  average_fidelity=0.7,
 
97
  metric = json.loads(metrics_lines[0])
98
  assert metric["scenario"] == "ml_benchmark"
99
  assert metric["agreement_reached"] is True
100
+
101
+
102
+ def test_scientist_compare_eval_cli_writes_rows(tmp_path, monkeypatch) -> None:
103
+ baseline_record = EpisodeRecord(
104
+ seed=101,
105
+ scenario="ml_benchmark",
106
+ difficulty="easy",
107
+ episode_id="baseline-1",
108
+ total_reward=1.0,
109
+ reward_breakdown=RewardBreakdown(rigor=0.4, feasibility=0.5, fidelity=0.6),
110
+ verdict="timeout",
111
+ agreement_reached=False,
112
+ )
113
+ trained_record = EpisodeRecord(
114
+ seed=101,
115
+ scenario="ml_benchmark",
116
+ difficulty="easy",
117
+ episode_id="trained-1",
118
+ total_reward=3.5,
119
+ reward_breakdown=RewardBreakdown(rigor=0.8, feasibility=0.9, fidelity=0.85),
120
+ verdict="accept",
121
+ agreement_reached=True,
122
+ )
123
+ rows = [
124
+ PolicyComparisonRow(
125
+ label="baseline",
126
+ episode_count=1,
127
+ average_reward=1.0,
128
+ average_rounds=2.0,
129
+ agreement_rate=0.0,
130
+ invalid_action_rate=0.5,
131
+ average_invalid_bounded_tool_rate=0.0,
132
+ average_rigor=0.4,
133
+ average_feasibility=0.5,
134
+ average_fidelity=0.6,
135
+ average_parsimony=1.0,
136
+ average_tool_trace_count=0.0,
137
+ ),
138
+ PolicyComparisonRow(
139
+ label="trained",
140
+ episode_count=1,
141
+ average_reward=3.5,
142
+ average_rounds=1.0,
143
+ agreement_rate=1.0,
144
+ invalid_action_rate=0.0,
145
+ average_invalid_bounded_tool_rate=0.0,
146
+ average_rigor=0.8,
147
+ average_feasibility=0.9,
148
+ average_fidelity=0.85,
149
+ average_parsimony=1.0,
150
+ average_tool_trace_count=0.0,
151
+ ),
152
+ ]
153
+
154
+ monkeypatch.setattr(
155
+ "replicalab.training.cli.build_remote_scientist_policy",
156
+ lambda **_: (lambda _obs: None),
157
+ )
158
+ monkeypatch.setattr(
159
+ "replicalab.training.cli.compare_policies",
160
+ lambda **_: (
161
+ {"baseline": [baseline_record], "trained": [trained_record]},
162
+ rows,
163
+ ),
164
+ )
165
+ monkeypatch.setattr(
166
+ "replicalab.training.cli.plot_evaluation_bars",
167
+ lambda *args, **kwargs: None,
168
+ )
169
+
170
+ exit_code = main(
171
+ [
172
+ "scientist-compare-eval",
173
+ "--persist-root",
174
+ str(tmp_path),
175
+ "--run-name",
176
+ "compare-eval-test",
177
+ "--eval-seeds",
178
+ "101",
179
+ "--scenarios",
180
+ "ml_benchmark",
181
+ "--difficulties",
182
+ "easy",
183
+ ]
184
+ )
185
+
186
+ assert exit_code == 0
187
+ summary_path = tmp_path / "compare-eval-test" / "reports" / "summary.json"
188
+ payload = json.loads(summary_path.read_text(encoding="utf-8"))
189
+ assert [row["label"] for row in payload["rows"]] == ["baseline", "trained"]
190
+ assert payload["rows"][1]["average_reward"] == 3.5
tests/test_training_metrics.py CHANGED
@@ -7,7 +7,11 @@ from replicalab.training.metrics import episode_to_metrics, summarize_episodes
7
  from replicalab.training.rollout import EpisodeRecord, StepRecord
8
 
9
 
10
- def _build_step_record(error: str | None = None) -> StepRecord:
 
 
 
 
11
  return StepRecord(
12
  round_number=0,
13
  observation=ScientistObservation(
@@ -36,6 +40,7 @@ def _build_step_record(error: str | None = None) -> StepRecord:
36
  done=False,
37
  error=error,
38
  info=StepInfo(error=error),
 
39
  )
40
 
41
 
@@ -55,12 +60,18 @@ def test_episode_to_metrics_counts_invalid_actions() -> None:
55
  ),
56
  verdict="accept",
57
  agreement_reached=True,
 
 
 
 
58
  )
59
 
60
  metrics = episode_to_metrics(record)
61
 
62
  assert metrics.invalid_action_count == 1
63
  assert metrics.invalid_action_rate == 0.5
 
 
64
  assert metrics.agreement_reached is True
65
 
66
 
@@ -75,6 +86,7 @@ def test_summarize_episodes_aggregates_rewards() -> None:
75
  reward_breakdown=RewardBreakdown(rigor=0.6, feasibility=0.7, fidelity=0.8),
76
  verdict="accept",
77
  agreement_reached=True,
 
78
  )
79
  second = EpisodeRecord(
80
  seed=2,
@@ -86,6 +98,7 @@ def test_summarize_episodes_aggregates_rewards() -> None:
86
  reward_breakdown=RewardBreakdown(rigor=0.2, feasibility=0.4, fidelity=0.5),
87
  verdict="timeout",
88
  agreement_reached=False,
 
89
  )
90
 
91
  summary = summarize_episodes([first, second])
@@ -93,3 +106,4 @@ def test_summarize_episodes_aggregates_rewards() -> None:
93
  assert summary.episode_count == 2
94
  assert summary.average_reward == 1.25
95
  assert 0.0 < summary.invalid_action_rate < 1.0
 
 
7
  from replicalab.training.rollout import EpisodeRecord, StepRecord
8
 
9
 
10
+ def _build_step_record(
11
+ error: str | None = None,
12
+ *,
13
+ tool_traces: list[dict[str, object]] | None = None,
14
+ ) -> StepRecord:
15
  return StepRecord(
16
  round_number=0,
17
  observation=ScientistObservation(
 
40
  done=False,
41
  error=error,
42
  info=StepInfo(error=error),
43
+ tool_traces=tool_traces or [],
44
  )
45
 
46
 
 
60
  ),
61
  verdict="accept",
62
  agreement_reached=True,
63
+ tool_traces=[
64
+ {"tool": "search_evidence", "status": "ok"},
65
+ {"tool": "run_code_check", "status": "error", "error": "timeout"},
66
+ ],
67
  )
68
 
69
  metrics = episode_to_metrics(record)
70
 
71
  assert metrics.invalid_action_count == 1
72
  assert metrics.invalid_action_rate == 0.5
73
+ assert metrics.invalid_bounded_tool_count == 1
74
+ assert metrics.invalid_bounded_tool_rate == 0.5
75
  assert metrics.agreement_reached is True
76
 
77
 
 
86
  reward_breakdown=RewardBreakdown(rigor=0.6, feasibility=0.7, fidelity=0.8),
87
  verdict="accept",
88
  agreement_reached=True,
89
+ tool_traces=[{"tool": "search_evidence", "status": "ok"}],
90
  )
91
  second = EpisodeRecord(
92
  seed=2,
 
98
  reward_breakdown=RewardBreakdown(rigor=0.2, feasibility=0.4, fidelity=0.5),
99
  verdict="timeout",
100
  agreement_reached=False,
101
+ tool_traces=[{"tool": "run_code_check", "status": "error"}],
102
  )
103
 
104
  summary = summarize_episodes([first, second])
 
106
  assert summary.episode_count == 2
107
  assert summary.average_reward == 1.25
108
  assert 0.0 < summary.invalid_action_rate < 1.0
109
+ assert summary.average_invalid_bounded_tool_rate == 0.5