Spaces:
Sleeping
Sleeping
File size: 4,700 Bytes
e75c8ce | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | """OpenEnv Environment: cache invalidation under partial observability."""
from __future__ import annotations
import random
from typing import Any, Optional
from openenv.core.env_server import Environment
from openenv.core.env_server.types import EnvironmentMetadata
from env.generator import generate_env
from env.grader import compute_step_reward, evaluate_episode
from env.models import CacheAction, CacheItem, CacheObservation, CacheState
from env.tasks import sample_task
class CacheInvalidationEnvironment(Environment[CacheAction, CacheObservation, CacheState]):
"""Stateful cache control: invalidate, refresh, or keep per step (one key)."""
SUPPORTS_CONCURRENT_SESSIONS = False
def __init__(self) -> None:
super().__init__()
self._rng: random.Random | type[random] = random
self.history: list[dict[str, Any]] = []
self.task_id: str = "easy"
self.hidden: list[dict[str, Any]] = []
self.current_time: int = 0
self._items: list[dict[str, Any]] = []
self._step: int = 0
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
task_id: Optional[str] = None,
task_name: Optional[str] = None,
**kwargs: Any,
) -> CacheObservation:
tid = task_id or task_name or kwargs.get("task_id") or kwargs.get("task_name")
self._reset_rubric()
if seed is not None:
self._rng = random.Random(int(seed))
else:
self._rng = random
self.history = []
if tid in ("easy", "medium", "hard"):
self.task_id = tid
else:
self.task_id = sample_task(self._rng)
items, hidden, current_time = generate_env(self.task_id, rng=self._rng)
self._items = items
self.hidden = hidden
self.current_time = current_time
self._step = 0
return self._observation(
reward=None,
done=False,
final_score=None,
)
def step(
self,
action: CacheAction,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> CacheObservation:
key = action.key
action_type = action.type
item_index = next(
(i for i, x in enumerate(self._items) if x["key"] == key), None
)
if item_index is None:
return self._observation(reward=-1.0, done=True, final_score=None)
hidden = self.hidden[item_index]
item = self._items[item_index]
age = self.current_time - hidden["last_update"]
is_stale = age > hidden["base_ttl"] or self._rng.random() < hidden["update_freq"]
self.history.append({"action": action_type, "is_stale": is_stale})
reward = compute_step_reward(action_type, is_stale)
if action_type == "invalidate":
hidden["last_update"] = self.current_time
item["age"] = 0
elif action_type == "refresh":
hidden["last_update"] = self.current_time - 1
item["age"] = 1
elif action_type == "keep":
item["age"] += 1
item["last_result"] = (
"stale"
if is_stale and self._rng.random() < 0.7
else "hit"
if not is_stale or self._rng.random() < 0.9
else "stale"
)
self.current_time += 1
self._step += 1
done = self._step >= 10
final_score = evaluate_episode(self.history) if done else None
return self._observation(
reward=reward,
done=done,
final_score=final_score,
)
@property
def state(self) -> CacheState:
return CacheState(
episode_id=None,
step_count=self._step,
task_id=self.task_id,
items=[CacheItem.model_validate(x) for x in self._items],
)
def get_metadata(self) -> EnvironmentMetadata:
return EnvironmentMetadata(
name="cache_invalidation_env",
description=(
"Cache invalidation under uncertainty: choose invalidate, refresh, or keep "
"per step from noisy hit/stale observations."
),
version="1.0.0",
)
def _observation(
self,
*,
reward: float | None,
done: bool,
final_score: float | None,
) -> CacheObservation:
return CacheObservation(
done=done,
reward=reward,
items=[CacheItem.model_validate(x) for x in self._items],
step=self._step,
task_id=self.task_id,
final_score=final_score,
)
|