import json import os import sys import textwrap from pathlib import Path from typing import Any, Dict, List, Optional import requests from openai import OpenAI from env.grader import clamp_unit_interval try: from dotenv import load_dotenv load_dotenv(Path(__file__).resolve().parent / ".env") except ImportError: pass API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1" MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct" ENV_URL = os.getenv( "ENV_URL", "http://127.0.0.1:7860", ).rstrip("/") BENCHMARK = "cache_invalidation_env" # Reproducibility (Phase 1 / baseline): fixed seed + task → deterministic heuristic run. EPISODE_SEED = int(os.getenv("EPISODE_SEED", "42")) TASK_ID = os.getenv("TASK_ID", "easy") if not API_KEY: print( "WARNING: HF_TOKEN is not set. LLM calls will fail; the script will use the " "heuristic policy only.", file=sys.stderr, ) client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY or "hf-invalid") MEMORY: Dict[str, Any] = {} LAST_USED: Optional[str] = None SYSTEM_PROMPT = textwrap.dedent( """ You are a cache invalidation agent. Given the environment observation (JSON), reply with exactly one JSON object on a single line, no markdown, with keys "type" and "key". type must be one of: invalidate, refresh, keep. key must match one of the item keys in observation["items"]. """ ).strip() def log_start(task: str, env: str, model: str) -> None: print(f"[START] task={task} env={env} model={model}", flush=True) def log_step( step: int, action: str, reward: float, done: bool, error: Optional[str] ) -> None: error_val = error if error else "null" done_val = str(done).lower() print( f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True, ) def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: rewards_str = ",".join(f"{r:.2f}" for r in rewards) print( f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True, ) def select_item(obs: Dict[str, Any], step: int) -> Dict[str, Any]: global LAST_USED items = obs["items"] def score(item: Dict[str, Any]) -> int: s = 0 if item["last_result"] == "stale": s += 3 if item["age"] > 5: s += 2 if item["access_count"] > 10: s += 1 return s best = max(items, key=score) if step % 2 == 1: for item in items: if item["key"] != LAST_USED: LAST_USED = item["key"] return item LAST_USED = best["key"] return best def decide(item: Dict[str, Any], step: int) -> Dict[str, str]: key = item["key"] last_result = item["last_result"] age = item["age"] mem = MEMORY.get(key, {}) if mem.get("last_action") == "invalidate" and step - mem.get("last_step", -10) < 2: return {"type": "keep", "key": key} if last_result == "stale" and age > 2: return {"type": "invalidate", "key": key} if 3 <= age <= 6: return {"type": "refresh", "key": key} if last_result == "hit" and age < 3: return {"type": "keep", "key": key} if age > 6: return {"type": "refresh", "key": key} return {"type": "keep", "key": key} def llm_action(obs: Dict[str, Any]) -> Optional[dict]: try: completion = client.chat.completions.create( model=MODEL_NAME, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, { "role": "user", "content": ( f"Observation:\n{json.dumps(obs)}\n\n" 'Return JSON only: {"type": "...", "key": "..."}' ), }, ], temperature=0, max_tokens=150, ) text = (completion.choices[0].message.content or "").strip() if text.startswith("```"): parts = text.split("```") text = parts[1] if len(parts) >= 2 else text text = text.strip() if text.lower().startswith("json"): text = text[4:].strip() action = json.loads(text) if "type" in action and "key" in action: return {"type": action["type"], "key": action["key"]} except Exception as exc: print(f"[LLM] request/parse failed: {exc}", file=sys.stderr) return None def run_episode(*, env_url: str, task_id: str, seed: int, use_llm: bool) -> None: """One episode over OpenEnv HTTP API (wrapped action + observation).""" global LAST_USED LAST_USED = None MEMORY.clear() rewards: List[float] = [] steps_taken = 0 episode_score = 0.0 success = False score_from_env = False try: res = requests.post( f"{env_url}/reset", json={"seed": seed, "task_id": task_id}, headers={"Content-Type": "application/json"}, timeout=60, ) res.raise_for_status() body = res.json() obs = body.get("observation", body) tid = str(obs.get("task_id", task_id)) log_start(task=tid, env=BENCHMARK, model=MODEL_NAME) for step in range(1, 11): item = select_item(obs, step) action: Optional[dict] = None if use_llm: action = llm_action(obs) if action is None: action = decide(item, step) MEMORY[item["key"]] = { "last_action": action["type"], "last_step": step, } step_res = requests.post( f"{env_url}/step", json={"action": action}, headers={"Content-Type": "application/json"}, timeout=60, ) step_res.raise_for_status() data = step_res.json() reward = float(data["reward"] if data["reward"] is not None else 0.0) done = bool(data["done"]) rewards.append(reward) steps_taken = step inner = data.get("observation", {}) if inner.get("final_score") is not None: episode_score = float(inner["final_score"]) score_from_env = True log_step( step=step, action=json.dumps(action), reward=reward, done=done, error=None, ) obs = inner if done: break if rewards: avg_r = sum(rewards) / len(rewards) success = avg_r > 0.3 if not score_from_env and rewards: avg_r = sum(rewards) / len(rewards) episode_score = clamp_unit_interval((avg_r + 1.0) / 2.0) except Exception as exc: success = False print(f"[RUN] fatal: {exc}", file=sys.stderr) finally: episode_score = clamp_unit_interval(episode_score) log_end( success=success, steps=steps_taken, score=episode_score, rewards=rewards, ) def run() -> None: use_llm = bool(API_KEY and API_KEY != "hf-invalid") if os.getenv("RUN_ALL_TASKS", "").lower() in ("1", "true", "yes"): for tid in ("easy", "medium", "hard"): run_episode( env_url=ENV_URL, task_id=tid, seed=EPISODE_SEED, use_llm=use_llm, ) return run_episode( env_url=ENV_URL, task_id=TASK_ID, seed=EPISODE_SEED, use_llm=use_llm, ) if __name__ == "__main__": run()