File size: 5,862 Bytes
ceeb029
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""Hugging Face Space UI for the CDN Cache Optimizer."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple

import gradio as gr
import matplotlib.pyplot as plt
import numpy as np

from env.cache import CDNCacheEnv, TASK_CONFIGS
from env.models import Action, Observation


@dataclass
class EpisodeMetrics:
    rewards: List[float]
    hit_rates: List[float]
    final_hit_rate: float
    total_reward: float
    bandwidth_saved_mb: float


def lru_baseline(obs: Observation) -> Action:
    if obs.cache_hit or not obs.cached_files:
        return Action(evict_file_id=None)
    victim = min(obs.cached_files, key=lambda f: f.last_accessed)
    return Action(evict_file_id=victim.file_id)


def smart_agent(obs: Observation) -> Action:
    if obs.cache_hit or not obs.cached_files:
        return Action(evict_file_id=None)
    if obs.cache_fill_ratio < 0.92:
        return Action(evict_file_id=None)

    preview = set(obs.queue_preview)

    def score(file_entry) -> Tuple[int, float, int, float]:
        preview_keep = 1 if file_entry.file_id in preview else 0
        viral_keep = 1 if file_entry.is_viral else 0
        return (
            preview_keep,
            viral_keep,
            file_entry.request_frequency,
            -file_entry.size_mb,
        )

    victim = min(obs.cached_files, key=score)
    return Action(evict_file_id=victim.file_id)


def run_episode(task_id: str, seed: int, policy: Callable[[Observation], Action]) -> EpisodeMetrics:
    env = CDNCacheEnv(task_id=task_id, seed=seed)
    obs = env.reset()
    rewards: List[float] = []
    hit_rates: List[float] = []
    done = False
    info: Dict = {}
    while not done:
        result = env.step(policy(obs))
        obs = result.observation
        info = result.info
        rewards.append(result.reward.total)
        hit_rates.append(float(info["hit_rate"]))
        done = result.done

    return EpisodeMetrics(
        rewards=rewards,
        hit_rates=hit_rates,
        final_hit_rate=float(info.get("hit_rate", 0.0)),
        total_reward=float(sum(rewards)),
        bandwidth_saved_mb=float(info.get("bandwidth_saved_mb", 0.0)),
    )


def make_plot(baseline: EpisodeMetrics, agent: EpisodeMetrics):
    fig, axes = plt.subplots(1, 2, figsize=(12, 4.6), dpi=150)
    fig.patch.set_facecolor("#0b1220")

    for ax in axes:
        ax.set_facecolor("#111827")
        ax.grid(True, alpha=0.25)
        ax.tick_params(colors="#d1d5db")
        ax.xaxis.label.set_color("#d1d5db")
        ax.yaxis.label.set_color("#d1d5db")
        ax.title.set_color("#f9fafb")

    x = np.arange(1, len(agent.hit_rates) + 1)
    axes[0].plot(x, baseline.hit_rates, color="#fb923c", lw=2, label="Baseline LRU")
    axes[0].plot(x, agent.hit_rates, color="#22c55e", lw=2, label="Fine-tuned Agent")
    axes[0].set_title("Cache Hit Rate Over Episode")
    axes[0].set_xlabel("Step")
    axes[0].set_ylabel("Hit rate")
    axes[0].legend(facecolor="#1f2937", labelcolor="#f9fafb")

    labels = ["Reward", "Hit Rate", "Bandwidth Saved"]
    baseline_values = [baseline.total_reward, baseline.final_hit_rate * 100, baseline.bandwidth_saved_mb]
    agent_values = [agent.total_reward, agent.final_hit_rate * 100, agent.bandwidth_saved_mb]
    idx = np.arange(len(labels))
    width = 0.36
    axes[1].bar(idx - width / 2, baseline_values, width, label="Baseline", color="#fb923c")
    axes[1].bar(idx + width / 2, agent_values, width, label="Agent", color="#22c55e")
    axes[1].set_xticks(idx)
    axes[1].set_xticklabels(labels, rotation=8, ha="right", color="#d1d5db")
    axes[1].set_title("Final Comparison")
    axes[1].legend(facecolor="#1f2937", labelcolor="#f9fafb")

    fig.suptitle("CDN Cache Optimizer: OpenEnv Agent Benchmark", color="#f9fafb", fontweight="bold")
    fig.tight_layout()
    return fig


def run_demo(task_label: str, seed: int):
    task_id = task_label.split(" ")[0]
    baseline = run_episode(task_id, int(seed), lru_baseline)
    agent = run_episode(task_id, int(seed), smart_agent)
    uplift = agent.final_hit_rate - baseline.final_hit_rate
    reward_uplift = agent.total_reward - baseline.total_reward
    summary = (
        f"### Results for `{task_id}`\n"
        f"- Baseline LRU reward: **{baseline.total_reward:.2f}**, hit rate: **{baseline.final_hit_rate:.1%}**\n"
        f"- Fine-tuned agent reward: **{agent.total_reward:.2f}**, hit rate: **{agent.final_hit_rate:.1%}**\n"
        f"- Reward uplift: **{reward_uplift:+.2f}** | Hit-rate uplift: **{uplift:+.1%}**\n\n"
        "The agent keeps viral/previewed objects, evicts low-frequency cold content, "
        "and avoids unnecessary churn under cache pressure."
    )
    return summary, make_plot(baseline, agent)


task_choices = [
    f"{task_id} - {cfg.name}" for task_id, cfg in TASK_CONFIGS.items()
]

with gr.Blocks(title="CDN Cache Optimizer") as demo:
    gr.Markdown(
        """
        # CDN Cache Optimizer

        OpenEnv-compliant reinforcement-learning environment for edge CDN cache
        admission and eviction. The live demo compares an LRU baseline with a
        fine-tuned agent policy on realistic steady and viral traffic.
        """
    )
    with gr.Row():
        task = gr.Dropdown(task_choices, value=task_choices[-1], label="OpenEnv task")
        seed = gr.Number(value=42, precision=0, label="Seed")
    run_btn = gr.Button("Run Benchmark", variant="primary")
    output = gr.Markdown()
    plot = gr.Plot()
    run_btn.click(run_demo, inputs=[task, seed], outputs=[output, plot])
    demo.load(run_demo, inputs=[task, seed], outputs=[output, plot])


if __name__ == "__main__":
    import os
    host = os.environ.get("GRADIO_SERVER_NAME", "127.0.0.1")
    port = int(os.environ.get("GRADIO_SERVER_PORT", "7860"))
    demo.launch(server_name=host, server_port=port)