Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import sys | |
| import textwrap | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| import requests | |
| from openai import OpenAI | |
| from env.grader import clamp_unit_interval | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv(Path(__file__).resolve().parent / ".env") | |
| except ImportError: | |
| pass | |
| API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") | |
| API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1" | |
| MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct" | |
| ENV_URL = os.getenv( | |
| "ENV_URL", | |
| "http://127.0.0.1:7860", | |
| ).rstrip("/") | |
| BENCHMARK = "cache_invalidation_env" | |
| # Reproducibility (Phase 1 / baseline): fixed seed + task → deterministic heuristic run. | |
| EPISODE_SEED = int(os.getenv("EPISODE_SEED", "42")) | |
| TASK_ID = os.getenv("TASK_ID", "easy") | |
| if not API_KEY: | |
| print( | |
| "WARNING: HF_TOKEN is not set. LLM calls will fail; the script will use the " | |
| "heuristic policy only.", | |
| file=sys.stderr, | |
| ) | |
| client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY or "hf-invalid") | |
| MEMORY: Dict[str, Any] = {} | |
| LAST_USED: Optional[str] = None | |
| SYSTEM_PROMPT = textwrap.dedent( | |
| """ | |
| You are a cache invalidation agent. Given the environment observation (JSON), reply with exactly one JSON object | |
| on a single line, no markdown, with keys "type" and "key". type must be one of: invalidate, refresh, keep. | |
| key must match one of the item keys in observation["items"]. | |
| """ | |
| ).strip() | |
| def log_start(task: str, env: str, model: str) -> None: | |
| print(f"[START] task={task} env={env} model={model}", flush=True) | |
| def log_step( | |
| step: int, action: str, reward: float, done: bool, error: Optional[str] | |
| ) -> None: | |
| error_val = error if error else "null" | |
| done_val = str(done).lower() | |
| print( | |
| f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", | |
| flush=True, | |
| ) | |
| def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: | |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) | |
| print( | |
| f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", | |
| flush=True, | |
| ) | |
| def select_item(obs: Dict[str, Any], step: int) -> Dict[str, Any]: | |
| global LAST_USED | |
| items = obs["items"] | |
| def score(item: Dict[str, Any]) -> int: | |
| s = 0 | |
| if item["last_result"] == "stale": | |
| s += 3 | |
| if item["age"] > 5: | |
| s += 2 | |
| if item["access_count"] > 10: | |
| s += 1 | |
| return s | |
| best = max(items, key=score) | |
| if step % 2 == 1: | |
| for item in items: | |
| if item["key"] != LAST_USED: | |
| LAST_USED = item["key"] | |
| return item | |
| LAST_USED = best["key"] | |
| return best | |
| def decide(item: Dict[str, Any], step: int) -> Dict[str, str]: | |
| key = item["key"] | |
| last_result = item["last_result"] | |
| age = item["age"] | |
| mem = MEMORY.get(key, {}) | |
| if mem.get("last_action") == "invalidate" and step - mem.get("last_step", -10) < 2: | |
| return {"type": "keep", "key": key} | |
| if last_result == "stale" and age > 2: | |
| return {"type": "invalidate", "key": key} | |
| if 3 <= age <= 6: | |
| return {"type": "refresh", "key": key} | |
| if last_result == "hit" and age < 3: | |
| return {"type": "keep", "key": key} | |
| if age > 6: | |
| return {"type": "refresh", "key": key} | |
| return {"type": "keep", "key": key} | |
| def llm_action(obs: Dict[str, Any]) -> Optional[dict]: | |
| try: | |
| completion = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| { | |
| "role": "user", | |
| "content": ( | |
| f"Observation:\n{json.dumps(obs)}\n\n" | |
| 'Return JSON only: {"type": "...", "key": "..."}' | |
| ), | |
| }, | |
| ], | |
| temperature=0, | |
| max_tokens=150, | |
| ) | |
| text = (completion.choices[0].message.content or "").strip() | |
| if text.startswith("```"): | |
| parts = text.split("```") | |
| text = parts[1] if len(parts) >= 2 else text | |
| text = text.strip() | |
| if text.lower().startswith("json"): | |
| text = text[4:].strip() | |
| action = json.loads(text) | |
| if "type" in action and "key" in action: | |
| return {"type": action["type"], "key": action["key"]} | |
| except Exception as exc: | |
| print(f"[LLM] request/parse failed: {exc}", file=sys.stderr) | |
| return None | |
| def run_episode(*, env_url: str, task_id: str, seed: int, use_llm: bool) -> None: | |
| """One episode over OpenEnv HTTP API (wrapped action + observation).""" | |
| global LAST_USED | |
| LAST_USED = None | |
| MEMORY.clear() | |
| rewards: List[float] = [] | |
| steps_taken = 0 | |
| episode_score = 0.0 | |
| success = False | |
| score_from_env = False | |
| try: | |
| res = requests.post( | |
| f"{env_url}/reset", | |
| json={"seed": seed, "task_id": task_id}, | |
| headers={"Content-Type": "application/json"}, | |
| timeout=60, | |
| ) | |
| res.raise_for_status() | |
| body = res.json() | |
| obs = body.get("observation", body) | |
| tid = str(obs.get("task_id", task_id)) | |
| log_start(task=tid, env=BENCHMARK, model=MODEL_NAME) | |
| for step in range(1, 11): | |
| item = select_item(obs, step) | |
| action: Optional[dict] = None | |
| if use_llm: | |
| action = llm_action(obs) | |
| if action is None: | |
| action = decide(item, step) | |
| MEMORY[item["key"]] = { | |
| "last_action": action["type"], | |
| "last_step": step, | |
| } | |
| step_res = requests.post( | |
| f"{env_url}/step", | |
| json={"action": action}, | |
| headers={"Content-Type": "application/json"}, | |
| timeout=60, | |
| ) | |
| step_res.raise_for_status() | |
| data = step_res.json() | |
| reward = float(data["reward"] if data["reward"] is not None else 0.0) | |
| done = bool(data["done"]) | |
| rewards.append(reward) | |
| steps_taken = step | |
| inner = data.get("observation", {}) | |
| if inner.get("final_score") is not None: | |
| episode_score = float(inner["final_score"]) | |
| score_from_env = True | |
| log_step( | |
| step=step, | |
| action=json.dumps(action), | |
| reward=reward, | |
| done=done, | |
| error=None, | |
| ) | |
| obs = inner | |
| if done: | |
| break | |
| if rewards: | |
| avg_r = sum(rewards) / len(rewards) | |
| success = avg_r > 0.3 | |
| if not score_from_env and rewards: | |
| avg_r = sum(rewards) / len(rewards) | |
| episode_score = clamp_unit_interval((avg_r + 1.0) / 2.0) | |
| except Exception as exc: | |
| success = False | |
| print(f"[RUN] fatal: {exc}", file=sys.stderr) | |
| finally: | |
| episode_score = clamp_unit_interval(episode_score) | |
| log_end( | |
| success=success, | |
| steps=steps_taken, | |
| score=episode_score, | |
| rewards=rewards, | |
| ) | |
| def run() -> None: | |
| use_llm = bool(API_KEY and API_KEY != "hf-invalid") | |
| if os.getenv("RUN_ALL_TASKS", "").lower() in ("1", "true", "yes"): | |
| for tid in ("easy", "medium", "hard"): | |
| run_episode( | |
| env_url=ENV_URL, | |
| task_id=tid, | |
| seed=EPISODE_SEED, | |
| use_llm=use_llm, | |
| ) | |
| return | |
| run_episode( | |
| env_url=ENV_URL, | |
| task_id=TASK_ID, | |
| seed=EPISODE_SEED, | |
| use_llm=use_llm, | |
| ) | |
| if __name__ == "__main__": | |
| run() | |