"""OpenEnv Environment: cache invalidation under partial observability.""" from __future__ import annotations import random from typing import Any, Optional from openenv.core.env_server import Environment from openenv.core.env_server.types import EnvironmentMetadata from env.generator import generate_env from env.grader import compute_step_reward, evaluate_episode from env.models import CacheAction, CacheItem, CacheObservation, CacheState from env.tasks import sample_task class CacheInvalidationEnvironment(Environment[CacheAction, CacheObservation, CacheState]): """Stateful cache control: invalidate, refresh, or keep per step (one key).""" SUPPORTS_CONCURRENT_SESSIONS = False def __init__(self) -> None: super().__init__() self._rng: random.Random | type[random] = random self.history: list[dict[str, Any]] = [] self.task_id: str = "easy" self.hidden: list[dict[str, Any]] = [] self.current_time: int = 0 self._items: list[dict[str, Any]] = [] self._step: int = 0 def reset( self, seed: Optional[int] = None, episode_id: Optional[str] = None, task_id: Optional[str] = None, task_name: Optional[str] = None, **kwargs: Any, ) -> CacheObservation: tid = task_id or task_name or kwargs.get("task_id") or kwargs.get("task_name") self._reset_rubric() if seed is not None: self._rng = random.Random(int(seed)) else: self._rng = random self.history = [] if tid in ("easy", "medium", "hard"): self.task_id = tid else: self.task_id = sample_task(self._rng) items, hidden, current_time = generate_env(self.task_id, rng=self._rng) self._items = items self.hidden = hidden self.current_time = current_time self._step = 0 return self._observation( reward=None, done=False, final_score=None, ) def step( self, action: CacheAction, timeout_s: Optional[float] = None, **kwargs: Any, ) -> CacheObservation: key = action.key action_type = action.type item_index = next( (i for i, x in enumerate(self._items) if x["key"] == key), None ) if item_index is None: return self._observation(reward=-1.0, done=True, final_score=None) hidden = self.hidden[item_index] item = self._items[item_index] age = self.current_time - hidden["last_update"] is_stale = age > hidden["base_ttl"] or self._rng.random() < hidden["update_freq"] self.history.append({"action": action_type, "is_stale": is_stale}) reward = compute_step_reward(action_type, is_stale) if action_type == "invalidate": hidden["last_update"] = self.current_time item["age"] = 0 elif action_type == "refresh": hidden["last_update"] = self.current_time - 1 item["age"] = 1 elif action_type == "keep": item["age"] += 1 item["last_result"] = ( "stale" if is_stale and self._rng.random() < 0.7 else "hit" if not is_stale or self._rng.random() < 0.9 else "stale" ) self.current_time += 1 self._step += 1 done = self._step >= 10 final_score = evaluate_episode(self.history) if done else None return self._observation( reward=reward, done=done, final_score=final_score, ) @property def state(self) -> CacheState: return CacheState( episode_id=None, step_count=self._step, task_id=self.task_id, items=[CacheItem.model_validate(x) for x in self._items], ) def get_metadata(self) -> EnvironmentMetadata: return EnvironmentMetadata( name="cache_invalidation_env", description=( "Cache invalidation under uncertainty: choose invalidate, refresh, or keep " "per step from noisy hit/stale observations." ), version="1.0.0", ) def _observation( self, *, reward: float | None, done: bool, final_score: float | None, ) -> CacheObservation: return CacheObservation( done=done, reward=reward, items=[CacheItem.model_validate(x) for x in self._items], step=self._step, task_id=self.task_id, final_score=final_score, )