File size: 7,893 Bytes
32ec139
6e72b95
32ec139
 
 
e75c8ce
32ec139
6e72b95
 
 
e75c8ce
d342897
32ec139
 
4f8cf04
32ec139
 
 
4f8cf04
32ec139
 
 
e75c8ce
 
 
 
32ec139
 
e75c8ce
 
 
 
32ec139
 
e75c8ce
 
32ec139
 
4f8cf04
32ec139
 
e75c8ce
 
4f8cf04
32ec139
 
e75c8ce
32ec139
e75c8ce
32ec139
 
 
 
 
4f8cf04
 
 
32ec139
 
 
4f8cf04
32ec139
4f8cf04
32ec139
4f8cf04
 
 
 
32ec139
4f8cf04
 
32ec139
4f8cf04
 
32ec139
 
e75c8ce
4f8cf04
e75c8ce
4f8cf04
e75c8ce
4f8cf04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32ec139
e75c8ce
4f8cf04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32ec139
e75c8ce
4f8cf04
32ec139
 
 
 
 
 
 
e75c8ce
32ec139
 
 
 
 
 
6e72b95
32ec139
 
 
 
 
 
 
6e72b95
 
32ec139
 
 
4f8cf04
6e72b95
 
e75c8ce
 
32ec139
 
 
4f8cf04
32ec139
 
 
 
e75c8ce
4f8cf04
32ec139
 
e75c8ce
 
32ec139
 
 
 
 
e75c8ce
 
32ec139
e75c8ce
32ec139
 
e75c8ce
32ec139
e75c8ce
 
 
32ec139
 
 
 
 
 
 
 
 
e75c8ce
 
32ec139
 
 
 
 
 
e75c8ce
32ec139
 
 
 
e75c8ce
 
 
32ec139
 
 
 
 
 
 
 
 
 
e75c8ce
32ec139
 
 
 
 
 
 
 
e75c8ce
32ec139
 
 
 
 
e75c8ce
32ec139
 
 
 
 
4f8cf04
6e72b95
 
e75c8ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e72b95
32ec139
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
import json
import os
import sys
import textwrap
from pathlib import Path
from typing import Any, Dict, List, Optional

import requests
from openai import OpenAI

from env.grader import clamp_unit_interval

try:
    from dotenv import load_dotenv

    load_dotenv(Path(__file__).resolve().parent / ".env")
except ImportError:
    pass

API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
ENV_URL = os.getenv(
    "ENV_URL",
    "http://127.0.0.1:7860",
).rstrip("/")
BENCHMARK = "cache_invalidation_env"

# Reproducibility (Phase 1 / baseline): fixed seed + task → deterministic heuristic run.
EPISODE_SEED = int(os.getenv("EPISODE_SEED", "42"))
TASK_ID = os.getenv("TASK_ID", "easy")

if not API_KEY:
    print(
        "WARNING: HF_TOKEN is not set. LLM calls will fail; the script will use the "
        "heuristic policy only.",
        file=sys.stderr,
    )

client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY or "hf-invalid")

MEMORY: Dict[str, Any] = {}
LAST_USED: Optional[str] = None

SYSTEM_PROMPT = textwrap.dedent(
    """
    You are a cache invalidation agent. Given the environment observation (JSON), reply with exactly one JSON object
    on a single line, no markdown, with keys "type" and "key". type must be one of: invalidate, refresh, keep.
    key must match one of the item keys in observation["items"].
    """
).strip()


def log_start(task: str, env: str, model: str) -> None:
    print(f"[START] task={task} env={env} model={model}", flush=True)


def log_step(
    step: int, action: str, reward: float, done: bool, error: Optional[str]
) -> None:
    error_val = error if error else "null"
    done_val = str(done).lower()
    print(
        f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
        flush=True,
    )


def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
    rewards_str = ",".join(f"{r:.2f}" for r in rewards)
    print(
        f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
        flush=True,
    )


def select_item(obs: Dict[str, Any], step: int) -> Dict[str, Any]:
    global LAST_USED
    items = obs["items"]

    def score(item: Dict[str, Any]) -> int:
        s = 0
        if item["last_result"] == "stale":
            s += 3
        if item["age"] > 5:
            s += 2
        if item["access_count"] > 10:
            s += 1
        return s

    best = max(items, key=score)

    if step % 2 == 1:
        for item in items:
            if item["key"] != LAST_USED:
                LAST_USED = item["key"]
                return item

    LAST_USED = best["key"]
    return best


