cache-env / inference.py
Parv Pareek
done
e75c8ce
import json
import os
import sys
import textwrap
from pathlib import Path
from typing import Any, Dict, List, Optional
import requests
from openai import OpenAI
from env.grader import clamp_unit_interval
try:
from dotenv import load_dotenv
load_dotenv(Path(__file__).resolve().parent / ".env")
except ImportError:
pass
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
ENV_URL = os.getenv(
"ENV_URL",
"http://127.0.0.1:7860",
).rstrip("/")
BENCHMARK = "cache_invalidation_env"
# Reproducibility (Phase 1 / baseline): fixed seed + task → deterministic heuristic run.
EPISODE_SEED = int(os.getenv("EPISODE_SEED", "42"))
TASK_ID = os.getenv("TASK_ID", "easy")
if not API_KEY:
print(
"WARNING: HF_TOKEN is not set. LLM calls will fail; the script will use the "
"heuristic policy only.",
file=sys.stderr,
)
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY or "hf-invalid")
MEMORY: Dict[str, Any] = {}
LAST_USED: Optional[str] = None
SYSTEM_PROMPT = textwrap.dedent(
"""
You are a cache invalidation agent. Given the environment observation (JSON), reply with exactly one JSON object
on a single line, no markdown, with keys "type" and "key". type must be one of: invalidate, refresh, keep.
key must match one of the item keys in observation["items"].
"""
).strip()
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(
step: int, action: str, reward: float, done: bool, error: Optional[str]
) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
print(
f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
flush=True,
)
def select_item(obs: Dict[str, Any], step: int) -> Dict[str, Any]:
global LAST_USED
items = obs["items"]
def score(item: Dict[str, Any]) -> int:
s = 0
if item["last_result"] == "stale":
s += 3
if item["age"] > 5:
s += 2
if item["access_count"] > 10:
s += 1
return s
best = max(items, key=score)
if step % 2 == 1:
for item in items:
if item["key"] != LAST_USED:
LAST_USED = item["key"]
return item
LAST_USED = best["key"]
return best
def decide(item: Dict[str, Any], step: int) -> Dict[str, str]:
key = item["key"]
last_result = item["last_result"]
age = item["age"]
mem = MEMORY.get(key, {})
if mem.get("last_action") == "invalidate" and step - mem.get("last_step", -10) < 2:
return {"type": "keep", "key": key}
if last_result == "stale" and age > 2:
return {"type": "invalidate", "key": key}
if 3 <= age <= 6:
return {"type": "refresh", "key": key}
if last_result == "hit" and age < 3:
return {"type": "keep", "key": key}
if age > 6:
return {"type": "refresh", "key": key}
return {"type": "keep", "key": key}
def llm_action(obs: Dict[str, Any]) -> Optional[dict]:
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": (
f"Observation:\n{json.dumps(obs)}\n\n"
'Return JSON only: {"type": "...", "key": "..."}'
),
},
],
temperature=0,
max_tokens=150,
)
text = (completion.choices[0].message.content or "").strip()
if text.startswith("```"):
parts = text.split("```")
text = parts[1] if len(parts) >= 2 else text
text = text.strip()
if text.lower().startswith("json"):
text = text[4:].strip()
action = json.loads(text)
if "type" in action and "key" in action:
return {"type": action["type"], "key": action["key"]}
except Exception as exc:
print(f"[LLM] request/parse failed: {exc}", file=sys.stderr)
return None
def run_episode(*, env_url: str, task_id: str, seed: int, use_llm: bool) -> None:
"""One episode over OpenEnv HTTP API (wrapped action + observation)."""
global LAST_USED
LAST_USED = None
MEMORY.clear()
rewards: List[float] = []
steps_taken = 0
episode_score = 0.0
success = False
score_from_env = False
try:
res = requests.post(
f"{env_url}/reset",
json={"seed": seed, "task_id": task_id},
headers={"Content-Type": "application/json"},
timeout=60,
)
res.raise_for_status()
body = res.json()
obs = body.get("observation", body)
tid = str(obs.get("task_id", task_id))
log_start(task=tid, env=BENCHMARK, model=MODEL_NAME)
for step in range(1, 11):
item = select_item(obs, step)
action: Optional[dict] = None
if use_llm:
action = llm_action(obs)
if action is None:
action = decide(item, step)
MEMORY[item["key"]] = {
"last_action": action["type"],
"last_step": step,
}
step_res = requests.post(
f"{env_url}/step",
json={"action": action},
headers={"Content-Type": "application/json"},
timeout=60,
)
step_res.raise_for_status()
data = step_res.json()
reward = float(data["reward"] if data["reward"] is not None else 0.0)
done = bool(data["done"])
rewards.append(reward)
steps_taken = step
inner = data.get("observation", {})
if inner.get("final_score") is not None:
episode_score = float(inner["final_score"])
score_from_env = True
log_step(
step=step,
action=json.dumps(action),
reward=reward,
done=done,
error=None,
)
obs = inner
if done:
break
if rewards:
avg_r = sum(rewards) / len(rewards)
success = avg_r > 0.3
if not score_from_env and rewards:
avg_r = sum(rewards) / len(rewards)
episode_score = clamp_unit_interval((avg_r + 1.0) / 2.0)
except Exception as exc:
success = False
print(f"[RUN] fatal: {exc}", file=sys.stderr)
finally:
episode_score = clamp_unit_interval(episode_score)
log_end(
success=success,
steps=steps_taken,
score=episode_score,
rewards=rewards,
)
def run() -> None:
use_llm = bool(API_KEY and API_KEY != "hf-invalid")
if os.getenv("RUN_ALL_TASKS", "").lower() in ("1", "true", "yes"):
for tid in ("easy", "medium", "hard"):
run_episode(
env_url=ENV_URL,
task_id=tid,
seed=EPISODE_SEED,
use_llm=use_llm,
)
return
run_episode(
env_url=ENV_URL,
task_id=TASK_ID,
seed=EPISODE_SEED,
use_llm=use_llm,
)
if __name__ == "__main__":
run()