cache-env / env /cache_environment.py
Parv Pareek
done
e75c8ce
"""OpenEnv Environment: cache invalidation under partial observability."""
from __future__ import annotations
import random
from typing import Any, Optional
from openenv.core.env_server import Environment
from openenv.core.env_server.types import EnvironmentMetadata
from env.generator import generate_env
from env.grader import compute_step_reward, evaluate_episode
from env.models import CacheAction, CacheItem, CacheObservation, CacheState
from env.tasks import sample_task
class CacheInvalidationEnvironment(Environment[CacheAction, CacheObservation, CacheState]):
"""Stateful cache control: invalidate, refresh, or keep per step (one key)."""
SUPPORTS_CONCURRENT_SESSIONS = False
def __init__(self) -> None:
super().__init__()
self._rng: random.Random | type[random] = random
self.history: list[dict[str, Any]] = []
self.task_id: str = "easy"
self.hidden: list[dict[str, Any]] = []
self.current_time: int = 0
self._items: list[dict[str, Any]] = []
self._step: int = 0
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
task_id: Optional[str] = None,
task_name: Optional[str] = None,
**kwargs: Any,
) -> CacheObservation:
tid = task_id or task_name or kwargs.get("task_id") or kwargs.get("task_name")
self._reset_rubric()
if seed is not None:
self._rng = random.Random(int(seed))
else:
self._rng = random
self.history = []
if tid in ("easy", "medium", "hard"):
self.task_id = tid
else:
self.task_id = sample_task(self._rng)
items, hidden, current_time = generate_env(self.task_id, rng=self._rng)
self._items = items
self.hidden = hidden
self.current_time = current_time
self._step = 0
return self._observation(
reward=None,
done=False,
final_score=None,
)
def step(
self,
action: CacheAction,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> CacheObservation:
key = action.key
action_type = action.type
item_index = next(
(i for i, x in enumerate(self._items) if x["key"] == key), None
)
if item_index is None:
return self._observation(reward=-1.0, done=True, final_score=None)
hidden = self.hidden[item_index]
item = self._items[item_index]
age = self.current_time - hidden["last_update"]
is_stale = age > hidden["base_ttl"] or self._rng.random() < hidden["update_freq"]
self.history.append({"action": action_type, "is_stale": is_stale})
reward = compute_step_reward(action_type, is_stale)
if action_type == "invalidate":
hidden["last_update"] = self.current_time
item["age"] = 0
elif action_type == "refresh":
hidden["last_update"] = self.current_time - 1
item["age"] = 1
elif action_type == "keep":
item["age"] += 1
item["last_result"] = (
"stale"
if is_stale and self._rng.random() < 0.7
else "hit"
if not is_stale or self._rng.random() < 0.9
else "stale"
)
self.current_time += 1
self._step += 1
done = self._step >= 10
final_score = evaluate_episode(self.history) if done else None
return self._observation(
reward=reward,
done=done,
final_score=final_score,
)
@property
def state(self) -> CacheState:
return CacheState(
episode_id=None,
step_count=self._step,
task_id=self.task_id,
items=[CacheItem.model_validate(x) for x in self._items],
)
def get_metadata(self) -> EnvironmentMetadata:
return EnvironmentMetadata(
name="cache_invalidation_env",
description=(
"Cache invalidation under uncertainty: choose invalidate, refresh, or keep "
"per step from noisy hit/stale observations."
),
version="1.0.0",
)
def _observation(
self,
*,
reward: float | None,
done: bool,
final_score: float | None,
) -> CacheObservation:
return CacheObservation(
done=done,
reward=reward,
items=[CacheItem.model_validate(x) for x in self._items],
step=self._step,
task_id=self.task_id,
final_score=final_score,
)