File size: 3,306 Bytes
80d8c84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from __future__ import annotations

from replicalab.models import ScientistAction, ScientistObservation
from replicalab.scoring import score_paper_understanding


def _observation() -> ScientistObservation:
    return ScientistObservation(
        paper_title="ResNet augmentation study",
        paper_hypothesis="Data augmentation improves CIFAR-10 classification accuracy.",
        paper_method="Train a ResNet model on CIFAR-10 with and without augmentation.",
        paper_key_finding="The augmented run improves top-1 accuracy over the baseline.",
        experiment_goal="Replicate the augmentation-driven accuracy gain on CIFAR-10.",
        conversation_history=[],
        current_protocol=None,
        round_number=0,
        max_rounds=6,
    )


def test_grounded_protocol_scores_higher_than_generic_protocol() -> None:
    observation = _observation()
    grounded = ScientistAction(
        action_type="propose_protocol",
        sample_size=8,
        controls=["no_augmentation_baseline"],
        technique="ResNet replication on CIFAR-10 with augmentation ablation",
        duration_days=2,
        required_equipment=["gpu_h100"],
        required_reagents=[],
        questions=[],
        rationale=(
            "Train the CIFAR-10 ResNet baseline and the augmented variant to test "
            "whether augmentation improves top-1 accuracy."
        ),
    )
    generic = ScientistAction(
        action_type="propose_protocol",
        sample_size=8,
        controls=["baseline"],
        technique="general benchmark workflow",
        duration_days=2,
        required_equipment=["gpu_h100"],
        required_reagents=[],
        questions=[],
        rationale="Run a generic experiment and compare the outputs.",
    )

    assert score_paper_understanding(observation, grounded) > score_paper_understanding(
        observation, generic
    )


def test_relevant_blocking_question_scores_above_irrelevant_question() -> None:
    observation = _observation()
    relevant = ScientistAction(
        action_type="request_info",
        sample_size=0,
        controls=[],
        technique="",
        duration_days=0,
        required_equipment=[],
        required_reagents=[],
        questions=["Which CIFAR-10 augmentation setting was used in the paper?"],
        rationale="",
    )
    irrelevant = ScientistAction(
        action_type="request_info",
        sample_size=0,
        controls=[],
        technique="",
        duration_days=0,
        required_equipment=[],
        required_reagents=[],
        questions=["What is the office Wi-Fi password?"],
        rationale="",
    )

    assert score_paper_understanding(observation, relevant) > score_paper_understanding(
        observation, irrelevant
    )


def test_score_is_bounded() -> None:
    score = score_paper_understanding(
        _observation(),
        ScientistAction(
            action_type="propose_protocol",
            sample_size=4,
            controls=["baseline"],
            technique="ResNet augmentation replication",
            duration_days=1,
            required_equipment=["gpu_h100"],
            required_reagents=[],
            questions=[],
            rationale="Replicate the CIFAR-10 augmentation finding.",
        ),
    )

    assert 0.0 <= score <= 1.0