File size: 6,833 Bytes
09e32d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef02d88
09e32d2
ef02d88
 
 
09e32d2
 
 
3589be9
09e32d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3589be9
09e32d2
 
 
 
 
 
 
 
3589be9
09e32d2
 
 
 
 
 
 
3589be9
09e32d2
3589be9
09e32d2
 
3589be9
09e32d2
 
 
 
 
3589be9
09e32d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef02d88
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""
inference.py - CDN Cache Optimizer Baseline Agent
Uses OpenAI client to run an LLM agent against the environment.
Emits structured [START], [STEP], [END] logs to stdout.

Required env vars:
  API_BASE_URL  - LLM API endpoint
  MODEL_NAME    - model identifier
  HF_TOKEN      - Hugging Face / API key
"""

import os
import sys
import json
import time
import requests
from openai import OpenAI
from env.cache import CDNCacheEnv, TASK_CONFIGS
from env.models import Action, Observation

# ─────────────────────────────────────────────
# Config from environment (required by OpenEnv spec)
# ─────────────────────────────────────────────
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME   = os.getenv("MODEL_NAME", "gpt-4o-mini")
HF_TOKEN     = os.getenv("HF_TOKEN")
HF_TOKEN     = os.environ.get("HF_TOKEN", "")

if not HF_TOKEN:
    print("[WARN] HF_TOKEN not set. Using API_BASE_URL without auth header override.", file=sys.stderr)

client = OpenAI(
    base_url=API_BASE_URL,
    api_key=HF_TOKEN or "placeholder",
)

TASKS = ["task_easy", "task_medium", "task_hard"]
SEED  = 42

# ─────────────────────────────────────────────
# LLM Agent
# ─────────────────────────────────────────────

SYSTEM_PROMPT = """You are an intelligent CDN cache management agent.

At each step you receive the current cache state and an incoming file request.
Your job: decide which file to evict (if any) to make room for new content.

Rules:
- Only evict a file if the cache is nearly full and the incoming file is NOT already cached
- Prefer evicting files with LOW request_frequency and NOT viral
- Never evict a file that was just evicted (cache thrashing)
- If cache has space, respond with null (no eviction needed)

You MUST respond with ONLY valid JSON in this exact format:
{"evict_file_id": "<file_id>" or null}

No explanation. No markdown. Only the JSON object."""


def build_user_prompt(obs: Observation) -> str:
    cached_summary = []
    for f in obs.cached_files:
        cached_summary.append(
            f"  - {f.file_id}: size={f.size_mb}MB freq={f.request_frequency:.1f} "
            f"viral={f.is_viral} last_accessed=step_{f.last_accessed}"
        )
    cached_str = "\n".join(cached_summary) if cached_summary else "  (empty)"

    space_needed = obs.incoming_file_size_mb
    space_free   = obs.cache_capacity_mb - obs.cache_used_mb

    return f"""Step {obs.step} | Time of day: {obs.time_of_day:.2f} | Hit rate: {obs.recent_hit_rate:.2f}

Cache: {obs.cache_used_mb:.1f}MB / {obs.cache_capacity_mb:.1f}MB used ({obs.cache_fill_ratio*100:.1f}% full)
Free space: {space_free:.1f}MB

Incoming request:
  file_id: {obs.incoming_file_id}
  size: {obs.incoming_file_size_mb}MB
  viral: {obs.incoming_file_is_viral}
  already_cached: {obs.cache_hit}
  space_needed_to_cache: {"none (fits)" if space_free >= space_needed else f"{space_needed - space_free:.1f}MB deficit"}

Next 3 requests preview: {obs.queue_preview}

Currently cached files ({len(obs.cached_files)} files):
{cached_str}

Decide: which file to evict? (null if no eviction needed)"""


def llm_action(obs: Observation, step_num: int) -> Action:
    """Call LLM and parse action. Fall back to LRU on failure."""
    prompt = build_user_prompt(obs)
    try:
        response = client.chat.completions.create(
            model=MODEL_NAME,
            messages=[
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user",   "content": prompt},
            ],
            max_tokens=50,
            temperature=0.0,
        )
        raw = response.choices[0].message.content.strip()
        parsed = json.loads(raw)
        return Action(evict_file_id=parsed.get("evict_file_id"))
    except Exception as e:
        # Fallback: LRU
        if obs.cached_files:
            lru = min(obs.cached_files, key=lambda f: f.last_accessed)
            return Action(evict_file_id=lru.file_id)
        return Action(evict_file_id=None)


# ─────────────────────────────────────────────
# Run one task episode
# ─────────────────────────────────────────────

def run_task(task_id: str) -> dict:
    config = TASK_CONFIGS[task_id]
    env    = CDNCacheEnv(task_id=task_id, seed=SEED)
    obs    = env.reset()

    total_reward = 0.0
    step_num     = 0

    # ── [START] ──
    print(f"[START] task={task_id}", flush=True)

    while True:
        action = llm_action(obs, step_num)
        result = env.step(action)

        total_reward += result.reward.total

        # ── [STEP] ──
        print(f"[STEP] step={step_num} reward={round(result.reward.total, 4)}", flush=True)

        obs      = result.observation
        step_num += 1

        if result.done:
            break

    final_state    = env.state()
    final_hit_rate = final_state["hit_rate"]
    score          = round(min(1.0, final_hit_rate / {"task_easy": 0.60, "task_medium": 0.55, "task_hard": 0.45}[task_id]), 4)

    # ── [END] ──
    print(f"[END] task={task_id} score={score} steps={step_num}", flush=True)

    return {
        "task_id":        task_id,
        "total_reward":   round(total_reward, 4),
        "final_hit_rate": round(final_hit_rate, 4),
        "score":          score,
    }


# ─────────────────────────────────────────────
# Main
# ─────────────────────────────────────────────

if __name__ == "__main__":
    print(f"[INFO] Starting CDN Cache Optimizer inference", file=sys.stderr)
    print(f"[INFO] Model: {MODEL_NAME} | API: {API_BASE_URL}", file=sys.stderr)

    results = []
    for task_id in TASKS:
        print(f"\n[INFO] Running {task_id}...", file=sys.stderr)
        r = run_task(task_id)
        results.append(r)
        print(f"[INFO] {task_id} done | score={r['score']} hit_rate={r['final_hit_rate']}", file=sys.stderr)

    print("\n[INFO] === FINAL RESULTS ===", file=sys.stderr)
    for r in results:
        print(f"[INFO] {r['task_id']}: score={r['score']} reward={r['total_reward']}", file=sys.stderr)

    overall = round(sum(r["score"] for r in results) / len(results), 4)
    print(f"[INFO] Overall score: {overall}", file=sys.stderr)