File size: 3,209 Bytes
80d8c84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from __future__ import annotations

import json

from replicalab.cache import CachedOracle, ScenarioCache
from replicalab.oracle_models import Scenario


def _scenario_payload() -> dict:
    return {
        "paper": {
            "title": "Cached benchmark",
            "domain": "ml_benchmark",
            "claim": "A small run remains useful under a tighter budget.",
            "method_summary": "Train a compact model and verify against a held-out split.",
            "original_sample_size": 1000,
            "original_duration_days": 2,
            "original_technique": "compact_model",
            "required_controls": ["baseline"],
            "required_equipment": ["GPU cluster"],
            "required_reagents": ["dataset snapshot"],
            "statistical_test": "accuracy_gap",
        },
        "lab_constraints": {
            "budget_total": 1200.0,
            "budget_remaining": 1200.0,
            "equipment": [
                {
                    "name": "GPU cluster",
                    "available": True,
                    "condition": "operational",
                    "booking_conflicts": [],
                    "cost_per_use": 100.0,
                }
            ],
            "reagents": [
                {
                    "name": "dataset snapshot",
                    "in_stock": True,
                    "quantity_available": 1.0,
                    "unit": "copy",
                    "lead_time_days": 0,
                    "cost": 0.0,
                }
            ],
            "staff": [],
            "max_duration_days": 3,
            "safety_rules": ["No external internet."],
            "valid_substitutions": [],
        },
        "minimum_viable_spec": {
            "min_sample_size": 800,
            "must_keep_controls": ["baseline"],
            "acceptable_techniques": ["compact_model"],
            "min_duration_days": 1,
            "critical_equipment": ["GPU cluster"],
            "flexible_equipment": [],
            "critical_reagents": ["dataset snapshot"],
            "flexible_reagents": [],
            "power_threshold": 0.75,
        },
        "difficulty": "easy",
        "narrative_hook": "The benchmark owners tightened the reporting budget.",
    }


def test_scenario_cache_round_trips(tmp_path) -> None:
    cache = ScenarioCache(tmp_path)
    scenario = Scenario.model_validate(_scenario_payload())

    path = cache.put(13, "easy", "ml_benchmark", scenario)
    restored = cache.get(13, "easy", "ml_benchmark")

    assert path.exists()
    assert restored is not None
    assert restored.model_dump(mode="json") == scenario.model_dump(mode="json")


def test_cached_oracle_uses_cache_after_first_generation(tmp_path) -> None:
    calls = {"count": 0}

    def fake_client(system: str, user: str, model: str) -> str:
        calls["count"] += 1
        return json.dumps(_scenario_payload())

    oracle = CachedOracle(fake_client, cache=ScenarioCache(tmp_path))

    first = oracle.generate_scenario(9, "easy", "ml_benchmark")
    second = oracle.generate_scenario(9, "easy", "ml_benchmark")

    assert first.model_dump(mode="json") == second.model_dump(mode="json")
    assert calls["count"] == 1