Spaces:
Sleeping
Sleeping
| """OpenEnv Environment: cache invalidation under partial observability.""" | |
| from __future__ import annotations | |
| import random | |
| from typing import Any, Optional | |
| from openenv.core.env_server import Environment | |
| from openenv.core.env_server.types import EnvironmentMetadata | |
| from env.generator import generate_env | |
| from env.grader import compute_step_reward, evaluate_episode | |
| from env.models import CacheAction, CacheItem, CacheObservation, CacheState | |
| from env.tasks import sample_task | |
| class CacheInvalidationEnvironment(Environment[CacheAction, CacheObservation, CacheState]): | |
| """Stateful cache control: invalidate, refresh, or keep per step (one key).""" | |
| SUPPORTS_CONCURRENT_SESSIONS = False | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self._rng: random.Random | type[random] = random | |
| self.history: list[dict[str, Any]] = [] | |
| self.task_id: str = "easy" | |
| self.hidden: list[dict[str, Any]] = [] | |
| self.current_time: int = 0 | |
| self._items: list[dict[str, Any]] = [] | |
| self._step: int = 0 | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| task_id: Optional[str] = None, | |
| task_name: Optional[str] = None, | |
| **kwargs: Any, | |
| ) -> CacheObservation: | |
| tid = task_id or task_name or kwargs.get("task_id") or kwargs.get("task_name") | |
| self._reset_rubric() | |
| if seed is not None: | |
| self._rng = random.Random(int(seed)) | |
| else: | |
| self._rng = random | |
| self.history = [] | |
| if tid in ("easy", "medium", "hard"): | |
| self.task_id = tid | |
| else: | |
| self.task_id = sample_task(self._rng) | |
| items, hidden, current_time = generate_env(self.task_id, rng=self._rng) | |
| self._items = items | |
| self.hidden = hidden | |
| self.current_time = current_time | |
| self._step = 0 | |
| return self._observation( | |
| reward=None, | |
| done=False, | |
| final_score=None, | |
| ) | |
| def step( | |
| self, | |
| action: CacheAction, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> CacheObservation: | |
| key = action.key | |
| action_type = action.type | |
| item_index = next( | |
| (i for i, x in enumerate(self._items) if x["key"] == key), None | |
| ) | |
| if item_index is None: | |
| return self._observation(reward=-1.0, done=True, final_score=None) | |
| hidden = self.hidden[item_index] | |
| item = self._items[item_index] | |
| age = self.current_time - hidden["last_update"] | |
| is_stale = age > hidden["base_ttl"] or self._rng.random() < hidden["update_freq"] | |
| self.history.append({"action": action_type, "is_stale": is_stale}) | |
| reward = compute_step_reward(action_type, is_stale) | |
| if action_type == "invalidate": | |
| hidden["last_update"] = self.current_time | |
| item["age"] = 0 | |
| elif action_type == "refresh": | |
| hidden["last_update"] = self.current_time - 1 | |
| item["age"] = 1 | |
| elif action_type == "keep": | |
| item["age"] += 1 | |
| item["last_result"] = ( | |
| "stale" | |
| if is_stale and self._rng.random() < 0.7 | |
| else "hit" | |
| if not is_stale or self._rng.random() < 0.9 | |
| else "stale" | |
| ) | |
| self.current_time += 1 | |
| self._step += 1 | |
| done = self._step >= 10 | |
| final_score = evaluate_episode(self.history) if done else None | |
| return self._observation( | |
| reward=reward, | |
| done=done, | |
| final_score=final_score, | |
| ) | |
| def state(self) -> CacheState: | |
| return CacheState( | |
| episode_id=None, | |
| step_count=self._step, | |
| task_id=self.task_id, | |
| items=[CacheItem.model_validate(x) for x in self._items], | |
| ) | |
| def get_metadata(self) -> EnvironmentMetadata: | |
| return EnvironmentMetadata( | |
| name="cache_invalidation_env", | |
| description=( | |
| "Cache invalidation under uncertainty: choose invalidate, refresh, or keep " | |
| "per step from noisy hit/stale observations." | |
| ), | |
| version="1.0.0", | |
| ) | |
| def _observation( | |
| self, | |
| *, | |
| reward: float | None, | |
| done: bool, | |
| final_score: float | None, | |
| ) -> CacheObservation: | |
| return CacheObservation( | |
| done=done, | |
| reward=reward, | |
| items=[CacheItem.model_validate(x) for x in self._items], | |
| step=self._step, | |
| task_id=self.task_id, | |
| final_score=final_score, | |
| ) | |