replicalab / tests /test_cache.py
maxxie114's picture
Initial HF Spaces deployment
80d8c84
from __future__ import annotations
import json
from replicalab.cache import CachedOracle, ScenarioCache
from replicalab.oracle_models import Scenario
def _scenario_payload() -> dict:
return {
"paper": {
"title": "Cached benchmark",
"domain": "ml_benchmark",
"claim": "A small run remains useful under a tighter budget.",
"method_summary": "Train a compact model and verify against a held-out split.",
"original_sample_size": 1000,
"original_duration_days": 2,
"original_technique": "compact_model",
"required_controls": ["baseline"],
"required_equipment": ["GPU cluster"],
"required_reagents": ["dataset snapshot"],
"statistical_test": "accuracy_gap",
},
"lab_constraints": {
"budget_total": 1200.0,
"budget_remaining": 1200.0,
"equipment": [
{
"name": "GPU cluster",
"available": True,
"condition": "operational",
"booking_conflicts": [],
"cost_per_use": 100.0,
}
],
"reagents": [
{
"name": "dataset snapshot",
"in_stock": True,
"quantity_available": 1.0,
"unit": "copy",
"lead_time_days": 0,
"cost": 0.0,
}
],
"staff": [],
"max_duration_days": 3,
"safety_rules": ["No external internet."],
"valid_substitutions": [],
},
"minimum_viable_spec": {
"min_sample_size": 800,
"must_keep_controls": ["baseline"],
"acceptable_techniques": ["compact_model"],
"min_duration_days": 1,
"critical_equipment": ["GPU cluster"],
"flexible_equipment": [],
"critical_reagents": ["dataset snapshot"],
"flexible_reagents": [],
"power_threshold": 0.75,
},
"difficulty": "easy",
"narrative_hook": "The benchmark owners tightened the reporting budget.",
}
def test_scenario_cache_round_trips(tmp_path) -> None:
cache = ScenarioCache(tmp_path)
scenario = Scenario.model_validate(_scenario_payload())
path = cache.put(13, "easy", "ml_benchmark", scenario)
restored = cache.get(13, "easy", "ml_benchmark")
assert path.exists()
assert restored is not None
assert restored.model_dump(mode="json") == scenario.model_dump(mode="json")
def test_cached_oracle_uses_cache_after_first_generation(tmp_path) -> None:
calls = {"count": 0}
def fake_client(system: str, user: str, model: str) -> str:
calls["count"] += 1
return json.dumps(_scenario_payload())
oracle = CachedOracle(fake_client, cache=ScenarioCache(tmp_path))
first = oracle.generate_scenario(9, "easy", "ml_benchmark")
second = oracle.generate_scenario(9, "easy", "ml_benchmark")
assert first.model_dump(mode="json") == second.model_dump(mode="json")
assert calls["count"] == 1