File size: 4,700 Bytes
e75c8ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""OpenEnv Environment: cache invalidation under partial observability."""

from __future__ import annotations

import random
from typing import Any, Optional

from openenv.core.env_server import Environment
from openenv.core.env_server.types import EnvironmentMetadata

from env.generator import generate_env
from env.grader import compute_step_reward, evaluate_episode
from env.models import CacheAction, CacheItem, CacheObservation, CacheState
from env.tasks import sample_task


class CacheInvalidationEnvironment(Environment[CacheAction, CacheObservation, CacheState]):
    """Stateful cache control: invalidate, refresh, or keep per step (one key)."""

    SUPPORTS_CONCURRENT_SESSIONS = False

    def __init__(self) -> None:
        super().__init__()
        self._rng: random.Random | type[random] = random
        self.history: list[dict[str, Any]] = []
        self.task_id: str = "easy"
        self.hidden: list[dict[str, Any]] = []
        self.current_time: int = 0
        self._items: list[dict[str, Any]] = []
        self._step: int = 0

    def reset(
        self,
        seed: Optional[int] = None,
        episode_id: Optional[str] = None,
        task_id: Optional[str] = None,
        task_name: Optional[str] = None,
        **kwargs: Any,
    ) -> CacheObservation:
        tid = task_id or task_name or kwargs.get("task_id") or kwargs.get("task_name")
        self._reset_rubric()

        if seed is not None:
            self._rng = random.Random(int(seed))
        else:
            self._rng = random

        self.history = []
        if tid in ("easy", "medium", "hard"):
            self.task_id = tid
        else:
            self.task_id = sample_task(self._rng)

        items, hidden, current_time = generate_env(self.task_id, rng=self._rng)
        self._items = items
        self.hidden = hidden
        self.current_time = current_time
        self._step = 0

        return self._observation(
            reward=None,
            done=False,
            final_score=None,
        )

    def step(
        self,
        action: CacheAction,
        timeout_s: Optional[float] = None,
        **kwargs: Any,
    ) -> CacheObservation:
        key = action.key
        action_type = action.type

        item_index = next(
            (i for i, x in enumerate(self._items) if x["key"] == key), None
        )

        if item_index is None:
            return self._observation(reward=-1.0, done=True, final_score=None)

        hidden = self.hidden[item_index]
        item = self._items[item_index]

        age = self.current_time - hidden["last_update"]
        is_stale = age > hidden["base_ttl"] or self._rng.random() < hidden["update_freq"]

        self.history.append({"action": action_type, "is_stale": is_stale})

        reward = compute_step_reward(action_type, is_stale)

        if action_type == "invalidate":
            hidden["last_update"] = self.current_time
            item["age"] = 0

        elif action_type == "refresh":
            hidden["last_update"] = self.current_time - 1
            item["age"] = 1

        elif action_type == "keep":
            item["age"] += 1

        item["last_result"] = (
            "stale"
            if is_stale and self._rng.random() < 0.7
            else "hit"
            if not is_stale or self._rng.random() < 0.9
            else "stale"
        )

        self.current_time += 1
        self._step += 1

        done = self._step >= 10
        final_score = evaluate_episode(self.history) if done else None

        return self._observation(
            reward=reward,
            done=done,
            final_score=final_score,
        )

    @property
    def state(self) -> CacheState:
        return CacheState(
            episode_id=None,
            step_count=self._step,
            task_id=self.task_id,
            items=[CacheItem.model_validate(x) for x in self._items],
        )

    def get_metadata(self) -> EnvironmentMetadata:
        return EnvironmentMetadata(
            name="cache_invalidation_env",
            description=(
                "Cache invalidation under uncertainty: choose invalidate, refresh, or keep "
                "per step from noisy hit/stale observations."
            ),
            version="1.0.0",
        )

    def _observation(
        self,
        *,
        reward: float | None,
        done: bool,
        final_score: float | None,
    ) -> CacheObservation:
        return CacheObservation(
            done=done,
            reward=reward,
            items=[CacheItem.model_validate(x) for x in self._items],
            step=self._step,
            task_id=self.task_id,
            final_score=final_score,
        )