Spaces:
Running
Running
| from __future__ import annotations | |
| import pytest | |
| from pydantic import ValidationError | |
| from replicalab.models import ( | |
| ConversationEntry, | |
| EpisodeLog, | |
| EpisodeState, | |
| LabManagerAction, | |
| LabManagerActionType, | |
| LabManagerObservation, | |
| Observation, | |
| Protocol, | |
| RewardBreakdown, | |
| ScientistAction, | |
| ScientistActionType, | |
| ScientistObservation, | |
| StepInfo, | |
| StepResult, | |
| ) | |
| def _valid_protocol_action() -> dict: | |
| return { | |
| "action_type": "propose_protocol", | |
| "sample_size": 48, | |
| "controls": ["vehicle_control", "positive_control"], | |
| "technique": "wst1_assay", | |
| "duration_days": 5, | |
| "required_equipment": ["plate_reader", "co2_incubator"], | |
| "required_reagents": ["wst1", "dmso", "drug_x"], | |
| "questions": [], | |
| "rationale": "Keeps the core readout while using common lab equipment.", | |
| } | |
| def _valid_lab_manager_action() -> dict: | |
| return { | |
| "action_type": "suggest_alternative", | |
| "feasible": False, | |
| "budget_ok": True, | |
| "equipment_ok": False, | |
| "reagents_ok": True, | |
| "schedule_ok": True, | |
| "staff_ok": True, | |
| "suggested_technique": "manual_cell_counting", | |
| "suggested_sample_size": 32, | |
| "suggested_controls": ["vehicle_control", "positive_control"], | |
| "explanation": "The plate reader is booked, so use manual counting.", | |
| } | |
| def _valid_observation_payload() -> dict: | |
| return { | |
| "scientist": { | |
| "paper_title": "Drug X reduces glioblastoma cell viability", | |
| "paper_hypothesis": "Drug X reduces viability in a dose-dependent manner.", | |
| "paper_method": "96-well viability assay with 24h incubation and absorbance readout.", | |
| "paper_key_finding": "The highest dose reduced viability by about 40 percent.", | |
| "experiment_goal": "Replicate the dose-response trend without dropping essential controls.", | |
| "conversation_history": [ | |
| { | |
| "role": "scientist", | |
| "message": "I propose a manual counting protocol.", | |
| "round_number": 0, | |
| "action_type": "propose_protocol", | |
| } | |
| ], | |
| "current_protocol": { | |
| "sample_size": 32, | |
| "controls": ["vehicle_control", "positive_control"], | |
| "technique": "manual_cell_counting", | |
| "duration_days": 5, | |
| "required_equipment": ["microscope", "co2_incubator"], | |
| "required_reagents": ["dmso", "drug_x", "culture_media"], | |
| "rationale": "Uses available equipment while preserving controls.", | |
| }, | |
| "round_number": 1, | |
| "max_rounds": 6, | |
| }, | |
| "lab_manager": { | |
| "budget_total": 1200.0, | |
| "budget_remaining": 850.0, | |
| "equipment_available": ["co2_incubator", "microscope"], | |
| "equipment_booked": ["plate_reader"], | |
| "reagents_in_stock": ["dmso", "drug_x", "culture_media"], | |
| "reagents_out_of_stock": ["wst1"], | |
| "staff_count": 2, | |
| "time_limit_days": 7, | |
| "safety_restrictions": ["no_radioactive_reagents"], | |
| "conversation_history": [ | |
| { | |
| "role": "lab_manager", | |
| "message": "The plate reader is unavailable.", | |
| "round_number": 1, | |
| "action_type": "suggest_alternative", | |
| } | |
| ], | |
| "current_protocol": { | |
| "sample_size": 32, | |
| "controls": ["vehicle_control", "positive_control"], | |
| "technique": "manual_cell_counting", | |
| "duration_days": 5, | |
| "required_equipment": ["microscope", "co2_incubator"], | |
| "required_reagents": ["dmso", "drug_x", "culture_media"], | |
| "rationale": "Uses available equipment while preserving controls.", | |
| }, | |
| "round_number": 1, | |
| "max_rounds": 6, | |
| }, | |
| } | |
| def test_scientist_action_accepts_valid_protocol_payload() -> None: | |
| action = ScientistAction.model_validate(_valid_protocol_action()) | |
| assert action.action_type is ScientistActionType.PROPOSE_PROTOCOL | |
| assert action.sample_size == 48 | |
| assert action.questions == [] | |
| def test_scientist_action_rejects_unknown_action_type() -> None: | |
| payload = _valid_protocol_action() | |
| payload["action_type"] = "banana" | |
| with pytest.raises(ValidationError): | |
| ScientistAction.model_validate(payload) | |
| def test_scientist_action_rejects_request_info_without_questions() -> None: | |
| payload = { | |
| "action_type": "request_info", | |
| "sample_size": 0, | |
| "controls": [], | |
| "technique": "", | |
| "duration_days": 0, | |
| "required_equipment": [], | |
| "required_reagents": [], | |
| "questions": [], | |
| "rationale": "", | |
| } | |
| with pytest.raises(ValidationError, match="questions must contain at least one item"): | |
| ScientistAction.model_validate(payload) | |
| def test_scientist_action_rejects_protocol_payload_for_request_info() -> None: | |
| payload = { | |
| "action_type": "request_info", | |
| "sample_size": 24, | |
| "controls": [], | |
| "technique": "", | |
| "duration_days": 0, | |
| "required_equipment": [], | |
| "required_reagents": [], | |
| "questions": ["What plate reader is available?"], | |
| "rationale": "", | |
| } | |
| with pytest.raises(ValidationError, match="request_info cannot include protocol"): | |
| ScientistAction.model_validate(payload) | |
| def test_scientist_action_rejects_protocol_with_zero_sample_size() -> None: | |
| payload = _valid_protocol_action() | |
| payload["sample_size"] = 0 | |
| with pytest.raises(ValidationError, match="sample_size must be >= 1"): | |
| ScientistAction.model_validate(payload) | |
| def test_scientist_action_rejects_extra_fields() -> None: | |
| payload = _valid_protocol_action() | |
| payload["unexpected"] = "value" | |
| with pytest.raises(ValidationError, match="Extra inputs are not permitted"): | |
| ScientistAction.model_validate(payload) | |
| def test_lab_manager_action_accepts_valid_suggestion_payload() -> None: | |
| action = LabManagerAction.model_validate(_valid_lab_manager_action()) | |
| assert action.action_type is LabManagerActionType.SUGGEST_ALTERNATIVE | |
| assert action.feasible is False | |
| assert action.suggested_sample_size == 32 | |
| def test_lab_manager_action_rejects_feasible_flag_mismatch() -> None: | |
| payload = _valid_lab_manager_action() | |
| payload["equipment_ok"] = True | |
| with pytest.raises(ValidationError, match="feasible must equal the logical AND"): | |
| LabManagerAction.model_validate(payload) | |
| def test_lab_manager_action_rejects_missing_suggestion_fields() -> None: | |
| payload = _valid_lab_manager_action() | |
| payload["suggested_technique"] = "" | |
| payload["suggested_sample_size"] = 0 | |
| payload["suggested_controls"] = [] | |
| with pytest.raises(ValidationError, match="requires at least one suggestion field"): | |
| LabManagerAction.model_validate(payload) | |
| def test_lab_manager_action_rejects_suggestions_for_report_feasibility() -> None: | |
| payload = _valid_lab_manager_action() | |
| payload["action_type"] = "report_feasibility" | |
| with pytest.raises(ValidationError, match="suggestion fields are only allowed"): | |
| LabManagerAction.model_validate(payload) | |
| def test_observation_coerces_nested_dicts_to_typed_models() -> None: | |
| observation = Observation.model_validate(_valid_observation_payload()) | |
| assert isinstance(observation.scientist, ScientistObservation) | |
| assert isinstance(observation.lab_manager, LabManagerObservation) | |
| assert isinstance(observation.scientist.conversation_history[0], ConversationEntry) | |
| assert isinstance(observation.scientist.current_protocol, Protocol) | |
| def test_observation_rejects_invalid_conversation_role() -> None: | |
| payload = _valid_observation_payload() | |
| payload["scientist"]["conversation_history"][0]["role"] = "reviewer" | |
| with pytest.raises(ValidationError): | |
| Observation.model_validate(payload) | |
| def test_observation_rejects_negative_budget() -> None: | |
| payload = _valid_observation_payload() | |
| payload["lab_manager"]["budget_total"] = -1.0 | |
| with pytest.raises(ValidationError): | |
| Observation.model_validate(payload) | |
| # --------------------------------------------------------------------------- | |
| # MOD 04 — Typed EpisodeState and EpisodeLog | |
| # --------------------------------------------------------------------------- | |
| def _sample_protocol() -> Protocol: | |
| return Protocol( | |
| sample_size=32, | |
| controls=["vehicle_control", "positive_control"], | |
| technique="manual_cell_counting", | |
| duration_days=5, | |
| required_equipment=["microscope", "co2_incubator"], | |
| required_reagents=["dmso", "drug_x", "culture_media"], | |
| rationale="Uses available equipment while preserving controls.", | |
| ) | |
| def _sample_conversation_entry() -> ConversationEntry: | |
| return ConversationEntry( | |
| role="scientist", | |
| message="I propose a manual counting protocol.", | |
| round_number=1, | |
| action_type="propose_protocol", | |
| ) | |
| def test_episode_state_accepts_typed_protocol_and_history() -> None: | |
| protocol = _sample_protocol() | |
| entry = _sample_conversation_entry() | |
| state = EpisodeState( | |
| seed=42, | |
| current_protocol=protocol, | |
| conversation_history=[entry], | |
| round_number=1, | |
| max_rounds=6, | |
| ) | |
| assert isinstance(state.current_protocol, Protocol) | |
| assert state.current_protocol.technique == "manual_cell_counting" | |
| assert isinstance(state.conversation_history[0], ConversationEntry) | |
| assert state.conversation_history[0].role == "scientist" | |
| def test_episode_state_accepts_none_protocol() -> None: | |
| state = EpisodeState(current_protocol=None, conversation_history=[]) | |
| assert state.current_protocol is None | |
| assert state.conversation_history == [] | |
| def test_episode_state_json_round_trip() -> None: | |
| protocol = _sample_protocol() | |
| entry = _sample_conversation_entry() | |
| state = EpisodeState( | |
| seed=7, | |
| scenario_template="math_reasoning", | |
| difficulty="hard", | |
| paper_title="Test Paper", | |
| current_protocol=protocol, | |
| conversation_history=[entry], | |
| round_number=2, | |
| max_rounds=6, | |
| ) | |
| dumped = state.model_dump_json() | |
| restored = EpisodeState.model_validate_json(dumped) | |
| assert isinstance(restored.current_protocol, Protocol) | |
| assert restored.current_protocol.sample_size == 32 | |
| assert isinstance(restored.conversation_history[0], ConversationEntry) | |
| assert restored.conversation_history[0].action_type == "propose_protocol" | |
| assert restored.seed == 7 | |
| def test_episode_log_accepts_typed_fields() -> None: | |
| entry = _sample_conversation_entry() | |
| breakdown = RewardBreakdown(rigor=0.8, feasibility=0.7, fidelity=0.9) | |
| log = EpisodeLog( | |
| episode_id="ep-001", | |
| seed=42, | |
| transcript=[entry], | |
| reward_breakdown=breakdown, | |
| total_reward=5.0, | |
| rounds_used=3, | |
| agreement_reached=True, | |
| ) | |
| assert isinstance(log.transcript[0], ConversationEntry) | |
| assert isinstance(log.reward_breakdown, RewardBreakdown) | |
| assert log.reward_breakdown.rigor == 0.8 | |
| def test_episode_log_none_reward_breakdown() -> None: | |
| log = EpisodeLog(episode_id="ep-002") | |
| assert log.reward_breakdown is None | |
| assert log.transcript == [] | |
| def test_episode_log_json_round_trip() -> None: | |
| entry = _sample_conversation_entry() | |
| breakdown = RewardBreakdown( | |
| rigor=0.6, feasibility=0.5, fidelity=0.7, | |
| efficiency_bonus=0.1, communication_bonus=0.05, | |
| penalties={"timeout": 0.02}, | |
| ) | |
| state = EpisodeState( | |
| seed=99, | |
| current_protocol=_sample_protocol(), | |
| conversation_history=[entry], | |
| round_number=3, | |
| max_rounds=6, | |
| done=True, | |
| agreement_reached=True, | |
| reward=5.0, | |
| rigor_score=0.6, | |
| ) | |
| log = EpisodeLog( | |
| episode_id="ep-round-trip", | |
| seed=99, | |
| final_state=state, | |
| transcript=[entry], | |
| reward_breakdown=breakdown, | |
| total_reward=5.0, | |
| rounds_used=3, | |
| agreement_reached=True, | |
| judge_notes="Good protocol.", | |
| verdict="accept", | |
| ) | |
| dumped = log.model_dump_json() | |
| restored = EpisodeLog.model_validate_json(dumped) | |
| assert isinstance(restored.final_state, EpisodeState) | |
| assert isinstance(restored.final_state.current_protocol, Protocol) | |
| assert isinstance(restored.final_state.conversation_history[0], ConversationEntry) | |
| assert isinstance(restored.transcript[0], ConversationEntry) | |
| assert isinstance(restored.reward_breakdown, RewardBreakdown) | |
| assert restored.reward_breakdown.penalties == {"timeout": 0.02} | |
| assert restored.episode_id == "ep-round-trip" | |
| def test_episode_log_nested_state_preserves_typed_fields() -> None: | |
| protocol = _sample_protocol() | |
| entry = _sample_conversation_entry() | |
| state = EpisodeState( | |
| current_protocol=protocol, | |
| conversation_history=[entry], | |
| ) | |
| log = EpisodeLog(final_state=state) | |
| assert isinstance(log.final_state.current_protocol, Protocol) | |
| assert log.final_state.current_protocol.technique == "manual_cell_counting" | |
| assert isinstance(log.final_state.conversation_history[0], ConversationEntry) | |
| def test_step_result_with_typed_info() -> None: | |
| breakdown = RewardBreakdown(rigor=0.8, feasibility=0.8, fidelity=0.8) | |
| info = StepInfo( | |
| agreement_reached=True, | |
| reward_breakdown=breakdown, | |
| judge_notes="All checks passed.", | |
| verdict="accept", | |
| round=3, | |
| stub=True, | |
| ) | |
| result = StepResult(reward=5.0, done=True, info=info) | |
| dumped = result.model_dump_json() | |
| restored = StepResult.model_validate_json(dumped) | |
| assert isinstance(restored.info, StepInfo) | |
| assert isinstance(restored.info.reward_breakdown, RewardBreakdown) | |
| assert restored.info.agreement_reached is True | |
| assert restored.info.verdict == "accept" | |