replicalab / tests /test_scientist_policy.py
maxxie114's picture
Initial HF Spaces deployment
80d8c84
from __future__ import annotations
import json
import httpx
import pytest
from replicalab.agents.scientist_policy import (
RetryMetadata,
ScientistCallResult,
ScientistOutputParseError,
build_anthropic_scientist_policy,
build_baseline_scientist_action,
build_ollama_scientist_policy,
build_scientist_system_prompt,
call_scientist_with_retry,
format_scientist_observation,
parse_scientist_output,
)
from replicalab.models import (
ConversationEntry,
Protocol,
ScientistActionType,
ScientistObservation,
)
from replicalab.scenarios import generate_scenario
# ---------------------------------------------------------------------------
# Shared valid JSON payloads for retry tests
# ---------------------------------------------------------------------------
_VALID_REQUEST_INFO_JSON = """{
"action_type": "request_info",
"sample_size": 0,
"controls": [],
"technique": "",
"duration_days": 0,
"required_equipment": [],
"required_reagents": [],
"questions": ["What compute budget is available?"],
"rationale": ""
}"""
def test_parse_scientist_output_accepts_plain_json() -> None:
raw_text = """
{
"action_type": "request_info",
"sample_size": 0,
"controls": [],
"technique": "",
"duration_days": 0,
"required_equipment": [],
"required_reagents": [],
"questions": ["What compute budget is available?"],
"rationale": ""
}
"""
action = parse_scientist_output(raw_text)
assert action.action_type is ScientistActionType.REQUEST_INFO
assert action.questions == ["What compute budget is available?"]
def test_parse_scientist_output_accepts_fenced_json_with_prose() -> None:
raw_text = """
I would revise the plan as follows:
```json
{
"action_type": "revise_protocol",
"sample_size": 24,
"controls": ["baseline", "ablation"],
"technique": "small_scale_backtest",
"duration_days": 3,
"required_equipment": ["gpu_node"],
"required_reagents": [],
"questions": [],
"rationale": "Shrink the trial to fit the available compute window."
}
```
"""
action = parse_scientist_output(raw_text)
assert action.action_type is ScientistActionType.REVISE_PROTOCOL
assert action.technique == "small_scale_backtest"
def test_parse_scientist_output_raises_explicit_error_when_json_is_missing() -> None:
with pytest.raises(ScientistOutputParseError) as exc_info:
parse_scientist_output("I need more context before I can answer.")
assert exc_info.value.code == "no_json"
assert "does not contain a JSON object" in exc_info.value.message
def test_parse_scientist_output_raises_explicit_error_when_json_is_invalid() -> None:
raw_text = """
```json
{
"action_type": "request_info",
"questions": ["What budget do we have?"],
}
```
"""
with pytest.raises(ScientistOutputParseError) as exc_info:
parse_scientist_output(raw_text)
assert exc_info.value.code == "invalid_json"
assert "could not be decoded" in exc_info.value.message
def test_parse_scientist_output_raises_explicit_error_when_schema_is_invalid() -> None:
raw_text = """
{
"action_type": "request_info",
"sample_size": 0,
"controls": [],
"technique": "",
"duration_days": 0,
"required_equipment": [],
"required_reagents": [],
"questions": [],
"rationale": ""
}
"""
with pytest.raises(ScientistOutputParseError) as exc_info:
parse_scientist_output(raw_text)
assert exc_info.value.code == "invalid_action"
assert "ScientistAction validation" in exc_info.value.message
def test_build_scientist_system_prompt_uses_normalized_scenario_data() -> None:
scenario = generate_scenario(seed=202, template="ml_benchmark", difficulty="medium")
prompt = build_scientist_system_prompt(scenario)
assert "You are the Scientist agent in ReplicaLab." in prompt
assert scenario.task_summary in prompt
assert scenario.success_criteria[0] in prompt
assert scenario.resources[0].label in prompt
assert "action_type values" in prompt
assert "propose_protocol" in prompt
assert "request_info" in prompt
# ---------------------------------------------------------------------------
# AGT 02 — format_scientist_observation
# ---------------------------------------------------------------------------
def _base_observation(**overrides) -> ScientistObservation:
defaults = {
"paper_title": "Test Paper Title",
"paper_hypothesis": "X improves Y.",
"paper_method": "Run benchmark Z.",
"paper_key_finding": "10% improvement.",
"experiment_goal": "Replicate the 10% improvement.",
"conversation_history": [],
"current_protocol": None,
"round_number": 0,
"max_rounds": 6,
}
defaults.update(overrides)
return ScientistObservation(**defaults)
def test_format_observation_empty_history_no_protocol() -> None:
obs = _base_observation()
result = format_scientist_observation(obs)
assert "Round 0 of 6" in result
assert "Test Paper Title" in result
assert "X improves Y." in result
assert "Run benchmark Z." in result
assert "10% improvement." in result
assert "Replicate the 10% improvement." in result
assert "No conversation history yet" in result
assert "No protocol has been proposed yet" in result
assert "ScientistAction schema reminder:" in result
assert (
"allowed action_type values: propose_protocol, revise_protocol, "
"request_info, accept"
) in result
assert "include all ScientistAction fields in every response" in result
assert "Respond with exactly one JSON object" in result
def test_format_observation_with_history_and_protocol() -> None:
history = [
ConversationEntry(
role="scientist",
message="I propose a grid search protocol.",
round_number=1,
action_type="propose_protocol",
),
ConversationEntry(
role="lab_manager",
message="The compute cluster is booked. Consider a smaller run.",
round_number=1,
action_type="suggest_alternative",
),
]
protocol = Protocol(
sample_size=24,
controls=["baseline", "ablation"],
technique="grid_search",
duration_days=3,
required_equipment=["gpu_node"],
required_reagents=["benchmark_dataset"],
rationale="Standard hyperparameter sweep.",
)
obs = _base_observation(
conversation_history=history,
current_protocol=protocol,
round_number=2,
)
result = format_scientist_observation(obs)
assert "Round 2 of 6" in result
assert "Conversation so far:" in result
assert "[SCIENTIST r1 [propose_protocol]]" in result
assert "I propose a grid search protocol." in result
assert "[LAB_MANAGER r1 [suggest_alternative]]" in result
assert "compute cluster is booked" in result
assert "Current protocol:" in result
assert "technique: grid_search" in result
assert "sample_size: 24" in result
assert "controls: baseline, ablation" in result
assert "duration_days: 3" in result
assert "gpu_node" in result
assert "benchmark_dataset" in result
assert "Standard hyperparameter sweep." in result
def test_format_observation_stable_section_order() -> None:
obs = _base_observation(
conversation_history=[
ConversationEntry(
role="scientist",
message="First proposal.",
round_number=1,
action_type="propose_protocol",
),
],
current_protocol=Protocol(
sample_size=10,
controls=["ctrl"],
technique="method_a",
duration_days=2,
required_equipment=["tool_a"],
required_reagents=["reagent_a"],
rationale="Simple test.",
),
round_number=1,
)
result = format_scientist_observation(obs)
# Sections must appear in this fixed order
round_pos = result.index("Round 1 of 6")
paper_pos = result.index("Paper:")
history_pos = result.index("Conversation so far:")
protocol_pos = result.index("Current protocol:")
action_pos = result.index("Respond with exactly one JSON")
assert round_pos < paper_pos < history_pos < protocol_pos < action_pos
def test_format_observation_history_entry_without_action_type() -> None:
obs = _base_observation(
conversation_history=[
ConversationEntry(
role="system",
message="Episode started.",
round_number=0,
action_type=None,
),
],
)
result = format_scientist_observation(obs)
assert "[SYSTEM r0]:" in result
assert "Episode started." in result
def test_format_observation_from_generated_scenario() -> None:
scenario = generate_scenario(seed=77, template="finance_trading", difficulty="easy")
obs = scenario.scientist_observation
result = format_scientist_observation(obs)
assert obs.paper_title in result
assert obs.experiment_goal in result
assert "Round 0" in result
assert "No conversation history yet" in result
assert "Respond with exactly one JSON" in result
# ---------------------------------------------------------------------------
# AGT 03 — call_scientist_with_retry
# ---------------------------------------------------------------------------
def _make_system_prompt() -> str:
scenario = generate_scenario(seed=1, template="math_reasoning", difficulty="easy")
return build_scientist_system_prompt(scenario)
def test_retry_success_on_first_try() -> None:
def gen_fn(messages: list[dict[str, str]]) -> str:
return _VALID_REQUEST_INFO_JSON
obs = _base_observation()
result = call_scientist_with_retry(gen_fn, _make_system_prompt(), obs)
assert isinstance(result, ScientistCallResult)
assert result.action.action_type is ScientistActionType.REQUEST_INFO
assert result.metadata.attempt_count == 1
assert result.metadata.retry_count == 0
assert result.metadata.last_error_code is None
assert result.metadata.last_error_message is None
def test_retry_malformed_json_then_valid() -> None:
call_count = 0
def gen_fn(messages: list[dict[str, str]]) -> str:
nonlocal call_count
call_count += 1
if call_count == 1:
return '{"action_type": "request_info", trailing garbage'
return _VALID_REQUEST_INFO_JSON
obs = _base_observation()
result = call_scientist_with_retry(gen_fn, _make_system_prompt(), obs)
assert result.action.action_type is ScientistActionType.REQUEST_INFO
assert result.metadata.attempt_count == 2
assert result.metadata.retry_count == 1
assert result.metadata.last_error_code == "invalid_json"
def test_retry_invalid_action_then_valid() -> None:
# First attempt: valid JSON but questions is empty for request_info
invalid_json = """{
"action_type": "request_info",
"sample_size": 0,
"controls": [],
"technique": "",
"duration_days": 0,
"required_equipment": [],
"required_reagents": [],
"questions": [],
"rationale": ""
}"""
call_count = 0
def gen_fn(messages: list[dict[str, str]]) -> str:
nonlocal call_count
call_count += 1
if call_count == 1:
return invalid_json
return _VALID_REQUEST_INFO_JSON
obs = _base_observation()
result = call_scientist_with_retry(gen_fn, _make_system_prompt(), obs)
assert result.action.action_type is ScientistActionType.REQUEST_INFO
assert result.metadata.attempt_count == 2
assert result.metadata.retry_count == 1
assert result.metadata.last_error_code == "invalid_action"
assert "ScientistAction validation" in result.metadata.last_error_message
def test_retry_exhausted_raises_last_error() -> None:
def gen_fn(messages: list[dict[str, str]]) -> str:
return "I cannot produce JSON right now."
obs = _base_observation()
with pytest.raises(ScientistOutputParseError) as exc_info:
call_scientist_with_retry(gen_fn, _make_system_prompt(), obs, max_retries=2)
assert exc_info.value.code == "no_json"
def test_retry_correction_message_includes_parser_error() -> None:
"""The correction prompt sent to the model must include the specific error."""
captured_messages: list[list[dict[str, str]]] = []
call_count = 0
def gen_fn(messages: list[dict[str, str]]) -> str:
nonlocal call_count
call_count += 1
captured_messages.append(list(messages))
if call_count == 1:
return "Just some prose, no JSON here."
return _VALID_REQUEST_INFO_JSON
obs = _base_observation()
call_scientist_with_retry(gen_fn, _make_system_prompt(), obs)
# Second call should have 4 messages: system, user, assistant (bad), user (correction)
assert len(captured_messages) == 2
retry_messages = captured_messages[1]
assert len(retry_messages) == 4
assert retry_messages[2]["role"] == "assistant"
assert retry_messages[2]["content"] == "Just some prose, no JSON here."
assert retry_messages[3]["role"] == "user"
correction = retry_messages[3]["content"]
assert "did not contain a JSON object" in correction
assert "No markdown fences, no prose" in correction
def test_retry_correction_for_invalid_action_includes_validation_detail() -> None:
"""invalid_action correction must include the schema validation message."""
captured_messages: list[list[dict[str, str]]] = []
invalid_json = """{
"action_type": "request_info",
"sample_size": 0,
"controls": [],
"technique": "",
"duration_days": 0,
"required_equipment": [],
"required_reagents": [],
"questions": [],
"rationale": ""
}"""
call_count = 0
def gen_fn(messages: list[dict[str, str]]) -> str:
nonlocal call_count
call_count += 1
captured_messages.append(list(messages))
if call_count == 1:
return invalid_json
return _VALID_REQUEST_INFO_JSON
obs = _base_observation()
call_scientist_with_retry(gen_fn, _make_system_prompt(), obs)
retry_messages = captured_messages[1]
correction = retry_messages[3]["content"]
assert "failed ScientistAction validation" in correction
assert "Fix the validation error" in correction
def test_retry_metadata_serializable() -> None:
def gen_fn(messages: list[dict[str, str]]) -> str:
return _VALID_REQUEST_INFO_JSON
obs = _base_observation()
result = call_scientist_with_retry(gen_fn, _make_system_prompt(), obs)
dumped = result.metadata.model_dump_json()
restored = RetryMetadata.model_validate_json(dumped)
assert restored.attempt_count == 1
assert restored.retry_count == 0
# ---------------------------------------------------------------------------
# AGT 04 - build_baseline_scientist_action
# ---------------------------------------------------------------------------
def test_baseline_scientist_proposes_protocol_for_fresh_observation() -> None:
action = build_baseline_scientist_action(_base_observation())
assert action.action_type is ScientistActionType.PROPOSE_PROTOCOL
assert action.sample_size >= 1
assert action.duration_days >= 1
assert action.questions == []
assert action.rationale
def test_baseline_scientist_accepts_existing_protocol_without_blocker() -> None:
obs = _base_observation(
current_protocol=Protocol(
sample_size=10,
controls=["baseline_check"],
technique="published_split_replication",
duration_days=2,
required_equipment=[],
required_reagents=[],
rationale="Initial protocol is already in place.",
),
conversation_history=[
ConversationEntry(
role="lab_manager",
message="The current plan remains feasible.",
round_number=1,
action_type="report_feasibility",
)
],
round_number=1,
)
action = build_baseline_scientist_action(obs)
assert action.action_type is ScientistActionType.ACCEPT
assert action.sample_size == 0
assert action.controls == []
def test_baseline_scientist_revises_when_latest_feedback_has_blocker() -> None:
obs = _base_observation(
current_protocol=Protocol(
sample_size=12,
controls=["published_split_check", "heldout_evaluation"],
technique="published_split_replication",
duration_days=3,
required_equipment=[],
required_reagents=[],
rationale="Original scope is full-size.",
),
conversation_history=[
ConversationEntry(
role="lab_manager",
message="The current GPU plan is booked, so the schedule is too tight.",
round_number=1,
action_type="suggest_alternative",
)
],
round_number=1,
)
action = build_baseline_scientist_action(obs)
assert action.action_type is ScientistActionType.REVISE_PROTOCOL
assert action.sample_size == 6
assert action.duration_days == 2
assert "latest Lab Manager concern" in action.rationale
def test_baseline_scientist_finishes_stub_episode_without_crashing() -> None:
from server.app import _StubEnv
env = _StubEnv()
first_observation = env.reset(
seed=14,
scenario="ml_benchmark",
difficulty="easy",
).scientist
assert first_observation is not None
first_action = build_baseline_scientist_action(first_observation)
first_step = env.step(first_action)
assert first_step.done is False
assert first_step.observation is not None
assert first_step.observation.scientist is not None
second_action = build_baseline_scientist_action(first_step.observation.scientist)
second_step = env.step(second_action)
assert second_step.done is True
assert second_step.info.agreement_reached is True
# ---------------------------------------------------------------------------
# AGT 08 — Extended prompt, parser, formatter, and baseline coverage
# ---------------------------------------------------------------------------
# --- Parser happy paths ---
def test_parse_scientist_output_accepts_propose_protocol() -> None:
raw_text = """{
"action_type": "propose_protocol",
"sample_size": 48,
"controls": ["vehicle_control", "positive_control"],
"technique": "wst1_assay",
"duration_days": 5,
"required_equipment": ["plate_reader"],
"required_reagents": ["wst1", "dmso"],
"questions": [],
"rationale": "Standard viability assay with two controls."
}"""
action = parse_scientist_output(raw_text)
assert action.action_type is ScientistActionType.PROPOSE_PROTOCOL
assert action.sample_size == 48
assert action.technique == "wst1_assay"
assert action.controls == ["vehicle_control", "positive_control"]
assert action.questions == []
def test_parse_scientist_output_accepts_accept_action() -> None:
raw_text = """{
"action_type": "accept",
"sample_size": 0,
"controls": [],
"technique": "",
"duration_days": 0,
"required_equipment": [],
"required_reagents": [],
"questions": [],
"rationale": ""
}"""
action = parse_scientist_output(raw_text)
assert action.action_type is ScientistActionType.ACCEPT
assert action.sample_size == 0
assert action.rationale == ""
def test_parse_scientist_output_accepts_prose_wrapped_json() -> None:
raw_text = (
"After reviewing the constraints I think a request is in order.\n\n"
'{"action_type": "request_info", "sample_size": 0, '
'"controls": [], "technique": "", "duration_days": 0, '
'"required_equipment": [], "required_reagents": [], '
'"questions": ["Is the GPU available?"], "rationale": ""}\n\n'
"That should clarify the compute situation."
)
action = parse_scientist_output(raw_text)
assert action.action_type is ScientistActionType.REQUEST_INFO
assert action.questions == ["Is the GPU available?"]
# --- Parser edge cases ---
def test_parse_scientist_output_raises_on_empty_string() -> None:
with pytest.raises(ScientistOutputParseError) as exc_info:
parse_scientist_output("")
assert exc_info.value.code == "no_json"
def test_parse_scientist_output_raises_on_whitespace_only() -> None:
with pytest.raises(ScientistOutputParseError) as exc_info:
parse_scientist_output(" \n\t ")
assert exc_info.value.code == "no_json"
def test_parse_scientist_output_raises_on_json_list() -> None:
# The parser's brace extractor finds the inner object from the list,
# so this surfaces as an invalid_action (missing required fields)
# rather than an invalid_json error.
with pytest.raises(ScientistOutputParseError) as exc_info:
parse_scientist_output('[{"action_type": "accept"}]')
assert exc_info.value.code == "invalid_action"
def test_parse_scientist_output_raises_on_extra_forbidden_keys() -> None:
raw_text = """{
"action_type": "accept",
"sample_size": 0,
"controls": [],
"technique": "",
"duration_days": 0,
"required_equipment": [],
"required_reagents": [],
"questions": [],
"rationale": "",
"secret_field": "should not be here"
}"""
with pytest.raises(ScientistOutputParseError) as exc_info:
parse_scientist_output(raw_text)
assert exc_info.value.code == "invalid_action"
assert exc_info.value.parsed_payload is not None
assert "secret_field" in exc_info.value.parsed_payload
def test_parse_error_to_dict_serialization() -> None:
try:
parse_scientist_output("no json here")
except ScientistOutputParseError as exc:
result = exc.to_dict()
assert result["code"] == "no_json"
assert result["raw_text"] == "no json here"
assert result["parsed_payload"] is None
assert "message" in result
else:
pytest.fail("Expected ScientistOutputParseError")
def test_parse_error_to_dict_with_parsed_payload() -> None:
raw_text = """{
"action_type": "request_info",
"sample_size": 0,
"controls": [],
"technique": "",
"duration_days": 0,
"required_equipment": [],
"required_reagents": [],
"questions": [],
"rationale": ""
}"""
try:
parse_scientist_output(raw_text)
except ScientistOutputParseError as exc:
result = exc.to_dict()
assert result["code"] == "invalid_action"
assert result["parsed_payload"] is not None
assert result["parsed_payload"]["action_type"] == "request_info"
else:
pytest.fail("Expected ScientistOutputParseError")
# --- System prompt: domain coverage ---
def test_system_prompt_math_domain() -> None:
scenario = generate_scenario(seed=10, template="math_reasoning", difficulty="easy")
prompt = build_scientist_system_prompt(scenario)
assert "Domain: mathematics" in prompt
assert scenario.task_summary in prompt
assert "You are the Scientist agent" in prompt
def test_system_prompt_finance_domain() -> None:
scenario = generate_scenario(seed=10, template="finance_trading", difficulty="easy")
prompt = build_scientist_system_prompt(scenario)
assert "Domain: finance_trading" in prompt
assert scenario.task_summary in prompt
def test_system_prompt_ml_domain() -> None:
scenario = generate_scenario(seed=10, template="ml_benchmark", difficulty="easy")
prompt = build_scientist_system_prompt(scenario)
assert "Domain: machine_learning" in prompt
assert scenario.task_summary in prompt
def test_system_prompt_accepts_dict_input() -> None:
scenario = generate_scenario(seed=5, template="math_reasoning", difficulty="easy")
pack_dict = scenario.model_dump()
prompt = build_scientist_system_prompt(pack_dict)
assert "You are the Scientist agent" in prompt
assert scenario.task_summary in prompt
assert "Domain: mathematics" in prompt
# --- System prompt: bounded-tool policy assertions ---
def test_system_prompt_contains_bounded_tool_policy() -> None:
scenario = generate_scenario(seed=1, template="math_reasoning", difficulty="easy")
prompt = build_scientist_system_prompt(scenario)
assert "search_evidence" in prompt
assert "run_code_check" in prompt
assert "inspect_image" in prompt
def test_system_prompt_bounded_tool_policy_rules() -> None:
scenario = generate_scenario(seed=1, template="ml_benchmark", difficulty="medium")
prompt = build_scientist_system_prompt(scenario)
assert "No unrestricted web browsing" in prompt
assert "No audio" in prompt
assert "do not override constraints" in prompt or "Tools do not override constraints" in prompt
def test_system_prompt_bounded_tool_policy_present_in_all_domains() -> None:
for template in ("math_reasoning", "ml_benchmark", "finance_trading"):
scenario = generate_scenario(seed=42, template=template, difficulty="easy")
prompt = build_scientist_system_prompt(scenario)
assert "Bounded tool policy" in prompt, f"Missing in {template}"
assert "search_evidence" in prompt, f"Missing search_evidence in {template}"
assert "run_code_check" in prompt, f"Missing run_code_check in {template}"
assert "inspect_image" in prompt, f"Missing inspect_image in {template}"
# --- System prompt: role-boundary assertions ---
def test_system_prompt_contains_role_boundaries() -> None:
scenario = generate_scenario(seed=1, template="math_reasoning", difficulty="easy")
prompt = build_scientist_system_prompt(scenario)
assert "do not invent resources" in prompt
assert "do not assume access to hidden ground truth" in prompt.lower() or \
"hidden ground truth" in prompt
def test_system_prompt_contains_output_contract() -> None:
scenario = generate_scenario(seed=1, template="math_reasoning", difficulty="easy")
prompt = build_scientist_system_prompt(scenario)
assert "Output contract" in prompt
assert "exactly one JSON object" in prompt
assert "no extra keys" in prompt
# --- Observation formatter edge cases ---
def test_format_observation_final_round() -> None:
obs = _base_observation(round_number=5, max_rounds=6)
result = format_scientist_observation(obs)
assert "Round 5 of 6" in result
assert "Respond with exactly one JSON" in result
def test_format_observation_protocol_with_empty_lists() -> None:
protocol = Protocol(
sample_size=1,
controls=[],
technique="minimal_check",
duration_days=1,
required_equipment=[],
required_reagents=[],
rationale="Minimal protocol.",
)
obs = _base_observation(current_protocol=protocol, round_number=1)
result = format_scientist_observation(obs)
assert "Current protocol:" in result
assert "technique: minimal_check" in result
assert "controls: (none)" in result
assert "required_equipment: (none)" in result
assert "required_reagents: (none)" in result
# --- Baseline: domain inference ---
def test_baseline_scientist_infers_ml_domain() -> None:
obs = _base_observation(
paper_title="Reproducing CIFAR-10 accuracy with ResNet",
paper_method="Train on CIFAR dataset with GPU",
experiment_goal="Match the published benchmark accuracy.",
)
action = build_baseline_scientist_action(obs)
assert action.action_type is ScientistActionType.PROPOSE_PROTOCOL
assert action.technique == "published_split_replication"
def test_baseline_scientist_infers_finance_domain() -> None:
obs = _base_observation(
paper_title="Offline backtest of SPY mean-reversion",
paper_method="Daily bar backtest with slippage modeling",
experiment_goal="Evaluate Sharpe ratio under drawdown limits.",
)
action = build_baseline_scientist_action(obs)
assert action.action_type is ScientistActionType.PROPOSE_PROTOCOL
assert action.technique == "offline_backtest_workflow"
def test_baseline_scientist_infers_math_domain() -> None:
obs = _base_observation(
paper_title="Planning a proof of AM-GM inequality",
paper_method="Algebraic manipulation with induction.",
experiment_goal="Verify the proof outline.",
)
action = build_baseline_scientist_action(obs)
assert action.action_type is ScientistActionType.PROPOSE_PROTOCOL
assert action.technique == "structured_proof_outline"
# --- Baseline: forced accept at final round ---
def test_baseline_scientist_accepts_at_final_round_even_with_blocker() -> None:
obs = _base_observation(
current_protocol=Protocol(
sample_size=20,
controls=["ctrl"],
technique="method_a",
duration_days=5,
required_equipment=[],
required_reagents=[],
rationale="Full scope plan.",
),
conversation_history=[
ConversationEntry(
role="lab_manager",
message="Budget is tight and equipment is booked.",
round_number=4,
action_type="suggest_alternative",
),
],
round_number=5,
max_rounds=6,
)
action = build_baseline_scientist_action(obs)
assert action.action_type is ScientistActionType.ACCEPT
def test_build_anthropic_scientist_policy_calls_messages_api() -> None:
captured: dict[str, object] = {}
def handler(request: httpx.Request) -> httpx.Response:
captured["headers"] = dict(request.headers)
captured["payload"] = json.loads(request.content.decode("utf-8"))
return httpx.Response(
200,
json={
"content": [
{
"type": "text",
"text": _VALID_REQUEST_INFO_JSON,
}
]
},
)
client = httpx.Client(transport=httpx.MockTransport(handler))
policy = build_anthropic_scientist_policy(
api_key="test-key",
model="claude-test",
client=client,
base_url="https://api.anthropic.com/v1/messages",
)
action = policy(
_base_observation(),
scenario="ml_benchmark",
difficulty="medium",
)
assert action.action_type is ScientistActionType.REQUEST_INFO
payload = captured["payload"]
assert isinstance(payload, dict)
assert payload["model"] == "claude-test"
assert payload["temperature"] == 0.0
assert payload["messages"][0]["role"] == "user"
headers = captured["headers"]
assert isinstance(headers, dict)
assert headers["x-api-key"] == "test-key"
def test_build_ollama_scientist_policy_calls_local_chat_api() -> None:
captured: dict[str, object] = {}
def handler(request: httpx.Request) -> httpx.Response:
captured["payload"] = json.loads(request.content.decode("utf-8"))
return httpx.Response(
200,
json={
"message": {
"role": "assistant",
"content": _VALID_REQUEST_INFO_JSON,
}
},
)
client = httpx.Client(transport=httpx.MockTransport(handler))
policy = build_ollama_scientist_policy(
model="glm-5:cloud",
client=client,
base_url="http://127.0.0.1:11434/api/chat",
)
action = policy(
_base_observation(),
scenario="finance_trading",
difficulty="medium",
)
assert action.action_type is ScientistActionType.REQUEST_INFO
payload = captured["payload"]
assert isinstance(payload, dict)
assert payload["model"] == "glm-5:cloud"
assert payload["stream"] is False
assert payload["format"] == "json"
assert payload["messages"][0]["role"] == "system"