replicalab / tests /test_models.py
maxxie114's picture
Initial HF Spaces deployment
80d8c84
from __future__ import annotations
import pytest
from pydantic import ValidationError
from replicalab.models import (
ConversationEntry,
EpisodeLog,
EpisodeState,
LabManagerAction,
LabManagerActionType,
LabManagerObservation,
Observation,
Protocol,
RewardBreakdown,
ScientistAction,
ScientistActionType,
ScientistObservation,
StepInfo,
StepResult,
)
def _valid_protocol_action() -> dict:
return {
"action_type": "propose_protocol",
"sample_size": 48,
"controls": ["vehicle_control", "positive_control"],
"technique": "wst1_assay",
"duration_days": 5,
"required_equipment": ["plate_reader", "co2_incubator"],
"required_reagents": ["wst1", "dmso", "drug_x"],
"questions": [],
"rationale": "Keeps the core readout while using common lab equipment.",
}
def _valid_lab_manager_action() -> dict:
return {
"action_type": "suggest_alternative",
"feasible": False,
"budget_ok": True,
"equipment_ok": False,
"reagents_ok": True,
"schedule_ok": True,
"staff_ok": True,
"suggested_technique": "manual_cell_counting",
"suggested_sample_size": 32,
"suggested_controls": ["vehicle_control", "positive_control"],
"explanation": "The plate reader is booked, so use manual counting.",
}
def _valid_observation_payload() -> dict:
return {
"scientist": {
"paper_title": "Drug X reduces glioblastoma cell viability",
"paper_hypothesis": "Drug X reduces viability in a dose-dependent manner.",
"paper_method": "96-well viability assay with 24h incubation and absorbance readout.",
"paper_key_finding": "The highest dose reduced viability by about 40 percent.",
"experiment_goal": "Replicate the dose-response trend without dropping essential controls.",
"conversation_history": [
{
"role": "scientist",
"message": "I propose a manual counting protocol.",
"round_number": 0,
"action_type": "propose_protocol",
}
],
"current_protocol": {
"sample_size": 32,
"controls": ["vehicle_control", "positive_control"],
"technique": "manual_cell_counting",
"duration_days": 5,
"required_equipment": ["microscope", "co2_incubator"],
"required_reagents": ["dmso", "drug_x", "culture_media"],
"rationale": "Uses available equipment while preserving controls.",
},
"round_number": 1,
"max_rounds": 6,
},
"lab_manager": {
"budget_total": 1200.0,
"budget_remaining": 850.0,
"equipment_available": ["co2_incubator", "microscope"],
"equipment_booked": ["plate_reader"],
"reagents_in_stock": ["dmso", "drug_x", "culture_media"],
"reagents_out_of_stock": ["wst1"],
"staff_count": 2,
"time_limit_days": 7,
"safety_restrictions": ["no_radioactive_reagents"],
"conversation_history": [
{
"role": "lab_manager",
"message": "The plate reader is unavailable.",
"round_number": 1,
"action_type": "suggest_alternative",
}
],
"current_protocol": {
"sample_size": 32,
"controls": ["vehicle_control", "positive_control"],
"technique": "manual_cell_counting",
"duration_days": 5,
"required_equipment": ["microscope", "co2_incubator"],
"required_reagents": ["dmso", "drug_x", "culture_media"],
"rationale": "Uses available equipment while preserving controls.",
},
"round_number": 1,
"max_rounds": 6,
},
}
def test_scientist_action_accepts_valid_protocol_payload() -> None:
action = ScientistAction.model_validate(_valid_protocol_action())
assert action.action_type is ScientistActionType.PROPOSE_PROTOCOL
assert action.sample_size == 48
assert action.questions == []
def test_scientist_action_rejects_unknown_action_type() -> None:
payload = _valid_protocol_action()
payload["action_type"] = "banana"
with pytest.raises(ValidationError):
ScientistAction.model_validate(payload)
def test_scientist_action_rejects_request_info_without_questions() -> None:
payload = {
"action_type": "request_info",
"sample_size": 0,
"controls": [],
"technique": "",
"duration_days": 0,
"required_equipment": [],
"required_reagents": [],
"questions": [],
"rationale": "",
}
with pytest.raises(ValidationError, match="questions must contain at least one item"):
ScientistAction.model_validate(payload)
def test_scientist_action_rejects_protocol_payload_for_request_info() -> None:
payload = {
"action_type": "request_info",
"sample_size": 24,
"controls": [],
"technique": "",
"duration_days": 0,
"required_equipment": [],
"required_reagents": [],
"questions": ["What plate reader is available?"],
"rationale": "",
}
with pytest.raises(ValidationError, match="request_info cannot include protocol"):
ScientistAction.model_validate(payload)
def test_scientist_action_rejects_protocol_with_zero_sample_size() -> None:
payload = _valid_protocol_action()
payload["sample_size"] = 0
with pytest.raises(ValidationError, match="sample_size must be >= 1"):
ScientistAction.model_validate(payload)
def test_scientist_action_rejects_extra_fields() -> None:
payload = _valid_protocol_action()
payload["unexpected"] = "value"
with pytest.raises(ValidationError, match="Extra inputs are not permitted"):
ScientistAction.model_validate(payload)
def test_lab_manager_action_accepts_valid_suggestion_payload() -> None:
action = LabManagerAction.model_validate(_valid_lab_manager_action())
assert action.action_type is LabManagerActionType.SUGGEST_ALTERNATIVE
assert action.feasible is False
assert action.suggested_sample_size == 32
def test_lab_manager_action_rejects_feasible_flag_mismatch() -> None:
payload = _valid_lab_manager_action()
payload["equipment_ok"] = True
with pytest.raises(ValidationError, match="feasible must equal the logical AND"):
LabManagerAction.model_validate(payload)
def test_lab_manager_action_rejects_missing_suggestion_fields() -> None:
payload = _valid_lab_manager_action()
payload["suggested_technique"] = ""
payload["suggested_sample_size"] = 0
payload["suggested_controls"] = []
with pytest.raises(ValidationError, match="requires at least one suggestion field"):
LabManagerAction.model_validate(payload)
def test_lab_manager_action_rejects_suggestions_for_report_feasibility() -> None:
payload = _valid_lab_manager_action()
payload["action_type"] = "report_feasibility"
with pytest.raises(ValidationError, match="suggestion fields are only allowed"):
LabManagerAction.model_validate(payload)
def test_observation_coerces_nested_dicts_to_typed_models() -> None:
observation = Observation.model_validate(_valid_observation_payload())
assert isinstance(observation.scientist, ScientistObservation)
assert isinstance(observation.lab_manager, LabManagerObservation)
assert isinstance(observation.scientist.conversation_history[0], ConversationEntry)
assert isinstance(observation.scientist.current_protocol, Protocol)
def test_observation_rejects_invalid_conversation_role() -> None:
payload = _valid_observation_payload()
payload["scientist"]["conversation_history"][0]["role"] = "reviewer"
with pytest.raises(ValidationError):
Observation.model_validate(payload)
def test_observation_rejects_negative_budget() -> None:
payload = _valid_observation_payload()
payload["lab_manager"]["budget_total"] = -1.0
with pytest.raises(ValidationError):
Observation.model_validate(payload)
# ---------------------------------------------------------------------------
# MOD 04 — Typed EpisodeState and EpisodeLog
# ---------------------------------------------------------------------------
def _sample_protocol() -> Protocol:
return Protocol(
sample_size=32,
controls=["vehicle_control", "positive_control"],
technique="manual_cell_counting",
duration_days=5,
required_equipment=["microscope", "co2_incubator"],
required_reagents=["dmso", "drug_x", "culture_media"],
rationale="Uses available equipment while preserving controls.",
)
def _sample_conversation_entry() -> ConversationEntry:
return ConversationEntry(
role="scientist",
message="I propose a manual counting protocol.",
round_number=1,
action_type="propose_protocol",
)
def test_episode_state_accepts_typed_protocol_and_history() -> None:
protocol = _sample_protocol()
entry = _sample_conversation_entry()
state = EpisodeState(
seed=42,
current_protocol=protocol,
conversation_history=[entry],
round_number=1,
max_rounds=6,
)
assert isinstance(state.current_protocol, Protocol)
assert state.current_protocol.technique == "manual_cell_counting"
assert isinstance(state.conversation_history[0], ConversationEntry)
assert state.conversation_history[0].role == "scientist"
def test_episode_state_accepts_none_protocol() -> None:
state = EpisodeState(current_protocol=None, conversation_history=[])
assert state.current_protocol is None
assert state.conversation_history == []
def test_episode_state_json_round_trip() -> None:
protocol = _sample_protocol()
entry = _sample_conversation_entry()
state = EpisodeState(
seed=7,
scenario_template="math_reasoning",
difficulty="hard",
paper_title="Test Paper",
current_protocol=protocol,
conversation_history=[entry],
round_number=2,
max_rounds=6,
)
dumped = state.model_dump_json()
restored = EpisodeState.model_validate_json(dumped)
assert isinstance(restored.current_protocol, Protocol)
assert restored.current_protocol.sample_size == 32
assert isinstance(restored.conversation_history[0], ConversationEntry)
assert restored.conversation_history[0].action_type == "propose_protocol"
assert restored.seed == 7
def test_episode_log_accepts_typed_fields() -> None:
entry = _sample_conversation_entry()
breakdown = RewardBreakdown(rigor=0.8, feasibility=0.7, fidelity=0.9)
log = EpisodeLog(
episode_id="ep-001",
seed=42,
transcript=[entry],
reward_breakdown=breakdown,
total_reward=5.0,
rounds_used=3,
agreement_reached=True,
)
assert isinstance(log.transcript[0], ConversationEntry)
assert isinstance(log.reward_breakdown, RewardBreakdown)
assert log.reward_breakdown.rigor == 0.8
def test_episode_log_none_reward_breakdown() -> None:
log = EpisodeLog(episode_id="ep-002")
assert log.reward_breakdown is None
assert log.transcript == []
def test_episode_log_json_round_trip() -> None:
entry = _sample_conversation_entry()
breakdown = RewardBreakdown(
rigor=0.6, feasibility=0.5, fidelity=0.7,
efficiency_bonus=0.1, communication_bonus=0.05,
penalties={"timeout": 0.02},
)
state = EpisodeState(
seed=99,
current_protocol=_sample_protocol(),
conversation_history=[entry],
round_number=3,
max_rounds=6,
done=True,
agreement_reached=True,
reward=5.0,
rigor_score=0.6,
)
log = EpisodeLog(
episode_id="ep-round-trip",
seed=99,
final_state=state,
transcript=[entry],
reward_breakdown=breakdown,
total_reward=5.0,
rounds_used=3,
agreement_reached=True,
judge_notes="Good protocol.",
verdict="accept",
)
dumped = log.model_dump_json()
restored = EpisodeLog.model_validate_json(dumped)
assert isinstance(restored.final_state, EpisodeState)
assert isinstance(restored.final_state.current_protocol, Protocol)
assert isinstance(restored.final_state.conversation_history[0], ConversationEntry)
assert isinstance(restored.transcript[0], ConversationEntry)
assert isinstance(restored.reward_breakdown, RewardBreakdown)
assert restored.reward_breakdown.penalties == {"timeout": 0.02}
assert restored.episode_id == "ep-round-trip"
def test_episode_log_nested_state_preserves_typed_fields() -> None:
protocol = _sample_protocol()
entry = _sample_conversation_entry()
state = EpisodeState(
current_protocol=protocol,
conversation_history=[entry],
)
log = EpisodeLog(final_state=state)
assert isinstance(log.final_state.current_protocol, Protocol)
assert log.final_state.current_protocol.technique == "manual_cell_counting"
assert isinstance(log.final_state.conversation_history[0], ConversationEntry)
def test_step_result_with_typed_info() -> None:
breakdown = RewardBreakdown(rigor=0.8, feasibility=0.8, fidelity=0.8)
info = StepInfo(
agreement_reached=True,
reward_breakdown=breakdown,
judge_notes="All checks passed.",
verdict="accept",
round=3,
stub=True,
)
result = StepResult(reward=5.0, done=True, info=info)
dumped = result.model_dump_json()
restored = StepResult.model_validate_json(dumped)
assert isinstance(restored.info, StepInfo)
assert isinstance(restored.info.reward_breakdown, RewardBreakdown)
assert restored.info.agreement_reached is True
assert restored.info.verdict == "accept"