from __future__ import annotations import json from replicalab.cache import CachedOracle, ScenarioCache from replicalab.oracle_models import Scenario def _scenario_payload() -> dict: return { "paper": { "title": "Cached benchmark", "domain": "ml_benchmark", "claim": "A small run remains useful under a tighter budget.", "method_summary": "Train a compact model and verify against a held-out split.", "original_sample_size": 1000, "original_duration_days": 2, "original_technique": "compact_model", "required_controls": ["baseline"], "required_equipment": ["GPU cluster"], "required_reagents": ["dataset snapshot"], "statistical_test": "accuracy_gap", }, "lab_constraints": { "budget_total": 1200.0, "budget_remaining": 1200.0, "equipment": [ { "name": "GPU cluster", "available": True, "condition": "operational", "booking_conflicts": [], "cost_per_use": 100.0, } ], "reagents": [ { "name": "dataset snapshot", "in_stock": True, "quantity_available": 1.0, "unit": "copy", "lead_time_days": 0, "cost": 0.0, } ], "staff": [], "max_duration_days": 3, "safety_rules": ["No external internet."], "valid_substitutions": [], }, "minimum_viable_spec": { "min_sample_size": 800, "must_keep_controls": ["baseline"], "acceptable_techniques": ["compact_model"], "min_duration_days": 1, "critical_equipment": ["GPU cluster"], "flexible_equipment": [], "critical_reagents": ["dataset snapshot"], "flexible_reagents": [], "power_threshold": 0.75, }, "difficulty": "easy", "narrative_hook": "The benchmark owners tightened the reporting budget.", } def test_scenario_cache_round_trips(tmp_path) -> None: cache = ScenarioCache(tmp_path) scenario = Scenario.model_validate(_scenario_payload()) path = cache.put(13, "easy", "ml_benchmark", scenario) restored = cache.get(13, "easy", "ml_benchmark") assert path.exists() assert restored is not None assert restored.model_dump(mode="json") == scenario.model_dump(mode="json") def test_cached_oracle_uses_cache_after_first_generation(tmp_path) -> None: calls = {"count": 0} def fake_client(system: str, user: str, model: str) -> str: calls["count"] += 1 return json.dumps(_scenario_payload()) oracle = CachedOracle(fake_client, cache=ScenarioCache(tmp_path)) first = oracle.generate_scenario(9, "easy", "ml_benchmark") second = oracle.generate_scenario(9, "easy", "ml_benchmark") assert first.model_dump(mode="json") == second.model_dump(mode="json") assert calls["count"] == 1