Spaces:
Sleeping
Sleeping
File size: 3,209 Bytes
80d8c84 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | from __future__ import annotations
import json
from replicalab.cache import CachedOracle, ScenarioCache
from replicalab.oracle_models import Scenario
def _scenario_payload() -> dict:
return {
"paper": {
"title": "Cached benchmark",
"domain": "ml_benchmark",
"claim": "A small run remains useful under a tighter budget.",
"method_summary": "Train a compact model and verify against a held-out split.",
"original_sample_size": 1000,
"original_duration_days": 2,
"original_technique": "compact_model",
"required_controls": ["baseline"],
"required_equipment": ["GPU cluster"],
"required_reagents": ["dataset snapshot"],
"statistical_test": "accuracy_gap",
},
"lab_constraints": {
"budget_total": 1200.0,
"budget_remaining": 1200.0,
"equipment": [
{
"name": "GPU cluster",
"available": True,
"condition": "operational",
"booking_conflicts": [],
"cost_per_use": 100.0,
}
],
"reagents": [
{
"name": "dataset snapshot",
"in_stock": True,
"quantity_available": 1.0,
"unit": "copy",
"lead_time_days": 0,
"cost": 0.0,
}
],
"staff": [],
"max_duration_days": 3,
"safety_rules": ["No external internet."],
"valid_substitutions": [],
},
"minimum_viable_spec": {
"min_sample_size": 800,
"must_keep_controls": ["baseline"],
"acceptable_techniques": ["compact_model"],
"min_duration_days": 1,
"critical_equipment": ["GPU cluster"],
"flexible_equipment": [],
"critical_reagents": ["dataset snapshot"],
"flexible_reagents": [],
"power_threshold": 0.75,
},
"difficulty": "easy",
"narrative_hook": "The benchmark owners tightened the reporting budget.",
}
def test_scenario_cache_round_trips(tmp_path) -> None:
cache = ScenarioCache(tmp_path)
scenario = Scenario.model_validate(_scenario_payload())
path = cache.put(13, "easy", "ml_benchmark", scenario)
restored = cache.get(13, "easy", "ml_benchmark")
assert path.exists()
assert restored is not None
assert restored.model_dump(mode="json") == scenario.model_dump(mode="json")
def test_cached_oracle_uses_cache_after_first_generation(tmp_path) -> None:
calls = {"count": 0}
def fake_client(system: str, user: str, model: str) -> str:
calls["count"] += 1
return json.dumps(_scenario_payload())
oracle = CachedOracle(fake_client, cache=ScenarioCache(tmp_path))
first = oracle.generate_scenario(9, "easy", "ml_benchmark")
second = oracle.generate_scenario(9, "easy", "ml_benchmark")
assert first.model_dump(mode="json") == second.model_dump(mode="json")
assert calls["count"] == 1
|