def decide(item: Dict[str, Any], step: int) -> Dict[str, str]:
    key = item["key"]
    last_result = item["last_result"]
    age = item["age"]

    mem = MEMORY.get(key, {})

    if mem.get("last_action") == "invalidate" and step - mem.get("last_step", -10) < 2:
        return {"type": "keep", "key": key}

    if last_result == "stale" and age > 2:
        return {"type": "invalidate", "key": key}

    if 3 <= age <= 6:
        return {"type": "refresh", "key": key}

    if last_result == "hit" and age < 3:
        return {"type": "keep", "key": key}

    if age > 6:
        return {"type": "refresh", "key": key}

    return {"type": "keep", "key": key}


def llm_action(obs: Dict[str, Any]) -> Optional[dict]:
    try:
        completion = client.chat.completions.create(
            model=MODEL_NAME,
            messages=[
                {"role": "system", "content": SYSTEM_PROMPT},
                {
                    "role": "user",
                    "content": (
                        f"Observation:\n{json.dumps(obs)}\n\n"
                        'Return JSON only: {"type": "...", "key": "..."}'
                    ),
                },
            ],
            temperature=0,
            max_tokens=150,
        )
        text = (completion.choices[0].message.content or "").strip()
        if text.startswith("```"):
            parts = text.split("```")
            text = parts[1] if len(parts) >= 2 else text
            text = text.strip()
            if text.lower().startswith("json"):
                text = text[4:].strip()
        action = json.loads(text)
        if "type" in action and "key" in action:
            return {"type": action["type"], "key": action["key"]}
    except Exception as exc:
        print(f"[LLM] request/parse failed: {exc}", file=sys.stderr)
    return None


def run_episode(*, env_url: str, task_id: str, seed: int, use_llm: bool) -> None:
    """One episode over OpenEnv HTTP API (wrapped action + observation)."""
    global LAST_USED
    LAST_USED = None
    MEMORY.clear()

    rewards: List[float] = []
    steps_taken = 0
    episode_score = 0.0
    success = False
    score_from_env = False

    try:
        res = requests.post(
            f"{env_url}/reset",
            json={"seed": seed, "task_id": task_id},
            headers={"Content-Type": "application/json"},
            timeout=60,
        )
        res.raise_for_status()
        body = res.json()
        obs = body.get("observation", body)
        tid = str(obs.get("task_id", task_id))

        log_start(task=tid, env=BENCHMARK, model=MODEL_NAME)

        for step in range(1, 11):
            item = select_item(obs, step)

            action: Optional[dict] = None
            if use_llm:
                action = llm_action(obs)
            if action is None:
                action = decide(item, step)

            MEMORY[item["key"]] = {
                "last_action": action["type"],
                "last_step": step,
            }

            step_res = requests.post(
                f"{env_url}/step",
                json={"action": action},
                headers={"Content-Type": "application/json"},
                timeout=60,
            )
            step_res.raise_for_status()
            data = step_res.json()

            reward = float(data["reward"] if data["reward"] is not None else 0.0)
            done = bool(data["done"])
            rewards.append(reward)
            steps_taken = step

            inner = data.get("observation", {})
            if inner.get("final_score") is not None:
                episode_score = float(inner["final_score"])
                score_from_env = True

            log_step(
                step=step,
                action=json.dumps(action),
                reward=reward,
                done=done,
                error=None,
            )

            obs = inner
            if done:
                break

        if rewards:
            avg_r = sum(rewards) / len(rewards)
            success = avg_r > 0.3
        if not score_from_env and rewards:
            avg_r = sum(rewards) / len(rewards)
            episode_score = clamp_unit_interval((avg_r + 1.0) / 2.0)

    except Exception as exc:
        success = False
        print(f"[RUN] fatal: {exc}", file=sys.stderr)
    finally:
        episode_score = clamp_unit_interval(episode_score)
        log_end(
            success=success,
            steps=steps_taken,
            score=episode_score,
            rewards=rewards,
        )


def run() -> None:
    use_llm = bool(API_KEY and API_KEY != "hf-invalid")
    if os.getenv("RUN_ALL_TASKS", "").lower() in ("1", "true", "yes"):
        for tid in ("easy", "medium", "hard"):
            run_episode(
                env_url=ENV_URL,
                task_id=tid,
                seed=EPISODE_SEED,
                use_llm=use_llm,
            )
        return
    run_episode(
        env_url=ENV_URL,
        task_id=TASK_ID,
        seed=EPISODE_SEED,
        use_llm=use_llm,
    )


if __name__ == "__main__":
    run()