Spaces:
Running
Running
File size: 3,306 Bytes
80d8c84 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 | from __future__ import annotations
from replicalab.models import ScientistAction, ScientistObservation
from replicalab.scoring import score_paper_understanding
def _observation() -> ScientistObservation:
return ScientistObservation(
paper_title="ResNet augmentation study",
paper_hypothesis="Data augmentation improves CIFAR-10 classification accuracy.",
paper_method="Train a ResNet model on CIFAR-10 with and without augmentation.",
paper_key_finding="The augmented run improves top-1 accuracy over the baseline.",
experiment_goal="Replicate the augmentation-driven accuracy gain on CIFAR-10.",
conversation_history=[],
current_protocol=None,
round_number=0,
max_rounds=6,
)
def test_grounded_protocol_scores_higher_than_generic_protocol() -> None:
observation = _observation()
grounded = ScientistAction(
action_type="propose_protocol",
sample_size=8,
controls=["no_augmentation_baseline"],
technique="ResNet replication on CIFAR-10 with augmentation ablation",
duration_days=2,
required_equipment=["gpu_h100"],
required_reagents=[],
questions=[],
rationale=(
"Train the CIFAR-10 ResNet baseline and the augmented variant to test "
"whether augmentation improves top-1 accuracy."
),
)
generic = ScientistAction(
action_type="propose_protocol",
sample_size=8,
controls=["baseline"],
technique="general benchmark workflow",
duration_days=2,
required_equipment=["gpu_h100"],
required_reagents=[],
questions=[],
rationale="Run a generic experiment and compare the outputs.",
)
assert score_paper_understanding(observation, grounded) > score_paper_understanding(
observation, generic
)
def test_relevant_blocking_question_scores_above_irrelevant_question() -> None:
observation = _observation()
relevant = ScientistAction(
action_type="request_info",
sample_size=0,
controls=[],
technique="",
duration_days=0,
required_equipment=[],
required_reagents=[],
questions=["Which CIFAR-10 augmentation setting was used in the paper?"],
rationale="",
)
irrelevant = ScientistAction(
action_type="request_info",
sample_size=0,
controls=[],
technique="",
duration_days=0,
required_equipment=[],
required_reagents=[],
questions=["What is the office Wi-Fi password?"],
rationale="",
)
assert score_paper_understanding(observation, relevant) > score_paper_understanding(
observation, irrelevant
)
def test_score_is_bounded() -> None:
score = score_paper_understanding(
_observation(),
ScientistAction(
action_type="propose_protocol",
sample_size=4,
controls=["baseline"],
technique="ResNet augmentation replication",
duration_days=1,
required_equipment=["gpu_h100"],
required_reagents=[],
questions=[],
rationale="Replicate the CIFAR-10 augmentation finding.",
),
)
assert 0.0 <= score <= 1.0
|