Spaces:
Sleeping
Sleeping
File size: 6,833 Bytes
09e32d2 ef02d88 09e32d2 ef02d88 09e32d2 3589be9 09e32d2 3589be9 09e32d2 3589be9 09e32d2 3589be9 09e32d2 3589be9 09e32d2 3589be9 09e32d2 3589be9 09e32d2 ef02d88 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 | """
inference.py - CDN Cache Optimizer Baseline Agent
Uses OpenAI client to run an LLM agent against the environment.
Emits structured [START], [STEP], [END] logs to stdout.
Required env vars:
API_BASE_URL - LLM API endpoint
MODEL_NAME - model identifier
HF_TOKEN - Hugging Face / API key
"""
import os
import sys
import json
import time
import requests
from openai import OpenAI
from env.cache import CDNCacheEnv, TASK_CONFIGS
from env.models import Action, Observation
# βββββββββββββββββββββββββββββββββββββββββββββ
# Config from environment (required by OpenEnv spec)
# βββββββββββββββββββββββββββββββββββββββββββββ
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
HF_TOKEN = os.getenv("HF_TOKEN")
HF_TOKEN = os.environ.get("HF_TOKEN", "")
if not HF_TOKEN:
print("[WARN] HF_TOKEN not set. Using API_BASE_URL without auth header override.", file=sys.stderr)
client = OpenAI(
base_url=API_BASE_URL,
api_key=HF_TOKEN or "placeholder",
)
TASKS = ["task_easy", "task_medium", "task_hard"]
SEED = 42
# βββββββββββββββββββββββββββββββββββββββββββββ
# LLM Agent
# βββββββββββββββββββββββββββββββββββββββββββββ
SYSTEM_PROMPT = """You are an intelligent CDN cache management agent.
At each step you receive the current cache state and an incoming file request.
Your job: decide which file to evict (if any) to make room for new content.
Rules:
- Only evict a file if the cache is nearly full and the incoming file is NOT already cached
- Prefer evicting files with LOW request_frequency and NOT viral
- Never evict a file that was just evicted (cache thrashing)
- If cache has space, respond with null (no eviction needed)
You MUST respond with ONLY valid JSON in this exact format:
{"evict_file_id": "<file_id>" or null}
No explanation. No markdown. Only the JSON object."""
def build_user_prompt(obs: Observation) -> str:
cached_summary = []
for f in obs.cached_files:
cached_summary.append(
f" - {f.file_id}: size={f.size_mb}MB freq={f.request_frequency:.1f} "
f"viral={f.is_viral} last_accessed=step_{f.last_accessed}"
)
cached_str = "\n".join(cached_summary) if cached_summary else " (empty)"
space_needed = obs.incoming_file_size_mb
space_free = obs.cache_capacity_mb - obs.cache_used_mb
return f"""Step {obs.step} | Time of day: {obs.time_of_day:.2f} | Hit rate: {obs.recent_hit_rate:.2f}
Cache: {obs.cache_used_mb:.1f}MB / {obs.cache_capacity_mb:.1f}MB used ({obs.cache_fill_ratio*100:.1f}% full)
Free space: {space_free:.1f}MB
Incoming request:
file_id: {obs.incoming_file_id}
size: {obs.incoming_file_size_mb}MB
viral: {obs.incoming_file_is_viral}
already_cached: {obs.cache_hit}
space_needed_to_cache: {"none (fits)" if space_free >= space_needed else f"{space_needed - space_free:.1f}MB deficit"}
Next 3 requests preview: {obs.queue_preview}
Currently cached files ({len(obs.cached_files)} files):
{cached_str}
Decide: which file to evict? (null if no eviction needed)"""
def llm_action(obs: Observation, step_num: int) -> Action:
"""Call LLM and parse action. Fall back to LRU on failure."""
prompt = build_user_prompt(obs)
try:
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
max_tokens=50,
temperature=0.0,
)
raw = response.choices[0].message.content.strip()
parsed = json.loads(raw)
return Action(evict_file_id=parsed.get("evict_file_id"))
except Exception as e:
# Fallback: LRU
if obs.cached_files:
lru = min(obs.cached_files, key=lambda f: f.last_accessed)
return Action(evict_file_id=lru.file_id)
return Action(evict_file_id=None)
# βββββββββββββββββββββββββββββββββββββββββββββ
# Run one task episode
# βββββββββββββββββββββββββββββββββββββββββββββ
def run_task(task_id: str) -> dict:
config = TASK_CONFIGS[task_id]
env = CDNCacheEnv(task_id=task_id, seed=SEED)
obs = env.reset()
total_reward = 0.0
step_num = 0
# ββ [START] ββ
print(f"[START] task={task_id}", flush=True)
while True:
action = llm_action(obs, step_num)
result = env.step(action)
total_reward += result.reward.total
# ββ [STEP] ββ
print(f"[STEP] step={step_num} reward={round(result.reward.total, 4)}", flush=True)
obs = result.observation
step_num += 1
if result.done:
break
final_state = env.state()
final_hit_rate = final_state["hit_rate"]
score = round(min(1.0, final_hit_rate / {"task_easy": 0.60, "task_medium": 0.55, "task_hard": 0.45}[task_id]), 4)
# ββ [END] ββ
print(f"[END] task={task_id} score={score} steps={step_num}", flush=True)
return {
"task_id": task_id,
"total_reward": round(total_reward, 4),
"final_hit_rate": round(final_hit_rate, 4),
"score": score,
}
# βββββββββββββββββββββββββββββββββββββββββββββ
# Main
# βββββββββββββββββββββββββββββββββββββββββββββ
if __name__ == "__main__":
print(f"[INFO] Starting CDN Cache Optimizer inference", file=sys.stderr)
print(f"[INFO] Model: {MODEL_NAME} | API: {API_BASE_URL}", file=sys.stderr)
results = []
for task_id in TASKS:
print(f"\n[INFO] Running {task_id}...", file=sys.stderr)
r = run_task(task_id)
results.append(r)
print(f"[INFO] {task_id} done | score={r['score']} hit_rate={r['final_hit_rate']}", file=sys.stderr)
print("\n[INFO] === FINAL RESULTS ===", file=sys.stderr)
for r in results:
print(f"[INFO] {r['task_id']}: score={r['score']} reward={r['total_reward']}", file=sys.stderr)
overall = round(sum(r["score"] for r in results) / len(results), 4)
print(f"[INFO] Overall score: {overall}", file=sys.stderr) |